mirror of
https://github.com/intel/llvm.git
synced 2026-01-27 06:06:34 +08:00
[mlir] Simplify various pieces of code now that Identifier has access to the Context/Dialect
This also exposed a bug in Dialect loading where it was not correctly identifying identifiers that had the dialect namespace as a prefix. Differential Revision: https://reviews.llvm.org/D97431
This commit is contained in:
@@ -93,7 +93,7 @@ private:
|
||||
|
||||
/// Helper conversion for a Toy AST location to an MLIR location.
|
||||
mlir::Location loc(Location loc) {
|
||||
return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
|
||||
return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line,
|
||||
loc.col);
|
||||
}
|
||||
|
||||
|
||||
@@ -93,7 +93,7 @@ private:
|
||||
|
||||
/// Helper conversion for a Toy AST location to an MLIR location.
|
||||
mlir::Location loc(Location loc) {
|
||||
return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
|
||||
return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line,
|
||||
loc.col);
|
||||
}
|
||||
|
||||
|
||||
@@ -93,7 +93,7 @@ private:
|
||||
|
||||
/// Helper conversion for a Toy AST location to an MLIR location.
|
||||
mlir::Location loc(Location loc) {
|
||||
return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
|
||||
return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line,
|
||||
loc.col);
|
||||
}
|
||||
|
||||
|
||||
@@ -93,7 +93,7 @@ private:
|
||||
|
||||
/// Helper conversion for a Toy AST location to an MLIR location.
|
||||
mlir::Location loc(Location loc) {
|
||||
return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
|
||||
return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line,
|
||||
loc.col);
|
||||
}
|
||||
|
||||
|
||||
@@ -93,7 +93,7 @@ private:
|
||||
|
||||
/// Helper conversion for a Toy AST location to an MLIR location.
|
||||
mlir::Location loc(Location loc) {
|
||||
return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
|
||||
return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line,
|
||||
loc.col);
|
||||
}
|
||||
|
||||
|
||||
@@ -113,7 +113,7 @@ private:
|
||||
|
||||
/// Helper conversion for a Toy AST location to an MLIR location.
|
||||
mlir::Location loc(Location loc) {
|
||||
return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
|
||||
return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line,
|
||||
loc.col);
|
||||
}
|
||||
|
||||
|
||||
@@ -56,8 +56,6 @@ public:
|
||||
|
||||
// Locations.
|
||||
Location getUnknownLoc();
|
||||
Location getFileLineColLoc(Identifier filename, unsigned line,
|
||||
unsigned column);
|
||||
Location getFusedLoc(ArrayRef<Location> locs,
|
||||
Attribute metadata = Attribute());
|
||||
|
||||
|
||||
@@ -296,8 +296,7 @@ public:
|
||||
using Base::getChecked;
|
||||
|
||||
/// Get or create a new OpaqueAttr with the provided dialect and string data.
|
||||
static OpaqueAttr get(MLIRContext *context, Identifier dialect,
|
||||
StringRef attrData, Type type);
|
||||
static OpaqueAttr get(Identifier dialect, StringRef attrData, Type type);
|
||||
|
||||
/// Get or create a new OpaqueAttr with the provided dialect and string data.
|
||||
/// If the given identifier is not a valid namespace for a dialect, then a
|
||||
|
||||
@@ -293,6 +293,15 @@ def Builtin_Opaque : Builtin_Type<"Opaque"> {
|
||||
"Identifier":$dialectNamespace,
|
||||
StringRefParameter<"">:$typeData
|
||||
);
|
||||
|
||||
let builders = [
|
||||
TypeBuilderWithInferredContext<(ins
|
||||
"Identifier":$dialectNamespace, CArg<"StringRef", "{}">:$typeData
|
||||
), [{
|
||||
return $_get(dialectNamespace.getContext(), dialectNamespace, typeData);
|
||||
}]>
|
||||
];
|
||||
let skipDefaultBuilders = 1;
|
||||
let genVerifyDecl = 1;
|
||||
}
|
||||
|
||||
|
||||
@@ -129,8 +129,7 @@ public:
|
||||
using Base::Base;
|
||||
|
||||
/// Return a uniqued FileLineCol location object.
|
||||
static Location get(Identifier filename, unsigned line, unsigned column,
|
||||
MLIRContext *context);
|
||||
static Location get(Identifier filename, unsigned line, unsigned column);
|
||||
static Location get(StringRef filename, unsigned line, unsigned column,
|
||||
MLIRContext *context);
|
||||
|
||||
@@ -174,7 +173,7 @@ public:
|
||||
static Location get(Identifier name, Location child);
|
||||
|
||||
/// Return a uniqued name location object with an unknown child.
|
||||
static Location get(Identifier name, MLIRContext *context);
|
||||
static Location get(Identifier name);
|
||||
|
||||
/// Return the name identifier.
|
||||
Identifier getName() const;
|
||||
|
||||
@@ -491,7 +491,7 @@ def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">,
|
||||
class OpaqueType<string dialect, string name, string summary>
|
||||
: Type<CPred<"isOpaqueTypeWithName($_self, \""#dialect#"\", \""#name#"\")">,
|
||||
summary, "::mlir::OpaqueType">,
|
||||
BuildableType<"::mlir::OpaqueType::get($_builder.getContext(), "
|
||||
BuildableType<"::mlir::OpaqueType::get("
|
||||
"$_builder.getIdentifier(\"" # dialect # "\"), \""
|
||||
# name # "\")">;
|
||||
|
||||
|
||||
@@ -314,7 +314,15 @@ public:
|
||||
OperationName(StringRef name, MLIRContext *context);
|
||||
|
||||
/// Return the name of the dialect this operation is registered to.
|
||||
StringRef getDialect() const;
|
||||
StringRef getDialectNamespace() const;
|
||||
|
||||
/// Return the Dialect this operation is registered to if it is loaded in the
|
||||
/// context, or nullptr if the dialect isn't loaded.
|
||||
Dialect *getDialect() const {
|
||||
if (const auto *abstractOp = getAbstractOperation())
|
||||
return &abstractOp->dialect;
|
||||
return representation.get<Identifier>().getDialect();
|
||||
}
|
||||
|
||||
/// Return the operation name with dialect name stripped, if it has one.
|
||||
StringRef stripDialect() const;
|
||||
|
||||
@@ -163,9 +163,9 @@ bool mlirAttributeIsAOpaque(MlirAttribute attr) {
|
||||
MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace,
|
||||
intptr_t dataLength, const char *data,
|
||||
MlirType type) {
|
||||
return wrap(OpaqueAttr::get(
|
||||
unwrap(ctx), Identifier::get(unwrap(dialectNamespace), unwrap(ctx)),
|
||||
StringRef(data, dataLength), unwrap(type)));
|
||||
return wrap(
|
||||
OpaqueAttr::get(Identifier::get(unwrap(dialectNamespace), unwrap(ctx)),
|
||||
StringRef(data, dataLength), unwrap(type)));
|
||||
}
|
||||
|
||||
MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) {
|
||||
|
||||
@@ -29,11 +29,6 @@ Identifier Builder::getIdentifier(StringRef str) {
|
||||
|
||||
Location Builder::getUnknownLoc() { return UnknownLoc::get(context); }
|
||||
|
||||
Location Builder::getFileLineColLoc(Identifier filename, unsigned line,
|
||||
unsigned column) {
|
||||
return FileLineColLoc::get(filename, line, column, context);
|
||||
}
|
||||
|
||||
Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
|
||||
return FusedLoc::get(locs, metadata, context);
|
||||
}
|
||||
|
||||
@@ -382,9 +382,8 @@ IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
|
||||
// OpaqueAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpaqueAttr OpaqueAttr::get(MLIRContext *context, Identifier dialect,
|
||||
StringRef attrData, Type type) {
|
||||
return Base::get(context, dialect, attrData, type);
|
||||
OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type) {
|
||||
return Base::get(dialect.getContext(), dialect, attrData, type);
|
||||
}
|
||||
|
||||
OpaqueAttr OpaqueAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
|
||||
|
||||
@@ -127,8 +127,8 @@ Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const {
|
||||
Type Dialect::parseType(DialectAsmParser &parser) const {
|
||||
// If this dialect allows unknown types, then represent this with OpaqueType.
|
||||
if (allowsUnknownTypes()) {
|
||||
auto ns = Identifier::get(getNamespace(), getContext());
|
||||
return OpaqueType::get(getContext(), ns, parser.getFullSymbolSpec());
|
||||
Identifier ns = Identifier::get(getNamespace(), getContext());
|
||||
return OpaqueType::get(ns, parser.getFullSymbolSpec());
|
||||
}
|
||||
|
||||
parser.emitError(parser.getNameLoc())
|
||||
|
||||
@@ -48,14 +48,14 @@ Location CallSiteLoc::getCaller() const { return getImpl()->caller; }
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Location FileLineColLoc::get(Identifier filename, unsigned line,
|
||||
unsigned column, MLIRContext *context) {
|
||||
return Base::get(context, filename, line, column);
|
||||
unsigned column) {
|
||||
return Base::get(filename.getContext(), filename, line, column);
|
||||
}
|
||||
|
||||
Location FileLineColLoc::get(StringRef filename, unsigned line, unsigned column,
|
||||
MLIRContext *context) {
|
||||
return get(Identifier::get(filename.empty() ? "-" : filename, context), line,
|
||||
column, context);
|
||||
column);
|
||||
}
|
||||
|
||||
StringRef FileLineColLoc::getFilename() const { return getImpl()->filename; }
|
||||
@@ -112,8 +112,8 @@ Location NameLoc::get(Identifier name, Location child) {
|
||||
return Base::get(child->getContext(), name, child);
|
||||
}
|
||||
|
||||
Location NameLoc::get(Identifier name, MLIRContext *context) {
|
||||
return get(name, UnknownLoc::get(context));
|
||||
Location NameLoc::get(Identifier name) {
|
||||
return get(name, UnknownLoc::get(name.getContext()));
|
||||
}
|
||||
|
||||
/// Return the name identifier.
|
||||
|
||||
@@ -520,9 +520,11 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
|
||||
// Refresh all the identifiers dialect field, this catches cases where a
|
||||
// dialect may be loaded after identifier prefixed with this dialect name
|
||||
// were already created.
|
||||
llvm::SmallString<32> dialectPrefix(dialectNamespace);
|
||||
dialectPrefix.push_back('.');
|
||||
for (auto &identifierEntry : impl.identifiers)
|
||||
if (!identifierEntry.second &&
|
||||
identifierEntry.first().startswith(dialectNamespace))
|
||||
if (identifierEntry.second.is<MLIRContext *>() &&
|
||||
identifierEntry.first().startswith(dialectPrefix))
|
||||
identifierEntry.second = dialect.get();
|
||||
|
||||
// Actually register the interfaces with delayed registration.
|
||||
|
||||
@@ -35,8 +35,10 @@ OperationName::OperationName(StringRef name, MLIRContext *context) {
|
||||
}
|
||||
|
||||
/// Return the name of the dialect this operation is registered to.
|
||||
StringRef OperationName::getDialect() const {
|
||||
return getStringRef().split('.').first;
|
||||
StringRef OperationName::getDialectNamespace() const {
|
||||
if (Dialect *dialect = getDialect())
|
||||
return dialect->getNamespace();
|
||||
return representation.get<Identifier>().strref().split('.').first;
|
||||
}
|
||||
|
||||
/// Return the operation name with dialect name stripped, if it has one.
|
||||
@@ -213,14 +215,7 @@ MLIRContext *Operation::getContext() { return location->getContext(); }
|
||||
|
||||
/// Return the dialect this operation is associated with, or nullptr if the
|
||||
/// associated dialect is not registered.
|
||||
Dialect *Operation::getDialect() {
|
||||
if (auto *abstractOp = getAbstractOperation())
|
||||
return &abstractOp->dialect;
|
||||
|
||||
// If this operation hasn't been registered or doesn't have abstract
|
||||
// operation, try looking up the dialect name in the context.
|
||||
return getContext()->getLoadedDialect(getName().getDialect());
|
||||
}
|
||||
Dialect *Operation::getDialect() { return getName().getDialect(); }
|
||||
|
||||
Region *Operation::getParentRegion() {
|
||||
return block ? block->getParent() : nullptr;
|
||||
|
||||
@@ -46,13 +46,6 @@ public:
|
||||
/// Verify the given operation.
|
||||
LogicalResult verify(Operation &op);
|
||||
|
||||
/// Returns the registered dialect for a dialect-specific attribute.
|
||||
Dialect *getDialectForAttribute(const NamedAttribute &attr) {
|
||||
assert(attr.first.strref().contains('.') && "expected dialect attribute");
|
||||
auto dialectNamePair = attr.first.strref().split('.');
|
||||
return ctx->getLoadedDialect(dialectNamePair.first);
|
||||
}
|
||||
|
||||
private:
|
||||
/// Verify the given potentially nested region or block.
|
||||
LogicalResult verifyRegion(Region ®ion);
|
||||
@@ -81,10 +74,6 @@ private:
|
||||
|
||||
/// Dominance information for this operation, when checking dominance.
|
||||
DominanceInfo *domInfo = nullptr;
|
||||
|
||||
/// Mapping between dialect namespace and if that dialect supports
|
||||
/// unregistered operations.
|
||||
llvm::StringMap<bool> dialectAllowsUnknownOps;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
@@ -170,15 +159,14 @@ LogicalResult OperationVerifier::verifyOperation(Operation &op) {
|
||||
/// Verify that all of the attributes are okay.
|
||||
for (auto attr : op.getAttrs()) {
|
||||
// Check for any optional dialect specific attributes.
|
||||
if (!attr.first.strref().contains('.'))
|
||||
continue;
|
||||
if (auto *dialect = getDialectForAttribute(attr))
|
||||
if (auto *dialect = attr.first.getDialect())
|
||||
if (failed(dialect->verifyOperationAttribute(&op, attr)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
// If we can get operation info for this, check the custom hook.
|
||||
auto *opInfo = op.getAbstractOperation();
|
||||
OperationName opName = op.getName();
|
||||
auto *opInfo = opName.getAbstractOperation();
|
||||
if (opInfo && failed(opInfo->verifyInvariants(&op)))
|
||||
return failure();
|
||||
|
||||
@@ -213,33 +201,21 @@ LogicalResult OperationVerifier::verifyOperation(Operation &op) {
|
||||
return success();
|
||||
|
||||
// Otherwise, verify that the parent dialect allows un-registered operations.
|
||||
auto dialectPrefix = op.getName().getDialect();
|
||||
|
||||
// Check for an existing answer for the operation dialect.
|
||||
auto it = dialectAllowsUnknownOps.find(dialectPrefix);
|
||||
if (it == dialectAllowsUnknownOps.end()) {
|
||||
// If the operation dialect is registered, query it directly.
|
||||
if (auto *dialect = ctx->getLoadedDialect(dialectPrefix))
|
||||
it = dialectAllowsUnknownOps
|
||||
.try_emplace(dialectPrefix, dialect->allowsUnknownOperations())
|
||||
.first;
|
||||
// Otherwise, unregistered dialects (when allowed by the context)
|
||||
// conservatively allow unknown operations.
|
||||
else {
|
||||
if (!op.getContext()->allowsUnregisteredDialects() && !op.getDialect())
|
||||
return op.emitOpError()
|
||||
<< "created with unregistered dialect. If this is "
|
||||
"intended, please call allowUnregisteredDialects() on the "
|
||||
"MLIRContext, or use -allow-unregistered-dialect with "
|
||||
"mlir-opt";
|
||||
|
||||
it = dialectAllowsUnknownOps.try_emplace(dialectPrefix, true).first;
|
||||
Dialect *dialect = opName.getDialect();
|
||||
if (!dialect) {
|
||||
if (!ctx->allowsUnregisteredDialects()) {
|
||||
return op.emitOpError()
|
||||
<< "created with unregistered dialect. If this is "
|
||||
"intended, please call allowUnregisteredDialects() on the "
|
||||
"MLIRContext, or use -allow-unregistered-dialect with "
|
||||
"mlir-opt";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
if (!it->second) {
|
||||
if (!dialect->allowsUnknownOperations()) {
|
||||
return op.emitError("unregistered operation '")
|
||||
<< op.getName() << "' found in dialect ('" << dialectPrefix
|
||||
<< op.getName() << "' found in dialect ('" << dialect->getNamespace()
|
||||
<< "') that does not allow unknown operations";
|
||||
}
|
||||
|
||||
|
||||
@@ -563,7 +563,7 @@ Type Parser::parseExtendedType() {
|
||||
|
||||
// Otherwise, form a new opaque type.
|
||||
return OpaqueType::getChecked(
|
||||
getEncodedSourceLocation(loc), state.context,
|
||||
getEncodedSourceLocation(loc),
|
||||
Identifier::get(dialectName, state.context), symbolData);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -145,7 +145,7 @@ ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) {
|
||||
"expected ')' after child location of NameLoc"))
|
||||
return failure();
|
||||
} else {
|
||||
loc = NameLoc::get(Identifier::get(str, ctx), ctx);
|
||||
loc = NameLoc::get(Identifier::get(str, ctx));
|
||||
}
|
||||
|
||||
return success();
|
||||
|
||||
@@ -1944,8 +1944,8 @@ Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
|
||||
auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
|
||||
if (fileName.empty())
|
||||
fileName = "<unknown>";
|
||||
return opBuilder.getFileLineColLoc(opBuilder.getIdentifier(fileName),
|
||||
debugLine->line, debugLine->col);
|
||||
return FileLineColLoc::get(opBuilder.getIdentifier(fileName), debugLine->line,
|
||||
debugLine->col);
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
|
||||
@@ -44,8 +44,7 @@ static void generateLocationsFromIR(raw_ostream &os, StringRef fileName,
|
||||
if (it == opToLineCol.end())
|
||||
return;
|
||||
const std::pair<unsigned, unsigned> &lineCol = it->second;
|
||||
auto newLoc =
|
||||
builder.getFileLineColLoc(file, lineCol.first, lineCol.second);
|
||||
auto newLoc = FileLineColLoc::get(file, lineCol.first, lineCol.second);
|
||||
|
||||
// If we don't have a tag, set the location directly
|
||||
if (!tagIdentifier) {
|
||||
|
||||
@@ -2702,10 +2702,10 @@ auto ConversionTarget::getOpInfo(OperationName op) const
|
||||
if (it != legalOperations.end())
|
||||
return it->second;
|
||||
// Check for info for the parent dialect.
|
||||
auto dialectIt = legalDialects.find(op.getDialect());
|
||||
auto dialectIt = legalDialects.find(op.getDialectNamespace());
|
||||
if (dialectIt != legalDialects.end()) {
|
||||
Optional<DynamicLegalityCallbackFn> callback;
|
||||
auto dialectFn = dialectLegalityFns.find(op.getDialect());
|
||||
auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
|
||||
if (dialectFn != dialectLegalityFns.end())
|
||||
callback = dialectFn->second;
|
||||
return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
|
||||
|
||||
@@ -862,8 +862,7 @@ int main(int argc, char **argv) {
|
||||
}
|
||||
|
||||
genContext.setLoc(NameLoc::get(
|
||||
Identifier::get(opConfig.metadata->cppOpName, &mlirContext),
|
||||
&mlirContext));
|
||||
Identifier::get(opConfig.metadata->cppOpName, &mlirContext)));
|
||||
if (failed(generateOp(opConfig, genContext))) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -842,8 +842,7 @@ std::string PatternEmitter::handleLocationDirective(DagNode tree) {
|
||||
if (tree.getNumArgs() == 1) {
|
||||
DagLeaf leaf = tree.getArgAsLeaf(0);
|
||||
if (leaf.isStringAttr())
|
||||
return formatv("::mlir::NameLoc::get(rewriter.getIdentifier(\"{0}\"), "
|
||||
"rewriter.getContext())",
|
||||
return formatv("::mlir::NameLoc::get(rewriter.getIdentifier(\"{0}\"))",
|
||||
leaf.getStringAttr())
|
||||
.str();
|
||||
return lookUpArgLoc(0);
|
||||
|
||||
@@ -151,7 +151,7 @@ TEST(DenseSplatTest, BF16Splat) {
|
||||
TEST(DenseSplatTest, StringSplat) {
|
||||
MLIRContext context;
|
||||
Type stringType =
|
||||
OpaqueType::get(&context, Identifier::get("test", &context), "string");
|
||||
OpaqueType::get(Identifier::get("test", &context), "string");
|
||||
StringRef value = "test-string";
|
||||
testSplat(stringType, value);
|
||||
}
|
||||
@@ -159,7 +159,7 @@ TEST(DenseSplatTest, StringSplat) {
|
||||
TEST(DenseSplatTest, StringAttrSplat) {
|
||||
MLIRContext context;
|
||||
Type stringType =
|
||||
OpaqueType::get(&context, Identifier::get("test", &context), "string");
|
||||
OpaqueType::get(Identifier::get("test", &context), "string");
|
||||
Attribute stringAttr = StringAttr::get("test-string", stringType);
|
||||
testSplat(stringType, stringAttr);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user