diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h index 55b6f7efd365..067fe53dad38 100644 --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -35,9 +35,6 @@ using DialectConstantFoldHook = std::function, SmallVectorImpl &)>; using DialectExtractElementHook = std::function)>; -using DialectTypeParserHook = - std::function; -using DialectTypePrinterHook = std::function; /// Dialects are groups of MLIR operations and behavior associated with the /// entire group. For example, hooks into other systems for constant folding, @@ -80,11 +77,16 @@ public: return Attribute(); }; - /// Registered parsing/printing hooks for types registered to the dialect. - DialectTypeParserHook typeParseHook = nullptr; + /// Parse a type registered to this dialect. + virtual Type parseType(StringRef tyData, Location loc, + MLIRContext *context) const; + + /// Print a type registered to this dialect. /// Note: The data printed for the provided type must not include any '"' /// characters. - DialectTypePrinterHook typePrintHook = nullptr; + virtual void printType(Type, raw_ostream &) const { + assert(0 && "dialect has no registered type printing hook"); + } /// Registered hooks for getting identifier aliases for symbols. The /// identifier is used in place of the symbol when printing textual IR. diff --git a/mlir/include/mlir/LLVMIR/LLVMDialect.h b/mlir/include/mlir/LLVMIR/LLVMDialect.h index cd2b5c9d7085..6c1716597967 100644 --- a/mlir/include/mlir/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/LLVMIR/LLVMDialect.h @@ -76,6 +76,13 @@ public: llvm::LLVMContext &getLLVMContext() { return llvmContext; } llvm::Module &getLLVMModule() { return module; } + /// Parse a type registered to this dialect. + Type parseType(StringRef tyData, Location loc, + MLIRContext *context) const override; + + /// Print a type registered to this dialect. + void printType(Type type, raw_ostream &os) const override; + private: llvm::LLVMContext llvmContext; llvm::Module module; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 5348125577d8..9af1794cb052 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -715,8 +715,7 @@ void ModulePrinter::printType(Type type) { default: { auto &dialect = type.getDialect(); os << '!' << dialect.getNamespace() << "<\""; - assert(dialect.typePrintHook && "Expected dialect type printing hook."); - dialect.typePrintHook(type, os); + dialect.printType(type, os); os << "\">"; return; } diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp index 338c918c3396..c24d6b1f388d 100644 --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectHooks.h" #include "mlir/IR/MLIRContext.h" +#include "llvm/ADT/Twine.h" #include "llvm/Support/ManagedStatic.h" using namespace mlir; @@ -65,3 +66,11 @@ Dialect::Dialect(StringRef namePrefix, MLIRContext *context) } Dialect::~Dialect() {} + +/// Parse a type registered to this dialect. +Type Dialect::parseType(StringRef tyData, Location loc, + MLIRContext *context) const { + context->emitError(loc, "dialect '" + getNamespace() + + "' provides no type parsing hook"); + return Type(); +} diff --git a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp index 9c3d31da9ce8..3444b0ee4c7d 100644 --- a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp @@ -57,27 +57,6 @@ LLVMType LLVMType::get(MLIRContext *context, llvm::Type *llvmType) { return Base::get(context, FIRST_LLVM_TYPE, llvmType); } -static Type parseLLVMType(StringRef data, Location loc, MLIRContext *ctx) { - llvm::SMDiagnostic errorMessage; - auto *llvmDialect = - static_cast(ctx->getRegisteredDialect("llvm")); - assert(llvmDialect && "LLVM dialect not registered"); - llvm::Type *type = - llvm::parseType(data, errorMessage, llvmDialect->getLLVMModule()); - if (!type) { - ctx->emitError(loc, errorMessage.getMessage()); - return {}; - } - return LLVMType::get(ctx, type); -} - -static void printLLVMType(Type ty, raw_ostream &os) { - auto type = ty.dyn_cast(); - assert(type && "printing wrong type"); - assert(type.getUnderlyingType() && "no underlying LLVM type"); - type.getUnderlyingType()->print(os); -} - llvm::Type *LLVMType::getUnderlyingType() const { return static_cast(type)->underlyingType; } @@ -91,9 +70,24 @@ LLVMDialect::LLVMDialect(MLIRContext *context) addOperations< #include "mlir/LLVMIR/llvm_ops.inc" >(); +} - typeParseHook = parseLLVMType; - typePrintHook = printLLVMType; +/// Parse a type registered to this dialect. +Type LLVMDialect::parseType(StringRef tyData, Location loc, + MLIRContext *context) const { + llvm::SMDiagnostic errorMessage; + llvm::Type *type = llvm::parseType(tyData, errorMessage, module); + if (!type) + return (context->emitError(loc, errorMessage.getMessage()), nullptr); + return LLVMType::get(context, type); +} + +/// Print a type registered to this dialect. +void LLVMDialect::printType(Type type, raw_ostream &os) const { + auto llvmType = type.dyn_cast(); + assert(llvmType && "printing wrong type"); + assert(llvmType.getUnderlyingType() && "no underlying LLVM type"); + llvmType.getUnderlyingType()->print(os); } static DialectRegistration llvmDialect; diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index f7ab6f1fe119..ce37f56d9aaf 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -496,15 +496,7 @@ Type Parser::parseExtendedType() { return aliasIt->second; } - // Otherwise, check for a registered dialect with this name. - auto *dialect = state.context->getRegisteredDialect(identifier); - if (dialect) { - // Make sure that the dialect provides a parsing hook. - if (!dialect->typeParseHook) - return (emitError("dialect '" + dialect->getNamespace() + - "' provides no type parsing hook"), - nullptr); - } + // Otherwise, we are parsing a dialect-specific type. // Consume the '<'. if (parseToken(Token::less, "expected '<' in dialect type")) @@ -522,8 +514,8 @@ Type Parser::parseExtendedType() { Type result; // If we found a registered dialect, then ask it to parse the type. - if (dialect) { - result = dialect->typeParseHook(typeData, loc, state.context); + if (auto *dialect = state.context->getRegisteredDialect(identifier)) { + result = dialect->parseType(typeData, loc, state.context); if (!result) return nullptr; } else {