mirror of
https://github.com/intel/llvm.git
synced 2026-01-27 06:06:34 +08:00
[mlir][linalg] Fix SemiFunctionType custom parsing crash on missing () (#110365)
The `SemiFunctionType` allows printing/parsing a set of argument and result types, where there is always exactly one argument type and zero or more result types. If there are no result types, the argument type can be written without enclosing parens in the assembly. If there is at least one result type, the parens are mandatory. This patch fixes a bug where omitting the parens around the argument types for a `SemiFunctionType` with non-optional result Types would crash the parser. It introduces a `bool` argument `resultOptional` to the parser and printer which, when `false`, correctly enforces the parens around argument types, otherwise printing an error. Fix https://github.com/llvm/llvm-project/issues/109128
This commit is contained in:
@@ -541,9 +541,10 @@ def MatchStructuredRankOp : Op<Transform_Dialect, "match.structured.rank", [
|
||||
|
||||
let arguments = (ins TransformHandleTypeInterface:$operand_handle);
|
||||
let results = (outs TransformParamTypeInterface:$rank);
|
||||
let assemblyFormat =
|
||||
"$operand_handle attr-dict `:`"
|
||||
"custom<SemiFunctionType>(type($operand_handle), type($rank))";
|
||||
let assemblyFormat = [{
|
||||
$operand_handle attr-dict `:`
|
||||
custom<SemiFunctionType>(type($operand_handle), type($rank), "false")
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
|
||||
}
|
||||
|
||||
@@ -418,9 +418,10 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
|
||||
|
||||
let arguments = (ins TransformHandleTypeInterface:$target);
|
||||
let results = (outs TransformHandleTypeInterface:$transformed);
|
||||
let assemblyFormat =
|
||||
"$target attr-dict `:` "
|
||||
"custom<SemiFunctionType>(type($target), type($transformed))";
|
||||
let assemblyFormat = [{
|
||||
$target attr-dict `:`
|
||||
custom<SemiFunctionType>(type($target), type($transformed), "false")
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::DiagnosedSilenceableFailure applyToOne(
|
||||
@@ -455,9 +456,10 @@ def SpecializeOp : Op<Transform_Dialect, "structured.specialize",
|
||||
|
||||
let arguments = (ins TransformHandleTypeInterface:$target);
|
||||
let results = (outs TransformHandleTypeInterface:$transformed);
|
||||
let assemblyFormat =
|
||||
"$target attr-dict `:` "
|
||||
"custom<SemiFunctionType>(type($target), type($transformed))";
|
||||
let assemblyFormat = [{
|
||||
$target attr-dict `:`
|
||||
custom<SemiFunctionType>(type($target), type($transformed), "false")
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::DiagnosedSilenceableFailure applyToOne(
|
||||
@@ -500,7 +502,7 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
|
||||
let assemblyFormat = [{
|
||||
$target
|
||||
(`iterator_interchange` `=` $iterator_interchange^)? attr-dict
|
||||
`:` custom<SemiFunctionType>(type($target), type($transformed))
|
||||
`:` custom<SemiFunctionType>(type($target), type($transformed), "false")
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
|
||||
@@ -1233,9 +1235,10 @@ def PromoteOp : Op<Transform_Dialect, "structured.promote",
|
||||
OptionalAttr<I64Attr>:$alignment);
|
||||
let results = (outs TransformHandleTypeInterface:$transformed);
|
||||
|
||||
let assemblyFormat =
|
||||
"$target attr-dict `:`"
|
||||
"custom<SemiFunctionType>(type($target), type($transformed))";
|
||||
let assemblyFormat = [{
|
||||
$target attr-dict `:`
|
||||
custom<SemiFunctionType>(type($target), type($transformed), "false")
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::DiagnosedSilenceableFailure applyToOne(
|
||||
@@ -1269,9 +1272,10 @@ def ReplaceOp : Op<Transform_Dialect, "structured.replace",
|
||||
let arguments = (ins TransformHandleTypeInterface:$target);
|
||||
let results = (outs TransformHandleTypeInterface:$replacement);
|
||||
let regions = (region SizedRegion<1>:$bodyRegion);
|
||||
let assemblyFormat =
|
||||
"$target attr-dict-with-keyword regions `:` "
|
||||
"custom<SemiFunctionType>(type($target), type($replacement))";
|
||||
let assemblyFormat = [{
|
||||
$target attr-dict-with-keyword regions `:`
|
||||
custom<SemiFunctionType>(type($target), type($replacement), "false")
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
@@ -1310,9 +1314,10 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
|
||||
let arguments = (ins TransformHandleTypeInterface:$target);
|
||||
let results = (outs TransformHandleTypeInterface:$result);
|
||||
|
||||
let assemblyFormat =
|
||||
"$target attr-dict `:`"
|
||||
"custom<SemiFunctionType>(type($target), type($result))";
|
||||
let assemblyFormat = [{
|
||||
$target attr-dict `:`
|
||||
custom<SemiFunctionType>(type($target), type($result), "false")
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::DiagnosedSilenceableFailure applyToOne(
|
||||
|
||||
@@ -30,7 +30,7 @@ class Operation;
|
||||
/// the argument type in absence of result types, and does not accept the
|
||||
/// trailing `-> ()` construct, which makes the syntax nicer for operations.
|
||||
ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
|
||||
Type &resultType);
|
||||
Type &resultType, bool resultOptional = true);
|
||||
ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
|
||||
SmallVectorImpl<Type> &resultTypes);
|
||||
|
||||
@@ -40,7 +40,8 @@ ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
|
||||
void printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
|
||||
Type argumentType, TypeRange resultType);
|
||||
void printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
|
||||
Type argumentType, Type resultType);
|
||||
Type argumentType, Type resultType,
|
||||
bool resultOptional = true);
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_LINALG_TRANSFORMOPS_SYNTAX_H
|
||||
|
||||
@@ -32,7 +32,10 @@ def MatchSparseInOut : Op<Transform_Dialect, "sparse_tensor.match.sparse_inout",
|
||||
let arguments = (ins TransformHandleTypeInterface:$target);
|
||||
let results = (outs TransformHandleTypeInterface:$result);
|
||||
|
||||
let assemblyFormat = "$target attr-dict `:` custom<SemiFunctionType>(type($target), type($result))";
|
||||
let assemblyFormat = [{
|
||||
$target attr-dict `:`
|
||||
custom<SemiFunctionType>(type($target), type($result), "false")
|
||||
}];
|
||||
let extraClassDeclaration = SingleOpMatcher.extraDeclaration # [{
|
||||
::mlir::Value getOperandHandle() { return getTarget(); }
|
||||
}];
|
||||
|
||||
@@ -12,9 +12,13 @@
|
||||
using namespace mlir;
|
||||
|
||||
ParseResult mlir::parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
|
||||
Type &resultType) {
|
||||
Type &resultType, bool resultOptional) {
|
||||
argumentType = resultType = nullptr;
|
||||
bool hasLParen = parser.parseOptionalLParen().succeeded();
|
||||
|
||||
bool hasLParen = resultOptional ? parser.parseOptionalLParen().succeeded()
|
||||
: parser.parseLParen().succeeded();
|
||||
if (!resultOptional && !hasLParen)
|
||||
return failure();
|
||||
if (parser.parseType(argumentType).failed())
|
||||
return failure();
|
||||
if (!hasLParen)
|
||||
@@ -69,7 +73,9 @@ void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
|
||||
}
|
||||
|
||||
void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
|
||||
Type argumentType, Type resultType) {
|
||||
Type argumentType, Type resultType,
|
||||
bool resultOptional) {
|
||||
assert(resultOptional || resultType != nullptr);
|
||||
return printSemiFunctionType(printer, op, argumentType,
|
||||
resultType ? TypeRange(resultType)
|
||||
: TypeRange());
|
||||
|
||||
@@ -92,3 +92,11 @@ transform.sequence failures(propagate) {
|
||||
transform.structured.vectorize %arg0 vector_sizes [%0 : !transform.param<i64>, 2] : !transform.any_op, !transform.param<i64>
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg0: !transform.any_op):
|
||||
// expected-error@below {{expected '('}}
|
||||
%res = transform.structured.generalize %arg0 : !transform.any_op -> !transform.any_op
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user