From 05594de2d77b6f4735b8d8d417039b60987b3a79 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Tue, 4 Jan 2022 08:28:59 -0800 Subject: [PATCH] [mlir][ods] Handle DeclareOpInterfaceMethods in formatgen Previously it would not consider ops with DeclareOpInterfaceMethods as having the InferTypeOpInterface interfaces added. The OpInterface nested inside DeclareOpInterfaceMethods is not retained so that one could query it, so check for the the C++ class directly (a bit raw/low level - will be addressed in follow up). Differential Revision: https://reviews.llvm.org/D116572 --- mlir/test/lib/Dialect/Test/TestDialect.cpp | 9 +++++++++ mlir/test/lib/Dialect/Test/TestOps.td | 6 ++++++ mlir/test/mlir-tblgen/op-format.mlir | 5 ++++- mlir/tools/mlir-tblgen/OpFormatGen.cpp | 13 ++++++++++--- 4 files changed, 29 insertions(+), 4 deletions(-) diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index aee0bdb13970..441817803ef0 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -264,6 +264,15 @@ Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, return builder.create(loc, type, value); } +::mlir::LogicalResult FormatInferType2Op::inferReturnTypes( + ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, + ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, + ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)}); + return ::mlir::success(); +} + void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, OperationName opName) { if (opName.getIdentifier() == "test.unregistered_side_effect_op" && diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 39f0b0b7da56..6fad11b85ad8 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2139,6 +2139,12 @@ def FormatInferTypeOp : TEST_Op<"format_infer_type", [InferTypeOpInterface]> { }]; } +// Check that formatget supports DeclareOpInterfaceMethods. +def FormatInferType2Op : TEST_Op<"format_infer_type2", [DeclareOpInterfaceMethods]> { + let results = (outs AnyType); + let assemblyFormat = "attr-dict"; +} + // Base class for testing mixing allOperandTypes, allOperands, and // inferResultTypes. class FormatInferAllTypesBaseOp traits = []> diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir index 152cd0a554f1..77afc41f6541 100644 --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -409,7 +409,10 @@ test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64 //===----------------------------------------------------------------------===// // CHECK: test.format_infer_type -%ignored_res7 = test.format_infer_type +%ignored_res7a = test.format_infer_type + +// CHECK: test.format_infer_type2 +%ignored_res7b = test.format_infer_type2 // CHECK: test.format_infer_type_all_operands_and_types(%[[I64]], %[[I32]]) : i64, i32 %ignored_res8:2 = test.format_infer_type_all_operands_and_types(%i64, %i32) : i64, i32 diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index 02d0e81b6860..b5218030b64d 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -2345,9 +2345,16 @@ LogicalResult FormatParser::parse() { handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true); } else if (def.isSubClassOf("TypesMatchWith")) { handleTypesMatchConstraint(variableTyResolver, def); - } else if (def.getName() == "InferTypeOpInterface" && - !op.allResultTypesKnown()) { - canInferResultTypes = true; + } else if (!op.allResultTypesKnown()) { + // This doesn't check the name directly to handle + // DeclareOpInterfaceMethods + // and the like. + // TODO: Add hasCppInterface check. + if (auto name = def.getValueAsOptionalString("cppClassName")) { + if (*name == "InferTypeOpInterface" && + def.getValueAsString("cppNamespace") == "::mlir") + canInferResultTypes = true; + } } }