mirror of
https://github.com/intel/llvm.git
synced 2026-01-23 16:06:39 +08:00
[MLIR] Attribute and type formats in ODS
Declarative attribute and type formats with assembly formats. Define an
`assemblyFormat` field in attribute and type defs with a `mnemonic` to
generate a parser and printer.
```tablegen
def MyAttr : AttrDef<MyDialect, "MyAttr"> {
let parameters = (ins "int64_t":$count, "AffineMap":$map);
let mnemonic = "my_attr";
let assemblyFormat = "`<` $count `,` $map `>`";
}
```
Use `struct` to define a comma-separated list of key-value pairs:
```tablegen
def MyType : TypeDef<MyDialect, "MyType"> {
let parameters = (ins "int":$one, "int":$two, "int":$three);
let mnemonic = "my_attr";
let assemblyFormat = "`<` $three `:` struct($one, $two) `>`";
}
```
Use `struct(*)` to capture all parameters.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D111594
This commit is contained in:
@@ -382,3 +382,172 @@ the things named `*Type` are generally now named `*Attr`.
|
||||
|
||||
Aside from that, all of the interfaces for uniquing and storage construction are
|
||||
all the same.
|
||||
|
||||
## Defining Custom Parsers and Printers using Assembly Formats
|
||||
|
||||
Attributes and types defined in ODS with a mnemonic can define an
|
||||
`assemblyFormat` to declaratively describe custom parsers and printers. The
|
||||
assembly format consists of literals, variables, and directives.
|
||||
|
||||
* A literal is a keyword or valid punctuation enclosed in backticks, e.g.
|
||||
`` `keyword` `` or `` `<` ``.
|
||||
* A variable is a parameter name preceeded by a dollar sign, e.g. `$param0`,
|
||||
which captures one attribute or type parameter.
|
||||
* A directive is a keyword followed by an optional argument list that defines
|
||||
special parser and printer behaviour.
|
||||
|
||||
```tablegen
|
||||
// An example type with an assembly format.
|
||||
def MyType : TypeDef<My_Dialect, "MyType"> {
|
||||
// Define a mnemonic to allow the dialect's parser hook to call into the
|
||||
// generated parser.
|
||||
let mnemonic = "my_type";
|
||||
|
||||
// Define two parameters whose C++ types are indicated in string literals.
|
||||
let parameters = (ins "int":$count, "AffineMap":$map);
|
||||
|
||||
// Define the assembly format. Surround the format with less `<` and greater
|
||||
// `>` so that MLIR's printers use the pretty format.
|
||||
let assemblyFormat = "`<` $count `,` `map` `=` $map `>`";
|
||||
}
|
||||
```
|
||||
|
||||
The declarative assembly format for `MyType` results in the following format
|
||||
in the IR:
|
||||
|
||||
```mlir
|
||||
!my_dialect.my_type<42, map = affine_map<(i, j) -> (j, i)>
|
||||
```
|
||||
|
||||
### Parameter Parsing and Printing
|
||||
|
||||
For many basic parameter types, no additional work is needed to define how
|
||||
these parameters are parsed or printerd.
|
||||
|
||||
* The default printer for any parameter is `$_printer << $_self`,
|
||||
where `$_self` is the C++ value of the parameter and `$_printer` is a
|
||||
`DialectAsmPrinter`.
|
||||
* The default parser for a parameter is
|
||||
`FieldParser<$cppClass>::parse($_parser)`, where `$cppClass` is the C++ type
|
||||
of the parameter and `$_parser` is a `DialectAsmParser`.
|
||||
|
||||
Printing and parsing behaviour can be added to additional C++ types by
|
||||
overloading these functions or by defining a `parser` and `printer` in an ODS
|
||||
parameter class.
|
||||
|
||||
Example of overloading:
|
||||
|
||||
```c++
|
||||
using MyParameter = std::pair<int, int>;
|
||||
|
||||
DialectAsmPrinter &operator<<(DialectAsmPrinter &printer, MyParameter param) {
|
||||
printer << param.first << " * " << param.second;
|
||||
}
|
||||
|
||||
template <> struct FieldParser<MyParameter> {
|
||||
static FailureOr<MyParameter> parse(DialectAsmParser &parser) {
|
||||
int a, b;
|
||||
if (parser.parseInteger(a) || parser.parseStar() ||
|
||||
parser.parseInteger(b))
|
||||
return failure();
|
||||
return MyParameter(a, b);
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
Example of using ODS parameter classes:
|
||||
|
||||
```
|
||||
def MyParameter : TypeParameter<"std::pair<int, int>", "pair of ints"> {
|
||||
let printer = [{ $_printer << $_self.first << " * " << $_self.second }];
|
||||
let parser = [{ [&] -> FailureOr<std::pair<int, int>> {
|
||||
int a, b;
|
||||
if ($_parser.parseInteger(a) || $_parser.parseStar() ||
|
||||
$_parser.parseInteger(b))
|
||||
return failure();
|
||||
return std::make_pair(a, b);
|
||||
}() }];
|
||||
}
|
||||
```
|
||||
|
||||
A type using this parameter with the assembly format `` `<` $myParam `>` ``
|
||||
will look as follows in the IR:
|
||||
|
||||
```mlir
|
||||
!my_dialect.my_type<42 * 24>
|
||||
```
|
||||
|
||||
#### Non-POD Parameters
|
||||
|
||||
Parameters that aren't plain-old-data (e.g. references) may need to define a
|
||||
`cppStorageType` to contain the data until it is copied into the allocator.
|
||||
For example, `StringRefParameter` uses `std::string` as its storage type,
|
||||
whereas `ArrayRefParameter` uses `SmallVector` as its storage type. The parsers
|
||||
for these parameters are expected to return `FailureOr<$cppStorageType>`.
|
||||
|
||||
### Assembly Format Directives
|
||||
|
||||
Attribute and type assembly formats have the following directives:
|
||||
|
||||
* `params`: capture all parameters of an attribute or type.
|
||||
* `struct`: generate a "struct-like" parser and printer for a list of key-value
|
||||
pairs.
|
||||
|
||||
#### `params` Directive
|
||||
|
||||
This directive is used to refer to all parameters of an attribute or type.
|
||||
When used as a top-level directive, `params` generates a parser and printer for
|
||||
a comma-separated list of the parameters. For example:
|
||||
|
||||
```tablegen
|
||||
def MyPairType : TypeDef<My_Dialect, "MyPairType"> {
|
||||
let parameters = (ins "int":$a, "int":$b);
|
||||
let mnemonic = "pair";
|
||||
let assemblyFormat = "`<` params `>`";
|
||||
}
|
||||
```
|
||||
|
||||
In the IR, this type will appear as:
|
||||
|
||||
```mlir
|
||||
!my_dialect.pair<42, 24>
|
||||
```
|
||||
|
||||
The `params` directive can also be passed to other directives, such as `struct`,
|
||||
as an argument that refers to all parameters in place of explicitly listing all
|
||||
parameters as variables.
|
||||
|
||||
#### `struct` Directive
|
||||
|
||||
The `struct` directive accepts a list of variables to capture and will generate
|
||||
a parser and printer for a comma-separated list of key-value pairs. The
|
||||
variables are printed in the order they are specified in the argument list **but
|
||||
can be parsed in any order**. For example:
|
||||
|
||||
```tablegen
|
||||
def MyStructType : TypeDef<My_Dialect, "MyStructType"> {
|
||||
let parameters = (ins StringRefParameter<>:$sym_name,
|
||||
"int":$a, "int":$b, "int":$c);
|
||||
let mnemonic = "struct";
|
||||
let assemblyFormat = "`<` $sym_name `->` struct($a, $b, $c) `>`";
|
||||
}
|
||||
```
|
||||
|
||||
In the IR, this type can appear with any permutation of the order of the
|
||||
parameters captured in the directive.
|
||||
|
||||
```mlir
|
||||
!my_dialect.struct<"foo" -> a = 1, b = 2, c = 3>
|
||||
!my_dialect.struct<"foo" -> b = 2, c = 3, a = 1>
|
||||
```
|
||||
|
||||
Passing `params` as the only argument to `struct` makes the directive capture
|
||||
all the parameters of the attribute or type. For the same type above, an
|
||||
assembly format of `` `<` struct(params) `>` `` will result in:
|
||||
|
||||
```mlir
|
||||
!my_dialect.struct<b = 2, sym_name = "foo", c = 3, a = 1>
|
||||
```
|
||||
|
||||
The order in which the parameters are printed is the order in which they are
|
||||
declared in the attribute's or type's `parameter` list.
|
||||
|
||||
@@ -47,6 +47,74 @@ public:
|
||||
virtual StringRef getFullSymbolSpec() const = 0;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Parse Fields
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Provide a template class that can be specialized by users to dispatch to
|
||||
/// parsers. Auto-generated parsers generate calls to `FieldParser<T>::parse`,
|
||||
/// where `T` is the parameter storage type, to parse custom types.
|
||||
template <typename T, typename = T>
|
||||
struct FieldParser;
|
||||
|
||||
/// Parse an attribute.
|
||||
template <typename AttributeT>
|
||||
struct FieldParser<
|
||||
AttributeT, std::enable_if_t<std::is_base_of<Attribute, AttributeT>::value,
|
||||
AttributeT>> {
|
||||
static FailureOr<AttributeT> parse(DialectAsmParser &parser) {
|
||||
AttributeT value;
|
||||
if (parser.parseAttribute(value))
|
||||
return failure();
|
||||
return value;
|
||||
}
|
||||
};
|
||||
|
||||
/// Parse any integer.
|
||||
template <typename IntT>
|
||||
struct FieldParser<IntT,
|
||||
std::enable_if_t<std::is_integral<IntT>::value, IntT>> {
|
||||
static FailureOr<IntT> parse(DialectAsmParser &parser) {
|
||||
IntT value;
|
||||
if (parser.parseInteger(value))
|
||||
return failure();
|
||||
return value;
|
||||
}
|
||||
};
|
||||
|
||||
/// Parse a string.
|
||||
template <>
|
||||
struct FieldParser<std::string> {
|
||||
static FailureOr<std::string> parse(DialectAsmParser &parser) {
|
||||
std::string value;
|
||||
if (parser.parseString(&value))
|
||||
return failure();
|
||||
return value;
|
||||
}
|
||||
};
|
||||
|
||||
/// Parse any container that supports back insertion as a list.
|
||||
template <typename ContainerT>
|
||||
struct FieldParser<
|
||||
ContainerT, std::enable_if_t<std::is_member_function_pointer<
|
||||
decltype(&ContainerT::push_back)>::value,
|
||||
ContainerT>> {
|
||||
using ElementT = typename ContainerT::value_type;
|
||||
static FailureOr<ContainerT> parse(DialectAsmParser &parser) {
|
||||
ContainerT elements;
|
||||
auto elementParser = [&]() {
|
||||
auto element = FieldParser<ElementT>::parse(parser);
|
||||
if (failed(element))
|
||||
return failure();
|
||||
elements.push_back(element.getValue());
|
||||
return success();
|
||||
};
|
||||
if (parser.parseCommaSeparatedList(elementParser))
|
||||
return failure();
|
||||
return elements;
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif
|
||||
#endif // MLIR_IR_DIALECTIMPLEMENTATION_H
|
||||
|
||||
@@ -2886,6 +2886,11 @@ class AttrOrTypeDef<string valueType, string name, list<Trait> defTraits,
|
||||
code printer = ?;
|
||||
code parser = ?;
|
||||
|
||||
// Custom assembly format. Requires 'mnemonic' to be specified. Cannot be
|
||||
// specified at the same time as either 'printer' or 'parser'. The generated
|
||||
// printer requires 'genAccessors' to be true.
|
||||
string assemblyFormat = ?;
|
||||
|
||||
// If set, generate accessors for each parameter.
|
||||
bit genAccessors = 1;
|
||||
|
||||
@@ -2964,10 +2969,22 @@ class AttrOrTypeParameter<string type, string desc, string accessorType = ""> {
|
||||
string cppType = type;
|
||||
// The C++ type of the accessor for this parameter.
|
||||
string cppAccessorType = !if(!empty(accessorType), type, accessorType);
|
||||
// The C++ storage type of of this parameter if it is a reference, e.g.
|
||||
// `std::string` for `StringRef` or `SmallVector` for `ArrayRef`.
|
||||
string cppStorageType = ?;
|
||||
// One-line human-readable description of the argument.
|
||||
string summary = desc;
|
||||
// The format string for the asm syntax (documentation only).
|
||||
string syntax = ?;
|
||||
// The default parameter parser is `::mlir::parseField<T>($_parser)`, which
|
||||
// returns `FailureOr<T>`. Overload `parseField` to support parsing for your
|
||||
// type. Or you can provide a customer printer. For attributes, "$_type" will
|
||||
// be replaced with the required attribute type.
|
||||
string parser = ?;
|
||||
// The default parameter printer is `$_printer << $_self`. Overload the stream
|
||||
// operator of `DialectAsmPrinter` as necessary to print your type. Or you can
|
||||
// provide a custom printer.
|
||||
string printer = ?;
|
||||
}
|
||||
class AttrParameter<string type, string desc, string accessorType = "">
|
||||
: AttrOrTypeParameter<type, desc, accessorType>;
|
||||
@@ -2978,6 +2995,8 @@ class TypeParameter<string type, string desc, string accessorType = "">
|
||||
class StringRefParameter<string desc = ""> :
|
||||
AttrOrTypeParameter<"::llvm::StringRef", desc> {
|
||||
let allocator = [{$_dst = $_allocator.copyInto($_self);}];
|
||||
let printer = [{$_printer << '"' << $_self << '"';}];
|
||||
let cppStorageType = "std::string";
|
||||
}
|
||||
|
||||
// For APFloats, which require comparison.
|
||||
@@ -2990,6 +3009,7 @@ class APFloatParameter<string desc> :
|
||||
class ArrayRefParameter<string arrayOf, string desc = ""> :
|
||||
AttrOrTypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> {
|
||||
let allocator = [{$_dst = $_allocator.copyInto($_self);}];
|
||||
let cppStorageType = "::llvm::SmallVector<" # arrayOf # ">";
|
||||
}
|
||||
|
||||
// For classes which require allocation and have their own allocateInto method.
|
||||
|
||||
@@ -182,10 +182,10 @@ operator<<(AsmPrinterT &p, const TypeRange &types) {
|
||||
llvm::interleaveComma(types, p);
|
||||
return p;
|
||||
}
|
||||
template <typename AsmPrinterT>
|
||||
template <typename AsmPrinterT, typename ElementT>
|
||||
inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
|
||||
AsmPrinterT &>
|
||||
operator<<(AsmPrinterT &p, ArrayRef<Type> types) {
|
||||
operator<<(AsmPrinterT &p, ArrayRef<ElementT> types) {
|
||||
llvm::interleaveComma(types, p);
|
||||
return p;
|
||||
}
|
||||
|
||||
@@ -101,6 +101,9 @@ public:
|
||||
// None. Otherwise, returns the contents of that code block.
|
||||
Optional<StringRef> getParserCode() const;
|
||||
|
||||
// Returns the custom assembly format, if one was specified.
|
||||
Optional<StringRef> getAssemblyFormat() const;
|
||||
|
||||
// Returns true if the accessors based on the parameters should be generated.
|
||||
bool genAccessors() const;
|
||||
|
||||
@@ -199,6 +202,15 @@ public:
|
||||
// Get the C++ accessor type of this parameter.
|
||||
StringRef getCppAccessorType() const;
|
||||
|
||||
// Get the C++ storage type of this parameter.
|
||||
StringRef getCppStorageType() const;
|
||||
|
||||
// Get an optional C++ parameter parser.
|
||||
Optional<StringRef> getParser() const;
|
||||
|
||||
// Get an optional C++ parameter printer.
|
||||
Optional<StringRef> getPrinter() const;
|
||||
|
||||
// Get a description of this parameter for documentation purposes.
|
||||
Optional<StringRef> getSummary() const;
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
//===- Dialect.h - Dialect class --------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
||||
@@ -132,6 +132,10 @@ Optional<StringRef> AttrOrTypeDef::getParserCode() const {
|
||||
return def->getValueAsOptionalString("parser");
|
||||
}
|
||||
|
||||
Optional<StringRef> AttrOrTypeDef::getAssemblyFormat() const {
|
||||
return def->getValueAsOptionalString("assemblyFormat");
|
||||
}
|
||||
|
||||
bool AttrOrTypeDef::genAccessors() const {
|
||||
return def->getValueAsBit("genAccessors");
|
||||
}
|
||||
@@ -219,6 +223,32 @@ StringRef AttrOrTypeParameter::getCppAccessorType() const {
|
||||
return getCppType();
|
||||
}
|
||||
|
||||
StringRef AttrOrTypeParameter::getCppStorageType() const {
|
||||
if (auto *param = dyn_cast<llvm::DefInit>(def->getArg(index))) {
|
||||
if (auto type = param->getDef()->getValueAsOptionalString("cppStorageType"))
|
||||
return *type;
|
||||
}
|
||||
return getCppType();
|
||||
}
|
||||
|
||||
Optional<StringRef> AttrOrTypeParameter::getParser() const {
|
||||
auto *parameterType = def->getArg(index);
|
||||
if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
|
||||
if (auto parser = param->getDef()->getValueAsOptionalString("parser"))
|
||||
return *parser;
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
Optional<StringRef> AttrOrTypeParameter::getPrinter() const {
|
||||
auto *parameterType = def->getArg(index);
|
||||
if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
|
||||
if (auto printer = param->getDef()->getValueAsOptionalString("printer"))
|
||||
return *printer;
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
Optional<StringRef> AttrOrTypeParameter::getSummary() const {
|
||||
auto *parameterType = def->getArg(index);
|
||||
if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
|
||||
|
||||
@@ -116,4 +116,44 @@ def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [
|
||||
);
|
||||
}
|
||||
|
||||
def TestParamOne : AttrParameter<"int64_t", ""> {}
|
||||
|
||||
def TestParamTwo : AttrParameter<"std::string", "", "llvm::StringRef"> {
|
||||
let printer = "$_printer << '\"' << $_self << '\"'";
|
||||
}
|
||||
|
||||
def TestParamFour : ArrayRefParameter<"int", ""> {
|
||||
let cppStorageType = "llvm::SmallVector<int>";
|
||||
let parser = "::parseIntArray($_parser)";
|
||||
let printer = "::printIntArray($_printer, $_self)";
|
||||
}
|
||||
|
||||
def TestAttrWithFormat : Test_Attr<"TestAttrWithFormat"> {
|
||||
let parameters = (
|
||||
ins
|
||||
TestParamOne:$one,
|
||||
TestParamTwo:$two,
|
||||
"::mlir::IntegerAttr":$three,
|
||||
TestParamFour:$four
|
||||
);
|
||||
|
||||
let mnemonic = "attr_with_format";
|
||||
let assemblyFormat = "`<` $one `:` struct($two, $four) `:` $three `>`";
|
||||
let genVerifyDecl = 1;
|
||||
}
|
||||
|
||||
def TestAttrUgly : Test_Attr<"TestAttrUgly"> {
|
||||
let parameters = (ins "::mlir::Attribute":$attr);
|
||||
|
||||
let mnemonic = "attr_ugly";
|
||||
let assemblyFormat = "`begin` $attr `end`";
|
||||
}
|
||||
|
||||
def TestAttrParams: Test_Attr<"TestAttrParams"> {
|
||||
let parameters = (ins "int":$v0, "int":$v1);
|
||||
|
||||
let mnemonic = "attr_params";
|
||||
let assemblyFormat = "`<` params `>`";
|
||||
}
|
||||
|
||||
#endif // TEST_ATTRDEFS
|
||||
|
||||
@@ -16,9 +16,11 @@
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "llvm/ADT/Hashing.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/ADT/bit.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace test;
|
||||
@@ -127,6 +129,36 @@ TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
TestAttrWithFormatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
int64_t one, std::string two, IntegerAttr three,
|
||||
ArrayRef<int> four) {
|
||||
if (four.size() != static_cast<unsigned>(one))
|
||||
return emitError() << "expected 'one' to equal 'four.size()'";
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utility Functions for Generated Attributes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static FailureOr<SmallVector<int>> parseIntArray(DialectAsmParser &parser) {
|
||||
SmallVector<int> ints;
|
||||
if (parser.parseLSquare() || parser.parseCommaSeparatedList([&]() {
|
||||
ints.push_back(0);
|
||||
return parser.parseInteger(ints.back());
|
||||
}) ||
|
||||
parser.parseRSquare())
|
||||
return failure();
|
||||
return ints;
|
||||
}
|
||||
|
||||
static void printIntArray(DialectAsmPrinter &printer, ArrayRef<int> ints) {
|
||||
printer << '[';
|
||||
llvm::interleaveComma(ints, printer);
|
||||
printer << ']';
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TestSubElementsAccessAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
// To get the test dialect def.
|
||||
include "TestOps.td"
|
||||
include "TestAttrDefs.td"
|
||||
include "mlir/IR/BuiltinTypes.td"
|
||||
include "mlir/Interfaces/DataLayoutInterfaces.td"
|
||||
|
||||
@@ -189,4 +190,44 @@ def TestTypeWithTrait : Test_Type<"TestTypeWithTrait", [TestTypeTrait]> {
|
||||
let mnemonic = "test_type_with_trait";
|
||||
}
|
||||
|
||||
// Type with assembly format.
|
||||
def TestTypeWithFormat : Test_Type<"TestTypeWithFormat"> {
|
||||
let parameters = (
|
||||
ins
|
||||
TestParamOne:$one,
|
||||
TestParamTwo:$two,
|
||||
"::mlir::Attribute":$three
|
||||
);
|
||||
|
||||
let mnemonic = "type_with_format";
|
||||
let assemblyFormat = "`<` $one `,` struct($three, $two) `>`";
|
||||
}
|
||||
|
||||
// Test dispatch to parseField
|
||||
def TestTypeNoParser : Test_Type<"TestTypeNoParser"> {
|
||||
let parameters = (
|
||||
ins
|
||||
"uint32_t":$one,
|
||||
ArrayRefParameter<"int64_t">:$two,
|
||||
StringRefParameter<>:$three,
|
||||
"::test::CustomParam":$four
|
||||
);
|
||||
|
||||
let mnemonic = "no_parser";
|
||||
let assemblyFormat = "`<` $one `,` `[` $two `]` `,` $three `,` $four `>`";
|
||||
}
|
||||
|
||||
def TestTypeStructCaptureAll : Test_Type<"TestStructTypeCaptureAll"> {
|
||||
let parameters = (
|
||||
ins
|
||||
"int":$v0,
|
||||
"int":$v1,
|
||||
"int":$v2,
|
||||
"int":$v3
|
||||
);
|
||||
|
||||
let mnemonic = "struct_capture_all";
|
||||
let assemblyFormat = "`<` struct(params) `>`";
|
||||
}
|
||||
|
||||
#endif // TEST_TYPEDEFS
|
||||
|
||||
@@ -38,8 +38,38 @@ struct FieldInfo {
|
||||
}
|
||||
};
|
||||
|
||||
/// A custom type for a test type parameter.
|
||||
struct CustomParam {
|
||||
int value;
|
||||
|
||||
bool operator==(const CustomParam &other) const {
|
||||
return other.value == value;
|
||||
}
|
||||
};
|
||||
|
||||
inline llvm::hash_code hash_value(const test::CustomParam ¶m) {
|
||||
return llvm::hash_value(param.value);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
|
||||
namespace mlir {
|
||||
template <>
|
||||
struct FieldParser<test::CustomParam> {
|
||||
static FailureOr<test::CustomParam> parse(DialectAsmParser &parser) {
|
||||
auto value = FieldParser<int>::parse(parser);
|
||||
if (failed(value))
|
||||
return failure();
|
||||
return test::CustomParam{value.getValue()};
|
||||
}
|
||||
};
|
||||
} // end namespace mlir
|
||||
|
||||
inline mlir::DialectAsmPrinter &operator<<(mlir::DialectAsmPrinter &printer,
|
||||
const test::CustomParam ¶m) {
|
||||
return printer << param.value;
|
||||
}
|
||||
|
||||
#include "TestTypeInterfaces.h.inc"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
@@ -52,17 +82,19 @@ namespace test {
|
||||
struct TestRecursiveTypeStorage : public ::mlir::TypeStorage {
|
||||
using KeyTy = ::llvm::StringRef;
|
||||
|
||||
explicit TestRecursiveTypeStorage(::llvm::StringRef key) : name(key), body(::mlir::Type()) {}
|
||||
explicit TestRecursiveTypeStorage(::llvm::StringRef key)
|
||||
: name(key), body(::mlir::Type()) {}
|
||||
|
||||
bool operator==(const KeyTy &other) const { return name == other; }
|
||||
|
||||
static TestRecursiveTypeStorage *construct(::mlir::TypeStorageAllocator &allocator,
|
||||
const KeyTy &key) {
|
||||
static TestRecursiveTypeStorage *
|
||||
construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &key) {
|
||||
return new (allocator.allocate<TestRecursiveTypeStorage>())
|
||||
TestRecursiveTypeStorage(allocator.copyInto(key));
|
||||
}
|
||||
|
||||
::mlir::LogicalResult mutate(::mlir::TypeStorageAllocator &allocator, ::mlir::Type newBody) {
|
||||
::mlir::LogicalResult mutate(::mlir::TypeStorageAllocator &allocator,
|
||||
::mlir::Type newBody) {
|
||||
// Cannot set a different body than before.
|
||||
if (body && body != newBody)
|
||||
return ::mlir::failure();
|
||||
@@ -79,11 +111,13 @@ struct TestRecursiveTypeStorage : public ::mlir::TypeStorage {
|
||||
/// type, potentially itself. This requires the body to be mutated separately
|
||||
/// from type creation.
|
||||
class TestRecursiveType
|
||||
: public ::mlir::Type::TypeBase<TestRecursiveType, ::mlir::Type, TestRecursiveTypeStorage> {
|
||||
: public ::mlir::Type::TypeBase<TestRecursiveType, ::mlir::Type,
|
||||
TestRecursiveTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
static TestRecursiveType get(::mlir::MLIRContext *ctx, ::llvm::StringRef name) {
|
||||
static TestRecursiveType get(::mlir::MLIRContext *ctx,
|
||||
::llvm::StringRef name) {
|
||||
return Base::get(ctx, name);
|
||||
}
|
||||
|
||||
|
||||
76
mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
Normal file
76
mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
Normal file
@@ -0,0 +1,76 @@
|
||||
// RUN: mlir-tblgen -gen-typedef-defs -I %S/../../include -asmformat-error-is-fatal=false %s 2>&1 | FileCheck %s
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def Test_Dialect : Dialect {
|
||||
let name = "TestDialect";
|
||||
let cppNamespace = "::test";
|
||||
}
|
||||
|
||||
class InvalidType<string name, string asm> : TypeDef<Test_Dialect, name> {
|
||||
let mnemonic = asm;
|
||||
}
|
||||
|
||||
/// Test format is missing a parameter capture.
|
||||
def InvalidTypeA : InvalidType<"InvalidTypeA", "invalid_a"> {
|
||||
let parameters = (ins "int":$v0, "int":$v1);
|
||||
// CHECK: format is missing reference to parameter: v1
|
||||
let assemblyFormat = "`<` $v0 `>`";
|
||||
}
|
||||
|
||||
/// Test format has duplicate parameter captures.
|
||||
def InvalidTypeB : InvalidType<"InvalidTypeB", "invalid_b"> {
|
||||
let parameters = (ins "int":$v0, "int":$v1);
|
||||
// CHECK: duplicate parameter 'v0'
|
||||
let assemblyFormat = "`<` $v0 `,` $v1 `,` $v0 `>`";
|
||||
}
|
||||
|
||||
/// Test format has invalid syntax.
|
||||
def InvalidTypeC : InvalidType<"InvalidTypeC", "invalid_c"> {
|
||||
let parameters = (ins "int":$v0, "int":$v1);
|
||||
// CHECK: expected literal, directive, or variable
|
||||
let assemblyFormat = "`<` $v0, $v1 `>`";
|
||||
}
|
||||
|
||||
/// Test struct directive has invalid syntax.
|
||||
def InvalidTypeD : InvalidType<"InvalidTypeD", "invalid_d"> {
|
||||
let parameters = (ins "int":$v0);
|
||||
// CHECK: literals may only be used in the top-level section of the format
|
||||
// CHECK: expected a variable in `struct` argument list
|
||||
let assemblyFormat = "`<` struct($v0, `,`) `>`";
|
||||
}
|
||||
|
||||
/// Test struct directive cannot capture zero parameters.
|
||||
def InvalidTypeE : InvalidType<"InvalidTypeE", "invalid_e"> {
|
||||
let parameters = (ins "int":$v0);
|
||||
// CHECK: `struct` argument list expected a variable or directive
|
||||
let assemblyFormat = "`<` struct() $v0 `>`";
|
||||
}
|
||||
|
||||
/// Test capture parameter that does not exist.
|
||||
def InvalidTypeF : InvalidType<"InvalidTypeF", "invalid_f"> {
|
||||
let parameters = (ins "int":$v0);
|
||||
// CHECK: InvalidTypeF has no parameter named 'v1'
|
||||
let assemblyFormat = "`<` $v0 $v1 `>`";
|
||||
}
|
||||
|
||||
/// Test duplicate capture of parameter in capture-all struct.
|
||||
def InvalidTypeG : InvalidType<"InvalidTypeG", "invalid_g"> {
|
||||
let parameters = (ins "int":$v0, "int":$v1, "int":$v2);
|
||||
// CHECK: duplicate parameter 'v0'
|
||||
let assemblyFormat = "`<` struct(params) $v0 `>`";
|
||||
}
|
||||
|
||||
/// Test capture-all struct duplicate capture.
|
||||
def InvalidTypeH : InvalidType<"InvalidTypeH", "invalid_h"> {
|
||||
let parameters = (ins "int":$v0, "int":$v1, "int":$v2);
|
||||
// CHECK: `params` captures duplicate parameter: v0
|
||||
let assemblyFormat = "`<` $v0 struct(params) `>`";
|
||||
}
|
||||
|
||||
/// Test capture of parameter after `params` directive.
|
||||
def InvalidTypeI : InvalidType<"InvalidTypeI", "invalid_i"> {
|
||||
let parameters = (ins "int":$v0);
|
||||
// CHECK: duplicate parameter 'v0'
|
||||
let assemblyFormat = "`<` params $v0 `>`";
|
||||
}
|
||||
21
mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
Normal file
21
mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
Normal file
@@ -0,0 +1,21 @@
|
||||
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @test_roundtrip_parameter_parsers
|
||||
// CHECK: !test.type_with_format<111, three = #test<"attr_ugly begin 5 : index end">, two = "foo">
|
||||
// CHECK: !test.type_with_format<2147, three = "hi", two = "hi">
|
||||
func private @test_roundtrip_parameter_parsers(!test.type_with_format<111, three = #test<"attr_ugly begin 5 : index end">, two = "foo">) -> !test.type_with_format<2147, two = "hi", three = "hi">
|
||||
attributes {
|
||||
// CHECK: #test.attr_with_format<3 : two = "hello", four = [1, 2, 3] : 42 : i64>
|
||||
attr0 = #test.attr_with_format<3 : two = "hello", four = [1, 2, 3] : 42 : i64>,
|
||||
// CHECK: #test.attr_with_format<5 : two = "a_string", four = [4, 5, 6, 7, 8] : 8 : i8>
|
||||
attr1 = #test.attr_with_format<5 : two = "a_string", four = [4, 5, 6, 7, 8] : 8 : i8>,
|
||||
// CHECK: #test<"attr_ugly begin 5 : index end">
|
||||
attr2 = #test<"attr_ugly begin 5 : index end">,
|
||||
// CHECK: #test.attr_params<42, 24>
|
||||
attr3 = #test.attr_params<42, 24>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_roundtrip_default_parsers_struct
|
||||
// CHECK: !test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4>
|
||||
// CHECK: !test.struct_capture_all<v0 = 0, v1 = 1, v2 = 2, v3 = 3>
|
||||
func private @test_roundtrip_default_parsers_struct(!test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4>) -> !test.struct_capture_all<v3 = 3, v1 = 1, v2 = 2, v0 = 0>
|
||||
127
mlir/test/mlir-tblgen/attr-or-type-format.mlir
Normal file
127
mlir/test/mlir-tblgen/attr-or-type-format.mlir
Normal file
@@ -0,0 +1,127 @@
|
||||
// RUN: mlir-opt --split-input-file %s --verify-diagnostics
|
||||
|
||||
func private @test_ugly_attr_cannot_be_pretty() -> () attributes {
|
||||
// expected-error@+1 {{expected 'begin'}}
|
||||
attr = #test.attr_ugly
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func private @test_ugly_attr_no_mnemonic() -> () attributes {
|
||||
// expected-error@+1 {{expected valid keyword}}
|
||||
attr = #test<"">
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func private @test_ugly_attr_parser_dispatch() -> () attributes {
|
||||
// expected-error@+1 {{expected 'begin'}}
|
||||
attr = #test<"attr_ugly">
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func private @test_ugly_attr_missing_parameter() -> () attributes {
|
||||
// expected-error@+2 {{failed to parse TestAttrUgly parameter 'attr'}}
|
||||
// expected-error@+1 {{expected non-function type}}
|
||||
attr = #test<"attr_ugly begin">
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func private @test_ugly_attr_missing_literal() -> () attributes {
|
||||
// expected-error@+1 {{expected 'end'}}
|
||||
attr = #test<"attr_ugly begin \"string_attr\"">
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func private @test_pretty_attr_expects_less() -> () attributes {
|
||||
// expected-error@+1 {{expected '<'}}
|
||||
attr = #test.attr_with_format
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func private @test_pretty_attr_missing_param() -> () attributes {
|
||||
// expected-error@+2 {{expected integer value}}
|
||||
// expected-error@+1 {{failed to parse TestAttrWithFormat parameter 'one'}}
|
||||
attr = #test.attr_with_format<>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func private @test_parse_invalid_param() -> () attributes {
|
||||
// Test parameter parser failure is propagated
|
||||
// expected-error@+2 {{expected integer value}}
|
||||
// expected-error@+1 {{failed to parse TestAttrWithFormat parameter 'one'}}
|
||||
attr = #test.attr_with_format<"hi">
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func private @test_pretty_attr_invalid_syntax() -> () attributes {
|
||||
// expected-error@+1 {{expected ':'}}
|
||||
attr = #test.attr_with_format<42>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func private @test_struct_missing_key() -> () attributes {
|
||||
// expected-error@+2 {{expected valid keyword}}
|
||||
// expected-error@+1 {{expected a parameter name in struct}}
|
||||
attr = #test.attr_with_format<42 :>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func private @test_struct_unknown_key() -> () attributes {
|
||||
// expected-error@+1 {{duplicate or unknown struct parameter}}
|
||||
attr = #test.attr_with_format<42 : nine = "foo">
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func private @test_struct_duplicate_key() -> () attributes {
|
||||
// expected-error@+1 {{duplicate or unknown struct parameter}}
|
||||
attr = #test.attr_with_format<42 : two = "foo", two = "bar">
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func private @test_struct_not_enough_values() -> () attributes {
|
||||
// expected-error@+1 {{expected ','}}
|
||||
attr = #test.attr_with_format<42 : two = "foo">
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func private @test_parse_param_after_struct() -> () attributes {
|
||||
// expected-error@+2 {{expected non-function type}}
|
||||
// expected-error@+1 {{failed to parse TestAttrWithFormat parameter 'three'}}
|
||||
attr = #test.attr_with_format<42 : two = "foo", four = [1, 2, 3] : >
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+1 {{expected '<'}}
|
||||
func private @test_invalid_type() -> !test.type_with_format
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+2 {{expected integer value}}
|
||||
// expected-error@+1 {{failed to parse TestTypeWithFormat parameter 'one'}}
|
||||
func private @test_pretty_type_invalid_param() -> !test.type_with_format<>
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+2 {{expected ':'}}
|
||||
// expected-error@+1 {{failed to parse TestTypeWithFormat parameter 'three'}}
|
||||
func private @test_type_syntax_error() -> !test.type_with_format<42, two = "hi", three = #test.attr_with_format<42>>
|
||||
|
||||
// -----
|
||||
|
||||
func private @test_verifier_fails() -> () attributes {
|
||||
// expected-error@+1 {{expected 'one' to equal 'four.size()'}}
|
||||
attr = #test.attr_with_format<42 : two = "hello", four = [1, 2, 3] : 42 : i64>
|
||||
}
|
||||
394
mlir/test/mlir-tblgen/attr-or-type-format.td
Normal file
394
mlir/test/mlir-tblgen/attr-or-type-format.td
Normal file
@@ -0,0 +1,394 @@
|
||||
// RUN: mlir-tblgen -gen-attrdef-defs -I %S/../../include %s | FileCheck %s --check-prefix=ATTR
|
||||
// RUN: mlir-tblgen -gen-typedef-defs -I %S/../../include %s | FileCheck %s --check-prefix=TYPE
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
/// Test that attribute and type printers and parsers are correctly generated.
|
||||
def Test_Dialect : Dialect {
|
||||
let name = "TestDialect";
|
||||
let cppNamespace = "::test";
|
||||
}
|
||||
|
||||
class TestAttr<string name> : AttrDef<Test_Dialect, name>;
|
||||
class TestType<string name> : TypeDef<Test_Dialect, name>;
|
||||
|
||||
def AttrParamA : AttrParameter<"TestParamA", "an attribute param A"> {
|
||||
let parser = "::parseAttrParamA($_parser, $_type)";
|
||||
let printer = "::printAttrParamA($_printer, $_self)";
|
||||
}
|
||||
|
||||
def AttrParamB : AttrParameter<"TestParamB", "an attribute param B"> {
|
||||
let parser = "$_type ? ::parseAttrWithType($_parser, $_type) : ::parseAttrWithout($_parser)";
|
||||
let printer = "::printAttrB($_printer, $_self)";
|
||||
}
|
||||
|
||||
def TypeParamA : TypeParameter<"TestParamC", "a type param C"> {
|
||||
let parser = "::parseTypeParamC($_parser)";
|
||||
let printer = "$_printer << $_self";
|
||||
}
|
||||
|
||||
def TypeParamB : TypeParameter<"TestParamD", "a type param D"> {
|
||||
let parser = "someFcnCall()";
|
||||
let printer = "myPrinter($_self)";
|
||||
}
|
||||
|
||||
/// Check simple attribute parser and printer are generated correctly.
|
||||
|
||||
// ATTR: ::mlir::Attribute TestAAttr::parse(::mlir::DialectAsmParser &parser,
|
||||
// ATTR: ::mlir::Type attrType) {
|
||||
// ATTR: FailureOr<IntegerAttr> _result_value;
|
||||
// ATTR: FailureOr<TestParamA> _result_complex;
|
||||
// ATTR: if (parser.parseKeyword("hello"))
|
||||
// ATTR: return {};
|
||||
// ATTR: if (parser.parseEqual())
|
||||
// ATTR: return {};
|
||||
// ATTR: _result_value = ::mlir::FieldParser<IntegerAttr>::parse(parser);
|
||||
// ATTR: if (failed(_result_value))
|
||||
// ATTR: return {};
|
||||
// ATTR: if (parser.parseComma())
|
||||
// ATTR: return {};
|
||||
// ATTR: _result_complex = ::parseAttrParamA(parser, attrType);
|
||||
// ATTR: if (failed(_result_complex))
|
||||
// ATTR: return {};
|
||||
// ATTR: if (parser.parseRParen())
|
||||
// ATTR: return {};
|
||||
// ATTR: return TestAAttr::get(parser.getContext(),
|
||||
// ATTR: _result_value.getValue(),
|
||||
// ATTR: _result_complex.getValue());
|
||||
// ATTR: }
|
||||
|
||||
// ATTR: void TestAAttr::print(::mlir::DialectAsmPrinter &printer) const {
|
||||
// ATTR: printer << "attr_a";
|
||||
// ATTR: printer << ' ' << "hello";
|
||||
// ATTR: printer << ' ' << "=";
|
||||
// ATTR: printer << ' ';
|
||||
// ATTR: printer << getValue();
|
||||
// ATTR: printer << ",";
|
||||
// ATTR: printer << ' ';
|
||||
// ATTR: ::printAttrParamA(printer, getComplex());
|
||||
// ATTR: printer << ")";
|
||||
// ATTR: }
|
||||
|
||||
def AttrA : TestAttr<"TestA"> {
|
||||
let parameters = (ins
|
||||
"IntegerAttr":$value,
|
||||
AttrParamA:$complex
|
||||
);
|
||||
|
||||
let mnemonic = "attr_a";
|
||||
let assemblyFormat = "`hello` `=` $value `,` $complex `)`";
|
||||
}
|
||||
|
||||
/// Test simple struct parser and printer are generated correctly.
|
||||
|
||||
// ATTR: ::mlir::Attribute TestBAttr::parse(::mlir::DialectAsmParser &parser,
|
||||
// ATTR: ::mlir::Type attrType) {
|
||||
// ATTR: bool _seen_v0 = false;
|
||||
// ATTR: bool _seen_v1 = false;
|
||||
// ATTR: for (unsigned _index = 0; _index < 2; ++_index) {
|
||||
// ATTR: StringRef _paramKey;
|
||||
// ATTR: if (parser.parseKeyword(&_paramKey))
|
||||
// ATTR: return {};
|
||||
// ATTR: if (parser.parseEqual())
|
||||
// ATTR: return {};
|
||||
// ATTR: if (!_seen_v0 && _paramKey == "v0") {
|
||||
// ATTR: _seen_v0 = true;
|
||||
// ATTR: _result_v0 = ::parseAttrParamA(parser, attrType);
|
||||
// ATTR: if (failed(_result_v0))
|
||||
// ATTR: return {};
|
||||
// ATTR: } else if (!_seen_v1 && _paramKey == "v1") {
|
||||
// ATTR: _seen_v1 = true;
|
||||
// ATTR: _result_v1 = attrType ? ::parseAttrWithType(parser, attrType) : ::parseAttrWithout(parser);
|
||||
// ATTR: if (failed(_result_v1))
|
||||
// ATTR: return {};
|
||||
// ATTR: } else {
|
||||
// ATTR: return {};
|
||||
// ATTR: }
|
||||
// ATTR: if ((_index != 2 - 1) && parser.parseComma())
|
||||
// ATTR: return {};
|
||||
// ATTR: }
|
||||
// ATTR: return TestBAttr::get(parser.getContext(),
|
||||
// ATTR: _result_v0.getValue(),
|
||||
// ATTR: _result_v1.getValue());
|
||||
// ATTR: }
|
||||
|
||||
// ATTR: void TestBAttr::print(::mlir::DialectAsmPrinter &printer) const {
|
||||
// ATTR: printer << "v0";
|
||||
// ATTR: printer << ' ' << "=";
|
||||
// ATTR: printer << ' ';
|
||||
// ATTR: ::printAttrParamA(printer, getV0());
|
||||
// ATTR: printer << ",";
|
||||
// ATTR: printer << ' ' << "v1";
|
||||
// ATTR: printer << ' ' << "=";
|
||||
// ATTR: printer << ' ';
|
||||
// ATTR: ::printAttrB(printer, getV1());
|
||||
// ATTR: }
|
||||
|
||||
def AttrB : TestAttr<"TestB"> {
|
||||
let parameters = (ins
|
||||
AttrParamA:$v0,
|
||||
AttrParamB:$v1
|
||||
);
|
||||
|
||||
let mnemonic = "attr_b";
|
||||
let assemblyFormat = "`{` struct($v0, $v1) `}`";
|
||||
}
|
||||
|
||||
/// Test attribute with capture-all params has correct parser and printer.
|
||||
|
||||
// ATTR: ::mlir::Attribute TestFAttr::parse(::mlir::DialectAsmParser &parser,
|
||||
// ATTR: ::mlir::Type attrType) {
|
||||
// ATTR: ::mlir::FailureOr<int> _result_v0;
|
||||
// ATTR: ::mlir::FailureOr<int> _result_v1;
|
||||
// ATTR: _result_v0 = ::mlir::FieldParser<int>::parse(parser);
|
||||
// ATTR: if (failed(_result_v0))
|
||||
// ATTR: return {};
|
||||
// ATTR: if (parser.parseComma())
|
||||
// ATTR: return {};
|
||||
// ATTR: _result_v1 = ::mlir::FieldParser<int>::parse(parser);
|
||||
// ATTR: if (failed(_result_v1))
|
||||
// ATTR: return {};
|
||||
// ATTR: return TestFAttr::get(parser.getContext(),
|
||||
// ATTR: _result_v0.getValue(),
|
||||
// ATTR: _result_v1.getValue());
|
||||
// ATTR: }
|
||||
|
||||
// ATTR: void TestFAttr::print(::mlir::DialectAsmPrinter &printer) const {
|
||||
// ATTR: printer << "attr_c";
|
||||
// ATTR: printer << ' ';
|
||||
// ATTR: printer << getV0();
|
||||
// ATTR: printer << ",";
|
||||
// ATTR: printer << ' ';
|
||||
// ATTR: printer << getV1();
|
||||
// ATTR: }
|
||||
|
||||
def AttrC : TestAttr<"TestF"> {
|
||||
let parameters = (ins "int":$v0, "int":$v1);
|
||||
|
||||
let mnemonic = "attr_c";
|
||||
let assemblyFormat = "params";
|
||||
}
|
||||
|
||||
/// Test type parser and printer that mix variables and struct are generated
|
||||
/// correctly.
|
||||
|
||||
// TYPE: ::mlir::Type TestCType::parse(::mlir::DialectAsmParser &parser) {
|
||||
// TYPE: FailureOr<IntegerAttr> _result_value;
|
||||
// TYPE: FailureOr<TestParamC> _result_complex;
|
||||
// TYPE: if (parser.parseKeyword("foo"))
|
||||
// TYPE: return {};
|
||||
// TYPE: if (parser.parseComma())
|
||||
// TYPE: return {};
|
||||
// TYPE: if (parser.parseColon())
|
||||
// TYPE: return {};
|
||||
// TYPE: if (parser.parseKeyword("bob"))
|
||||
// TYPE: return {};
|
||||
// TYPE: if (parser.parseKeyword("bar"))
|
||||
// TYPE: return {};
|
||||
// TYPE: _result_value = ::mlir::FieldParser<IntegerAttr>::parse(parser);
|
||||
// TYPE: if (failed(_result_value))
|
||||
// TYPE: return {};
|
||||
// TYPE: bool _seen_complex = false;
|
||||
// TYPE: for (unsigned _index = 0; _index < 1; ++_index) {
|
||||
// TYPE: StringRef _paramKey;
|
||||
// TYPE: if (parser.parseKeyword(&_paramKey))
|
||||
// TYPE: return {};
|
||||
// TYPE: if (!_seen_complex && _paramKey == "complex") {
|
||||
// TYPE: _seen_complex = true;
|
||||
// TYPE: _result_complex = ::parseTypeParamC(parser);
|
||||
// TYPE: if (failed(_result_complex))
|
||||
// TYPE: return {};
|
||||
// TYPE: } else {
|
||||
// TYPE: return {};
|
||||
// TYPE: }
|
||||
// TYPE: if ((_index != 1 - 1) && parser.parseComma())
|
||||
// TYPE: return {};
|
||||
// TYPE: }
|
||||
// TYPE: if (parser.parseRParen())
|
||||
// TYPE: return {};
|
||||
// TYPE: }
|
||||
|
||||
// TYPE: void TestCType::print(::mlir::DialectAsmPrinter &printer) const {
|
||||
// TYPE: printer << "type_c";
|
||||
// TYPE: printer << ' ' << "foo";
|
||||
// TYPE: printer << ",";
|
||||
// TYPE: printer << ' ' << ":";
|
||||
// TYPE: printer << ' ' << "bob";
|
||||
// TYPE: printer << ' ' << "bar";
|
||||
// TYPE: printer << ' ';
|
||||
// TYPE: printer << getValue();
|
||||
// TYPE: printer << ' ' << "complex";
|
||||
// TYPE: printer << ' ' << "=";
|
||||
// TYPE: printer << ' ';
|
||||
// TYPE: printer << getComplex();
|
||||
// TYPE: printer << ")";
|
||||
// TYPE: }
|
||||
|
||||
def TypeA : TestType<"TestC"> {
|
||||
let parameters = (ins
|
||||
"IntegerAttr":$value,
|
||||
TypeParamA:$complex
|
||||
);
|
||||
|
||||
let mnemonic = "type_c";
|
||||
let assemblyFormat = "`foo` `,` `:` `bob` `bar` $value struct($complex) `)`";
|
||||
}
|
||||
|
||||
/// Test type parser and printer with mix of variables and struct are generated
|
||||
/// correctly.
|
||||
|
||||
// TYPE: ::mlir::Type TestDType::parse(::mlir::DialectAsmParser &parser) {
|
||||
// TYPE: _result_v0 = ::parseTypeParamC(parser);
|
||||
// TYPE: if (failed(_result_v0))
|
||||
// TYPE: return {};
|
||||
// TYPE: bool _seen_v1 = false;
|
||||
// TYPE: bool _seen_v2 = false;
|
||||
// TYPE: for (unsigned _index = 0; _index < 2; ++_index) {
|
||||
// TYPE: StringRef _paramKey;
|
||||
// TYPE: if (parser.parseKeyword(&_paramKey))
|
||||
// TYPE: return {};
|
||||
// TYPE: if (parser.parseEqual())
|
||||
// TYPE: return {};
|
||||
// TYPE: if (!_seen_v1 && _paramKey == "v1") {
|
||||
// TYPE: _seen_v1 = true;
|
||||
// TYPE: _result_v1 = someFcnCall();
|
||||
// TYPE: if (failed(_result_v1))
|
||||
// TYPE: return {};
|
||||
// TYPE: } else if (!_seen_v2 && _paramKey == "v2") {
|
||||
// TYPE: _seen_v2 = true;
|
||||
// TYPE: _result_v2 = ::parseTypeParamC(parser);
|
||||
// TYPE: if (failed(_result_v2))
|
||||
// TYPE: return {};
|
||||
// TYPE: } else {
|
||||
// TYPE: return {};
|
||||
// TYPE: }
|
||||
// TYPE: if ((_index != 2 - 1) && parser.parseComma())
|
||||
// TYPE: return {};
|
||||
// TYPE: }
|
||||
// TYPE: _result_v3 = someFcnCall();
|
||||
// TYPE: if (failed(_result_v3))
|
||||
// TYPE: return {};
|
||||
// TYPE: return TestDType::get(parser.getContext(),
|
||||
// TYPE: _result_v0.getValue(),
|
||||
// TYPE: _result_v1.getValue(),
|
||||
// TYPE: _result_v2.getValue(),
|
||||
// TYPE: _result_v3.getValue());
|
||||
// TYPE: }
|
||||
|
||||
// TYPE: void TestDType::print(::mlir::DialectAsmPrinter &printer) const {
|
||||
// TYPE: printer << getV0();
|
||||
// TYPE: myPrinter(getV1());
|
||||
// TYPE: printer << ' ' << "v2";
|
||||
// TYPE: printer << ' ' << "=";
|
||||
// TYPE: printer << ' ';
|
||||
// TYPE: printer << getV2();
|
||||
// TYPE: myPrinter(getV3());
|
||||
// TYPE: }
|
||||
|
||||
def TypeB : TestType<"TestD"> {
|
||||
let parameters = (ins
|
||||
TypeParamA:$v0,
|
||||
TypeParamB:$v1,
|
||||
TypeParamA:$v2,
|
||||
TypeParamB:$v3
|
||||
);
|
||||
|
||||
let mnemonic = "type_d";
|
||||
let assemblyFormat = "`<` `foo` `:` $v0 `,` struct($v1, $v2) `,` $v3 `>`";
|
||||
}
|
||||
|
||||
/// Type test with two struct directives has correctly generated parser and
|
||||
/// printer.
|
||||
|
||||
// TYPE: ::mlir::Type TestEType::parse(::mlir::DialectAsmParser &parser) {
|
||||
// TYPE: FailureOr<IntegerAttr> _result_v0;
|
||||
// TYPE: FailureOr<IntegerAttr> _result_v1;
|
||||
// TYPE: FailureOr<IntegerAttr> _result_v2;
|
||||
// TYPE: FailureOr<IntegerAttr> _result_v3;
|
||||
// TYPE: bool _seen_v0 = false;
|
||||
// TYPE: bool _seen_v2 = false;
|
||||
// TYPE: for (unsigned _index = 0; _index < 2; ++_index) {
|
||||
// TYPE: StringRef _paramKey;
|
||||
// TYPE: if (parser.parseKeyword(&_paramKey))
|
||||
// TYPE: return {};
|
||||
// TYPE: if (parser.parseEqual())
|
||||
// TYPE: return {};
|
||||
// TYPE: if (!_seen_v0 && _paramKey == "v0") {
|
||||
// TYPE: _seen_v0 = true;
|
||||
// TYPE: _result_v0 = ::mlir::FieldParser<IntegerAttr>::parse(parser);
|
||||
// TYPE: if (failed(_result_v0))
|
||||
// TYPE: return {};
|
||||
// TYPE: } else if (!_seen_v2 && _paramKey == "v2") {
|
||||
// TYPE: _seen_v2 = true;
|
||||
// TYPE: _result_v2 = ::mlir::FieldParser<IntegerAttr>::parse(parser);
|
||||
// TYPE: if (failed(_result_v2))
|
||||
// TYPE: return {};
|
||||
// TYPE: } else {
|
||||
// TYPE: return {};
|
||||
// TYPE: }
|
||||
// TYPE: if ((_index != 2 - 1) && parser.parseComma())
|
||||
// TYPE: return {};
|
||||
// TYPE: }
|
||||
// TYPE: bool _seen_v1 = false;
|
||||
// TYPE: bool _seen_v3 = false;
|
||||
// TYPE: for (unsigned _index = 0; _index < 2; ++_index) {
|
||||
// TYPE: StringRef _paramKey;
|
||||
// TYPE: if (parser.parseKeyword(&_paramKey))
|
||||
// TYPE: return {};
|
||||
// TYPE: if (parser.parseEqual())
|
||||
// TYPE: return {};
|
||||
// TYPE: if (!_seen_v1 && _paramKey == "v1") {
|
||||
// TYPE: _seen_v1 = true;
|
||||
// TYPE: _result_v1 = ::mlir::FieldParser<IntegerAttr>::parse(parser);
|
||||
// TYPE: if (failed(_result_v1))
|
||||
// TYPE: return {};
|
||||
// TYPE: } else if (!_seen_v3 && _paramKey == "v3") {
|
||||
// TYPE: _seen_v3 = true;
|
||||
// TYPE: _result_v3 = ::mlir::FieldParser<IntegerAttr>::parse(parser);
|
||||
// TYPE: if (failed(_result_v3))
|
||||
// TYPE: return {};
|
||||
// TYPE: } else {
|
||||
// TYPE: return {};
|
||||
// TYPE: }
|
||||
// TYPE: if ((_index != 2 - 1) && parser.parseComma())
|
||||
// TYPE: return {};
|
||||
// TYPE: }
|
||||
// TYPE: return TestEType::get(parser.getContext(),
|
||||
// TYPE: _result_v0.getValue(),
|
||||
// TYPE: _result_v1.getValue(),
|
||||
// TYPE: _result_v2.getValue(),
|
||||
// TYPE: _result_v3.getValue());
|
||||
// TYPE: }
|
||||
|
||||
// TYPE: void TestEType::print(::mlir::DialectAsmPrinter &printer) const {
|
||||
// TYPE: printer << "v0";
|
||||
// TYPE: printer << ' ' << "=";
|
||||
// TYPE: printer << ' ';
|
||||
// TYPE: printer << getV0();
|
||||
// TYPE: printer << ",";
|
||||
// TYPE: printer << ' ' << "v2";
|
||||
// TYPE: printer << ' ' << "=";
|
||||
// TYPE: printer << ' ';
|
||||
// TYPE: printer << getV2();
|
||||
// TYPE: printer << "v1";
|
||||
// TYPE: printer << ' ' << "=";
|
||||
// TYPE: printer << ' ';
|
||||
// TYPE: printer << getV1();
|
||||
// TYPE: printer << ",";
|
||||
// TYPE: printer << ' ' << "v3";
|
||||
// TYPE: printer << ' ' << "=";
|
||||
// TYPE: printer << ' ';
|
||||
// TYPE: printer << getV3();
|
||||
// TYPE: }
|
||||
|
||||
def TypeC : TestType<"TestE"> {
|
||||
let parameters = (ins
|
||||
"IntegerAttr":$v0,
|
||||
"IntegerAttr":$v1,
|
||||
"IntegerAttr":$v2,
|
||||
"IntegerAttr":$v3
|
||||
);
|
||||
|
||||
let mnemonic = "type_e";
|
||||
let assemblyFormat = "`{` struct($v0, $v2) `}` `{` struct($v1, $v3) `}`";
|
||||
}
|
||||
@@ -6,6 +6,7 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "AttrOrTypeFormatGen.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/TableGen/AttrOrTypeDef.h"
|
||||
#include "mlir/TableGen/CodeGenHelpers.h"
|
||||
@@ -24,6 +25,17 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utility Functions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
std::string mlir::tblgen::getParameterAccessorName(StringRef name) {
|
||||
assert(!name.empty() && "parameter has empty name");
|
||||
auto ret = "get" + name.str();
|
||||
ret[3] = llvm::toUpper(ret[3]); // uppercase first letter of the name
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Find all the AttrOrTypeDef for the specified dialect. If no dialect
|
||||
/// specified and can only find one dialect's defs, use that.
|
||||
static void collectAllDefs(StringRef selectedDialect,
|
||||
@@ -399,7 +411,8 @@ void DefGenerator::emitDefDecl(const AttrOrTypeDef &def) {
|
||||
<< " }\n";
|
||||
|
||||
// If mnemonic specified, emit print/parse declarations.
|
||||
if (def.getParserCode() || def.getPrinterCode() || !params.empty()) {
|
||||
if (def.getParserCode() || def.getPrinterCode() ||
|
||||
def.getAssemblyFormat() || !params.empty()) {
|
||||
os << llvm::formatv(defDeclParsePrintStr, valueType,
|
||||
isAttrGenerator ? ", ::mlir::Type type" : "");
|
||||
}
|
||||
@@ -410,10 +423,8 @@ void DefGenerator::emitDefDecl(const AttrOrTypeDef &def) {
|
||||
def.getParameters(parameters);
|
||||
|
||||
for (AttrOrTypeParameter ¶meter : parameters) {
|
||||
SmallString<16> name = parameter.getName();
|
||||
name[0] = llvm::toUpper(name[0]);
|
||||
os << formatv(" {0} get{1}() const;\n", parameter.getCppAccessorType(),
|
||||
name);
|
||||
os << formatv(" {0} {1}() const;\n", parameter.getCppAccessorType(),
|
||||
getParameterAccessorName(parameter.getName()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -700,8 +711,32 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
|
||||
}
|
||||
|
||||
void DefGenerator::emitParsePrint(const AttrOrTypeDef &def) {
|
||||
auto printerCode = def.getPrinterCode();
|
||||
auto parserCode = def.getParserCode();
|
||||
auto assemblyFormat = def.getAssemblyFormat();
|
||||
if (assemblyFormat && (printerCode || parserCode)) {
|
||||
// Custom assembly format cannot be specified at the same time as either
|
||||
// custom printer or parser code.
|
||||
PrintFatalError(def.getLoc(),
|
||||
def.getName() + ": assembly format cannot be specified at "
|
||||
"the same time as printer or parser code");
|
||||
}
|
||||
|
||||
// Generate a parser and printer based on the assembly format, if specified.
|
||||
if (assemblyFormat) {
|
||||
// A custom assembly format requires accessors to be generated for the
|
||||
// generated printer.
|
||||
if (!def.genAccessors()) {
|
||||
PrintFatalError(def.getLoc(),
|
||||
def.getName() +
|
||||
": the generated printer from 'assemblyFormat' "
|
||||
"requires 'genAccessors' to be true");
|
||||
}
|
||||
return generateAttrOrTypeFormat(def, os);
|
||||
}
|
||||
|
||||
// Emit the printer code, if specified.
|
||||
if (Optional<StringRef> printerCode = def.getPrinterCode()) {
|
||||
if (printerCode) {
|
||||
// Both the mnenomic and printerCode must be defined (for parity with
|
||||
// parserCode).
|
||||
os << "void " << def.getCppClassName()
|
||||
@@ -717,7 +752,7 @@ void DefGenerator::emitParsePrint(const AttrOrTypeDef &def) {
|
||||
}
|
||||
|
||||
// Emit the parser code, if specified.
|
||||
if (Optional<StringRef> parserCode = def.getParserCode()) {
|
||||
if (parserCode) {
|
||||
FmtContext fmtCtxt;
|
||||
fmtCtxt.addSubst("_parser", "parser")
|
||||
.addSubst("_ctxt", "parser.getContext()");
|
||||
@@ -857,11 +892,10 @@ void DefGenerator::emitDefDef(const AttrOrTypeDef &def) {
|
||||
paramStorageName = param.getName();
|
||||
}
|
||||
|
||||
SmallString<16> name = param.getName();
|
||||
name[0] = llvm::toUpper(name[0]);
|
||||
os << formatv("{0} {3}::get{1}() const {{ return getImpl()->{2}; }\n",
|
||||
param.getCppAccessorType(), name, paramStorageName,
|
||||
def.getCppClassName());
|
||||
os << formatv("{0} {3}::{1}() const {{ return getImpl()->{2}; }\n",
|
||||
param.getCppAccessorType(),
|
||||
getParameterAccessorName(param.getName()),
|
||||
paramStorageName, def.getCppClassName());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
781
mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
Normal file
781
mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
Normal file
@@ -0,0 +1,781 @@
|
||||
//===- AttrOrTypeFormatGen.cpp - MLIR attribute and type format generator -===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "AttrOrTypeFormatGen.h"
|
||||
#include "FormatGen.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/TableGen/AttrOrTypeDef.h"
|
||||
#include "mlir/TableGen/Format.h"
|
||||
#include "mlir/TableGen/GenInfo.h"
|
||||
#include "llvm/ADT/BitVector.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/Support/MemoryBuffer.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/TableGen/Error.h"
|
||||
#include "llvm/TableGen/TableGenBackend.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
|
||||
using llvm::formatv;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Element
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
|
||||
/// This class represents a single format element.
|
||||
class Element {
|
||||
public:
|
||||
/// LLVM-style RTTI.
|
||||
enum class Kind {
|
||||
/// This element is a directive.
|
||||
ParamsDirective,
|
||||
StructDirective,
|
||||
|
||||
/// This element is a literal.
|
||||
Literal,
|
||||
|
||||
/// This element is a variable.
|
||||
Variable,
|
||||
};
|
||||
Element(Kind kind) : kind(kind) {}
|
||||
virtual ~Element() = default;
|
||||
|
||||
/// Return the kind of this element.
|
||||
Kind getKind() const { return kind; }
|
||||
|
||||
private:
|
||||
/// The kind of this element.
|
||||
Kind kind;
|
||||
};
|
||||
|
||||
/// This class represents an instance of a literal element.
|
||||
class LiteralElement : public Element {
|
||||
public:
|
||||
LiteralElement(StringRef literal)
|
||||
: Element(Kind::Literal), literal(literal) {}
|
||||
|
||||
static bool classof(const Element *el) {
|
||||
return el->getKind() == Kind::Literal;
|
||||
}
|
||||
|
||||
/// Get the literal spelling.
|
||||
StringRef getSpelling() const { return literal; }
|
||||
|
||||
private:
|
||||
/// The spelling of the literal for this element.
|
||||
StringRef literal;
|
||||
};
|
||||
|
||||
/// This class represents an instance of a variable element. A variable refers
|
||||
/// to an attribute or type parameter.
|
||||
class VariableElement : public Element {
|
||||
public:
|
||||
VariableElement(AttrOrTypeParameter param)
|
||||
: Element(Kind::Variable), param(param) {}
|
||||
|
||||
static bool classof(const Element *el) {
|
||||
return el->getKind() == Kind::Variable;
|
||||
}
|
||||
|
||||
/// Get the parameter in the element.
|
||||
const AttrOrTypeParameter &getParam() const { return param; }
|
||||
|
||||
private:
|
||||
AttrOrTypeParameter param;
|
||||
};
|
||||
|
||||
/// Base class for a directive that contains references to multiple variables.
|
||||
template <Element::Kind ElementKind>
|
||||
class ParamsDirectiveBase : public Element {
|
||||
public:
|
||||
using Base = ParamsDirectiveBase<ElementKind>;
|
||||
|
||||
ParamsDirectiveBase(SmallVector<std::unique_ptr<Element>> &¶ms)
|
||||
: Element(ElementKind), params(std::move(params)) {}
|
||||
|
||||
static bool classof(const Element *el) {
|
||||
return el->getKind() == ElementKind;
|
||||
}
|
||||
|
||||
/// Get the parameters contained in this directive.
|
||||
auto getParams() const {
|
||||
return llvm::map_range(params, [](auto &el) {
|
||||
return cast<VariableElement>(el.get())->getParam();
|
||||
});
|
||||
}
|
||||
|
||||
/// Get the number of parameters.
|
||||
unsigned getNumParams() const { return params.size(); }
|
||||
|
||||
/// Take all of the parameters from this directive.
|
||||
SmallVector<std::unique_ptr<Element>> takeParams() {
|
||||
return std::move(params);
|
||||
}
|
||||
|
||||
private:
|
||||
/// The parameters captured by this directive.
|
||||
SmallVector<std::unique_ptr<Element>> params;
|
||||
};
|
||||
|
||||
/// This class represents a `params` directive that refers to all parameters
|
||||
/// of an attribute or type. When used as a top-level directive, it generates
|
||||
/// a format of the form:
|
||||
///
|
||||
/// (param-value (`,` param-value)*)?
|
||||
///
|
||||
/// When used as an argument to another directive that accepts variables,
|
||||
/// `params` can be used in place of manually listing all parameters of an
|
||||
/// attribute or type.
|
||||
class ParamsDirective
|
||||
: public ParamsDirectiveBase<Element::Kind::ParamsDirective> {
|
||||
public:
|
||||
using Base::Base;
|
||||
};
|
||||
|
||||
/// This class represents a `struct` directive that generates a struct format
|
||||
/// of the form:
|
||||
///
|
||||
/// `{` param-name `=` param-value (`,` param-name `=` param-value)* `}`
|
||||
///
|
||||
class StructDirective
|
||||
: public ParamsDirectiveBase<Element::Kind::StructDirective> {
|
||||
public:
|
||||
using Base::Base;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Format Strings
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Format for defining an attribute parser.
|
||||
///
|
||||
/// $0: The attribute C++ class name.
|
||||
static const char *const attrParserDefn = R"(
|
||||
::mlir::Attribute $0::parse(::mlir::DialectAsmParser &$_parser,
|
||||
::mlir::Type $_type) {
|
||||
)";
|
||||
|
||||
/// Format for defining a type parser.
|
||||
///
|
||||
/// $0: The type C++ class name.
|
||||
static const char *const typeParserDefn = R"(
|
||||
::mlir::Type $0::parse(::mlir::DialectAsmParser &$_parser) {
|
||||
)";
|
||||
|
||||
/// Default parser for attribute or type parameters.
|
||||
static const char *const defaultParameterParser =
|
||||
"::mlir::FieldParser<$0>::parse($_parser)";
|
||||
|
||||
/// Default printer for attribute or type parameters.
|
||||
static const char *const defaultParameterPrinter = "$_printer << $_self";
|
||||
|
||||
/// Print an error when failing to parse an element.
|
||||
///
|
||||
/// $0: The parameter C++ class name.
|
||||
static const char *const parseErrorStr =
|
||||
"$_parser.emitError($_parser.getCurrentLocation(), ";
|
||||
|
||||
/// Format for defining an attribute or type printer.
|
||||
///
|
||||
/// $0: The attribute or type C++ class name.
|
||||
/// $1: The attribute or type mnemonic.
|
||||
static const char *const attrOrTypePrinterDefn = R"(
|
||||
void $0::print(::mlir::DialectAsmPrinter &$_printer) const {
|
||||
$_printer << "$1";
|
||||
)";
|
||||
|
||||
/// Loop declaration for struct parser.
|
||||
///
|
||||
/// $0: Number of expected parameters.
|
||||
static const char *const structParseLoopStart = R"(
|
||||
for (unsigned _index = 0; _index < $0; ++_index) {
|
||||
StringRef _paramKey;
|
||||
if ($_parser.parseKeyword(&_paramKey)) {
|
||||
$_parser.emitError($_parser.getCurrentLocation(),
|
||||
"expected a parameter name in struct");
|
||||
return {};
|
||||
}
|
||||
)";
|
||||
|
||||
/// Terminator code segment for the struct parser loop. Check for duplicate or
|
||||
/// unknown parameters. Parse a comma except on the last element.
|
||||
///
|
||||
/// {0}: Code template for printing an error.
|
||||
/// {1}: Number of elements in the struct.
|
||||
static const char *const structParseLoopEnd = R"({{
|
||||
{0}"duplicate or unknown struct parameter name: ") << _paramKey;
|
||||
return {{};
|
||||
}
|
||||
if ((_index != {1} - 1) && parser.parseComma())
|
||||
return {{};
|
||||
}
|
||||
)";
|
||||
|
||||
/// Code format to parse a variable. Separate by lines because variable parsers
|
||||
/// may be generated inside other directives, which requires indentation.
|
||||
///
|
||||
/// {0}: The parameter name.
|
||||
/// {1}: The parse code for the parameter.
|
||||
/// {2}: Code template for printing an error.
|
||||
/// {3}: Name of the attribute or type.
|
||||
/// {4}: C++ class of the parameter.
|
||||
static const char *const variableParser[] = {
|
||||
" // Parse variable '{0}'",
|
||||
" _result_{0} = {1};",
|
||||
" if (failed(_result_{0})) {{",
|
||||
" {2}\"failed to parse {3} parameter '{0}' which is to be a `{4}`\");",
|
||||
" return {{};",
|
||||
" }",
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utility Functions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Get a list of an attribute's or type's parameters. These can be wrapper
|
||||
/// objects around `AttrOrTypeParameter` or string inits.
|
||||
static auto getParameters(const AttrOrTypeDef &def) {
|
||||
SmallVector<AttrOrTypeParameter> params;
|
||||
def.getParameters(params);
|
||||
return params;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AttrOrTypeFormat
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class AttrOrTypeFormat {
|
||||
public:
|
||||
AttrOrTypeFormat(const AttrOrTypeDef &def,
|
||||
std::vector<std::unique_ptr<Element>> &&elements)
|
||||
: def(def), elements(std::move(elements)) {}
|
||||
|
||||
/// Generate the attribute or type parser.
|
||||
void genParser(raw_ostream &os);
|
||||
/// Generate the attribute or type printer.
|
||||
void genPrinter(raw_ostream &os);
|
||||
|
||||
private:
|
||||
/// Generate the parser code for a specific format element.
|
||||
void genElementParser(Element *el, FmtContext &ctx, raw_ostream &os);
|
||||
/// Generate the parser code for a literal.
|
||||
void genLiteralParser(StringRef value, FmtContext &ctx, raw_ostream &os,
|
||||
unsigned indent = 0);
|
||||
/// Generate the parser code for a variable.
|
||||
void genVariableParser(const AttrOrTypeParameter ¶m, FmtContext &ctx,
|
||||
raw_ostream &os, unsigned indent = 0);
|
||||
/// Generate the parser code for a `params` directive.
|
||||
void genParamsParser(ParamsDirective *el, FmtContext &ctx, raw_ostream &os);
|
||||
/// Generate the parser code for a `struct` directive.
|
||||
void genStructParser(StructDirective *el, FmtContext &ctx, raw_ostream &os);
|
||||
|
||||
/// Generate the printer code for a specific format element.
|
||||
void genElementPrinter(Element *el, FmtContext &ctx, raw_ostream &os);
|
||||
/// Generate the printer code for a literal.
|
||||
void genLiteralPrinter(StringRef value, FmtContext &ctx, raw_ostream &os);
|
||||
/// Generate the printer code for a variable.
|
||||
void genVariablePrinter(const AttrOrTypeParameter ¶m, FmtContext &ctx,
|
||||
raw_ostream &os);
|
||||
/// Generate the printer code for a `params` directive.
|
||||
void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, raw_ostream &os);
|
||||
/// Generate the printer code for a `struct` directive.
|
||||
void genStructPrinter(StructDirective *el, FmtContext &ctx, raw_ostream &os);
|
||||
|
||||
/// The ODS definition of the attribute or type whose format is being used to
|
||||
/// generate a parser and printer.
|
||||
const AttrOrTypeDef &def;
|
||||
/// The list of top-level format elements returned by the assembly format
|
||||
/// parser.
|
||||
std::vector<std::unique_ptr<Element>> elements;
|
||||
|
||||
/// Flags for printing spaces.
|
||||
bool shouldEmitSpace;
|
||||
bool lastWasPunctuation;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ParserGen
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void AttrOrTypeFormat::genParser(raw_ostream &os) {
|
||||
FmtContext ctx;
|
||||
ctx.addSubst("_parser", "parser");
|
||||
|
||||
/// Generate the definition.
|
||||
if (isa<AttrDef>(def)) {
|
||||
ctx.addSubst("_type", "attrType");
|
||||
os << tgfmt(attrParserDefn, &ctx, def.getCppClassName());
|
||||
} else {
|
||||
os << tgfmt(typeParserDefn, &ctx, def.getCppClassName());
|
||||
}
|
||||
|
||||
/// Declare variables to store all of the parameters. Allocated parameters
|
||||
/// such as `ArrayRef` and `StringRef` must provide a `storageType`. Store
|
||||
/// FailureOr<T> to defer type construction for parameters that are parsed in
|
||||
/// a loop (parsers return FailureOr anyways).
|
||||
SmallVector<AttrOrTypeParameter> params = getParameters(def);
|
||||
for (const AttrOrTypeParameter ¶m : params) {
|
||||
os << formatv(" ::mlir::FailureOr<{0}> _result_{1};\n",
|
||||
param.getCppStorageType(), param.getName());
|
||||
}
|
||||
|
||||
/// Store the initial location of the parser.
|
||||
ctx.addSubst("_loc", "loc");
|
||||
os << tgfmt(" ::llvm::SMLoc $_loc = $_parser.getCurrentLocation();\n"
|
||||
" (void) $_loc;\n",
|
||||
&ctx);
|
||||
|
||||
/// Generate call to each parameter parser.
|
||||
for (auto &el : elements)
|
||||
genElementParser(el.get(), ctx, os);
|
||||
|
||||
/// Generate call to the attribute or type builder. Use the checked getter
|
||||
/// if one was generated.
|
||||
if (def.genVerifyDecl()) {
|
||||
os << tgfmt(" return $_parser.getChecked<$0>($_loc, $_parser.getContext()",
|
||||
&ctx, def.getCppClassName());
|
||||
} else {
|
||||
os << tgfmt(" return $0::get($_parser.getContext()", &ctx,
|
||||
def.getCppClassName());
|
||||
}
|
||||
for (const AttrOrTypeParameter ¶m : params)
|
||||
os << formatv(",\n _result_{0}.getValue()", param.getName());
|
||||
os << ");\n}\n\n";
|
||||
}
|
||||
|
||||
void AttrOrTypeFormat::genElementParser(Element *el, FmtContext &ctx,
|
||||
raw_ostream &os) {
|
||||
if (auto *literal = dyn_cast<LiteralElement>(el))
|
||||
return genLiteralParser(literal->getSpelling(), ctx, os);
|
||||
if (auto *var = dyn_cast<VariableElement>(el))
|
||||
return genVariableParser(var->getParam(), ctx, os);
|
||||
if (auto *params = dyn_cast<ParamsDirective>(el))
|
||||
return genParamsParser(params, ctx, os);
|
||||
if (auto *strct = dyn_cast<StructDirective>(el))
|
||||
return genStructParser(strct, ctx, os);
|
||||
|
||||
llvm_unreachable("unknown format element");
|
||||
}
|
||||
|
||||
void AttrOrTypeFormat::genLiteralParser(StringRef value, FmtContext &ctx,
|
||||
raw_ostream &os, unsigned indent) {
|
||||
os.indent(indent) << " // Parse literal '" << value << "'\n";
|
||||
os.indent(indent) << tgfmt(" if ($_parser.parse", &ctx);
|
||||
if (value.front() == '_' || isalpha(value.front())) {
|
||||
os << "Keyword(\"" << value << "\")";
|
||||
} else {
|
||||
os << StringSwitch<StringRef>(value)
|
||||
.Case("->", "Arrow")
|
||||
.Case(":", "Colon")
|
||||
.Case(",", "Comma")
|
||||
.Case("=", "Equal")
|
||||
.Case("<", "Less")
|
||||
.Case(">", "Greater")
|
||||
.Case("{", "LBrace")
|
||||
.Case("}", "RBrace")
|
||||
.Case("(", "LParen")
|
||||
.Case(")", "RParen")
|
||||
.Case("[", "LSquare")
|
||||
.Case("]", "RSquare")
|
||||
.Case("?", "Question")
|
||||
.Case("+", "Plus")
|
||||
.Case("*", "Star")
|
||||
<< "()";
|
||||
}
|
||||
os << ")\n";
|
||||
// Parser will emit an error
|
||||
os.indent(indent) << " return {};\n";
|
||||
}
|
||||
|
||||
void AttrOrTypeFormat::genVariableParser(const AttrOrTypeParameter ¶m,
|
||||
FmtContext &ctx, raw_ostream &os,
|
||||
unsigned indent) {
|
||||
/// Check for a custom parser. Use the default attribute parser otherwise.
|
||||
auto customParser = param.getParser();
|
||||
auto parser =
|
||||
customParser ? *customParser : StringRef(defaultParameterParser);
|
||||
for (const char *line : variableParser) {
|
||||
os.indent(indent) << formatv(line, param.getName(),
|
||||
tgfmt(parser, &ctx, param.getCppStorageType()),
|
||||
tgfmt(parseErrorStr, &ctx), def.getName(),
|
||||
param.getCppType())
|
||||
<< "\n";
|
||||
}
|
||||
}
|
||||
|
||||
void AttrOrTypeFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,
|
||||
raw_ostream &os) {
|
||||
os << " // Parse parameter list\n";
|
||||
llvm::interleave(
|
||||
el->getParams(), [&](auto param) { genVariableParser(param, ctx, os); },
|
||||
[&]() { genLiteralParser(",", ctx, os); });
|
||||
}
|
||||
|
||||
void AttrOrTypeFormat::genStructParser(StructDirective *el, FmtContext &ctx,
|
||||
raw_ostream &os) {
|
||||
os << " // Parse parameter struct\n";
|
||||
|
||||
/// Declare a "seen" variable for each key.
|
||||
for (const AttrOrTypeParameter ¶m : el->getParams())
|
||||
os << formatv(" bool _seen_{0} = false;\n", param.getName());
|
||||
|
||||
/// Generate the parsing loop.
|
||||
os << tgfmt(structParseLoopStart, &ctx, el->getNumParams());
|
||||
genLiteralParser("=", ctx, os, 2);
|
||||
os << " ";
|
||||
for (const AttrOrTypeParameter ¶m : el->getParams()) {
|
||||
os << formatv("if (!_seen_{0} && _paramKey == \"{0}\") {\n"
|
||||
" _seen_{0} = true;\n",
|
||||
param.getName());
|
||||
genVariableParser(param, ctx, os, 4);
|
||||
os << " } else ";
|
||||
}
|
||||
|
||||
/// Duplicate or unknown parameter.
|
||||
os << formatv(structParseLoopEnd, tgfmt(parseErrorStr, &ctx),
|
||||
el->getNumParams());
|
||||
|
||||
/// Because the loop loops N times and each non-failing iteration sets 1 of
|
||||
/// N flags, successfully exiting the loop means that all parameters have been
|
||||
/// seen. `parseOptionalComma` would cause issues with any formats that use
|
||||
/// "struct(...) `,`" beacuse structs aren't sounded by braces.
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PrinterGen
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void AttrOrTypeFormat::genPrinter(raw_ostream &os) {
|
||||
FmtContext ctx;
|
||||
ctx.addSubst("_printer", "printer");
|
||||
|
||||
/// Generate the definition.
|
||||
os << tgfmt(attrOrTypePrinterDefn, &ctx, def.getCppClassName(),
|
||||
*def.getMnemonic());
|
||||
|
||||
/// Generate printers.
|
||||
shouldEmitSpace = true;
|
||||
lastWasPunctuation = false;
|
||||
for (auto &el : elements)
|
||||
genElementPrinter(el.get(), ctx, os);
|
||||
|
||||
os << "}\n\n";
|
||||
}
|
||||
|
||||
void AttrOrTypeFormat::genElementPrinter(Element *el, FmtContext &ctx,
|
||||
raw_ostream &os) {
|
||||
if (auto *literal = dyn_cast<LiteralElement>(el))
|
||||
return genLiteralPrinter(literal->getSpelling(), ctx, os);
|
||||
if (auto *params = dyn_cast<ParamsDirective>(el))
|
||||
return genParamsPrinter(params, ctx, os);
|
||||
if (auto *strct = dyn_cast<StructDirective>(el))
|
||||
return genStructPrinter(strct, ctx, os);
|
||||
if (auto *var = dyn_cast<VariableElement>(el))
|
||||
return genVariablePrinter(var->getParam(), ctx, os);
|
||||
|
||||
llvm_unreachable("unknown format element");
|
||||
}
|
||||
|
||||
void AttrOrTypeFormat::genLiteralPrinter(StringRef value, FmtContext &ctx,
|
||||
raw_ostream &os) {
|
||||
/// Don't insert a space before certain punctuation.
|
||||
bool needSpace =
|
||||
shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation);
|
||||
os << tgfmt(" $_printer$0 << \"$1\";\n", &ctx, needSpace ? " << ' '" : "",
|
||||
value);
|
||||
|
||||
/// Update the flags.
|
||||
shouldEmitSpace =
|
||||
value.size() != 1 || !StringRef("<({[").contains(value.front());
|
||||
lastWasPunctuation = !(value.front() == '_' || isalpha(value.front()));
|
||||
}
|
||||
|
||||
void AttrOrTypeFormat::genVariablePrinter(const AttrOrTypeParameter ¶m,
|
||||
FmtContext &ctx, raw_ostream &os) {
|
||||
/// Insert a space before the next parameter, if necessary.
|
||||
if (shouldEmitSpace || !lastWasPunctuation)
|
||||
os << tgfmt(" $_printer << ' ';\n", &ctx);
|
||||
shouldEmitSpace = true;
|
||||
lastWasPunctuation = false;
|
||||
|
||||
ctx.withSelf(getParameterAccessorName(param.getName()) + "()");
|
||||
os << " ";
|
||||
if (auto printer = param.getPrinter())
|
||||
os << tgfmt(*printer, &ctx) << ";\n";
|
||||
else
|
||||
os << tgfmt(defaultParameterPrinter, &ctx) << ";\n";
|
||||
}
|
||||
|
||||
void AttrOrTypeFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx,
|
||||
raw_ostream &os) {
|
||||
llvm::interleave(
|
||||
el->getParams(), [&](auto param) { genVariablePrinter(param, ctx, os); },
|
||||
[&]() { genLiteralPrinter(",", ctx, os); });
|
||||
}
|
||||
|
||||
void AttrOrTypeFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
|
||||
raw_ostream &os) {
|
||||
llvm::interleave(
|
||||
el->getParams(),
|
||||
[&](auto param) {
|
||||
genLiteralPrinter(param.getName(), ctx, os);
|
||||
genLiteralPrinter("=", ctx, os);
|
||||
os << tgfmt(" $_printer << ' ';\n", &ctx);
|
||||
genVariablePrinter(param, ctx, os);
|
||||
},
|
||||
[&]() { genLiteralPrinter(",", ctx, os); });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FormatParser
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class FormatParser {
|
||||
public:
|
||||
FormatParser(llvm::SourceMgr &mgr, const AttrOrTypeDef &def)
|
||||
: lexer(mgr, def.getLoc()[0]), curToken(lexer.lexToken()), def(def),
|
||||
seenParams(def.getNumParameters()) {}
|
||||
|
||||
/// Parse the attribute or type format and create the format elements.
|
||||
FailureOr<AttrOrTypeFormat> parse();
|
||||
|
||||
private:
|
||||
/// The current context of the parser when parsing an element.
|
||||
enum ParserContext {
|
||||
/// The element is being parsed in the default context - at the top of the
|
||||
/// format
|
||||
TopLevelContext,
|
||||
/// The element is being parsed as a child to a `struct` directive.
|
||||
StructDirective,
|
||||
};
|
||||
|
||||
/// Emit an error.
|
||||
LogicalResult emitError(const Twine &msg) {
|
||||
lexer.emitError(curToken.getLoc(), msg);
|
||||
return failure();
|
||||
}
|
||||
|
||||
/// Parse an expected token.
|
||||
LogicalResult parseToken(FormatToken::Kind kind, const Twine &msg) {
|
||||
if (curToken.getKind() != kind)
|
||||
return emitError(msg);
|
||||
consumeToken();
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Advance the lexer to the next token.
|
||||
void consumeToken() {
|
||||
assert(curToken.getKind() != FormatToken::eof &&
|
||||
curToken.getKind() != FormatToken::error &&
|
||||
"shouldn't advance past EOF or errors");
|
||||
curToken = lexer.lexToken();
|
||||
}
|
||||
|
||||
/// Parse any element.
|
||||
FailureOr<std::unique_ptr<Element>> parseElement(ParserContext ctx);
|
||||
/// Parse a literal element.
|
||||
FailureOr<std::unique_ptr<Element>> parseLiteral(ParserContext ctx);
|
||||
/// Parse a variable element.
|
||||
FailureOr<std::unique_ptr<Element>> parseVariable(ParserContext ctx);
|
||||
/// Parse a directive.
|
||||
FailureOr<std::unique_ptr<Element>> parseDirective(ParserContext ctx);
|
||||
/// Parse a `params` directive.
|
||||
FailureOr<std::unique_ptr<Element>> parseParamsDirective();
|
||||
/// Parse a `struct` directive.
|
||||
FailureOr<std::unique_ptr<Element>> parseStructDirective();
|
||||
|
||||
/// The current format lexer.
|
||||
FormatLexer lexer;
|
||||
/// The current token in the stream.
|
||||
FormatToken curToken;
|
||||
/// Attribute or type tablegen def.
|
||||
const AttrOrTypeDef &def;
|
||||
|
||||
/// Seen attribute or type parameters.
|
||||
llvm::BitVector seenParams;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
FailureOr<AttrOrTypeFormat> FormatParser::parse() {
|
||||
std::vector<std::unique_ptr<Element>> elements;
|
||||
elements.reserve(16);
|
||||
|
||||
/// Parse the format elements.
|
||||
while (curToken.getKind() != FormatToken::eof) {
|
||||
auto element = parseElement(TopLevelContext);
|
||||
if (failed(element))
|
||||
return failure();
|
||||
|
||||
/// Add the format element and continue.
|
||||
elements.push_back(std::move(*element));
|
||||
}
|
||||
|
||||
/// Check that all parameters have been seen.
|
||||
SmallVector<AttrOrTypeParameter> params = getParameters(def);
|
||||
for (auto it : llvm::enumerate(params)) {
|
||||
if (!seenParams.test(it.index())) {
|
||||
return emitError("format is missing reference to parameter: " +
|
||||
it.value().getName());
|
||||
}
|
||||
}
|
||||
|
||||
return AttrOrTypeFormat(def, std::move(elements));
|
||||
}
|
||||
|
||||
FailureOr<std::unique_ptr<Element>>
|
||||
FormatParser::parseElement(ParserContext ctx) {
|
||||
if (curToken.getKind() == FormatToken::literal)
|
||||
return parseLiteral(ctx);
|
||||
if (curToken.getKind() == FormatToken::variable)
|
||||
return parseVariable(ctx);
|
||||
if (curToken.isKeyword())
|
||||
return parseDirective(ctx);
|
||||
|
||||
return emitError("expected literal, directive, or variable");
|
||||
}
|
||||
|
||||
FailureOr<std::unique_ptr<Element>>
|
||||
FormatParser::parseLiteral(ParserContext ctx) {
|
||||
if (ctx != TopLevelContext) {
|
||||
return emitError(
|
||||
"literals may only be used in the top-level section of the format");
|
||||
}
|
||||
|
||||
/// Get the literal spelling without the surrounding "`".
|
||||
auto value = curToken.getSpelling().drop_front().drop_back();
|
||||
if (!isValidLiteral(value))
|
||||
return emitError("literal '" + value + "' is not valid");
|
||||
|
||||
consumeToken();
|
||||
return {std::make_unique<LiteralElement>(value)};
|
||||
}
|
||||
|
||||
FailureOr<std::unique_ptr<Element>>
|
||||
FormatParser::parseVariable(ParserContext ctx) {
|
||||
/// Get the parameter name without the preceding "$".
|
||||
auto name = curToken.getSpelling().drop_front();
|
||||
|
||||
/// Lookup the parameter.
|
||||
SmallVector<AttrOrTypeParameter> params = getParameters(def);
|
||||
auto *it = llvm::find_if(
|
||||
params, [&](auto ¶m) { return param.getName() == name; });
|
||||
|
||||
/// Check that the parameter reference is valid.
|
||||
if (it == params.end())
|
||||
return emitError(def.getName() + " has no parameter named '" + name + "'");
|
||||
auto idx = std::distance(params.begin(), it);
|
||||
if (seenParams.test(idx))
|
||||
return emitError("duplicate parameter '" + name + "'");
|
||||
seenParams.set(idx);
|
||||
|
||||
consumeToken();
|
||||
return {std::make_unique<VariableElement>(*it)};
|
||||
}
|
||||
|
||||
FailureOr<std::unique_ptr<Element>>
|
||||
FormatParser::parseDirective(ParserContext ctx) {
|
||||
|
||||
switch (curToken.getKind()) {
|
||||
case FormatToken::kw_params:
|
||||
return parseParamsDirective();
|
||||
case FormatToken::kw_struct:
|
||||
if (ctx != TopLevelContext) {
|
||||
return emitError(
|
||||
"`struct` may only be used in the top-level section of the format");
|
||||
}
|
||||
return parseStructDirective();
|
||||
default:
|
||||
return emitError("unknown directive in format: " + curToken.getSpelling());
|
||||
}
|
||||
}
|
||||
|
||||
FailureOr<std::unique_ptr<Element>> FormatParser::parseParamsDirective() {
|
||||
consumeToken();
|
||||
/// Collect all of the attribute's or type's parameters.
|
||||
SmallVector<AttrOrTypeParameter> params = getParameters(def);
|
||||
SmallVector<std::unique_ptr<Element>> vars;
|
||||
/// Ensure that none of the parameters have already been captured.
|
||||
for (auto it : llvm::enumerate(params)) {
|
||||
if (seenParams.test(it.index())) {
|
||||
return emitError("`params` captures duplicate parameter: " +
|
||||
it.value().getName());
|
||||
}
|
||||
seenParams.set(it.index());
|
||||
vars.push_back(std::make_unique<VariableElement>(it.value()));
|
||||
}
|
||||
return {std::make_unique<ParamsDirective>(std::move(vars))};
|
||||
}
|
||||
|
||||
FailureOr<std::unique_ptr<Element>> FormatParser::parseStructDirective() {
|
||||
consumeToken();
|
||||
if (failed(parseToken(FormatToken::l_paren,
|
||||
"expected '(' before `struct` argument list")))
|
||||
return failure();
|
||||
|
||||
/// Parse variables captured by `struct`.
|
||||
SmallVector<std::unique_ptr<Element>> vars;
|
||||
|
||||
/// Parse first captured parameter or a `params` directive.
|
||||
FailureOr<std::unique_ptr<Element>> var = parseElement(StructDirective);
|
||||
if (failed(var) || !isa<VariableElement, ParamsDirective>(*var))
|
||||
return emitError("`struct` argument list expected a variable or directive");
|
||||
if (isa<VariableElement>(*var)) {
|
||||
/// Parse any other parameters.
|
||||
vars.push_back(std::move(*var));
|
||||
while (curToken.getKind() == FormatToken::comma) {
|
||||
consumeToken();
|
||||
var = parseElement(StructDirective);
|
||||
if (failed(var) || !isa<VariableElement>(*var))
|
||||
return emitError("expected a variable in `struct` argument list");
|
||||
vars.push_back(std::move(*var));
|
||||
}
|
||||
} else {
|
||||
/// `struct(params)` captures all parameters in the attribute or type.
|
||||
vars = cast<ParamsDirective>(var->get())->takeParams();
|
||||
}
|
||||
|
||||
if (curToken.getKind() != FormatToken::r_paren)
|
||||
return emitError("expected ')' at the end of an argument list");
|
||||
|
||||
consumeToken();
|
||||
return {std::make_unique<::StructDirective>(std::move(vars))};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Interface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void mlir::tblgen::generateAttrOrTypeFormat(const AttrOrTypeDef &def,
|
||||
raw_ostream &os) {
|
||||
llvm::SourceMgr mgr;
|
||||
mgr.AddNewSourceBuffer(
|
||||
llvm::MemoryBuffer::getMemBuffer(*def.getAssemblyFormat()),
|
||||
llvm::SMLoc());
|
||||
|
||||
/// Parse the custom assembly format>
|
||||
FormatParser parser(mgr, def);
|
||||
FailureOr<AttrOrTypeFormat> format = parser.parse();
|
||||
if (failed(format)) {
|
||||
if (formatErrorIsFatal)
|
||||
PrintFatalError(def.getLoc(), "failed to parse assembly format");
|
||||
return;
|
||||
}
|
||||
|
||||
/// Generate the parser and printer.
|
||||
format->genParser(os);
|
||||
format->genPrinter(os);
|
||||
}
|
||||
32
mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h
Normal file
32
mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h
Normal file
@@ -0,0 +1,32 @@
|
||||
//===- AttrOrTypeFormatGen.h - MLIR attribute and type format generator ---===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TOOLS_MLIRTBLGEN_ATTRORTYPEFORMATGEN_H_
|
||||
#define MLIR_TOOLS_MLIRTBLGEN_ATTRORTYPEFORMATGEN_H_
|
||||
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace mlir {
|
||||
namespace tblgen {
|
||||
class AttrOrTypeDef;
|
||||
|
||||
/// Generate a parser and printer based on a custom assembly format for an
|
||||
/// attribute or type.
|
||||
void generateAttrOrTypeFormat(const AttrOrTypeDef &def, llvm::raw_ostream &os);
|
||||
|
||||
/// From the parameter name, get the name of the accessor function in camelcase.
|
||||
/// The first letter of the parameter is upper-cased and prefixed with "get".
|
||||
/// E.g. 'value' -> 'getValue'.
|
||||
std::string getParameterAccessorName(llvm::StringRef name);
|
||||
|
||||
} // end namespace tblgen
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TOOLS_MLIRTBLGEN_ATTRORTYPEFORMATGEN_H_
|
||||
@@ -6,10 +6,12 @@ set(LLVM_LINK_COMPONENTS
|
||||
|
||||
add_tablegen(mlir-tblgen MLIR
|
||||
AttrOrTypeDefGen.cpp
|
||||
AttrOrTypeFormatGen.cpp
|
||||
CodeGenHelpers.cpp
|
||||
DialectGen.cpp
|
||||
DirectiveCommonGen.cpp
|
||||
EnumsGen.cpp
|
||||
FormatGen.cpp
|
||||
LLVMIRConversionGen.cpp
|
||||
LLVMIRIntrinsicGen.cpp
|
||||
mlir-tblgen.cpp
|
||||
|
||||
225
mlir/tools/mlir-tblgen/FormatGen.cpp
Normal file
225
mlir/tools/mlir-tblgen/FormatGen.cpp
Normal file
@@ -0,0 +1,225 @@
|
||||
//===- FormatGen.cpp - Utilities for custom assembly formats ----*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "FormatGen.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/TableGen/Error.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FormatToken
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
llvm::SMLoc FormatToken::getLoc() const {
|
||||
return llvm::SMLoc::getFromPointer(spelling.data());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FormatLexer
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
FormatLexer::FormatLexer(llvm::SourceMgr &mgr, llvm::SMLoc loc)
|
||||
: mgr(mgr), loc(loc),
|
||||
curBuffer(mgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer()),
|
||||
curPtr(curBuffer.begin()) {}
|
||||
|
||||
FormatToken FormatLexer::emitError(llvm::SMLoc loc, const Twine &msg) {
|
||||
mgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
|
||||
llvm::SrcMgr.PrintMessage(this->loc, llvm::SourceMgr::DK_Note,
|
||||
"in custom assembly format for this operation");
|
||||
return formToken(FormatToken::error, loc.getPointer());
|
||||
}
|
||||
|
||||
FormatToken FormatLexer::emitError(const char *loc, const Twine &msg) {
|
||||
return emitError(llvm::SMLoc::getFromPointer(loc), msg);
|
||||
}
|
||||
|
||||
FormatToken FormatLexer::emitErrorAndNote(llvm::SMLoc loc, const Twine &msg,
|
||||
const Twine ¬e) {
|
||||
mgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
|
||||
llvm::SrcMgr.PrintMessage(this->loc, llvm::SourceMgr::DK_Note,
|
||||
"in custom assembly format for this operation");
|
||||
mgr.PrintMessage(loc, llvm::SourceMgr::DK_Note, note);
|
||||
return formToken(FormatToken::error, loc.getPointer());
|
||||
}
|
||||
|
||||
int FormatLexer::getNextChar() {
|
||||
char curChar = *curPtr++;
|
||||
switch (curChar) {
|
||||
default:
|
||||
return (unsigned char)curChar;
|
||||
case 0: {
|
||||
// A nul character in the stream is either the end of the current buffer or
|
||||
// a random nul in the file. Disambiguate that here.
|
||||
if (curPtr - 1 != curBuffer.end())
|
||||
return 0;
|
||||
|
||||
// Otherwise, return end of file.
|
||||
--curPtr;
|
||||
return EOF;
|
||||
}
|
||||
case '\n':
|
||||
case '\r':
|
||||
// Handle the newline character by ignoring it and incrementing the line
|
||||
// count. However, be careful about 'dos style' files with \n\r in them.
|
||||
// Only treat a \n\r or \r\n as a single line.
|
||||
if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
|
||||
++curPtr;
|
||||
return '\n';
|
||||
}
|
||||
}
|
||||
|
||||
FormatToken FormatLexer::lexToken() {
|
||||
const char *tokStart = curPtr;
|
||||
|
||||
// This always consumes at least one character.
|
||||
int curChar = getNextChar();
|
||||
switch (curChar) {
|
||||
default:
|
||||
// Handle identifiers: [a-zA-Z_]
|
||||
if (isalpha(curChar) || curChar == '_')
|
||||
return lexIdentifier(tokStart);
|
||||
|
||||
// Unknown character, emit an error.
|
||||
return emitError(tokStart, "unexpected character");
|
||||
case EOF:
|
||||
// Return EOF denoting the end of lexing.
|
||||
return formToken(FormatToken::eof, tokStart);
|
||||
|
||||
// Lex punctuation.
|
||||
case '^':
|
||||
return formToken(FormatToken::caret, tokStart);
|
||||
case ':':
|
||||
return formToken(FormatToken::colon, tokStart);
|
||||
case ',':
|
||||
return formToken(FormatToken::comma, tokStart);
|
||||
case '=':
|
||||
return formToken(FormatToken::equal, tokStart);
|
||||
case '<':
|
||||
return formToken(FormatToken::less, tokStart);
|
||||
case '>':
|
||||
return formToken(FormatToken::greater, tokStart);
|
||||
case '?':
|
||||
return formToken(FormatToken::question, tokStart);
|
||||
case '(':
|
||||
return formToken(FormatToken::l_paren, tokStart);
|
||||
case ')':
|
||||
return formToken(FormatToken::r_paren, tokStart);
|
||||
case '*':
|
||||
return formToken(FormatToken::star, tokStart);
|
||||
|
||||
// Ignore whitespace characters.
|
||||
case 0:
|
||||
case ' ':
|
||||
case '\t':
|
||||
case '\n':
|
||||
return lexToken();
|
||||
|
||||
case '`':
|
||||
return lexLiteral(tokStart);
|
||||
case '$':
|
||||
return lexVariable(tokStart);
|
||||
}
|
||||
}
|
||||
|
||||
FormatToken FormatLexer::lexLiteral(const char *tokStart) {
|
||||
assert(curPtr[-1] == '`');
|
||||
|
||||
// Lex a literal surrounded by ``.
|
||||
while (const char curChar = *curPtr++) {
|
||||
if (curChar == '`')
|
||||
return formToken(FormatToken::literal, tokStart);
|
||||
}
|
||||
return emitError(curPtr - 1, "unexpected end of file in literal");
|
||||
}
|
||||
|
||||
FormatToken FormatLexer::lexVariable(const char *tokStart) {
|
||||
if (!isalpha(curPtr[0]) && curPtr[0] != '_')
|
||||
return emitError(curPtr - 1, "expected variable name");
|
||||
|
||||
// Otherwise, consume the rest of the characters.
|
||||
while (isalnum(*curPtr) || *curPtr == '_')
|
||||
++curPtr;
|
||||
return formToken(FormatToken::variable, tokStart);
|
||||
}
|
||||
|
||||
FormatToken FormatLexer::lexIdentifier(const char *tokStart) {
|
||||
// Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
|
||||
while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-')
|
||||
++curPtr;
|
||||
|
||||
// Check to see if this identifier is a keyword.
|
||||
StringRef str(tokStart, curPtr - tokStart);
|
||||
auto kind =
|
||||
StringSwitch<FormatToken::Kind>(str)
|
||||
.Case("attr-dict", FormatToken::kw_attr_dict)
|
||||
.Case("attr-dict-with-keyword", FormatToken::kw_attr_dict_w_keyword)
|
||||
.Case("custom", FormatToken::kw_custom)
|
||||
.Case("functional-type", FormatToken::kw_functional_type)
|
||||
.Case("operands", FormatToken::kw_operands)
|
||||
.Case("params", FormatToken::kw_params)
|
||||
.Case("ref", FormatToken::kw_ref)
|
||||
.Case("regions", FormatToken::kw_regions)
|
||||
.Case("results", FormatToken::kw_results)
|
||||
.Case("struct", FormatToken::kw_struct)
|
||||
.Case("successors", FormatToken::kw_successors)
|
||||
.Case("type", FormatToken::kw_type)
|
||||
.Default(FormatToken::identifier);
|
||||
return FormatToken(kind, str);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utility Functions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool mlir::tblgen::shouldEmitSpaceBefore(StringRef value,
|
||||
bool lastWasPunctuation) {
|
||||
if (value.size() != 1 && value != "->")
|
||||
return true;
|
||||
if (lastWasPunctuation)
|
||||
return !StringRef(">)}],").contains(value.front());
|
||||
return !StringRef("<>(){}[],").contains(value.front());
|
||||
}
|
||||
|
||||
bool mlir::tblgen::canFormatStringAsKeyword(StringRef value) {
|
||||
if (!isalpha(value.front()) && value.front() != '_')
|
||||
return false;
|
||||
return llvm::all_of(value.drop_front(), [](char c) {
|
||||
return isalnum(c) || c == '_' || c == '$' || c == '.';
|
||||
});
|
||||
}
|
||||
|
||||
bool mlir::tblgen::isValidLiteral(StringRef value) {
|
||||
if (value.empty())
|
||||
return false;
|
||||
char front = value.front();
|
||||
|
||||
// If there is only one character, this must either be punctuation or a
|
||||
// single character bare identifier.
|
||||
if (value.size() == 1)
|
||||
return isalpha(front) || StringRef("_:,=<>()[]{}?+*").contains(front);
|
||||
|
||||
// Check the punctuation that are larger than a single character.
|
||||
if (value == "->")
|
||||
return true;
|
||||
|
||||
// Otherwise, this must be an identifier.
|
||||
return canFormatStringAsKeyword(value);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Commandline Options
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
llvm::cl::opt<bool> mlir::tblgen::formatErrorIsFatal(
|
||||
"asmformat-error-is-fatal",
|
||||
llvm::cl::desc("Emit a fatal error if format parsing fails"),
|
||||
llvm::cl::init(true));
|
||||
161
mlir/tools/mlir-tblgen/FormatGen.h
Normal file
161
mlir/tools/mlir-tblgen/FormatGen.h
Normal file
@@ -0,0 +1,161 @@
|
||||
//===- FormatGen.h - Utilities for custom assembly formats ------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file contains common classes for building custom assembly format parsers
|
||||
// and generators.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TOOLS_MLIRTBLGEN_FORMATGEN_H_
|
||||
#define MLIR_TOOLS_MLIRTBLGEN_FORMATGEN_H_
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/SMLoc.h"
|
||||
|
||||
namespace llvm {
|
||||
class SourceMgr;
|
||||
} // end namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
namespace tblgen {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FormatToken
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class represents a specific token in the input format.
|
||||
class FormatToken {
|
||||
public:
|
||||
/// Basic token kinds.
|
||||
enum Kind {
|
||||
// Markers.
|
||||
eof,
|
||||
error,
|
||||
|
||||
// Tokens with no info.
|
||||
l_paren,
|
||||
r_paren,
|
||||
caret,
|
||||
colon,
|
||||
comma,
|
||||
equal,
|
||||
less,
|
||||
greater,
|
||||
question,
|
||||
star,
|
||||
|
||||
// Keywords.
|
||||
keyword_start,
|
||||
kw_attr_dict,
|
||||
kw_attr_dict_w_keyword,
|
||||
kw_custom,
|
||||
kw_functional_type,
|
||||
kw_operands,
|
||||
kw_params,
|
||||
kw_ref,
|
||||
kw_regions,
|
||||
kw_results,
|
||||
kw_struct,
|
||||
kw_successors,
|
||||
kw_type,
|
||||
keyword_end,
|
||||
|
||||
// String valued tokens.
|
||||
identifier,
|
||||
literal,
|
||||
variable,
|
||||
};
|
||||
|
||||
FormatToken(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
|
||||
|
||||
/// Return the bytes that make up this token.
|
||||
StringRef getSpelling() const { return spelling; }
|
||||
|
||||
/// Return the kind of this token.
|
||||
Kind getKind() const { return kind; }
|
||||
|
||||
/// Return a location for this token.
|
||||
llvm::SMLoc getLoc() const;
|
||||
|
||||
/// Return if this token is a keyword.
|
||||
bool isKeyword() const {
|
||||
return getKind() > Kind::keyword_start && getKind() < Kind::keyword_end;
|
||||
}
|
||||
|
||||
private:
|
||||
/// Discriminator that indicates the kind of token this is.
|
||||
Kind kind;
|
||||
|
||||
/// A reference to the entire token contents; this is always a pointer into
|
||||
/// a memory buffer owned by the source manager.
|
||||
StringRef spelling;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FormatLexer
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class implements a simple lexer for operation assembly format strings.
|
||||
class FormatLexer {
|
||||
public:
|
||||
FormatLexer(llvm::SourceMgr &mgr, llvm::SMLoc loc);
|
||||
|
||||
/// Lex the next token and return it.
|
||||
FormatToken lexToken();
|
||||
|
||||
/// Emit an error to the lexer with the given location and message.
|
||||
FormatToken emitError(llvm::SMLoc loc, const Twine &msg);
|
||||
FormatToken emitError(const char *loc, const Twine &msg);
|
||||
|
||||
FormatToken emitErrorAndNote(llvm::SMLoc loc, const Twine &msg,
|
||||
const Twine ¬e);
|
||||
|
||||
private:
|
||||
/// Return the next character in the stream.
|
||||
int getNextChar();
|
||||
|
||||
/// Lex an identifier, literal, or variable.
|
||||
FormatToken lexIdentifier(const char *tokStart);
|
||||
FormatToken lexLiteral(const char *tokStart);
|
||||
FormatToken lexVariable(const char *tokStart);
|
||||
|
||||
/// Create a token with the current pointer and a start pointer.
|
||||
FormatToken formToken(FormatToken::Kind kind, const char *tokStart) {
|
||||
return FormatToken(kind, StringRef(tokStart, curPtr - tokStart));
|
||||
}
|
||||
|
||||
/// The source manager containing the format string.
|
||||
llvm::SourceMgr &mgr;
|
||||
/// Location of the format string.
|
||||
llvm::SMLoc loc;
|
||||
/// Buffer containing the format string.
|
||||
StringRef curBuffer;
|
||||
/// Current pointer in the buffer.
|
||||
const char *curPtr;
|
||||
};
|
||||
|
||||
/// Whether a space needs to be emitted before a literal. E.g., two keywords
|
||||
/// back-to-back require a space separator, but a keyword followed by '<' does
|
||||
/// not require a space.
|
||||
bool shouldEmitSpaceBefore(StringRef value, bool lastWasPunctuation);
|
||||
|
||||
/// Returns true if the given string can be formatted as a keyword.
|
||||
bool canFormatStringAsKeyword(StringRef value);
|
||||
|
||||
/// Returns true if the given string is valid format literal element.
|
||||
bool isValidLiteral(StringRef value);
|
||||
|
||||
/// Whether a failure in parsing the assembly format should be a fatal error.
|
||||
extern llvm::cl::opt<bool> formatErrorIsFatal;
|
||||
|
||||
} // end namespace tblgen
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TOOLS_MLIRTBLGEN_FORMATGEN_H_
|
||||
@@ -7,6 +7,7 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "OpFormatGen.h"
|
||||
#include "FormatGen.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/TableGen/Format.h"
|
||||
#include "mlir/TableGen/GenInfo.h"
|
||||
@@ -20,7 +21,6 @@
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/Signals.h"
|
||||
#include "llvm/TableGen/Error.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
@@ -30,20 +30,6 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
|
||||
static llvm::cl::opt<bool> formatErrorIsFatal(
|
||||
"asmformat-error-is-fatal",
|
||||
llvm::cl::desc("Emit a fatal error if format parsing fails"),
|
||||
llvm::cl::init(true));
|
||||
|
||||
/// Returns true if the given string can be formatted as a keyword.
|
||||
static bool canFormatStringAsKeyword(StringRef value) {
|
||||
if (!isalpha(value.front()) && value.front() != '_')
|
||||
return false;
|
||||
return llvm::all_of(value.drop_front(), [](char c) {
|
||||
return isalnum(c) || c == '_' || c == '$' || c == '.';
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Element
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -273,33 +259,12 @@ public:
|
||||
/// Return the literal for this element.
|
||||
StringRef getLiteral() const { return literal; }
|
||||
|
||||
/// Returns true if the given string is a valid literal.
|
||||
static bool isValidLiteral(StringRef value);
|
||||
|
||||
private:
|
||||
/// The spelling of the literal for this element.
|
||||
StringRef literal;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
bool LiteralElement::isValidLiteral(StringRef value) {
|
||||
if (value.empty())
|
||||
return false;
|
||||
char front = value.front();
|
||||
|
||||
// If there is only one character, this must either be punctuation or a
|
||||
// single character bare identifier.
|
||||
if (value.size() == 1)
|
||||
return isalpha(front) || StringRef("_:,=<>()[]{}?+*").contains(front);
|
||||
|
||||
// Check the punctuation that are larger than a single character.
|
||||
if (value == "->")
|
||||
return true;
|
||||
|
||||
// Otherwise, this must be an identifier.
|
||||
return canFormatStringAsKeyword(value);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// WhitespaceElement
|
||||
|
||||
@@ -1705,14 +1670,7 @@ static void genLiteralPrinter(StringRef value, OpMethodBody &body,
|
||||
body << " _odsPrinter";
|
||||
|
||||
// Don't insert a space for certain punctuation.
|
||||
auto shouldPrintSpaceBeforeLiteral = [&] {
|
||||
if (value.size() != 1 && value != "->")
|
||||
return true;
|
||||
if (lastWasPunctuation)
|
||||
return !StringRef(">)}],").contains(value.front());
|
||||
return !StringRef("<>(){}[],").contains(value.front());
|
||||
};
|
||||
if (shouldEmitSpace && shouldPrintSpaceBeforeLiteral())
|
||||
if (shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation))
|
||||
body << " << ' '";
|
||||
body << " << \"" << value << "\";\n";
|
||||
|
||||
@@ -2101,253 +2059,6 @@ void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
|
||||
lastWasPunctuation);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FormatLexer
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// This class represents a specific token in the input format.
|
||||
class Token {
|
||||
public:
|
||||
enum Kind {
|
||||
// Markers.
|
||||
eof,
|
||||
error,
|
||||
|
||||
// Tokens with no info.
|
||||
l_paren,
|
||||
r_paren,
|
||||
caret,
|
||||
colon,
|
||||
comma,
|
||||
equal,
|
||||
less,
|
||||
greater,
|
||||
question,
|
||||
|
||||
// Keywords.
|
||||
keyword_start,
|
||||
kw_attr_dict,
|
||||
kw_attr_dict_w_keyword,
|
||||
kw_custom,
|
||||
kw_functional_type,
|
||||
kw_operands,
|
||||
kw_ref,
|
||||
kw_regions,
|
||||
kw_results,
|
||||
kw_successors,
|
||||
kw_type,
|
||||
keyword_end,
|
||||
|
||||
// String valued tokens.
|
||||
identifier,
|
||||
literal,
|
||||
variable,
|
||||
};
|
||||
Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
|
||||
|
||||
/// Return the bytes that make up this token.
|
||||
StringRef getSpelling() const { return spelling; }
|
||||
|
||||
/// Return the kind of this token.
|
||||
Kind getKind() const { return kind; }
|
||||
|
||||
/// Return a location for this token.
|
||||
llvm::SMLoc getLoc() const {
|
||||
return llvm::SMLoc::getFromPointer(spelling.data());
|
||||
}
|
||||
|
||||
/// Return if this token is a keyword.
|
||||
bool isKeyword() const { return kind > keyword_start && kind < keyword_end; }
|
||||
|
||||
private:
|
||||
/// Discriminator that indicates the kind of token this is.
|
||||
Kind kind;
|
||||
|
||||
/// A reference to the entire token contents; this is always a pointer into
|
||||
/// a memory buffer owned by the source manager.
|
||||
StringRef spelling;
|
||||
};
|
||||
|
||||
/// This class implements a simple lexer for operation assembly format strings.
|
||||
class FormatLexer {
|
||||
public:
|
||||
FormatLexer(llvm::SourceMgr &mgr, Operator &op);
|
||||
|
||||
/// Lex the next token and return it.
|
||||
Token lexToken();
|
||||
|
||||
/// Emit an error to the lexer with the given location and message.
|
||||
Token emitError(llvm::SMLoc loc, const Twine &msg);
|
||||
Token emitError(const char *loc, const Twine &msg);
|
||||
|
||||
Token emitErrorAndNote(llvm::SMLoc loc, const Twine &msg, const Twine ¬e);
|
||||
|
||||
private:
|
||||
Token formToken(Token::Kind kind, const char *tokStart) {
|
||||
return Token(kind, StringRef(tokStart, curPtr - tokStart));
|
||||
}
|
||||
|
||||
/// Return the next character in the stream.
|
||||
int getNextChar();
|
||||
|
||||
/// Lex an identifier, literal, or variable.
|
||||
Token lexIdentifier(const char *tokStart);
|
||||
Token lexLiteral(const char *tokStart);
|
||||
Token lexVariable(const char *tokStart);
|
||||
|
||||
llvm::SourceMgr &srcMgr;
|
||||
Operator &op;
|
||||
StringRef curBuffer;
|
||||
const char *curPtr;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
FormatLexer::FormatLexer(llvm::SourceMgr &mgr, Operator &op)
|
||||
: srcMgr(mgr), op(op) {
|
||||
curBuffer = srcMgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer();
|
||||
curPtr = curBuffer.begin();
|
||||
}
|
||||
|
||||
Token FormatLexer::emitError(llvm::SMLoc loc, const Twine &msg) {
|
||||
srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
|
||||
llvm::SrcMgr.PrintMessage(op.getLoc()[0], llvm::SourceMgr::DK_Note,
|
||||
"in custom assembly format for this operation");
|
||||
return formToken(Token::error, loc.getPointer());
|
||||
}
|
||||
Token FormatLexer::emitErrorAndNote(llvm::SMLoc loc, const Twine &msg,
|
||||
const Twine ¬e) {
|
||||
srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
|
||||
llvm::SrcMgr.PrintMessage(op.getLoc()[0], llvm::SourceMgr::DK_Note,
|
||||
"in custom assembly format for this operation");
|
||||
srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Note, note);
|
||||
return formToken(Token::error, loc.getPointer());
|
||||
}
|
||||
Token FormatLexer::emitError(const char *loc, const Twine &msg) {
|
||||
return emitError(llvm::SMLoc::getFromPointer(loc), msg);
|
||||
}
|
||||
|
||||
int FormatLexer::getNextChar() {
|
||||
char curChar = *curPtr++;
|
||||
switch (curChar) {
|
||||
default:
|
||||
return (unsigned char)curChar;
|
||||
case 0: {
|
||||
// A nul character in the stream is either the end of the current buffer or
|
||||
// a random nul in the file. Disambiguate that here.
|
||||
if (curPtr - 1 != curBuffer.end())
|
||||
return 0;
|
||||
|
||||
// Otherwise, return end of file.
|
||||
--curPtr;
|
||||
return EOF;
|
||||
}
|
||||
case '\n':
|
||||
case '\r':
|
||||
// Handle the newline character by ignoring it and incrementing the line
|
||||
// count. However, be careful about 'dos style' files with \n\r in them.
|
||||
// Only treat a \n\r or \r\n as a single line.
|
||||
if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
|
||||
++curPtr;
|
||||
return '\n';
|
||||
}
|
||||
}
|
||||
|
||||
Token FormatLexer::lexToken() {
|
||||
const char *tokStart = curPtr;
|
||||
|
||||
// This always consumes at least one character.
|
||||
int curChar = getNextChar();
|
||||
switch (curChar) {
|
||||
default:
|
||||
// Handle identifiers: [a-zA-Z_]
|
||||
if (isalpha(curChar) || curChar == '_')
|
||||
return lexIdentifier(tokStart);
|
||||
|
||||
// Unknown character, emit an error.
|
||||
return emitError(tokStart, "unexpected character");
|
||||
case EOF:
|
||||
// Return EOF denoting the end of lexing.
|
||||
return formToken(Token::eof, tokStart);
|
||||
|
||||
// Lex punctuation.
|
||||
case '^':
|
||||
return formToken(Token::caret, tokStart);
|
||||
case ':':
|
||||
return formToken(Token::colon, tokStart);
|
||||
case ',':
|
||||
return formToken(Token::comma, tokStart);
|
||||
case '=':
|
||||
return formToken(Token::equal, tokStart);
|
||||
case '<':
|
||||
return formToken(Token::less, tokStart);
|
||||
case '>':
|
||||
return formToken(Token::greater, tokStart);
|
||||
case '?':
|
||||
return formToken(Token::question, tokStart);
|
||||
case '(':
|
||||
return formToken(Token::l_paren, tokStart);
|
||||
case ')':
|
||||
return formToken(Token::r_paren, tokStart);
|
||||
|
||||
// Ignore whitespace characters.
|
||||
case 0:
|
||||
case ' ':
|
||||
case '\t':
|
||||
case '\n':
|
||||
return lexToken();
|
||||
|
||||
case '`':
|
||||
return lexLiteral(tokStart);
|
||||
case '$':
|
||||
return lexVariable(tokStart);
|
||||
}
|
||||
}
|
||||
|
||||
Token FormatLexer::lexLiteral(const char *tokStart) {
|
||||
assert(curPtr[-1] == '`');
|
||||
|
||||
// Lex a literal surrounded by ``.
|
||||
while (const char curChar = *curPtr++) {
|
||||
if (curChar == '`')
|
||||
return formToken(Token::literal, tokStart);
|
||||
}
|
||||
return emitError(curPtr - 1, "unexpected end of file in literal");
|
||||
}
|
||||
|
||||
Token FormatLexer::lexVariable(const char *tokStart) {
|
||||
if (!isalpha(curPtr[0]) && curPtr[0] != '_')
|
||||
return emitError(curPtr - 1, "expected variable name");
|
||||
|
||||
// Otherwise, consume the rest of the characters.
|
||||
while (isalnum(*curPtr) || *curPtr == '_')
|
||||
++curPtr;
|
||||
return formToken(Token::variable, tokStart);
|
||||
}
|
||||
|
||||
Token FormatLexer::lexIdentifier(const char *tokStart) {
|
||||
// Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
|
||||
while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-')
|
||||
++curPtr;
|
||||
|
||||
// Check to see if this identifier is a keyword.
|
||||
StringRef str(tokStart, curPtr - tokStart);
|
||||
Token::Kind kind =
|
||||
StringSwitch<Token::Kind>(str)
|
||||
.Case("attr-dict", Token::kw_attr_dict)
|
||||
.Case("attr-dict-with-keyword", Token::kw_attr_dict_w_keyword)
|
||||
.Case("custom", Token::kw_custom)
|
||||
.Case("functional-type", Token::kw_functional_type)
|
||||
.Case("operands", Token::kw_operands)
|
||||
.Case("ref", Token::kw_ref)
|
||||
.Case("regions", Token::kw_regions)
|
||||
.Case("results", Token::kw_results)
|
||||
.Case("successors", Token::kw_successors)
|
||||
.Case("type", Token::kw_type)
|
||||
.Default(Token::identifier);
|
||||
return Token(kind, str);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FormatParser
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -2366,8 +2077,8 @@ namespace {
|
||||
class FormatParser {
|
||||
public:
|
||||
FormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op)
|
||||
: lexer(mgr, op), curToken(lexer.lexToken()), fmt(format), op(op),
|
||||
seenOperandTypes(op.getNumOperands()),
|
||||
: lexer(mgr, op.getLoc()[0]), curToken(lexer.lexToken()), fmt(format),
|
||||
op(op), seenOperandTypes(op.getNumOperands()),
|
||||
seenResultTypes(op.getNumResults()) {}
|
||||
|
||||
/// Parse the operation assembly format.
|
||||
@@ -2469,7 +2180,8 @@ private:
|
||||
LogicalResult parseCustomDirectiveParameter(
|
||||
std::vector<std::unique_ptr<Element>> ¶meters);
|
||||
LogicalResult parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
|
||||
Token tok, ParserContext context);
|
||||
FormatToken tok,
|
||||
ParserContext context);
|
||||
LogicalResult parseOperandsDirective(std::unique_ptr<Element> &element,
|
||||
llvm::SMLoc loc, ParserContext context);
|
||||
LogicalResult parseReferenceDirective(std::unique_ptr<Element> &element,
|
||||
@@ -2481,8 +2193,8 @@ private:
|
||||
LogicalResult parseSuccessorsDirective(std::unique_ptr<Element> &element,
|
||||
llvm::SMLoc loc,
|
||||
ParserContext context);
|
||||
LogicalResult parseTypeDirective(std::unique_ptr<Element> &element, Token tok,
|
||||
ParserContext context);
|
||||
LogicalResult parseTypeDirective(std::unique_ptr<Element> &element,
|
||||
FormatToken tok, ParserContext context);
|
||||
LogicalResult parseTypeDirectiveOperand(std::unique_ptr<Element> &element,
|
||||
bool isRefChild = false);
|
||||
|
||||
@@ -2492,12 +2204,12 @@ private:
|
||||
|
||||
/// Advance the current lexer onto the next token.
|
||||
void consumeToken() {
|
||||
assert(curToken.getKind() != Token::eof &&
|
||||
curToken.getKind() != Token::error &&
|
||||
assert(curToken.getKind() != FormatToken::eof &&
|
||||
curToken.getKind() != FormatToken::error &&
|
||||
"shouldn't advance past EOF or errors");
|
||||
curToken = lexer.lexToken();
|
||||
}
|
||||
LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
|
||||
LogicalResult parseToken(FormatToken::Kind kind, const Twine &msg) {
|
||||
if (curToken.getKind() != kind)
|
||||
return emitError(curToken.getLoc(), msg);
|
||||
consumeToken();
|
||||
@@ -2518,7 +2230,7 @@ private:
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
FormatLexer lexer;
|
||||
Token curToken;
|
||||
FormatToken curToken;
|
||||
OperationFormat &fmt;
|
||||
Operator &op;
|
||||
|
||||
@@ -2539,7 +2251,7 @@ LogicalResult FormatParser::parse() {
|
||||
llvm::SMLoc loc = curToken.getLoc();
|
||||
|
||||
// Parse each of the format elements into the main format.
|
||||
while (curToken.getKind() != Token::eof) {
|
||||
while (curToken.getKind() != FormatToken::eof) {
|
||||
std::unique_ptr<Element> element;
|
||||
if (failed(parseElement(element, TopLevelContext)))
|
||||
return ::mlir::failure();
|
||||
@@ -2864,13 +2576,13 @@ LogicalResult FormatParser::parseElement(std::unique_ptr<Element> &element,
|
||||
if (curToken.isKeyword())
|
||||
return parseDirective(element, context);
|
||||
// Literals.
|
||||
if (curToken.getKind() == Token::literal)
|
||||
if (curToken.getKind() == FormatToken::literal)
|
||||
return parseLiteral(element, context);
|
||||
// Optionals.
|
||||
if (curToken.getKind() == Token::l_paren)
|
||||
if (curToken.getKind() == FormatToken::l_paren)
|
||||
return parseOptional(element, context);
|
||||
// Variables.
|
||||
if (curToken.getKind() == Token::variable)
|
||||
if (curToken.getKind() == FormatToken::variable)
|
||||
return parseVariable(element, context);
|
||||
return emitError(curToken.getLoc(),
|
||||
"expected directive, literal, variable, or optional group");
|
||||
@@ -2878,7 +2590,7 @@ LogicalResult FormatParser::parseElement(std::unique_ptr<Element> &element,
|
||||
|
||||
LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
|
||||
ParserContext context) {
|
||||
Token varTok = curToken;
|
||||
FormatToken varTok = curToken;
|
||||
consumeToken();
|
||||
|
||||
StringRef name = varTok.getSpelling().drop_front();
|
||||
@@ -2958,31 +2670,31 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
|
||||
|
||||
LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
|
||||
ParserContext context) {
|
||||
Token dirTok = curToken;
|
||||
FormatToken dirTok = curToken;
|
||||
consumeToken();
|
||||
|
||||
switch (dirTok.getKind()) {
|
||||
case Token::kw_attr_dict:
|
||||
case FormatToken::kw_attr_dict:
|
||||
return parseAttrDictDirective(element, dirTok.getLoc(), context,
|
||||
/*withKeyword=*/false);
|
||||
case Token::kw_attr_dict_w_keyword:
|
||||
case FormatToken::kw_attr_dict_w_keyword:
|
||||
return parseAttrDictDirective(element, dirTok.getLoc(), context,
|
||||
/*withKeyword=*/true);
|
||||
case Token::kw_custom:
|
||||
case FormatToken::kw_custom:
|
||||
return parseCustomDirective(element, dirTok.getLoc(), context);
|
||||
case Token::kw_functional_type:
|
||||
case FormatToken::kw_functional_type:
|
||||
return parseFunctionalTypeDirective(element, dirTok, context);
|
||||
case Token::kw_operands:
|
||||
case FormatToken::kw_operands:
|
||||
return parseOperandsDirective(element, dirTok.getLoc(), context);
|
||||
case Token::kw_regions:
|
||||
case FormatToken::kw_regions:
|
||||
return parseRegionsDirective(element, dirTok.getLoc(), context);
|
||||
case Token::kw_results:
|
||||
case FormatToken::kw_results:
|
||||
return parseResultsDirective(element, dirTok.getLoc(), context);
|
||||
case Token::kw_successors:
|
||||
case FormatToken::kw_successors:
|
||||
return parseSuccessorsDirective(element, dirTok.getLoc(), context);
|
||||
case Token::kw_ref:
|
||||
case FormatToken::kw_ref:
|
||||
return parseReferenceDirective(element, dirTok.getLoc(), context);
|
||||
case Token::kw_type:
|
||||
case FormatToken::kw_type:
|
||||
return parseTypeDirective(element, dirTok, context);
|
||||
|
||||
default:
|
||||
@@ -2992,7 +2704,7 @@ LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
|
||||
|
||||
LogicalResult FormatParser::parseLiteral(std::unique_ptr<Element> &element,
|
||||
ParserContext context) {
|
||||
Token literalTok = curToken;
|
||||
FormatToken literalTok = curToken;
|
||||
if (context != TopLevelContext) {
|
||||
return emitError(
|
||||
literalTok.getLoc(),
|
||||
@@ -3014,7 +2726,7 @@ LogicalResult FormatParser::parseLiteral(std::unique_ptr<Element> &element,
|
||||
}
|
||||
|
||||
// Check that the parsed literal is valid.
|
||||
if (!LiteralElement::isValidLiteral(value))
|
||||
if (!isValidLiteral(value))
|
||||
return emitError(literalTok.getLoc(), "expected valid literal");
|
||||
|
||||
element = std::make_unique<LiteralElement>(value);
|
||||
@@ -3035,14 +2747,15 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
|
||||
do {
|
||||
if (failed(parseOptionalChildElement(thenElements, anchorIdx)))
|
||||
return ::mlir::failure();
|
||||
} while (curToken.getKind() != Token::r_paren);
|
||||
} while (curToken.getKind() != FormatToken::r_paren);
|
||||
consumeToken();
|
||||
|
||||
// Parse the `else` elements of this optional group.
|
||||
if (curToken.getKind() == Token::colon) {
|
||||
if (curToken.getKind() == FormatToken::colon) {
|
||||
consumeToken();
|
||||
if (failed(parseToken(Token::l_paren, "expected '(' to start else branch "
|
||||
"of optional group")))
|
||||
if (failed(parseToken(FormatToken::l_paren,
|
||||
"expected '(' to start else branch "
|
||||
"of optional group")))
|
||||
return failure();
|
||||
do {
|
||||
llvm::SMLoc childLoc = curToken.getLoc();
|
||||
@@ -3051,11 +2764,12 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
|
||||
failed(verifyOptionalChildElement(elseElements.back().get(), childLoc,
|
||||
/*isAnchor=*/false)))
|
||||
return failure();
|
||||
} while (curToken.getKind() != Token::r_paren);
|
||||
} while (curToken.getKind() != FormatToken::r_paren);
|
||||
consumeToken();
|
||||
}
|
||||
|
||||
if (failed(parseToken(Token::question, "expected '?' after optional group")))
|
||||
if (failed(parseToken(FormatToken::question,
|
||||
"expected '?' after optional group")))
|
||||
return ::mlir::failure();
|
||||
|
||||
// The optional group is required to have an anchor.
|
||||
@@ -3090,7 +2804,7 @@ LogicalResult FormatParser::parseOptionalChildElement(
|
||||
return ::mlir::failure();
|
||||
|
||||
// Check to see if this element is the anchor of the optional group.
|
||||
bool isAnchor = curToken.getKind() == Token::caret;
|
||||
bool isAnchor = curToken.getKind() == FormatToken::caret;
|
||||
if (isAnchor) {
|
||||
if (anchorIdx)
|
||||
return emitError(childLoc, "only one element can be marked as the anchor "
|
||||
@@ -3194,16 +2908,16 @@ FormatParser::parseCustomDirective(std::unique_ptr<Element> &element,
|
||||
return emitError(loc, "'custom' is only valid as a top-level directive");
|
||||
|
||||
// Parse the custom directive name.
|
||||
if (failed(
|
||||
parseToken(Token::less, "expected '<' before custom directive name")))
|
||||
if (failed(parseToken(FormatToken::less,
|
||||
"expected '<' before custom directive name")))
|
||||
return ::mlir::failure();
|
||||
|
||||
Token nameTok = curToken;
|
||||
if (failed(parseToken(Token::identifier,
|
||||
FormatToken nameTok = curToken;
|
||||
if (failed(parseToken(FormatToken::identifier,
|
||||
"expected custom directive name identifier")) ||
|
||||
failed(parseToken(Token::greater,
|
||||
failed(parseToken(FormatToken::greater,
|
||||
"expected '>' after custom directive name")) ||
|
||||
failed(parseToken(Token::l_paren,
|
||||
failed(parseToken(FormatToken::l_paren,
|
||||
"expected '(' before custom directive parameters")))
|
||||
return ::mlir::failure();
|
||||
|
||||
@@ -3212,12 +2926,12 @@ FormatParser::parseCustomDirective(std::unique_ptr<Element> &element,
|
||||
do {
|
||||
if (failed(parseCustomDirectiveParameter(elements)))
|
||||
return ::mlir::failure();
|
||||
if (curToken.getKind() != Token::comma)
|
||||
if (curToken.getKind() != FormatToken::comma)
|
||||
break;
|
||||
consumeToken();
|
||||
} while (true);
|
||||
|
||||
if (failed(parseToken(Token::r_paren,
|
||||
if (failed(parseToken(FormatToken::r_paren,
|
||||
"expected ')' after custom directive parameters")))
|
||||
return ::mlir::failure();
|
||||
|
||||
@@ -3254,9 +2968,8 @@ LogicalResult FormatParser::parseCustomDirectiveParameter(
|
||||
return ::mlir::success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
FormatParser::parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
|
||||
Token tok, ParserContext context) {
|
||||
LogicalResult FormatParser::parseFunctionalTypeDirective(
|
||||
std::unique_ptr<Element> &element, FormatToken tok, ParserContext context) {
|
||||
llvm::SMLoc loc = tok.getLoc();
|
||||
if (context != TopLevelContext)
|
||||
return emitError(
|
||||
@@ -3264,11 +2977,14 @@ FormatParser::parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
|
||||
|
||||
// Parse the main operand.
|
||||
std::unique_ptr<Element> inputs, results;
|
||||
if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) ||
|
||||
if (failed(parseToken(FormatToken::l_paren,
|
||||
"expected '(' before argument list")) ||
|
||||
failed(parseTypeDirectiveOperand(inputs)) ||
|
||||
failed(parseToken(Token::comma, "expected ',' after inputs argument")) ||
|
||||
failed(parseToken(FormatToken::comma,
|
||||
"expected ',' after inputs argument")) ||
|
||||
failed(parseTypeDirectiveOperand(results)) ||
|
||||
failed(parseToken(Token::r_paren, "expected ')' after argument list")))
|
||||
failed(
|
||||
parseToken(FormatToken::r_paren, "expected ')' after argument list")))
|
||||
return ::mlir::failure();
|
||||
element = std::make_unique<FunctionalTypeDirective>(std::move(inputs),
|
||||
std::move(results));
|
||||
@@ -3299,9 +3015,11 @@ FormatParser::parseReferenceDirective(std::unique_ptr<Element> &element,
|
||||
return emitError(loc, "'ref' is only valid within a `custom` directive");
|
||||
|
||||
std::unique_ptr<Element> operand;
|
||||
if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) ||
|
||||
if (failed(parseToken(FormatToken::l_paren,
|
||||
"expected '(' before argument list")) ||
|
||||
failed(parseElement(operand, RefDirectiveContext)) ||
|
||||
failed(parseToken(Token::r_paren, "expected ')' after argument list")))
|
||||
failed(
|
||||
parseToken(FormatToken::r_paren, "expected ')' after argument list")))
|
||||
return ::mlir::failure();
|
||||
|
||||
element = std::make_unique<RefDirective>(std::move(operand));
|
||||
@@ -3360,17 +3078,19 @@ FormatParser::parseSuccessorsDirective(std::unique_ptr<Element> &element,
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
FormatParser::parseTypeDirective(std::unique_ptr<Element> &element, Token tok,
|
||||
ParserContext context) {
|
||||
FormatParser::parseTypeDirective(std::unique_ptr<Element> &element,
|
||||
FormatToken tok, ParserContext context) {
|
||||
llvm::SMLoc loc = tok.getLoc();
|
||||
if (context == TypeDirectiveContext)
|
||||
return emitError(loc, "'type' cannot be used as a child of another `type`");
|
||||
|
||||
bool isRefChild = context == RefDirectiveContext;
|
||||
std::unique_ptr<Element> operand;
|
||||
if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) ||
|
||||
if (failed(parseToken(FormatToken::l_paren,
|
||||
"expected '(' before argument list")) ||
|
||||
failed(parseTypeDirectiveOperand(operand, isRefChild)) ||
|
||||
failed(parseToken(Token::r_paren, "expected ')' after argument list")))
|
||||
failed(
|
||||
parseToken(FormatToken::r_paren, "expected ')' after argument list")))
|
||||
return ::mlir::failure();
|
||||
|
||||
element = std::make_unique<TypeDirective>(std::move(operand));
|
||||
|
||||
@@ -220,6 +220,7 @@ cc_library(
|
||||
"//mlir:SideEffects",
|
||||
"//mlir:StandardOps",
|
||||
"//mlir:StandardOpsTransforms",
|
||||
"//mlir:Support",
|
||||
"//mlir:TensorDialect",
|
||||
"//mlir:TransformUtils",
|
||||
"//mlir:Transforms",
|
||||
|
||||
Reference in New Issue
Block a user