[mlir] Add refineReturnTypes to InferTypeOpInterface

refineReturnType method shares the same parameters as inferReturnTypes
but gets passed in the return types of the op if known that can be used
during refinement passes or for more op specific error reporting.
Currently the error reporting on failure is generic and doesn't allow
for specializing the returned result based on failure, with this change
what would previously have been a separate trait with specialized
verification can just be handled as part of inferrence rather than
duplicated.

refineReturnTypes behaves like inferReturnTypes if no result types are fed in,
while the current verification is recast as the default implementation for
refineReturnTypes with it calling inferReturnTypes (and so the default type
verification now goes through refine and allows for more op specific inference
mismatch errors).

Differential Revision: https://reviews.llvm.org/D129955
This commit is contained in:
Jacques Pienaar
2022-07-18 22:18:52 -07:00
parent 83e922562f
commit c8598fa22f
5 changed files with 96 additions and 13 deletions

View File

@@ -47,6 +47,54 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
"::mlir::RegionRange":$regions,
"::llvm::SmallVectorImpl<::mlir::Type>&":$inferredReturnTypes)
>,
StaticInterfaceMethod<
/*desc=*/[{Refine the return types that an op would generate.
This method computes the return types as `inferReturnTypes` does but
additionally takes the existing result types as input. The existing
result types can be checked as part of inference to provide more
op-specific error messages as well as part of inference to merge
additional information, attributes, during inference. It is called during
verification for ops implementing this trait with default behavior
reporting mismatch with current and inferred types printed.
The operands and attributes correspond to those with which an Operation
would be created (e.g., as used in Operation::create) and the regions of
the op. The method takes an optional location which, if set, will be used
to report errors on.
The return types may be elided or specific elements be null for elements
that should just be returned but not verified.
Be aware that this method is supposed to be called with valid arguments,
e.g., operands are verified, or it may result in an undefined behavior.
}],
/*retTy=*/"::mlir::LogicalResult",
/*methodName=*/"refineReturnTypes",
/*args=*/(ins "::mlir::MLIRContext *":$context,
"::llvm::Optional<::mlir::Location>":$location,
"::mlir::ValueRange":$operands,
"::mlir::DictionaryAttr":$attributes,
"::mlir::RegionRange":$regions,
"::llvm::SmallVectorImpl<::mlir::Type>&":$returnTypes),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
llvm::SmallVector<Type, 4> inferredReturnTypes;
if (failed(ConcreteOp::inferReturnTypes(context, location, operands,
attributes, regions,
inferredReturnTypes)))
return failure();
if (!ConcreteOp::isCompatibleReturnTypes(inferredReturnTypes,
returnTypes)) {
return emitOptionalError(
location, "'", ConcreteOp::getOperationName(),
"' op inferred type(s) ", inferredReturnTypes,
" are incompatible with return type(s) of operation ",
returnTypes);
}
return success();
}]
>,
StaticInterfaceMethod<
/*desc=*/"Returns whether two array of types are compatible result types"
" for an op.",

View File

@@ -204,17 +204,9 @@ LogicalResult mlir::detail::inferReturnTensorTypes(
}
LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
SmallVector<Type, 4> inferredReturnTypes;
SmallVector<Type, 4> inferredReturnTypes(op->getResultTypes());
auto retTypeFn = cast<InferTypeOpInterface>(op);
if (failed(retTypeFn.inferReturnTypes(
op->getContext(), op->getLoc(), op->getOperands(),
op->getAttrDictionary(), op->getRegions(), inferredReturnTypes)))
return failure();
if (!retTypeFn.isCompatibleReturnTypes(inferredReturnTypes,
op->getResultTypes()))
return op->emitOpError("inferred type(s) ")
<< inferredReturnTypes
<< " are incompatible with return type(s) of operation "
<< op->getResultTypes();
return success();
return retTypeFn.refineReturnTypes(op->getContext(), op->getLoc(),
op->getOperands(), op->getAttrDictionary(),
op->getRegions(), inferredReturnTypes);
}

View File

@@ -1128,6 +1128,36 @@ LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
return success();
}
// TODO: We should be able to only define either inferReturnType or
// refineReturnType, currently only refineReturnType can be omitted.
LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &returnTypes) {
returnTypes.clear();
return OpWithRefineTypeInterfaceOp::refineReturnTypes(
context, location, operands, attributes, regions, returnTypes);
}
LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
MLIRContext *, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &returnTypes) {
if (operands[0].getType() != operands[1].getType()) {
return emitOptionalError(location, "operand type mismatch ",
operands[0].getType(), " vs ",
operands[1].getType());
}
// TODO: Add helper to make this more concise to write.
if (returnTypes.empty())
returnTypes.resize(1, nullptr);
if (returnTypes[0] && returnTypes[0] != operands[0].getType())
return emitOptionalError(location,
"required first operand and result to match");
returnTypes[0] = operands[0].getType();
return success();
}
LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
DictionaryAttr attributes, RegionRange regions,

View File

@@ -645,8 +645,14 @@ def IndexElementsAttrOp : TEST_Op<"indexElementsAttr"> {
}
def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let arguments = (ins AnyTensor, AnyTensor);
let results = (outs AnyTensor);
}
def OpWithRefineTypeInterfaceOp : TEST_Op<"op_with_refine_type_if", [
DeclareOpInterfaceMethods<InferTypeOpInterface,
["inferReturnTypeComponents"]>]> {
["refineReturnTypes"]>]> {
let arguments = (ins AnyTensor, AnyTensor);
let results = (outs AnyTensor);
}

View File

@@ -39,6 +39,13 @@ func.func @testReturnTypeOpInterfaceMismatch(%arg0 : tensor<10xf32>, %arg1 : ten
// -----
func.func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) {
// expected-error@+1 {{required first operand and result to match}}
%bad = "test.op_with_refine_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32>
return
}
// -----
// CHECK-LABEL: testReifyFunctions
func.func @testReifyFunctions(%arg0 : tensor<10xf32>, %arg1 : tensor<20xf32>) {
// expected-remark@+1 {{arith.constant 10}}