mirror of
https://github.com/intel/llvm.git
synced 2026-01-24 08:30:34 +08:00
[mlir] Add refineReturnTypes to InferTypeOpInterface
refineReturnType method shares the same parameters as inferReturnTypes but gets passed in the return types of the op if known that can be used during refinement passes or for more op specific error reporting. Currently the error reporting on failure is generic and doesn't allow for specializing the returned result based on failure, with this change what would previously have been a separate trait with specialized verification can just be handled as part of inferrence rather than duplicated. refineReturnTypes behaves like inferReturnTypes if no result types are fed in, while the current verification is recast as the default implementation for refineReturnTypes with it calling inferReturnTypes (and so the default type verification now goes through refine and allows for more op specific inference mismatch errors). Differential Revision: https://reviews.llvm.org/D129955
This commit is contained in:
@@ -47,6 +47,54 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
|
||||
"::mlir::RegionRange":$regions,
|
||||
"::llvm::SmallVectorImpl<::mlir::Type>&":$inferredReturnTypes)
|
||||
>,
|
||||
StaticInterfaceMethod<
|
||||
/*desc=*/[{Refine the return types that an op would generate.
|
||||
|
||||
This method computes the return types as `inferReturnTypes` does but
|
||||
additionally takes the existing result types as input. The existing
|
||||
result types can be checked as part of inference to provide more
|
||||
op-specific error messages as well as part of inference to merge
|
||||
additional information, attributes, during inference. It is called during
|
||||
verification for ops implementing this trait with default behavior
|
||||
reporting mismatch with current and inferred types printed.
|
||||
|
||||
The operands and attributes correspond to those with which an Operation
|
||||
would be created (e.g., as used in Operation::create) and the regions of
|
||||
the op. The method takes an optional location which, if set, will be used
|
||||
to report errors on.
|
||||
|
||||
The return types may be elided or specific elements be null for elements
|
||||
that should just be returned but not verified.
|
||||
|
||||
Be aware that this method is supposed to be called with valid arguments,
|
||||
e.g., operands are verified, or it may result in an undefined behavior.
|
||||
}],
|
||||
/*retTy=*/"::mlir::LogicalResult",
|
||||
/*methodName=*/"refineReturnTypes",
|
||||
/*args=*/(ins "::mlir::MLIRContext *":$context,
|
||||
"::llvm::Optional<::mlir::Location>":$location,
|
||||
"::mlir::ValueRange":$operands,
|
||||
"::mlir::DictionaryAttr":$attributes,
|
||||
"::mlir::RegionRange":$regions,
|
||||
"::llvm::SmallVectorImpl<::mlir::Type>&":$returnTypes),
|
||||
/*methodBody=*/[{}],
|
||||
/*defaultImplementation=*/[{
|
||||
llvm::SmallVector<Type, 4> inferredReturnTypes;
|
||||
if (failed(ConcreteOp::inferReturnTypes(context, location, operands,
|
||||
attributes, regions,
|
||||
inferredReturnTypes)))
|
||||
return failure();
|
||||
if (!ConcreteOp::isCompatibleReturnTypes(inferredReturnTypes,
|
||||
returnTypes)) {
|
||||
return emitOptionalError(
|
||||
location, "'", ConcreteOp::getOperationName(),
|
||||
"' op inferred type(s) ", inferredReturnTypes,
|
||||
" are incompatible with return type(s) of operation ",
|
||||
returnTypes);
|
||||
}
|
||||
return success();
|
||||
}]
|
||||
>,
|
||||
StaticInterfaceMethod<
|
||||
/*desc=*/"Returns whether two array of types are compatible result types"
|
||||
" for an op.",
|
||||
|
||||
@@ -204,17 +204,9 @@ LogicalResult mlir::detail::inferReturnTensorTypes(
|
||||
}
|
||||
|
||||
LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
|
||||
SmallVector<Type, 4> inferredReturnTypes;
|
||||
SmallVector<Type, 4> inferredReturnTypes(op->getResultTypes());
|
||||
auto retTypeFn = cast<InferTypeOpInterface>(op);
|
||||
if (failed(retTypeFn.inferReturnTypes(
|
||||
op->getContext(), op->getLoc(), op->getOperands(),
|
||||
op->getAttrDictionary(), op->getRegions(), inferredReturnTypes)))
|
||||
return failure();
|
||||
if (!retTypeFn.isCompatibleReturnTypes(inferredReturnTypes,
|
||||
op->getResultTypes()))
|
||||
return op->emitOpError("inferred type(s) ")
|
||||
<< inferredReturnTypes
|
||||
<< " are incompatible with return type(s) of operation "
|
||||
<< op->getResultTypes();
|
||||
return success();
|
||||
return retTypeFn.refineReturnTypes(op->getContext(), op->getLoc(),
|
||||
op->getOperands(), op->getAttrDictionary(),
|
||||
op->getRegions(), inferredReturnTypes);
|
||||
}
|
||||
|
||||
@@ -1128,6 +1128,36 @@ LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
|
||||
return success();
|
||||
}
|
||||
|
||||
// TODO: We should be able to only define either inferReturnType or
|
||||
// refineReturnType, currently only refineReturnType can be omitted.
|
||||
LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &returnTypes) {
|
||||
returnTypes.clear();
|
||||
return OpWithRefineTypeInterfaceOp::refineReturnTypes(
|
||||
context, location, operands, attributes, regions, returnTypes);
|
||||
}
|
||||
|
||||
LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
|
||||
MLIRContext *, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &returnTypes) {
|
||||
if (operands[0].getType() != operands[1].getType()) {
|
||||
return emitOptionalError(location, "operand type mismatch ",
|
||||
operands[0].getType(), " vs ",
|
||||
operands[1].getType());
|
||||
}
|
||||
// TODO: Add helper to make this more concise to write.
|
||||
if (returnTypes.empty())
|
||||
returnTypes.resize(1, nullptr);
|
||||
if (returnTypes[0] && returnTypes[0] != operands[0].getType())
|
||||
return emitOptionalError(location,
|
||||
"required first operand and result to match");
|
||||
returnTypes[0] = operands[0].getType();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
|
||||
MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
|
||||
@@ -645,8 +645,14 @@ def IndexElementsAttrOp : TEST_Op<"indexElementsAttr"> {
|
||||
}
|
||||
|
||||
def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let arguments = (ins AnyTensor, AnyTensor);
|
||||
let results = (outs AnyTensor);
|
||||
}
|
||||
|
||||
def OpWithRefineTypeInterfaceOp : TEST_Op<"op_with_refine_type_if", [
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface,
|
||||
["inferReturnTypeComponents"]>]> {
|
||||
["refineReturnTypes"]>]> {
|
||||
let arguments = (ins AnyTensor, AnyTensor);
|
||||
let results = (outs AnyTensor);
|
||||
}
|
||||
|
||||
@@ -39,6 +39,13 @@ func.func @testReturnTypeOpInterfaceMismatch(%arg0 : tensor<10xf32>, %arg1 : ten
|
||||
|
||||
// -----
|
||||
|
||||
func.func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) {
|
||||
// expected-error@+1 {{required first operand and result to match}}
|
||||
%bad = "test.op_with_refine_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: testReifyFunctions
|
||||
func.func @testReifyFunctions(%arg0 : tensor<10xf32>, %arg1 : tensor<20xf32>) {
|
||||
// expected-remark@+1 {{arith.constant 10}}
|
||||
|
||||
Reference in New Issue
Block a user