Add compatible query method to infer type interface

A return type that differs from the inferred return type need not indicate that an operation is invalid (e.g., tensor<*xf32> vs tensor<10xf32>) but they should be compatible for the operation to be considered valid. Add method to query if inferred type is compatible with return type.

Also add InferTypeOpIntefaceDefault trait that considers equality and compatibility as the same. Currently an op has to opt in to using it explicitly.

PiperOrigin-RevId: 279085639
This commit is contained in:
Jacques Pienaar
2019-11-07 07:51:12 -08:00
committed by A. Unique TensorFlower
parent 72040bf7c8
commit 7af61f6bcd
6 changed files with 59 additions and 10 deletions

View File

@@ -35,6 +35,19 @@ namespace mlir {
#include "mlir/Analysis/InferTypeOpInterface.h.inc"
namespace OpTrait {
template <typename ConcreteType>
class TypeOpInterfaceDefault
: public TraitBase<ConcreteType, TypeOpInterfaceDefault> {
public:
/// Returns whether two arrays are equal as strongest check for compatibility
/// by default.
static bool isCompatibleReturnTypes(ArrayRef<Type> lhs, ArrayRef<Type> rhs) {
return lhs == rhs;
};
};
} // namespace OpTrait
} // namespace mlir
#endif // MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_

View File

@@ -1,4 +1,4 @@
//===- InferTypeOpInterface.td - Infer Type interfaces -*- tablegen -----*-===//
//===- InferTypeOpInterface.td - Infer Type interfaces -----*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
@@ -54,7 +54,21 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
"ArrayRef<NamedAttribute>":$attributes,
"ArrayRef<Region>":$regions)
>,
StaticInterfaceMethod<
/*desc=*/"Returns whether two array of types are compatible result types"
" for an op.",
/*retTy=*/"bool",
/*methodName=*/"isCompatibleReturnTypes",
/*args=*/(ins "ArrayRef<Type>":$lhs, "ArrayRef<Type>":$rhs),
[{
return ConcreteOp::isCompatibleReturnTypes(lhs, rhs);
}]
>,
];
}
// Default implementations for some of the interface methods above:
// - compatibleReturnTypes returns whether strictly true.
def InferTypeOpInterfaceDefault : NativeOpTrait<"TypeOpInterfaceDefault">;
#endif // MLIR_INFERTYPEOPINTERFACE

View File

@@ -270,9 +270,14 @@ LogicalResult TestOpWithVariadicResultsAndFolder::fold(
SmallVector<Type, 2> mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
llvm::Optional<Location> location, ArrayRef<Value *> operands,
ArrayRef<NamedAttribute> attributes, ArrayRef<Region> regions) {
if (location)
mlir::emitError(*location) << "expected to fail";
return SmallVector<Type, 2>{nullptr};
if (operands[0]->getType() != operands[1]->getType()) {
if (location)
mlir::emitError(*location)
<< "operand type mismatch " << operands[0]->getType() << " vs "
<< operands[1]->getType();
return {nullptr};
}
return {operands[0]->getType()};
}
// Static initialization for Test dialect registration.

View File

@@ -395,8 +395,9 @@ def I32ElementsAttrOp : TEST_Op<"i32ElementsAttr"> {
let arguments = (ins I32ElementsAttr:$attr);
}
def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if",
[DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [
DeclareOpInterfaceMethods<InferTypeOpInterface>,
InferTypeOpInterfaceDefault]> {
let arguments = (ins AnyTensor, AnyTensor);
let results = (outs AnyTensor);
}

View File

@@ -77,8 +77,13 @@ struct ReturnTypeOpMatch : public RewritePattern {
values.reserve(op->getNumOperands());
for (auto &operand : op->getOpOperands())
values.push_back(operand.get());
(void)retTypeFn.inferReturnTypes(op->getLoc(), values, op->getAttrs(),
op->getRegions());
auto res = retTypeFn.inferReturnTypes(op->getLoc(), values,
op->getAttrs(), op->getRegions());
SmallVector<Type, 1> result_types(op->getResultTypes());
if (!retTypeFn.isCompatibleReturnTypes(res, result_types))
return op->emitOpError(
"inferred type incompatible with return type of operation"),
matchFailure();
}
return matchFailure();
}

View File

@@ -2,7 +2,18 @@
// CHECK-LABEL: testReturnTypeOpInterface
func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) {
// expected-error@+1 {{expected to fail}}
%0 = "test.op_with_infer_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32>
%good = "test.op_with_infer_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
// expected-error@+1 {{incompatible with return type}}
%bad = "test.op_with_infer_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32>
return
}
// -----
// CHECK-LABEL: testReturnTypeOpInterface
func @testReturnTypeOpInterfaceMismatch(%arg0 : tensor<10xf32>, %arg1 : tensor<20xf32>) {
// expected-error@+2 {{incompatible with return type}}
// expected-error@+1 {{operand type mismatch}}
%bad = "test.op_with_infer_type_if"(%arg0, %arg1) : (tensor<10xf32>, tensor<20xf32>) -> tensor<*xf32>
return
}