[ODS] Introduce IntEnumAttr

In ODS, right now we use StringAttrs to emulate enum attributes. It is
suboptimal if the op actually can and wants to store the enum as a
single integer value; we are paying extra cost on storing and comparing
the attribute value.

This CL introduces a new enum attribute subclass that are backed by
IntegerAttr. The downside with IntegerAttr-backed enum attributes is
that the assembly form now uses integer values, which is less obvious
than the StringAttr-backed ones. However, that can be remedied by
defining custom assembly form with the help of the conversion utility
functions generated via EnumsGen.

Choices are given to the dialect writers to decide which one to use for
their enum attributes.

PiperOrigin-RevId: 255935542
This commit is contained in:
Lei Zhang
2019-07-01 05:26:14 -07:00
committed by jpienaar
parent e7f51ad08a
commit 9dd182e0fa
12 changed files with 383 additions and 174 deletions

View File

@@ -483,6 +483,9 @@ def FloatLike : TypeConstraint<Or<[AnyFloat.predicate,
// Attribute definitions
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// Base attribute definition
// Base class for all attributes.
class Attr<Pred condition, string descr = ""> :
AttrConstraint<condition, descr> {
@@ -517,6 +520,9 @@ class Attr<Pred condition, string descr = ""> :
bit isOptional = 0;
}
//===----------------------------------------------------------------------===//
// Attribute modifier definition
// Decorates an attribute to have an (unvalidated) default value if not present.
class DefaultValuedAttr<Attr attr, string val> :
Attr<attr.predicate, attr.description> {
@@ -550,6 +556,9 @@ class OptionalAttr<Attr attr> : Attr<attr.predicate, attr.description> {
string baseAttr = !cast<string>(attr);
}
//===----------------------------------------------------------------------===//
// Primitive attribute kinds
// A generic attribute that must be constructed around a specific type
// `attrValType`. Backed by MLIR attribute kind `attrKind`.
class TypedAttrBase<BuildableType attrValType, string attrKind,
@@ -631,30 +640,46 @@ class TypeAttrBase<string retType, string description> :
def TypeAttr : TypeAttrBase<"Type", "any type attribute">;
// An enum attribute case.
class EnumAttrCase<string sym, int val = -1> : StringBasedAttr<
CPred<"$_self.cast<StringAttr>().getValue() == \"" # sym # "\"">,
"case " # sym> {
//===----------------------------------------------------------------------===//
// Enum attribute kinds
// Additional information for an enum attribute case.
class EnumAttrCaseInfo<string sym, int val> {
// The C++ enumerant symbol
string symbol = sym;
// The C++ enumerant value
// A non-negative value must be provided if to use EnumsGen backend.
// If less than zero, there will be no explicit discriminator values assigned
// to enumerators in the generated enum class.
int value = val;
}
// An enum attribute. Its value can only be one from the given list of `cases`.
// Enum attributes are emulated via mlir::StringAttr, plus extra verification
// on the string: only the symbols of the allowed cases are permitted as the
// string value.
class EnumAttr<string name, string description, list<EnumAttrCase> cases> :
StringBasedAttr<And<[StrAttr.predicate,
Or<!foreach(case, cases, case.predicate)>]>,
description> {
// An enum attribute case stored with StringAttr.
// TODO(antiagainst): rename this to StrEnumAttrCase to be consistent
class EnumAttrCase<string sym, int val = -1> :
EnumAttrCaseInfo<sym, val>,
StringBasedAttr<
CPred<"$_self.cast<StringAttr>().getValue() == \"" # sym # "\"">,
"case " # sym>;
// An enum attribute case stored with IntegerAttr.
class IntEnumAttrCaseBase<I intType, string sym, int val> :
EnumAttrCaseInfo<sym, val>,
IntegerAttrBase<intType, "case " # sym> {
let predicate =
CPred<"$_self.cast<IntegerAttr>().getInt() == " # val>;
}
class I32EnumAttrCase<string sym, int val> : IntEnumAttrCaseBase<I32, sym, val>;
class I64EnumAttrCase<string sym, int val> : IntEnumAttrCaseBase<I64, sym, val>;
// Additional information for an enum attribute.
class EnumAttrInfo<string name, list<EnumAttrCaseInfo> cases> {
// The C++ enum class name
string className = name;
// List of all accepted cases
list<EnumAttrCase> enumerants = cases;
list<EnumAttrCaseInfo> enumerants = cases;
// The following fields are only used by the EnumsGen backend to generate
// an enum class definition and conversion utility functions.
@@ -706,6 +731,51 @@ class EnumAttr<string name, string description, list<EnumAttrCase> cases> :
string maxEnumValFnName = "getMaxEnumValFor" # name;
}
// An enum attribute backed by StringAttr.
//
// Op attributes of this kind are stored as StringAttr. Extra verification will
// be generated on the string though: only the symbols of the allowed cases are
// permitted as the string value.
// TODO(antiagainst): rename this to StrEnumAttr to be consistent
class EnumAttr<string name, string description,
list<EnumAttrCase> cases> :
EnumAttrInfo<name, cases>,
StringBasedAttr<
And<[StrAttr.predicate, Or<!foreach(case, cases, case.predicate)>]>,
!if(!empty(description), "allowed string cases: " #
StrJoin<!foreach(case, cases, "'" # case.symbol # "'")>.result,
description)>;
// An enum attribute backed by IntegerAttr.
//
// Op attributes of this kind are stored as IntegerAttr. Extra verification will
// be generated on the integer though: only the values of the allowed cases are
// permitted as the integer value.
class IntEnumAttr<I intType, string name, string description,
list<IntEnumAttrCaseBase> cases> :
EnumAttrInfo<name, cases>,
IntegerAttrBase<intType,
!if(!empty(description), "allowed " # intType.description # " cases: " #
StrJoinInt<!foreach(case, cases, case.value)>.result, description)> {
let predicate = And<[
IntegerAttrBase<intType, "">.predicate,
Or<!foreach(case, cases, case.predicate)>]>;
}
class I32EnumAttr<string name, string description,
list<I32EnumAttrCase> cases> :
IntEnumAttr<I32, name, description, cases> {
let underlyingType = "uint32_t";
}
class I64EnumAttr<string name, string description,
list<I64EnumAttrCase> cases> :
IntEnumAttr<I64, name, description, cases> {
let underlyingType = "uint64_t";
}
//===----------------------------------------------------------------------===//
// Composite attribute kinds
class ElementsAttrBase<Pred condition, string description> :
Attr<condition, description> {
let storageType = [{ ElementsAttr }];
@@ -771,6 +841,9 @@ def FunctionAttr : Attr<CPred<"$_self.isa<FunctionAttr>()">,
let constBuilderCall = "$_builder.getFunctionAttr($0)";
}
//===----------------------------------------------------------------------===//
// Derive attribute kinds
// DerivedAttr are attributes whose value is computed from properties
// of the operation. They do not require additional storage and are
// materialized as needed.
@@ -782,6 +855,9 @@ class DerivedAttr<code ret, code b> : Attr<CPred<"true">, "derived attribute"> {
// Derived attribute that returns a mlir::Type.
class DerivedTypeAttr<code body> : DerivedAttr<"Type", body>;
//===----------------------------------------------------------------------===//
// Constant attribute kinds
// Represents a constant attribute of specific Attr type. A constant
// attribute can be specified only of attributes that have a constant
// builder call defined. The constant value is specified as a string.

View File

@@ -121,12 +121,15 @@ private:
};
// Wrapper class providing helper methods for accessing enum attribute cases
// defined in TableGen. This class should closely reflect what is defined as
// class `EnumAttrCase` in TableGen.
// defined in TableGen. This is used for enum attribute case backed by both
// StringAttr and IntegerAttr.
class EnumAttrCase : public Attribute {
public:
explicit EnumAttrCase(const llvm::DefInit *init);
// Returns true if this EnumAttrCase is backed by a StringAttr.
bool isStrCase() const;
// Returns the symbol of this enum attribute case.
StringRef getSymbol() const;
@@ -135,14 +138,17 @@ public:
};
// Wrapper class providing helper methods for accessing enum attributes defined
// in TableGen. This class should closely reflect what is defined as class
// `EnumAttr` in TableGen.
// in TableGen.This is used for enum attribute case backed by both StringAttr
// and IntegerAttr.
class EnumAttr : public Attribute {
public:
explicit EnumAttr(const llvm::Record *record);
explicit EnumAttr(const llvm::Record &record);
explicit EnumAttr(const llvm::DefInit *init);
// Returns true if this EnumAttr is backed by a StringAttr.
bool isStrEnum() const;
// Returns the enum class name.
StringRef getEnumClassName() const;

View File

@@ -136,8 +136,12 @@ StringRef tblgen::ConstantAttr::getConstantValue() const {
tblgen::EnumAttrCase::EnumAttrCase(const llvm::DefInit *init)
: Attribute(init) {
assert(def->isSubClassOf("EnumAttrCase") &&
"must be subclass of TableGen 'EnumAttrCase' class");
assert(def->isSubClassOf("EnumAttrCaseInfo") &&
"must be subclass of TableGen 'EnumAttrInfo' class");
}
bool tblgen::EnumAttrCase::isStrCase() const {
return def->isSubClassOf("EnumAttrCase");
}
StringRef tblgen::EnumAttrCase::getSymbol() const {
@@ -145,11 +149,12 @@ StringRef tblgen::EnumAttrCase::getSymbol() const {
}
int64_t tblgen::EnumAttrCase::getValue() const {
assert(isStrCase() && "cannot get value for EnumAttrCase");
return def->getValueAsInt("value");
}
tblgen::EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) {
assert(def->isSubClassOf("EnumAttr") &&
assert(def->isSubClassOf("EnumAttrInfo") &&
"must be subclass of TableGen 'EnumAttr' class");
}
@@ -158,6 +163,10 @@ tblgen::EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {}
tblgen::EnumAttr::EnumAttr(const llvm::DefInit *init)
: EnumAttr(init->getDef()) {}
bool tblgen::EnumAttr::isStrEnum() const {
return def->isSubClassOf("EnumAttr");
}
StringRef tblgen::EnumAttr::getEnumClassName() const {
return def->getValueAsString("className");
}

View File

@@ -54,7 +54,7 @@ bool tblgen::DagLeaf::isConstantAttr() const {
}
bool tblgen::DagLeaf::isEnumAttrCase() const {
return isSubClassOf("EnumAttrCase");
return isSubClassOf("EnumAttrCaseInfo");
}
tblgen::Constraint tblgen::DagLeaf::getAsConstraint() const {

View File

@@ -30,3 +30,88 @@ func @string_attr_custom_type() {
test.string_attr_with_type "string_data"
return
}
// -----
//===----------------------------------------------------------------------===//
// Test StrEnumAttr
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @allowed_cases_pass
func @allowed_cases_pass() {
// CHECK: test.str_enum_attr
%0 = "test.str_enum_attr"() {attr = "A"} : () -> i32
// CHECK: test.str_enum_attr
%1 = "test.str_enum_attr"() {attr = "B"} : () -> i32
return
}
// -----
func @disallowed_case_fail() {
// expected-error @+1 {{allowed string cases: 'A', 'B'}}
%0 = "test.str_enum_attr"() {attr = 7: i32} : () -> i32
return
}
// -----
//===----------------------------------------------------------------------===//
// Test I32EnumAttr
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @allowed_cases_pass
func @allowed_cases_pass() {
// CHECK: test.i32_enum_attr
%0 = "test.i32_enum_attr"() {attr = 5: i32} : () -> i32
// CHECK: test.i32_enum_attr
%1 = "test.i32_enum_attr"() {attr = 10: i32} : () -> i32
return
}
// -----
func @disallowed_case7_fail() {
// expected-error @+1 {{allowed 32-bit integer cases: 5, 10}}
%0 = "test.i32_enum_attr"() {attr = 7: i32} : () -> i32
return
}
// -----
func @disallowed_case7_fail() {
// expected-error @+1 {{allowed 32-bit integer cases: 5, 10}}
%0 = "test.i32_enum_attr"() {attr = 5: i64} : () -> i32
return
}
// -----
//===----------------------------------------------------------------------===//
// Test I64EnumAttr
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @allowed_cases_pass
func @allowed_cases_pass() {
// CHECK: test.i64_enum_attr
%0 = "test.i64_enum_attr"() {attr = 5: i64} : () -> i32
// CHECK: test.i64_enum_attr
%1 = "test.i64_enum_attr"() {attr = 10: i64} : () -> i32
return
}
// -----
func @disallowed_case7_fail() {
// expected-error @+1 {{allowed 64-bit integer cases: 5, 10}}
%0 = "test.i64_enum_attr"() {attr = 7: i64} : () -> i32
return
}
// -----
func @disallowed_case7_fail() {
// expected-error @+1 {{allowed 64-bit integer cases: 5, 10}}
%0 = "test.i64_enum_attr"() {attr = 5: i32} : () -> i32
return
}

View File

@@ -104,6 +104,39 @@ def TypeStringAttrWithTypeOp : TEST_Op<"string_attr_with_type"> {
}];
}
def StrCaseA: EnumAttrCase<"A">;
def StrCaseB: EnumAttrCase<"B">;
def SomeStrEnum: EnumAttr<
"SomeStrEnum", "", [StrCaseA, StrCaseB]>;
def StrEnumAttrOp : TEST_Op<"str_enum_attr"> {
let arguments = (ins SomeStrEnum:$attr);
let results = (outs I32:$val);
}
def I32Case5: I32EnumAttrCase<"case5", 5>;
def I32Case10: I32EnumAttrCase<"case10", 10>;
def SomeI32Enum: I32EnumAttr<
"SomeI32Enum", "", [I32Case5, I32Case10]>;
def I32EnumAttrOp : TEST_Op<"i32_enum_attr"> {
let arguments = (ins SomeI32Enum:$attr);
let results = (outs I32:$val);
}
def I64Case5: I64EnumAttrCase<"case5", 5>;
def I64Case10: I64EnumAttrCase<"case10", 10>;
def SomeI64Enum: I64EnumAttr<
"SomeI64Enum", "", [I64Case5, I64Case10]>;
def I64EnumAttrOp : TEST_Op<"i64_enum_attr"> {
let arguments = (ins SomeI64Enum:$attr);
let results = (outs I32:$val);
}
//===----------------------------------------------------------------------===//
// Test Traits
//===----------------------------------------------------------------------===//
@@ -192,6 +225,12 @@ def : Pat<(OpD $input), (OpF $input), [], (addBenefit 10)>;
def : Pat<(OpG $input), (OpB $input, ConstantAttr<I32Attr, "20">:$attr)>;
def : Pat<(OpG (OpG $input)), (OpB $input, ConstantAttr<I32Attr, "34">:$attr)>;
// Test string enum attribute in rewrites.
def : Pat<(StrEnumAttrOp StrCaseA), (StrEnumAttrOp StrCaseB)>;
// Test integer enum attribute in rewrites.
def : Pat<(I32EnumAttrOp I32Case5), (I32EnumAttrOp I32Case10)>;
def : Pat<(I64EnumAttrOp I64Case5), (I64EnumAttrOp I64Case10)>;
//===----------------------------------------------------------------------===//
// Test op regions
//===----------------------------------------------------------------------===//

View File

@@ -1,69 +0,0 @@
// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF
// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s --check-prefix=PAT
include "mlir/IR/OpBase.td"
def NS_SomeEnum_A : EnumAttrCase<"A">;
def NS_SomeEnum_B : EnumAttrCase<"B">;
def NS_SomeEnum : EnumAttr<
"SomeEnum", "some enum", [NS_SomeEnum_A, NS_SomeEnum_B]>;
def Test_Dialect : Dialect {
let name = "test";
let cppNamespace = "NS";
}
class NS_Op<string mnemonic, list<OpTrait> traits> :
Op<Test_Dialect, mnemonic, traits>;
def NS_OpA : NS_Op<"op_a_with_enum_attr", []> {
let arguments = (ins NS_SomeEnum:$attr);
let results = (outs I32:$result);
}
// Test enum attribute getter method and verification
// ---
// DEF-LABEL: StringRef OpA::attr()
// DEF-NEXT: auto attr = this->getAttr("attr").cast<StringAttr>();
// DEF-NEXT: return attr.getValue();
// DEF-LABEL: OpA::verify()
// DEF: auto tblgen_attr = this->getAttr("attr");
// DEF: if (!(((tblgen_attr.isa<StringAttr>())) && (((tblgen_attr.cast<StringAttr>().getValue() == "A")) || ((tblgen_attr.cast<StringAttr>().getValue() == "B")))))
// DEF-SAME: return emitOpError("attribute 'attr' failed to satisfy constraint: some enum");
def NS_OpB : NS_Op<"op_b_with_enum_attr", []> {
let arguments = (ins NS_SomeEnum:$attr);
let results = (outs I32:$result);
}
def : Pat<(NS_OpA NS_SomeEnum_A:$attr), (NS_OpB NS_SomeEnum_B)>;
// Test enum attribute match and rewrite
// ---
// PAT-LABEL: struct GeneratedConvert0
// PAT: PatternMatchResult match
// PAT: auto attr = op0->getAttrOfType<StringAttr>("attr");
// PAT-NEXT: if (!attr) return matchFailure();
// PAT-NEXT: if (!((attr.cast<StringAttr>().getValue() == "A"))) return matchFailure();
// PAT: void rewrite
// PAT: auto vOpB0 = rewriter.create<NS::OpB>(loc,
// PAT-NEXT: rewriter.getStringAttr("B")
// PAT-NEXT: );
def NS_SomeEnum_Array : TypedArrayAttrBase<NS_SomeEnum, "SomeEnum array">;
def NS_OpC : NS_Op<"op_b_with_enum_array_attr", []> {
let arguments = (ins NS_SomeEnum_Array:$attr);
let results = (outs I32:$result);
}
// Test enum array attribute verification
// ---
// DEF-LABEL: OpC::verify()
// DEF: [](Attribute attr) { return ((attr.isa<StringAttr>())) && (((attr.cast<StringAttr>().getValue() == "A")) || ((attr.cast<StringAttr>().getValue() == "B"))); }

View File

@@ -27,4 +27,25 @@ func @verifyBenefit(%arg0 : i32) -> i32 {
// CHECK: "test.op_f"(%arg0)
// CHECK: "test.op_b"(%arg0) {attr = 34 : i32}
return %0 : i32
}
}
// CHECK-LABEL: verifyStrEnumAttr
func @verifyStrEnumAttr() -> i32 {
// CHECK: "test.str_enum_attr"() {attr = "B"}
%0 = "test.str_enum_attr"() {attr = "A"} : () -> i32
return %0 : i32
}
// CHECK-LABEL: verifyI32EnumAttr
func @verifyI32EnumAttr() -> i32 {
// CHECK: "test.i32_enum_attr"() {attr = 10 : i32}
%0 = "test.i32_enum_attr"() {attr = 5: i32} : () -> i32
return %0 : i32
}
// CHECK-LABEL: verifyI64EnumAttr
func @verifyI64EnumAttr() -> i32 {
// CHECK: "test.i64_enum_attr"() {attr = 10 : i64}
%0 = "test.i64_enum_attr"() {attr = 5: i64} : () -> i32
return %0 : i32
}

View File

@@ -60,11 +60,11 @@ static void emitEnumClass(const Record &enumDef, StringRef enumName,
for (const auto &enumerant : enumerants) {
auto symbol = makeIdentifier(enumerant.getSymbol());
auto value = enumerant.getValue();
if (value < 0) {
llvm::PrintFatalError(enumDef.getLoc(),
"all enumerants must have a non-negative value");
if (value >= 0) {
os << formatv(" {0} = {1},\n", symbol, value);
} else {
os << formatv(" {0},\n", symbol);
}
os << formatv(" {0} = {1},\n", symbol, value);
}
os << "};\n\n";
}
@@ -101,6 +101,88 @@ template<> struct DenseMapInfo<{0}> {{
os << "\n\n";
}
static void emitMaxValueFn(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef maxEnumValFnName = enumAttr.getMaxEnumValFnName();
auto enumerants = enumAttr.getAllCases();
unsigned maxEnumVal = 0;
for (const auto &enumerant : enumerants) {
int64_t value = enumerant.getValue();
// Avoid generating the max value function if there is an enumerant without
// explicit value.
if (value < 0)
return;
maxEnumVal = std::max(maxEnumVal, static_cast<unsigned>(value));
}
// Emit the function to return the max enum value
os << formatv("inline constexpr unsigned {0}() {{\n", maxEnumValFnName);
os << formatv(" return {0};\n", maxEnumVal);
os << "}\n\n";
}
static void emitSymToStrFn(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
auto enumerants = enumAttr.getAllCases();
os << formatv("llvm::StringRef {1}({0} val) {{\n", enumName, symToStrFnName);
os << " switch (val) {\n";
for (const auto &enumerant : enumerants) {
auto symbol = enumerant.getSymbol();
os << formatv(" case {0}::{1}: return \"{2}\";\n", enumName,
makeIdentifier(symbol), symbol);
}
os << " }\n";
os << " return \"\";\n";
os << "}\n\n";
}
static void emitStrToSymFn(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
auto enumerants = enumAttr.getAllCases();
os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef str) {{\n", enumName,
strToSymFnName);
os << formatv(" return llvm::StringSwitch<llvm::Optional<{0}>>(str)\n",
enumName);
for (const auto &enumerant : enumerants) {
auto symbol = enumerant.getSymbol();
os << formatv(" .Case(\"{1}\", {0}::{2})\n", enumName, symbol,
makeIdentifier(symbol));
}
os << " .Default(llvm::None);\n";
os << "}\n";
}
static void emitUnderlyingToSymFn(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
std::string underlyingType = enumAttr.getUnderlyingType();
StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
auto enumerants = enumAttr.getAllCases();
os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", enumName,
underlyingToSymFnName,
underlyingType.empty() ? std::string("unsigned")
: underlyingType)
<< " switch (value) {\n";
for (const auto &enumerant : enumerants) {
auto symbol = enumerant.getSymbol();
auto value = enumerant.getValue();
os << formatv(" case {0}: return {1}::{2};\n", value, enumName,
makeIdentifier(symbol));
}
os << " default: return llvm::None;\n"
<< " }\n"
<< "}\n\n";
}
static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
@@ -110,7 +192,6 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
StringRef maxEnumValFnName = enumAttr.getMaxEnumValFnName();
auto enumerants = enumAttr.getAllCases();
llvm::SmallVector<StringRef, 2> namespaces;
@@ -130,20 +211,11 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef);\n", enumName,
strToSymFnName);
emitMaxValueFn(enumDef, os);
for (auto ns : llvm::reverse(namespaces))
os << "} // namespace " << ns << "\n";
// Emit the function to return the max enum value
unsigned maxEnumVal = 0;
for (const auto &enumerant : enumerants) {
auto value = enumerant.getValue();
// Already checked that the value is non-negetive.
maxEnumVal = std::max(maxEnumVal, static_cast<unsigned>(value));
}
os << formatv("inline constexpr unsigned {0}() {{\n", maxEnumValFnName);
os << formatv(" return {0};\n", maxEnumVal);
os << "}\n\n";
// Emit DenseMapInfo for this enum class
emitDenseMapInfo(enumName, underlyingType, cppNamespace, os);
}
@@ -151,7 +223,7 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
llvm::emitSourceFileHeader("Enum Utility Declarations", os);
auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttr");
auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
for (const auto *def : defs)
emitEnumDecl(*def, os);
@@ -160,13 +232,7 @@ static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
StringRef cppNamespace = enumAttr.getCppNamespace();
std::string underlyingType = enumAttr.getUnderlyingType();
StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
auto enumerants = enumAttr.getAllCases();
llvm::SmallVector<StringRef, 2> namespaces;
llvm::SplitString(cppNamespace, namespaces, "::");
@@ -174,43 +240,9 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
for (auto ns : namespaces)
os << "namespace " << ns << " {\n";
os << formatv("llvm::StringRef {1}({0} val) {{\n", enumName, symToStrFnName);
os << " switch (val) {\n";
for (const auto &enumerant : enumerants) {
auto symbol = enumerant.getSymbol();
os << formatv(" case {0}::{1}: return \"{2}\";\n", enumName,
makeIdentifier(symbol), symbol);
}
os << " }\n";
os << " return \"\";\n";
os << "}\n\n";
os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", enumName,
underlyingToSymFnName,
underlyingType.empty() ? std::string("unsigned")
: underlyingType)
<< " switch (value) {\n";
for (const auto &enumerant : enumerants) {
auto symbol = enumerant.getSymbol();
auto value = enumerant.getValue();
os << formatv(" case {0}: return {1}::{2};\n", value, enumName,
makeIdentifier(symbol));
}
os << " default: return llvm::None;\n"
<< " }\n"
<< "}\n\n";
os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef str) {{\n", enumName,
strToSymFnName);
os << formatv(" return llvm::StringSwitch<llvm::Optional<{0}>>(str)\n",
enumName);
for (const auto &enumerant : enumerants) {
auto symbol = enumerant.getSymbol();
os << formatv(" .Case(\"{1}\", {0}::{2})\n", enumName, symbol,
makeIdentifier(symbol));
}
os << " .Default(llvm::None);\n";
os << "}\n";
emitSymToStrFn(enumDef, os);
emitStrToSymFn(enumDef, os);
emitUnderlyingToSymFn(enumDef, os);
for (auto ns : llvm::reverse(namespaces))
os << "} // namespace " << ns << "\n";
@@ -220,7 +252,7 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
llvm::emitSourceFileHeader("Enum Utility Definitions", os);
auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttr");
auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
for (const auto *def : defs)
emitEnumDef(*def, os);

View File

@@ -628,7 +628,12 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
}
if (leaf.isEnumAttrCase()) {
auto enumCase = leaf.getAsEnumAttrCase();
return handleConstantAttr(enumCase, enumCase.getSymbol());
if (enumCase.isStrCase())
return handleConstantAttr(enumCase, enumCase.getSymbol());
// This is an enum case backed by an IntegerAttr. We need to get its value
// to build the constant.
std::string val = std::to_string(enumCase.getValue());
return handleConstantAttr(enumCase, val);
}
pattern.ensureBoundInSourcePattern(argName);
std::string result = getBoundSymbol(argName).str();

View File

@@ -31,36 +31,40 @@ using ::testing::StrEq;
// Test namespaces and enum class/utility names
using Outer::Inner::ConvertToEnum;
using Outer::Inner::ConvertToString;
using Outer::Inner::MyEnum;
using Outer::Inner::StrEnum;
TEST(EnumsGenTest, GeneratedEnumDefinition) {
EXPECT_EQ(0u, static_cast<uint64_t>(MyEnum::CaseA));
EXPECT_EQ(10u, static_cast<uint64_t>(MyEnum::CaseB));
TEST(EnumsGenTest, GeneratedStrEnumDefinition) {
EXPECT_EQ(0u, static_cast<uint64_t>(StrEnum::CaseA));
EXPECT_EQ(10u, static_cast<uint64_t>(StrEnum::CaseB));
}
TEST(EnumsGenTest, GeneratedI32EnumDefinition) {
EXPECT_EQ(5u, static_cast<uint64_t>(I32Enum::Case5));
EXPECT_EQ(10u, static_cast<uint64_t>(I32Enum::Case10));
}
TEST(EnumsGenTest, GeneratedDenseMapInfo) {
llvm::DenseMap<MyEnum, std::string> myMap;
llvm::DenseMap<StrEnum, std::string> myMap;
myMap[MyEnum::CaseA] = "zero";
myMap[MyEnum::CaseB] = "ten";
myMap[StrEnum::CaseA] = "zero";
myMap[StrEnum::CaseB] = "one";
EXPECT_THAT(myMap[MyEnum::CaseA], StrEq("zero"));
EXPECT_THAT(myMap[MyEnum::CaseB], StrEq("ten"));
EXPECT_THAT(myMap[StrEnum::CaseA], StrEq("zero"));
EXPECT_THAT(myMap[StrEnum::CaseB], StrEq("one"));
}
TEST(EnumsGenTest, GeneratedSymbolToStringFn) {
EXPECT_THAT(ConvertToString(MyEnum::CaseA), StrEq("CaseA"));
EXPECT_THAT(ConvertToString(MyEnum::CaseB), StrEq("CaseB"));
EXPECT_THAT(ConvertToString(StrEnum::CaseA), StrEq("CaseA"));
EXPECT_THAT(ConvertToString(StrEnum::CaseB), StrEq("CaseB"));
}
TEST(EnumsGenTest, GeneratedStringToSymbolFn) {
EXPECT_EQ(llvm::Optional<MyEnum>(MyEnum::CaseA), ConvertToEnum("CaseA"));
EXPECT_EQ(llvm::Optional<MyEnum>(MyEnum::CaseB), ConvertToEnum("CaseB"));
EXPECT_EQ(llvm::Optional<StrEnum>(StrEnum::CaseA), ConvertToEnum("CaseA"));
EXPECT_EQ(llvm::Optional<StrEnum>(StrEnum::CaseB), ConvertToEnum("CaseB"));
EXPECT_EQ(llvm::None, ConvertToEnum("X"));
}
TEST(EnumsGenTest, GeneratedUnderlyingType) {
bool v =
std::is_same<uint64_t, std::underlying_type<Uint64Enum>::type>::value;
bool v = std::is_same<uint32_t, std::underlying_type<I32Enum>::type>::value;
EXPECT_TRUE(v);
}

View File

@@ -17,15 +17,16 @@
include "mlir/IR/OpBase.td"
def CaseA: EnumAttrCase<"CaseA", 0>;
def CaseA: EnumAttrCase<"CaseA">;
def CaseB: EnumAttrCase<"CaseB", 10>;
def MyEnum: EnumAttr<"MyEnum", "A test enum", [CaseA, CaseB]> {
def StrEnum: EnumAttr<"StrEnum", "A test enum", [CaseA, CaseB]> {
let cppNamespace = "Outer::Inner";
let stringToSymbolFnName = "ConvertToEnum";
let symbolToStringFnName = "ConvertToString";
}
def Uint64Enum : EnumAttr<"Uint64Enum", "A test enum", [CaseA, CaseB]> {
let underlyingType = "uint64_t";
}
def Case5: I32EnumAttrCase<"Case5", 5>;
def Case10: I32EnumAttrCase<"Case10", 10>;
def I32Enum: I32EnumAttr<"I32Enum", "A test enum", [Case5, Case10]>;