[mlir][ods] Handle DeclareOpInterfaceMethods in formatgen

Previously it would not consider ops with
DeclareOpInterfaceMethods<InferTypeOpInterface> as having the
InferTypeOpInterface interfaces added. The OpInterface nested inside
DeclareOpInterfaceMethods is not retained so that one could query it, so
check for the the C++ class directly (a bit raw/low level - will be
addressed in follow up).

Differential Revision: https://reviews.llvm.org/D116572
This commit is contained in:
Jacques Pienaar
2022-01-04 08:28:59 -08:00
parent da6b0d0b76
commit 05594de2d7
4 changed files with 29 additions and 4 deletions

View File

@@ -264,6 +264,15 @@ Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
return builder.create<TestOpConstant>(loc, type, value);
}
::mlir::LogicalResult FormatInferType2Op::inferReturnTypes(
::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location,
::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
return ::mlir::success();
}
void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
OperationName opName) {
if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&

View File

@@ -2139,6 +2139,12 @@ def FormatInferTypeOp : TEST_Op<"format_infer_type", [InferTypeOpInterface]> {
}];
}
// Check that formatget supports DeclareOpInterfaceMethods.
def FormatInferType2Op : TEST_Op<"format_infer_type2", [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let results = (outs AnyType);
let assemblyFormat = "attr-dict";
}
// Base class for testing mixing allOperandTypes, allOperands, and
// inferResultTypes.
class FormatInferAllTypesBaseOp<string mnemonic, list<OpTrait> traits = []>

View File

@@ -409,7 +409,10 @@ test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64
//===----------------------------------------------------------------------===//
// CHECK: test.format_infer_type
%ignored_res7 = test.format_infer_type
%ignored_res7a = test.format_infer_type
// CHECK: test.format_infer_type2
%ignored_res7b = test.format_infer_type2
// CHECK: test.format_infer_type_all_operands_and_types(%[[I64]], %[[I32]]) : i64, i32
%ignored_res8:2 = test.format_infer_type_all_operands_and_types(%i64, %i32) : i64, i32

View File

@@ -2345,9 +2345,16 @@ LogicalResult FormatParser::parse() {
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
} else if (def.isSubClassOf("TypesMatchWith")) {
handleTypesMatchConstraint(variableTyResolver, def);
} else if (def.getName() == "InferTypeOpInterface" &&
!op.allResultTypesKnown()) {
canInferResultTypes = true;
} else if (!op.allResultTypesKnown()) {
// This doesn't check the name directly to handle
// DeclareOpInterfaceMethods<InferTypeOpInterface>
// and the like.
// TODO: Add hasCppInterface check.
if (auto name = def.getValueAsOptionalString("cppClassName")) {
if (*name == "InferTypeOpInterface" &&
def.getValueAsString("cppNamespace") == "::mlir")
canInferResultTypes = true;
}
}
}