mirror of
https://github.com/intel/llvm.git
synced 2026-01-19 09:31:59 +08:00
[mlir][OpFormatGen] Add initial support for regions in the custom op assembly format
This adds some initial support for regions and does not support formatting the specific arguments of a region. For now this can be achieved by using a custom directive that formats the arguments and then parses the region. Differential Revision: https://reviews.llvm.org/D86760
This commit is contained in:
@@ -681,6 +681,10 @@ The available directives are as follows:
|
||||
|
||||
- Represents all of the operands of an operation.
|
||||
|
||||
* `regions`
|
||||
|
||||
- Represents all of the regions of an operation.
|
||||
|
||||
* `results`
|
||||
|
||||
- Represents all of the results of an operation.
|
||||
@@ -700,13 +704,14 @@ The available directives are as follows:
|
||||
A literal is either a keyword or punctuation surrounded by \`\`.
|
||||
|
||||
The following are the set of valid punctuation:
|
||||
`:`, `,`, `=`, `<`, `>`, `(`, `)`, `[`, `]`, `->`
|
||||
|
||||
`:`, `,`, `=`, `<`, `>`, `(`, `)`, `{`, `}`, `[`, `]`, `->`
|
||||
|
||||
#### Variables
|
||||
|
||||
A variable is an entity that has been registered on the operation itself, i.e.
|
||||
an argument(attribute or operand), result, successor, etc. In the `CallOp`
|
||||
example above, the variables would be `$callee` and `$args`.
|
||||
an argument(attribute or operand), region, result, successor, etc. In the
|
||||
`CallOp` example above, the variables would be `$callee` and `$args`.
|
||||
|
||||
Attribute variables are printed with their respective value type, unless that
|
||||
value type is buildable. In those cases, the type of the attribute is elided.
|
||||
@@ -747,6 +752,9 @@ declarative parameter to `parse` method argument is detailed below:
|
||||
- Single: `OpAsmParser::OperandType &`
|
||||
- Optional: `Optional<OpAsmParser::OperandType> &`
|
||||
- Variadic: `SmallVectorImpl<OpAsmParser::OperandType> &`
|
||||
* Region Variables
|
||||
- Single: `Region &`
|
||||
- Variadic: `SmallVectorImpl<std::unique_ptr<Region>> &`
|
||||
* Successor Variables
|
||||
- Single: `Block *&`
|
||||
- Variadic: `SmallVectorImpl<Block *> &`
|
||||
@@ -770,6 +778,9 @@ declarative parameter to `print` method argument is detailed below:
|
||||
- Single: `Value`
|
||||
- Optional: `Value`
|
||||
- Variadic: `OperandRange`
|
||||
* Region Variables
|
||||
- Single: `Region &`
|
||||
- Variadic: `MutableArrayRef<Region>`
|
||||
* Successor Variables
|
||||
- Single: `Block *`
|
||||
- Variadic: `SuccessorRange`
|
||||
@@ -788,8 +799,8 @@ of the assembly format can be marked as `optional` based on the presence of this
|
||||
information. An optional group is defined by wrapping a set of elements within
|
||||
`()` followed by a `?` and has the following requirements:
|
||||
|
||||
* The first element of the group must either be a literal, attribute, or an
|
||||
operand.
|
||||
* The first element of the group must either be a attribute, literal, operand,
|
||||
or region.
|
||||
- This is because the first element must be optionally parsable.
|
||||
* Exactly one argument variable within the group must be marked as the anchor
|
||||
of the group.
|
||||
@@ -797,11 +808,15 @@ information. An optional group is defined by wrapping a set of elements within
|
||||
should be printed/parsed.
|
||||
- An element is marked as the anchor by adding a trailing `^`.
|
||||
- The first element is *not* required to be the anchor of the group.
|
||||
- When a non-variadic region anchors a group, the detector for printing
|
||||
the group is if the region is empty.
|
||||
* Literals, variables, custom directives, and type directives are the only
|
||||
valid elements within the group.
|
||||
- Any attribute variable may be used, but only optional attributes can be
|
||||
marked as the anchor.
|
||||
- Only variadic or optional operand arguments can be used.
|
||||
- All region variables can be used. When a non-variable length region is
|
||||
used, if the group is not present the region is empty.
|
||||
- The operands to a type directive must be defined within the optional
|
||||
group.
|
||||
|
||||
@@ -853,18 +868,22 @@ foo.op
|
||||
The format specification has a certain set of requirements that must be adhered
|
||||
to:
|
||||
|
||||
1. The output and operation name are never shown as they are fixed and cannot be
|
||||
altered.
|
||||
1. All operands within the operation must appear within the format, either
|
||||
individually or with the `operands` directive.
|
||||
1. All operand and result types must appear within the format using the various
|
||||
`type` directives, either individually or with the `operands` or `results`
|
||||
directives.
|
||||
1. The `attr-dict` directive must always be present.
|
||||
1. Must not contain overlapping information; e.g. multiple instances of
|
||||
'attr-dict', types, operands, etc.
|
||||
- Note that `attr-dict` does not overlap with individual attributes. These
|
||||
attributes will simply be elided when printing the attribute dictionary.
|
||||
1. The output and operation name are never shown as they are fixed and cannot
|
||||
be altered.
|
||||
1. All operands within the operation must appear within the format, either
|
||||
individually or with the `operands` directive.
|
||||
1. All regions within the operation must appear within the format, either
|
||||
individually or with the `regions` directive.
|
||||
1. All successors within the operation must appear within the format, either
|
||||
individually or with the `successors` directive.
|
||||
1. All operand and result types must appear within the format using the various
|
||||
`type` directives, either individually or with the `operands` or `results`
|
||||
directives.
|
||||
1. The `attr-dict` directive must always be present.
|
||||
1. Must not contain overlapping information; e.g. multiple instances of
|
||||
'attr-dict', types, operands, etc.
|
||||
- Note that `attr-dict` does not overlap with individual attributes. These
|
||||
attributes will simply be elided when printing the attribute dictionary.
|
||||
|
||||
##### Type Inference
|
||||
|
||||
|
||||
@@ -424,8 +424,8 @@ public:
|
||||
Type type,
|
||||
StringRef attrName,
|
||||
NamedAttrList &attrs) = 0;
|
||||
OptionalParseResult parseOptionalAttribute(Attribute &result,
|
||||
StringRef attrName,
|
||||
template <typename AttrT>
|
||||
OptionalParseResult parseOptionalAttribute(AttrT &result, StringRef attrName,
|
||||
NamedAttrList &attrs) {
|
||||
return parseOptionalAttribute(result, Type(), attrName, attrs);
|
||||
}
|
||||
@@ -433,6 +433,7 @@ public:
|
||||
/// Specialized variants of `parseOptionalAttribute` that remove potential
|
||||
/// ambiguities in syntax.
|
||||
virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
|
||||
Type type,
|
||||
StringRef attrName,
|
||||
NamedAttrList &attrs) = 0;
|
||||
|
||||
@@ -621,16 +622,23 @@ public:
|
||||
/// can only be set to true for regions attached to operations that are
|
||||
/// "IsolatedFromAbove".
|
||||
virtual ParseResult parseRegion(Region ®ion,
|
||||
ArrayRef<OperandType> arguments,
|
||||
ArrayRef<Type> argTypes,
|
||||
ArrayRef<OperandType> arguments = {},
|
||||
ArrayRef<Type> argTypes = {},
|
||||
bool enableNameShadowing = false) = 0;
|
||||
|
||||
/// Parses a region if present.
|
||||
virtual ParseResult parseOptionalRegion(Region ®ion,
|
||||
ArrayRef<OperandType> arguments,
|
||||
ArrayRef<Type> argTypes,
|
||||
ArrayRef<OperandType> arguments = {},
|
||||
ArrayRef<Type> argTypes = {},
|
||||
bool enableNameShadowing = false) = 0;
|
||||
|
||||
/// Parses a region if present. If the region is present, a new region is
|
||||
/// allocated and placed in `region`. If no region is present or on failure,
|
||||
/// `region` remains untouched.
|
||||
virtual OptionalParseResult parseOptionalRegion(
|
||||
std::unique_ptr<Region> ®ion, ArrayRef<OperandType> arguments = {},
|
||||
ArrayRef<Type> argTypes = {}, bool enableNameShadowing = false) = 0;
|
||||
|
||||
/// Parse a region argument, this argument is resolved when calling
|
||||
/// 'parseRegion'.
|
||||
virtual ParseResult parseRegionArgument(OperandType &argument) = 0;
|
||||
|
||||
@@ -414,6 +414,10 @@ public:
|
||||
/// region is null, a new empty region will be attached to the Operation.
|
||||
void addRegion(std::unique_ptr<Region> &®ion);
|
||||
|
||||
/// Take ownership of a set of regions that should be attached to the
|
||||
/// Operation.
|
||||
void addRegions(MutableArrayRef<std::unique_ptr<Region>> regions);
|
||||
|
||||
/// Get the context held by this operation state.
|
||||
MLIRContext *getContext() const { return location->getContext(); }
|
||||
};
|
||||
|
||||
@@ -199,6 +199,12 @@ void OperationState::addRegion(std::unique_ptr<Region> &®ion) {
|
||||
regions.push_back(std::move(region));
|
||||
}
|
||||
|
||||
void OperationState::addRegions(
|
||||
MutableArrayRef<std::unique_ptr<Region>> regions) {
|
||||
for (std::unique_ptr<Region> ®ion : regions)
|
||||
addRegion(std::move(region));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OperandStorage
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -221,8 +221,9 @@ OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
|
||||
return result;
|
||||
}
|
||||
}
|
||||
OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute) {
|
||||
return parseOptionalAttributeWithToken(Token::l_square, attribute);
|
||||
OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute,
|
||||
Type type) {
|
||||
return parseOptionalAttributeWithToken(Token::l_square, attribute, type);
|
||||
}
|
||||
|
||||
/// Attribute dictionary.
|
||||
|
||||
@@ -1045,7 +1045,6 @@ public:
|
||||
}
|
||||
|
||||
/// Parse an optional attribute.
|
||||
/// Template utilities to simplify specifying multiple derived overloads.
|
||||
template <typename AttrT>
|
||||
OptionalParseResult
|
||||
parseOptionalAttributeAndAddToList(AttrT &result, Type type,
|
||||
@@ -1056,25 +1055,15 @@ public:
|
||||
attrs.push_back(parser.builder.getNamedAttr(attrName, result));
|
||||
return parseResult;
|
||||
}
|
||||
template <typename AttrT>
|
||||
OptionalParseResult parseOptionalAttributeAndAddToList(AttrT &result,
|
||||
StringRef attrName,
|
||||
NamedAttrList &attrs) {
|
||||
OptionalParseResult parseResult = parser.parseOptionalAttribute(result);
|
||||
if (parseResult.hasValue() && succeeded(*parseResult))
|
||||
attrs.push_back(parser.builder.getNamedAttr(attrName, result));
|
||||
return parseResult;
|
||||
}
|
||||
|
||||
OptionalParseResult parseOptionalAttribute(Attribute &result, Type type,
|
||||
StringRef attrName,
|
||||
NamedAttrList &attrs) override {
|
||||
return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
|
||||
}
|
||||
OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
|
||||
OptionalParseResult parseOptionalAttribute(ArrayAttr &result, Type type,
|
||||
StringRef attrName,
|
||||
NamedAttrList &attrs) override {
|
||||
return parseOptionalAttributeAndAddToList(result, attrName, attrs);
|
||||
return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
|
||||
}
|
||||
|
||||
/// Parse a named dictionary into 'result' if it is present.
|
||||
@@ -1355,6 +1344,23 @@ public:
|
||||
return parseRegion(region, arguments, argTypes, enableNameShadowing);
|
||||
}
|
||||
|
||||
/// Parses a region if present. If the region is present, a new region is
|
||||
/// allocated and placed in `region`. If no region is present, `region`
|
||||
/// remains untouched.
|
||||
OptionalParseResult
|
||||
parseOptionalRegion(std::unique_ptr<Region> ®ion,
|
||||
ArrayRef<OperandType> arguments, ArrayRef<Type> argTypes,
|
||||
bool enableNameShadowing = false) override {
|
||||
if (parser.getToken().isNot(Token::l_brace))
|
||||
return llvm::None;
|
||||
std::unique_ptr<Region> newRegion = std::make_unique<Region>();
|
||||
if (parseRegion(*newRegion, arguments, argTypes, enableNameShadowing))
|
||||
return failure();
|
||||
|
||||
region = std::move(newRegion);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Parse a region argument. The type of the argument will be resolved later
|
||||
/// by a call to `parseRegion`.
|
||||
ParseResult parseRegionArgument(OperandType &argument) override {
|
||||
|
||||
@@ -187,7 +187,7 @@ public:
|
||||
/// Parse an optional attribute with the provided type.
|
||||
OptionalParseResult parseOptionalAttribute(Attribute &attribute,
|
||||
Type type = {});
|
||||
OptionalParseResult parseOptionalAttribute(ArrayAttr &attribute);
|
||||
OptionalParseResult parseOptionalAttribute(ArrayAttr &attribute, Type type);
|
||||
|
||||
/// Parse an optional attribute that is demarcated by a specific token.
|
||||
template <typename AttributeT>
|
||||
@@ -197,8 +197,8 @@ public:
|
||||
if (getToken().isNot(kind))
|
||||
return llvm::None;
|
||||
|
||||
if (Attribute parsedAttr = parseAttribute()) {
|
||||
attr = parsedAttr.cast<ArrayAttr>();
|
||||
if (Attribute parsedAttr = parseAttribute(type)) {
|
||||
attr = parsedAttr.cast<AttributeT>();
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
|
||||
@@ -319,6 +319,19 @@ static ParseResult parseCustomDirectiveOperandsAndTypes(
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
static ParseResult parseCustomDirectiveRegions(
|
||||
OpAsmParser &parser, Region ®ion,
|
||||
SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
|
||||
if (parser.parseRegion(region))
|
||||
return failure();
|
||||
if (failed(parser.parseOptionalComma()))
|
||||
return success();
|
||||
std::unique_ptr<Region> varRegion = std::make_unique<Region>();
|
||||
if (parser.parseRegion(*varRegion))
|
||||
return failure();
|
||||
varRegions.emplace_back(std::move(varRegion));
|
||||
return success();
|
||||
}
|
||||
static ParseResult
|
||||
parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
|
||||
SmallVectorImpl<Block *> &varSuccessors) {
|
||||
@@ -361,6 +374,15 @@ printCustomDirectiveOperandsAndTypes(OpAsmPrinter &printer, Value operand,
|
||||
printCustomDirectiveResults(printer, operandType, optOperandType,
|
||||
varOperandTypes);
|
||||
}
|
||||
static void printCustomDirectiveRegions(OpAsmPrinter &printer, Region ®ion,
|
||||
MutableArrayRef<Region> varRegions) {
|
||||
printer.printRegion(region);
|
||||
if (!varRegions.empty()) {
|
||||
printer << ", ";
|
||||
for (Region ®ion : varRegions)
|
||||
printer.printRegion(region);
|
||||
}
|
||||
}
|
||||
static void printCustomDirectiveSuccessors(OpAsmPrinter &printer,
|
||||
Block *successor,
|
||||
SuccessorRange varSuccessors) {
|
||||
|
||||
@@ -1161,8 +1161,13 @@ def TestRecursiveRewriteOp : TEST_Op<"recursive_rewrite"> {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TestRegionBuilderOp : TEST_Op<"region_builder">;
|
||||
def TestReturnOp : TEST_Op<"return", [ReturnLike, Terminator]>,
|
||||
Arguments<(ins Variadic<AnyType>)>;
|
||||
def TestReturnOp : TEST_Op<"return", [ReturnLike, Terminator]> {
|
||||
let arguments = (ins Variadic<AnyType>);
|
||||
let builders = [
|
||||
OpBuilder<"OpBuilder &builder, OperationState &state",
|
||||
[{ build(builder, state, {}); }]>
|
||||
];
|
||||
}
|
||||
def TestCastOp : TEST_Op<"cast">,
|
||||
Arguments<(ins Variadic<AnyType>)>, Results<(outs AnyType)>;
|
||||
def TestInvalidOp : TEST_Op<"invalid", [Terminator]>,
|
||||
@@ -1333,6 +1338,43 @@ def FormatBuildableTypeOp : TEST_Op<"format_buildable_type_op"> {
|
||||
let assemblyFormat = "$buildable attr-dict";
|
||||
}
|
||||
|
||||
// Test various mixings of region formatting.
|
||||
class FormatRegionBase<string suffix, string fmt>
|
||||
: TEST_Op<"format_region_" # suffix # "_op"> {
|
||||
let regions = (region AnyRegion:$region);
|
||||
let assemblyFormat = fmt;
|
||||
}
|
||||
def FormatRegionAOp : FormatRegionBase<"a", [{
|
||||
regions attr-dict
|
||||
}]>;
|
||||
def FormatRegionBOp : FormatRegionBase<"b", [{
|
||||
$region attr-dict
|
||||
}]>;
|
||||
def FormatRegionCOp : FormatRegionBase<"c", [{
|
||||
(`region` $region^)? attr-dict
|
||||
}]>;
|
||||
class FormatVariadicRegionBase<string suffix, string fmt>
|
||||
: TEST_Op<"format_variadic_region_" # suffix # "_op"> {
|
||||
let regions = (region VariadicRegion<AnyRegion>:$regions);
|
||||
let assemblyFormat = fmt;
|
||||
}
|
||||
def FormatVariadicRegionAOp : FormatVariadicRegionBase<"a", [{
|
||||
$regions attr-dict
|
||||
}]>;
|
||||
def FormatVariadicRegionBOp : FormatVariadicRegionBase<"b", [{
|
||||
($regions^ `found_regions`)? attr-dict
|
||||
}]>;
|
||||
class FormatRegionImplicitTerminatorBase<string suffix, string fmt>
|
||||
: TEST_Op<"format_implicit_terminator_region_" # suffix # "_op",
|
||||
[SingleBlockImplicitTerminator<"TestReturnOp">]> {
|
||||
let regions = (region AnyRegion:$region);
|
||||
let assemblyFormat = fmt;
|
||||
}
|
||||
def FormatFormatRegionImplicitTerminatorAOp
|
||||
: FormatRegionImplicitTerminatorBase<"a", [{
|
||||
$region attr-dict
|
||||
}]>;
|
||||
|
||||
// Test various mixings of result type formatting.
|
||||
class FormatResultBase<string suffix, string fmt>
|
||||
: TEST_Op<"format_result_" # suffix # "_op"> {
|
||||
@@ -1454,6 +1496,16 @@ def FormatCustomDirectiveOperandsAndTypes
|
||||
}];
|
||||
}
|
||||
|
||||
def FormatCustomDirectiveRegions : TEST_Op<"format_custom_directive_regions"> {
|
||||
let regions = (region AnyRegion:$region, VariadicRegion<AnyRegion>:$regions);
|
||||
let assemblyFormat = [{
|
||||
custom<CustomDirectiveRegions>(
|
||||
$region, $regions
|
||||
)
|
||||
attr-dict
|
||||
}];
|
||||
}
|
||||
|
||||
def FormatCustomDirectiveResults
|
||||
: TEST_Op<"format_custom_directive_results", [AttrSizedResultSegments]> {
|
||||
let results = (outs AnyType:$result, Optional<AnyType>:$optResult,
|
||||
|
||||
@@ -133,6 +133,28 @@ def DirectiveOperandsValid : TestFormat_Op<"operands_valid", [{
|
||||
operands attr-dict
|
||||
}]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// regions
|
||||
|
||||
// CHECK: error: 'regions' directive creates overlap in format
|
||||
def DirectiveRegionsInvalidA : TestFormat_Op<"regions_invalid_a", [{
|
||||
regions regions attr-dict
|
||||
}]>;
|
||||
// CHECK: error: 'regions' directive creates overlap in format
|
||||
def DirectiveRegionsInvalidB : TestFormat_Op<"regions_invalid_b", [{
|
||||
$region regions attr-dict
|
||||
}]> {
|
||||
let regions = (region AnyRegion:$region);
|
||||
}
|
||||
// CHECK: error: 'regions' is only valid as a top-level directive
|
||||
def DirectiveRegionsInvalidC : TestFormat_Op<"regions_invalid_c", [{
|
||||
type(regions)
|
||||
}]>;
|
||||
// CHECK-NOT: error:
|
||||
def DirectiveRegionsValid : TestFormat_Op<"regions_valid", [{
|
||||
regions attr-dict
|
||||
}]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// results
|
||||
|
||||
@@ -249,7 +271,7 @@ def OptionalInvalidB : TestFormat_Op<"optional_invalid_b", [{
|
||||
def OptionalInvalidC : TestFormat_Op<"optional_invalid_c", [{
|
||||
($attr)? attr-dict
|
||||
}]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
|
||||
// CHECK: error: first element of an operand group must be an attribute, literal, or operand
|
||||
// CHECK: error: first element of an operand group must be an attribute, literal, operand, or region
|
||||
def OptionalInvalidD : TestFormat_Op<"optional_invalid_d", [{
|
||||
(type($operand) $operand^)? attr-dict
|
||||
}]>, Arguments<(ins Optional<I64>:$operand)>;
|
||||
@@ -290,7 +312,7 @@ def OptionalInvalidL : TestFormat_Op<"optional_invalid_l", [{
|
||||
// Variables
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK: error: expected variable to refer to an argument, result, or successor
|
||||
// CHECK: error: expected variable to refer to an argument, region, result, or successor
|
||||
def VariableInvalidA : TestFormat_Op<"variable_invalid_a", [{
|
||||
$unknown_arg attr-dict
|
||||
}]>;
|
||||
@@ -330,11 +352,35 @@ def VariableInvalidH : TestFormat_Op<"variable_invalid_h", [{
|
||||
def VariableInvalidI : TestFormat_Op<"variable_invalid_i", [{
|
||||
(`foo` $attr^)? `:` attr-dict
|
||||
}]>, Arguments<(ins OptionalAttr<ElementsAttr>:$attr)>;
|
||||
// CHECK-NOT: error:
|
||||
// CHECK: error: region 'region' is already bound
|
||||
def VariableInvalidJ : TestFormat_Op<"variable_invalid_j", [{
|
||||
$region $region attr-dict
|
||||
}]> {
|
||||
let regions = (region AnyRegion:$region);
|
||||
}
|
||||
// CHECK: error: region 'region' is already bound
|
||||
def VariableInvalidK : TestFormat_Op<"variable_invalid_K", [{
|
||||
regions $region attr-dict
|
||||
}]> {
|
||||
let regions = (region AnyRegion:$region);
|
||||
}
|
||||
// CHECK: error: regions can only be used at the top level
|
||||
def VariableInvalidL : TestFormat_Op<"variable_invalid_l", [{
|
||||
type($region)
|
||||
}]> {
|
||||
let regions = (region AnyRegion:$region);
|
||||
}
|
||||
// CHECK: error: region #0, named 'region', not found
|
||||
def VariableInvalidM : TestFormat_Op<"variable_invalid_m", [{
|
||||
attr-dict
|
||||
}]> {
|
||||
let regions = (region AnyRegion:$region);
|
||||
}
|
||||
// CHECK-NOT: error:
|
||||
def VariableValidA : TestFormat_Op<"variable_valid_a", [{
|
||||
$attr `:` attr-dict
|
||||
}]>, Arguments<(ins OptionalAttr<I1Attr>:$attr)>;
|
||||
def VariableInvalidK : TestFormat_Op<"variable_invalid_k", [{
|
||||
def VariableValidB : TestFormat_Op<"variable_valid_b", [{
|
||||
(`foo` $attr^)? `:` attr-dict
|
||||
}]>, Arguments<(ins OptionalAttr<I1Attr>:$attr)>;
|
||||
|
||||
|
||||
@@ -40,6 +40,72 @@ test.format_attr_dict_w_keyword attributes {attr = 10 : i64, opt_attr = 10 : i64
|
||||
// CHECK: test.format_buildable_type_op %[[I64]]
|
||||
%ignored = test.format_buildable_type_op %i64
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Format regions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK: test.format_region_a_op {
|
||||
// CHECK-NEXT: test.return
|
||||
test.format_region_a_op {
|
||||
"test.return"() : () -> ()
|
||||
}
|
||||
|
||||
// CHECK: test.format_region_b_op {
|
||||
// CHECK-NEXT: test.return
|
||||
test.format_region_b_op {
|
||||
"test.return"() : () -> ()
|
||||
}
|
||||
|
||||
// CHECK: test.format_region_c_op region {
|
||||
// CHECK-NEXT: test.return
|
||||
test.format_region_c_op region {
|
||||
"test.return"() : () -> ()
|
||||
}
|
||||
// CHECK: test.format_region_c_op
|
||||
// CHECK-NOT: region {
|
||||
test.format_region_c_op
|
||||
|
||||
// CHECK: test.format_variadic_region_a_op {
|
||||
// CHECK-NEXT: test.return
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: test.return
|
||||
// CHECK-NEXT: }
|
||||
test.format_variadic_region_a_op {
|
||||
"test.return"() : () -> ()
|
||||
}, {
|
||||
"test.return"() : () -> ()
|
||||
}
|
||||
// CHECK: test.format_variadic_region_b_op {
|
||||
// CHECK-NEXT: test.return
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: test.return
|
||||
// CHECK-NEXT: } found_regions
|
||||
test.format_variadic_region_b_op {
|
||||
"test.return"() : () -> ()
|
||||
}, {
|
||||
"test.return"() : () -> ()
|
||||
} found_regions
|
||||
// CHECK: test.format_variadic_region_b_op
|
||||
// CHECK-NOT: {
|
||||
// CHECK-NOT: found_regions
|
||||
test.format_variadic_region_b_op
|
||||
|
||||
// CHECK: test.format_implicit_terminator_region_a_op {
|
||||
// CHECK-NEXT: }
|
||||
test.format_implicit_terminator_region_a_op {
|
||||
"test.return"() : () -> ()
|
||||
}
|
||||
// CHECK: test.format_implicit_terminator_region_a_op {
|
||||
// CHECK-NEXT: test.return"() {foo.attr
|
||||
test.format_implicit_terminator_region_a_op {
|
||||
"test.return"() {foo.attr} : () -> ()
|
||||
}
|
||||
// CHECK: test.format_implicit_terminator_region_a_op {
|
||||
// CHECK-NEXT: test.return"(%[[I64]]) : (i64)
|
||||
test.format_implicit_terminator_region_a_op {
|
||||
"test.return"(%i64) : (i64) -> ()
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Format results
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -147,6 +213,24 @@ test.format_custom_directive_operands_and_types %i64, %i64 -> (%i64) : i64, i64
|
||||
// CHECK: test.format_custom_directive_operands_and_types %[[I64]] -> (%[[I64]]) : i64 -> (i64)
|
||||
test.format_custom_directive_operands_and_types %i64 -> (%i64) : i64 -> (i64)
|
||||
|
||||
// CHECK: test.format_custom_directive_regions {
|
||||
// CHECK-NEXT: test.return
|
||||
// CHECK-NEXT: }
|
||||
test.format_custom_directive_regions {
|
||||
"test.return"() : () -> ()
|
||||
}
|
||||
|
||||
// CHECK: test.format_custom_directive_regions {
|
||||
// CHECK-NEXT: test.return
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: test.return
|
||||
// CHECK-NEXT: }
|
||||
test.format_custom_directive_regions {
|
||||
"test.return"() : () -> ()
|
||||
}, {
|
||||
"test.return"() : () -> ()
|
||||
}
|
||||
|
||||
// CHECK: test.format_custom_directive_results : i64, i64 -> (i64)
|
||||
test.format_custom_directive_results : i64, i64 -> (i64)
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "mlir/TableGen/Operator.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
@@ -48,6 +49,7 @@ public:
|
||||
CustomDirective,
|
||||
FunctionalTypeDirective,
|
||||
OperandsDirective,
|
||||
RegionsDirective,
|
||||
ResultsDirective,
|
||||
SuccessorsDirective,
|
||||
TypeDirective,
|
||||
@@ -58,6 +60,7 @@ public:
|
||||
/// This element is an variable value.
|
||||
AttributeVariable,
|
||||
OperandVariable,
|
||||
RegionVariable,
|
||||
ResultVariable,
|
||||
SuccessorVariable,
|
||||
|
||||
@@ -119,6 +122,10 @@ struct AttributeVariable
|
||||
using OperandVariable =
|
||||
VariableElement<NamedTypeConstraint, Element::Kind::OperandVariable>;
|
||||
|
||||
/// This class represents a variable that refers to a region.
|
||||
using RegionVariable =
|
||||
VariableElement<NamedRegion, Element::Kind::RegionVariable>;
|
||||
|
||||
/// This class represents a variable that refers to a result.
|
||||
using ResultVariable =
|
||||
VariableElement<NamedTypeConstraint, Element::Kind::ResultVariable>;
|
||||
@@ -133,7 +140,8 @@ using SuccessorVariable =
|
||||
|
||||
namespace {
|
||||
/// This class implements single kind directives.
|
||||
template <Element::Kind type> class DirectiveElement : public Element {
|
||||
template <Element::Kind type>
|
||||
class DirectiveElement : public Element {
|
||||
public:
|
||||
DirectiveElement() : Element(type){};
|
||||
static bool classof(const Element *ele) { return ele->getKind() == type; }
|
||||
@@ -142,6 +150,10 @@ public:
|
||||
/// all of the operands of an operation.
|
||||
using OperandsDirective = DirectiveElement<Element::Kind::OperandsDirective>;
|
||||
|
||||
/// This class represents the `regions` directive. This directive represents
|
||||
/// all of the regions of an operation.
|
||||
using RegionsDirective = DirectiveElement<Element::Kind::ResultsDirective>;
|
||||
|
||||
/// This class represents the `results` directive. This directive represents
|
||||
/// all of the results of an operation.
|
||||
using ResultsDirective = DirectiveElement<Element::Kind::ResultsDirective>;
|
||||
@@ -350,13 +362,23 @@ struct OperationFormat {
|
||||
: allOperands(false), allOperandTypes(false), allResultTypes(false) {
|
||||
operandTypes.resize(op.getNumOperands(), TypeResolution());
|
||||
resultTypes.resize(op.getNumResults(), TypeResolution());
|
||||
|
||||
hasImplicitTermTrait =
|
||||
llvm::any_of(op.getTraits(), [](const OpTrait &trait) {
|
||||
return trait.getDef().isSubClassOf("SingleBlockImplicitTerminator");
|
||||
});
|
||||
}
|
||||
|
||||
/// Generate the operation parser from this format.
|
||||
void genParser(Operator &op, OpClass &opClass);
|
||||
/// Generate the parser code for a specific format element.
|
||||
void genElementParser(Element *element, OpMethodBody &body,
|
||||
FmtContext &attrTypeCtx);
|
||||
/// Generate the c++ to resolve the types of operands and results during
|
||||
/// parsing.
|
||||
void genParserTypeResolution(Operator &op, OpMethodBody &body);
|
||||
/// Generate the c++ to resolve regions during parsing.
|
||||
void genParserRegionResolution(Operator &op, OpMethodBody &body);
|
||||
/// Generate the c++ to resolve successors during parsing.
|
||||
void genParserSuccessorResolution(Operator &op, OpMethodBody &body);
|
||||
/// Generate the c++ to handling variadic segment size traits.
|
||||
@@ -365,6 +387,10 @@ struct OperationFormat {
|
||||
/// Generate the operation printer from this format.
|
||||
void genPrinter(Operator &op, OpClass &opClass);
|
||||
|
||||
/// Generate the printer code for a specific format element.
|
||||
void genElementPrinter(Element *element, OpMethodBody &body, Operator &op,
|
||||
bool &shouldEmitSpace, bool &lastWasPunctuation);
|
||||
|
||||
/// The various elements in this format.
|
||||
std::vector<std::unique_ptr<Element>> elements;
|
||||
|
||||
@@ -372,11 +398,18 @@ struct OperationFormat {
|
||||
/// contains these, it can not contain individual type resolvers.
|
||||
bool allOperands, allOperandTypes, allResultTypes;
|
||||
|
||||
/// A flag indicating if this operation has the SingleBlockImplicitTerminator
|
||||
/// trait.
|
||||
bool hasImplicitTermTrait;
|
||||
|
||||
/// A map of buildable types to indices.
|
||||
llvm::MapVector<StringRef, int, llvm::StringMap<int>> buildableTypes;
|
||||
|
||||
/// The index of the buildable type, if valid, for every operand and result.
|
||||
std::vector<TypeResolution> operandTypes, resultTypes;
|
||||
|
||||
/// The set of attributes explicitly used within the format.
|
||||
SmallVector<const NamedAttribute *, 8> usedAttributes;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
@@ -541,6 +574,60 @@ const char *const functionalTypeParserCode = R"(
|
||||
{1}Types = {0}__{1}_functionType.getResults();
|
||||
)";
|
||||
|
||||
/// The code snippet used to generate a parser call for a region list.
|
||||
///
|
||||
/// {0}: The name for the region list.
|
||||
const char *regionListParserCode = R"(
|
||||
{
|
||||
std::unique_ptr<::mlir::Region> region;
|
||||
auto firstRegionResult = parser.parseOptionalRegion(region);
|
||||
if (firstRegionResult.hasValue()) {
|
||||
if (failed(*firstRegionResult))
|
||||
return failure();
|
||||
{0}Regions.emplace_back(std::move(region));
|
||||
|
||||
// Parse any trailing regions.
|
||||
while (succeeded(parser.parseOptionalComma())) {
|
||||
region = std::make_unique<::mlir::Region>();
|
||||
if (parser.parseRegion(*region))
|
||||
return failure();
|
||||
{0}Regions.emplace_back(std::move(region));
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
/// The code snippet used to ensure a list of regions have terminators.
|
||||
///
|
||||
/// {0}: The name of the region list.
|
||||
const char *regionListEnsureTerminatorParserCode = R"(
|
||||
for (auto ®ion : {0}Regions)
|
||||
ensureTerminator(*region, parser.getBuilder(), result.location);
|
||||
)";
|
||||
|
||||
/// The code snippet used to generate a parser call for an optional region.
|
||||
///
|
||||
/// {0}: The name of the region.
|
||||
const char *optionalRegionParserCode = R"(
|
||||
if (parser.parseOptionalRegion(*{0}Region))
|
||||
return failure();
|
||||
)";
|
||||
|
||||
/// The code snippet used to generate a parser call for a region.
|
||||
///
|
||||
/// {0}: The name of the region.
|
||||
const char *regionParserCode = R"(
|
||||
if (parser.parseRegion(*{0}Region))
|
||||
return failure();
|
||||
)";
|
||||
|
||||
/// The code snippet used to ensure a region has a terminator.
|
||||
///
|
||||
/// {0}: The name of the region.
|
||||
const char *regionEnsureTerminatorParserCode = R"(
|
||||
ensureTerminator(*{0}Region, parser.getBuilder(), result.location);
|
||||
)";
|
||||
|
||||
/// The code snippet used to generate a parser call for a successor list.
|
||||
///
|
||||
/// {0}: The name for the successor list.
|
||||
@@ -658,6 +745,10 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
|
||||
body << " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
|
||||
"allOperands;\n";
|
||||
|
||||
} else if (isa<RegionsDirective>(element)) {
|
||||
body << " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
|
||||
"fullRegions;\n";
|
||||
|
||||
} else if (isa<SuccessorsDirective>(element)) {
|
||||
body << " ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n";
|
||||
|
||||
@@ -680,6 +771,20 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
|
||||
body << llvm::formatv(" ::llvm::SMLoc {0}OperandsLoc;\n"
|
||||
" (void){0}OperandsLoc;\n",
|
||||
name);
|
||||
|
||||
} else if (auto *region = dyn_cast<RegionVariable>(element)) {
|
||||
StringRef name = region->getVar()->name;
|
||||
if (region->getVar()->isVariadic()) {
|
||||
body << llvm::formatv(
|
||||
" ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
|
||||
"{0}Regions;\n",
|
||||
name);
|
||||
} else {
|
||||
body << llvm::formatv(" std::unique_ptr<::mlir::Region> {0}Region = "
|
||||
"std::make_unique<::mlir::Region>();\n",
|
||||
name);
|
||||
}
|
||||
|
||||
} else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
|
||||
StringRef name = successor->getVar()->name;
|
||||
if (successor->getVar()->isVariadic()) {
|
||||
@@ -725,6 +830,13 @@ static void genCustomParameterParser(Element ¶m, OpMethodBody &body) {
|
||||
else
|
||||
body << formatv("{0}RawOperands[0]", name);
|
||||
|
||||
} else if (auto *region = dyn_cast<RegionVariable>(¶m)) {
|
||||
StringRef name = region->getVar()->name;
|
||||
if (region->getVar()->isVariadic())
|
||||
body << llvm::formatv("{0}Regions", name);
|
||||
else
|
||||
body << llvm::formatv("*{0}Region", name);
|
||||
|
||||
} else if (auto *successor = dyn_cast<SuccessorVariable>(¶m)) {
|
||||
StringRef name = successor->getVar()->name;
|
||||
if (successor->getVar()->isVariadic())
|
||||
@@ -809,9 +921,39 @@ static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
|
||||
body << " }\n";
|
||||
}
|
||||
|
||||
/// Generate the parser for a single format element.
|
||||
static void genElementParser(Element *element, OpMethodBody &body,
|
||||
FmtContext &attrTypeCtx) {
|
||||
void OperationFormat::genParser(Operator &op, OpClass &opClass) {
|
||||
auto &method = opClass.newMethod(
|
||||
"::mlir::ParseResult", "parse",
|
||||
"::mlir::OpAsmParser &parser, ::mlir::OperationState &result",
|
||||
OpMethod::MP_Static);
|
||||
auto &body = method.body();
|
||||
|
||||
// Generate variables to store the operands and type within the format. This
|
||||
// allows for referencing these variables in the presence of optional
|
||||
// groupings.
|
||||
for (auto &element : elements)
|
||||
genElementParserStorage(&*element, body);
|
||||
|
||||
// A format context used when parsing attributes with buildable types.
|
||||
FmtContext attrTypeCtx;
|
||||
attrTypeCtx.withBuilder("parser.getBuilder()");
|
||||
|
||||
// Generate parsers for each of the elements.
|
||||
for (auto &element : elements)
|
||||
genElementParser(element.get(), body, attrTypeCtx);
|
||||
|
||||
// Generate the code to resolve the operand/result types and successors now
|
||||
// that they have been parsed.
|
||||
genParserTypeResolution(op, body);
|
||||
genParserRegionResolution(op, body);
|
||||
genParserSuccessorResolution(op, body);
|
||||
genParserVariadicSegmentResolution(op, body);
|
||||
|
||||
body << " return success();\n";
|
||||
}
|
||||
|
||||
void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
|
||||
FmtContext &attrTypeCtx) {
|
||||
/// Optional Group.
|
||||
if (auto *optional = dyn_cast<OptionalElement>(element)) {
|
||||
auto elements = optional->getElements();
|
||||
@@ -829,6 +971,17 @@ static void genElementParser(Element *element, OpMethodBody &body,
|
||||
} else if (auto *opVar = dyn_cast<OperandVariable>(firstElement)) {
|
||||
genElementParser(opVar, body, attrTypeCtx);
|
||||
body << " if (!" << opVar->getVar()->name << "Operands.empty()) {\n";
|
||||
} else if (auto *regionVar = dyn_cast<RegionVariable>(firstElement)) {
|
||||
const NamedRegion *region = regionVar->getVar();
|
||||
if (region->isVariadic()) {
|
||||
genElementParser(regionVar, body, attrTypeCtx);
|
||||
body << " if (!" << region->name << "Regions.empty()) {\n";
|
||||
} else {
|
||||
body << llvm::formatv(optionalRegionParserCode, region->name);
|
||||
body << " if (!" << region->name << "Region->empty()) {\n ";
|
||||
if (hasImplicitTermTrait)
|
||||
body << llvm::formatv(regionEnsureTerminatorParserCode, region->name);
|
||||
}
|
||||
}
|
||||
|
||||
// If the anchor is a unit attribute, we don't need to print it. When
|
||||
@@ -907,6 +1060,17 @@ static void genElementParser(Element *element, OpMethodBody &body,
|
||||
body << llvm::formatv(optionalOperandParserCode, name);
|
||||
else
|
||||
body << formatv(operandParserCode, name);
|
||||
|
||||
} else if (auto *region = dyn_cast<RegionVariable>(element)) {
|
||||
bool isVariadic = region->getVar()->isVariadic();
|
||||
body << llvm::formatv(isVariadic ? regionListParserCode : regionParserCode,
|
||||
region->getVar()->name);
|
||||
if (hasImplicitTermTrait) {
|
||||
body << llvm::formatv(isVariadic ? regionListEnsureTerminatorParserCode
|
||||
: regionEnsureTerminatorParserCode,
|
||||
region->getVar()->name);
|
||||
}
|
||||
|
||||
} else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
|
||||
bool isVariadic = successor->getVar()->isVariadic();
|
||||
body << formatv(isVariadic ? successorListParserCode : successorParserCode,
|
||||
@@ -925,8 +1089,15 @@ static void genElementParser(Element *element, OpMethodBody &body,
|
||||
body << " ::llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n"
|
||||
<< " if (parser.parseOperandList(allOperands))\n"
|
||||
<< " return failure();\n";
|
||||
|
||||
} else if (isa<RegionsDirective>(element)) {
|
||||
body << llvm::formatv(regionListParserCode, "full");
|
||||
if (hasImplicitTermTrait)
|
||||
body << llvm::formatv(regionListEnsureTerminatorParserCode, "full");
|
||||
|
||||
} else if (isa<SuccessorsDirective>(element)) {
|
||||
body << llvm::formatv(successorListParserCode, "full");
|
||||
|
||||
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
|
||||
ArgumentLengthKind lengthKind;
|
||||
StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
|
||||
@@ -946,36 +1117,6 @@ static void genElementParser(Element *element, OpMethodBody &body,
|
||||
}
|
||||
}
|
||||
|
||||
void OperationFormat::genParser(Operator &op, OpClass &opClass) {
|
||||
auto &method = opClass.newMethod(
|
||||
"::mlir::ParseResult", "parse",
|
||||
"::mlir::OpAsmParser &parser, ::mlir::OperationState &result",
|
||||
OpMethod::MP_Static);
|
||||
auto &body = method.body();
|
||||
|
||||
// Generate variables to store the operands and type within the format. This
|
||||
// allows for referencing these variables in the presence of optional
|
||||
// groupings.
|
||||
for (auto &element : elements)
|
||||
genElementParserStorage(&*element, body);
|
||||
|
||||
// A format context used when parsing attributes with buildable types.
|
||||
FmtContext attrTypeCtx;
|
||||
attrTypeCtx.withBuilder("parser.getBuilder()");
|
||||
|
||||
// Generate parsers for each of the elements.
|
||||
for (auto &element : elements)
|
||||
genElementParser(element.get(), body, attrTypeCtx);
|
||||
|
||||
// Generate the code to resolve the operand/result types and successors now
|
||||
// that they have been parsed.
|
||||
genParserTypeResolution(op, body);
|
||||
genParserSuccessorResolution(op, body);
|
||||
genParserVariadicSegmentResolution(op, body);
|
||||
|
||||
body << " return success();\n";
|
||||
}
|
||||
|
||||
void OperationFormat::genParserTypeResolution(Operator &op,
|
||||
OpMethodBody &body) {
|
||||
// If any of type resolutions use transformed variables, make sure that the
|
||||
@@ -1133,6 +1274,25 @@ void OperationFormat::genParserTypeResolution(Operator &op,
|
||||
}
|
||||
}
|
||||
|
||||
void OperationFormat::genParserRegionResolution(Operator &op,
|
||||
OpMethodBody &body) {
|
||||
// Check for the case where all regions were parsed.
|
||||
bool hasAllRegions = llvm::any_of(
|
||||
elements, [](auto &elt) { return isa<RegionsDirective>(elt.get()); });
|
||||
if (hasAllRegions) {
|
||||
body << " result.addRegions(fullRegions);\n";
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, handle each region individually.
|
||||
for (const NamedRegion ®ion : op.getRegions()) {
|
||||
if (region.isVariadic())
|
||||
body << " result.addRegions(" << region.name << "Regions);\n";
|
||||
else
|
||||
body << " result.addRegion(std::move(" << region.name << "Region));\n";
|
||||
}
|
||||
}
|
||||
|
||||
void OperationFormat::genParserSuccessorResolution(Operator &op,
|
||||
OpMethodBody &body) {
|
||||
// Check for the case where all successors were parsed.
|
||||
@@ -1186,23 +1346,26 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PrinterGen
|
||||
|
||||
/// The code snippet used to generate a printer call for a region of an
|
||||
// operation that has the SingleBlockImplicitTerminator trait.
|
||||
///
|
||||
/// {0}: The name of the region.
|
||||
const char *regionSingleBlockImplicitTerminatorPrinterCode = R"(
|
||||
{
|
||||
bool printTerminator = true;
|
||||
if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) {{
|
||||
printTerminator = !term->getMutableAttrDict().empty() ||
|
||||
term->getNumOperands() != 0 ||
|
||||
term->getNumResults() != 0;
|
||||
}
|
||||
p.printRegion({0}, /*printEntryBlockArgs=*/true,
|
||||
/*printBlockTerminators=*/printTerminator);
|
||||
}
|
||||
)";
|
||||
|
||||
/// Generate the printer for the 'attr-dict' directive.
|
||||
static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
|
||||
OpMethodBody &body, bool withKeyword) {
|
||||
// Collect all of the attributes used in the format, these will be elided.
|
||||
SmallVector<const NamedAttribute *, 1> usedAttributes;
|
||||
for (auto &it : fmt.elements) {
|
||||
if (auto *attr = dyn_cast<AttributeVariable>(it.get()))
|
||||
usedAttributes.push_back(attr->getVar());
|
||||
// Collect the optional attributes.
|
||||
if (auto *opt = dyn_cast<OptionalElement>(it.get())) {
|
||||
for (auto &elem : opt->getElements()) {
|
||||
if (auto *attr = dyn_cast<AttributeVariable>(&elem))
|
||||
usedAttributes.push_back(attr->getVar());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
body << " p.printOptionalAttrDict" << (withKeyword ? "WithKeyword" : "")
|
||||
<< "(getAttrs(), /*elidedAttrs=*/{";
|
||||
// Elide the variadic segment size attributes if necessary.
|
||||
@@ -1210,9 +1373,9 @@ static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
|
||||
body << "\"operand_segment_sizes\", ";
|
||||
if (!fmt.allResultTypes && op.getTrait("OpTrait::AttrSizedResultSegments"))
|
||||
body << "\"result_segment_sizes\", ";
|
||||
llvm::interleaveComma(usedAttributes, body, [&](const NamedAttribute *attr) {
|
||||
body << "\"" << attr->name << "\"";
|
||||
});
|
||||
llvm::interleaveComma(
|
||||
fmt.usedAttributes, body,
|
||||
[&](const NamedAttribute *attr) { body << "\"" << attr->name << "\""; });
|
||||
body << "});\n";
|
||||
}
|
||||
|
||||
@@ -1255,6 +1418,9 @@ static void genCustomDirectivePrinter(CustomDirective *customDir,
|
||||
} else if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
|
||||
body << operand->getVar()->name << "()";
|
||||
|
||||
} else if (auto *region = dyn_cast<RegionVariable>(¶m)) {
|
||||
body << region->getVar()->name << "()";
|
||||
|
||||
} else if (auto *successor = dyn_cast<SuccessorVariable>(¶m)) {
|
||||
body << successor->getVar()->name << "()";
|
||||
|
||||
@@ -1277,6 +1443,24 @@ static void genCustomDirectivePrinter(CustomDirective *customDir,
|
||||
body << ");\n";
|
||||
}
|
||||
|
||||
/// Generate the printer for a region with the given variable name.
|
||||
static void genRegionPrinter(const Twine ®ionName, OpMethodBody &body,
|
||||
bool hasImplicitTermTrait) {
|
||||
if (hasImplicitTermTrait)
|
||||
body << llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode,
|
||||
regionName);
|
||||
else
|
||||
body << " p.printRegion(" << regionName << ");\n";
|
||||
}
|
||||
static void genVariadicRegionPrinter(const Twine ®ionListName,
|
||||
OpMethodBody &body,
|
||||
bool hasImplicitTermTrait) {
|
||||
body << " llvm::interleaveComma(" << regionListName
|
||||
<< ", p, [&](::mlir::Region ®ion) {\n ";
|
||||
genRegionPrinter("region", body, hasImplicitTermTrait);
|
||||
body << " });\n";
|
||||
}
|
||||
|
||||
/// Generate the C++ for an operand to a (*-)type directive.
|
||||
static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
|
||||
if (isa<OperandsDirective>(arg))
|
||||
@@ -1296,10 +1480,9 @@ static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
|
||||
<< "().getType())";
|
||||
}
|
||||
|
||||
/// Generate the code for printing the given element.
|
||||
static void genElementPrinter(Element *element, OpMethodBody &body,
|
||||
OperationFormat &fmt, Operator &op,
|
||||
bool &shouldEmitSpace, bool &lastWasPunctuation) {
|
||||
void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
|
||||
Operator &op, bool &shouldEmitSpace,
|
||||
bool &lastWasPunctuation) {
|
||||
if (LiteralElement *literal = dyn_cast<LiteralElement>(element))
|
||||
return genLiteralPrinter(literal->getLiteral(), body, shouldEmitSpace,
|
||||
lastWasPunctuation);
|
||||
@@ -1314,6 +1497,11 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
|
||||
body << " if (" << var->name << "()) {\n";
|
||||
else if (var->isVariadic())
|
||||
body << " if (!" << var->name << "().empty()) {\n";
|
||||
} else if (auto *region = dyn_cast<RegionVariable>(anchor)) {
|
||||
const NamedRegion *var = region->getVar();
|
||||
// TODO: Add a check for optional here when ODS supports it.
|
||||
body << " if (!" << var->name << "().empty()) {\n";
|
||||
|
||||
} else {
|
||||
body << " if (getAttr(\""
|
||||
<< cast<AttributeVariable>(anchor)->getVar()->name << "\")) {\n";
|
||||
@@ -1332,7 +1520,7 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
|
||||
// Emit each of the elements.
|
||||
for (Element &childElement : elements) {
|
||||
if (&childElement != elidedAnchorElement) {
|
||||
genElementPrinter(&childElement, body, fmt, op, shouldEmitSpace,
|
||||
genElementPrinter(&childElement, body, op, shouldEmitSpace,
|
||||
lastWasPunctuation);
|
||||
}
|
||||
}
|
||||
@@ -1342,7 +1530,7 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
|
||||
|
||||
// Emit the attribute dictionary.
|
||||
if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
|
||||
genAttrDictPrinter(fmt, op, body, attrDict->isWithKeyword());
|
||||
genAttrDictPrinter(*this, op, body, attrDict->isWithKeyword());
|
||||
lastWasPunctuation = false;
|
||||
return;
|
||||
}
|
||||
@@ -1384,6 +1572,13 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
|
||||
} else {
|
||||
body << " p << " << operand->getVar()->name << "();\n";
|
||||
}
|
||||
} else if (auto *region = dyn_cast<RegionVariable>(element)) {
|
||||
const NamedRegion *var = region->getVar();
|
||||
if (var->isVariadic()) {
|
||||
genVariadicRegionPrinter(var->name + "()", body, hasImplicitTermTrait);
|
||||
} else {
|
||||
genRegionPrinter(var->name + "()", body, hasImplicitTermTrait);
|
||||
}
|
||||
} else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
|
||||
const NamedSuccessor *var = successor->getVar();
|
||||
if (var->isVariadic())
|
||||
@@ -1394,6 +1589,9 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
|
||||
genCustomDirectivePrinter(dir, body);
|
||||
} else if (isa<OperandsDirective>(element)) {
|
||||
body << " p << getOperation()->getOperands();\n";
|
||||
} else if (isa<RegionsDirective>(element)) {
|
||||
genVariadicRegionPrinter("getOperation()->getRegions()", body,
|
||||
hasImplicitTermTrait);
|
||||
} else if (isa<SuccessorsDirective>(element)) {
|
||||
body << " ::llvm::interleaveComma(getOperation()->getSuccessors(), p);\n";
|
||||
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
|
||||
@@ -1426,7 +1624,7 @@ void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
|
||||
// punctuation.
|
||||
bool shouldEmitSpace = true, lastWasPunctuation = false;
|
||||
for (auto &element : elements)
|
||||
genElementPrinter(element.get(), body, *this, op, shouldEmitSpace,
|
||||
genElementPrinter(element.get(), body, op, shouldEmitSpace,
|
||||
lastWasPunctuation);
|
||||
}
|
||||
|
||||
@@ -1460,6 +1658,7 @@ public:
|
||||
kw_custom,
|
||||
kw_functional_type,
|
||||
kw_operands,
|
||||
kw_regions,
|
||||
kw_results,
|
||||
kw_successors,
|
||||
kw_type,
|
||||
@@ -1663,6 +1862,7 @@ Token FormatLexer::lexIdentifier(const char *tokStart) {
|
||||
.Case("custom", Token::kw_custom)
|
||||
.Case("functional-type", Token::kw_functional_type)
|
||||
.Case("operands", Token::kw_operands)
|
||||
.Case("regions", Token::kw_regions)
|
||||
.Case("results", Token::kw_results)
|
||||
.Case("successors", Token::kw_successors)
|
||||
.Case("type", Token::kw_type)
|
||||
@@ -1676,7 +1876,8 @@ Token FormatLexer::lexIdentifier(const char *tokStart) {
|
||||
|
||||
/// Function to find an element within the given range that has the same name as
|
||||
/// 'name'.
|
||||
template <typename RangeT> static auto findArg(RangeT &&range, StringRef name) {
|
||||
template <typename RangeT>
|
||||
static auto findArg(RangeT &&range, StringRef name) {
|
||||
auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
|
||||
return it != range.end() ? &*it : nullptr;
|
||||
}
|
||||
@@ -1719,6 +1920,9 @@ private:
|
||||
verifyOperands(llvm::SMLoc loc,
|
||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
|
||||
|
||||
/// Verify the state of operation regions within the format.
|
||||
LogicalResult verifyRegions(llvm::SMLoc loc);
|
||||
|
||||
/// Verify the state of operation results within the format.
|
||||
LogicalResult
|
||||
verifyResults(llvm::SMLoc loc,
|
||||
@@ -1775,6 +1979,8 @@ private:
|
||||
Token tok, bool isTopLevel);
|
||||
LogicalResult parseOperandsDirective(std::unique_ptr<Element> &element,
|
||||
llvm::SMLoc loc, bool isTopLevel);
|
||||
LogicalResult parseRegionsDirective(std::unique_ptr<Element> &element,
|
||||
llvm::SMLoc loc, bool isTopLevel);
|
||||
LogicalResult parseResultsDirective(std::unique_ptr<Element> &element,
|
||||
llvm::SMLoc loc, bool isTopLevel);
|
||||
LogicalResult parseSuccessorsDirective(std::unique_ptr<Element> &element,
|
||||
@@ -1821,11 +2027,12 @@ private:
|
||||
|
||||
// The following are various bits of format state used for verification
|
||||
// during parsing.
|
||||
bool hasAllOperands = false, hasAttrDict = false;
|
||||
bool hasAllSuccessors = false;
|
||||
bool hasAttrDict = false;
|
||||
bool hasAllRegions = false, hasAllSuccessors = false;
|
||||
llvm::SmallBitVector seenOperandTypes, seenResultTypes;
|
||||
llvm::SmallSetVector<const NamedAttribute *, 8> seenAttrs;
|
||||
llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
|
||||
llvm::DenseSet<const NamedAttribute *> seenAttrs;
|
||||
llvm::DenseSet<const NamedRegion *> seenRegions;
|
||||
llvm::DenseSet<const NamedSuccessor *> seenSuccessors;
|
||||
llvm::DenseSet<const NamedTypeConstraint *> optionalVariables;
|
||||
};
|
||||
@@ -1867,13 +2074,11 @@ LogicalResult FormatParser::parse() {
|
||||
if (failed(verifyAttributes(loc)) ||
|
||||
failed(verifyResults(loc, variableTyResolver)) ||
|
||||
failed(verifyOperands(loc, variableTyResolver)) ||
|
||||
failed(verifySuccessors(loc)))
|
||||
failed(verifyRegions(loc)) || failed(verifySuccessors(loc)))
|
||||
return failure();
|
||||
|
||||
// Check to see if we are formatting all of the operands.
|
||||
fmt.allOperands = llvm::any_of(fmt.elements, [](auto &elt) {
|
||||
return isa<OperandsDirective>(elt.get());
|
||||
});
|
||||
// Collect the set of used attributes in the format.
|
||||
fmt.usedAttributes = seenAttrs.takeVector();
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -1953,7 +2158,7 @@ LogicalResult FormatParser::verifyOperands(
|
||||
NamedTypeConstraint &operand = op.getOperand(i);
|
||||
|
||||
// Check that the operand itself is in the format.
|
||||
if (!hasAllOperands && !seenOperands.count(&operand)) {
|
||||
if (!fmt.allOperands && !seenOperands.count(&operand)) {
|
||||
return emitErrorAndNote(loc,
|
||||
"operand #" + Twine(i) + ", named '" +
|
||||
operand.name + "', not found",
|
||||
@@ -1976,7 +2181,7 @@ LogicalResult FormatParser::verifyOperands(
|
||||
// Similarly to results, allow a custom builder for resolving the type if
|
||||
// we aren't using the 'operands' directive.
|
||||
Optional<StringRef> builder = operand.constraint.getBuilderCall();
|
||||
if (!builder || (hasAllOperands && operand.isVariableLength())) {
|
||||
if (!builder || (fmt.allOperands && operand.isVariableLength())) {
|
||||
return emitErrorAndNote(
|
||||
loc,
|
||||
"type of operand #" + Twine(i) + ", named '" + operand.name +
|
||||
@@ -1991,6 +2196,24 @@ LogicalResult FormatParser::verifyOperands(
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult FormatParser::verifyRegions(llvm::SMLoc loc) {
|
||||
// Check that all of the regions are within the format.
|
||||
if (hasAllRegions)
|
||||
return success();
|
||||
|
||||
for (unsigned i = 0, e = op.getNumRegions(); i != e; ++i) {
|
||||
const NamedRegion ®ion = op.getRegion(i);
|
||||
if (!seenRegions.count(®ion)) {
|
||||
return emitErrorAndNote(loc,
|
||||
"region #" + Twine(i) + ", named '" +
|
||||
region.name + "', not found",
|
||||
"suggest adding a '$" + region.name +
|
||||
"' directive to the custom assembly format");
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult FormatParser::verifyResults(
|
||||
llvm::SMLoc loc,
|
||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
|
||||
@@ -2108,7 +2331,7 @@ ConstArgument FormatParser::findSeenArg(StringRef name) {
|
||||
if (const NamedTypeConstraint *arg = findArg(op.getResults(), name))
|
||||
return seenResultTypes.test(arg - op.result_begin()) ? arg : nullptr;
|
||||
if (const NamedAttribute *attr = findArg(op.getAttributes(), name))
|
||||
return seenAttrs.find_as(attr) != seenAttrs.end() ? attr : nullptr;
|
||||
return seenAttrs.count(attr) ? attr : nullptr;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -2142,7 +2365,7 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
|
||||
// op.
|
||||
/// Attributes
|
||||
if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) {
|
||||
if (isTopLevel && !seenAttrs.insert(attr).second)
|
||||
if (isTopLevel && !seenAttrs.insert(attr))
|
||||
return emitError(loc, "attribute '" + name + "' is already bound");
|
||||
element = std::make_unique<AttributeVariable>(attr);
|
||||
return success();
|
||||
@@ -2150,12 +2373,21 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
|
||||
/// Operands
|
||||
if (const NamedTypeConstraint *operand = findArg(op.getOperands(), name)) {
|
||||
if (isTopLevel) {
|
||||
if (hasAllOperands || !seenOperands.insert(operand).second)
|
||||
if (fmt.allOperands || !seenOperands.insert(operand).second)
|
||||
return emitError(loc, "operand '" + name + "' is already bound");
|
||||
}
|
||||
element = std::make_unique<OperandVariable>(operand);
|
||||
return success();
|
||||
}
|
||||
/// Regions
|
||||
if (const NamedRegion *region = findArg(op.getRegions(), name)) {
|
||||
if (!isTopLevel)
|
||||
return emitError(loc, "regions can only be used at the top level");
|
||||
if (hasAllRegions || !seenRegions.insert(region).second)
|
||||
return emitError(loc, "region '" + name + "' is already bound");
|
||||
element = std::make_unique<RegionVariable>(region);
|
||||
return success();
|
||||
}
|
||||
/// Results.
|
||||
if (const auto *result = findArg(op.getResults(), name)) {
|
||||
if (isTopLevel)
|
||||
@@ -2172,8 +2404,8 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
|
||||
element = std::make_unique<SuccessorVariable>(successor);
|
||||
return success();
|
||||
}
|
||||
return emitError(
|
||||
loc, "expected variable to refer to an argument, result, or successor");
|
||||
return emitError(loc, "expected variable to refer to an argument, region, "
|
||||
"result, or successor");
|
||||
}
|
||||
|
||||
LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
|
||||
@@ -2194,6 +2426,8 @@ LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
|
||||
return parseFunctionalTypeDirective(element, dirTok, isTopLevel);
|
||||
case Token::kw_operands:
|
||||
return parseOperandsDirective(element, dirTok.getLoc(), isTopLevel);
|
||||
case Token::kw_regions:
|
||||
return parseRegionsDirective(element, dirTok.getLoc(), isTopLevel);
|
||||
case Token::kw_results:
|
||||
return parseResultsDirective(element, dirTok.getLoc(), isTopLevel);
|
||||
case Token::kw_successors:
|
||||
@@ -2247,9 +2481,10 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
|
||||
// optional fashion.
|
||||
Element *firstElement = &*elements.front();
|
||||
if (!isa<AttributeVariable>(firstElement) &&
|
||||
!isa<LiteralElement>(firstElement) && !isa<OperandVariable>(firstElement))
|
||||
!isa<LiteralElement>(firstElement) &&
|
||||
!isa<OperandVariable>(firstElement) && !isa<RegionVariable>(firstElement))
|
||||
return emitError(curLoc, "first element of an operand group must be an "
|
||||
"attribute, literal, or operand");
|
||||
"attribute, literal, operand, or region");
|
||||
|
||||
// After parsing all of the elements, ensure that all type directives refer
|
||||
// only to elements within the group.
|
||||
@@ -2314,10 +2549,15 @@ LogicalResult FormatParser::parseOptionalChildElement(
|
||||
seenVariables.insert(ele->getVar());
|
||||
return success();
|
||||
})
|
||||
.Case<RegionVariable>([&](RegionVariable *) {
|
||||
// TODO: When ODS has proper support for marking "optional" regions, add
|
||||
// a check here.
|
||||
return success();
|
||||
})
|
||||
// Literals, custom directives, and type directives may be used,
|
||||
// but they can't anchor the group.
|
||||
.Case<LiteralElement, CustomDirective, TypeDirective,
|
||||
FunctionalTypeDirective>([&](Element *) {
|
||||
.Case<LiteralElement, CustomDirective, FunctionalTypeDirective,
|
||||
OptionalElement, TypeDirective>([&](Element *) {
|
||||
if (isAnchor)
|
||||
return emitError(childLoc, "only variables can be used to anchor "
|
||||
"an optional group");
|
||||
@@ -2401,7 +2641,7 @@ LogicalResult FormatParser::parseCustomDirectiveParameter(
|
||||
return failure();
|
||||
|
||||
// Verify that the element can be placed within a custom directive.
|
||||
if (!isa<TypeDirective, AttributeVariable, OperandVariable,
|
||||
if (!isa<TypeDirective, AttributeVariable, OperandVariable, RegionVariable,
|
||||
SuccessorVariable>(parameters.back().get())) {
|
||||
return emitError(childLoc, "only variables and types may be used as "
|
||||
"parameters to a custom directive");
|
||||
@@ -2433,13 +2673,27 @@ FormatParser::parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
|
||||
LogicalResult
|
||||
FormatParser::parseOperandsDirective(std::unique_ptr<Element> &element,
|
||||
llvm::SMLoc loc, bool isTopLevel) {
|
||||
if (isTopLevel && (hasAllOperands || !seenOperands.empty()))
|
||||
return emitError(loc, "'operands' directive creates overlap in format");
|
||||
hasAllOperands = true;
|
||||
if (isTopLevel) {
|
||||
if (fmt.allOperands || !seenOperands.empty())
|
||||
return emitError(loc, "'operands' directive creates overlap in format");
|
||||
fmt.allOperands = true;
|
||||
}
|
||||
element = std::make_unique<OperandsDirective>();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
FormatParser::parseRegionsDirective(std::unique_ptr<Element> &element,
|
||||
llvm::SMLoc loc, bool isTopLevel) {
|
||||
if (!isTopLevel)
|
||||
return emitError(loc, "'regions' is only valid as a top-level directive");
|
||||
if (hasAllRegions || !seenRegions.empty())
|
||||
return emitError(loc, "'regions' directive creates overlap in format");
|
||||
hasAllRegions = true;
|
||||
element = std::make_unique<RegionsDirective>();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
FormatParser::parseResultsDirective(std::unique_ptr<Element> &element,
|
||||
llvm::SMLoc loc, bool isTopLevel) {
|
||||
|
||||
Reference in New Issue
Block a user