[mlir] Add support for parsing optional Attribute values.

This adds a `parseOptionalAttribute` method to the OpAsmParser that allows for parsing optional attributes, in a similar fashion to how optional types are parsed. This also enables the use of attribute values as the first element of an assembly format optional group.

Differential Revision: https://reviews.llvm.org/D83712
This commit is contained in:
River Riddle
2020-07-14 13:14:14 -07:00
parent aef60af34e
commit 6b476e2426
9 changed files with 129 additions and 17 deletions

View File

@@ -713,7 +713,8 @@ of the assembly format can be marked as `optional` based on the presence of this
information. An optional group is defined by wrapping a set of elements within
`()` followed by a `?` and has the following requirements:
* The first element of the group must either be a literal or an operand.
* The first element of the group must either be a literal, attribute, or an
operand.
- This is because the first element must be optionally parsable.
* Exactly one argument variable within the group must be marked as the anchor
of the group.

View File

@@ -384,6 +384,17 @@ public:
StringRef attrName,
NamedAttrList &attrs) = 0;
/// Parse an optional attribute.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result,
Type type,
StringRef attrName,
NamedAttrList &attrs) = 0;
OptionalParseResult parseOptionalAttribute(Attribute &result,
StringRef attrName,
NamedAttrList &attrs) {
return parseOptionalAttribute(result, Type(), attrName, attrs);
}
/// Parse an attribute of a specific kind and type.
template <typename AttrType>
ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,

View File

@@ -187,6 +187,40 @@ Attribute Parser::parseAttribute(Type type) {
}
}
/// Parse an optional attribute with the provided type.
OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
Type type) {
switch (getToken().getKind()) {
case Token::at_identifier:
case Token::floatliteral:
case Token::integer:
case Token::hash_identifier:
case Token::kw_affine_map:
case Token::kw_affine_set:
case Token::kw_dense:
case Token::kw_false:
case Token::kw_loc:
case Token::kw_opaque:
case Token::kw_sparse:
case Token::kw_true:
case Token::kw_unit:
case Token::l_brace:
case Token::l_square:
case Token::minus:
case Token::string:
attribute = parseAttribute(type);
return success(attribute != nullptr);
default:
// Parse an optional type attribute.
Type type;
OptionalParseResult result = parseOptionalType(type);
if (result.hasValue() && succeeded(*result))
attribute = TypeAttr::get(type);
return result;
}
}
/// Attribute dictionary.
///
/// attribute-dict ::= `{` `}`

View File

@@ -1011,6 +1011,17 @@ public:
return success();
}
/// Parse an optional attribute.
OptionalParseResult parseOptionalAttribute(Attribute &result, Type type,
StringRef attrName,
NamedAttrList &attrs) override {
OptionalParseResult parseResult =
parser.parseOptionalAttribute(result, type);
if (parseResult.hasValue() && succeeded(*parseResult))
attrs.push_back(parser.builder.getNamedAttr(attrName, result));
return parseResult;
}
/// Parse a named dictionary into 'result' if it is present.
ParseResult parseOptionalAttrDict(NamedAttrList &result) override {
if (parser.getToken().isNot(Token::l_brace))

View File

@@ -184,6 +184,10 @@ public:
/// Parse an arbitrary attribute with an optional type.
Attribute parseAttribute(Type type = {});
/// Parse an optional attribute with the provided type.
OptionalParseResult parseOptionalAttribute(Attribute &attribute,
Type type = {});
/// Parse an attribute dictionary.
ParseResult parseAttributeDict(NamedAttrList &attributes);

View File

@@ -1253,9 +1253,13 @@ def FormatAttrOp : TEST_Op<"format_attr_op"> {
}
// Test that we elide optional attributes that are within the syntax.
def FormatOptAttrOp : TEST_Op<"format_opt_attr_op"> {
def FormatOptAttrAOp : TEST_Op<"format_opt_attr_op_a"> {
let arguments = (ins OptionalAttr<I64Attr>:$opt_attr);
let assemblyFormat = "(`(`$opt_attr^`)`)? attr-dict";
let assemblyFormat = "(`(` $opt_attr^ `)` )? attr-dict";
}
def FormatOptAttrBOp : TEST_Op<"format_opt_attr_op_b"> {
let arguments = (ins OptionalAttr<I64Attr>:$opt_attr);
let assemblyFormat = "($opt_attr^)? attr-dict";
}
// Test that we elide attributes that are within the syntax.

View File

@@ -206,10 +206,10 @@ def OptionalInvalidB : TestFormat_Op<"optional_invalid_b", [{
def OptionalInvalidC : TestFormat_Op<"optional_invalid_c", [{
($attr)? attr-dict
}]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
// CHECK: error: first element of an operand group must be a literal or operand
// CHECK: error: first element of an operand group must be an attribute, literal, or operand
def OptionalInvalidD : TestFormat_Op<"optional_invalid_d", [{
($attr^)? attr-dict
}]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
(type($operand) $operand^)? attr-dict
}]>, Arguments<(ins Optional<I64>:$operand)>;
// CHECK: error: type directive can only refer to variables within the optional group
def OptionalInvalidE : TestFormat_Op<"optional_invalid_e", [{
(`,` $attr^ type(operands))? attr-dict

View File

@@ -12,9 +12,15 @@ test.format_literal_op keyword_$. -> :, = <> () [] {foo.some_attr}
// CHECK-NOT: {attr
test.format_attr_op 10
// CHECK: test.format_opt_attr_op(10)
// CHECK: test.format_opt_attr_op_a(10)
// CHECK-NOT: {opt_attr
test.format_opt_attr_op(10)
test.format_opt_attr_op_a(10)
test.format_opt_attr_op_a
// CHECK: test.format_opt_attr_op_b 10
// CHECK-NOT: {opt_attr
test.format_opt_attr_op_b 10
test.format_opt_attr_op_b
// CHECK: test.format_attr_dict_w_keyword attributes {attr = 10 : i64}
test.format_attr_dict_w_keyword attributes {attr = 10 : i64}

View File

@@ -373,6 +373,15 @@ const char *const attrParserCode = R"(
if (parser.parseAttribute({1}Attr{2}, "{1}", result.attributes))
return failure();
)";
const char *const optionalAttrParserCode = R"(
{0} {1}Attr;
{
::mlir::OptionalParseResult parseResult =
parser.parseOptionalAttribute({1}Attr{2}, "{1}", result.attributes);
if (parseResult.hasValue() && failed(*parseResult))
return failure();
}
)";
/// The code snippet used to generate a parser call for an enum attribute.
///
@@ -397,6 +406,30 @@ const char *const enumAttrParserCode = R"(
result.addAttribute("{0}", {3});
}
)";
const char *const optionalEnumAttrParserCode = R"(
Attribute {0}Attr;
{
::mlir::StringAttr attrVal;
::mlir::NamedAttrList attrStorage;
auto loc = parser.getCurrentLocation();
::mlir::OptionalParseResult parseResult =
parser.parseOptionalAttribute(attrVal, parser.getBuilder().getNoneType(),
"{0}", attrStorage);
if (parseResult.hasValue()) {
if (failed(*parseResult))
return failure();
auto attrOptional = {1}::{2}(attrVal.getValue());
if (!attrOptional)
return parser.emitError(loc, "invalid ")
<< "{0} attribute specification: " << attrVal;
{0}Attr = {3};
result.addAttribute("{0}", {0}Attr);
}
}
)";
/// The code snippet used to generate a parser call for an operand.
///
@@ -599,11 +632,15 @@ static void genElementParser(Element *element, OpMethodBody &body,
// Generate a special optional parser for the first element to gate the
// parsing of the rest of the elements.
if (auto *literal = dyn_cast<LiteralElement>(&*elements.begin())) {
Element *firstElement = &*elements.begin();
if (auto *attrVar = dyn_cast<AttributeVariable>(firstElement)) {
genElementParser(attrVar, body, attrTypeCtx);
body << " if (" << attrVar->getVar()->name << "Attr) {\n";
} else if (auto *literal = dyn_cast<LiteralElement>(firstElement)) {
body << " if (succeeded(parser.parseOptional";
genLiteralParser(literal->getLiteral(), body);
body << ")) {\n";
} else if (auto *opVar = dyn_cast<OperandVariable>(&*elements.begin())) {
} else if (auto *opVar = dyn_cast<OperandVariable>(firstElement)) {
genElementParser(opVar, body, attrTypeCtx);
body << " if (!" << opVar->getVar()->name << "Operands.empty()) {\n";
}
@@ -635,7 +672,9 @@ static void genElementParser(Element *element, OpMethodBody &body,
"attrOptional.getValue()");
}
body << formatv(enumAttrParserCode, var->name, enumAttr.getCppNamespace(),
body << formatv(var->attr.isOptional() ? optionalEnumAttrParserCode
: enumAttrParserCode,
var->name, enumAttr.getCppNamespace(),
enumAttr.getStringToSymbolFnName(), attrBuilderStr);
return;
}
@@ -648,8 +687,9 @@ static void genElementParser(Element *element, OpMethodBody &body,
os << ", " << tgfmt(*typeBuilder, &attrTypeCtx);
}
body << formatv(attrParserCode, var->attr.getStorageType(), var->name,
attrTypeStr);
body << formatv(var->attr.isOptional() ? optionalAttrParserCode
: attrParserCode,
var->attr.getStorageType(), var->name, attrTypeStr);
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
StringRef name = operand->getVar()->name;
@@ -1910,10 +1950,11 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
// The first element of the group must be one that can be parsed/printed in an
// optional fashion.
if (!isa<LiteralElement>(&*elements.front()) &&
!isa<OperandVariable>(&*elements.front()))
return emitError(curLoc, "first element of an operand group must be a "
"literal or operand");
Element *firstElement = &*elements.front();
if (!isa<AttributeVariable>(firstElement) &&
!isa<LiteralElement>(firstElement) && !isa<OperandVariable>(firstElement))
return emitError(curLoc, "first element of an operand group must be an "
"attribute, literal, or operand");
// After parsing all of the elements, ensure that all type directives refer
// only to elements within the group.