[MLIR] Attribute and type formats in ODS

Declarative attribute and type formats with assembly formats. Define an
`assemblyFormat` field in attribute and type defs with a `mnemonic` to
generate a parser and printer.

```tablegen
def MyAttr : AttrDef<MyDialect, "MyAttr"> {
  let parameters = (ins "int64_t":$count, "AffineMap":$map);
  let mnemonic = "my_attr";
  let assemblyFormat = "`<` $count `,` $map `>`";
}
```

Use `struct` to define a comma-separated list of key-value pairs:

```tablegen
def MyType : TypeDef<MyDialect, "MyType"> {
  let parameters = (ins "int":$one, "int":$two, "int":$three);
  let mnemonic = "my_attr";
  let assemblyFormat = "`<` $three `:` struct($one, $two) `>`";
}
```

Use `struct(*)` to capture all parameters.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D111594
This commit is contained in:
Jeff Niu
2021-10-15 21:39:07 +00:00
committed by Mogball
parent 56ada0f80d
commit 9a2fdc369d
23 changed files with 2387 additions and 366 deletions

View File

@@ -382,3 +382,172 @@ the things named `*Type` are generally now named `*Attr`.
Aside from that, all of the interfaces for uniquing and storage construction are
all the same.
## Defining Custom Parsers and Printers using Assembly Formats
Attributes and types defined in ODS with a mnemonic can define an
`assemblyFormat` to declaratively describe custom parsers and printers. The
assembly format consists of literals, variables, and directives.
* A literal is a keyword or valid punctuation enclosed in backticks, e.g.
`` `keyword` `` or `` `<` ``.
* A variable is a parameter name preceeded by a dollar sign, e.g. `$param0`,
which captures one attribute or type parameter.
* A directive is a keyword followed by an optional argument list that defines
special parser and printer behaviour.
```tablegen
// An example type with an assembly format.
def MyType : TypeDef<My_Dialect, "MyType"> {
// Define a mnemonic to allow the dialect's parser hook to call into the
// generated parser.
let mnemonic = "my_type";
// Define two parameters whose C++ types are indicated in string literals.
let parameters = (ins "int":$count, "AffineMap":$map);
// Define the assembly format. Surround the format with less `<` and greater
// `>` so that MLIR's printers use the pretty format.
let assemblyFormat = "`<` $count `,` `map` `=` $map `>`";
}
```
The declarative assembly format for `MyType` results in the following format
in the IR:
```mlir
!my_dialect.my_type<42, map = affine_map<(i, j) -> (j, i)>
```
### Parameter Parsing and Printing
For many basic parameter types, no additional work is needed to define how
these parameters are parsed or printerd.
* The default printer for any parameter is `$_printer << $_self`,
where `$_self` is the C++ value of the parameter and `$_printer` is a
`DialectAsmPrinter`.
* The default parser for a parameter is
`FieldParser<$cppClass>::parse($_parser)`, where `$cppClass` is the C++ type
of the parameter and `$_parser` is a `DialectAsmParser`.
Printing and parsing behaviour can be added to additional C++ types by
overloading these functions or by defining a `parser` and `printer` in an ODS
parameter class.
Example of overloading:
```c++
using MyParameter = std::pair<int, int>;
DialectAsmPrinter &operator<<(DialectAsmPrinter &printer, MyParameter param) {
printer << param.first << " * " << param.second;
}
template <> struct FieldParser<MyParameter> {
static FailureOr<MyParameter> parse(DialectAsmParser &parser) {
int a, b;
if (parser.parseInteger(a) || parser.parseStar() ||
parser.parseInteger(b))
return failure();
return MyParameter(a, b);
}
};
```
Example of using ODS parameter classes:
```
def MyParameter : TypeParameter<"std::pair<int, int>", "pair of ints"> {
let printer = [{ $_printer << $_self.first << " * " << $_self.second }];
let parser = [{ [&] -> FailureOr<std::pair<int, int>> {
int a, b;
if ($_parser.parseInteger(a) || $_parser.parseStar() ||
$_parser.parseInteger(b))
return failure();
return std::make_pair(a, b);
}() }];
}
```
A type using this parameter with the assembly format `` `<` $myParam `>` ``
will look as follows in the IR:
```mlir
!my_dialect.my_type<42 * 24>
```
#### Non-POD Parameters
Parameters that aren't plain-old-data (e.g. references) may need to define a
`cppStorageType` to contain the data until it is copied into the allocator.
For example, `StringRefParameter` uses `std::string` as its storage type,
whereas `ArrayRefParameter` uses `SmallVector` as its storage type. The parsers
for these parameters are expected to return `FailureOr<$cppStorageType>`.
### Assembly Format Directives
Attribute and type assembly formats have the following directives:
* `params`: capture all parameters of an attribute or type.
* `struct`: generate a "struct-like" parser and printer for a list of key-value
pairs.
#### `params` Directive
This directive is used to refer to all parameters of an attribute or type.
When used as a top-level directive, `params` generates a parser and printer for
a comma-separated list of the parameters. For example:
```tablegen
def MyPairType : TypeDef<My_Dialect, "MyPairType"> {
let parameters = (ins "int":$a, "int":$b);
let mnemonic = "pair";
let assemblyFormat = "`<` params `>`";
}
```
In the IR, this type will appear as:
```mlir
!my_dialect.pair<42, 24>
```
The `params` directive can also be passed to other directives, such as `struct`,
as an argument that refers to all parameters in place of explicitly listing all
parameters as variables.
#### `struct` Directive
The `struct` directive accepts a list of variables to capture and will generate
a parser and printer for a comma-separated list of key-value pairs. The
variables are printed in the order they are specified in the argument list **but
can be parsed in any order**. For example:
```tablegen
def MyStructType : TypeDef<My_Dialect, "MyStructType"> {
let parameters = (ins StringRefParameter<>:$sym_name,
"int":$a, "int":$b, "int":$c);
let mnemonic = "struct";
let assemblyFormat = "`<` $sym_name `->` struct($a, $b, $c) `>`";
}
```
In the IR, this type can appear with any permutation of the order of the
parameters captured in the directive.
```mlir
!my_dialect.struct<"foo" -> a = 1, b = 2, c = 3>
!my_dialect.struct<"foo" -> b = 2, c = 3, a = 1>
```
Passing `params` as the only argument to `struct` makes the directive capture
all the parameters of the attribute or type. For the same type above, an
assembly format of `` `<` struct(params) `>` `` will result in:
```mlir
!my_dialect.struct<b = 2, sym_name = "foo", c = 3, a = 1>
```
The order in which the parameters are printed is the order in which they are
declared in the attribute's or type's `parameter` list.

View File

@@ -47,6 +47,74 @@ public:
virtual StringRef getFullSymbolSpec() const = 0;
};
//===----------------------------------------------------------------------===//
// Parse Fields
//===----------------------------------------------------------------------===//
/// Provide a template class that can be specialized by users to dispatch to
/// parsers. Auto-generated parsers generate calls to `FieldParser<T>::parse`,
/// where `T` is the parameter storage type, to parse custom types.
template <typename T, typename = T>
struct FieldParser;
/// Parse an attribute.
template <typename AttributeT>
struct FieldParser<
AttributeT, std::enable_if_t<std::is_base_of<Attribute, AttributeT>::value,
AttributeT>> {
static FailureOr<AttributeT> parse(DialectAsmParser &parser) {
AttributeT value;
if (parser.parseAttribute(value))
return failure();
return value;
}
};
/// Parse any integer.
template <typename IntT>
struct FieldParser<IntT,
std::enable_if_t<std::is_integral<IntT>::value, IntT>> {
static FailureOr<IntT> parse(DialectAsmParser &parser) {
IntT value;
if (parser.parseInteger(value))
return failure();
return value;
}
};
/// Parse a string.
template <>
struct FieldParser<std::string> {
static FailureOr<std::string> parse(DialectAsmParser &parser) {
std::string value;
if (parser.parseString(&value))
return failure();
return value;
}
};
/// Parse any container that supports back insertion as a list.
template <typename ContainerT>
struct FieldParser<
ContainerT, std::enable_if_t<std::is_member_function_pointer<
decltype(&ContainerT::push_back)>::value,
ContainerT>> {
using ElementT = typename ContainerT::value_type;
static FailureOr<ContainerT> parse(DialectAsmParser &parser) {
ContainerT elements;
auto elementParser = [&]() {
auto element = FieldParser<ElementT>::parse(parser);
if (failed(element))
return failure();
elements.push_back(element.getValue());
return success();
};
if (parser.parseCommaSeparatedList(elementParser))
return failure();
return elements;
}
};
} // end namespace mlir
#endif
#endif // MLIR_IR_DIALECTIMPLEMENTATION_H

View File

@@ -2886,6 +2886,11 @@ class AttrOrTypeDef<string valueType, string name, list<Trait> defTraits,
code printer = ?;
code parser = ?;
// Custom assembly format. Requires 'mnemonic' to be specified. Cannot be
// specified at the same time as either 'printer' or 'parser'. The generated
// printer requires 'genAccessors' to be true.
string assemblyFormat = ?;
// If set, generate accessors for each parameter.
bit genAccessors = 1;
@@ -2964,10 +2969,22 @@ class AttrOrTypeParameter<string type, string desc, string accessorType = ""> {
string cppType = type;
// The C++ type of the accessor for this parameter.
string cppAccessorType = !if(!empty(accessorType), type, accessorType);
// The C++ storage type of of this parameter if it is a reference, e.g.
// `std::string` for `StringRef` or `SmallVector` for `ArrayRef`.
string cppStorageType = ?;
// One-line human-readable description of the argument.
string summary = desc;
// The format string for the asm syntax (documentation only).
string syntax = ?;
// The default parameter parser is `::mlir::parseField<T>($_parser)`, which
// returns `FailureOr<T>`. Overload `parseField` to support parsing for your
// type. Or you can provide a customer printer. For attributes, "$_type" will
// be replaced with the required attribute type.
string parser = ?;
// The default parameter printer is `$_printer << $_self`. Overload the stream
// operator of `DialectAsmPrinter` as necessary to print your type. Or you can
// provide a custom printer.
string printer = ?;
}
class AttrParameter<string type, string desc, string accessorType = "">
: AttrOrTypeParameter<type, desc, accessorType>;
@@ -2978,6 +2995,8 @@ class TypeParameter<string type, string desc, string accessorType = "">
class StringRefParameter<string desc = ""> :
AttrOrTypeParameter<"::llvm::StringRef", desc> {
let allocator = [{$_dst = $_allocator.copyInto($_self);}];
let printer = [{$_printer << '"' << $_self << '"';}];
let cppStorageType = "std::string";
}
// For APFloats, which require comparison.
@@ -2990,6 +3009,7 @@ class APFloatParameter<string desc> :
class ArrayRefParameter<string arrayOf, string desc = ""> :
AttrOrTypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> {
let allocator = [{$_dst = $_allocator.copyInto($_self);}];
let cppStorageType = "::llvm::SmallVector<" # arrayOf # ">";
}
// For classes which require allocation and have their own allocateInto method.

View File

@@ -182,10 +182,10 @@ operator<<(AsmPrinterT &p, const TypeRange &types) {
llvm::interleaveComma(types, p);
return p;
}
template <typename AsmPrinterT>
template <typename AsmPrinterT, typename ElementT>
inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
AsmPrinterT &>
operator<<(AsmPrinterT &p, ArrayRef<Type> types) {
operator<<(AsmPrinterT &p, ArrayRef<ElementT> types) {
llvm::interleaveComma(types, p);
return p;
}

View File

@@ -101,6 +101,9 @@ public:
// None. Otherwise, returns the contents of that code block.
Optional<StringRef> getParserCode() const;
// Returns the custom assembly format, if one was specified.
Optional<StringRef> getAssemblyFormat() const;
// Returns true if the accessors based on the parameters should be generated.
bool genAccessors() const;
@@ -199,6 +202,15 @@ public:
// Get the C++ accessor type of this parameter.
StringRef getCppAccessorType() const;
// Get the C++ storage type of this parameter.
StringRef getCppStorageType() const;
// Get an optional C++ parameter parser.
Optional<StringRef> getParser() const;
// Get an optional C++ parameter printer.
Optional<StringRef> getPrinter() const;
// Get a description of this parameter for documentation purposes.
Optional<StringRef> getSummary() const;

View File

@@ -1,3 +1,4 @@
//===- Dialect.h - Dialect class --------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.

View File

@@ -132,6 +132,10 @@ Optional<StringRef> AttrOrTypeDef::getParserCode() const {
return def->getValueAsOptionalString("parser");
}
Optional<StringRef> AttrOrTypeDef::getAssemblyFormat() const {
return def->getValueAsOptionalString("assemblyFormat");
}
bool AttrOrTypeDef::genAccessors() const {
return def->getValueAsBit("genAccessors");
}
@@ -219,6 +223,32 @@ StringRef AttrOrTypeParameter::getCppAccessorType() const {
return getCppType();
}
StringRef AttrOrTypeParameter::getCppStorageType() const {
if (auto *param = dyn_cast<llvm::DefInit>(def->getArg(index))) {
if (auto type = param->getDef()->getValueAsOptionalString("cppStorageType"))
return *type;
}
return getCppType();
}
Optional<StringRef> AttrOrTypeParameter::getParser() const {
auto *parameterType = def->getArg(index);
if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
if (auto parser = param->getDef()->getValueAsOptionalString("parser"))
return *parser;
}
return {};
}
Optional<StringRef> AttrOrTypeParameter::getPrinter() const {
auto *parameterType = def->getArg(index);
if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
if (auto printer = param->getDef()->getValueAsOptionalString("printer"))
return *printer;
}
return {};
}
Optional<StringRef> AttrOrTypeParameter::getSummary() const {
auto *parameterType = def->getArg(index);
if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {

View File

@@ -116,4 +116,44 @@ def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [
);
}
def TestParamOne : AttrParameter<"int64_t", ""> {}
def TestParamTwo : AttrParameter<"std::string", "", "llvm::StringRef"> {
let printer = "$_printer << '\"' << $_self << '\"'";
}
def TestParamFour : ArrayRefParameter<"int", ""> {
let cppStorageType = "llvm::SmallVector<int>";
let parser = "::parseIntArray($_parser)";
let printer = "::printIntArray($_printer, $_self)";
}
def TestAttrWithFormat : Test_Attr<"TestAttrWithFormat"> {
let parameters = (
ins
TestParamOne:$one,
TestParamTwo:$two,
"::mlir::IntegerAttr":$three,
TestParamFour:$four
);
let mnemonic = "attr_with_format";
let assemblyFormat = "`<` $one `:` struct($two, $four) `:` $three `>`";
let genVerifyDecl = 1;
}
def TestAttrUgly : Test_Attr<"TestAttrUgly"> {
let parameters = (ins "::mlir::Attribute":$attr);
let mnemonic = "attr_ugly";
let assemblyFormat = "`begin` $attr `end`";
}
def TestAttrParams: Test_Attr<"TestAttrParams"> {
let parameters = (ins "int":$v0, "int":$v1);
let mnemonic = "attr_params";
let assemblyFormat = "`<` params `>`";
}
#endif // TEST_ATTRDEFS

View File

@@ -16,9 +16,11 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/bit.h"
using namespace mlir;
using namespace test;
@@ -127,6 +129,36 @@ TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
LogicalResult
TestAttrWithFormatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
int64_t one, std::string two, IntegerAttr three,
ArrayRef<int> four) {
if (four.size() != static_cast<unsigned>(one))
return emitError() << "expected 'one' to equal 'four.size()'";
return success();
}
//===----------------------------------------------------------------------===//
// Utility Functions for Generated Attributes
//===----------------------------------------------------------------------===//
static FailureOr<SmallVector<int>> parseIntArray(DialectAsmParser &parser) {
SmallVector<int> ints;
if (parser.parseLSquare() || parser.parseCommaSeparatedList([&]() {
ints.push_back(0);
return parser.parseInteger(ints.back());
}) ||
parser.parseRSquare())
return failure();
return ints;
}
static void printIntArray(DialectAsmPrinter &printer, ArrayRef<int> ints) {
printer << '[';
llvm::interleaveComma(ints, printer);
printer << ']';
}
//===----------------------------------------------------------------------===//
// TestSubElementsAccessAttr
//===----------------------------------------------------------------------===//

View File

@@ -15,6 +15,7 @@
// To get the test dialect def.
include "TestOps.td"
include "TestAttrDefs.td"
include "mlir/IR/BuiltinTypes.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
@@ -189,4 +190,44 @@ def TestTypeWithTrait : Test_Type<"TestTypeWithTrait", [TestTypeTrait]> {
let mnemonic = "test_type_with_trait";
}
// Type with assembly format.
def TestTypeWithFormat : Test_Type<"TestTypeWithFormat"> {
let parameters = (
ins
TestParamOne:$one,
TestParamTwo:$two,
"::mlir::Attribute":$three
);
let mnemonic = "type_with_format";
let assemblyFormat = "`<` $one `,` struct($three, $two) `>`";
}
// Test dispatch to parseField
def TestTypeNoParser : Test_Type<"TestTypeNoParser"> {
let parameters = (
ins
"uint32_t":$one,
ArrayRefParameter<"int64_t">:$two,
StringRefParameter<>:$three,
"::test::CustomParam":$four
);
let mnemonic = "no_parser";
let assemblyFormat = "`<` $one `,` `[` $two `]` `,` $three `,` $four `>`";
}
def TestTypeStructCaptureAll : Test_Type<"TestStructTypeCaptureAll"> {
let parameters = (
ins
"int":$v0,
"int":$v1,
"int":$v2,
"int":$v3
);
let mnemonic = "struct_capture_all";
let assemblyFormat = "`<` struct(params) `>`";
}
#endif // TEST_TYPEDEFS

View File

@@ -38,8 +38,38 @@ struct FieldInfo {
}
};
/// A custom type for a test type parameter.
struct CustomParam {
int value;
bool operator==(const CustomParam &other) const {
return other.value == value;
}
};
inline llvm::hash_code hash_value(const test::CustomParam &param) {
return llvm::hash_value(param.value);
}
} // namespace test
namespace mlir {
template <>
struct FieldParser<test::CustomParam> {
static FailureOr<test::CustomParam> parse(DialectAsmParser &parser) {
auto value = FieldParser<int>::parse(parser);
if (failed(value))
return failure();
return test::CustomParam{value.getValue()};
}
};
} // end namespace mlir
inline mlir::DialectAsmPrinter &operator<<(mlir::DialectAsmPrinter &printer,
const test::CustomParam &param) {
return printer << param.value;
}
#include "TestTypeInterfaces.h.inc"
#define GET_TYPEDEF_CLASSES
@@ -52,17 +82,19 @@ namespace test {
struct TestRecursiveTypeStorage : public ::mlir::TypeStorage {
using KeyTy = ::llvm::StringRef;
explicit TestRecursiveTypeStorage(::llvm::StringRef key) : name(key), body(::mlir::Type()) {}
explicit TestRecursiveTypeStorage(::llvm::StringRef key)
: name(key), body(::mlir::Type()) {}
bool operator==(const KeyTy &other) const { return name == other; }
static TestRecursiveTypeStorage *construct(::mlir::TypeStorageAllocator &allocator,
const KeyTy &key) {
static TestRecursiveTypeStorage *
construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &key) {
return new (allocator.allocate<TestRecursiveTypeStorage>())
TestRecursiveTypeStorage(allocator.copyInto(key));
}
::mlir::LogicalResult mutate(::mlir::TypeStorageAllocator &allocator, ::mlir::Type newBody) {
::mlir::LogicalResult mutate(::mlir::TypeStorageAllocator &allocator,
::mlir::Type newBody) {
// Cannot set a different body than before.
if (body && body != newBody)
return ::mlir::failure();
@@ -79,11 +111,13 @@ struct TestRecursiveTypeStorage : public ::mlir::TypeStorage {
/// type, potentially itself. This requires the body to be mutated separately
/// from type creation.
class TestRecursiveType
: public ::mlir::Type::TypeBase<TestRecursiveType, ::mlir::Type, TestRecursiveTypeStorage> {
: public ::mlir::Type::TypeBase<TestRecursiveType, ::mlir::Type,
TestRecursiveTypeStorage> {
public:
using Base::Base;
static TestRecursiveType get(::mlir::MLIRContext *ctx, ::llvm::StringRef name) {
static TestRecursiveType get(::mlir::MLIRContext *ctx,
::llvm::StringRef name) {
return Base::get(ctx, name);
}

View File

@@ -0,0 +1,76 @@
// RUN: mlir-tblgen -gen-typedef-defs -I %S/../../include -asmformat-error-is-fatal=false %s 2>&1 | FileCheck %s
include "mlir/IR/OpBase.td"
def Test_Dialect : Dialect {
let name = "TestDialect";
let cppNamespace = "::test";
}
class InvalidType<string name, string asm> : TypeDef<Test_Dialect, name> {
let mnemonic = asm;
}
/// Test format is missing a parameter capture.
def InvalidTypeA : InvalidType<"InvalidTypeA", "invalid_a"> {
let parameters = (ins "int":$v0, "int":$v1);
// CHECK: format is missing reference to parameter: v1
let assemblyFormat = "`<` $v0 `>`";
}
/// Test format has duplicate parameter captures.
def InvalidTypeB : InvalidType<"InvalidTypeB", "invalid_b"> {
let parameters = (ins "int":$v0, "int":$v1);
// CHECK: duplicate parameter 'v0'
let assemblyFormat = "`<` $v0 `,` $v1 `,` $v0 `>`";
}
/// Test format has invalid syntax.
def InvalidTypeC : InvalidType<"InvalidTypeC", "invalid_c"> {
let parameters = (ins "int":$v0, "int":$v1);
// CHECK: expected literal, directive, or variable
let assemblyFormat = "`<` $v0, $v1 `>`";
}
/// Test struct directive has invalid syntax.
def InvalidTypeD : InvalidType<"InvalidTypeD", "invalid_d"> {
let parameters = (ins "int":$v0);
// CHECK: literals may only be used in the top-level section of the format
// CHECK: expected a variable in `struct` argument list
let assemblyFormat = "`<` struct($v0, `,`) `>`";
}
/// Test struct directive cannot capture zero parameters.
def InvalidTypeE : InvalidType<"InvalidTypeE", "invalid_e"> {
let parameters = (ins "int":$v0);
// CHECK: `struct` argument list expected a variable or directive
let assemblyFormat = "`<` struct() $v0 `>`";
}
/// Test capture parameter that does not exist.
def InvalidTypeF : InvalidType<"InvalidTypeF", "invalid_f"> {
let parameters = (ins "int":$v0);
// CHECK: InvalidTypeF has no parameter named 'v1'
let assemblyFormat = "`<` $v0 $v1 `>`";
}
/// Test duplicate capture of parameter in capture-all struct.
def InvalidTypeG : InvalidType<"InvalidTypeG", "invalid_g"> {
let parameters = (ins "int":$v0, "int":$v1, "int":$v2);
// CHECK: duplicate parameter 'v0'
let assemblyFormat = "`<` struct(params) $v0 `>`";
}
/// Test capture-all struct duplicate capture.
def InvalidTypeH : InvalidType<"InvalidTypeH", "invalid_h"> {
let parameters = (ins "int":$v0, "int":$v1, "int":$v2);
// CHECK: `params` captures duplicate parameter: v0
let assemblyFormat = "`<` $v0 struct(params) `>`";
}
/// Test capture of parameter after `params` directive.
def InvalidTypeI : InvalidType<"InvalidTypeI", "invalid_i"> {
let parameters = (ins "int":$v0);
// CHECK: duplicate parameter 'v0'
let assemblyFormat = "`<` params $v0 `>`";
}

View File

@@ -0,0 +1,21 @@
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
// CHECK-LABEL: @test_roundtrip_parameter_parsers
// CHECK: !test.type_with_format<111, three = #test<"attr_ugly begin 5 : index end">, two = "foo">
// CHECK: !test.type_with_format<2147, three = "hi", two = "hi">
func private @test_roundtrip_parameter_parsers(!test.type_with_format<111, three = #test<"attr_ugly begin 5 : index end">, two = "foo">) -> !test.type_with_format<2147, two = "hi", three = "hi">
attributes {
// CHECK: #test.attr_with_format<3 : two = "hello", four = [1, 2, 3] : 42 : i64>
attr0 = #test.attr_with_format<3 : two = "hello", four = [1, 2, 3] : 42 : i64>,
// CHECK: #test.attr_with_format<5 : two = "a_string", four = [4, 5, 6, 7, 8] : 8 : i8>
attr1 = #test.attr_with_format<5 : two = "a_string", four = [4, 5, 6, 7, 8] : 8 : i8>,
// CHECK: #test<"attr_ugly begin 5 : index end">
attr2 = #test<"attr_ugly begin 5 : index end">,
// CHECK: #test.attr_params<42, 24>
attr3 = #test.attr_params<42, 24>
}
// CHECK-LABEL: @test_roundtrip_default_parsers_struct
// CHECK: !test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4>
// CHECK: !test.struct_capture_all<v0 = 0, v1 = 1, v2 = 2, v3 = 3>
func private @test_roundtrip_default_parsers_struct(!test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4>) -> !test.struct_capture_all<v3 = 3, v1 = 1, v2 = 2, v0 = 0>

View File

@@ -0,0 +1,127 @@
// RUN: mlir-opt --split-input-file %s --verify-diagnostics
func private @test_ugly_attr_cannot_be_pretty() -> () attributes {
// expected-error@+1 {{expected 'begin'}}
attr = #test.attr_ugly
}
// -----
func private @test_ugly_attr_no_mnemonic() -> () attributes {
// expected-error@+1 {{expected valid keyword}}
attr = #test<"">
}
// -----
func private @test_ugly_attr_parser_dispatch() -> () attributes {
// expected-error@+1 {{expected 'begin'}}
attr = #test<"attr_ugly">
}
// -----
func private @test_ugly_attr_missing_parameter() -> () attributes {
// expected-error@+2 {{failed to parse TestAttrUgly parameter 'attr'}}
// expected-error@+1 {{expected non-function type}}
attr = #test<"attr_ugly begin">
}
// -----
func private @test_ugly_attr_missing_literal() -> () attributes {
// expected-error@+1 {{expected 'end'}}
attr = #test<"attr_ugly begin \"string_attr\"">
}
// -----
func private @test_pretty_attr_expects_less() -> () attributes {
// expected-error@+1 {{expected '<'}}
attr = #test.attr_with_format
}
// -----
func private @test_pretty_attr_missing_param() -> () attributes {
// expected-error@+2 {{expected integer value}}
// expected-error@+1 {{failed to parse TestAttrWithFormat parameter 'one'}}
attr = #test.attr_with_format<>
}
// -----
func private @test_parse_invalid_param() -> () attributes {
// Test parameter parser failure is propagated
// expected-error@+2 {{expected integer value}}
// expected-error@+1 {{failed to parse TestAttrWithFormat parameter 'one'}}
attr = #test.attr_with_format<"hi">
}
// -----
func private @test_pretty_attr_invalid_syntax() -> () attributes {
// expected-error@+1 {{expected ':'}}
attr = #test.attr_with_format<42>
}
// -----
func private @test_struct_missing_key() -> () attributes {
// expected-error@+2 {{expected valid keyword}}
// expected-error@+1 {{expected a parameter name in struct}}
attr = #test.attr_with_format<42 :>
}
// -----
func private @test_struct_unknown_key() -> () attributes {
// expected-error@+1 {{duplicate or unknown struct parameter}}
attr = #test.attr_with_format<42 : nine = "foo">
}
// -----
func private @test_struct_duplicate_key() -> () attributes {
// expected-error@+1 {{duplicate or unknown struct parameter}}
attr = #test.attr_with_format<42 : two = "foo", two = "bar">
}
// -----
func private @test_struct_not_enough_values() -> () attributes {
// expected-error@+1 {{expected ','}}
attr = #test.attr_with_format<42 : two = "foo">
}
// -----
func private @test_parse_param_after_struct() -> () attributes {
// expected-error@+2 {{expected non-function type}}
// expected-error@+1 {{failed to parse TestAttrWithFormat parameter 'three'}}
attr = #test.attr_with_format<42 : two = "foo", four = [1, 2, 3] : >
}
// -----
// expected-error@+1 {{expected '<'}}
func private @test_invalid_type() -> !test.type_with_format
// -----
// expected-error@+2 {{expected integer value}}
// expected-error@+1 {{failed to parse TestTypeWithFormat parameter 'one'}}
func private @test_pretty_type_invalid_param() -> !test.type_with_format<>
// -----
// expected-error@+2 {{expected ':'}}
// expected-error@+1 {{failed to parse TestTypeWithFormat parameter 'three'}}
func private @test_type_syntax_error() -> !test.type_with_format<42, two = "hi", three = #test.attr_with_format<42>>
// -----
func private @test_verifier_fails() -> () attributes {
// expected-error@+1 {{expected 'one' to equal 'four.size()'}}
attr = #test.attr_with_format<42 : two = "hello", four = [1, 2, 3] : 42 : i64>
}

View File

@@ -0,0 +1,394 @@
// RUN: mlir-tblgen -gen-attrdef-defs -I %S/../../include %s | FileCheck %s --check-prefix=ATTR
// RUN: mlir-tblgen -gen-typedef-defs -I %S/../../include %s | FileCheck %s --check-prefix=TYPE
include "mlir/IR/OpBase.td"
/// Test that attribute and type printers and parsers are correctly generated.
def Test_Dialect : Dialect {
let name = "TestDialect";
let cppNamespace = "::test";
}
class TestAttr<string name> : AttrDef<Test_Dialect, name>;
class TestType<string name> : TypeDef<Test_Dialect, name>;
def AttrParamA : AttrParameter<"TestParamA", "an attribute param A"> {
let parser = "::parseAttrParamA($_parser, $_type)";
let printer = "::printAttrParamA($_printer, $_self)";
}
def AttrParamB : AttrParameter<"TestParamB", "an attribute param B"> {
let parser = "$_type ? ::parseAttrWithType($_parser, $_type) : ::parseAttrWithout($_parser)";
let printer = "::printAttrB($_printer, $_self)";
}
def TypeParamA : TypeParameter<"TestParamC", "a type param C"> {
let parser = "::parseTypeParamC($_parser)";
let printer = "$_printer << $_self";
}
def TypeParamB : TypeParameter<"TestParamD", "a type param D"> {
let parser = "someFcnCall()";
let printer = "myPrinter($_self)";
}
/// Check simple attribute parser and printer are generated correctly.
// ATTR: ::mlir::Attribute TestAAttr::parse(::mlir::DialectAsmParser &parser,
// ATTR: ::mlir::Type attrType) {
// ATTR: FailureOr<IntegerAttr> _result_value;
// ATTR: FailureOr<TestParamA> _result_complex;
// ATTR: if (parser.parseKeyword("hello"))
// ATTR: return {};
// ATTR: if (parser.parseEqual())
// ATTR: return {};
// ATTR: _result_value = ::mlir::FieldParser<IntegerAttr>::parse(parser);
// ATTR: if (failed(_result_value))
// ATTR: return {};
// ATTR: if (parser.parseComma())
// ATTR: return {};
// ATTR: _result_complex = ::parseAttrParamA(parser, attrType);
// ATTR: if (failed(_result_complex))
// ATTR: return {};
// ATTR: if (parser.parseRParen())
// ATTR: return {};
// ATTR: return TestAAttr::get(parser.getContext(),
// ATTR: _result_value.getValue(),
// ATTR: _result_complex.getValue());
// ATTR: }
// ATTR: void TestAAttr::print(::mlir::DialectAsmPrinter &printer) const {
// ATTR: printer << "attr_a";
// ATTR: printer << ' ' << "hello";
// ATTR: printer << ' ' << "=";
// ATTR: printer << ' ';
// ATTR: printer << getValue();
// ATTR: printer << ",";
// ATTR: printer << ' ';
// ATTR: ::printAttrParamA(printer, getComplex());
// ATTR: printer << ")";
// ATTR: }
def AttrA : TestAttr<"TestA"> {
let parameters = (ins
"IntegerAttr":$value,
AttrParamA:$complex
);
let mnemonic = "attr_a";
let assemblyFormat = "`hello` `=` $value `,` $complex `)`";
}
/// Test simple struct parser and printer are generated correctly.
// ATTR: ::mlir::Attribute TestBAttr::parse(::mlir::DialectAsmParser &parser,
// ATTR: ::mlir::Type attrType) {
// ATTR: bool _seen_v0 = false;
// ATTR: bool _seen_v1 = false;
// ATTR: for (unsigned _index = 0; _index < 2; ++_index) {
// ATTR: StringRef _paramKey;
// ATTR: if (parser.parseKeyword(&_paramKey))
// ATTR: return {};
// ATTR: if (parser.parseEqual())
// ATTR: return {};
// ATTR: if (!_seen_v0 && _paramKey == "v0") {
// ATTR: _seen_v0 = true;
// ATTR: _result_v0 = ::parseAttrParamA(parser, attrType);
// ATTR: if (failed(_result_v0))
// ATTR: return {};
// ATTR: } else if (!_seen_v1 && _paramKey == "v1") {
// ATTR: _seen_v1 = true;
// ATTR: _result_v1 = attrType ? ::parseAttrWithType(parser, attrType) : ::parseAttrWithout(parser);
// ATTR: if (failed(_result_v1))
// ATTR: return {};
// ATTR: } else {
// ATTR: return {};
// ATTR: }
// ATTR: if ((_index != 2 - 1) && parser.parseComma())
// ATTR: return {};
// ATTR: }
// ATTR: return TestBAttr::get(parser.getContext(),
// ATTR: _result_v0.getValue(),
// ATTR: _result_v1.getValue());
// ATTR: }
// ATTR: void TestBAttr::print(::mlir::DialectAsmPrinter &printer) const {
// ATTR: printer << "v0";
// ATTR: printer << ' ' << "=";
// ATTR: printer << ' ';
// ATTR: ::printAttrParamA(printer, getV0());
// ATTR: printer << ",";
// ATTR: printer << ' ' << "v1";
// ATTR: printer << ' ' << "=";
// ATTR: printer << ' ';
// ATTR: ::printAttrB(printer, getV1());
// ATTR: }
def AttrB : TestAttr<"TestB"> {
let parameters = (ins
AttrParamA:$v0,
AttrParamB:$v1
);
let mnemonic = "attr_b";
let assemblyFormat = "`{` struct($v0, $v1) `}`";
}
/// Test attribute with capture-all params has correct parser and printer.
// ATTR: ::mlir::Attribute TestFAttr::parse(::mlir::DialectAsmParser &parser,
// ATTR: ::mlir::Type attrType) {
// ATTR: ::mlir::FailureOr<int> _result_v0;
// ATTR: ::mlir::FailureOr<int> _result_v1;
// ATTR: _result_v0 = ::mlir::FieldParser<int>::parse(parser);
// ATTR: if (failed(_result_v0))
// ATTR: return {};
// ATTR: if (parser.parseComma())
// ATTR: return {};
// ATTR: _result_v1 = ::mlir::FieldParser<int>::parse(parser);
// ATTR: if (failed(_result_v1))
// ATTR: return {};
// ATTR: return TestFAttr::get(parser.getContext(),
// ATTR: _result_v0.getValue(),
// ATTR: _result_v1.getValue());
// ATTR: }
// ATTR: void TestFAttr::print(::mlir::DialectAsmPrinter &printer) const {
// ATTR: printer << "attr_c";
// ATTR: printer << ' ';
// ATTR: printer << getV0();
// ATTR: printer << ",";
// ATTR: printer << ' ';
// ATTR: printer << getV1();
// ATTR: }
def AttrC : TestAttr<"TestF"> {
let parameters = (ins "int":$v0, "int":$v1);
let mnemonic = "attr_c";
let assemblyFormat = "params";
}
/// Test type parser and printer that mix variables and struct are generated
/// correctly.
// TYPE: ::mlir::Type TestCType::parse(::mlir::DialectAsmParser &parser) {
// TYPE: FailureOr<IntegerAttr> _result_value;
// TYPE: FailureOr<TestParamC> _result_complex;
// TYPE: if (parser.parseKeyword("foo"))
// TYPE: return {};
// TYPE: if (parser.parseComma())
// TYPE: return {};
// TYPE: if (parser.parseColon())
// TYPE: return {};
// TYPE: if (parser.parseKeyword("bob"))
// TYPE: return {};
// TYPE: if (parser.parseKeyword("bar"))
// TYPE: return {};
// TYPE: _result_value = ::mlir::FieldParser<IntegerAttr>::parse(parser);
// TYPE: if (failed(_result_value))
// TYPE: return {};
// TYPE: bool _seen_complex = false;
// TYPE: for (unsigned _index = 0; _index < 1; ++_index) {
// TYPE: StringRef _paramKey;
// TYPE: if (parser.parseKeyword(&_paramKey))
// TYPE: return {};
// TYPE: if (!_seen_complex && _paramKey == "complex") {
// TYPE: _seen_complex = true;
// TYPE: _result_complex = ::parseTypeParamC(parser);
// TYPE: if (failed(_result_complex))
// TYPE: return {};
// TYPE: } else {
// TYPE: return {};
// TYPE: }
// TYPE: if ((_index != 1 - 1) && parser.parseComma())
// TYPE: return {};
// TYPE: }
// TYPE: if (parser.parseRParen())
// TYPE: return {};
// TYPE: }
// TYPE: void TestCType::print(::mlir::DialectAsmPrinter &printer) const {
// TYPE: printer << "type_c";
// TYPE: printer << ' ' << "foo";
// TYPE: printer << ",";
// TYPE: printer << ' ' << ":";
// TYPE: printer << ' ' << "bob";
// TYPE: printer << ' ' << "bar";
// TYPE: printer << ' ';
// TYPE: printer << getValue();
// TYPE: printer << ' ' << "complex";
// TYPE: printer << ' ' << "=";
// TYPE: printer << ' ';
// TYPE: printer << getComplex();
// TYPE: printer << ")";
// TYPE: }
def TypeA : TestType<"TestC"> {
let parameters = (ins
"IntegerAttr":$value,
TypeParamA:$complex
);
let mnemonic = "type_c";
let assemblyFormat = "`foo` `,` `:` `bob` `bar` $value struct($complex) `)`";
}
/// Test type parser and printer with mix of variables and struct are generated
/// correctly.
// TYPE: ::mlir::Type TestDType::parse(::mlir::DialectAsmParser &parser) {
// TYPE: _result_v0 = ::parseTypeParamC(parser);
// TYPE: if (failed(_result_v0))
// TYPE: return {};
// TYPE: bool _seen_v1 = false;
// TYPE: bool _seen_v2 = false;
// TYPE: for (unsigned _index = 0; _index < 2; ++_index) {
// TYPE: StringRef _paramKey;
// TYPE: if (parser.parseKeyword(&_paramKey))
// TYPE: return {};
// TYPE: if (parser.parseEqual())
// TYPE: return {};
// TYPE: if (!_seen_v1 && _paramKey == "v1") {
// TYPE: _seen_v1 = true;
// TYPE: _result_v1 = someFcnCall();
// TYPE: if (failed(_result_v1))
// TYPE: return {};
// TYPE: } else if (!_seen_v2 && _paramKey == "v2") {
// TYPE: _seen_v2 = true;
// TYPE: _result_v2 = ::parseTypeParamC(parser);
// TYPE: if (failed(_result_v2))
// TYPE: return {};
// TYPE: } else {
// TYPE: return {};
// TYPE: }
// TYPE: if ((_index != 2 - 1) && parser.parseComma())
// TYPE: return {};
// TYPE: }
// TYPE: _result_v3 = someFcnCall();
// TYPE: if (failed(_result_v3))
// TYPE: return {};
// TYPE: return TestDType::get(parser.getContext(),
// TYPE: _result_v0.getValue(),
// TYPE: _result_v1.getValue(),
// TYPE: _result_v2.getValue(),
// TYPE: _result_v3.getValue());
// TYPE: }
// TYPE: void TestDType::print(::mlir::DialectAsmPrinter &printer) const {
// TYPE: printer << getV0();
// TYPE: myPrinter(getV1());
// TYPE: printer << ' ' << "v2";
// TYPE: printer << ' ' << "=";
// TYPE: printer << ' ';
// TYPE: printer << getV2();
// TYPE: myPrinter(getV3());
// TYPE: }
def TypeB : TestType<"TestD"> {
let parameters = (ins
TypeParamA:$v0,
TypeParamB:$v1,
TypeParamA:$v2,
TypeParamB:$v3
);
let mnemonic = "type_d";
let assemblyFormat = "`<` `foo` `:` $v0 `,` struct($v1, $v2) `,` $v3 `>`";
}
/// Type test with two struct directives has correctly generated parser and
/// printer.
// TYPE: ::mlir::Type TestEType::parse(::mlir::DialectAsmParser &parser) {
// TYPE: FailureOr<IntegerAttr> _result_v0;
// TYPE: FailureOr<IntegerAttr> _result_v1;
// TYPE: FailureOr<IntegerAttr> _result_v2;
// TYPE: FailureOr<IntegerAttr> _result_v3;
// TYPE: bool _seen_v0 = false;
// TYPE: bool _seen_v2 = false;
// TYPE: for (unsigned _index = 0; _index < 2; ++_index) {
// TYPE: StringRef _paramKey;
// TYPE: if (parser.parseKeyword(&_paramKey))
// TYPE: return {};
// TYPE: if (parser.parseEqual())
// TYPE: return {};
// TYPE: if (!_seen_v0 && _paramKey == "v0") {
// TYPE: _seen_v0 = true;
// TYPE: _result_v0 = ::mlir::FieldParser<IntegerAttr>::parse(parser);
// TYPE: if (failed(_result_v0))
// TYPE: return {};
// TYPE: } else if (!_seen_v2 && _paramKey == "v2") {
// TYPE: _seen_v2 = true;
// TYPE: _result_v2 = ::mlir::FieldParser<IntegerAttr>::parse(parser);
// TYPE: if (failed(_result_v2))
// TYPE: return {};
// TYPE: } else {
// TYPE: return {};
// TYPE: }
// TYPE: if ((_index != 2 - 1) && parser.parseComma())
// TYPE: return {};
// TYPE: }
// TYPE: bool _seen_v1 = false;
// TYPE: bool _seen_v3 = false;
// TYPE: for (unsigned _index = 0; _index < 2; ++_index) {
// TYPE: StringRef _paramKey;
// TYPE: if (parser.parseKeyword(&_paramKey))
// TYPE: return {};
// TYPE: if (parser.parseEqual())
// TYPE: return {};
// TYPE: if (!_seen_v1 && _paramKey == "v1") {
// TYPE: _seen_v1 = true;
// TYPE: _result_v1 = ::mlir::FieldParser<IntegerAttr>::parse(parser);
// TYPE: if (failed(_result_v1))
// TYPE: return {};
// TYPE: } else if (!_seen_v3 && _paramKey == "v3") {
// TYPE: _seen_v3 = true;
// TYPE: _result_v3 = ::mlir::FieldParser<IntegerAttr>::parse(parser);
// TYPE: if (failed(_result_v3))
// TYPE: return {};
// TYPE: } else {
// TYPE: return {};
// TYPE: }
// TYPE: if ((_index != 2 - 1) && parser.parseComma())
// TYPE: return {};
// TYPE: }
// TYPE: return TestEType::get(parser.getContext(),
// TYPE: _result_v0.getValue(),
// TYPE: _result_v1.getValue(),
// TYPE: _result_v2.getValue(),
// TYPE: _result_v3.getValue());
// TYPE: }
// TYPE: void TestEType::print(::mlir::DialectAsmPrinter &printer) const {
// TYPE: printer << "v0";
// TYPE: printer << ' ' << "=";
// TYPE: printer << ' ';
// TYPE: printer << getV0();
// TYPE: printer << ",";
// TYPE: printer << ' ' << "v2";
// TYPE: printer << ' ' << "=";
// TYPE: printer << ' ';
// TYPE: printer << getV2();
// TYPE: printer << "v1";
// TYPE: printer << ' ' << "=";
// TYPE: printer << ' ';
// TYPE: printer << getV1();
// TYPE: printer << ",";
// TYPE: printer << ' ' << "v3";
// TYPE: printer << ' ' << "=";
// TYPE: printer << ' ';
// TYPE: printer << getV3();
// TYPE: }
def TypeC : TestType<"TestE"> {
let parameters = (ins
"IntegerAttr":$v0,
"IntegerAttr":$v1,
"IntegerAttr":$v2,
"IntegerAttr":$v3
);
let mnemonic = "type_e";
let assemblyFormat = "`{` struct($v0, $v2) `}` `{` struct($v1, $v3) `}`";
}

View File

@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
#include "AttrOrTypeFormatGen.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/TableGen/AttrOrTypeDef.h"
#include "mlir/TableGen/CodeGenHelpers.h"
@@ -24,6 +25,17 @@
using namespace mlir;
using namespace mlir::tblgen;
//===----------------------------------------------------------------------===//
// Utility Functions
//===----------------------------------------------------------------------===//
std::string mlir::tblgen::getParameterAccessorName(StringRef name) {
assert(!name.empty() && "parameter has empty name");
auto ret = "get" + name.str();
ret[3] = llvm::toUpper(ret[3]); // uppercase first letter of the name
return ret;
}
/// Find all the AttrOrTypeDef for the specified dialect. If no dialect
/// specified and can only find one dialect's defs, use that.
static void collectAllDefs(StringRef selectedDialect,
@@ -399,7 +411,8 @@ void DefGenerator::emitDefDecl(const AttrOrTypeDef &def) {
<< " }\n";
// If mnemonic specified, emit print/parse declarations.
if (def.getParserCode() || def.getPrinterCode() || !params.empty()) {
if (def.getParserCode() || def.getPrinterCode() ||
def.getAssemblyFormat() || !params.empty()) {
os << llvm::formatv(defDeclParsePrintStr, valueType,
isAttrGenerator ? ", ::mlir::Type type" : "");
}
@@ -410,10 +423,8 @@ void DefGenerator::emitDefDecl(const AttrOrTypeDef &def) {
def.getParameters(parameters);
for (AttrOrTypeParameter &parameter : parameters) {
SmallString<16> name = parameter.getName();
name[0] = llvm::toUpper(name[0]);
os << formatv(" {0} get{1}() const;\n", parameter.getCppAccessorType(),
name);
os << formatv(" {0} {1}() const;\n", parameter.getCppAccessorType(),
getParameterAccessorName(parameter.getName()));
}
}
@@ -700,8 +711,32 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
}
void DefGenerator::emitParsePrint(const AttrOrTypeDef &def) {
auto printerCode = def.getPrinterCode();
auto parserCode = def.getParserCode();
auto assemblyFormat = def.getAssemblyFormat();
if (assemblyFormat && (printerCode || parserCode)) {
// Custom assembly format cannot be specified at the same time as either
// custom printer or parser code.
PrintFatalError(def.getLoc(),
def.getName() + ": assembly format cannot be specified at "
"the same time as printer or parser code");
}
// Generate a parser and printer based on the assembly format, if specified.
if (assemblyFormat) {
// A custom assembly format requires accessors to be generated for the
// generated printer.
if (!def.genAccessors()) {
PrintFatalError(def.getLoc(),
def.getName() +
": the generated printer from 'assemblyFormat' "
"requires 'genAccessors' to be true");
}
return generateAttrOrTypeFormat(def, os);
}
// Emit the printer code, if specified.
if (Optional<StringRef> printerCode = def.getPrinterCode()) {
if (printerCode) {
// Both the mnenomic and printerCode must be defined (for parity with
// parserCode).
os << "void " << def.getCppClassName()
@@ -717,7 +752,7 @@ void DefGenerator::emitParsePrint(const AttrOrTypeDef &def) {
}
// Emit the parser code, if specified.
if (Optional<StringRef> parserCode = def.getParserCode()) {
if (parserCode) {
FmtContext fmtCtxt;
fmtCtxt.addSubst("_parser", "parser")
.addSubst("_ctxt", "parser.getContext()");
@@ -857,11 +892,10 @@ void DefGenerator::emitDefDef(const AttrOrTypeDef &def) {
paramStorageName = param.getName();
}
SmallString<16> name = param.getName();
name[0] = llvm::toUpper(name[0]);
os << formatv("{0} {3}::get{1}() const {{ return getImpl()->{2}; }\n",
param.getCppAccessorType(), name, paramStorageName,
def.getCppClassName());
os << formatv("{0} {3}::{1}() const {{ return getImpl()->{2}; }\n",
param.getCppAccessorType(),
getParameterAccessorName(param.getName()),
paramStorageName, def.getCppClassName());
}
}
}

View File

@@ -0,0 +1,781 @@
//===- AttrOrTypeFormatGen.cpp - MLIR attribute and type format generator -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "AttrOrTypeFormatGen.h"
#include "FormatGen.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/TableGen/AttrOrTypeDef.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/TableGenBackend.h"
using namespace mlir;
using namespace mlir::tblgen;
using llvm::formatv;
//===----------------------------------------------------------------------===//
// Element
//===----------------------------------------------------------------------===//
namespace {
/// This class represents a single format element.
class Element {
public:
/// LLVM-style RTTI.
enum class Kind {
/// This element is a directive.
ParamsDirective,
StructDirective,
/// This element is a literal.
Literal,
/// This element is a variable.
Variable,
};
Element(Kind kind) : kind(kind) {}
virtual ~Element() = default;
/// Return the kind of this element.
Kind getKind() const { return kind; }
private:
/// The kind of this element.
Kind kind;
};
/// This class represents an instance of a literal element.
class LiteralElement : public Element {
public:
LiteralElement(StringRef literal)
: Element(Kind::Literal), literal(literal) {}
static bool classof(const Element *el) {
return el->getKind() == Kind::Literal;
}
/// Get the literal spelling.
StringRef getSpelling() const { return literal; }
private:
/// The spelling of the literal for this element.
StringRef literal;
};
/// This class represents an instance of a variable element. A variable refers
/// to an attribute or type parameter.
class VariableElement : public Element {
public:
VariableElement(AttrOrTypeParameter param)
: Element(Kind::Variable), param(param) {}
static bool classof(const Element *el) {
return el->getKind() == Kind::Variable;
}
/// Get the parameter in the element.
const AttrOrTypeParameter &getParam() const { return param; }
private:
AttrOrTypeParameter param;
};
/// Base class for a directive that contains references to multiple variables.
template <Element::Kind ElementKind>
class ParamsDirectiveBase : public Element {
public:
using Base = ParamsDirectiveBase<ElementKind>;
ParamsDirectiveBase(SmallVector<std::unique_ptr<Element>> &&params)
: Element(ElementKind), params(std::move(params)) {}
static bool classof(const Element *el) {
return el->getKind() == ElementKind;
}
/// Get the parameters contained in this directive.
auto getParams() const {
return llvm::map_range(params, [](auto &el) {
return cast<VariableElement>(el.get())->getParam();
});
}
/// Get the number of parameters.
unsigned getNumParams() const { return params.size(); }
/// Take all of the parameters from this directive.
SmallVector<std::unique_ptr<Element>> takeParams() {
return std::move(params);
}
private:
/// The parameters captured by this directive.
SmallVector<std::unique_ptr<Element>> params;
};
/// This class represents a `params` directive that refers to all parameters
/// of an attribute or type. When used as a top-level directive, it generates
/// a format of the form:
///
/// (param-value (`,` param-value)*)?
///
/// When used as an argument to another directive that accepts variables,
/// `params` can be used in place of manually listing all parameters of an
/// attribute or type.
class ParamsDirective
: public ParamsDirectiveBase<Element::Kind::ParamsDirective> {
public:
using Base::Base;
};
/// This class represents a `struct` directive that generates a struct format
/// of the form:
///
/// `{` param-name `=` param-value (`,` param-name `=` param-value)* `}`
///
class StructDirective
: public ParamsDirectiveBase<Element::Kind::StructDirective> {
public:
using Base::Base;
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// Format Strings
//===----------------------------------------------------------------------===//
/// Format for defining an attribute parser.
///
/// $0: The attribute C++ class name.
static const char *const attrParserDefn = R"(
::mlir::Attribute $0::parse(::mlir::DialectAsmParser &$_parser,
::mlir::Type $_type) {
)";
/// Format for defining a type parser.
///
/// $0: The type C++ class name.
static const char *const typeParserDefn = R"(
::mlir::Type $0::parse(::mlir::DialectAsmParser &$_parser) {
)";
/// Default parser for attribute or type parameters.
static const char *const defaultParameterParser =
"::mlir::FieldParser<$0>::parse($_parser)";
/// Default printer for attribute or type parameters.
static const char *const defaultParameterPrinter = "$_printer << $_self";
/// Print an error when failing to parse an element.
///
/// $0: The parameter C++ class name.
static const char *const parseErrorStr =
"$_parser.emitError($_parser.getCurrentLocation(), ";
/// Format for defining an attribute or type printer.
///
/// $0: The attribute or type C++ class name.
/// $1: The attribute or type mnemonic.
static const char *const attrOrTypePrinterDefn = R"(
void $0::print(::mlir::DialectAsmPrinter &$_printer) const {
$_printer << "$1";
)";
/// Loop declaration for struct parser.
///
/// $0: Number of expected parameters.
static const char *const structParseLoopStart = R"(
for (unsigned _index = 0; _index < $0; ++_index) {
StringRef _paramKey;
if ($_parser.parseKeyword(&_paramKey)) {
$_parser.emitError($_parser.getCurrentLocation(),
"expected a parameter name in struct");
return {};
}
)";
/// Terminator code segment for the struct parser loop. Check for duplicate or
/// unknown parameters. Parse a comma except on the last element.
///
/// {0}: Code template for printing an error.
/// {1}: Number of elements in the struct.
static const char *const structParseLoopEnd = R"({{
{0}"duplicate or unknown struct parameter name: ") << _paramKey;
return {{};
}
if ((_index != {1} - 1) && parser.parseComma())
return {{};
}
)";
/// Code format to parse a variable. Separate by lines because variable parsers
/// may be generated inside other directives, which requires indentation.
///
/// {0}: The parameter name.
/// {1}: The parse code for the parameter.
/// {2}: Code template for printing an error.
/// {3}: Name of the attribute or type.
/// {4}: C++ class of the parameter.
static const char *const variableParser[] = {
" // Parse variable '{0}'",
" _result_{0} = {1};",
" if (failed(_result_{0})) {{",
" {2}\"failed to parse {3} parameter '{0}' which is to be a `{4}`\");",
" return {{};",
" }",
};
//===----------------------------------------------------------------------===//
// Utility Functions
//===----------------------------------------------------------------------===//
/// Get a list of an attribute's or type's parameters. These can be wrapper
/// objects around `AttrOrTypeParameter` or string inits.
static auto getParameters(const AttrOrTypeDef &def) {
SmallVector<AttrOrTypeParameter> params;
def.getParameters(params);
return params;
}
//===----------------------------------------------------------------------===//
// AttrOrTypeFormat
//===----------------------------------------------------------------------===//
namespace {
class AttrOrTypeFormat {
public:
AttrOrTypeFormat(const AttrOrTypeDef &def,
std::vector<std::unique_ptr<Element>> &&elements)
: def(def), elements(std::move(elements)) {}
/// Generate the attribute or type parser.
void genParser(raw_ostream &os);
/// Generate the attribute or type printer.
void genPrinter(raw_ostream &os);
private:
/// Generate the parser code for a specific format element.
void genElementParser(Element *el, FmtContext &ctx, raw_ostream &os);
/// Generate the parser code for a literal.
void genLiteralParser(StringRef value, FmtContext &ctx, raw_ostream &os,
unsigned indent = 0);
/// Generate the parser code for a variable.
void genVariableParser(const AttrOrTypeParameter &param, FmtContext &ctx,
raw_ostream &os, unsigned indent = 0);
/// Generate the parser code for a `params` directive.
void genParamsParser(ParamsDirective *el, FmtContext &ctx, raw_ostream &os);
/// Generate the parser code for a `struct` directive.
void genStructParser(StructDirective *el, FmtContext &ctx, raw_ostream &os);
/// Generate the printer code for a specific format element.
void genElementPrinter(Element *el, FmtContext &ctx, raw_ostream &os);
/// Generate the printer code for a literal.
void genLiteralPrinter(StringRef value, FmtContext &ctx, raw_ostream &os);
/// Generate the printer code for a variable.
void genVariablePrinter(const AttrOrTypeParameter &param, FmtContext &ctx,
raw_ostream &os);
/// Generate the printer code for a `params` directive.
void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, raw_ostream &os);
/// Generate the printer code for a `struct` directive.
void genStructPrinter(StructDirective *el, FmtContext &ctx, raw_ostream &os);
/// The ODS definition of the attribute or type whose format is being used to
/// generate a parser and printer.
const AttrOrTypeDef &def;
/// The list of top-level format elements returned by the assembly format
/// parser.
std::vector<std::unique_ptr<Element>> elements;
/// Flags for printing spaces.
bool shouldEmitSpace;
bool lastWasPunctuation;
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// ParserGen
//===----------------------------------------------------------------------===//
void AttrOrTypeFormat::genParser(raw_ostream &os) {
FmtContext ctx;
ctx.addSubst("_parser", "parser");
/// Generate the definition.
if (isa<AttrDef>(def)) {
ctx.addSubst("_type", "attrType");
os << tgfmt(attrParserDefn, &ctx, def.getCppClassName());
} else {
os << tgfmt(typeParserDefn, &ctx, def.getCppClassName());
}
/// Declare variables to store all of the parameters. Allocated parameters
/// such as `ArrayRef` and `StringRef` must provide a `storageType`. Store
/// FailureOr<T> to defer type construction for parameters that are parsed in
/// a loop (parsers return FailureOr anyways).
SmallVector<AttrOrTypeParameter> params = getParameters(def);
for (const AttrOrTypeParameter &param : params) {
os << formatv(" ::mlir::FailureOr<{0}> _result_{1};\n",
param.getCppStorageType(), param.getName());
}
/// Store the initial location of the parser.
ctx.addSubst("_loc", "loc");
os << tgfmt(" ::llvm::SMLoc $_loc = $_parser.getCurrentLocation();\n"
" (void) $_loc;\n",
&ctx);
/// Generate call to each parameter parser.
for (auto &el : elements)
genElementParser(el.get(), ctx, os);
/// Generate call to the attribute or type builder. Use the checked getter
/// if one was generated.
if (def.genVerifyDecl()) {
os << tgfmt(" return $_parser.getChecked<$0>($_loc, $_parser.getContext()",
&ctx, def.getCppClassName());
} else {
os << tgfmt(" return $0::get($_parser.getContext()", &ctx,
def.getCppClassName());
}
for (const AttrOrTypeParameter &param : params)
os << formatv(",\n _result_{0}.getValue()", param.getName());
os << ");\n}\n\n";
}
void AttrOrTypeFormat::genElementParser(Element *el, FmtContext &ctx,
raw_ostream &os) {
if (auto *literal = dyn_cast<LiteralElement>(el))
return genLiteralParser(literal->getSpelling(), ctx, os);
if (auto *var = dyn_cast<VariableElement>(el))
return genVariableParser(var->getParam(), ctx, os);
if (auto *params = dyn_cast<ParamsDirective>(el))
return genParamsParser(params, ctx, os);
if (auto *strct = dyn_cast<StructDirective>(el))
return genStructParser(strct, ctx, os);
llvm_unreachable("unknown format element");
}
void AttrOrTypeFormat::genLiteralParser(StringRef value, FmtContext &ctx,
raw_ostream &os, unsigned indent) {
os.indent(indent) << " // Parse literal '" << value << "'\n";
os.indent(indent) << tgfmt(" if ($_parser.parse", &ctx);
if (value.front() == '_' || isalpha(value.front())) {
os << "Keyword(\"" << value << "\")";
} else {
os << StringSwitch<StringRef>(value)
.Case("->", "Arrow")
.Case(":", "Colon")
.Case(",", "Comma")
.Case("=", "Equal")
.Case("<", "Less")
.Case(">", "Greater")
.Case("{", "LBrace")
.Case("}", "RBrace")
.Case("(", "LParen")
.Case(")", "RParen")
.Case("[", "LSquare")
.Case("]", "RSquare")
.Case("?", "Question")
.Case("+", "Plus")
.Case("*", "Star")
<< "()";
}
os << ")\n";
// Parser will emit an error
os.indent(indent) << " return {};\n";
}
void AttrOrTypeFormat::genVariableParser(const AttrOrTypeParameter &param,
FmtContext &ctx, raw_ostream &os,
unsigned indent) {
/// Check for a custom parser. Use the default attribute parser otherwise.
auto customParser = param.getParser();
auto parser =
customParser ? *customParser : StringRef(defaultParameterParser);
for (const char *line : variableParser) {
os.indent(indent) << formatv(line, param.getName(),
tgfmt(parser, &ctx, param.getCppStorageType()),
tgfmt(parseErrorStr, &ctx), def.getName(),
param.getCppType())
<< "\n";
}
}
void AttrOrTypeFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,
raw_ostream &os) {
os << " // Parse parameter list\n";
llvm::interleave(
el->getParams(), [&](auto param) { genVariableParser(param, ctx, os); },
[&]() { genLiteralParser(",", ctx, os); });
}
void AttrOrTypeFormat::genStructParser(StructDirective *el, FmtContext &ctx,
raw_ostream &os) {
os << " // Parse parameter struct\n";
/// Declare a "seen" variable for each key.
for (const AttrOrTypeParameter &param : el->getParams())
os << formatv(" bool _seen_{0} = false;\n", param.getName());
/// Generate the parsing loop.
os << tgfmt(structParseLoopStart, &ctx, el->getNumParams());
genLiteralParser("=", ctx, os, 2);
os << " ";
for (const AttrOrTypeParameter &param : el->getParams()) {
os << formatv("if (!_seen_{0} && _paramKey == \"{0}\") {\n"
" _seen_{0} = true;\n",
param.getName());
genVariableParser(param, ctx, os, 4);
os << " } else ";
}
/// Duplicate or unknown parameter.
os << formatv(structParseLoopEnd, tgfmt(parseErrorStr, &ctx),
el->getNumParams());
/// Because the loop loops N times and each non-failing iteration sets 1 of
/// N flags, successfully exiting the loop means that all parameters have been
/// seen. `parseOptionalComma` would cause issues with any formats that use
/// "struct(...) `,`" beacuse structs aren't sounded by braces.
}
//===----------------------------------------------------------------------===//
// PrinterGen
//===----------------------------------------------------------------------===//
void AttrOrTypeFormat::genPrinter(raw_ostream &os) {
FmtContext ctx;
ctx.addSubst("_printer", "printer");
/// Generate the definition.
os << tgfmt(attrOrTypePrinterDefn, &ctx, def.getCppClassName(),
*def.getMnemonic());
/// Generate printers.
shouldEmitSpace = true;
lastWasPunctuation = false;
for (auto &el : elements)
genElementPrinter(el.get(), ctx, os);
os << "}\n\n";
}
void AttrOrTypeFormat::genElementPrinter(Element *el, FmtContext &ctx,
raw_ostream &os) {
if (auto *literal = dyn_cast<LiteralElement>(el))
return genLiteralPrinter(literal->getSpelling(), ctx, os);
if (auto *params = dyn_cast<ParamsDirective>(el))
return genParamsPrinter(params, ctx, os);
if (auto *strct = dyn_cast<StructDirective>(el))
return genStructPrinter(strct, ctx, os);
if (auto *var = dyn_cast<VariableElement>(el))
return genVariablePrinter(var->getParam(), ctx, os);
llvm_unreachable("unknown format element");
}
void AttrOrTypeFormat::genLiteralPrinter(StringRef value, FmtContext &ctx,
raw_ostream &os) {
/// Don't insert a space before certain punctuation.
bool needSpace =
shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation);
os << tgfmt(" $_printer$0 << \"$1\";\n", &ctx, needSpace ? " << ' '" : "",
value);
/// Update the flags.
shouldEmitSpace =
value.size() != 1 || !StringRef("<({[").contains(value.front());
lastWasPunctuation = !(value.front() == '_' || isalpha(value.front()));
}
void AttrOrTypeFormat::genVariablePrinter(const AttrOrTypeParameter &param,
FmtContext &ctx, raw_ostream &os) {
/// Insert a space before the next parameter, if necessary.
if (shouldEmitSpace || !lastWasPunctuation)
os << tgfmt(" $_printer << ' ';\n", &ctx);
shouldEmitSpace = true;
lastWasPunctuation = false;
ctx.withSelf(getParameterAccessorName(param.getName()) + "()");
os << " ";
if (auto printer = param.getPrinter())
os << tgfmt(*printer, &ctx) << ";\n";
else
os << tgfmt(defaultParameterPrinter, &ctx) << ";\n";
}
void AttrOrTypeFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx,
raw_ostream &os) {
llvm::interleave(
el->getParams(), [&](auto param) { genVariablePrinter(param, ctx, os); },
[&]() { genLiteralPrinter(",", ctx, os); });
}
void AttrOrTypeFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
raw_ostream &os) {
llvm::interleave(
el->getParams(),
[&](auto param) {
genLiteralPrinter(param.getName(), ctx, os);
genLiteralPrinter("=", ctx, os);
os << tgfmt(" $_printer << ' ';\n", &ctx);
genVariablePrinter(param, ctx, os);
},
[&]() { genLiteralPrinter(",", ctx, os); });
}
//===----------------------------------------------------------------------===//
// FormatParser
//===----------------------------------------------------------------------===//
namespace {
class FormatParser {
public:
FormatParser(llvm::SourceMgr &mgr, const AttrOrTypeDef &def)
: lexer(mgr, def.getLoc()[0]), curToken(lexer.lexToken()), def(def),
seenParams(def.getNumParameters()) {}
/// Parse the attribute or type format and create the format elements.
FailureOr<AttrOrTypeFormat> parse();
private:
/// The current context of the parser when parsing an element.
enum ParserContext {
/// The element is being parsed in the default context - at the top of the
/// format
TopLevelContext,
/// The element is being parsed as a child to a `struct` directive.
StructDirective,
};
/// Emit an error.
LogicalResult emitError(const Twine &msg) {
lexer.emitError(curToken.getLoc(), msg);
return failure();
}
/// Parse an expected token.
LogicalResult parseToken(FormatToken::Kind kind, const Twine &msg) {
if (curToken.getKind() != kind)
return emitError(msg);
consumeToken();
return success();
}
/// Advance the lexer to the next token.
void consumeToken() {
assert(curToken.getKind() != FormatToken::eof &&
curToken.getKind() != FormatToken::error &&
"shouldn't advance past EOF or errors");
curToken = lexer.lexToken();
}
/// Parse any element.
FailureOr<std::unique_ptr<Element>> parseElement(ParserContext ctx);
/// Parse a literal element.
FailureOr<std::unique_ptr<Element>> parseLiteral(ParserContext ctx);
/// Parse a variable element.
FailureOr<std::unique_ptr<Element>> parseVariable(ParserContext ctx);
/// Parse a directive.
FailureOr<std::unique_ptr<Element>> parseDirective(ParserContext ctx);
/// Parse a `params` directive.
FailureOr<std::unique_ptr<Element>> parseParamsDirective();
/// Parse a `struct` directive.
FailureOr<std::unique_ptr<Element>> parseStructDirective();
/// The current format lexer.
FormatLexer lexer;
/// The current token in the stream.
FormatToken curToken;
/// Attribute or type tablegen def.
const AttrOrTypeDef &def;
/// Seen attribute or type parameters.
llvm::BitVector seenParams;
};
} // end anonymous namespace
FailureOr<AttrOrTypeFormat> FormatParser::parse() {
std::vector<std::unique_ptr<Element>> elements;
elements.reserve(16);
/// Parse the format elements.
while (curToken.getKind() != FormatToken::eof) {
auto element = parseElement(TopLevelContext);
if (failed(element))
return failure();
/// Add the format element and continue.
elements.push_back(std::move(*element));
}
/// Check that all parameters have been seen.
SmallVector<AttrOrTypeParameter> params = getParameters(def);
for (auto it : llvm::enumerate(params)) {
if (!seenParams.test(it.index())) {
return emitError("format is missing reference to parameter: " +
it.value().getName());
}
}
return AttrOrTypeFormat(def, std::move(elements));
}
FailureOr<std::unique_ptr<Element>>
FormatParser::parseElement(ParserContext ctx) {
if (curToken.getKind() == FormatToken::literal)
return parseLiteral(ctx);
if (curToken.getKind() == FormatToken::variable)
return parseVariable(ctx);
if (curToken.isKeyword())
return parseDirective(ctx);
return emitError("expected literal, directive, or variable");
}
FailureOr<std::unique_ptr<Element>>
FormatParser::parseLiteral(ParserContext ctx) {
if (ctx != TopLevelContext) {
return emitError(
"literals may only be used in the top-level section of the format");
}
/// Get the literal spelling without the surrounding "`".
auto value = curToken.getSpelling().drop_front().drop_back();
if (!isValidLiteral(value))
return emitError("literal '" + value + "' is not valid");
consumeToken();
return {std::make_unique<LiteralElement>(value)};
}
FailureOr<std::unique_ptr<Element>>
FormatParser::parseVariable(ParserContext ctx) {
/// Get the parameter name without the preceding "$".
auto name = curToken.getSpelling().drop_front();
/// Lookup the parameter.
SmallVector<AttrOrTypeParameter> params = getParameters(def);
auto *it = llvm::find_if(
params, [&](auto &param) { return param.getName() == name; });
/// Check that the parameter reference is valid.
if (it == params.end())
return emitError(def.getName() + " has no parameter named '" + name + "'");
auto idx = std::distance(params.begin(), it);
if (seenParams.test(idx))
return emitError("duplicate parameter '" + name + "'");
seenParams.set(idx);
consumeToken();
return {std::make_unique<VariableElement>(*it)};
}
FailureOr<std::unique_ptr<Element>>
FormatParser::parseDirective(ParserContext ctx) {
switch (curToken.getKind()) {
case FormatToken::kw_params:
return parseParamsDirective();
case FormatToken::kw_struct:
if (ctx != TopLevelContext) {
return emitError(
"`struct` may only be used in the top-level section of the format");
}
return parseStructDirective();
default:
return emitError("unknown directive in format: " + curToken.getSpelling());
}
}
FailureOr<std::unique_ptr<Element>> FormatParser::parseParamsDirective() {
consumeToken();
/// Collect all of the attribute's or type's parameters.
SmallVector<AttrOrTypeParameter> params = getParameters(def);
SmallVector<std::unique_ptr<Element>> vars;
/// Ensure that none of the parameters have already been captured.
for (auto it : llvm::enumerate(params)) {
if (seenParams.test(it.index())) {
return emitError("`params` captures duplicate parameter: " +
it.value().getName());
}
seenParams.set(it.index());
vars.push_back(std::make_unique<VariableElement>(it.value()));
}
return {std::make_unique<ParamsDirective>(std::move(vars))};
}
FailureOr<std::unique_ptr<Element>> FormatParser::parseStructDirective() {
consumeToken();
if (failed(parseToken(FormatToken::l_paren,
"expected '(' before `struct` argument list")))
return failure();
/// Parse variables captured by `struct`.
SmallVector<std::unique_ptr<Element>> vars;
/// Parse first captured parameter or a `params` directive.
FailureOr<std::unique_ptr<Element>> var = parseElement(StructDirective);
if (failed(var) || !isa<VariableElement, ParamsDirective>(*var))
return emitError("`struct` argument list expected a variable or directive");
if (isa<VariableElement>(*var)) {
/// Parse any other parameters.
vars.push_back(std::move(*var));
while (curToken.getKind() == FormatToken::comma) {
consumeToken();
var = parseElement(StructDirective);
if (failed(var) || !isa<VariableElement>(*var))
return emitError("expected a variable in `struct` argument list");
vars.push_back(std::move(*var));
}
} else {
/// `struct(params)` captures all parameters in the attribute or type.
vars = cast<ParamsDirective>(var->get())->takeParams();
}
if (curToken.getKind() != FormatToken::r_paren)
return emitError("expected ')' at the end of an argument list");
consumeToken();
return {std::make_unique<::StructDirective>(std::move(vars))};
}
//===----------------------------------------------------------------------===//
// Interface
//===----------------------------------------------------------------------===//
void mlir::tblgen::generateAttrOrTypeFormat(const AttrOrTypeDef &def,
raw_ostream &os) {
llvm::SourceMgr mgr;
mgr.AddNewSourceBuffer(
llvm::MemoryBuffer::getMemBuffer(*def.getAssemblyFormat()),
llvm::SMLoc());
/// Parse the custom assembly format>
FormatParser parser(mgr, def);
FailureOr<AttrOrTypeFormat> format = parser.parse();
if (failed(format)) {
if (formatErrorIsFatal)
PrintFatalError(def.getLoc(), "failed to parse assembly format");
return;
}
/// Generate the parser and printer.
format->genParser(os);
format->genPrinter(os);
}

View File

@@ -0,0 +1,32 @@
//===- AttrOrTypeFormatGen.h - MLIR attribute and type format generator ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TOOLS_MLIRTBLGEN_ATTRORTYPEFORMATGEN_H_
#define MLIR_TOOLS_MLIRTBLGEN_ATTRORTYPEFORMATGEN_H_
#include "llvm/Support/raw_ostream.h"
#include <string>
namespace mlir {
namespace tblgen {
class AttrOrTypeDef;
/// Generate a parser and printer based on a custom assembly format for an
/// attribute or type.
void generateAttrOrTypeFormat(const AttrOrTypeDef &def, llvm::raw_ostream &os);
/// From the parameter name, get the name of the accessor function in camelcase.
/// The first letter of the parameter is upper-cased and prefixed with "get".
/// E.g. 'value' -> 'getValue'.
std::string getParameterAccessorName(llvm::StringRef name);
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TOOLS_MLIRTBLGEN_ATTRORTYPEFORMATGEN_H_

View File

@@ -6,10 +6,12 @@ set(LLVM_LINK_COMPONENTS
add_tablegen(mlir-tblgen MLIR
AttrOrTypeDefGen.cpp
AttrOrTypeFormatGen.cpp
CodeGenHelpers.cpp
DialectGen.cpp
DirectiveCommonGen.cpp
EnumsGen.cpp
FormatGen.cpp
LLVMIRConversionGen.cpp
LLVMIRIntrinsicGen.cpp
mlir-tblgen.cpp

View File

@@ -0,0 +1,225 @@
//===- FormatGen.cpp - Utilities for custom assembly formats ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "FormatGen.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/TableGen/Error.h"
using namespace mlir;
using namespace mlir::tblgen;
//===----------------------------------------------------------------------===//
// FormatToken
//===----------------------------------------------------------------------===//
llvm::SMLoc FormatToken::getLoc() const {
return llvm::SMLoc::getFromPointer(spelling.data());
}
//===----------------------------------------------------------------------===//
// FormatLexer
//===----------------------------------------------------------------------===//
FormatLexer::FormatLexer(llvm::SourceMgr &mgr, llvm::SMLoc loc)
: mgr(mgr), loc(loc),
curBuffer(mgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer()),
curPtr(curBuffer.begin()) {}
FormatToken FormatLexer::emitError(llvm::SMLoc loc, const Twine &msg) {
mgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
llvm::SrcMgr.PrintMessage(this->loc, llvm::SourceMgr::DK_Note,
"in custom assembly format for this operation");
return formToken(FormatToken::error, loc.getPointer());
}
FormatToken FormatLexer::emitError(const char *loc, const Twine &msg) {
return emitError(llvm::SMLoc::getFromPointer(loc), msg);
}
FormatToken FormatLexer::emitErrorAndNote(llvm::SMLoc loc, const Twine &msg,
const Twine &note) {
mgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
llvm::SrcMgr.PrintMessage(this->loc, llvm::SourceMgr::DK_Note,
"in custom assembly format for this operation");
mgr.PrintMessage(loc, llvm::SourceMgr::DK_Note, note);
return formToken(FormatToken::error, loc.getPointer());
}
int FormatLexer::getNextChar() {
char curChar = *curPtr++;
switch (curChar) {
default:
return (unsigned char)curChar;
case 0: {
// A nul character in the stream is either the end of the current buffer or
// a random nul in the file. Disambiguate that here.
if (curPtr - 1 != curBuffer.end())
return 0;
// Otherwise, return end of file.
--curPtr;
return EOF;
}
case '\n':
case '\r':
// Handle the newline character by ignoring it and incrementing the line
// count. However, be careful about 'dos style' files with \n\r in them.
// Only treat a \n\r or \r\n as a single line.
if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
++curPtr;
return '\n';
}
}
FormatToken FormatLexer::lexToken() {
const char *tokStart = curPtr;
// This always consumes at least one character.
int curChar = getNextChar();
switch (curChar) {
default:
// Handle identifiers: [a-zA-Z_]
if (isalpha(curChar) || curChar == '_')
return lexIdentifier(tokStart);
// Unknown character, emit an error.
return emitError(tokStart, "unexpected character");
case EOF:
// Return EOF denoting the end of lexing.
return formToken(FormatToken::eof, tokStart);
// Lex punctuation.
case '^':
return formToken(FormatToken::caret, tokStart);
case ':':
return formToken(FormatToken::colon, tokStart);
case ',':
return formToken(FormatToken::comma, tokStart);
case '=':
return formToken(FormatToken::equal, tokStart);
case '<':
return formToken(FormatToken::less, tokStart);
case '>':
return formToken(FormatToken::greater, tokStart);
case '?':
return formToken(FormatToken::question, tokStart);
case '(':
return formToken(FormatToken::l_paren, tokStart);
case ')':
return formToken(FormatToken::r_paren, tokStart);
case '*':
return formToken(FormatToken::star, tokStart);
// Ignore whitespace characters.
case 0:
case ' ':
case '\t':
case '\n':
return lexToken();
case '`':
return lexLiteral(tokStart);
case '$':
return lexVariable(tokStart);
}
}
FormatToken FormatLexer::lexLiteral(const char *tokStart) {
assert(curPtr[-1] == '`');
// Lex a literal surrounded by ``.
while (const char curChar = *curPtr++) {
if (curChar == '`')
return formToken(FormatToken::literal, tokStart);
}
return emitError(curPtr - 1, "unexpected end of file in literal");
}
FormatToken FormatLexer::lexVariable(const char *tokStart) {
if (!isalpha(curPtr[0]) && curPtr[0] != '_')
return emitError(curPtr - 1, "expected variable name");
// Otherwise, consume the rest of the characters.
while (isalnum(*curPtr) || *curPtr == '_')
++curPtr;
return formToken(FormatToken::variable, tokStart);
}
FormatToken FormatLexer::lexIdentifier(const char *tokStart) {
// Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-')
++curPtr;
// Check to see if this identifier is a keyword.
StringRef str(tokStart, curPtr - tokStart);
auto kind =
StringSwitch<FormatToken::Kind>(str)
.Case("attr-dict", FormatToken::kw_attr_dict)
.Case("attr-dict-with-keyword", FormatToken::kw_attr_dict_w_keyword)
.Case("custom", FormatToken::kw_custom)
.Case("functional-type", FormatToken::kw_functional_type)
.Case("operands", FormatToken::kw_operands)
.Case("params", FormatToken::kw_params)
.Case("ref", FormatToken::kw_ref)
.Case("regions", FormatToken::kw_regions)
.Case("results", FormatToken::kw_results)
.Case("struct", FormatToken::kw_struct)
.Case("successors", FormatToken::kw_successors)
.Case("type", FormatToken::kw_type)
.Default(FormatToken::identifier);
return FormatToken(kind, str);
}
//===----------------------------------------------------------------------===//
// Utility Functions
//===----------------------------------------------------------------------===//
bool mlir::tblgen::shouldEmitSpaceBefore(StringRef value,
bool lastWasPunctuation) {
if (value.size() != 1 && value != "->")
return true;
if (lastWasPunctuation)
return !StringRef(">)}],").contains(value.front());
return !StringRef("<>(){}[],").contains(value.front());
}
bool mlir::tblgen::canFormatStringAsKeyword(StringRef value) {
if (!isalpha(value.front()) && value.front() != '_')
return false;
return llvm::all_of(value.drop_front(), [](char c) {
return isalnum(c) || c == '_' || c == '$' || c == '.';
});
}
bool mlir::tblgen::isValidLiteral(StringRef value) {
if (value.empty())
return false;
char front = value.front();
// If there is only one character, this must either be punctuation or a
// single character bare identifier.
if (value.size() == 1)
return isalpha(front) || StringRef("_:,=<>()[]{}?+*").contains(front);
// Check the punctuation that are larger than a single character.
if (value == "->")
return true;
// Otherwise, this must be an identifier.
return canFormatStringAsKeyword(value);
}
//===----------------------------------------------------------------------===//
// Commandline Options
//===----------------------------------------------------------------------===//
llvm::cl::opt<bool> mlir::tblgen::formatErrorIsFatal(
"asmformat-error-is-fatal",
llvm::cl::desc("Emit a fatal error if format parsing fails"),
llvm::cl::init(true));

View File

@@ -0,0 +1,161 @@
//===- FormatGen.h - Utilities for custom assembly formats ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains common classes for building custom assembly format parsers
// and generators.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TOOLS_MLIRTBLGEN_FORMATGEN_H_
#define MLIR_TOOLS_MLIRTBLGEN_FORMATGEN_H_
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/SMLoc.h"
namespace llvm {
class SourceMgr;
} // end namespace llvm
namespace mlir {
namespace tblgen {
//===----------------------------------------------------------------------===//
// FormatToken
//===----------------------------------------------------------------------===//
/// This class represents a specific token in the input format.
class FormatToken {
public:
/// Basic token kinds.
enum Kind {
// Markers.
eof,
error,
// Tokens with no info.
l_paren,
r_paren,
caret,
colon,
comma,
equal,
less,
greater,
question,
star,
// Keywords.
keyword_start,
kw_attr_dict,
kw_attr_dict_w_keyword,
kw_custom,
kw_functional_type,
kw_operands,
kw_params,
kw_ref,
kw_regions,
kw_results,
kw_struct,
kw_successors,
kw_type,
keyword_end,
// String valued tokens.
identifier,
literal,
variable,
};
FormatToken(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
/// Return the bytes that make up this token.
StringRef getSpelling() const { return spelling; }
/// Return the kind of this token.
Kind getKind() const { return kind; }
/// Return a location for this token.
llvm::SMLoc getLoc() const;
/// Return if this token is a keyword.
bool isKeyword() const {
return getKind() > Kind::keyword_start && getKind() < Kind::keyword_end;
}
private:
/// Discriminator that indicates the kind of token this is.
Kind kind;
/// A reference to the entire token contents; this is always a pointer into
/// a memory buffer owned by the source manager.
StringRef spelling;
};
//===----------------------------------------------------------------------===//
// FormatLexer
//===----------------------------------------------------------------------===//
/// This class implements a simple lexer for operation assembly format strings.
class FormatLexer {
public:
FormatLexer(llvm::SourceMgr &mgr, llvm::SMLoc loc);
/// Lex the next token and return it.
FormatToken lexToken();
/// Emit an error to the lexer with the given location and message.
FormatToken emitError(llvm::SMLoc loc, const Twine &msg);
FormatToken emitError(const char *loc, const Twine &msg);
FormatToken emitErrorAndNote(llvm::SMLoc loc, const Twine &msg,
const Twine &note);
private:
/// Return the next character in the stream.
int getNextChar();
/// Lex an identifier, literal, or variable.
FormatToken lexIdentifier(const char *tokStart);
FormatToken lexLiteral(const char *tokStart);
FormatToken lexVariable(const char *tokStart);
/// Create a token with the current pointer and a start pointer.
FormatToken formToken(FormatToken::Kind kind, const char *tokStart) {
return FormatToken(kind, StringRef(tokStart, curPtr - tokStart));
}
/// The source manager containing the format string.
llvm::SourceMgr &mgr;
/// Location of the format string.
llvm::SMLoc loc;
/// Buffer containing the format string.
StringRef curBuffer;
/// Current pointer in the buffer.
const char *curPtr;
};
/// Whether a space needs to be emitted before a literal. E.g., two keywords
/// back-to-back require a space separator, but a keyword followed by '<' does
/// not require a space.
bool shouldEmitSpaceBefore(StringRef value, bool lastWasPunctuation);
/// Returns true if the given string can be formatted as a keyword.
bool canFormatStringAsKeyword(StringRef value);
/// Returns true if the given string is valid format literal element.
bool isValidLiteral(StringRef value);
/// Whether a failure in parsing the assembly format should be a fatal error.
extern llvm::cl::opt<bool> formatErrorIsFatal;
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TOOLS_MLIRTBLGEN_FORMATGEN_H_

View File

@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "OpFormatGen.h"
#include "FormatGen.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
@@ -20,7 +21,6 @@
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
@@ -30,20 +30,6 @@
using namespace mlir;
using namespace mlir::tblgen;
static llvm::cl::opt<bool> formatErrorIsFatal(
"asmformat-error-is-fatal",
llvm::cl::desc("Emit a fatal error if format parsing fails"),
llvm::cl::init(true));
/// Returns true if the given string can be formatted as a keyword.
static bool canFormatStringAsKeyword(StringRef value) {
if (!isalpha(value.front()) && value.front() != '_')
return false;
return llvm::all_of(value.drop_front(), [](char c) {
return isalnum(c) || c == '_' || c == '$' || c == '.';
});
}
//===----------------------------------------------------------------------===//
// Element
//===----------------------------------------------------------------------===//
@@ -273,33 +259,12 @@ public:
/// Return the literal for this element.
StringRef getLiteral() const { return literal; }
/// Returns true if the given string is a valid literal.
static bool isValidLiteral(StringRef value);
private:
/// The spelling of the literal for this element.
StringRef literal;
};
} // end anonymous namespace
bool LiteralElement::isValidLiteral(StringRef value) {
if (value.empty())
return false;
char front = value.front();
// If there is only one character, this must either be punctuation or a
// single character bare identifier.
if (value.size() == 1)
return isalpha(front) || StringRef("_:,=<>()[]{}?+*").contains(front);
// Check the punctuation that are larger than a single character.
if (value == "->")
return true;
// Otherwise, this must be an identifier.
return canFormatStringAsKeyword(value);
}
//===----------------------------------------------------------------------===//
// WhitespaceElement
@@ -1705,14 +1670,7 @@ static void genLiteralPrinter(StringRef value, OpMethodBody &body,
body << " _odsPrinter";
// Don't insert a space for certain punctuation.
auto shouldPrintSpaceBeforeLiteral = [&] {
if (value.size() != 1 && value != "->")
return true;
if (lastWasPunctuation)
return !StringRef(">)}],").contains(value.front());
return !StringRef("<>(){}[],").contains(value.front());
};
if (shouldEmitSpace && shouldPrintSpaceBeforeLiteral())
if (shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation))
body << " << ' '";
body << " << \"" << value << "\";\n";
@@ -2101,253 +2059,6 @@ void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
lastWasPunctuation);
}
//===----------------------------------------------------------------------===//
// FormatLexer
//===----------------------------------------------------------------------===//
namespace {
/// This class represents a specific token in the input format.
class Token {
public:
enum Kind {
// Markers.
eof,
error,
// Tokens with no info.
l_paren,
r_paren,
caret,
colon,
comma,
equal,
less,
greater,
question,
// Keywords.
keyword_start,
kw_attr_dict,
kw_attr_dict_w_keyword,
kw_custom,
kw_functional_type,
kw_operands,
kw_ref,
kw_regions,
kw_results,
kw_successors,
kw_type,
keyword_end,
// String valued tokens.
identifier,
literal,
variable,
};
Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
/// Return the bytes that make up this token.
StringRef getSpelling() const { return spelling; }
/// Return the kind of this token.
Kind getKind() const { return kind; }
/// Return a location for this token.
llvm::SMLoc getLoc() const {
return llvm::SMLoc::getFromPointer(spelling.data());
}
/// Return if this token is a keyword.
bool isKeyword() const { return kind > keyword_start && kind < keyword_end; }
private:
/// Discriminator that indicates the kind of token this is.
Kind kind;
/// A reference to the entire token contents; this is always a pointer into
/// a memory buffer owned by the source manager.
StringRef spelling;
};
/// This class implements a simple lexer for operation assembly format strings.
class FormatLexer {
public:
FormatLexer(llvm::SourceMgr &mgr, Operator &op);
/// Lex the next token and return it.
Token lexToken();
/// Emit an error to the lexer with the given location and message.
Token emitError(llvm::SMLoc loc, const Twine &msg);
Token emitError(const char *loc, const Twine &msg);
Token emitErrorAndNote(llvm::SMLoc loc, const Twine &msg, const Twine &note);
private:
Token formToken(Token::Kind kind, const char *tokStart) {
return Token(kind, StringRef(tokStart, curPtr - tokStart));
}
/// Return the next character in the stream.
int getNextChar();
/// Lex an identifier, literal, or variable.
Token lexIdentifier(const char *tokStart);
Token lexLiteral(const char *tokStart);
Token lexVariable(const char *tokStart);
llvm::SourceMgr &srcMgr;
Operator &op;
StringRef curBuffer;
const char *curPtr;
};
} // end anonymous namespace
FormatLexer::FormatLexer(llvm::SourceMgr &mgr, Operator &op)
: srcMgr(mgr), op(op) {
curBuffer = srcMgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer();
curPtr = curBuffer.begin();
}
Token FormatLexer::emitError(llvm::SMLoc loc, const Twine &msg) {
srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
llvm::SrcMgr.PrintMessage(op.getLoc()[0], llvm::SourceMgr::DK_Note,
"in custom assembly format for this operation");
return formToken(Token::error, loc.getPointer());
}
Token FormatLexer::emitErrorAndNote(llvm::SMLoc loc, const Twine &msg,
const Twine &note) {
srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
llvm::SrcMgr.PrintMessage(op.getLoc()[0], llvm::SourceMgr::DK_Note,
"in custom assembly format for this operation");
srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Note, note);
return formToken(Token::error, loc.getPointer());
}
Token FormatLexer::emitError(const char *loc, const Twine &msg) {
return emitError(llvm::SMLoc::getFromPointer(loc), msg);
}
int FormatLexer::getNextChar() {
char curChar = *curPtr++;
switch (curChar) {
default:
return (unsigned char)curChar;
case 0: {
// A nul character in the stream is either the end of the current buffer or
// a random nul in the file. Disambiguate that here.
if (curPtr - 1 != curBuffer.end())
return 0;
// Otherwise, return end of file.
--curPtr;
return EOF;
}
case '\n':
case '\r':
// Handle the newline character by ignoring it and incrementing the line
// count. However, be careful about 'dos style' files with \n\r in them.
// Only treat a \n\r or \r\n as a single line.
if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
++curPtr;
return '\n';
}
}
Token FormatLexer::lexToken() {
const char *tokStart = curPtr;
// This always consumes at least one character.
int curChar = getNextChar();
switch (curChar) {
default:
// Handle identifiers: [a-zA-Z_]
if (isalpha(curChar) || curChar == '_')
return lexIdentifier(tokStart);
// Unknown character, emit an error.
return emitError(tokStart, "unexpected character");
case EOF:
// Return EOF denoting the end of lexing.
return formToken(Token::eof, tokStart);
// Lex punctuation.
case '^':
return formToken(Token::caret, tokStart);
case ':':
return formToken(Token::colon, tokStart);
case ',':
return formToken(Token::comma, tokStart);
case '=':
return formToken(Token::equal, tokStart);
case '<':
return formToken(Token::less, tokStart);
case '>':
return formToken(Token::greater, tokStart);
case '?':
return formToken(Token::question, tokStart);
case '(':
return formToken(Token::l_paren, tokStart);
case ')':
return formToken(Token::r_paren, tokStart);
// Ignore whitespace characters.
case 0:
case ' ':
case '\t':
case '\n':
return lexToken();
case '`':
return lexLiteral(tokStart);
case '$':
return lexVariable(tokStart);
}
}
Token FormatLexer::lexLiteral(const char *tokStart) {
assert(curPtr[-1] == '`');
// Lex a literal surrounded by ``.
while (const char curChar = *curPtr++) {
if (curChar == '`')
return formToken(Token::literal, tokStart);
}
return emitError(curPtr - 1, "unexpected end of file in literal");
}
Token FormatLexer::lexVariable(const char *tokStart) {
if (!isalpha(curPtr[0]) && curPtr[0] != '_')
return emitError(curPtr - 1, "expected variable name");
// Otherwise, consume the rest of the characters.
while (isalnum(*curPtr) || *curPtr == '_')
++curPtr;
return formToken(Token::variable, tokStart);
}
Token FormatLexer::lexIdentifier(const char *tokStart) {
// Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-')
++curPtr;
// Check to see if this identifier is a keyword.
StringRef str(tokStart, curPtr - tokStart);
Token::Kind kind =
StringSwitch<Token::Kind>(str)
.Case("attr-dict", Token::kw_attr_dict)
.Case("attr-dict-with-keyword", Token::kw_attr_dict_w_keyword)
.Case("custom", Token::kw_custom)
.Case("functional-type", Token::kw_functional_type)
.Case("operands", Token::kw_operands)
.Case("ref", Token::kw_ref)
.Case("regions", Token::kw_regions)
.Case("results", Token::kw_results)
.Case("successors", Token::kw_successors)
.Case("type", Token::kw_type)
.Default(Token::identifier);
return Token(kind, str);
}
//===----------------------------------------------------------------------===//
// FormatParser
//===----------------------------------------------------------------------===//
@@ -2366,8 +2077,8 @@ namespace {
class FormatParser {
public:
FormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op)
: lexer(mgr, op), curToken(lexer.lexToken()), fmt(format), op(op),
seenOperandTypes(op.getNumOperands()),
: lexer(mgr, op.getLoc()[0]), curToken(lexer.lexToken()), fmt(format),
op(op), seenOperandTypes(op.getNumOperands()),
seenResultTypes(op.getNumResults()) {}
/// Parse the operation assembly format.
@@ -2469,7 +2180,8 @@ private:
LogicalResult parseCustomDirectiveParameter(
std::vector<std::unique_ptr<Element>> &parameters);
LogicalResult parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
Token tok, ParserContext context);
FormatToken tok,
ParserContext context);
LogicalResult parseOperandsDirective(std::unique_ptr<Element> &element,
llvm::SMLoc loc, ParserContext context);
LogicalResult parseReferenceDirective(std::unique_ptr<Element> &element,
@@ -2481,8 +2193,8 @@ private:
LogicalResult parseSuccessorsDirective(std::unique_ptr<Element> &element,
llvm::SMLoc loc,
ParserContext context);
LogicalResult parseTypeDirective(std::unique_ptr<Element> &element, Token tok,
ParserContext context);
LogicalResult parseTypeDirective(std::unique_ptr<Element> &element,
FormatToken tok, ParserContext context);
LogicalResult parseTypeDirectiveOperand(std::unique_ptr<Element> &element,
bool isRefChild = false);
@@ -2492,12 +2204,12 @@ private:
/// Advance the current lexer onto the next token.
void consumeToken() {
assert(curToken.getKind() != Token::eof &&
curToken.getKind() != Token::error &&
assert(curToken.getKind() != FormatToken::eof &&
curToken.getKind() != FormatToken::error &&
"shouldn't advance past EOF or errors");
curToken = lexer.lexToken();
}
LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
LogicalResult parseToken(FormatToken::Kind kind, const Twine &msg) {
if (curToken.getKind() != kind)
return emitError(curToken.getLoc(), msg);
consumeToken();
@@ -2518,7 +2230,7 @@ private:
//===--------------------------------------------------------------------===//
FormatLexer lexer;
Token curToken;
FormatToken curToken;
OperationFormat &fmt;
Operator &op;
@@ -2539,7 +2251,7 @@ LogicalResult FormatParser::parse() {
llvm::SMLoc loc = curToken.getLoc();
// Parse each of the format elements into the main format.
while (curToken.getKind() != Token::eof) {
while (curToken.getKind() != FormatToken::eof) {
std::unique_ptr<Element> element;
if (failed(parseElement(element, TopLevelContext)))
return ::mlir::failure();
@@ -2864,13 +2576,13 @@ LogicalResult FormatParser::parseElement(std::unique_ptr<Element> &element,
if (curToken.isKeyword())
return parseDirective(element, context);
// Literals.
if (curToken.getKind() == Token::literal)
if (curToken.getKind() == FormatToken::literal)
return parseLiteral(element, context);
// Optionals.
if (curToken.getKind() == Token::l_paren)
if (curToken.getKind() == FormatToken::l_paren)
return parseOptional(element, context);
// Variables.
if (curToken.getKind() == Token::variable)
if (curToken.getKind() == FormatToken::variable)
return parseVariable(element, context);
return emitError(curToken.getLoc(),
"expected directive, literal, variable, or optional group");
@@ -2878,7 +2590,7 @@ LogicalResult FormatParser::parseElement(std::unique_ptr<Element> &element,
LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
ParserContext context) {
Token varTok = curToken;
FormatToken varTok = curToken;
consumeToken();
StringRef name = varTok.getSpelling().drop_front();
@@ -2958,31 +2670,31 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
ParserContext context) {
Token dirTok = curToken;
FormatToken dirTok = curToken;
consumeToken();
switch (dirTok.getKind()) {
case Token::kw_attr_dict:
case FormatToken::kw_attr_dict:
return parseAttrDictDirective(element, dirTok.getLoc(), context,
/*withKeyword=*/false);
case Token::kw_attr_dict_w_keyword:
case FormatToken::kw_attr_dict_w_keyword:
return parseAttrDictDirective(element, dirTok.getLoc(), context,
/*withKeyword=*/true);
case Token::kw_custom:
case FormatToken::kw_custom:
return parseCustomDirective(element, dirTok.getLoc(), context);
case Token::kw_functional_type:
case FormatToken::kw_functional_type:
return parseFunctionalTypeDirective(element, dirTok, context);
case Token::kw_operands:
case FormatToken::kw_operands:
return parseOperandsDirective(element, dirTok.getLoc(), context);
case Token::kw_regions:
case FormatToken::kw_regions:
return parseRegionsDirective(element, dirTok.getLoc(), context);
case Token::kw_results:
case FormatToken::kw_results:
return parseResultsDirective(element, dirTok.getLoc(), context);
case Token::kw_successors:
case FormatToken::kw_successors:
return parseSuccessorsDirective(element, dirTok.getLoc(), context);
case Token::kw_ref:
case FormatToken::kw_ref:
return parseReferenceDirective(element, dirTok.getLoc(), context);
case Token::kw_type:
case FormatToken::kw_type:
return parseTypeDirective(element, dirTok, context);
default:
@@ -2992,7 +2704,7 @@ LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
LogicalResult FormatParser::parseLiteral(std::unique_ptr<Element> &element,
ParserContext context) {
Token literalTok = curToken;
FormatToken literalTok = curToken;
if (context != TopLevelContext) {
return emitError(
literalTok.getLoc(),
@@ -3014,7 +2726,7 @@ LogicalResult FormatParser::parseLiteral(std::unique_ptr<Element> &element,
}
// Check that the parsed literal is valid.
if (!LiteralElement::isValidLiteral(value))
if (!isValidLiteral(value))
return emitError(literalTok.getLoc(), "expected valid literal");
element = std::make_unique<LiteralElement>(value);
@@ -3035,14 +2747,15 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
do {
if (failed(parseOptionalChildElement(thenElements, anchorIdx)))
return ::mlir::failure();
} while (curToken.getKind() != Token::r_paren);
} while (curToken.getKind() != FormatToken::r_paren);
consumeToken();
// Parse the `else` elements of this optional group.
if (curToken.getKind() == Token::colon) {
if (curToken.getKind() == FormatToken::colon) {
consumeToken();
if (failed(parseToken(Token::l_paren, "expected '(' to start else branch "
"of optional group")))
if (failed(parseToken(FormatToken::l_paren,
"expected '(' to start else branch "
"of optional group")))
return failure();
do {
llvm::SMLoc childLoc = curToken.getLoc();
@@ -3051,11 +2764,12 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
failed(verifyOptionalChildElement(elseElements.back().get(), childLoc,
/*isAnchor=*/false)))
return failure();
} while (curToken.getKind() != Token::r_paren);
} while (curToken.getKind() != FormatToken::r_paren);
consumeToken();
}
if (failed(parseToken(Token::question, "expected '?' after optional group")))
if (failed(parseToken(FormatToken::question,
"expected '?' after optional group")))
return ::mlir::failure();
// The optional group is required to have an anchor.
@@ -3090,7 +2804,7 @@ LogicalResult FormatParser::parseOptionalChildElement(
return ::mlir::failure();
// Check to see if this element is the anchor of the optional group.
bool isAnchor = curToken.getKind() == Token::caret;
bool isAnchor = curToken.getKind() == FormatToken::caret;
if (isAnchor) {
if (anchorIdx)
return emitError(childLoc, "only one element can be marked as the anchor "
@@ -3194,16 +2908,16 @@ FormatParser::parseCustomDirective(std::unique_ptr<Element> &element,
return emitError(loc, "'custom' is only valid as a top-level directive");
// Parse the custom directive name.
if (failed(
parseToken(Token::less, "expected '<' before custom directive name")))
if (failed(parseToken(FormatToken::less,
"expected '<' before custom directive name")))
return ::mlir::failure();
Token nameTok = curToken;
if (failed(parseToken(Token::identifier,
FormatToken nameTok = curToken;
if (failed(parseToken(FormatToken::identifier,
"expected custom directive name identifier")) ||
failed(parseToken(Token::greater,
failed(parseToken(FormatToken::greater,
"expected '>' after custom directive name")) ||
failed(parseToken(Token::l_paren,
failed(parseToken(FormatToken::l_paren,
"expected '(' before custom directive parameters")))
return ::mlir::failure();
@@ -3212,12 +2926,12 @@ FormatParser::parseCustomDirective(std::unique_ptr<Element> &element,
do {
if (failed(parseCustomDirectiveParameter(elements)))
return ::mlir::failure();
if (curToken.getKind() != Token::comma)
if (curToken.getKind() != FormatToken::comma)
break;
consumeToken();
} while (true);
if (failed(parseToken(Token::r_paren,
if (failed(parseToken(FormatToken::r_paren,
"expected ')' after custom directive parameters")))
return ::mlir::failure();
@@ -3254,9 +2968,8 @@ LogicalResult FormatParser::parseCustomDirectiveParameter(
return ::mlir::success();
}
LogicalResult
FormatParser::parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
Token tok, ParserContext context) {
LogicalResult FormatParser::parseFunctionalTypeDirective(
std::unique_ptr<Element> &element, FormatToken tok, ParserContext context) {
llvm::SMLoc loc = tok.getLoc();
if (context != TopLevelContext)
return emitError(
@@ -3264,11 +2977,14 @@ FormatParser::parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
// Parse the main operand.
std::unique_ptr<Element> inputs, results;
if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) ||
if (failed(parseToken(FormatToken::l_paren,
"expected '(' before argument list")) ||
failed(parseTypeDirectiveOperand(inputs)) ||
failed(parseToken(Token::comma, "expected ',' after inputs argument")) ||
failed(parseToken(FormatToken::comma,
"expected ',' after inputs argument")) ||
failed(parseTypeDirectiveOperand(results)) ||
failed(parseToken(Token::r_paren, "expected ')' after argument list")))
failed(
parseToken(FormatToken::r_paren, "expected ')' after argument list")))
return ::mlir::failure();
element = std::make_unique<FunctionalTypeDirective>(std::move(inputs),
std::move(results));
@@ -3299,9 +3015,11 @@ FormatParser::parseReferenceDirective(std::unique_ptr<Element> &element,
return emitError(loc, "'ref' is only valid within a `custom` directive");
std::unique_ptr<Element> operand;
if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) ||
if (failed(parseToken(FormatToken::l_paren,
"expected '(' before argument list")) ||
failed(parseElement(operand, RefDirectiveContext)) ||
failed(parseToken(Token::r_paren, "expected ')' after argument list")))
failed(
parseToken(FormatToken::r_paren, "expected ')' after argument list")))
return ::mlir::failure();
element = std::make_unique<RefDirective>(std::move(operand));
@@ -3360,17 +3078,19 @@ FormatParser::parseSuccessorsDirective(std::unique_ptr<Element> &element,
}
LogicalResult
FormatParser::parseTypeDirective(std::unique_ptr<Element> &element, Token tok,
ParserContext context) {
FormatParser::parseTypeDirective(std::unique_ptr<Element> &element,
FormatToken tok, ParserContext context) {
llvm::SMLoc loc = tok.getLoc();
if (context == TypeDirectiveContext)
return emitError(loc, "'type' cannot be used as a child of another `type`");
bool isRefChild = context == RefDirectiveContext;
std::unique_ptr<Element> operand;
if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) ||
if (failed(parseToken(FormatToken::l_paren,
"expected '(' before argument list")) ||
failed(parseTypeDirectiveOperand(operand, isRefChild)) ||
failed(parseToken(Token::r_paren, "expected ')' after argument list")))
failed(
parseToken(FormatToken::r_paren, "expected ')' after argument list")))
return ::mlir::failure();
element = std::make_unique<TypeDirective>(std::move(operand));

View File

@@ -220,6 +220,7 @@ cc_library(
"//mlir:SideEffects",
"//mlir:StandardOps",
"//mlir:StandardOpsTransforms",
"//mlir:Support",
"//mlir:TensorDialect",
"//mlir:TransformUtils",
"//mlir:Transforms",