diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 269c09a030ef..802ae2c6335f 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -483,6 +483,9 @@ def FloatLike : TypeConstraint : AttrConstraint { @@ -517,6 +520,9 @@ class Attr : bit isOptional = 0; } +//===----------------------------------------------------------------------===// +// Attribute modifier definition + // Decorates an attribute to have an (unvalidated) default value if not present. class DefaultValuedAttr : Attr { @@ -550,6 +556,9 @@ class OptionalAttr : Attr { string baseAttr = !cast(attr); } +//===----------------------------------------------------------------------===// +// Primitive attribute kinds + // A generic attribute that must be constructed around a specific type // `attrValType`. Backed by MLIR attribute kind `attrKind`. class TypedAttrBase : def TypeAttr : TypeAttrBase<"Type", "any type attribute">; -// An enum attribute case. -class EnumAttrCase : StringBasedAttr< - CPred<"$_self.cast().getValue() == \"" # sym # "\"">, - "case " # sym> { +//===----------------------------------------------------------------------===// +// Enum attribute kinds + +// Additional information for an enum attribute case. +class EnumAttrCaseInfo { // The C++ enumerant symbol string symbol = sym; + // The C++ enumerant value - // A non-negative value must be provided if to use EnumsGen backend. + // If less than zero, there will be no explicit discriminator values assigned + // to enumerators in the generated enum class. int value = val; } -// An enum attribute. Its value can only be one from the given list of `cases`. -// Enum attributes are emulated via mlir::StringAttr, plus extra verification -// on the string: only the symbols of the allowed cases are permitted as the -// string value. -class EnumAttr cases> : - StringBasedAttr]>, - description> { +// An enum attribute case stored with StringAttr. +// TODO(antiagainst): rename this to StrEnumAttrCase to be consistent +class EnumAttrCase : + EnumAttrCaseInfo, + StringBasedAttr< + CPred<"$_self.cast().getValue() == \"" # sym # "\"">, + "case " # sym>; + +// An enum attribute case stored with IntegerAttr. +class IntEnumAttrCaseBase : + EnumAttrCaseInfo, + IntegerAttrBase { + let predicate = + CPred<"$_self.cast().getInt() == " # val>; +} + +class I32EnumAttrCase : IntEnumAttrCaseBase; +class I64EnumAttrCase : IntEnumAttrCaseBase; + +// Additional information for an enum attribute. +class EnumAttrInfo cases> { // The C++ enum class name string className = name; // List of all accepted cases - list enumerants = cases; + list enumerants = cases; // The following fields are only used by the EnumsGen backend to generate // an enum class definition and conversion utility functions. @@ -706,6 +731,51 @@ class EnumAttr cases> : string maxEnumValFnName = "getMaxEnumValFor" # name; } +// An enum attribute backed by StringAttr. +// +// Op attributes of this kind are stored as StringAttr. Extra verification will +// be generated on the string though: only the symbols of the allowed cases are +// permitted as the string value. +// TODO(antiagainst): rename this to StrEnumAttr to be consistent +class EnumAttr cases> : + EnumAttrInfo, + StringBasedAttr< + And<[StrAttr.predicate, Or]>, + !if(!empty(description), "allowed string cases: " # + StrJoin.result, + description)>; + +// An enum attribute backed by IntegerAttr. +// +// Op attributes of this kind are stored as IntegerAttr. Extra verification will +// be generated on the integer though: only the values of the allowed cases are +// permitted as the integer value. +class IntEnumAttr cases> : + EnumAttrInfo, + IntegerAttrBase.result, description)> { + let predicate = And<[ + IntegerAttrBase.predicate, + Or]>; +} + +class I32EnumAttr cases> : + IntEnumAttr { + let underlyingType = "uint32_t"; +} +class I64EnumAttr cases> : + IntEnumAttr { + let underlyingType = "uint64_t"; +} + +//===----------------------------------------------------------------------===// +// Composite attribute kinds + class ElementsAttrBase : Attr { let storageType = [{ ElementsAttr }]; @@ -771,6 +841,9 @@ def FunctionAttr : Attr()">, let constBuilderCall = "$_builder.getFunctionAttr($0)"; } +//===----------------------------------------------------------------------===// +// Derive attribute kinds + // DerivedAttr are attributes whose value is computed from properties // of the operation. They do not require additional storage and are // materialized as needed. @@ -782,6 +855,9 @@ class DerivedAttr : Attr, "derived attribute"> { // Derived attribute that returns a mlir::Type. class DerivedTypeAttr : DerivedAttr<"Type", body>; +//===----------------------------------------------------------------------===// +// Constant attribute kinds + // Represents a constant attribute of specific Attr type. A constant // attribute can be specified only of attributes that have a constant // builder call defined. The constant value is specified as a string. diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h index d4914703c5bd..a3fbf6dcf4b9 100644 --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -121,12 +121,15 @@ private: }; // Wrapper class providing helper methods for accessing enum attribute cases -// defined in TableGen. This class should closely reflect what is defined as -// class `EnumAttrCase` in TableGen. +// defined in TableGen. This is used for enum attribute case backed by both +// StringAttr and IntegerAttr. class EnumAttrCase : public Attribute { public: explicit EnumAttrCase(const llvm::DefInit *init); + // Returns true if this EnumAttrCase is backed by a StringAttr. + bool isStrCase() const; + // Returns the symbol of this enum attribute case. StringRef getSymbol() const; @@ -135,14 +138,17 @@ public: }; // Wrapper class providing helper methods for accessing enum attributes defined -// in TableGen. This class should closely reflect what is defined as class -// `EnumAttr` in TableGen. +// in TableGen.This is used for enum attribute case backed by both StringAttr +// and IntegerAttr. class EnumAttr : public Attribute { public: explicit EnumAttr(const llvm::Record *record); explicit EnumAttr(const llvm::Record &record); explicit EnumAttr(const llvm::DefInit *init); + // Returns true if this EnumAttr is backed by a StringAttr. + bool isStrEnum() const; + // Returns the enum class name. StringRef getEnumClassName() const; diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp index 1107048bb882..d292f2bdac62 100644 --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -136,8 +136,12 @@ StringRef tblgen::ConstantAttr::getConstantValue() const { tblgen::EnumAttrCase::EnumAttrCase(const llvm::DefInit *init) : Attribute(init) { - assert(def->isSubClassOf("EnumAttrCase") && - "must be subclass of TableGen 'EnumAttrCase' class"); + assert(def->isSubClassOf("EnumAttrCaseInfo") && + "must be subclass of TableGen 'EnumAttrInfo' class"); +} + +bool tblgen::EnumAttrCase::isStrCase() const { + return def->isSubClassOf("EnumAttrCase"); } StringRef tblgen::EnumAttrCase::getSymbol() const { @@ -145,11 +149,12 @@ StringRef tblgen::EnumAttrCase::getSymbol() const { } int64_t tblgen::EnumAttrCase::getValue() const { + assert(isStrCase() && "cannot get value for EnumAttrCase"); return def->getValueAsInt("value"); } tblgen::EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) { - assert(def->isSubClassOf("EnumAttr") && + assert(def->isSubClassOf("EnumAttrInfo") && "must be subclass of TableGen 'EnumAttr' class"); } @@ -158,6 +163,10 @@ tblgen::EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {} tblgen::EnumAttr::EnumAttr(const llvm::DefInit *init) : EnumAttr(init->getDef()) {} +bool tblgen::EnumAttr::isStrEnum() const { + return def->isSubClassOf("EnumAttr"); +} + StringRef tblgen::EnumAttr::getEnumClassName() const { return def->getValueAsString("className"); } diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index e2ddcbae076e..467c3d5c177d 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -54,7 +54,7 @@ bool tblgen::DagLeaf::isConstantAttr() const { } bool tblgen::DagLeaf::isEnumAttrCase() const { - return isSubClassOf("EnumAttrCase"); + return isSubClassOf("EnumAttrCaseInfo"); } tblgen::Constraint tblgen::DagLeaf::getAsConstraint() const { diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir index f7149145eac9..9c47bcec80b7 100644 --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -30,3 +30,88 @@ func @string_attr_custom_type() { test.string_attr_with_type "string_data" return } + +// ----- + +//===----------------------------------------------------------------------===// +// Test StrEnumAttr +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @allowed_cases_pass +func @allowed_cases_pass() { + // CHECK: test.str_enum_attr + %0 = "test.str_enum_attr"() {attr = "A"} : () -> i32 + // CHECK: test.str_enum_attr + %1 = "test.str_enum_attr"() {attr = "B"} : () -> i32 + return +} + +// ----- + +func @disallowed_case_fail() { + // expected-error @+1 {{allowed string cases: 'A', 'B'}} + %0 = "test.str_enum_attr"() {attr = 7: i32} : () -> i32 + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// Test I32EnumAttr +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @allowed_cases_pass +func @allowed_cases_pass() { + // CHECK: test.i32_enum_attr + %0 = "test.i32_enum_attr"() {attr = 5: i32} : () -> i32 + // CHECK: test.i32_enum_attr + %1 = "test.i32_enum_attr"() {attr = 10: i32} : () -> i32 + return +} + +// ----- + +func @disallowed_case7_fail() { + // expected-error @+1 {{allowed 32-bit integer cases: 5, 10}} + %0 = "test.i32_enum_attr"() {attr = 7: i32} : () -> i32 + return +} + +// ----- + +func @disallowed_case7_fail() { + // expected-error @+1 {{allowed 32-bit integer cases: 5, 10}} + %0 = "test.i32_enum_attr"() {attr = 5: i64} : () -> i32 + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// Test I64EnumAttr +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @allowed_cases_pass +func @allowed_cases_pass() { + // CHECK: test.i64_enum_attr + %0 = "test.i64_enum_attr"() {attr = 5: i64} : () -> i32 + // CHECK: test.i64_enum_attr + %1 = "test.i64_enum_attr"() {attr = 10: i64} : () -> i32 + return +} + +// ----- + +func @disallowed_case7_fail() { + // expected-error @+1 {{allowed 64-bit integer cases: 5, 10}} + %0 = "test.i64_enum_attr"() {attr = 7: i64} : () -> i32 + return +} + +// ----- + +func @disallowed_case7_fail() { + // expected-error @+1 {{allowed 64-bit integer cases: 5, 10}} + %0 = "test.i64_enum_attr"() {attr = 5: i32} : () -> i32 + return +} diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index ba5362b9985c..d64af5b0a972 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -104,6 +104,39 @@ def TypeStringAttrWithTypeOp : TEST_Op<"string_attr_with_type"> { }]; } +def StrCaseA: EnumAttrCase<"A">; +def StrCaseB: EnumAttrCase<"B">; + +def SomeStrEnum: EnumAttr< + "SomeStrEnum", "", [StrCaseA, StrCaseB]>; + +def StrEnumAttrOp : TEST_Op<"str_enum_attr"> { + let arguments = (ins SomeStrEnum:$attr); + let results = (outs I32:$val); +} + +def I32Case5: I32EnumAttrCase<"case5", 5>; +def I32Case10: I32EnumAttrCase<"case10", 10>; + +def SomeI32Enum: I32EnumAttr< + "SomeI32Enum", "", [I32Case5, I32Case10]>; + +def I32EnumAttrOp : TEST_Op<"i32_enum_attr"> { + let arguments = (ins SomeI32Enum:$attr); + let results = (outs I32:$val); +} + +def I64Case5: I64EnumAttrCase<"case5", 5>; +def I64Case10: I64EnumAttrCase<"case10", 10>; + +def SomeI64Enum: I64EnumAttr< + "SomeI64Enum", "", [I64Case5, I64Case10]>; + +def I64EnumAttrOp : TEST_Op<"i64_enum_attr"> { + let arguments = (ins SomeI64Enum:$attr); + let results = (outs I32:$val); +} + //===----------------------------------------------------------------------===// // Test Traits //===----------------------------------------------------------------------===// @@ -192,6 +225,12 @@ def : Pat<(OpD $input), (OpF $input), [], (addBenefit 10)>; def : Pat<(OpG $input), (OpB $input, ConstantAttr:$attr)>; def : Pat<(OpG (OpG $input)), (OpB $input, ConstantAttr:$attr)>; +// Test string enum attribute in rewrites. +def : Pat<(StrEnumAttrOp StrCaseA), (StrEnumAttrOp StrCaseB)>; +// Test integer enum attribute in rewrites. +def : Pat<(I32EnumAttrOp I32Case5), (I32EnumAttrOp I32Case10)>; +def : Pat<(I64EnumAttrOp I64Case5), (I64EnumAttrOp I64Case10)>; + //===----------------------------------------------------------------------===// // Test op regions //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/attr-enum.td b/mlir/test/mlir-tblgen/attr-enum.td deleted file mode 100644 index a86829c5dd15..000000000000 --- a/mlir/test/mlir-tblgen/attr-enum.td +++ /dev/null @@ -1,69 +0,0 @@ -// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF -// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s --check-prefix=PAT - -include "mlir/IR/OpBase.td" - -def NS_SomeEnum_A : EnumAttrCase<"A">; -def NS_SomeEnum_B : EnumAttrCase<"B">; - -def NS_SomeEnum : EnumAttr< - "SomeEnum", "some enum", [NS_SomeEnum_A, NS_SomeEnum_B]>; - -def Test_Dialect : Dialect { - let name = "test"; - let cppNamespace = "NS"; -} -class NS_Op traits> : - Op; - -def NS_OpA : NS_Op<"op_a_with_enum_attr", []> { - let arguments = (ins NS_SomeEnum:$attr); - let results = (outs I32:$result); -} - -// Test enum attribute getter method and verification -// --- - -// DEF-LABEL: StringRef OpA::attr() -// DEF-NEXT: auto attr = this->getAttr("attr").cast(); -// DEF-NEXT: return attr.getValue(); - -// DEF-LABEL: OpA::verify() -// DEF: auto tblgen_attr = this->getAttr("attr"); -// DEF: if (!(((tblgen_attr.isa())) && (((tblgen_attr.cast().getValue() == "A")) || ((tblgen_attr.cast().getValue() == "B"))))) -// DEF-SAME: return emitOpError("attribute 'attr' failed to satisfy constraint: some enum"); - -def NS_OpB : NS_Op<"op_b_with_enum_attr", []> { - let arguments = (ins NS_SomeEnum:$attr); - let results = (outs I32:$result); -} - -def : Pat<(NS_OpA NS_SomeEnum_A:$attr), (NS_OpB NS_SomeEnum_B)>; - -// Test enum attribute match and rewrite -// --- - -// PAT-LABEL: struct GeneratedConvert0 - -// PAT: PatternMatchResult match -// PAT: auto attr = op0->getAttrOfType("attr"); -// PAT-NEXT: if (!attr) return matchFailure(); -// PAT-NEXT: if (!((attr.cast().getValue() == "A"))) return matchFailure(); - -// PAT: void rewrite -// PAT: auto vOpB0 = rewriter.create(loc, -// PAT-NEXT: rewriter.getStringAttr("B") -// PAT-NEXT: ); - -def NS_SomeEnum_Array : TypedArrayAttrBase; - -def NS_OpC : NS_Op<"op_b_with_enum_array_attr", []> { - let arguments = (ins NS_SomeEnum_Array:$attr); - let results = (outs I32:$result); -} - -// Test enum array attribute verification -// --- - -// DEF-LABEL: OpC::verify() -// DEF: [](Attribute attr) { return ((attr.isa())) && (((attr.cast().getValue() == "A")) || ((attr.cast().getValue() == "B"))); } diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir index cf7026b43ed1..557c340fd63c 100644 --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -27,4 +27,25 @@ func @verifyBenefit(%arg0 : i32) -> i32 { // CHECK: "test.op_f"(%arg0) // CHECK: "test.op_b"(%arg0) {attr = 34 : i32} return %0 : i32 -} \ No newline at end of file +} + +// CHECK-LABEL: verifyStrEnumAttr +func @verifyStrEnumAttr() -> i32 { + // CHECK: "test.str_enum_attr"() {attr = "B"} + %0 = "test.str_enum_attr"() {attr = "A"} : () -> i32 + return %0 : i32 +} + +// CHECK-LABEL: verifyI32EnumAttr +func @verifyI32EnumAttr() -> i32 { + // CHECK: "test.i32_enum_attr"() {attr = 10 : i32} + %0 = "test.i32_enum_attr"() {attr = 5: i32} : () -> i32 + return %0 : i32 +} + +// CHECK-LABEL: verifyI64EnumAttr +func @verifyI64EnumAttr() -> i32 { + // CHECK: "test.i64_enum_attr"() {attr = 10 : i64} + %0 = "test.i64_enum_attr"() {attr = 5: i64} : () -> i32 + return %0 : i32 +} diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp index f85fa138be1b..962a7eafaf43 100644 --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -60,11 +60,11 @@ static void emitEnumClass(const Record &enumDef, StringRef enumName, for (const auto &enumerant : enumerants) { auto symbol = makeIdentifier(enumerant.getSymbol()); auto value = enumerant.getValue(); - if (value < 0) { - llvm::PrintFatalError(enumDef.getLoc(), - "all enumerants must have a non-negative value"); + if (value >= 0) { + os << formatv(" {0} = {1},\n", symbol, value); + } else { + os << formatv(" {0},\n", symbol); } - os << formatv(" {0} = {1},\n", symbol, value); } os << "};\n\n"; } @@ -101,6 +101,88 @@ template<> struct DenseMapInfo<{0}> {{ os << "\n\n"; } +static void emitMaxValueFn(const Record &enumDef, raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef maxEnumValFnName = enumAttr.getMaxEnumValFnName(); + auto enumerants = enumAttr.getAllCases(); + + unsigned maxEnumVal = 0; + for (const auto &enumerant : enumerants) { + int64_t value = enumerant.getValue(); + // Avoid generating the max value function if there is an enumerant without + // explicit value. + if (value < 0) + return; + + maxEnumVal = std::max(maxEnumVal, static_cast(value)); + } + + // Emit the function to return the max enum value + os << formatv("inline constexpr unsigned {0}() {{\n", maxEnumValFnName); + os << formatv(" return {0};\n", maxEnumVal); + os << "}\n\n"; +} + +static void emitSymToStrFn(const Record &enumDef, raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef enumName = enumAttr.getEnumClassName(); + StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); + auto enumerants = enumAttr.getAllCases(); + + os << formatv("llvm::StringRef {1}({0} val) {{\n", enumName, symToStrFnName); + os << " switch (val) {\n"; + for (const auto &enumerant : enumerants) { + auto symbol = enumerant.getSymbol(); + os << formatv(" case {0}::{1}: return \"{2}\";\n", enumName, + makeIdentifier(symbol), symbol); + } + os << " }\n"; + os << " return \"\";\n"; + os << "}\n\n"; +} + +static void emitStrToSymFn(const Record &enumDef, raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef enumName = enumAttr.getEnumClassName(); + StringRef strToSymFnName = enumAttr.getStringToSymbolFnName(); + auto enumerants = enumAttr.getAllCases(); + + os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef str) {{\n", enumName, + strToSymFnName); + os << formatv(" return llvm::StringSwitch>(str)\n", + enumName); + for (const auto &enumerant : enumerants) { + auto symbol = enumerant.getSymbol(); + os << formatv(" .Case(\"{1}\", {0}::{2})\n", enumName, symbol, + makeIdentifier(symbol)); + } + os << " .Default(llvm::None);\n"; + os << "}\n"; +} + +static void emitUnderlyingToSymFn(const Record &enumDef, raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef enumName = enumAttr.getEnumClassName(); + std::string underlyingType = enumAttr.getUnderlyingType(); + StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName(); + auto enumerants = enumAttr.getAllCases(); + + os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", enumName, + underlyingToSymFnName, + underlyingType.empty() ? std::string("unsigned") + : underlyingType) + << " switch (value) {\n"; + for (const auto &enumerant : enumerants) { + auto symbol = enumerant.getSymbol(); + auto value = enumerant.getValue(); + os << formatv(" case {0}: return {1}::{2};\n", value, enumName, + makeIdentifier(symbol)); + } + os << " default: return llvm::None;\n" + << " }\n" + << "}\n\n"; +} + static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { EnumAttr enumAttr(enumDef); StringRef enumName = enumAttr.getEnumClassName(); @@ -110,7 +192,6 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { StringRef strToSymFnName = enumAttr.getStringToSymbolFnName(); StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName(); - StringRef maxEnumValFnName = enumAttr.getMaxEnumValFnName(); auto enumerants = enumAttr.getAllCases(); llvm::SmallVector namespaces; @@ -130,20 +211,11 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef);\n", enumName, strToSymFnName); + emitMaxValueFn(enumDef, os); + for (auto ns : llvm::reverse(namespaces)) os << "} // namespace " << ns << "\n"; - // Emit the function to return the max enum value - unsigned maxEnumVal = 0; - for (const auto &enumerant : enumerants) { - auto value = enumerant.getValue(); - // Already checked that the value is non-negetive. - maxEnumVal = std::max(maxEnumVal, static_cast(value)); - } - os << formatv("inline constexpr unsigned {0}() {{\n", maxEnumValFnName); - os << formatv(" return {0};\n", maxEnumVal); - os << "}\n\n"; - // Emit DenseMapInfo for this enum class emitDenseMapInfo(enumName, underlyingType, cppNamespace, os); } @@ -151,7 +223,7 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { llvm::emitSourceFileHeader("Enum Utility Declarations", os); - auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttr"); + auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo"); for (const auto *def : defs) emitEnumDecl(*def, os); @@ -160,13 +232,7 @@ static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { static void emitEnumDef(const Record &enumDef, raw_ostream &os) { EnumAttr enumAttr(enumDef); - StringRef enumName = enumAttr.getEnumClassName(); StringRef cppNamespace = enumAttr.getCppNamespace(); - std::string underlyingType = enumAttr.getUnderlyingType(); - StringRef strToSymFnName = enumAttr.getStringToSymbolFnName(); - StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); - StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName(); - auto enumerants = enumAttr.getAllCases(); llvm::SmallVector namespaces; llvm::SplitString(cppNamespace, namespaces, "::"); @@ -174,43 +240,9 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) { for (auto ns : namespaces) os << "namespace " << ns << " {\n"; - os << formatv("llvm::StringRef {1}({0} val) {{\n", enumName, symToStrFnName); - os << " switch (val) {\n"; - for (const auto &enumerant : enumerants) { - auto symbol = enumerant.getSymbol(); - os << formatv(" case {0}::{1}: return \"{2}\";\n", enumName, - makeIdentifier(symbol), symbol); - } - os << " }\n"; - os << " return \"\";\n"; - os << "}\n\n"; - - os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", enumName, - underlyingToSymFnName, - underlyingType.empty() ? std::string("unsigned") - : underlyingType) - << " switch (value) {\n"; - for (const auto &enumerant : enumerants) { - auto symbol = enumerant.getSymbol(); - auto value = enumerant.getValue(); - os << formatv(" case {0}: return {1}::{2};\n", value, enumName, - makeIdentifier(symbol)); - } - os << " default: return llvm::None;\n" - << " }\n" - << "}\n\n"; - - os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef str) {{\n", enumName, - strToSymFnName); - os << formatv(" return llvm::StringSwitch>(str)\n", - enumName); - for (const auto &enumerant : enumerants) { - auto symbol = enumerant.getSymbol(); - os << formatv(" .Case(\"{1}\", {0}::{2})\n", enumName, symbol, - makeIdentifier(symbol)); - } - os << " .Default(llvm::None);\n"; - os << "}\n"; + emitSymToStrFn(enumDef, os); + emitStrToSymFn(enumDef, os); + emitUnderlyingToSymFn(enumDef, os); for (auto ns : llvm::reverse(namespaces)) os << "} // namespace " << ns << "\n"; @@ -220,7 +252,7 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) { static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { llvm::emitSourceFileHeader("Enum Utility Definitions", os); - auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttr"); + auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo"); for (const auto *def : defs) emitEnumDef(*def, os); diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 203393b5caa9..2419c31b1701 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -628,7 +628,12 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf, } if (leaf.isEnumAttrCase()) { auto enumCase = leaf.getAsEnumAttrCase(); - return handleConstantAttr(enumCase, enumCase.getSymbol()); + if (enumCase.isStrCase()) + return handleConstantAttr(enumCase, enumCase.getSymbol()); + // This is an enum case backed by an IntegerAttr. We need to get its value + // to build the constant. + std::string val = std::to_string(enumCase.getValue()); + return handleConstantAttr(enumCase, val); } pattern.ensureBoundInSourcePattern(argName); std::string result = getBoundSymbol(argName).str(); diff --git a/mlir/unittests/TableGen/EnumsGenTest.cpp b/mlir/unittests/TableGen/EnumsGenTest.cpp index b9a98a4504ca..f20ec0ca3819 100644 --- a/mlir/unittests/TableGen/EnumsGenTest.cpp +++ b/mlir/unittests/TableGen/EnumsGenTest.cpp @@ -31,36 +31,40 @@ using ::testing::StrEq; // Test namespaces and enum class/utility names using Outer::Inner::ConvertToEnum; using Outer::Inner::ConvertToString; -using Outer::Inner::MyEnum; +using Outer::Inner::StrEnum; -TEST(EnumsGenTest, GeneratedEnumDefinition) { - EXPECT_EQ(0u, static_cast(MyEnum::CaseA)); - EXPECT_EQ(10u, static_cast(MyEnum::CaseB)); +TEST(EnumsGenTest, GeneratedStrEnumDefinition) { + EXPECT_EQ(0u, static_cast(StrEnum::CaseA)); + EXPECT_EQ(10u, static_cast(StrEnum::CaseB)); +} + +TEST(EnumsGenTest, GeneratedI32EnumDefinition) { + EXPECT_EQ(5u, static_cast(I32Enum::Case5)); + EXPECT_EQ(10u, static_cast(I32Enum::Case10)); } TEST(EnumsGenTest, GeneratedDenseMapInfo) { - llvm::DenseMap myMap; + llvm::DenseMap myMap; - myMap[MyEnum::CaseA] = "zero"; - myMap[MyEnum::CaseB] = "ten"; + myMap[StrEnum::CaseA] = "zero"; + myMap[StrEnum::CaseB] = "one"; - EXPECT_THAT(myMap[MyEnum::CaseA], StrEq("zero")); - EXPECT_THAT(myMap[MyEnum::CaseB], StrEq("ten")); + EXPECT_THAT(myMap[StrEnum::CaseA], StrEq("zero")); + EXPECT_THAT(myMap[StrEnum::CaseB], StrEq("one")); } TEST(EnumsGenTest, GeneratedSymbolToStringFn) { - EXPECT_THAT(ConvertToString(MyEnum::CaseA), StrEq("CaseA")); - EXPECT_THAT(ConvertToString(MyEnum::CaseB), StrEq("CaseB")); + EXPECT_THAT(ConvertToString(StrEnum::CaseA), StrEq("CaseA")); + EXPECT_THAT(ConvertToString(StrEnum::CaseB), StrEq("CaseB")); } TEST(EnumsGenTest, GeneratedStringToSymbolFn) { - EXPECT_EQ(llvm::Optional(MyEnum::CaseA), ConvertToEnum("CaseA")); - EXPECT_EQ(llvm::Optional(MyEnum::CaseB), ConvertToEnum("CaseB")); + EXPECT_EQ(llvm::Optional(StrEnum::CaseA), ConvertToEnum("CaseA")); + EXPECT_EQ(llvm::Optional(StrEnum::CaseB), ConvertToEnum("CaseB")); EXPECT_EQ(llvm::None, ConvertToEnum("X")); } TEST(EnumsGenTest, GeneratedUnderlyingType) { - bool v = - std::is_same::type>::value; + bool v = std::is_same::type>::value; EXPECT_TRUE(v); } diff --git a/mlir/unittests/TableGen/enums.td b/mlir/unittests/TableGen/enums.td index 289829552c37..bd8306718419 100644 --- a/mlir/unittests/TableGen/enums.td +++ b/mlir/unittests/TableGen/enums.td @@ -17,15 +17,16 @@ include "mlir/IR/OpBase.td" -def CaseA: EnumAttrCase<"CaseA", 0>; +def CaseA: EnumAttrCase<"CaseA">; def CaseB: EnumAttrCase<"CaseB", 10>; -def MyEnum: EnumAttr<"MyEnum", "A test enum", [CaseA, CaseB]> { +def StrEnum: EnumAttr<"StrEnum", "A test enum", [CaseA, CaseB]> { let cppNamespace = "Outer::Inner"; let stringToSymbolFnName = "ConvertToEnum"; let symbolToStringFnName = "ConvertToString"; } -def Uint64Enum : EnumAttr<"Uint64Enum", "A test enum", [CaseA, CaseB]> { - let underlyingType = "uint64_t"; -} +def Case5: I32EnumAttrCase<"Case5", 5>; +def Case10: I32EnumAttrCase<"Case10", 10>; + +def I32Enum: I32EnumAttr<"I32Enum", "A test enum", [Case5, Case10]>;