From fbba6395171d8644e859db8e4ca1ed662a0962bc Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 3 Feb 2020 21:52:43 -0800 Subject: [PATCH] [mlir][ODS] Refactor BuildableType to use $_builder as part of the format Summary: Currently BuildableType is assumed to be preceded by a builder. This prevents constructing types that don't have a callable 'get' method with the builder. This revision reworks the format to be like attribute builders, i.e. by accepting $_builder within the format itself. Differential Revision: https://reviews.llvm.org/D73736 --- mlir/include/mlir/IR/OpBase.td | 17 +++++++++-------- mlir/lib/TableGen/Type.cpp | 14 +++++++++++--- mlir/tools/mlir-tblgen/OpFormatGen.cpp | 12 +++++++++--- 3 files changed, 29 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 499d73285cf3..399bd478322c 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -259,12 +259,14 @@ class Dialect { class Type : TypeConstraint { string typeDescription = ""; + string builderCall = ""; } // Allows providing an alternative name and description to an existing type def. class TypeAlias : Type { let typeDescription = t.typeDescription; + let builderCall = t.builderCall; } // A type of a specific dialect. @@ -289,7 +291,6 @@ class Variadic : TypeConstraint { // making some Types and some Attrs buildable. class BuildableType { // The builder call to invoke (if specified) to construct the BuildableType. - // Format: this will be affixed to the builder. code builderCall = builder; } @@ -313,13 +314,13 @@ def AnyInteger : Type()">, "integer">; // Index type. def Index : Type()">, "index">, - BuildableType<"getIndexType()">; + BuildableType<"$_builder.getIndexType()">; // Integer type of a specific width. class I : Type, width # "-bit integer">, - BuildableType<"getIntegerType(" # width # ")"> { + BuildableType<"$_builder.getIntegerType(" # width # ")"> { int bitwidth = width; } @@ -342,7 +343,7 @@ def AnyFloat : Type()">, "floating-point">; class F : Type, width # "-bit float">, - BuildableType<"getF" # width # "Type()"> { + BuildableType<"$_builder.getF" # width # "Type()"> { int bitwidth = width; } @@ -355,7 +356,7 @@ def F32 : F<32>; def F64 : F<64>; def BF16 : Type, "bfloat16 type">, - BuildableType<"getBF16Type()">; + BuildableType<"$_builder.getBF16Type()">; class Complex : Type allowedTypes> : TensorRankOf; class 4DTensorOf allowedTypes> : TensorRankOf; // Unranked Memref type -def AnyUnrankedMemRef : - ShapedContainerType<[AnyType], +def AnyUnrankedMemRef : + ShapedContainerType<[AnyType], IsUnrankedMemRefTypePred, "unranked.memref">; // Memref type. @@ -685,7 +686,7 @@ class OptionalAttr : Attr { class TypedAttrBase : Attr { - let constBuilderCall = "$_builder.get" # attrKind # "($_builder." # + let constBuilderCall = "$_builder.get" # attrKind # "(" # attrValType.builderCall # ", $0)"; let storageType = attrKind; } diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp index 8e84360d2c36..4105c48292fd 100644 --- a/mlir/lib/TableGen/Type.cpp +++ b/mlir/lib/TableGen/Type.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/TableGen/Type.h" +#include "mlir/ADT/TypeSwitch.h" #include "llvm/TableGen/Record.h" using namespace mlir; @@ -36,9 +37,16 @@ Optional TypeConstraint::getBuilderCall() const { if (isVariadic()) baseType = baseType->getValueAsDef("baseType"); - if (!baseType->isSubClassOf("BuildableType")) - return None; - return baseType->getValueAsString("builderCall"); + // Check to see if this type constraint has a builder call. + const llvm::RecordVal *builderCall = baseType->getValue("builderCall"); + if (!builderCall || !builderCall->getValue()) + return llvm::None; + return TypeSwitch>(builderCall->getValue()) + .Case([&](auto *init) { + StringRef value = init->getValue(); + return value.empty() ? Optional() : value; + }) + .Default([](auto *) { return llvm::None; }); } Type::Type(const llvm::Record *record) : TypeConstraint(record) {} diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index 9e97f79c5378..5ccc780fb8cb 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -410,9 +410,15 @@ void OperationFormat::genParser(Operator &op, OpClass &opClass) { void OperationFormat::genParserTypeResolution(Operator &op, OpMethodBody &body) { // Initialize the set of buildable types. - for (auto &it : buildableTypes) - body << " Type odsBuildableType" << it.second << " = parser.getBuilder()." - << it.first << ";\n"; + if (!buildableTypes.empty()) { + body << " Builder &builder = parser.getBuilder();\n"; + + FmtContext typeBuilderCtx; + typeBuilderCtx.withBuilder("builder"); + for (auto &it : buildableTypes) + body << " Type odsBuildableType" << it.second << " = " + << tgfmt(it.first, &typeBuilderCtx) << ";\n"; + } // Resolve each of the result types. if (allResultTypes) {