[mlir][OpFormatGen] Add initial support for regions in the custom op assembly format

This adds some initial support for regions and does not support formatting the specific arguments of a region. For now this can be achieved by using a custom directive that formats the arguments and then parses the region.

Differential Revision: https://reviews.llvm.org/D86760
This commit is contained in:
River Riddle
2020-08-31 12:33:55 -07:00
parent 24b88920fe
commit eaeadce9bd
12 changed files with 631 additions and 129 deletions

View File

@@ -681,6 +681,10 @@ The available directives are as follows:
- Represents all of the operands of an operation.
* `regions`
- Represents all of the regions of an operation.
* `results`
- Represents all of the results of an operation.
@@ -700,13 +704,14 @@ The available directives are as follows:
A literal is either a keyword or punctuation surrounded by \`\`.
The following are the set of valid punctuation:
`:`, `,`, `=`, `<`, `>`, `(`, `)`, `[`, `]`, `->`
`:`, `,`, `=`, `<`, `>`, `(`, `)`, `{`, `}`, `[`, `]`, `->`
#### Variables
A variable is an entity that has been registered on the operation itself, i.e.
an argument(attribute or operand), result, successor, etc. In the `CallOp`
example above, the variables would be `$callee` and `$args`.
an argument(attribute or operand), region, result, successor, etc. In the
`CallOp` example above, the variables would be `$callee` and `$args`.
Attribute variables are printed with their respective value type, unless that
value type is buildable. In those cases, the type of the attribute is elided.
@@ -747,6 +752,9 @@ declarative parameter to `parse` method argument is detailed below:
- Single: `OpAsmParser::OperandType &`
- Optional: `Optional<OpAsmParser::OperandType> &`
- Variadic: `SmallVectorImpl<OpAsmParser::OperandType> &`
* Region Variables
- Single: `Region &`
- Variadic: `SmallVectorImpl<std::unique_ptr<Region>> &`
* Successor Variables
- Single: `Block *&`
- Variadic: `SmallVectorImpl<Block *> &`
@@ -770,6 +778,9 @@ declarative parameter to `print` method argument is detailed below:
- Single: `Value`
- Optional: `Value`
- Variadic: `OperandRange`
* Region Variables
- Single: `Region &`
- Variadic: `MutableArrayRef<Region>`
* Successor Variables
- Single: `Block *`
- Variadic: `SuccessorRange`
@@ -788,8 +799,8 @@ of the assembly format can be marked as `optional` based on the presence of this
information. An optional group is defined by wrapping a set of elements within
`()` followed by a `?` and has the following requirements:
* The first element of the group must either be a literal, attribute, or an
operand.
* The first element of the group must either be a attribute, literal, operand,
or region.
- This is because the first element must be optionally parsable.
* Exactly one argument variable within the group must be marked as the anchor
of the group.
@@ -797,11 +808,15 @@ information. An optional group is defined by wrapping a set of elements within
should be printed/parsed.
- An element is marked as the anchor by adding a trailing `^`.
- The first element is *not* required to be the anchor of the group.
- When a non-variadic region anchors a group, the detector for printing
the group is if the region is empty.
* Literals, variables, custom directives, and type directives are the only
valid elements within the group.
- Any attribute variable may be used, but only optional attributes can be
marked as the anchor.
- Only variadic or optional operand arguments can be used.
- All region variables can be used. When a non-variable length region is
used, if the group is not present the region is empty.
- The operands to a type directive must be defined within the optional
group.
@@ -853,18 +868,22 @@ foo.op
The format specification has a certain set of requirements that must be adhered
to:
1. The output and operation name are never shown as they are fixed and cannot be
altered.
1. All operands within the operation must appear within the format, either
individually or with the `operands` directive.
1. All operand and result types must appear within the format using the various
`type` directives, either individually or with the `operands` or `results`
directives.
1. The `attr-dict` directive must always be present.
1. Must not contain overlapping information; e.g. multiple instances of
'attr-dict', types, operands, etc.
- Note that `attr-dict` does not overlap with individual attributes. These
attributes will simply be elided when printing the attribute dictionary.
1. The output and operation name are never shown as they are fixed and cannot
be altered.
1. All operands within the operation must appear within the format, either
individually or with the `operands` directive.
1. All regions within the operation must appear within the format, either
individually or with the `regions` directive.
1. All successors within the operation must appear within the format, either
individually or with the `successors` directive.
1. All operand and result types must appear within the format using the various
`type` directives, either individually or with the `operands` or `results`
directives.
1. The `attr-dict` directive must always be present.
1. Must not contain overlapping information; e.g. multiple instances of
'attr-dict', types, operands, etc.
- Note that `attr-dict` does not overlap with individual attributes. These
attributes will simply be elided when printing the attribute dictionary.
##### Type Inference

View File

@@ -424,8 +424,8 @@ public:
Type type,
StringRef attrName,
NamedAttrList &attrs) = 0;
OptionalParseResult parseOptionalAttribute(Attribute &result,
StringRef attrName,
template <typename AttrT>
OptionalParseResult parseOptionalAttribute(AttrT &result, StringRef attrName,
NamedAttrList &attrs) {
return parseOptionalAttribute(result, Type(), attrName, attrs);
}
@@ -433,6 +433,7 @@ public:
/// Specialized variants of `parseOptionalAttribute` that remove potential
/// ambiguities in syntax.
virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
Type type,
StringRef attrName,
NamedAttrList &attrs) = 0;
@@ -621,16 +622,23 @@ public:
/// can only be set to true for regions attached to operations that are
/// "IsolatedFromAbove".
virtual ParseResult parseRegion(Region &region,
ArrayRef<OperandType> arguments,
ArrayRef<Type> argTypes,
ArrayRef<OperandType> arguments = {},
ArrayRef<Type> argTypes = {},
bool enableNameShadowing = false) = 0;
/// Parses a region if present.
virtual ParseResult parseOptionalRegion(Region &region,
ArrayRef<OperandType> arguments,
ArrayRef<Type> argTypes,
ArrayRef<OperandType> arguments = {},
ArrayRef<Type> argTypes = {},
bool enableNameShadowing = false) = 0;
/// Parses a region if present. If the region is present, a new region is
/// allocated and placed in `region`. If no region is present or on failure,
/// `region` remains untouched.
virtual OptionalParseResult parseOptionalRegion(
std::unique_ptr<Region> &region, ArrayRef<OperandType> arguments = {},
ArrayRef<Type> argTypes = {}, bool enableNameShadowing = false) = 0;
/// Parse a region argument, this argument is resolved when calling
/// 'parseRegion'.
virtual ParseResult parseRegionArgument(OperandType &argument) = 0;

View File

@@ -414,6 +414,10 @@ public:
/// region is null, a new empty region will be attached to the Operation.
void addRegion(std::unique_ptr<Region> &&region);
/// Take ownership of a set of regions that should be attached to the
/// Operation.
void addRegions(MutableArrayRef<std::unique_ptr<Region>> regions);
/// Get the context held by this operation state.
MLIRContext *getContext() const { return location->getContext(); }
};

View File

@@ -199,6 +199,12 @@ void OperationState::addRegion(std::unique_ptr<Region> &&region) {
regions.push_back(std::move(region));
}
void OperationState::addRegions(
MutableArrayRef<std::unique_ptr<Region>> regions) {
for (std::unique_ptr<Region> &region : regions)
addRegion(std::move(region));
}
//===----------------------------------------------------------------------===//
// OperandStorage
//===----------------------------------------------------------------------===//

View File

@@ -221,8 +221,9 @@ OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
return result;
}
}
OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute) {
return parseOptionalAttributeWithToken(Token::l_square, attribute);
OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute,
Type type) {
return parseOptionalAttributeWithToken(Token::l_square, attribute, type);
}
/// Attribute dictionary.

View File

@@ -1045,7 +1045,6 @@ public:
}
/// Parse an optional attribute.
/// Template utilities to simplify specifying multiple derived overloads.
template <typename AttrT>
OptionalParseResult
parseOptionalAttributeAndAddToList(AttrT &result, Type type,
@@ -1056,25 +1055,15 @@ public:
attrs.push_back(parser.builder.getNamedAttr(attrName, result));
return parseResult;
}
template <typename AttrT>
OptionalParseResult parseOptionalAttributeAndAddToList(AttrT &result,
StringRef attrName,
NamedAttrList &attrs) {
OptionalParseResult parseResult = parser.parseOptionalAttribute(result);
if (parseResult.hasValue() && succeeded(*parseResult))
attrs.push_back(parser.builder.getNamedAttr(attrName, result));
return parseResult;
}
OptionalParseResult parseOptionalAttribute(Attribute &result, Type type,
StringRef attrName,
NamedAttrList &attrs) override {
return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
}
OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
OptionalParseResult parseOptionalAttribute(ArrayAttr &result, Type type,
StringRef attrName,
NamedAttrList &attrs) override {
return parseOptionalAttributeAndAddToList(result, attrName, attrs);
return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
}
/// Parse a named dictionary into 'result' if it is present.
@@ -1355,6 +1344,23 @@ public:
return parseRegion(region, arguments, argTypes, enableNameShadowing);
}
/// Parses a region if present. If the region is present, a new region is
/// allocated and placed in `region`. If no region is present, `region`
/// remains untouched.
OptionalParseResult
parseOptionalRegion(std::unique_ptr<Region> &region,
ArrayRef<OperandType> arguments, ArrayRef<Type> argTypes,
bool enableNameShadowing = false) override {
if (parser.getToken().isNot(Token::l_brace))
return llvm::None;
std::unique_ptr<Region> newRegion = std::make_unique<Region>();
if (parseRegion(*newRegion, arguments, argTypes, enableNameShadowing))
return failure();
region = std::move(newRegion);
return success();
}
/// Parse a region argument. The type of the argument will be resolved later
/// by a call to `parseRegion`.
ParseResult parseRegionArgument(OperandType &argument) override {

View File

@@ -187,7 +187,7 @@ public:
/// Parse an optional attribute with the provided type.
OptionalParseResult parseOptionalAttribute(Attribute &attribute,
Type type = {});
OptionalParseResult parseOptionalAttribute(ArrayAttr &attribute);
OptionalParseResult parseOptionalAttribute(ArrayAttr &attribute, Type type);
/// Parse an optional attribute that is demarcated by a specific token.
template <typename AttributeT>
@@ -197,8 +197,8 @@ public:
if (getToken().isNot(kind))
return llvm::None;
if (Attribute parsedAttr = parseAttribute()) {
attr = parsedAttr.cast<ArrayAttr>();
if (Attribute parsedAttr = parseAttribute(type)) {
attr = parsedAttr.cast<AttributeT>();
return success();
}
return failure();

View File

@@ -319,6 +319,19 @@ static ParseResult parseCustomDirectiveOperandsAndTypes(
return failure();
return success();
}
static ParseResult parseCustomDirectiveRegions(
OpAsmParser &parser, Region &region,
SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
if (parser.parseRegion(region))
return failure();
if (failed(parser.parseOptionalComma()))
return success();
std::unique_ptr<Region> varRegion = std::make_unique<Region>();
if (parser.parseRegion(*varRegion))
return failure();
varRegions.emplace_back(std::move(varRegion));
return success();
}
static ParseResult
parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
SmallVectorImpl<Block *> &varSuccessors) {
@@ -361,6 +374,15 @@ printCustomDirectiveOperandsAndTypes(OpAsmPrinter &printer, Value operand,
printCustomDirectiveResults(printer, operandType, optOperandType,
varOperandTypes);
}
static void printCustomDirectiveRegions(OpAsmPrinter &printer, Region &region,
MutableArrayRef<Region> varRegions) {
printer.printRegion(region);
if (!varRegions.empty()) {
printer << ", ";
for (Region &region : varRegions)
printer.printRegion(region);
}
}
static void printCustomDirectiveSuccessors(OpAsmPrinter &printer,
Block *successor,
SuccessorRange varSuccessors) {

View File

@@ -1161,8 +1161,13 @@ def TestRecursiveRewriteOp : TEST_Op<"recursive_rewrite"> {
//===----------------------------------------------------------------------===//
def TestRegionBuilderOp : TEST_Op<"region_builder">;
def TestReturnOp : TEST_Op<"return", [ReturnLike, Terminator]>,
Arguments<(ins Variadic<AnyType>)>;
def TestReturnOp : TEST_Op<"return", [ReturnLike, Terminator]> {
let arguments = (ins Variadic<AnyType>);
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state",
[{ build(builder, state, {}); }]>
];
}
def TestCastOp : TEST_Op<"cast">,
Arguments<(ins Variadic<AnyType>)>, Results<(outs AnyType)>;
def TestInvalidOp : TEST_Op<"invalid", [Terminator]>,
@@ -1333,6 +1338,43 @@ def FormatBuildableTypeOp : TEST_Op<"format_buildable_type_op"> {
let assemblyFormat = "$buildable attr-dict";
}
// Test various mixings of region formatting.
class FormatRegionBase<string suffix, string fmt>
: TEST_Op<"format_region_" # suffix # "_op"> {
let regions = (region AnyRegion:$region);
let assemblyFormat = fmt;
}
def FormatRegionAOp : FormatRegionBase<"a", [{
regions attr-dict
}]>;
def FormatRegionBOp : FormatRegionBase<"b", [{
$region attr-dict
}]>;
def FormatRegionCOp : FormatRegionBase<"c", [{
(`region` $region^)? attr-dict
}]>;
class FormatVariadicRegionBase<string suffix, string fmt>
: TEST_Op<"format_variadic_region_" # suffix # "_op"> {
let regions = (region VariadicRegion<AnyRegion>:$regions);
let assemblyFormat = fmt;
}
def FormatVariadicRegionAOp : FormatVariadicRegionBase<"a", [{
$regions attr-dict
}]>;
def FormatVariadicRegionBOp : FormatVariadicRegionBase<"b", [{
($regions^ `found_regions`)? attr-dict
}]>;
class FormatRegionImplicitTerminatorBase<string suffix, string fmt>
: TEST_Op<"format_implicit_terminator_region_" # suffix # "_op",
[SingleBlockImplicitTerminator<"TestReturnOp">]> {
let regions = (region AnyRegion:$region);
let assemblyFormat = fmt;
}
def FormatFormatRegionImplicitTerminatorAOp
: FormatRegionImplicitTerminatorBase<"a", [{
$region attr-dict
}]>;
// Test various mixings of result type formatting.
class FormatResultBase<string suffix, string fmt>
: TEST_Op<"format_result_" # suffix # "_op"> {
@@ -1454,6 +1496,16 @@ def FormatCustomDirectiveOperandsAndTypes
}];
}
def FormatCustomDirectiveRegions : TEST_Op<"format_custom_directive_regions"> {
let regions = (region AnyRegion:$region, VariadicRegion<AnyRegion>:$regions);
let assemblyFormat = [{
custom<CustomDirectiveRegions>(
$region, $regions
)
attr-dict
}];
}
def FormatCustomDirectiveResults
: TEST_Op<"format_custom_directive_results", [AttrSizedResultSegments]> {
let results = (outs AnyType:$result, Optional<AnyType>:$optResult,

View File

@@ -133,6 +133,28 @@ def DirectiveOperandsValid : TestFormat_Op<"operands_valid", [{
operands attr-dict
}]>;
//===----------------------------------------------------------------------===//
// regions
// CHECK: error: 'regions' directive creates overlap in format
def DirectiveRegionsInvalidA : TestFormat_Op<"regions_invalid_a", [{
regions regions attr-dict
}]>;
// CHECK: error: 'regions' directive creates overlap in format
def DirectiveRegionsInvalidB : TestFormat_Op<"regions_invalid_b", [{
$region regions attr-dict
}]> {
let regions = (region AnyRegion:$region);
}
// CHECK: error: 'regions' is only valid as a top-level directive
def DirectiveRegionsInvalidC : TestFormat_Op<"regions_invalid_c", [{
type(regions)
}]>;
// CHECK-NOT: error:
def DirectiveRegionsValid : TestFormat_Op<"regions_valid", [{
regions attr-dict
}]>;
//===----------------------------------------------------------------------===//
// results
@@ -249,7 +271,7 @@ def OptionalInvalidB : TestFormat_Op<"optional_invalid_b", [{
def OptionalInvalidC : TestFormat_Op<"optional_invalid_c", [{
($attr)? attr-dict
}]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
// CHECK: error: first element of an operand group must be an attribute, literal, or operand
// CHECK: error: first element of an operand group must be an attribute, literal, operand, or region
def OptionalInvalidD : TestFormat_Op<"optional_invalid_d", [{
(type($operand) $operand^)? attr-dict
}]>, Arguments<(ins Optional<I64>:$operand)>;
@@ -290,7 +312,7 @@ def OptionalInvalidL : TestFormat_Op<"optional_invalid_l", [{
// Variables
//===----------------------------------------------------------------------===//
// CHECK: error: expected variable to refer to an argument, result, or successor
// CHECK: error: expected variable to refer to an argument, region, result, or successor
def VariableInvalidA : TestFormat_Op<"variable_invalid_a", [{
$unknown_arg attr-dict
}]>;
@@ -330,11 +352,35 @@ def VariableInvalidH : TestFormat_Op<"variable_invalid_h", [{
def VariableInvalidI : TestFormat_Op<"variable_invalid_i", [{
(`foo` $attr^)? `:` attr-dict
}]>, Arguments<(ins OptionalAttr<ElementsAttr>:$attr)>;
// CHECK-NOT: error:
// CHECK: error: region 'region' is already bound
def VariableInvalidJ : TestFormat_Op<"variable_invalid_j", [{
$region $region attr-dict
}]> {
let regions = (region AnyRegion:$region);
}
// CHECK: error: region 'region' is already bound
def VariableInvalidK : TestFormat_Op<"variable_invalid_K", [{
regions $region attr-dict
}]> {
let regions = (region AnyRegion:$region);
}
// CHECK: error: regions can only be used at the top level
def VariableInvalidL : TestFormat_Op<"variable_invalid_l", [{
type($region)
}]> {
let regions = (region AnyRegion:$region);
}
// CHECK: error: region #0, named 'region', not found
def VariableInvalidM : TestFormat_Op<"variable_invalid_m", [{
attr-dict
}]> {
let regions = (region AnyRegion:$region);
}
// CHECK-NOT: error:
def VariableValidA : TestFormat_Op<"variable_valid_a", [{
$attr `:` attr-dict
}]>, Arguments<(ins OptionalAttr<I1Attr>:$attr)>;
def VariableInvalidK : TestFormat_Op<"variable_invalid_k", [{
def VariableValidB : TestFormat_Op<"variable_valid_b", [{
(`foo` $attr^)? `:` attr-dict
}]>, Arguments<(ins OptionalAttr<I1Attr>:$attr)>;

View File

@@ -40,6 +40,72 @@ test.format_attr_dict_w_keyword attributes {attr = 10 : i64, opt_attr = 10 : i64
// CHECK: test.format_buildable_type_op %[[I64]]
%ignored = test.format_buildable_type_op %i64
//===----------------------------------------------------------------------===//
// Format regions
//===----------------------------------------------------------------------===//
// CHECK: test.format_region_a_op {
// CHECK-NEXT: test.return
test.format_region_a_op {
"test.return"() : () -> ()
}
// CHECK: test.format_region_b_op {
// CHECK-NEXT: test.return
test.format_region_b_op {
"test.return"() : () -> ()
}
// CHECK: test.format_region_c_op region {
// CHECK-NEXT: test.return
test.format_region_c_op region {
"test.return"() : () -> ()
}
// CHECK: test.format_region_c_op
// CHECK-NOT: region {
test.format_region_c_op
// CHECK: test.format_variadic_region_a_op {
// CHECK-NEXT: test.return
// CHECK-NEXT: }, {
// CHECK-NEXT: test.return
// CHECK-NEXT: }
test.format_variadic_region_a_op {
"test.return"() : () -> ()
}, {
"test.return"() : () -> ()
}
// CHECK: test.format_variadic_region_b_op {
// CHECK-NEXT: test.return
// CHECK-NEXT: }, {
// CHECK-NEXT: test.return
// CHECK-NEXT: } found_regions
test.format_variadic_region_b_op {
"test.return"() : () -> ()
}, {
"test.return"() : () -> ()
} found_regions
// CHECK: test.format_variadic_region_b_op
// CHECK-NOT: {
// CHECK-NOT: found_regions
test.format_variadic_region_b_op
// CHECK: test.format_implicit_terminator_region_a_op {
// CHECK-NEXT: }
test.format_implicit_terminator_region_a_op {
"test.return"() : () -> ()
}
// CHECK: test.format_implicit_terminator_region_a_op {
// CHECK-NEXT: test.return"() {foo.attr
test.format_implicit_terminator_region_a_op {
"test.return"() {foo.attr} : () -> ()
}
// CHECK: test.format_implicit_terminator_region_a_op {
// CHECK-NEXT: test.return"(%[[I64]]) : (i64)
test.format_implicit_terminator_region_a_op {
"test.return"(%i64) : (i64) -> ()
}
//===----------------------------------------------------------------------===//
// Format results
//===----------------------------------------------------------------------===//
@@ -147,6 +213,24 @@ test.format_custom_directive_operands_and_types %i64, %i64 -> (%i64) : i64, i64
// CHECK: test.format_custom_directive_operands_and_types %[[I64]] -> (%[[I64]]) : i64 -> (i64)
test.format_custom_directive_operands_and_types %i64 -> (%i64) : i64 -> (i64)
// CHECK: test.format_custom_directive_regions {
// CHECK-NEXT: test.return
// CHECK-NEXT: }
test.format_custom_directive_regions {
"test.return"() : () -> ()
}
// CHECK: test.format_custom_directive_regions {
// CHECK-NEXT: test.return
// CHECK-NEXT: }, {
// CHECK-NEXT: test.return
// CHECK-NEXT: }
test.format_custom_directive_regions {
"test.return"() : () -> ()
}, {
"test.return"() : () -> ()
}
// CHECK: test.format_custom_directive_results : i64, i64 -> (i64)
test.format_custom_directive_results : i64, i64 -> (i64)

View File

@@ -16,6 +16,7 @@
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -48,6 +49,7 @@ public:
CustomDirective,
FunctionalTypeDirective,
OperandsDirective,
RegionsDirective,
ResultsDirective,
SuccessorsDirective,
TypeDirective,
@@ -58,6 +60,7 @@ public:
/// This element is an variable value.
AttributeVariable,
OperandVariable,
RegionVariable,
ResultVariable,
SuccessorVariable,
@@ -119,6 +122,10 @@ struct AttributeVariable
using OperandVariable =
VariableElement<NamedTypeConstraint, Element::Kind::OperandVariable>;
/// This class represents a variable that refers to a region.
using RegionVariable =
VariableElement<NamedRegion, Element::Kind::RegionVariable>;
/// This class represents a variable that refers to a result.
using ResultVariable =
VariableElement<NamedTypeConstraint, Element::Kind::ResultVariable>;
@@ -133,7 +140,8 @@ using SuccessorVariable =
namespace {
/// This class implements single kind directives.
template <Element::Kind type> class DirectiveElement : public Element {
template <Element::Kind type>
class DirectiveElement : public Element {
public:
DirectiveElement() : Element(type){};
static bool classof(const Element *ele) { return ele->getKind() == type; }
@@ -142,6 +150,10 @@ public:
/// all of the operands of an operation.
using OperandsDirective = DirectiveElement<Element::Kind::OperandsDirective>;
/// This class represents the `regions` directive. This directive represents
/// all of the regions of an operation.
using RegionsDirective = DirectiveElement<Element::Kind::ResultsDirective>;
/// This class represents the `results` directive. This directive represents
/// all of the results of an operation.
using ResultsDirective = DirectiveElement<Element::Kind::ResultsDirective>;
@@ -350,13 +362,23 @@ struct OperationFormat {
: allOperands(false), allOperandTypes(false), allResultTypes(false) {
operandTypes.resize(op.getNumOperands(), TypeResolution());
resultTypes.resize(op.getNumResults(), TypeResolution());
hasImplicitTermTrait =
llvm::any_of(op.getTraits(), [](const OpTrait &trait) {
return trait.getDef().isSubClassOf("SingleBlockImplicitTerminator");
});
}
/// Generate the operation parser from this format.
void genParser(Operator &op, OpClass &opClass);
/// Generate the parser code for a specific format element.
void genElementParser(Element *element, OpMethodBody &body,
FmtContext &attrTypeCtx);
/// Generate the c++ to resolve the types of operands and results during
/// parsing.
void genParserTypeResolution(Operator &op, OpMethodBody &body);
/// Generate the c++ to resolve regions during parsing.
void genParserRegionResolution(Operator &op, OpMethodBody &body);
/// Generate the c++ to resolve successors during parsing.
void genParserSuccessorResolution(Operator &op, OpMethodBody &body);
/// Generate the c++ to handling variadic segment size traits.
@@ -365,6 +387,10 @@ struct OperationFormat {
/// Generate the operation printer from this format.
void genPrinter(Operator &op, OpClass &opClass);
/// Generate the printer code for a specific format element.
void genElementPrinter(Element *element, OpMethodBody &body, Operator &op,
bool &shouldEmitSpace, bool &lastWasPunctuation);
/// The various elements in this format.
std::vector<std::unique_ptr<Element>> elements;
@@ -372,11 +398,18 @@ struct OperationFormat {
/// contains these, it can not contain individual type resolvers.
bool allOperands, allOperandTypes, allResultTypes;
/// A flag indicating if this operation has the SingleBlockImplicitTerminator
/// trait.
bool hasImplicitTermTrait;
/// A map of buildable types to indices.
llvm::MapVector<StringRef, int, llvm::StringMap<int>> buildableTypes;
/// The index of the buildable type, if valid, for every operand and result.
std::vector<TypeResolution> operandTypes, resultTypes;
/// The set of attributes explicitly used within the format.
SmallVector<const NamedAttribute *, 8> usedAttributes;
};
} // end anonymous namespace
@@ -541,6 +574,60 @@ const char *const functionalTypeParserCode = R"(
{1}Types = {0}__{1}_functionType.getResults();
)";
/// The code snippet used to generate a parser call for a region list.
///
/// {0}: The name for the region list.
const char *regionListParserCode = R"(
{
std::unique_ptr<::mlir::Region> region;
auto firstRegionResult = parser.parseOptionalRegion(region);
if (firstRegionResult.hasValue()) {
if (failed(*firstRegionResult))
return failure();
{0}Regions.emplace_back(std::move(region));
// Parse any trailing regions.
while (succeeded(parser.parseOptionalComma())) {
region = std::make_unique<::mlir::Region>();
if (parser.parseRegion(*region))
return failure();
{0}Regions.emplace_back(std::move(region));
}
}
}
)";
/// The code snippet used to ensure a list of regions have terminators.
///
/// {0}: The name of the region list.
const char *regionListEnsureTerminatorParserCode = R"(
for (auto &region : {0}Regions)
ensureTerminator(*region, parser.getBuilder(), result.location);
)";
/// The code snippet used to generate a parser call for an optional region.
///
/// {0}: The name of the region.
const char *optionalRegionParserCode = R"(
if (parser.parseOptionalRegion(*{0}Region))
return failure();
)";
/// The code snippet used to generate a parser call for a region.
///
/// {0}: The name of the region.
const char *regionParserCode = R"(
if (parser.parseRegion(*{0}Region))
return failure();
)";
/// The code snippet used to ensure a region has a terminator.
///
/// {0}: The name of the region.
const char *regionEnsureTerminatorParserCode = R"(
ensureTerminator(*{0}Region, parser.getBuilder(), result.location);
)";
/// The code snippet used to generate a parser call for a successor list.
///
/// {0}: The name for the successor list.
@@ -658,6 +745,10 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
body << " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
"allOperands;\n";
} else if (isa<RegionsDirective>(element)) {
body << " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
"fullRegions;\n";
} else if (isa<SuccessorsDirective>(element)) {
body << " ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n";
@@ -680,6 +771,20 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
body << llvm::formatv(" ::llvm::SMLoc {0}OperandsLoc;\n"
" (void){0}OperandsLoc;\n",
name);
} else if (auto *region = dyn_cast<RegionVariable>(element)) {
StringRef name = region->getVar()->name;
if (region->getVar()->isVariadic()) {
body << llvm::formatv(
" ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
"{0}Regions;\n",
name);
} else {
body << llvm::formatv(" std::unique_ptr<::mlir::Region> {0}Region = "
"std::make_unique<::mlir::Region>();\n",
name);
}
} else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
StringRef name = successor->getVar()->name;
if (successor->getVar()->isVariadic()) {
@@ -725,6 +830,13 @@ static void genCustomParameterParser(Element &param, OpMethodBody &body) {
else
body << formatv("{0}RawOperands[0]", name);
} else if (auto *region = dyn_cast<RegionVariable>(&param)) {
StringRef name = region->getVar()->name;
if (region->getVar()->isVariadic())
body << llvm::formatv("{0}Regions", name);
else
body << llvm::formatv("*{0}Region", name);
} else if (auto *successor = dyn_cast<SuccessorVariable>(&param)) {
StringRef name = successor->getVar()->name;
if (successor->getVar()->isVariadic())
@@ -809,9 +921,39 @@ static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
body << " }\n";
}
/// Generate the parser for a single format element.
static void genElementParser(Element *element, OpMethodBody &body,
FmtContext &attrTypeCtx) {
void OperationFormat::genParser(Operator &op, OpClass &opClass) {
auto &method = opClass.newMethod(
"::mlir::ParseResult", "parse",
"::mlir::OpAsmParser &parser, ::mlir::OperationState &result",
OpMethod::MP_Static);
auto &body = method.body();
// Generate variables to store the operands and type within the format. This
// allows for referencing these variables in the presence of optional
// groupings.
for (auto &element : elements)
genElementParserStorage(&*element, body);
// A format context used when parsing attributes with buildable types.
FmtContext attrTypeCtx;
attrTypeCtx.withBuilder("parser.getBuilder()");
// Generate parsers for each of the elements.
for (auto &element : elements)
genElementParser(element.get(), body, attrTypeCtx);
// Generate the code to resolve the operand/result types and successors now
// that they have been parsed.
genParserTypeResolution(op, body);
genParserRegionResolution(op, body);
genParserSuccessorResolution(op, body);
genParserVariadicSegmentResolution(op, body);
body << " return success();\n";
}
void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
FmtContext &attrTypeCtx) {
/// Optional Group.
if (auto *optional = dyn_cast<OptionalElement>(element)) {
auto elements = optional->getElements();
@@ -829,6 +971,17 @@ static void genElementParser(Element *element, OpMethodBody &body,
} else if (auto *opVar = dyn_cast<OperandVariable>(firstElement)) {
genElementParser(opVar, body, attrTypeCtx);
body << " if (!" << opVar->getVar()->name << "Operands.empty()) {\n";
} else if (auto *regionVar = dyn_cast<RegionVariable>(firstElement)) {
const NamedRegion *region = regionVar->getVar();
if (region->isVariadic()) {
genElementParser(regionVar, body, attrTypeCtx);
body << " if (!" << region->name << "Regions.empty()) {\n";
} else {
body << llvm::formatv(optionalRegionParserCode, region->name);
body << " if (!" << region->name << "Region->empty()) {\n ";
if (hasImplicitTermTrait)
body << llvm::formatv(regionEnsureTerminatorParserCode, region->name);
}
}
// If the anchor is a unit attribute, we don't need to print it. When
@@ -907,6 +1060,17 @@ static void genElementParser(Element *element, OpMethodBody &body,
body << llvm::formatv(optionalOperandParserCode, name);
else
body << formatv(operandParserCode, name);
} else if (auto *region = dyn_cast<RegionVariable>(element)) {
bool isVariadic = region->getVar()->isVariadic();
body << llvm::formatv(isVariadic ? regionListParserCode : regionParserCode,
region->getVar()->name);
if (hasImplicitTermTrait) {
body << llvm::formatv(isVariadic ? regionListEnsureTerminatorParserCode
: regionEnsureTerminatorParserCode,
region->getVar()->name);
}
} else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
bool isVariadic = successor->getVar()->isVariadic();
body << formatv(isVariadic ? successorListParserCode : successorParserCode,
@@ -925,8 +1089,15 @@ static void genElementParser(Element *element, OpMethodBody &body,
body << " ::llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n"
<< " if (parser.parseOperandList(allOperands))\n"
<< " return failure();\n";
} else if (isa<RegionsDirective>(element)) {
body << llvm::formatv(regionListParserCode, "full");
if (hasImplicitTermTrait)
body << llvm::formatv(regionListEnsureTerminatorParserCode, "full");
} else if (isa<SuccessorsDirective>(element)) {
body << llvm::formatv(successorListParserCode, "full");
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
@@ -946,36 +1117,6 @@ static void genElementParser(Element *element, OpMethodBody &body,
}
}
void OperationFormat::genParser(Operator &op, OpClass &opClass) {
auto &method = opClass.newMethod(
"::mlir::ParseResult", "parse",
"::mlir::OpAsmParser &parser, ::mlir::OperationState &result",
OpMethod::MP_Static);
auto &body = method.body();
// Generate variables to store the operands and type within the format. This
// allows for referencing these variables in the presence of optional
// groupings.
for (auto &element : elements)
genElementParserStorage(&*element, body);
// A format context used when parsing attributes with buildable types.
FmtContext attrTypeCtx;
attrTypeCtx.withBuilder("parser.getBuilder()");
// Generate parsers for each of the elements.
for (auto &element : elements)
genElementParser(element.get(), body, attrTypeCtx);
// Generate the code to resolve the operand/result types and successors now
// that they have been parsed.
genParserTypeResolution(op, body);
genParserSuccessorResolution(op, body);
genParserVariadicSegmentResolution(op, body);
body << " return success();\n";
}
void OperationFormat::genParserTypeResolution(Operator &op,
OpMethodBody &body) {
// If any of type resolutions use transformed variables, make sure that the
@@ -1133,6 +1274,25 @@ void OperationFormat::genParserTypeResolution(Operator &op,
}
}
void OperationFormat::genParserRegionResolution(Operator &op,
OpMethodBody &body) {
// Check for the case where all regions were parsed.
bool hasAllRegions = llvm::any_of(
elements, [](auto &elt) { return isa<RegionsDirective>(elt.get()); });
if (hasAllRegions) {
body << " result.addRegions(fullRegions);\n";
return;
}
// Otherwise, handle each region individually.
for (const NamedRegion &region : op.getRegions()) {
if (region.isVariadic())
body << " result.addRegions(" << region.name << "Regions);\n";
else
body << " result.addRegion(std::move(" << region.name << "Region));\n";
}
}
void OperationFormat::genParserSuccessorResolution(Operator &op,
OpMethodBody &body) {
// Check for the case where all successors were parsed.
@@ -1186,23 +1346,26 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
//===----------------------------------------------------------------------===//
// PrinterGen
/// The code snippet used to generate a printer call for a region of an
// operation that has the SingleBlockImplicitTerminator trait.
///
/// {0}: The name of the region.
const char *regionSingleBlockImplicitTerminatorPrinterCode = R"(
{
bool printTerminator = true;
if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) {{
printTerminator = !term->getMutableAttrDict().empty() ||
term->getNumOperands() != 0 ||
term->getNumResults() != 0;
}
p.printRegion({0}, /*printEntryBlockArgs=*/true,
/*printBlockTerminators=*/printTerminator);
}
)";
/// Generate the printer for the 'attr-dict' directive.
static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
OpMethodBody &body, bool withKeyword) {
// Collect all of the attributes used in the format, these will be elided.
SmallVector<const NamedAttribute *, 1> usedAttributes;
for (auto &it : fmt.elements) {
if (auto *attr = dyn_cast<AttributeVariable>(it.get()))
usedAttributes.push_back(attr->getVar());
// Collect the optional attributes.
if (auto *opt = dyn_cast<OptionalElement>(it.get())) {
for (auto &elem : opt->getElements()) {
if (auto *attr = dyn_cast<AttributeVariable>(&elem))
usedAttributes.push_back(attr->getVar());
}
}
}
body << " p.printOptionalAttrDict" << (withKeyword ? "WithKeyword" : "")
<< "(getAttrs(), /*elidedAttrs=*/{";
// Elide the variadic segment size attributes if necessary.
@@ -1210,9 +1373,9 @@ static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
body << "\"operand_segment_sizes\", ";
if (!fmt.allResultTypes && op.getTrait("OpTrait::AttrSizedResultSegments"))
body << "\"result_segment_sizes\", ";
llvm::interleaveComma(usedAttributes, body, [&](const NamedAttribute *attr) {
body << "\"" << attr->name << "\"";
});
llvm::interleaveComma(
fmt.usedAttributes, body,
[&](const NamedAttribute *attr) { body << "\"" << attr->name << "\""; });
body << "});\n";
}
@@ -1255,6 +1418,9 @@ static void genCustomDirectivePrinter(CustomDirective *customDir,
} else if (auto *operand = dyn_cast<OperandVariable>(&param)) {
body << operand->getVar()->name << "()";
} else if (auto *region = dyn_cast<RegionVariable>(&param)) {
body << region->getVar()->name << "()";
} else if (auto *successor = dyn_cast<SuccessorVariable>(&param)) {
body << successor->getVar()->name << "()";
@@ -1277,6 +1443,24 @@ static void genCustomDirectivePrinter(CustomDirective *customDir,
body << ");\n";
}
/// Generate the printer for a region with the given variable name.
static void genRegionPrinter(const Twine &regionName, OpMethodBody &body,
bool hasImplicitTermTrait) {
if (hasImplicitTermTrait)
body << llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode,
regionName);
else
body << " p.printRegion(" << regionName << ");\n";
}
static void genVariadicRegionPrinter(const Twine &regionListName,
OpMethodBody &body,
bool hasImplicitTermTrait) {
body << " llvm::interleaveComma(" << regionListName
<< ", p, [&](::mlir::Region &region) {\n ";
genRegionPrinter("region", body, hasImplicitTermTrait);
body << " });\n";
}
/// Generate the C++ for an operand to a (*-)type directive.
static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
if (isa<OperandsDirective>(arg))
@@ -1296,10 +1480,9 @@ static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
<< "().getType())";
}
/// Generate the code for printing the given element.
static void genElementPrinter(Element *element, OpMethodBody &body,
OperationFormat &fmt, Operator &op,
bool &shouldEmitSpace, bool &lastWasPunctuation) {
void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
Operator &op, bool &shouldEmitSpace,
bool &lastWasPunctuation) {
if (LiteralElement *literal = dyn_cast<LiteralElement>(element))
return genLiteralPrinter(literal->getLiteral(), body, shouldEmitSpace,
lastWasPunctuation);
@@ -1314,6 +1497,11 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
body << " if (" << var->name << "()) {\n";
else if (var->isVariadic())
body << " if (!" << var->name << "().empty()) {\n";
} else if (auto *region = dyn_cast<RegionVariable>(anchor)) {
const NamedRegion *var = region->getVar();
// TODO: Add a check for optional here when ODS supports it.
body << " if (!" << var->name << "().empty()) {\n";
} else {
body << " if (getAttr(\""
<< cast<AttributeVariable>(anchor)->getVar()->name << "\")) {\n";
@@ -1332,7 +1520,7 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
// Emit each of the elements.
for (Element &childElement : elements) {
if (&childElement != elidedAnchorElement) {
genElementPrinter(&childElement, body, fmt, op, shouldEmitSpace,
genElementPrinter(&childElement, body, op, shouldEmitSpace,
lastWasPunctuation);
}
}
@@ -1342,7 +1530,7 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
// Emit the attribute dictionary.
if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
genAttrDictPrinter(fmt, op, body, attrDict->isWithKeyword());
genAttrDictPrinter(*this, op, body, attrDict->isWithKeyword());
lastWasPunctuation = false;
return;
}
@@ -1384,6 +1572,13 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
} else {
body << " p << " << operand->getVar()->name << "();\n";
}
} else if (auto *region = dyn_cast<RegionVariable>(element)) {
const NamedRegion *var = region->getVar();
if (var->isVariadic()) {
genVariadicRegionPrinter(var->name + "()", body, hasImplicitTermTrait);
} else {
genRegionPrinter(var->name + "()", body, hasImplicitTermTrait);
}
} else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
const NamedSuccessor *var = successor->getVar();
if (var->isVariadic())
@@ -1394,6 +1589,9 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
genCustomDirectivePrinter(dir, body);
} else if (isa<OperandsDirective>(element)) {
body << " p << getOperation()->getOperands();\n";
} else if (isa<RegionsDirective>(element)) {
genVariadicRegionPrinter("getOperation()->getRegions()", body,
hasImplicitTermTrait);
} else if (isa<SuccessorsDirective>(element)) {
body << " ::llvm::interleaveComma(getOperation()->getSuccessors(), p);\n";
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
@@ -1426,7 +1624,7 @@ void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
// punctuation.
bool shouldEmitSpace = true, lastWasPunctuation = false;
for (auto &element : elements)
genElementPrinter(element.get(), body, *this, op, shouldEmitSpace,
genElementPrinter(element.get(), body, op, shouldEmitSpace,
lastWasPunctuation);
}
@@ -1460,6 +1658,7 @@ public:
kw_custom,
kw_functional_type,
kw_operands,
kw_regions,
kw_results,
kw_successors,
kw_type,
@@ -1663,6 +1862,7 @@ Token FormatLexer::lexIdentifier(const char *tokStart) {
.Case("custom", Token::kw_custom)
.Case("functional-type", Token::kw_functional_type)
.Case("operands", Token::kw_operands)
.Case("regions", Token::kw_regions)
.Case("results", Token::kw_results)
.Case("successors", Token::kw_successors)
.Case("type", Token::kw_type)
@@ -1676,7 +1876,8 @@ Token FormatLexer::lexIdentifier(const char *tokStart) {
/// Function to find an element within the given range that has the same name as
/// 'name'.
template <typename RangeT> static auto findArg(RangeT &&range, StringRef name) {
template <typename RangeT>
static auto findArg(RangeT &&range, StringRef name) {
auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
return it != range.end() ? &*it : nullptr;
}
@@ -1719,6 +1920,9 @@ private:
verifyOperands(llvm::SMLoc loc,
llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
/// Verify the state of operation regions within the format.
LogicalResult verifyRegions(llvm::SMLoc loc);
/// Verify the state of operation results within the format.
LogicalResult
verifyResults(llvm::SMLoc loc,
@@ -1775,6 +1979,8 @@ private:
Token tok, bool isTopLevel);
LogicalResult parseOperandsDirective(std::unique_ptr<Element> &element,
llvm::SMLoc loc, bool isTopLevel);
LogicalResult parseRegionsDirective(std::unique_ptr<Element> &element,
llvm::SMLoc loc, bool isTopLevel);
LogicalResult parseResultsDirective(std::unique_ptr<Element> &element,
llvm::SMLoc loc, bool isTopLevel);
LogicalResult parseSuccessorsDirective(std::unique_ptr<Element> &element,
@@ -1821,11 +2027,12 @@ private:
// The following are various bits of format state used for verification
// during parsing.
bool hasAllOperands = false, hasAttrDict = false;
bool hasAllSuccessors = false;
bool hasAttrDict = false;
bool hasAllRegions = false, hasAllSuccessors = false;
llvm::SmallBitVector seenOperandTypes, seenResultTypes;
llvm::SmallSetVector<const NamedAttribute *, 8> seenAttrs;
llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
llvm::DenseSet<const NamedAttribute *> seenAttrs;
llvm::DenseSet<const NamedRegion *> seenRegions;
llvm::DenseSet<const NamedSuccessor *> seenSuccessors;
llvm::DenseSet<const NamedTypeConstraint *> optionalVariables;
};
@@ -1867,13 +2074,11 @@ LogicalResult FormatParser::parse() {
if (failed(verifyAttributes(loc)) ||
failed(verifyResults(loc, variableTyResolver)) ||
failed(verifyOperands(loc, variableTyResolver)) ||
failed(verifySuccessors(loc)))
failed(verifyRegions(loc)) || failed(verifySuccessors(loc)))
return failure();
// Check to see if we are formatting all of the operands.
fmt.allOperands = llvm::any_of(fmt.elements, [](auto &elt) {
return isa<OperandsDirective>(elt.get());
});
// Collect the set of used attributes in the format.
fmt.usedAttributes = seenAttrs.takeVector();
return success();
}
@@ -1953,7 +2158,7 @@ LogicalResult FormatParser::verifyOperands(
NamedTypeConstraint &operand = op.getOperand(i);
// Check that the operand itself is in the format.
if (!hasAllOperands && !seenOperands.count(&operand)) {
if (!fmt.allOperands && !seenOperands.count(&operand)) {
return emitErrorAndNote(loc,
"operand #" + Twine(i) + ", named '" +
operand.name + "', not found",
@@ -1976,7 +2181,7 @@ LogicalResult FormatParser::verifyOperands(
// Similarly to results, allow a custom builder for resolving the type if
// we aren't using the 'operands' directive.
Optional<StringRef> builder = operand.constraint.getBuilderCall();
if (!builder || (hasAllOperands && operand.isVariableLength())) {
if (!builder || (fmt.allOperands && operand.isVariableLength())) {
return emitErrorAndNote(
loc,
"type of operand #" + Twine(i) + ", named '" + operand.name +
@@ -1991,6 +2196,24 @@ LogicalResult FormatParser::verifyOperands(
return success();
}
LogicalResult FormatParser::verifyRegions(llvm::SMLoc loc) {
// Check that all of the regions are within the format.
if (hasAllRegions)
return success();
for (unsigned i = 0, e = op.getNumRegions(); i != e; ++i) {
const NamedRegion &region = op.getRegion(i);
if (!seenRegions.count(&region)) {
return emitErrorAndNote(loc,
"region #" + Twine(i) + ", named '" +
region.name + "', not found",
"suggest adding a '$" + region.name +
"' directive to the custom assembly format");
}
}
return success();
}
LogicalResult FormatParser::verifyResults(
llvm::SMLoc loc,
llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
@@ -2108,7 +2331,7 @@ ConstArgument FormatParser::findSeenArg(StringRef name) {
if (const NamedTypeConstraint *arg = findArg(op.getResults(), name))
return seenResultTypes.test(arg - op.result_begin()) ? arg : nullptr;
if (const NamedAttribute *attr = findArg(op.getAttributes(), name))
return seenAttrs.find_as(attr) != seenAttrs.end() ? attr : nullptr;
return seenAttrs.count(attr) ? attr : nullptr;
return nullptr;
}
@@ -2142,7 +2365,7 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
// op.
/// Attributes
if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) {
if (isTopLevel && !seenAttrs.insert(attr).second)
if (isTopLevel && !seenAttrs.insert(attr))
return emitError(loc, "attribute '" + name + "' is already bound");
element = std::make_unique<AttributeVariable>(attr);
return success();
@@ -2150,12 +2373,21 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
/// Operands
if (const NamedTypeConstraint *operand = findArg(op.getOperands(), name)) {
if (isTopLevel) {
if (hasAllOperands || !seenOperands.insert(operand).second)
if (fmt.allOperands || !seenOperands.insert(operand).second)
return emitError(loc, "operand '" + name + "' is already bound");
}
element = std::make_unique<OperandVariable>(operand);
return success();
}
/// Regions
if (const NamedRegion *region = findArg(op.getRegions(), name)) {
if (!isTopLevel)
return emitError(loc, "regions can only be used at the top level");
if (hasAllRegions || !seenRegions.insert(region).second)
return emitError(loc, "region '" + name + "' is already bound");
element = std::make_unique<RegionVariable>(region);
return success();
}
/// Results.
if (const auto *result = findArg(op.getResults(), name)) {
if (isTopLevel)
@@ -2172,8 +2404,8 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
element = std::make_unique<SuccessorVariable>(successor);
return success();
}
return emitError(
loc, "expected variable to refer to an argument, result, or successor");
return emitError(loc, "expected variable to refer to an argument, region, "
"result, or successor");
}
LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
@@ -2194,6 +2426,8 @@ LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
return parseFunctionalTypeDirective(element, dirTok, isTopLevel);
case Token::kw_operands:
return parseOperandsDirective(element, dirTok.getLoc(), isTopLevel);
case Token::kw_regions:
return parseRegionsDirective(element, dirTok.getLoc(), isTopLevel);
case Token::kw_results:
return parseResultsDirective(element, dirTok.getLoc(), isTopLevel);
case Token::kw_successors:
@@ -2247,9 +2481,10 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
// optional fashion.
Element *firstElement = &*elements.front();
if (!isa<AttributeVariable>(firstElement) &&
!isa<LiteralElement>(firstElement) && !isa<OperandVariable>(firstElement))
!isa<LiteralElement>(firstElement) &&
!isa<OperandVariable>(firstElement) && !isa<RegionVariable>(firstElement))
return emitError(curLoc, "first element of an operand group must be an "
"attribute, literal, or operand");
"attribute, literal, operand, or region");
// After parsing all of the elements, ensure that all type directives refer
// only to elements within the group.
@@ -2314,10 +2549,15 @@ LogicalResult FormatParser::parseOptionalChildElement(
seenVariables.insert(ele->getVar());
return success();
})
.Case<RegionVariable>([&](RegionVariable *) {
// TODO: When ODS has proper support for marking "optional" regions, add
// a check here.
return success();
})
// Literals, custom directives, and type directives may be used,
// but they can't anchor the group.
.Case<LiteralElement, CustomDirective, TypeDirective,
FunctionalTypeDirective>([&](Element *) {
.Case<LiteralElement, CustomDirective, FunctionalTypeDirective,
OptionalElement, TypeDirective>([&](Element *) {
if (isAnchor)
return emitError(childLoc, "only variables can be used to anchor "
"an optional group");
@@ -2401,7 +2641,7 @@ LogicalResult FormatParser::parseCustomDirectiveParameter(
return failure();
// Verify that the element can be placed within a custom directive.
if (!isa<TypeDirective, AttributeVariable, OperandVariable,
if (!isa<TypeDirective, AttributeVariable, OperandVariable, RegionVariable,
SuccessorVariable>(parameters.back().get())) {
return emitError(childLoc, "only variables and types may be used as "
"parameters to a custom directive");
@@ -2433,13 +2673,27 @@ FormatParser::parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
LogicalResult
FormatParser::parseOperandsDirective(std::unique_ptr<Element> &element,
llvm::SMLoc loc, bool isTopLevel) {
if (isTopLevel && (hasAllOperands || !seenOperands.empty()))
return emitError(loc, "'operands' directive creates overlap in format");
hasAllOperands = true;
if (isTopLevel) {
if (fmt.allOperands || !seenOperands.empty())
return emitError(loc, "'operands' directive creates overlap in format");
fmt.allOperands = true;
}
element = std::make_unique<OperandsDirective>();
return success();
}
LogicalResult
FormatParser::parseRegionsDirective(std::unique_ptr<Element> &element,
llvm::SMLoc loc, bool isTopLevel) {
if (!isTopLevel)
return emitError(loc, "'regions' is only valid as a top-level directive");
if (hasAllRegions || !seenRegions.empty())
return emitError(loc, "'regions' directive creates overlap in format");
hasAllRegions = true;
element = std::make_unique<RegionsDirective>();
return success();
}
LogicalResult
FormatParser::parseResultsDirective(std::unique_ptr<Element> &element,
llvm::SMLoc loc, bool isTopLevel) {