diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h index 335d64e30649..13bf391ccc1a 100644 --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -40,7 +40,7 @@ class Dialect { public: MLIRContext *getContext() const { return context; } - StringRef getOperationPrefix() const { return opPrefix; } + StringRef getNamespace() const { return namePrefix; } /// Registered fallback constant fold hook for the dialect. Like the constant /// fold hook of each operation, it attempts to constant fold the operation @@ -60,9 +60,13 @@ public: virtual ~Dialect(); protected: - /// The prefix should be common across all ops in this set, e.g. "" for the - /// standard operation set, and "tf." for the TensorFlow ops like "tf.add". - Dialect(StringRef opPrefix, MLIRContext *context); + /// Note: The namePrefix can be empty, but it must not contain '.' characters. + /// Note: If the name is non empty, then all operations belonging to this + /// dialect will need to start with the namePrefix followed by a '.'. + /// Example: + /// - "" for the standard operation set. + /// - "tf" for the TensorFlow ops like "tf.add". + Dialect(StringRef namePrefix, MLIRContext *context); /// This method is used by derived classes to add their operations to the set. /// @@ -99,9 +103,8 @@ private: /// takes ownership of the heap allocated dialect. void registerDialect(MLIRContext *context); - /// This is the prefix that all operations belonging to this operation set - /// start with. - StringRef opPrefix; + /// This is the namespace used as a prefix for IR defined by this dialect. + StringRef namePrefix; /// This is the context that owns this Dialect object. MLIRContext *context; diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h index baf55b4b2c29..8705da343d4f 100644 --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -45,9 +45,9 @@ public: /// Return information about all registered IR dialects. std::vector getRegisteredDialects() const; - /// Get registered IR dialect which has the longest matching with the given - /// prefix. If none is found, returns nullptr. - Dialect *getRegisteredDialect(StringRef prefix) const; + /// Get a registered IR dialect with the given namespace. If an exact match is + /// not found, then return nullptr. + Dialect *getRegisteredDialect(StringRef name) const; /// Return information about all registered operations. This isn't very /// efficient: typically you should ask the operations about their properties diff --git a/mlir/lib/IR/BuiltinOps.cpp b/mlir/lib/IR/BuiltinOps.cpp index a0264fc11b05..17205ff260ff 100644 --- a/mlir/lib/IR/BuiltinOps.cpp +++ b/mlir/lib/IR/BuiltinOps.cpp @@ -32,7 +32,7 @@ using namespace mlir; //===----------------------------------------------------------------------===// BuiltinDialect::BuiltinDialect(MLIRContext *context) - : Dialect(/*opPrefix=*/"", context) { + : Dialect(/*namePrefix=*/"", context) { addOperations(); } diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp index 7640ce93c944..f6a163b18b3b 100644 --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -53,8 +53,10 @@ void mlir::registerAllDialects(MLIRContext *context) { fn(context); } -Dialect::Dialect(StringRef opPrefix, MLIRContext *context) - : opPrefix(opPrefix), context(context) { +Dialect::Dialect(StringRef namePrefix, MLIRContext *context) + : namePrefix(namePrefix), context(context) { + assert(!namePrefix.contains('.') && + "Dialect names cannot contain '.' characters."); registerDialect(context); } diff --git a/mlir/lib/IR/Instruction.cpp b/mlir/lib/IR/Instruction.cpp index 92f3c4ecba34..08b3cd90b995 100644 --- a/mlir/lib/IR/Instruction.cpp +++ b/mlir/lib/IR/Instruction.cpp @@ -512,7 +512,8 @@ bool OperationInst::constantFold(ArrayRef operands, // If this operation hasn't been registered or doesn't have abstract // operation, fall back to a dialect which matches the prefix. auto opName = getName().getStringRef(); - if (auto *dialect = getContext()->getRegisteredDialect(opName)) { + auto dialectPrefix = opName.split('.').first; + if (auto *dialect = getContext()->getRegisteredDialect(dialectPrefix)) { return dialect->constantFoldHook(llvm::cast(this), operands, results); } diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index abc3e1cfda4f..07ca6aa634bf 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -506,17 +506,13 @@ std::vector MLIRContext::getRegisteredDialects() const { return result; } -/// Get registered IR dialect which has the longest matching with the given -/// prefix. If none is found, returns nullptr. -Dialect *MLIRContext::getRegisteredDialect(StringRef prefix) const { - Dialect *result = nullptr; - for (auto &dialect : getImpl().dialects) { - if (prefix.startswith(dialect->getOperationPrefix())) - if (!result || result->getOperationPrefix().size() < - dialect->getOperationPrefix().size()) - result = dialect.get(); - } - return result; +/// Get a registered IR dialect with the given namespace. If none is found, +/// then return nullptr. +Dialect *MLIRContext::getRegisteredDialect(StringRef name) const { + for (auto &dialect : getImpl().dialects) + if (name == dialect->getNamespace()) + return dialect.get(); + return nullptr; } /// Register this dialect object with the specified context. The context @@ -549,8 +545,8 @@ std::vector MLIRContext::getRegisteredOperations() const { } void Dialect::addOperation(AbstractOperation opInfo) { - assert(opInfo.name.startswith(opPrefix) && - "op name doesn't start with prefix"); + assert((namePrefix.empty() || (opInfo.name.split('.').first == namePrefix)) && + "op name doesn't start with dialect prefix"); assert(&opInfo.dialect == this && "Dialect object mismatch"); auto &impl = context->getImpl(); diff --git a/mlir/lib/StandardOps/StandardOps.cpp b/mlir/lib/StandardOps/StandardOps.cpp index 44ca8277e783..4dfb3c5f3a81 100644 --- a/mlir/lib/StandardOps/StandardOps.cpp +++ b/mlir/lib/StandardOps/StandardOps.cpp @@ -36,7 +36,7 @@ using namespace mlir; //===----------------------------------------------------------------------===// StandardOpsDialect::StandardOpsDialect(MLIRContext *context) - : Dialect(/*opPrefix=*/"", context) { + : Dialect(/*namePrefix=*/"", context) { addOperations(); }