diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index 71a533903baa..de53bf6073d6 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -179,6 +179,9 @@ public: OperationName(AbstractOperation *op) : representation(op) {} OperationName(StringRef name, MLIRContext *context); + /// Return the name of the dialect this operation is registered to. + StringRef getDialect() const; + /// Return the name of this operation. This always succeeds. StringRef getStringRef() const; diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index 7e9248712bfc..843c49990c3f 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -274,8 +274,7 @@ LogicalResult FuncVerifier::verifyOperation(Operation &op) { return success(); // Otherwise, verify that the parent dialect allows un-registered operations. - auto opName = op.getName().getStringRef(); - auto dialectPrefix = opName.split('.').first; + auto dialectPrefix = op.getName().getDialect(); // Check for an existing answer for the operation dialect. auto it = dialectAllowsUnknownOps.find(dialectPrefix); @@ -291,7 +290,7 @@ LogicalResult FuncVerifier::verifyOperation(Operation &op) { } if (!it->second) { - return failure("unregistered operation '" + opName + + return failure("unregistered operation '" + op.getName().getStringRef() + "' found in dialect ('" + dialectPrefix + "') that does not allow unknown operations", op); diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 95d474b74831..64fb8bcd0e2d 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -38,6 +38,11 @@ OperationName::OperationName(StringRef name, MLIRContext *context) { representation = Identifier::get(name, context); } +/// Return the name of the dialect this operation is registered to. +StringRef OperationName::getDialect() const { + return getStringRef().split('.').first; +} + /// Return the name of this operation. This always succeeds. StringRef OperationName::getStringRef() const { if (auto *op = representation.dyn_cast()) @@ -275,10 +280,8 @@ Dialect *Operation::getDialect() { return &abstractOp->dialect; // If this operation hasn't been registered or doesn't have abstract - // operation, fall back to a dialect which matches the prefix. - auto opName = getName().getStringRef(); - auto dialectPrefix = opName.split('.').first; - return getContext()->getRegisteredDialect(dialectPrefix); + // operation, try looking up the dialect name in the context. + return getContext()->getRegisteredDialect(getName().getDialect()); } Region *Operation::getContainingRegion() const { @@ -528,16 +531,9 @@ LogicalResult Operation::fold(ArrayRef operands, return success(); // Otherwise, fall back on the dialect hook to handle it. - Dialect *dialect; - if (abstractOp) { - dialect = &abstractOp->dialect; - } else { - // If this operation hasn't been registered, lookup the parent dialect. - auto opName = getName().getStringRef(); - auto dialectPrefix = opName.split('.').first; - if (!(dialect = getContext()->getRegisteredDialect(dialectPrefix))) - return failure(); - } + Dialect *dialect = getDialect(); + if (!dialect) + return failure(); SmallVector constants; if (failed(dialect->constantFoldHook(this, operands, constants)))