Rename OperationPrefix to Namespace in Dialect. This is important as dialects will soon be able to define more than just operations.

Moving forward dialect namespaces cannot contain '.' characters.

This cl also standardizes that operation names must begin with the dialect namespace followed by a '.'.

PiperOrigin-RevId: 227532193
This commit is contained in:
River Riddle
2019-01-02 09:26:35 -08:00
committed by jpienaar
parent 0565067495
commit ae3f8a79ae
8 changed files with 31 additions and 29 deletions

View File

@@ -40,7 +40,7 @@ class Dialect {
public:
MLIRContext *getContext() const { return context; }
StringRef getOperationPrefix() const { return opPrefix; }
StringRef getNamespace() const { return namePrefix; }
/// Registered fallback constant fold hook for the dialect. Like the constant
/// fold hook of each operation, it attempts to constant fold the operation
@@ -60,9 +60,13 @@ public:
virtual ~Dialect();
protected:
/// The prefix should be common across all ops in this set, e.g. "" for the
/// standard operation set, and "tf." for the TensorFlow ops like "tf.add".
Dialect(StringRef opPrefix, MLIRContext *context);
/// Note: The namePrefix can be empty, but it must not contain '.' characters.
/// Note: If the name is non empty, then all operations belonging to this
/// dialect will need to start with the namePrefix followed by a '.'.
/// Example:
/// - "" for the standard operation set.
/// - "tf" for the TensorFlow ops like "tf.add".
Dialect(StringRef namePrefix, MLIRContext *context);
/// This method is used by derived classes to add their operations to the set.
///
@@ -99,9 +103,8 @@ private:
/// takes ownership of the heap allocated dialect.
void registerDialect(MLIRContext *context);
/// This is the prefix that all operations belonging to this operation set
/// start with.
StringRef opPrefix;
/// This is the namespace used as a prefix for IR defined by this dialect.
StringRef namePrefix;
/// This is the context that owns this Dialect object.
MLIRContext *context;

View File

@@ -45,9 +45,9 @@ public:
/// Return information about all registered IR dialects.
std::vector<Dialect *> getRegisteredDialects() const;
/// Get registered IR dialect which has the longest matching with the given
/// prefix. If none is found, returns nullptr.
Dialect *getRegisteredDialect(StringRef prefix) const;
/// Get a registered IR dialect with the given namespace. If an exact match is
/// not found, then return nullptr.
Dialect *getRegisteredDialect(StringRef name) const;
/// Return information about all registered operations. This isn't very
/// efficient: typically you should ask the operations about their properties

View File

@@ -32,7 +32,7 @@ using namespace mlir;
//===----------------------------------------------------------------------===//
BuiltinDialect::BuiltinDialect(MLIRContext *context)
: Dialect(/*opPrefix=*/"", context) {
: Dialect(/*namePrefix=*/"", context) {
addOperations<AffineApplyOp, BranchOp, CondBranchOp, ConstantOp, ReturnOp>();
}

View File

@@ -53,8 +53,10 @@ void mlir::registerAllDialects(MLIRContext *context) {
fn(context);
}
Dialect::Dialect(StringRef opPrefix, MLIRContext *context)
: opPrefix(opPrefix), context(context) {
Dialect::Dialect(StringRef namePrefix, MLIRContext *context)
: namePrefix(namePrefix), context(context) {
assert(!namePrefix.contains('.') &&
"Dialect names cannot contain '.' characters.");
registerDialect(context);
}

View File

@@ -512,7 +512,8 @@ bool OperationInst::constantFold(ArrayRef<Attribute> operands,
// If this operation hasn't been registered or doesn't have abstract
// operation, fall back to a dialect which matches the prefix.
auto opName = getName().getStringRef();
if (auto *dialect = getContext()->getRegisteredDialect(opName)) {
auto dialectPrefix = opName.split('.').first;
if (auto *dialect = getContext()->getRegisteredDialect(dialectPrefix)) {
return dialect->constantFoldHook(llvm::cast<OperationInst>(this), operands,
results);
}

View File

@@ -506,17 +506,13 @@ std::vector<Dialect *> MLIRContext::getRegisteredDialects() const {
return result;
}
/// Get registered IR dialect which has the longest matching with the given
/// prefix. If none is found, returns nullptr.
Dialect *MLIRContext::getRegisteredDialect(StringRef prefix) const {
Dialect *result = nullptr;
for (auto &dialect : getImpl().dialects) {
if (prefix.startswith(dialect->getOperationPrefix()))
if (!result || result->getOperationPrefix().size() <
dialect->getOperationPrefix().size())
result = dialect.get();
}
return result;
/// Get a registered IR dialect with the given namespace. If none is found,
/// then return nullptr.
Dialect *MLIRContext::getRegisteredDialect(StringRef name) const {
for (auto &dialect : getImpl().dialects)
if (name == dialect->getNamespace())
return dialect.get();
return nullptr;
}
/// Register this dialect object with the specified context. The context
@@ -549,8 +545,8 @@ std::vector<AbstractOperation *> MLIRContext::getRegisteredOperations() const {
}
void Dialect::addOperation(AbstractOperation opInfo) {
assert(opInfo.name.startswith(opPrefix) &&
"op name doesn't start with prefix");
assert((namePrefix.empty() || (opInfo.name.split('.').first == namePrefix)) &&
"op name doesn't start with dialect prefix");
assert(&opInfo.dialect == this && "Dialect object mismatch");
auto &impl = context->getImpl();

View File

@@ -36,7 +36,7 @@ using namespace mlir;
//===----------------------------------------------------------------------===//
StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
: Dialect(/*opPrefix=*/"", context) {
: Dialect(/*namePrefix=*/"", context) {
addOperations<AddFOp, AddIOp, AllocOp, CallOp, CallIndirectOp, CmpIOp,
DeallocOp, DimOp, DmaStartOp, DmaWaitOp, ExtractElementOp,
LoadOp, MemRefCastOp, MulFOp, MulIOp, SelectOp, StoreOp, SubFOp,

View File

@@ -34,7 +34,7 @@ using namespace mlir;
//===----------------------------------------------------------------------===//
SuperVectorOpsDialect::SuperVectorOpsDialect(MLIRContext *context)
: Dialect(/*opPrefix=*/"", context) {
: Dialect(/*namePrefix=*/"", context) {
addOperations<VectorTransferReadOp, VectorTransferWriteOp,
VectorTypeCastOp>();
}