mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 12:26:52 +08:00
Add compatible query method to infer type interface
A return type that differs from the inferred return type need not indicate that an operation is invalid (e.g., tensor<*xf32> vs tensor<10xf32>) but they should be compatible for the operation to be considered valid. Add method to query if inferred type is compatible with return type. Also add InferTypeOpIntefaceDefault trait that considers equality and compatibility as the same. Currently an op has to opt in to using it explicitly. PiperOrigin-RevId: 279085639
This commit is contained in:
committed by
A. Unique TensorFlower
parent
72040bf7c8
commit
7af61f6bcd
@@ -35,6 +35,19 @@ namespace mlir {
|
||||
|
||||
#include "mlir/Analysis/InferTypeOpInterface.h.inc"
|
||||
|
||||
namespace OpTrait {
|
||||
template <typename ConcreteType>
|
||||
class TypeOpInterfaceDefault
|
||||
: public TraitBase<ConcreteType, TypeOpInterfaceDefault> {
|
||||
public:
|
||||
/// Returns whether two arrays are equal as strongest check for compatibility
|
||||
/// by default.
|
||||
static bool isCompatibleReturnTypes(ArrayRef<Type> lhs, ArrayRef<Type> rhs) {
|
||||
return lhs == rhs;
|
||||
};
|
||||
};
|
||||
} // namespace OpTrait
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//===- InferTypeOpInterface.td - Infer Type interfaces -*- tablegen -----*-===//
|
||||
//===- InferTypeOpInterface.td - Infer Type interfaces -----*- tablegen -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
@@ -54,7 +54,21 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
|
||||
"ArrayRef<NamedAttribute>":$attributes,
|
||||
"ArrayRef<Region>":$regions)
|
||||
>,
|
||||
StaticInterfaceMethod<
|
||||
/*desc=*/"Returns whether two array of types are compatible result types"
|
||||
" for an op.",
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"isCompatibleReturnTypes",
|
||||
/*args=*/(ins "ArrayRef<Type>":$lhs, "ArrayRef<Type>":$rhs),
|
||||
[{
|
||||
return ConcreteOp::isCompatibleReturnTypes(lhs, rhs);
|
||||
}]
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
// Default implementations for some of the interface methods above:
|
||||
// - compatibleReturnTypes returns whether strictly true.
|
||||
def InferTypeOpInterfaceDefault : NativeOpTrait<"TypeOpInterfaceDefault">;
|
||||
|
||||
#endif // MLIR_INFERTYPEOPINTERFACE
|
||||
|
||||
@@ -270,9 +270,14 @@ LogicalResult TestOpWithVariadicResultsAndFolder::fold(
|
||||
SmallVector<Type, 2> mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
|
||||
llvm::Optional<Location> location, ArrayRef<Value *> operands,
|
||||
ArrayRef<NamedAttribute> attributes, ArrayRef<Region> regions) {
|
||||
if (location)
|
||||
mlir::emitError(*location) << "expected to fail";
|
||||
return SmallVector<Type, 2>{nullptr};
|
||||
if (operands[0]->getType() != operands[1]->getType()) {
|
||||
if (location)
|
||||
mlir::emitError(*location)
|
||||
<< "operand type mismatch " << operands[0]->getType() << " vs "
|
||||
<< operands[1]->getType();
|
||||
return {nullptr};
|
||||
}
|
||||
return {operands[0]->getType()};
|
||||
}
|
||||
|
||||
// Static initialization for Test dialect registration.
|
||||
|
||||
@@ -395,8 +395,9 @@ def I32ElementsAttrOp : TEST_Op<"i32ElementsAttr"> {
|
||||
let arguments = (ins I32ElementsAttr:$attr);
|
||||
}
|
||||
|
||||
def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if",
|
||||
[DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
InferTypeOpInterfaceDefault]> {
|
||||
let arguments = (ins AnyTensor, AnyTensor);
|
||||
let results = (outs AnyTensor);
|
||||
}
|
||||
|
||||
@@ -77,8 +77,13 @@ struct ReturnTypeOpMatch : public RewritePattern {
|
||||
values.reserve(op->getNumOperands());
|
||||
for (auto &operand : op->getOpOperands())
|
||||
values.push_back(operand.get());
|
||||
(void)retTypeFn.inferReturnTypes(op->getLoc(), values, op->getAttrs(),
|
||||
op->getRegions());
|
||||
auto res = retTypeFn.inferReturnTypes(op->getLoc(), values,
|
||||
op->getAttrs(), op->getRegions());
|
||||
SmallVector<Type, 1> result_types(op->getResultTypes());
|
||||
if (!retTypeFn.isCompatibleReturnTypes(res, result_types))
|
||||
return op->emitOpError(
|
||||
"inferred type incompatible with return type of operation"),
|
||||
matchFailure();
|
||||
}
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
@@ -2,7 +2,18 @@
|
||||
|
||||
// CHECK-LABEL: testReturnTypeOpInterface
|
||||
func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) {
|
||||
// expected-error@+1 {{expected to fail}}
|
||||
%0 = "test.op_with_infer_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32>
|
||||
%good = "test.op_with_infer_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
|
||||
// expected-error@+1 {{incompatible with return type}}
|
||||
%bad = "test.op_with_infer_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testReturnTypeOpInterface
|
||||
func @testReturnTypeOpInterfaceMismatch(%arg0 : tensor<10xf32>, %arg1 : tensor<20xf32>) {
|
||||
// expected-error@+2 {{incompatible with return type}}
|
||||
// expected-error@+1 {{operand type mismatch}}
|
||||
%bad = "test.op_with_infer_type_if"(%arg0, %arg1) : (tensor<10xf32>, tensor<20xf32>) -> tensor<*xf32>
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user