[mlir] Simplify various pieces of code now that Identifier has access to the Context/Dialect

This also exposed a bug in Dialect loading where it was not correctly identifying identifiers that had the dialect namespace as a prefix.

Differential Revision: https://reviews.llvm.org/D97431
This commit is contained in:
River Riddle
2021-02-26 17:57:03 -08:00
parent 16abacaea9
commit e6260ad043
28 changed files with 74 additions and 97 deletions

View File

@@ -93,7 +93,7 @@ private:
/// Helper conversion for a Toy AST location to an MLIR location.
mlir::Location loc(Location loc) {
return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line,
loc.col);
}

View File

@@ -93,7 +93,7 @@ private:
/// Helper conversion for a Toy AST location to an MLIR location.
mlir::Location loc(Location loc) {
return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line,
loc.col);
}

View File

@@ -93,7 +93,7 @@ private:
/// Helper conversion for a Toy AST location to an MLIR location.
mlir::Location loc(Location loc) {
return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line,
loc.col);
}

View File

@@ -93,7 +93,7 @@ private:
/// Helper conversion for a Toy AST location to an MLIR location.
mlir::Location loc(Location loc) {
return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line,
loc.col);
}

View File

@@ -93,7 +93,7 @@ private:
/// Helper conversion for a Toy AST location to an MLIR location.
mlir::Location loc(Location loc) {
return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line,
loc.col);
}

View File

@@ -113,7 +113,7 @@ private:
/// Helper conversion for a Toy AST location to an MLIR location.
mlir::Location loc(Location loc) {
return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line,
loc.col);
}

View File

@@ -56,8 +56,6 @@ public:
// Locations.
Location getUnknownLoc();
Location getFileLineColLoc(Identifier filename, unsigned line,
unsigned column);
Location getFusedLoc(ArrayRef<Location> locs,
Attribute metadata = Attribute());

View File

@@ -296,8 +296,7 @@ public:
using Base::getChecked;
/// Get or create a new OpaqueAttr with the provided dialect and string data.
static OpaqueAttr get(MLIRContext *context, Identifier dialect,
StringRef attrData, Type type);
static OpaqueAttr get(Identifier dialect, StringRef attrData, Type type);
/// Get or create a new OpaqueAttr with the provided dialect and string data.
/// If the given identifier is not a valid namespace for a dialect, then a

View File

@@ -293,6 +293,15 @@ def Builtin_Opaque : Builtin_Type<"Opaque"> {
"Identifier":$dialectNamespace,
StringRefParameter<"">:$typeData
);
let builders = [
TypeBuilderWithInferredContext<(ins
"Identifier":$dialectNamespace, CArg<"StringRef", "{}">:$typeData
), [{
return $_get(dialectNamespace.getContext(), dialectNamespace, typeData);
}]>
];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
}

View File

@@ -129,8 +129,7 @@ public:
using Base::Base;
/// Return a uniqued FileLineCol location object.
static Location get(Identifier filename, unsigned line, unsigned column,
MLIRContext *context);
static Location get(Identifier filename, unsigned line, unsigned column);
static Location get(StringRef filename, unsigned line, unsigned column,
MLIRContext *context);
@@ -174,7 +173,7 @@ public:
static Location get(Identifier name, Location child);
/// Return a uniqued name location object with an unknown child.
static Location get(Identifier name, MLIRContext *context);
static Location get(Identifier name);
/// Return the name identifier.
Identifier getName() const;

View File

@@ -491,7 +491,7 @@ def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">,
class OpaqueType<string dialect, string name, string summary>
: Type<CPred<"isOpaqueTypeWithName($_self, \""#dialect#"\", \""#name#"\")">,
summary, "::mlir::OpaqueType">,
BuildableType<"::mlir::OpaqueType::get($_builder.getContext(), "
BuildableType<"::mlir::OpaqueType::get("
"$_builder.getIdentifier(\"" # dialect # "\"), \""
# name # "\")">;

View File

@@ -314,7 +314,15 @@ public:
OperationName(StringRef name, MLIRContext *context);
/// Return the name of the dialect this operation is registered to.
StringRef getDialect() const;
StringRef getDialectNamespace() const;
/// Return the Dialect this operation is registered to if it is loaded in the
/// context, or nullptr if the dialect isn't loaded.
Dialect *getDialect() const {
if (const auto *abstractOp = getAbstractOperation())
return &abstractOp->dialect;
return representation.get<Identifier>().getDialect();
}
/// Return the operation name with dialect name stripped, if it has one.
StringRef stripDialect() const;

View File

@@ -163,9 +163,9 @@ bool mlirAttributeIsAOpaque(MlirAttribute attr) {
MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace,
intptr_t dataLength, const char *data,
MlirType type) {
return wrap(OpaqueAttr::get(
unwrap(ctx), Identifier::get(unwrap(dialectNamespace), unwrap(ctx)),
StringRef(data, dataLength), unwrap(type)));
return wrap(
OpaqueAttr::get(Identifier::get(unwrap(dialectNamespace), unwrap(ctx)),
StringRef(data, dataLength), unwrap(type)));
}
MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) {

View File

@@ -29,11 +29,6 @@ Identifier Builder::getIdentifier(StringRef str) {
Location Builder::getUnknownLoc() { return UnknownLoc::get(context); }
Location Builder::getFileLineColLoc(Identifier filename, unsigned line,
unsigned column) {
return FileLineColLoc::get(filename, line, column, context);
}
Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
return FusedLoc::get(locs, metadata, context);
}

View File

@@ -382,9 +382,8 @@ IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
// OpaqueAttr
//===----------------------------------------------------------------------===//
OpaqueAttr OpaqueAttr::get(MLIRContext *context, Identifier dialect,
StringRef attrData, Type type) {
return Base::get(context, dialect, attrData, type);
OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type) {
return Base::get(dialect.getContext(), dialect, attrData, type);
}
OpaqueAttr OpaqueAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,

View File

@@ -127,8 +127,8 @@ Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const {
Type Dialect::parseType(DialectAsmParser &parser) const {
// If this dialect allows unknown types, then represent this with OpaqueType.
if (allowsUnknownTypes()) {
auto ns = Identifier::get(getNamespace(), getContext());
return OpaqueType::get(getContext(), ns, parser.getFullSymbolSpec());
Identifier ns = Identifier::get(getNamespace(), getContext());
return OpaqueType::get(ns, parser.getFullSymbolSpec());
}
parser.emitError(parser.getNameLoc())

View File

@@ -48,14 +48,14 @@ Location CallSiteLoc::getCaller() const { return getImpl()->caller; }
//===----------------------------------------------------------------------===//
Location FileLineColLoc::get(Identifier filename, unsigned line,
unsigned column, MLIRContext *context) {
return Base::get(context, filename, line, column);
unsigned column) {
return Base::get(filename.getContext(), filename, line, column);
}
Location FileLineColLoc::get(StringRef filename, unsigned line, unsigned column,
MLIRContext *context) {
return get(Identifier::get(filename.empty() ? "-" : filename, context), line,
column, context);
column);
}
StringRef FileLineColLoc::getFilename() const { return getImpl()->filename; }
@@ -112,8 +112,8 @@ Location NameLoc::get(Identifier name, Location child) {
return Base::get(child->getContext(), name, child);
}
Location NameLoc::get(Identifier name, MLIRContext *context) {
return get(name, UnknownLoc::get(context));
Location NameLoc::get(Identifier name) {
return get(name, UnknownLoc::get(name.getContext()));
}
/// Return the name identifier.

View File

@@ -520,9 +520,11 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
// Refresh all the identifiers dialect field, this catches cases where a
// dialect may be loaded after identifier prefixed with this dialect name
// were already created.
llvm::SmallString<32> dialectPrefix(dialectNamespace);
dialectPrefix.push_back('.');
for (auto &identifierEntry : impl.identifiers)
if (!identifierEntry.second &&
identifierEntry.first().startswith(dialectNamespace))
if (identifierEntry.second.is<MLIRContext *>() &&
identifierEntry.first().startswith(dialectPrefix))
identifierEntry.second = dialect.get();
// Actually register the interfaces with delayed registration.

View File

@@ -35,8 +35,10 @@ OperationName::OperationName(StringRef name, MLIRContext *context) {
}
/// Return the name of the dialect this operation is registered to.
StringRef OperationName::getDialect() const {
return getStringRef().split('.').first;
StringRef OperationName::getDialectNamespace() const {
if (Dialect *dialect = getDialect())
return dialect->getNamespace();
return representation.get<Identifier>().strref().split('.').first;
}
/// Return the operation name with dialect name stripped, if it has one.
@@ -213,14 +215,7 @@ MLIRContext *Operation::getContext() { return location->getContext(); }
/// Return the dialect this operation is associated with, or nullptr if the
/// associated dialect is not registered.
Dialect *Operation::getDialect() {
if (auto *abstractOp = getAbstractOperation())
return &abstractOp->dialect;
// If this operation hasn't been registered or doesn't have abstract
// operation, try looking up the dialect name in the context.
return getContext()->getLoadedDialect(getName().getDialect());
}
Dialect *Operation::getDialect() { return getName().getDialect(); }
Region *Operation::getParentRegion() {
return block ? block->getParent() : nullptr;

View File

@@ -46,13 +46,6 @@ public:
/// Verify the given operation.
LogicalResult verify(Operation &op);
/// Returns the registered dialect for a dialect-specific attribute.
Dialect *getDialectForAttribute(const NamedAttribute &attr) {
assert(attr.first.strref().contains('.') && "expected dialect attribute");
auto dialectNamePair = attr.first.strref().split('.');
return ctx->getLoadedDialect(dialectNamePair.first);
}
private:
/// Verify the given potentially nested region or block.
LogicalResult verifyRegion(Region &region);
@@ -81,10 +74,6 @@ private:
/// Dominance information for this operation, when checking dominance.
DominanceInfo *domInfo = nullptr;
/// Mapping between dialect namespace and if that dialect supports
/// unregistered operations.
llvm::StringMap<bool> dialectAllowsUnknownOps;
};
} // end anonymous namespace
@@ -170,15 +159,14 @@ LogicalResult OperationVerifier::verifyOperation(Operation &op) {
/// Verify that all of the attributes are okay.
for (auto attr : op.getAttrs()) {
// Check for any optional dialect specific attributes.
if (!attr.first.strref().contains('.'))
continue;
if (auto *dialect = getDialectForAttribute(attr))
if (auto *dialect = attr.first.getDialect())
if (failed(dialect->verifyOperationAttribute(&op, attr)))
return failure();
}
// If we can get operation info for this, check the custom hook.
auto *opInfo = op.getAbstractOperation();
OperationName opName = op.getName();
auto *opInfo = opName.getAbstractOperation();
if (opInfo && failed(opInfo->verifyInvariants(&op)))
return failure();
@@ -213,33 +201,21 @@ LogicalResult OperationVerifier::verifyOperation(Operation &op) {
return success();
// Otherwise, verify that the parent dialect allows un-registered operations.
auto dialectPrefix = op.getName().getDialect();
// Check for an existing answer for the operation dialect.
auto it = dialectAllowsUnknownOps.find(dialectPrefix);
if (it == dialectAllowsUnknownOps.end()) {
// If the operation dialect is registered, query it directly.
if (auto *dialect = ctx->getLoadedDialect(dialectPrefix))
it = dialectAllowsUnknownOps
.try_emplace(dialectPrefix, dialect->allowsUnknownOperations())
.first;
// Otherwise, unregistered dialects (when allowed by the context)
// conservatively allow unknown operations.
else {
if (!op.getContext()->allowsUnregisteredDialects() && !op.getDialect())
return op.emitOpError()
<< "created with unregistered dialect. If this is "
"intended, please call allowUnregisteredDialects() on the "
"MLIRContext, or use -allow-unregistered-dialect with "
"mlir-opt";
it = dialectAllowsUnknownOps.try_emplace(dialectPrefix, true).first;
Dialect *dialect = opName.getDialect();
if (!dialect) {
if (!ctx->allowsUnregisteredDialects()) {
return op.emitOpError()
<< "created with unregistered dialect. If this is "
"intended, please call allowUnregisteredDialects() on the "
"MLIRContext, or use -allow-unregistered-dialect with "
"mlir-opt";
}
return success();
}
if (!it->second) {
if (!dialect->allowsUnknownOperations()) {
return op.emitError("unregistered operation '")
<< op.getName() << "' found in dialect ('" << dialectPrefix
<< op.getName() << "' found in dialect ('" << dialect->getNamespace()
<< "') that does not allow unknown operations";
}

View File

@@ -563,7 +563,7 @@ Type Parser::parseExtendedType() {
// Otherwise, form a new opaque type.
return OpaqueType::getChecked(
getEncodedSourceLocation(loc), state.context,
getEncodedSourceLocation(loc),
Identifier::get(dialectName, state.context), symbolData);
});
}

View File

@@ -145,7 +145,7 @@ ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) {
"expected ')' after child location of NameLoc"))
return failure();
} else {
loc = NameLoc::get(Identifier::get(str, ctx), ctx);
loc = NameLoc::get(Identifier::get(str, ctx));
}
return success();

View File

@@ -1944,8 +1944,8 @@ Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
if (fileName.empty())
fileName = "<unknown>";
return opBuilder.getFileLineColLoc(opBuilder.getIdentifier(fileName),
debugLine->line, debugLine->col);
return FileLineColLoc::get(opBuilder.getIdentifier(fileName), debugLine->line,
debugLine->col);
}
LogicalResult

View File

@@ -44,8 +44,7 @@ static void generateLocationsFromIR(raw_ostream &os, StringRef fileName,
if (it == opToLineCol.end())
return;
const std::pair<unsigned, unsigned> &lineCol = it->second;
auto newLoc =
builder.getFileLineColLoc(file, lineCol.first, lineCol.second);
auto newLoc = FileLineColLoc::get(file, lineCol.first, lineCol.second);
// If we don't have a tag, set the location directly
if (!tagIdentifier) {

View File

@@ -2702,10 +2702,10 @@ auto ConversionTarget::getOpInfo(OperationName op) const
if (it != legalOperations.end())
return it->second;
// Check for info for the parent dialect.
auto dialectIt = legalDialects.find(op.getDialect());
auto dialectIt = legalDialects.find(op.getDialectNamespace());
if (dialectIt != legalDialects.end()) {
Optional<DynamicLegalityCallbackFn> callback;
auto dialectFn = dialectLegalityFns.find(op.getDialect());
auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
if (dialectFn != dialectLegalityFns.end())
callback = dialectFn->second;
return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,

View File

@@ -862,8 +862,7 @@ int main(int argc, char **argv) {
}
genContext.setLoc(NameLoc::get(
Identifier::get(opConfig.metadata->cppOpName, &mlirContext),
&mlirContext));
Identifier::get(opConfig.metadata->cppOpName, &mlirContext)));
if (failed(generateOp(opConfig, genContext))) {
return 1;
}

View File

@@ -842,8 +842,7 @@ std::string PatternEmitter::handleLocationDirective(DagNode tree) {
if (tree.getNumArgs() == 1) {
DagLeaf leaf = tree.getArgAsLeaf(0);
if (leaf.isStringAttr())
return formatv("::mlir::NameLoc::get(rewriter.getIdentifier(\"{0}\"), "
"rewriter.getContext())",
return formatv("::mlir::NameLoc::get(rewriter.getIdentifier(\"{0}\"))",
leaf.getStringAttr())
.str();
return lookUpArgLoc(0);

View File

@@ -151,7 +151,7 @@ TEST(DenseSplatTest, BF16Splat) {
TEST(DenseSplatTest, StringSplat) {
MLIRContext context;
Type stringType =
OpaqueType::get(&context, Identifier::get("test", &context), "string");
OpaqueType::get(Identifier::get("test", &context), "string");
StringRef value = "test-string";
testSplat(stringType, value);
}
@@ -159,7 +159,7 @@ TEST(DenseSplatTest, StringSplat) {
TEST(DenseSplatTest, StringAttrSplat) {
MLIRContext context;
Type stringType =
OpaqueType::get(&context, Identifier::get("test", &context), "string");
OpaqueType::get(Identifier::get("test", &context), "string");
Attribute stringAttr = StringAttr::get("test-string", stringType);
testSplat(stringType, stringAttr);
}