diff --git a/mlir/include/mlir/Analysis/InferTypeOpInterface.h b/mlir/include/mlir/Analysis/InferTypeOpInterface.h index b80723e45f13..2d68ada0d135 100644 --- a/mlir/include/mlir/Analysis/InferTypeOpInterface.h +++ b/mlir/include/mlir/Analysis/InferTypeOpInterface.h @@ -35,6 +35,19 @@ namespace mlir { #include "mlir/Analysis/InferTypeOpInterface.h.inc" +namespace OpTrait { +template +class TypeOpInterfaceDefault + : public TraitBase { +public: + /// Returns whether two arrays are equal as strongest check for compatibility + /// by default. + static bool isCompatibleReturnTypes(ArrayRef lhs, ArrayRef rhs) { + return lhs == rhs; + }; +}; +} // namespace OpTrait + } // namespace mlir #endif // MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_ diff --git a/mlir/include/mlir/Analysis/InferTypeOpInterface.td b/mlir/include/mlir/Analysis/InferTypeOpInterface.td index addd0b353201..325c550f5ea2 100644 --- a/mlir/include/mlir/Analysis/InferTypeOpInterface.td +++ b/mlir/include/mlir/Analysis/InferTypeOpInterface.td @@ -1,4 +1,4 @@ -//===- InferTypeOpInterface.td - Infer Type interfaces -*- tablegen -----*-===// +//===- InferTypeOpInterface.td - Infer Type interfaces -----*- tablegen -*-===// // // Copyright 2019 The MLIR Authors. // @@ -54,7 +54,21 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> { "ArrayRef":$attributes, "ArrayRef":$regions) >, + StaticInterfaceMethod< + /*desc=*/"Returns whether two array of types are compatible result types" + " for an op.", + /*retTy=*/"bool", + /*methodName=*/"isCompatibleReturnTypes", + /*args=*/(ins "ArrayRef":$lhs, "ArrayRef":$rhs), + [{ + return ConcreteOp::isCompatibleReturnTypes(lhs, rhs); + }] + >, ]; } +// Default implementations for some of the interface methods above: +// - compatibleReturnTypes returns whether strictly true. +def InferTypeOpInterfaceDefault : NativeOpTrait<"TypeOpInterfaceDefault">; + #endif // MLIR_INFERTYPEOPINTERFACE diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp index 5528e8f0daab..a6ec6ad9d7ed 100644 --- a/mlir/test/lib/TestDialect/TestDialect.cpp +++ b/mlir/test/lib/TestDialect/TestDialect.cpp @@ -270,9 +270,14 @@ LogicalResult TestOpWithVariadicResultsAndFolder::fold( SmallVector mlir::OpWithInferTypeInterfaceOp::inferReturnTypes( llvm::Optional location, ArrayRef operands, ArrayRef attributes, ArrayRef regions) { - if (location) - mlir::emitError(*location) << "expected to fail"; - return SmallVector{nullptr}; + if (operands[0]->getType() != operands[1]->getType()) { + if (location) + mlir::emitError(*location) + << "operand type mismatch " << operands[0]->getType() << " vs " + << operands[1]->getType(); + return {nullptr}; + } + return {operands[0]->getType()}; } // Static initialization for Test dialect registration. diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index cc1e22278baf..4071f7e232f3 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -395,8 +395,9 @@ def I32ElementsAttrOp : TEST_Op<"i32ElementsAttr"> { let arguments = (ins I32ElementsAttr:$attr); } -def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", - [DeclareOpInterfaceMethods]> { +def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [ + DeclareOpInterfaceMethods, + InferTypeOpInterfaceDefault]> { let arguments = (ins AnyTensor, AnyTensor); let results = (outs AnyTensor); } diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp index dfde4f8bc994..cc935c71c28d 100644 --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -77,8 +77,13 @@ struct ReturnTypeOpMatch : public RewritePattern { values.reserve(op->getNumOperands()); for (auto &operand : op->getOpOperands()) values.push_back(operand.get()); - (void)retTypeFn.inferReturnTypes(op->getLoc(), values, op->getAttrs(), - op->getRegions()); + auto res = retTypeFn.inferReturnTypes(op->getLoc(), values, + op->getAttrs(), op->getRegions()); + SmallVector result_types(op->getResultTypes()); + if (!retTypeFn.isCompatibleReturnTypes(res, result_types)) + return op->emitOpError( + "inferred type incompatible with return type of operation"), + matchFailure(); } return matchFailure(); } diff --git a/mlir/test/mlir-tblgen/return-types.mlir b/mlir/test/mlir-tblgen/return-types.mlir index f203677546ee..13eda19e144a 100644 --- a/mlir/test/mlir-tblgen/return-types.mlir +++ b/mlir/test/mlir-tblgen/return-types.mlir @@ -2,7 +2,18 @@ // CHECK-LABEL: testReturnTypeOpInterface func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) { - // expected-error@+1 {{expected to fail}} - %0 = "test.op_with_infer_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32> + %good = "test.op_with_infer_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + // expected-error@+1 {{incompatible with return type}} + %bad = "test.op_with_infer_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32> + return +} + +// ----- + +// CHECK-LABEL: testReturnTypeOpInterface +func @testReturnTypeOpInterfaceMismatch(%arg0 : tensor<10xf32>, %arg1 : tensor<20xf32>) { + // expected-error@+2 {{incompatible with return type}} + // expected-error@+1 {{operand type mismatch}} + %bad = "test.op_with_infer_type_if"(%arg0, %arg1) : (tensor<10xf32>, tensor<20xf32>) -> tensor<*xf32> return }