Files
llvm/mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp
Felix Schneider a07b422e90 [mlir][linalg] Fix SemiFunctionType custom parsing crash on missing () (#110365)
The `SemiFunctionType` allows printing/parsing a set of argument and
result types, where there is always exactly one argument type and zero
or more result types. If there are no result types, the argument type
can be written without enclosing parens in the assembly. If there is at
least one result type, the parens are mandatory.

This patch fixes a bug where omitting the parens around the argument
types for a `SemiFunctionType` with non-optional result Types would
crash the parser. It introduces a `bool` argument `resultOptional` to
the parser and printer which, when `false`, correctly enforces the
parens around argument types, otherwise printing an error.

Fix https://github.com/llvm/llvm-project/issues/109128
2024-11-03 15:31:25 +01:00

83 lines
2.8 KiB
C++

//===- Syntax.cpp - Custom syntax for Linalg transform ops ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
#include "mlir/IR/OpImplementation.h"
using namespace mlir;
ParseResult mlir::parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
Type &resultType, bool resultOptional) {
argumentType = resultType = nullptr;
bool hasLParen = resultOptional ? parser.parseOptionalLParen().succeeded()
: parser.parseLParen().succeeded();
if (!resultOptional && !hasLParen)
return failure();
if (parser.parseType(argumentType).failed())
return failure();
if (!hasLParen)
return success();
return failure(parser.parseRParen().failed() ||
parser.parseArrow().failed() ||
parser.parseType(resultType).failed());
}
ParseResult mlir::parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
SmallVectorImpl<Type> &resultTypes) {
argumentType = nullptr;
bool hasLParen = parser.parseOptionalLParen().succeeded();
if (parser.parseType(argumentType).failed())
return failure();
if (!hasLParen)
return success();
if (parser.parseRParen().failed() || parser.parseArrow().failed())
return failure();
if (parser.parseOptionalLParen().failed()) {
Type type;
if (parser.parseType(type).failed())
return failure();
resultTypes.push_back(type);
return success();
}
if (parser.parseTypeList(resultTypes).failed() ||
parser.parseRParen().failed()) {
resultTypes.clear();
return failure();
}
return success();
}
void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
Type argumentType, TypeRange resultType) {
if (!resultType.empty())
printer << "(";
printer << argumentType;
if (resultType.empty())
return;
printer << ") -> ";
if (resultType.size() > 1)
printer << "(";
llvm::interleaveComma(resultType, printer.getStream());
if (resultType.size() > 1)
printer << ")";
}
void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
Type argumentType, Type resultType,
bool resultOptional) {
assert(resultOptional || resultType != nullptr);
return printSemiFunctionType(printer, op, argumentType,
resultType ? TypeRange(resultType)
: TypeRange());
}