[mlir][linalg] Fix SemiFunctionType custom parsing crash on missing () (#110365)

The `SemiFunctionType` allows printing/parsing a set of argument and
result types, where there is always exactly one argument type and zero
or more result types. If there are no result types, the argument type
can be written without enclosing parens in the assembly. If there is at
least one result type, the parens are mandatory.

This patch fixes a bug where omitting the parens around the argument
types for a `SemiFunctionType` with non-optional result Types would
crash the parser. It introduces a `bool` argument `resultOptional` to
the parser and printer which, when `false`, correctly enforces the
parens around argument types, otherwise printing an error.

Fix https://github.com/llvm/llvm-project/issues/109128
This commit is contained in:
Felix Schneider
2024-11-03 15:31:25 +01:00
committed by GitHub
parent 30213e99b8
commit a07b422e90
6 changed files with 49 additions and 25 deletions

View File

@@ -541,9 +541,10 @@ def MatchStructuredRankOp : Op<Transform_Dialect, "match.structured.rank", [
let arguments = (ins TransformHandleTypeInterface:$operand_handle);
let results = (outs TransformParamTypeInterface:$rank);
let assemblyFormat =
"$operand_handle attr-dict `:`"
"custom<SemiFunctionType>(type($operand_handle), type($rank))";
let assemblyFormat = [{
$operand_handle attr-dict `:`
custom<SemiFunctionType>(type($operand_handle), type($rank), "false")
}];
let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
}

View File

@@ -418,9 +418,10 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$transformed);
let assemblyFormat =
"$target attr-dict `:` "
"custom<SemiFunctionType>(type($target), type($transformed))";
let assemblyFormat = [{
$target attr-dict `:`
custom<SemiFunctionType>(type($target), type($transformed), "false")
}];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -455,9 +456,10 @@ def SpecializeOp : Op<Transform_Dialect, "structured.specialize",
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$transformed);
let assemblyFormat =
"$target attr-dict `:` "
"custom<SemiFunctionType>(type($target), type($transformed))";
let assemblyFormat = [{
$target attr-dict `:`
custom<SemiFunctionType>(type($target), type($transformed), "false")
}];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -500,7 +502,7 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
let assemblyFormat = [{
$target
(`iterator_interchange` `=` $iterator_interchange^)? attr-dict
`:` custom<SemiFunctionType>(type($target), type($transformed))
`:` custom<SemiFunctionType>(type($target), type($transformed), "false")
}];
let hasVerifier = 1;
@@ -1233,9 +1235,10 @@ def PromoteOp : Op<Transform_Dialect, "structured.promote",
OptionalAttr<I64Attr>:$alignment);
let results = (outs TransformHandleTypeInterface:$transformed);
let assemblyFormat =
"$target attr-dict `:`"
"custom<SemiFunctionType>(type($target), type($transformed))";
let assemblyFormat = [{
$target attr-dict `:`
custom<SemiFunctionType>(type($target), type($transformed), "false")
}];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -1269,9 +1272,10 @@ def ReplaceOp : Op<Transform_Dialect, "structured.replace",
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$replacement);
let regions = (region SizedRegion<1>:$bodyRegion);
let assemblyFormat =
"$target attr-dict-with-keyword regions `:` "
"custom<SemiFunctionType>(type($target), type($replacement))";
let assemblyFormat = [{
$target attr-dict-with-keyword regions `:`
custom<SemiFunctionType>(type($target), type($replacement), "false")
}];
let hasVerifier = 1;
}
@@ -1310,9 +1314,10 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$result);
let assemblyFormat =
"$target attr-dict `:`"
"custom<SemiFunctionType>(type($target), type($result))";
let assemblyFormat = [{
$target attr-dict `:`
custom<SemiFunctionType>(type($target), type($result), "false")
}];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(

View File

@@ -30,7 +30,7 @@ class Operation;
/// the argument type in absence of result types, and does not accept the
/// trailing `-> ()` construct, which makes the syntax nicer for operations.
ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
Type &resultType);
Type &resultType, bool resultOptional = true);
ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
SmallVectorImpl<Type> &resultTypes);
@@ -40,7 +40,8 @@ ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
void printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
Type argumentType, TypeRange resultType);
void printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
Type argumentType, Type resultType);
Type argumentType, Type resultType,
bool resultOptional = true);
} // namespace mlir
#endif // MLIR_DIALECT_LINALG_TRANSFORMOPS_SYNTAX_H

View File

@@ -32,7 +32,10 @@ def MatchSparseInOut : Op<Transform_Dialect, "sparse_tensor.match.sparse_inout",
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$result);
let assemblyFormat = "$target attr-dict `:` custom<SemiFunctionType>(type($target), type($result))";
let assemblyFormat = [{
$target attr-dict `:`
custom<SemiFunctionType>(type($target), type($result), "false")
}];
let extraClassDeclaration = SingleOpMatcher.extraDeclaration # [{
::mlir::Value getOperandHandle() { return getTarget(); }
}];

View File

@@ -12,9 +12,13 @@
using namespace mlir;
ParseResult mlir::parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
Type &resultType) {
Type &resultType, bool resultOptional) {
argumentType = resultType = nullptr;
bool hasLParen = parser.parseOptionalLParen().succeeded();
bool hasLParen = resultOptional ? parser.parseOptionalLParen().succeeded()
: parser.parseLParen().succeeded();
if (!resultOptional && !hasLParen)
return failure();
if (parser.parseType(argumentType).failed())
return failure();
if (!hasLParen)
@@ -69,7 +73,9 @@ void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
}
void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
Type argumentType, Type resultType) {
Type argumentType, Type resultType,
bool resultOptional) {
assert(resultOptional || resultType != nullptr);
return printSemiFunctionType(printer, op, argumentType,
resultType ? TypeRange(resultType)
: TypeRange());

View File

@@ -92,3 +92,11 @@ transform.sequence failures(propagate) {
transform.structured.vectorize %arg0 vector_sizes [%0 : !transform.param<i64>, 2] : !transform.any_op, !transform.param<i64>
}
// -----
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
// expected-error@below {{expected '('}}
%res = transform.structured.generalize %arg0 : !transform.any_op -> !transform.any_op
}