From 9dd182e0fa3aeb178d274dc0d64b5891436fba47 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Mon, 1 Jul 2019 05:26:14 -0700 Subject: [PATCH] [ODS] Introduce IntEnumAttr In ODS, right now we use StringAttrs to emulate enum attributes. It is suboptimal if the op actually can and wants to store the enum as a single integer value; we are paying extra cost on storing and comparing the attribute value. This CL introduces a new enum attribute subclass that are backed by IntegerAttr. The downside with IntegerAttr-backed enum attributes is that the assembly form now uses integer values, which is less obvious than the StringAttr-backed ones. However, that can be remedied by defining custom assembly form with the help of the conversion utility functions generated via EnumsGen. Choices are given to the dialect writers to decide which one to use for their enum attributes. PiperOrigin-RevId: 255935542 --- mlir/include/mlir/IR/OpBase.td | 104 ++++++++++++--- mlir/include/mlir/TableGen/Attribute.h | 14 ++- mlir/lib/TableGen/Attribute.cpp | 15 ++- mlir/lib/TableGen/Pattern.cpp | 2 +- mlir/test/IR/attribute.mlir | 85 +++++++++++++ mlir/test/lib/TestDialect/TestOps.td | 39 ++++++ mlir/test/mlir-tblgen/attr-enum.td | 69 ---------- mlir/test/mlir-tblgen/pattern.mlir | 23 +++- mlir/tools/mlir-tblgen/EnumsGen.cpp | 154 ++++++++++++++--------- mlir/tools/mlir-tblgen/RewriterGen.cpp | 7 +- mlir/unittests/TableGen/EnumsGenTest.cpp | 34 ++--- mlir/unittests/TableGen/enums.td | 11 +- 12 files changed, 383 insertions(+), 174 deletions(-) delete mode 100644 mlir/test/mlir-tblgen/attr-enum.td diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 269c09a030ef..802ae2c6335f 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -483,6 +483,9 @@ def FloatLike : TypeConstraint : AttrConstraint { @@ -517,6 +520,9 @@ class Attr : bit isOptional = 0; } +//===----------------------------------------------------------------------===// +// Attribute modifier definition + // Decorates an attribute to have an (unvalidated) default value if not present. class DefaultValuedAttr : Attr { @@ -550,6 +556,9 @@ class OptionalAttr : Attr { string baseAttr = !cast(attr); } +//===----------------------------------------------------------------------===// +// Primitive attribute kinds + // A generic attribute that must be constructed around a specific type // `attrValType`. Backed by MLIR attribute kind `attrKind`. class TypedAttrBase : def TypeAttr : TypeAttrBase<"Type", "any type attribute">; -// An enum attribute case. -class EnumAttrCase : StringBasedAttr< - CPred<"$_self.cast().getValue() == \"" # sym # "\"">, - "case " # sym> { +//===----------------------------------------------------------------------===// +// Enum attribute kinds + +// Additional information for an enum attribute case. +class EnumAttrCaseInfo { // The C++ enumerant symbol string symbol = sym; + // The C++ enumerant value - // A non-negative value must be provided if to use EnumsGen backend. + // If less than zero, there will be no explicit discriminator values assigned + // to enumerators in the generated enum class. int value = val; } -// An enum attribute. Its value can only be one from the given list of `cases`. -// Enum attributes are emulated via mlir::StringAttr, plus extra verification -// on the string: only the symbols of the allowed cases are permitted as the -// string value. -class EnumAttr cases> : - StringBasedAttr]>, - description> { +// An enum attribute case stored with StringAttr. +// TODO(antiagainst): rename this to StrEnumAttrCase to be consistent +class EnumAttrCase : + EnumAttrCaseInfo, + StringBasedAttr< + CPred<"$_self.cast().getValue() == \"" # sym # "\"">, + "case " # sym>; + +// An enum attribute case stored with IntegerAttr. +class IntEnumAttrCaseBase : + EnumAttrCaseInfo, + IntegerAttrBase { + let predicate = + CPred<"$_self.cast().getInt() == " # val>; +} + +class I32EnumAttrCase : IntEnumAttrCaseBase; +class I64EnumAttrCase : IntEnumAttrCaseBase; + +// Additional information for an enum attribute. +class EnumAttrInfo cases> { // The C++ enum class name string className = name; // List of all accepted cases - list enumerants = cases; + list enumerants = cases; // The following fields are only used by the EnumsGen backend to generate // an enum class definition and conversion utility functions. @@ -706,6 +731,51 @@ class EnumAttr cases> : string maxEnumValFnName = "getMaxEnumValFor" # name; } +// An enum attribute backed by StringAttr. +// +// Op attributes of this kind are stored as StringAttr. Extra verification will +// be generated on the string though: only the symbols of the allowed cases are +// permitted as the string value. +// TODO(antiagainst): rename this to StrEnumAttr to be consistent +class EnumAttr cases> : + EnumAttrInfo, + StringBasedAttr< + And<[StrAttr.predicate, Or]>, + !if(!empty(description), "allowed string cases: " # + StrJoin.result, + description)>; + +// An enum attribute backed by IntegerAttr. +// +// Op attributes of this kind are stored as IntegerAttr. Extra verification will +// be generated on the integer though: only the values of the allowed cases are +// permitted as the integer value. +class IntEnumAttr cases> : + EnumAttrInfo, + IntegerAttrBase.result, description)> { + let predicate = And<[ + IntegerAttrBase.predicate, + Or]>; +} + +class I32EnumAttr cases> : + IntEnumAttr { + let underlyingType = "uint32_t"; +} +class I64EnumAttr cases> : + IntEnumAttr { + let underlyingType = "uint64_t"; +} + +//===----------------------------------------------------------------------===// +// Composite attribute kinds + class ElementsAttrBase : Attr { let storageType = [{ ElementsAttr }]; @@ -771,6 +841,9 @@ def FunctionAttr : Attr()">, let constBuilderCall = "$_builder.getFunctionAttr($0)"; } +//===----------------------------------------------------------------------===// +// Derive attribute kinds + // DerivedAttr are attributes whose value is computed from properties // of the operation. They do not require additional storage and are // materialized as needed. @@ -782,6 +855,9 @@ class DerivedAttr : Attr, "derived attribute"> { // Derived attribute that returns a mlir::Type. class DerivedTypeAttr : DerivedAttr<"Type", body>; +//===----------------------------------------------------------------------===// +// Constant attribute kinds + // Represents a constant attribute of specific Attr type. A constant // attribute can be specified only of attributes that have a constant // builder call defined. The constant value is specified as a string. diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h index d4914703c5bd..a3fbf6dcf4b9 100644 --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -121,12 +121,15 @@ private: }; // Wrapper class providing helper methods for accessing enum attribute cases -// defined in TableGen. This class should closely reflect what is defined as -// class `EnumAttrCase` in TableGen. +// defined in TableGen. This is used for enum attribute case backed by both +// StringAttr and IntegerAttr. class EnumAttrCase : public Attribute { public: explicit EnumAttrCase(const llvm::DefInit *init); + // Returns true if this EnumAttrCase is backed by a StringAttr. + bool isStrCase() const; + // Returns the symbol of this enum attribute case. StringRef getSymbol() const; @@ -135,14 +138,17 @@ public: }; // Wrapper class providing helper methods for accessing enum attributes defined -// in TableGen. This class should closely reflect what is defined as class -// `EnumAttr` in TableGen. +// in TableGen.This is used for enum attribute case backed by both StringAttr +// and IntegerAttr. class EnumAttr : public Attribute { public: explicit EnumAttr(const llvm::Record *record); explicit EnumAttr(const llvm::Record &record); explicit EnumAttr(const llvm::DefInit *init); + // Returns true if this EnumAttr is backed by a StringAttr. + bool isStrEnum() const; + // Returns the enum class name. StringRef getEnumClassName() const; diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp index 1107048bb882..d292f2bdac62 100644 --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -136,8 +136,12 @@ StringRef tblgen::ConstantAttr::getConstantValue() const { tblgen::EnumAttrCase::EnumAttrCase(const llvm::DefInit *init) : Attribute(init) { - assert(def->isSubClassOf("EnumAttrCase") && - "must be subclass of TableGen 'EnumAttrCase' class"); + assert(def->isSubClassOf("EnumAttrCaseInfo") && + "must be subclass of TableGen 'EnumAttrInfo' class"); +} + +bool tblgen::EnumAttrCase::isStrCase() const { + return def->isSubClassOf("EnumAttrCase"); } StringRef tblgen::EnumAttrCase::getSymbol() const { @@ -145,11 +149,12 @@ StringRef tblgen::EnumAttrCase::getSymbol() const { } int64_t tblgen::EnumAttrCase::getValue() const { + assert(isStrCase() && "cannot get value for EnumAttrCase"); return def->getValueAsInt("value"); } tblgen::EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) { - assert(def->isSubClassOf("EnumAttr") && + assert(def->isSubClassOf("EnumAttrInfo") && "must be subclass of TableGen 'EnumAttr' class"); } @@ -158,6 +163,10 @@ tblgen::EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {} tblgen::EnumAttr::EnumAttr(const llvm::DefInit *init) : EnumAttr(init->getDef()) {} +bool tblgen::EnumAttr::isStrEnum() const { + return def->isSubClassOf("EnumAttr"); +} + StringRef tblgen::EnumAttr::getEnumClassName() const { return def->getValueAsString("className"); } diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index e2ddcbae076e..467c3d5c177d 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -54,7 +54,7 @@ bool tblgen::DagLeaf::isConstantAttr() const { } bool tblgen::DagLeaf::isEnumAttrCase() const { - return isSubClassOf("EnumAttrCase"); + return isSubClassOf("EnumAttrCaseInfo"); } tblgen::Constraint tblgen::DagLeaf::getAsConstraint() const { diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir index f7149145eac9..9c47bcec80b7 100644 --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -30,3 +30,88 @@ func @string_attr_custom_type() { test.string_attr_with_type "string_data" return } + +// ----- + +//===----------------------------------------------------------------------===// +// Test StrEnumAttr +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @allowed_cases_pass +func @allowed_cases_pass() { + // CHECK: test.str_enum_attr + %0 = "test.str_enum_attr"() {attr = "A"} : () -> i32 + // CHECK: test.str_enum_attr + %1 = "test.str_enum_attr"() {attr = "B"} : () -> i32 + return +} + +// ----- + +func @disallowed_case_fail() { + // expected-error @+1 {{allowed string cases: 'A', 'B'}} + %0 = "test.str_enum_attr"() {attr = 7: i32} : () -> i32 + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// Test I32EnumAttr +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @allowed_cases_pass +func @allowed_cases_pass() { + // CHECK: test.i32_enum_attr + %0 = "test.i32_enum_attr"() {attr = 5: i32} : () -> i32 + // CHECK: test.i32_enum_attr + %1 = "test.i32_enum_attr"() {attr = 10: i32} : () -> i32 + return +} + +// ----- + +func @disallowed_case7_fail() { + // expected-error @+1 {{allowed 32-bit integer cases: 5, 10}} + %0 = "test.i32_enum_attr"() {attr = 7: i32} : () -> i32 + return +} + +// ----- + +func @disallowed_case7_fail() { + // expected-error @+1 {{allowed 32-bit integer cases: 5, 10}} + %0 = "test.i32_enum_attr"() {attr = 5: i64} : () -> i32 + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// Test I64EnumAttr +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @allowed_cases_pass +func @allowed_cases_pass() { + // CHECK: test.i64_enum_attr + %0 = "test.i64_enum_attr"() {attr = 5: i64} : () -> i32 + // CHECK: test.i64_enum_attr + %1 = "test.i64_enum_attr"() {attr = 10: i64} : () -> i32 + return +} + +// ----- + +func @disallowed_case7_fail() { + // expected-error @+1 {{allowed 64-bit integer cases: 5, 10}} + %0 = "test.i64_enum_attr"() {attr = 7: i64} : () -> i32 + return +} + +// ----- + +func @disallowed_case7_fail() { + // expected-error @+1 {{allowed 64-bit integer cases: 5, 10}} + %0 = "test.i64_enum_attr"() {attr = 5: i32} : () -> i32 + return +} diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index ba5362b9985c..d64af5b0a972 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -104,6 +104,39 @@ def TypeStringAttrWithTypeOp : TEST_Op<"string_attr_with_type"> { }]; } +def StrCaseA: EnumAttrCase<"A">; +def StrCaseB: EnumAttrCase<"B">; + +def SomeStrEnum: EnumAttr< + "SomeStrEnum", "", [StrCaseA, StrCaseB]>; + +def StrEnumAttrOp : TEST_Op<"str_enum_attr"> { + let arguments = (ins SomeStrEnum:$attr); + let results = (outs I32:$val); +} + +def I32Case5: I32EnumAttrCase<"case5", 5>; +def I32Case10: I32EnumAttrCase<"case10", 10>; + +def SomeI32Enum: I32EnumAttr< + "SomeI32Enum", "", [I32Case5, I32Case10]>; + +def I32EnumAttrOp : TEST_Op<"i32_enum_attr"> { + let arguments = (ins SomeI32Enum:$attr); + let results = (outs I32:$val); +} + +def I64Case5: I64EnumAttrCase<"case5", 5>; +def I64Case10: I64EnumAttrCase<"case10", 10>; + +def SomeI64Enum: I64EnumAttr< + "SomeI64Enum", "", [I64Case5, I64Case10]>; + +def I64EnumAttrOp : TEST_Op<"i64_enum_attr"> { + let arguments = (ins SomeI64Enum:$attr); + let results = (outs I32:$val); +} + //===----------------------------------------------------------------------===// // Test Traits //===----------------------------------------------------------------------===// @@ -192,6 +225,12 @@ def : Pat<(OpD $input), (OpF $input), [], (addBenefit 10)>; def : Pat<(OpG $input), (OpB $input, ConstantAttr:$attr)>; def : Pat<(OpG (OpG $input)), (OpB $input, ConstantAttr:$attr)>; +// Test string enum attribute in rewrites. +def : Pat<(StrEnumAttrOp StrCaseA), (StrEnumAttrOp StrCaseB)>; +// Test integer enum attribute in rewrites. +def : Pat<(I32EnumAttrOp I32Case5), (I32EnumAttrOp I32Case10)>; +def : Pat<(I64EnumAttrOp I64Case5), (I64EnumAttrOp I64Case10)>; + //===----------------------------------------------------------------------===// // Test op regions //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/attr-enum.td b/mlir/test/mlir-tblgen/attr-enum.td deleted file mode 100644 index a86829c5dd15..000000000000 --- a/mlir/test/mlir-tblgen/attr-enum.td +++ /dev/null @@ -1,69 +0,0 @@ -// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF -// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s --check-prefix=PAT - -include "mlir/IR/OpBase.td" - -def NS_SomeEnum_A : EnumAttrCase<"A">; -def NS_SomeEnum_B : EnumAttrCase<"B">; - -def NS_SomeEnum : EnumAttr< - "SomeEnum", "some enum", [NS_SomeEnum_A, NS_SomeEnum_B]>; - -def Test_Dialect : Dialect { - let name = "test"; - let cppNamespace = "NS"; -} -class NS_Op traits> : - Op; - -def NS_OpA : NS_Op<"op_a_with_enum_attr", []> { - let arguments = (ins NS_SomeEnum:$attr); - let results = (outs I32:$result); -} - -// Test enum attribute getter method and verification -// --- - -// DEF-LABEL: StringRef OpA::attr() -// DEF-NEXT: auto attr = this->getAttr("attr").cast(); -// DEF-NEXT: return attr.getValue(); - -// DEF-LABEL: OpA::verify() -// DEF: auto tblgen_attr = this->getAttr("attr"); -// DEF: if (!(((tblgen_attr.isa())) && (((tblgen_attr.cast().getValue() == "A")) || ((tblgen_attr.cast().getValue() == "B"))))) -// DEF-SAME: return emitOpError("attribute 'attr' failed to satisfy constraint: some enum"); - -def NS_OpB : NS_Op<"op_b_with_enum_attr", []> { - let arguments = (ins NS_SomeEnum:$attr); - let results = (outs I32:$result); -} - -def : Pat<(NS_OpA NS_SomeEnum_A:$attr), (NS_OpB NS_SomeEnum_B)>; - -// Test enum attribute match and rewrite -// --- - -// PAT-LABEL: struct GeneratedConvert0 - -// PAT: PatternMatchResult match -// PAT: auto attr = op0->getAttrOfType("attr"); -// PAT-NEXT: if (!attr) return matchFailure(); -// PAT-NEXT: if (!((attr.cast().getValue() == "A"))) return matchFailure(); - -// PAT: void rewrite -// PAT: auto vOpB0 = rewriter.create(loc, -// PAT-NEXT: rewriter.getStringAttr("B") -// PAT-NEXT: ); - -def NS_SomeEnum_Array : TypedArrayAttrBase; - -def NS_OpC : NS_Op<"op_b_with_enum_array_attr", []> { - let arguments = (ins NS_SomeEnum_Array:$attr); - let results = (outs I32:$result); -} - -// Test enum array attribute verification -// --- - -// DEF-LABEL: OpC::verify() -// DEF: [](Attribute attr) { return ((attr.isa())) && (((attr.cast().getValue() == "A")) || ((attr.cast().getValue() == "B"))); } diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir index cf7026b43ed1..557c340fd63c 100644 --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -27,4 +27,25 @@ func @verifyBenefit(%arg0 : i32) -> i32 { // CHECK: "test.op_f"(%arg0) // CHECK: "test.op_b"(%arg0) {attr = 34 : i32} return %0 : i32 -} \ No newline at end of file +} + +// CHECK-LABEL: verifyStrEnumAttr +func @verifyStrEnumAttr() -> i32 { + // CHECK: "test.str_enum_attr"() {attr = "B"} + %0 = "test.str_enum_attr"() {attr = "A"} : () -> i32 + return %0 : i32 +} + +// CHECK-LABEL: verifyI32EnumAttr +func @verifyI32EnumAttr() -> i32 { + // CHECK: "test.i32_enum_attr"() {attr = 10 : i32} + %0 = "test.i32_enum_attr"() {attr = 5: i32} : () -> i32 + return %0 : i32 +} + +// CHECK-LABEL: verifyI64EnumAttr +func @verifyI64EnumAttr() -> i32 { + // CHECK: "test.i64_enum_attr"() {attr = 10 : i64} + %0 = "test.i64_enum_attr"() {attr = 5: i64} : () -> i32 + return %0 : i32 +} diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp index f85fa138be1b..962a7eafaf43 100644 --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -60,11 +60,11 @@ static void emitEnumClass(const Record &enumDef, StringRef enumName, for (const auto &enumerant : enumerants) { auto symbol = makeIdentifier(enumerant.getSymbol()); auto value = enumerant.getValue(); - if (value < 0) { - llvm::PrintFatalError(enumDef.getLoc(), - "all enumerants must have a non-negative value"); + if (value >= 0) { + os << formatv(" {0} = {1},\n", symbol, value); + } else { + os << formatv(" {0},\n", symbol); } - os << formatv(" {0} = {1},\n", symbol, value); } os << "};\n\n"; } @@ -101,6 +101,88 @@ template<> struct DenseMapInfo<{0}> {{ os << "\n\n"; } +static void emitMaxValueFn(const Record &enumDef, raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef maxEnumValFnName = enumAttr.getMaxEnumValFnName(); + auto enumerants = enumAttr.getAllCases(); + + unsigned maxEnumVal = 0; + for (const auto &enumerant : enumerants) { + int64_t value = enumerant.getValue(); + // Avoid generating the max value function if there is an enumerant without + // explicit value. + if (value < 0) + return; + + maxEnumVal = std::max(maxEnumVal, static_cast(value)); + } + + // Emit the function to return the max enum value + os << formatv("inline constexpr unsigned {0}() {{\n", maxEnumValFnName); + os << formatv(" return {0};\n", maxEnumVal); + os << "}\n\n"; +} + +static void emitSymToStrFn(const Record &enumDef, raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef enumName = enumAttr.getEnumClassName(); + StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); + auto enumerants = enumAttr.getAllCases(); + + os << formatv("llvm::StringRef {1}({0} val) {{\n", enumName, symToStrFnName); + os << " switch (val) {\n"; + for (const auto &enumerant : enumerants) { + auto symbol = enumerant.getSymbol(); + os << formatv(" case {0}::{1}: return \"{2}\";\n", enumName, + makeIdentifier(symbol), symbol); + } + os << " }\n"; + os << " return \"\";\n"; + os << "}\n\n"; +} + +static void emitStrToSymFn(const Record &enumDef, raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef enumName = enumAttr.getEnumClassName(); + StringRef strToSymFnName = enumAttr.getStringToSymbolFnName(); + auto enumerants = enumAttr.getAllCases(); + + os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef str) {{\n", enumName, + strToSymFnName); + os << formatv(" return llvm::StringSwitch>(str)\n", + enumName); + for (const auto &enumerant : enumerants) { + auto symbol = enumerant.getSymbol(); + os << formatv(" .Case(\"{1}\", {0}::{2})\n", enumName, symbol, + makeIdentifier(symbol)); + } + os << " .Default(llvm::None);\n"; + os << "}\n"; +} + +static void emitUnderlyingToSymFn(const Record &enumDef, raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef enumName = enumAttr.getEnumClassName(); + std::string underlyingType = enumAttr.getUnderlyingType(); + StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName(); + auto enumerants = enumAttr.getAllCases(); + + os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", enumName, + underlyingToSymFnName, + underlyingType.empty() ? std::string("unsigned") + : underlyingType) + << " switch (value) {\n"; + for (const auto &enumerant : enumerants) { + auto symbol = enumerant.getSymbol(); + auto value = enumerant.getValue(); + os << formatv(" case {0}: return {1}::{2};\n", value, enumName, + makeIdentifier(symbol)); + } + os << " default: return llvm::None;\n" + << " }\n" + << "}\n\n"; +} + static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { EnumAttr enumAttr(enumDef); StringRef enumName = enumAttr.getEnumClassName(); @@ -110,7 +192,6 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { StringRef strToSymFnName = enumAttr.getStringToSymbolFnName(); StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName(); - StringRef maxEnumValFnName = enumAttr.getMaxEnumValFnName(); auto enumerants = enumAttr.getAllCases(); llvm::SmallVector namespaces; @@ -130,20 +211,11 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef);\n", enumName, strToSymFnName); + emitMaxValueFn(enumDef, os); + for (auto ns : llvm::reverse(namespaces)) os << "} // namespace " << ns << "\n"; - // Emit the function to return the max enum value - unsigned maxEnumVal = 0; - for (const auto &enumerant : enumerants) { - auto value = enumerant.getValue(); - // Already checked that the value is non-negetive. - maxEnumVal = std::max(maxEnumVal, static_cast(value)); - } - os << formatv("inline constexpr unsigned {0}() {{\n", maxEnumValFnName); - os << formatv(" return {0};\n", maxEnumVal); - os << "}\n\n"; - // Emit DenseMapInfo for this enum class emitDenseMapInfo(enumName, underlyingType, cppNamespace, os); } @@ -151,7 +223,7 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { llvm::emitSourceFileHeader("Enum Utility Declarations", os); - auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttr"); + auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo"); for (const auto *def : defs) emitEnumDecl(*def, os); @@ -160,13 +232,7 @@ static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { static void emitEnumDef(const Record &enumDef, raw_ostream &os) { EnumAttr enumAttr(enumDef); - StringRef enumName = enumAttr.getEnumClassName(); StringRef cppNamespace = enumAttr.getCppNamespace(); - std::string underlyingType = enumAttr.getUnderlyingType(); - StringRef strToSymFnName = enumAttr.getStringToSymbolFnName(); - StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); - StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName(); - auto enumerants = enumAttr.getAllCases(); llvm::SmallVector namespaces; llvm::SplitString(cppNamespace, namespaces, "::"); @@ -174,43 +240,9 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) { for (auto ns : namespaces) os << "namespace " << ns << " {\n"; - os << formatv("llvm::StringRef {1}({0} val) {{\n", enumName, symToStrFnName); - os << " switch (val) {\n"; - for (const auto &enumerant : enumerants) { - auto symbol = enumerant.getSymbol(); - os << formatv(" case {0}::{1}: return \"{2}\";\n", enumName, - makeIdentifier(symbol), symbol); - } - os << " }\n"; - os << " return \"\";\n"; - os << "}\n\n"; - - os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", enumName, - underlyingToSymFnName, - underlyingType.empty() ? std::string("unsigned") - : underlyingType) - << " switch (value) {\n"; - for (const auto &enumerant : enumerants) { - auto symbol = enumerant.getSymbol(); - auto value = enumerant.getValue(); - os << formatv(" case {0}: return {1}::{2};\n", value, enumName, - makeIdentifier(symbol)); - } - os << " default: return llvm::None;\n" - << " }\n" - << "}\n\n"; - - os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef str) {{\n", enumName, - strToSymFnName); - os << formatv(" return llvm::StringSwitch>(str)\n", - enumName); - for (const auto &enumerant : enumerants) { - auto symbol = enumerant.getSymbol(); - os << formatv(" .Case(\"{1}\", {0}::{2})\n", enumName, symbol, - makeIdentifier(symbol)); - } - os << " .Default(llvm::None);\n"; - os << "}\n"; + emitSymToStrFn(enumDef, os); + emitStrToSymFn(enumDef, os); + emitUnderlyingToSymFn(enumDef, os); for (auto ns : llvm::reverse(namespaces)) os << "} // namespace " << ns << "\n"; @@ -220,7 +252,7 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) { static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { llvm::emitSourceFileHeader("Enum Utility Definitions", os); - auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttr"); + auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo"); for (const auto *def : defs) emitEnumDef(*def, os); diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 203393b5caa9..2419c31b1701 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -628,7 +628,12 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf, } if (leaf.isEnumAttrCase()) { auto enumCase = leaf.getAsEnumAttrCase(); - return handleConstantAttr(enumCase, enumCase.getSymbol()); + if (enumCase.isStrCase()) + return handleConstantAttr(enumCase, enumCase.getSymbol()); + // This is an enum case backed by an IntegerAttr. We need to get its value + // to build the constant. + std::string val = std::to_string(enumCase.getValue()); + return handleConstantAttr(enumCase, val); } pattern.ensureBoundInSourcePattern(argName); std::string result = getBoundSymbol(argName).str(); diff --git a/mlir/unittests/TableGen/EnumsGenTest.cpp b/mlir/unittests/TableGen/EnumsGenTest.cpp index b9a98a4504ca..f20ec0ca3819 100644 --- a/mlir/unittests/TableGen/EnumsGenTest.cpp +++ b/mlir/unittests/TableGen/EnumsGenTest.cpp @@ -31,36 +31,40 @@ using ::testing::StrEq; // Test namespaces and enum class/utility names using Outer::Inner::ConvertToEnum; using Outer::Inner::ConvertToString; -using Outer::Inner::MyEnum; +using Outer::Inner::StrEnum; -TEST(EnumsGenTest, GeneratedEnumDefinition) { - EXPECT_EQ(0u, static_cast(MyEnum::CaseA)); - EXPECT_EQ(10u, static_cast(MyEnum::CaseB)); +TEST(EnumsGenTest, GeneratedStrEnumDefinition) { + EXPECT_EQ(0u, static_cast(StrEnum::CaseA)); + EXPECT_EQ(10u, static_cast(StrEnum::CaseB)); +} + +TEST(EnumsGenTest, GeneratedI32EnumDefinition) { + EXPECT_EQ(5u, static_cast(I32Enum::Case5)); + EXPECT_EQ(10u, static_cast(I32Enum::Case10)); } TEST(EnumsGenTest, GeneratedDenseMapInfo) { - llvm::DenseMap myMap; + llvm::DenseMap myMap; - myMap[MyEnum::CaseA] = "zero"; - myMap[MyEnum::CaseB] = "ten"; + myMap[StrEnum::CaseA] = "zero"; + myMap[StrEnum::CaseB] = "one"; - EXPECT_THAT(myMap[MyEnum::CaseA], StrEq("zero")); - EXPECT_THAT(myMap[MyEnum::CaseB], StrEq("ten")); + EXPECT_THAT(myMap[StrEnum::CaseA], StrEq("zero")); + EXPECT_THAT(myMap[StrEnum::CaseB], StrEq("one")); } TEST(EnumsGenTest, GeneratedSymbolToStringFn) { - EXPECT_THAT(ConvertToString(MyEnum::CaseA), StrEq("CaseA")); - EXPECT_THAT(ConvertToString(MyEnum::CaseB), StrEq("CaseB")); + EXPECT_THAT(ConvertToString(StrEnum::CaseA), StrEq("CaseA")); + EXPECT_THAT(ConvertToString(StrEnum::CaseB), StrEq("CaseB")); } TEST(EnumsGenTest, GeneratedStringToSymbolFn) { - EXPECT_EQ(llvm::Optional(MyEnum::CaseA), ConvertToEnum("CaseA")); - EXPECT_EQ(llvm::Optional(MyEnum::CaseB), ConvertToEnum("CaseB")); + EXPECT_EQ(llvm::Optional(StrEnum::CaseA), ConvertToEnum("CaseA")); + EXPECT_EQ(llvm::Optional(StrEnum::CaseB), ConvertToEnum("CaseB")); EXPECT_EQ(llvm::None, ConvertToEnum("X")); } TEST(EnumsGenTest, GeneratedUnderlyingType) { - bool v = - std::is_same::type>::value; + bool v = std::is_same::type>::value; EXPECT_TRUE(v); } diff --git a/mlir/unittests/TableGen/enums.td b/mlir/unittests/TableGen/enums.td index 289829552c37..bd8306718419 100644 --- a/mlir/unittests/TableGen/enums.td +++ b/mlir/unittests/TableGen/enums.td @@ -17,15 +17,16 @@ include "mlir/IR/OpBase.td" -def CaseA: EnumAttrCase<"CaseA", 0>; +def CaseA: EnumAttrCase<"CaseA">; def CaseB: EnumAttrCase<"CaseB", 10>; -def MyEnum: EnumAttr<"MyEnum", "A test enum", [CaseA, CaseB]> { +def StrEnum: EnumAttr<"StrEnum", "A test enum", [CaseA, CaseB]> { let cppNamespace = "Outer::Inner"; let stringToSymbolFnName = "ConvertToEnum"; let symbolToStringFnName = "ConvertToString"; } -def Uint64Enum : EnumAttr<"Uint64Enum", "A test enum", [CaseA, CaseB]> { - let underlyingType = "uint64_t"; -} +def Case5: I32EnumAttrCase<"Case5", 5>; +def Case10: I32EnumAttrCase<"Case10", 10>; + +def I32Enum: I32EnumAttr<"I32Enum", "A test enum", [Case5, Case10]>;