From 3812d956eaef834eb3794d311aef2097aac268e0 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Sun, 9 Jun 2019 07:00:09 -0700 Subject: [PATCH] [ODS] Support variadic operand/result verification This CL enables verification code generation for variadic operands and results. In verify(), we use fallback getter methods to access all the dynamic values belonging to one static variadic operand/result to reuse the value range calculation there. PiperOrigin-RevId: 252288219 --- mlir/include/mlir/IR/OpBase.td | 7 +- mlir/include/mlir/LLVMIR/LLVMOps.td | 4 +- mlir/include/mlir/TableGen/Operator.h | 5 +- mlir/lib/StandardOps/Ops.cpp | 6 -- mlir/lib/TableGen/Operator.cpp | 4 +- mlir/test/IR/invalid-ops.mlir | 2 +- mlir/test/IR/operand.mlir | 35 ++++++++++ mlir/test/IR/result.mlir | 36 +++++++++++ mlir/test/TestDialect/TestOps.td | 25 +++++++ mlir/test/mlir-tblgen/op-operand.td | 18 ------ mlir/test/mlir-tblgen/op-result.td | 18 ------ mlir/test/mlir-tblgen/predicate.td | 6 +- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 72 +++++++++++---------- 13 files changed, 152 insertions(+), 86 deletions(-) create mode 100644 mlir/test/IR/operand.mlir create mode 100644 mlir/test/IR/result.mlir diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index ac8c6528be09..0d3221345988 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -255,9 +255,7 @@ class TypeAlias : // class is used for supporting variadic operands/results. An op can declare no // more than one variadic operand/result, and that operand/result must be the // last one in the operand/result list. -class Variadic - // TODO(b/132908002): support variadic type conditions - : TypeConstraint, descr> { +class Variadic : TypeConstraint { Type baseType = type; } @@ -907,6 +905,9 @@ def Terminator : NativeOpTrait<"IsTerminator">; def FirstAttrDerivedResultType : GenInternalOpTrait<"FirstAttrDerivedResultType">; +// TODO(antiagainst): Turn the following into normal traits and generate +// verification for them. + // All variadic operands of the op have the same number of values. // A variadic operand contains an array of values whose array size is only // known at runtime. This trait requires all variadic operands of an op diff --git a/mlir/include/mlir/LLVMIR/LLVMOps.td b/mlir/include/mlir/LLVMIR/LLVMOps.td index a207e947023f..e9f235ac95c5 100644 --- a/mlir/include/mlir/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/LLVMIR/LLVMOps.td @@ -203,7 +203,9 @@ def LLVM_PtrToIntOp // Call-related operations. def LLVM_CallOp : LLVM_Op<"call">, Arguments<(ins OptionalAttr:$callee, - Variadic)>, + // TODO(b/133216756): fix test failure and + // change to LLVM_Type + Variadic)>, Results<(outs Variadic)>, LLVM_TwoBuilders { diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h index 6cc6bbc35f56..4551790cef23 100644 --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -69,11 +69,12 @@ public: std::string getQualCppClassName() const; using value_iterator = NamedTypeConstraint *; + using value_range = llvm::iterator_range; // Op result iterators. value_iterator result_begin(); value_iterator result_end(); - llvm::iterator_range getResults(); + value_range getResults(); // Returns the number of results this op produces. int getNumResults() const; @@ -110,7 +111,7 @@ public: // Op operand iterators. value_iterator operand_begin(); value_iterator operand_end(); - llvm::iterator_range getOperands(); + value_range getOperands(); int getNumOperands() const { return operands.size(); } NamedTypeConstraint &getOperand(int index) { return operands[index]; } diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 4b2940a17303..1ef3fcd0f82f 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -1595,12 +1595,6 @@ static LogicalResult verify(ExtractElementOp op) { if (op.getType() != aggregateType.getElementType()) return op.emitOpError("result type must match element type of aggregate"); - // TODO(b/132908002) This should be covered by the op specification in - // tablegen, but for some reason it's not. - for (auto *idx : op.getIndices()) - if (!idx->getType().isIndex()) - return op.emitOpError("index to extract_element must have 'index' type"); - // Verify the # indices match if we have a ranked type. if (aggregateType.hasRank() && aggregateType.getRank() != op.getNumOperands() - 1) diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index 3c269ba65472..a5dd9c25027d 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -95,7 +95,7 @@ auto tblgen::Operator::result_begin() -> value_iterator { auto tblgen::Operator::result_end() -> value_iterator { return results.end(); } -auto tblgen::Operator::getResults() -> llvm::iterator_range { +auto tblgen::Operator::getResults() -> value_range { return {result_begin(), result_end()}; } @@ -205,7 +205,7 @@ auto tblgen::Operator::operand_begin() -> value_iterator { auto tblgen::Operator::operand_end() -> value_iterator { return operands.end(); } -auto tblgen::Operator::getOperands() -> llvm::iterator_range { +auto tblgen::Operator::getOperands() -> value_range { return {operand_begin(), operand_end()}; } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index 562c2ce831e8..6baa104b3b3b 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -639,7 +639,7 @@ func @extract_element_no_indices(%v : vector<3xf32>) { // ----- func @extract_element_invalid_index_type(%v : vector<3xf32>, %i : i32) { - // expected-error@+1 {{index to extract_element must have 'index' type}} + // expected-error@+1 {{operand #1 must be index}} %0 = "std.extract_element"(%v, %i) : (vector<3xf32>, i32) -> f32 return } diff --git a/mlir/test/IR/operand.mlir b/mlir/test/IR/operand.mlir new file mode 100644 index 000000000000..0d7939f756f2 --- /dev/null +++ b/mlir/test/IR/operand.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-test-opt %s -split-input-file -verify | FileCheck %s + +//===----------------------------------------------------------------------===// +// Test mixed normal and variadic operands +//===----------------------------------------------------------------------===// + +func @correct_variadic_operand(%arg0: tensor, %arg1: f32) { + // CHECK: mixed_normal_variadic_operand + "test.mixed_normal_variadic_operand"(%arg0, %arg0, %arg0, %arg0, %arg0) : (tensor, tensor, tensor, tensor, tensor) -> () + return +} + +// ----- + +func @error_in_first_variadic_operand(%arg0: tensor, %arg1: f32) { + // expected-error @+1 {{operand #0 must be tensor of any type}} + "test.mixed_normal_variadic_operand"(%arg0, %arg1, %arg0, %arg0, %arg0) : (tensor, f32, tensor, tensor, tensor) -> () + return +} + +// ----- + +func @error_in_normal_operand(%arg0: tensor, %arg1: f32) { + // expected-error @+1 {{operand #1 must be tensor of any type}} + "test.mixed_normal_variadic_operand"(%arg0, %arg0, %arg1, %arg0, %arg0) : (tensor, tensor, f32, tensor, tensor) -> () + return +} + +// ----- + +func @error_in_second_variadic_operand(%arg0: tensor, %arg1: f32) { + // expected-error @+1 {{operand #2 must be tensor of any type}} + "test.mixed_normal_variadic_operand"(%arg0, %arg0, %arg0, %arg1, %arg0) : (tensor, tensor, tensor, f32, tensor) -> () + return +} diff --git a/mlir/test/IR/result.mlir b/mlir/test/IR/result.mlir new file mode 100644 index 000000000000..fc5d597c31d3 --- /dev/null +++ b/mlir/test/IR/result.mlir @@ -0,0 +1,36 @@ +// RUN: mlir-test-opt %s -split-input-file -verify | FileCheck %s + +//===----------------------------------------------------------------------===// +// Test mixed normal and variadic results +//===----------------------------------------------------------------------===// + +func @correct_variadic_result() -> tensor { + // CHECK: mixed_normal_variadic_result + %0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor, tensor, tensor, tensor, tensor) + return %0#4 : tensor +} + +// ----- + +func @error_in_first_variadic_result() -> tensor { + // expected-error @+1 {{result #0 must be tensor of any type}} + %0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor, f32, tensor, tensor, tensor) + return %0#4 : tensor +} + +// ----- + +func @error_in_normal_result() -> tensor { + // expected-error @+1 {{result #1 must be tensor of any type}} + %0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor, tensor, f32, tensor, tensor) + return %0#4 : tensor +} + +// ----- + +func @error_in_second_variadic_result() -> tensor { + // expected-error @+1 {{result #2 must be tensor of any type}} + %0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor, tensor, tensor, f32, tensor) + return %0#4 : tensor +} + diff --git a/mlir/test/TestDialect/TestOps.td b/mlir/test/TestDialect/TestOps.td index 845b08ded41b..10c144f1d05e 100644 --- a/mlir/test/TestDialect/TestOps.td +++ b/mlir/test/TestDialect/TestOps.td @@ -60,6 +60,31 @@ def NestedTupleOp : TEST_Op<"nested_tuple_32_bit"> { let results = (outs NestedTupleOf<[I32, F32]>); } +//===----------------------------------------------------------------------===// +// Test Operands +//===----------------------------------------------------------------------===// + +def MixedNormalVariadicOperandOp : TEST_Op< + "mixed_normal_variadic_operand", [SameVariadicOperandSize]> { + let arguments = (ins + Variadic:$input1, + AnyTensor:$input2, + Variadic:$input3 + ); +} + +//===----------------------------------------------------------------------===// +// Test Results +//===----------------------------------------------------------------------===// + +def MixedNormalVariadicResults : TEST_Op< + "mixed_normal_variadic_result", [SameVariadicResultSize]> { + let results = (outs + Variadic:$output1, + AnyTensor:$output2, + Variadic:$output3 + ); +} //===----------------------------------------------------------------------===// // Test Attributes diff --git a/mlir/test/mlir-tblgen/op-operand.td b/mlir/test/mlir-tblgen/op-operand.td index 6055081d49db..ea567e408b40 100644 --- a/mlir/test/mlir-tblgen/op-operand.td +++ b/mlir/test/mlir-tblgen/op-operand.td @@ -26,10 +26,6 @@ def OpA : NS_Op<"one_normal_operand_op", []> { // CHECK: assert(operands.size() == 1u && "mismatched number of parameters"); // CHECK: tblgen_state->addOperands(operands); -// CHECK: LogicalResult OpA::verify() { -// CHECK: if (!((this->getOperation()->getOperand(0)->getType().isInteger(32)))) -// CHECK-NEXT: return emitOpError("operand #0 must be 32-bit integer"); - def OpB : NS_Op<"one_variadic_operand_op", []> { let arguments = (ins Variadic:$input); } @@ -52,20 +48,6 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]> // CHECK-LABEL: ArrayRef OpDOperandAdaptor::input3 // CHECK-NEXT: return getODSOperands(2); -// TODO(b/134305899): Move to use TestDialect after fixing verification. - -// CHECK-LABEL: Operation::operand_range OpD::getODSOperands(unsigned index) -// CHECK-NEXT: bool isVariadic[] = {true, false, true}; -// CHECK-NEXT: int prevVariadicCount = 0; -// CHECK-NEXT: for (int i = 0; i < index; ++i) -// CHECK-NEXT: if (isVariadic[i]) ++prevVariadicCount; - -// CHECK: int variadicSize = (getOperation()->getNumOperands() - 1) / 2; -// CHECK: int offset = index + (variadicSize - 1) * prevVariadicCount; -// CHECK-NEXT: int size = isVariadic[index] ? variadicSize : 1; - -// CHECK: return {std::next(getOperation()->operand_begin(), offset), std::next(getOperation()->operand_begin(), offset + size)}; - // CHECK-LABEL: Operation::operand_range OpD::input1 // CHECK-NEXT: return getODSOperands(0); diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td index e0f14e45d478..83f804abf105 100644 --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -17,10 +17,6 @@ def OpA : NS_Op<"one_normal_result_op", []> { // CHECK: assert(resultTypes.size() == 1u && "mismatched number of return types"); // CHECK-NEXT: tblgen_state->addTypes(resultTypes); -// CHECK-LABEL: LogicalResult OpA::verify() -// CHECK: if (!((this->getOperation()->getResult(0)->getType().isInteger(32)))) -// CHECK-NEXT: return emitOpError("result #0 must be 32-bit integer"); - def OpB : NS_Op<"same_input_output_type_op", [SameOperandsAndResultType]> { let arguments = (ins I32:$x); let results = (outs I32:$y); @@ -90,20 +86,6 @@ def OpI : NS_Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]> let results = (outs Variadic:$output1, AnyTensor:$output2, Variadic:$output3); } -// TODO(b/134305899): Move to use TestDialect after fixing verification. - -// CHECK-LABEL: Operation::result_range OpI::getODSResults(unsigned index) -// CHECK-NEXT: bool isVariadic[] = {true, false, true}; -// CHECK-NEXT: int prevVariadicCount = 0; -// CHECK-NEXT: for (int i = 0; i < index; ++i) -// CHECK-NEXT: if (isVariadic[i]) ++prevVariadicCount; - -// CHECK: int variadicSize = (getOperation()->getNumResults() - 1) / 2; -// CHECK: int offset = index + (variadicSize - 1) * prevVariadicCount; -// CHECK-NEXT: int size = isVariadic[index] ? variadicSize : 1; - -// CHECK: return {std::next(getOperation()->result_begin(), offset), std::next(getOperation()->result_begin(), offset + size)}; - // CHECK-LABEL: Operation::result_range OpI::output1 // CHECK-NEXT: return getODSResults(0); diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td index 454a01b2a794..7cf5a8dc378d 100644 --- a/mlir/test/mlir-tblgen/predicate.td +++ b/mlir/test/mlir-tblgen/predicate.td @@ -16,7 +16,8 @@ def OpA : NS_Op<"op_for_CPred_containing_multiple_same_placeholder", []> { } // CHECK-LABEL: OpA::verify -// CHECK: if (!((this->getOperation()->getOperand(0)->getType().isInteger(32) || this->getOperation()->getOperand(0)->getType().isF32()))) +// CHECK: for (Value *v : getODSOperands(0)) { +// CHECK: if (!((v->getType().isInteger(32) || v->getType().isF32()))) def OpB : NS_Op<"op_for_And_PredOpTrait", [ PredOpTrait<"both first and second holds", @@ -103,4 +104,5 @@ def OpK : NS_Op<"op_for_AnyTensorOf", []> { } // CHECK-LABEL: OpK::verify -// CHECK: if (!(((this->getOperation()->getOperand(0)->getType().isa())) && (((this->getOperation()->getOperand(0)->getType().cast().getElementType().isF32())) || ((this->getOperation()->getOperand(0)->getType().cast().getElementType().isInteger(32)))))) +// CHECK: for (Value *v : getODSOperands(0)) { +// CHECK: if (!(((v->getType().isa())) && (((v->getType().cast().getElementType().isF32())) || ((v->getType().cast().getElementType().isInteger(32)))))) diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 7183f34803ea..7718a0dc49b0 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -448,6 +448,12 @@ private: // Generates verify method for the operation. void genVerifier(); + // Generates verify statements for operands and results in the operation. + // The generated code will be attached to `body`. + void genOperandResultVerifier(OpMethodBody &body, + Operator::value_range values, + StringRef valueKind); + // Generates verify statements for regions in the operation. // The generated code will be attached to `body`. void genRegionVerifier(OpMethodBody &body); @@ -1022,39 +1028,8 @@ void OpEmitter::genVerifier() { body << " }\n"; } - // Emits verification code for an operand or result. - auto verifyValue = [&](const tblgen::NamedTypeConstraint &value, int index, - bool isOperand) -> void { - // TODO: Handle variadic operand/result verification. - if (value.isVariadic()) - return; - - // TODO: Commonality between matchers could be extracted to have a more - // concise code. - if (value.hasPredicate()) { - auto description = value.constraint.getDescription(); - body << " if (!(" - << tgfmt( - value.constraint.getConditionTemplate(), - &verifyCtx.withSelf("this->getOperation()->get" + - Twine(isOperand ? "Operand" : "Result") + - "(" + Twine(index) + ")->getType()")) - << ")) {\n"; - body << " return emitOpError(\"" << (isOperand ? "operand" : "result") - << " #" << index - << (description.empty() ? " type precondition failed" - : " must be " + Twine(description)) - << "\");\n }\n"; - } - }; - - for (int i = 0, e = op.getNumOperands(); i < e; ++i) { - verifyValue(op.getOperand(i), i, /*isOperand=*/true); - } - - for (int i = 0, e = op.getNumResults(); i < e; ++i) { - verifyValue(op.getResult(i), i, /*isOperand=*/false); - } + genOperandResultVerifier(body, op.getOperands(), "operand"); + genOperandResultVerifier(body, op.getResults(), "result"); for (auto &trait : op.getTraits()) { if (auto t = dyn_cast(&trait)) { @@ -1073,6 +1048,37 @@ void OpEmitter::genVerifier() { body << " return mlir::success();\n"; } +void OpEmitter::genOperandResultVerifier(OpMethodBody &body, + Operator::value_range values, + StringRef valueKind) { + FmtContext fctx; + unsigned i = 0; + for (auto &staticValue : values) { + if (!staticValue.hasPredicate()) + continue; + + // Emit a loop to check all the dynamic values in the pack. + body << formatv(" for (Value *v : getODS{0}{1}s({2})) {{\n", + // Capitalize the first letter to match the function name + valueKind.substr(0, 1).upper(), valueKind.substr(1), i); + + auto description = staticValue.constraint.getDescription(); + body << " (void)v;\n"; + body << " if (!(" + << tgfmt(staticValue.constraint.getConditionTemplate(), + &fctx.withSelf("v->getType()")) + << "))\n"; + body << " return emitOpError(\"" + // TODO(b/129706806): Use the name of the operand/result here + << valueKind << " #" << i + << (description.empty() ? " type precondition failed" + : " must be " + Twine(description)) + << "\");\n"; + body << " }\n"; + ++i; + } +} + void OpEmitter::genRegionVerifier(OpMethodBody &body) { unsigned numRegions = op.getNumRegions();