diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index ac8c6528be09..0d3221345988 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -255,9 +255,7 @@ class TypeAlias : // class is used for supporting variadic operands/results. An op can declare no // more than one variadic operand/result, and that operand/result must be the // last one in the operand/result list. -class Variadic - // TODO(b/132908002): support variadic type conditions - : TypeConstraint, descr> { +class Variadic : TypeConstraint { Type baseType = type; } @@ -907,6 +905,9 @@ def Terminator : NativeOpTrait<"IsTerminator">; def FirstAttrDerivedResultType : GenInternalOpTrait<"FirstAttrDerivedResultType">; +// TODO(antiagainst): Turn the following into normal traits and generate +// verification for them. + // All variadic operands of the op have the same number of values. // A variadic operand contains an array of values whose array size is only // known at runtime. This trait requires all variadic operands of an op diff --git a/mlir/include/mlir/LLVMIR/LLVMOps.td b/mlir/include/mlir/LLVMIR/LLVMOps.td index a207e947023f..e9f235ac95c5 100644 --- a/mlir/include/mlir/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/LLVMIR/LLVMOps.td @@ -203,7 +203,9 @@ def LLVM_PtrToIntOp // Call-related operations. def LLVM_CallOp : LLVM_Op<"call">, Arguments<(ins OptionalAttr:$callee, - Variadic)>, + // TODO(b/133216756): fix test failure and + // change to LLVM_Type + Variadic)>, Results<(outs Variadic)>, LLVM_TwoBuilders { diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h index 6cc6bbc35f56..4551790cef23 100644 --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -69,11 +69,12 @@ public: std::string getQualCppClassName() const; using value_iterator = NamedTypeConstraint *; + using value_range = llvm::iterator_range; // Op result iterators. value_iterator result_begin(); value_iterator result_end(); - llvm::iterator_range getResults(); + value_range getResults(); // Returns the number of results this op produces. int getNumResults() const; @@ -110,7 +111,7 @@ public: // Op operand iterators. value_iterator operand_begin(); value_iterator operand_end(); - llvm::iterator_range getOperands(); + value_range getOperands(); int getNumOperands() const { return operands.size(); } NamedTypeConstraint &getOperand(int index) { return operands[index]; } diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 4b2940a17303..1ef3fcd0f82f 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -1595,12 +1595,6 @@ static LogicalResult verify(ExtractElementOp op) { if (op.getType() != aggregateType.getElementType()) return op.emitOpError("result type must match element type of aggregate"); - // TODO(b/132908002) This should be covered by the op specification in - // tablegen, but for some reason it's not. - for (auto *idx : op.getIndices()) - if (!idx->getType().isIndex()) - return op.emitOpError("index to extract_element must have 'index' type"); - // Verify the # indices match if we have a ranked type. if (aggregateType.hasRank() && aggregateType.getRank() != op.getNumOperands() - 1) diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index 3c269ba65472..a5dd9c25027d 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -95,7 +95,7 @@ auto tblgen::Operator::result_begin() -> value_iterator { auto tblgen::Operator::result_end() -> value_iterator { return results.end(); } -auto tblgen::Operator::getResults() -> llvm::iterator_range { +auto tblgen::Operator::getResults() -> value_range { return {result_begin(), result_end()}; } @@ -205,7 +205,7 @@ auto tblgen::Operator::operand_begin() -> value_iterator { auto tblgen::Operator::operand_end() -> value_iterator { return operands.end(); } -auto tblgen::Operator::getOperands() -> llvm::iterator_range { +auto tblgen::Operator::getOperands() -> value_range { return {operand_begin(), operand_end()}; } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index 562c2ce831e8..6baa104b3b3b 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -639,7 +639,7 @@ func @extract_element_no_indices(%v : vector<3xf32>) { // ----- func @extract_element_invalid_index_type(%v : vector<3xf32>, %i : i32) { - // expected-error@+1 {{index to extract_element must have 'index' type}} + // expected-error@+1 {{operand #1 must be index}} %0 = "std.extract_element"(%v, %i) : (vector<3xf32>, i32) -> f32 return } diff --git a/mlir/test/IR/operand.mlir b/mlir/test/IR/operand.mlir new file mode 100644 index 000000000000..0d7939f756f2 --- /dev/null +++ b/mlir/test/IR/operand.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-test-opt %s -split-input-file -verify | FileCheck %s + +//===----------------------------------------------------------------------===// +// Test mixed normal and variadic operands +//===----------------------------------------------------------------------===// + +func @correct_variadic_operand(%arg0: tensor, %arg1: f32) { + // CHECK: mixed_normal_variadic_operand + "test.mixed_normal_variadic_operand"(%arg0, %arg0, %arg0, %arg0, %arg0) : (tensor, tensor, tensor, tensor, tensor) -> () + return +} + +// ----- + +func @error_in_first_variadic_operand(%arg0: tensor, %arg1: f32) { + // expected-error @+1 {{operand #0 must be tensor of any type}} + "test.mixed_normal_variadic_operand"(%arg0, %arg1, %arg0, %arg0, %arg0) : (tensor, f32, tensor, tensor, tensor) -> () + return +} + +// ----- + +func @error_in_normal_operand(%arg0: tensor, %arg1: f32) { + // expected-error @+1 {{operand #1 must be tensor of any type}} + "test.mixed_normal_variadic_operand"(%arg0, %arg0, %arg1, %arg0, %arg0) : (tensor, tensor, f32, tensor, tensor) -> () + return +} + +// ----- + +func @error_in_second_variadic_operand(%arg0: tensor, %arg1: f32) { + // expected-error @+1 {{operand #2 must be tensor of any type}} + "test.mixed_normal_variadic_operand"(%arg0, %arg0, %arg0, %arg1, %arg0) : (tensor, tensor, tensor, f32, tensor) -> () + return +} diff --git a/mlir/test/IR/result.mlir b/mlir/test/IR/result.mlir new file mode 100644 index 000000000000..fc5d597c31d3 --- /dev/null +++ b/mlir/test/IR/result.mlir @@ -0,0 +1,36 @@ +// RUN: mlir-test-opt %s -split-input-file -verify | FileCheck %s + +//===----------------------------------------------------------------------===// +// Test mixed normal and variadic results +//===----------------------------------------------------------------------===// + +func @correct_variadic_result() -> tensor { + // CHECK: mixed_normal_variadic_result + %0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor, tensor, tensor, tensor, tensor) + return %0#4 : tensor +} + +// ----- + +func @error_in_first_variadic_result() -> tensor { + // expected-error @+1 {{result #0 must be tensor of any type}} + %0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor, f32, tensor, tensor, tensor) + return %0#4 : tensor +} + +// ----- + +func @error_in_normal_result() -> tensor { + // expected-error @+1 {{result #1 must be tensor of any type}} + %0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor, tensor, f32, tensor, tensor) + return %0#4 : tensor +} + +// ----- + +func @error_in_second_variadic_result() -> tensor { + // expected-error @+1 {{result #2 must be tensor of any type}} + %0:5 = "test.mixed_normal_variadic_result"() : () -> (tensor, tensor, tensor, f32, tensor) + return %0#4 : tensor +} + diff --git a/mlir/test/TestDialect/TestOps.td b/mlir/test/TestDialect/TestOps.td index 845b08ded41b..10c144f1d05e 100644 --- a/mlir/test/TestDialect/TestOps.td +++ b/mlir/test/TestDialect/TestOps.td @@ -60,6 +60,31 @@ def NestedTupleOp : TEST_Op<"nested_tuple_32_bit"> { let results = (outs NestedTupleOf<[I32, F32]>); } +//===----------------------------------------------------------------------===// +// Test Operands +//===----------------------------------------------------------------------===// + +def MixedNormalVariadicOperandOp : TEST_Op< + "mixed_normal_variadic_operand", [SameVariadicOperandSize]> { + let arguments = (ins + Variadic:$input1, + AnyTensor:$input2, + Variadic:$input3 + ); +} + +//===----------------------------------------------------------------------===// +// Test Results +//===----------------------------------------------------------------------===// + +def MixedNormalVariadicResults : TEST_Op< + "mixed_normal_variadic_result", [SameVariadicResultSize]> { + let results = (outs + Variadic:$output1, + AnyTensor:$output2, + Variadic:$output3 + ); +} //===----------------------------------------------------------------------===// // Test Attributes diff --git a/mlir/test/mlir-tblgen/op-operand.td b/mlir/test/mlir-tblgen/op-operand.td index 6055081d49db..ea567e408b40 100644 --- a/mlir/test/mlir-tblgen/op-operand.td +++ b/mlir/test/mlir-tblgen/op-operand.td @@ -26,10 +26,6 @@ def OpA : NS_Op<"one_normal_operand_op", []> { // CHECK: assert(operands.size() == 1u && "mismatched number of parameters"); // CHECK: tblgen_state->addOperands(operands); -// CHECK: LogicalResult OpA::verify() { -// CHECK: if (!((this->getOperation()->getOperand(0)->getType().isInteger(32)))) -// CHECK-NEXT: return emitOpError("operand #0 must be 32-bit integer"); - def OpB : NS_Op<"one_variadic_operand_op", []> { let arguments = (ins Variadic:$input); } @@ -52,20 +48,6 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]> // CHECK-LABEL: ArrayRef OpDOperandAdaptor::input3 // CHECK-NEXT: return getODSOperands(2); -// TODO(b/134305899): Move to use TestDialect after fixing verification. - -// CHECK-LABEL: Operation::operand_range OpD::getODSOperands(unsigned index) -// CHECK-NEXT: bool isVariadic[] = {true, false, true}; -// CHECK-NEXT: int prevVariadicCount = 0; -// CHECK-NEXT: for (int i = 0; i < index; ++i) -// CHECK-NEXT: if (isVariadic[i]) ++prevVariadicCount; - -// CHECK: int variadicSize = (getOperation()->getNumOperands() - 1) / 2; -// CHECK: int offset = index + (variadicSize - 1) * prevVariadicCount; -// CHECK-NEXT: int size = isVariadic[index] ? variadicSize : 1; - -// CHECK: return {std::next(getOperation()->operand_begin(), offset), std::next(getOperation()->operand_begin(), offset + size)}; - // CHECK-LABEL: Operation::operand_range OpD::input1 // CHECK-NEXT: return getODSOperands(0); diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td index e0f14e45d478..83f804abf105 100644 --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -17,10 +17,6 @@ def OpA : NS_Op<"one_normal_result_op", []> { // CHECK: assert(resultTypes.size() == 1u && "mismatched number of return types"); // CHECK-NEXT: tblgen_state->addTypes(resultTypes); -// CHECK-LABEL: LogicalResult OpA::verify() -// CHECK: if (!((this->getOperation()->getResult(0)->getType().isInteger(32)))) -// CHECK-NEXT: return emitOpError("result #0 must be 32-bit integer"); - def OpB : NS_Op<"same_input_output_type_op", [SameOperandsAndResultType]> { let arguments = (ins I32:$x); let results = (outs I32:$y); @@ -90,20 +86,6 @@ def OpI : NS_Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]> let results = (outs Variadic:$output1, AnyTensor:$output2, Variadic:$output3); } -// TODO(b/134305899): Move to use TestDialect after fixing verification. - -// CHECK-LABEL: Operation::result_range OpI::getODSResults(unsigned index) -// CHECK-NEXT: bool isVariadic[] = {true, false, true}; -// CHECK-NEXT: int prevVariadicCount = 0; -// CHECK-NEXT: for (int i = 0; i < index; ++i) -// CHECK-NEXT: if (isVariadic[i]) ++prevVariadicCount; - -// CHECK: int variadicSize = (getOperation()->getNumResults() - 1) / 2; -// CHECK: int offset = index + (variadicSize - 1) * prevVariadicCount; -// CHECK-NEXT: int size = isVariadic[index] ? variadicSize : 1; - -// CHECK: return {std::next(getOperation()->result_begin(), offset), std::next(getOperation()->result_begin(), offset + size)}; - // CHECK-LABEL: Operation::result_range OpI::output1 // CHECK-NEXT: return getODSResults(0); diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td index 454a01b2a794..7cf5a8dc378d 100644 --- a/mlir/test/mlir-tblgen/predicate.td +++ b/mlir/test/mlir-tblgen/predicate.td @@ -16,7 +16,8 @@ def OpA : NS_Op<"op_for_CPred_containing_multiple_same_placeholder", []> { } // CHECK-LABEL: OpA::verify -// CHECK: if (!((this->getOperation()->getOperand(0)->getType().isInteger(32) || this->getOperation()->getOperand(0)->getType().isF32()))) +// CHECK: for (Value *v : getODSOperands(0)) { +// CHECK: if (!((v->getType().isInteger(32) || v->getType().isF32()))) def OpB : NS_Op<"op_for_And_PredOpTrait", [ PredOpTrait<"both first and second holds", @@ -103,4 +104,5 @@ def OpK : NS_Op<"op_for_AnyTensorOf", []> { } // CHECK-LABEL: OpK::verify -// CHECK: if (!(((this->getOperation()->getOperand(0)->getType().isa())) && (((this->getOperation()->getOperand(0)->getType().cast().getElementType().isF32())) || ((this->getOperation()->getOperand(0)->getType().cast().getElementType().isInteger(32)))))) +// CHECK: for (Value *v : getODSOperands(0)) { +// CHECK: if (!(((v->getType().isa())) && (((v->getType().cast().getElementType().isF32())) || ((v->getType().cast().getElementType().isInteger(32)))))) diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 7183f34803ea..7718a0dc49b0 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -448,6 +448,12 @@ private: // Generates verify method for the operation. void genVerifier(); + // Generates verify statements for operands and results in the operation. + // The generated code will be attached to `body`. + void genOperandResultVerifier(OpMethodBody &body, + Operator::value_range values, + StringRef valueKind); + // Generates verify statements for regions in the operation. // The generated code will be attached to `body`. void genRegionVerifier(OpMethodBody &body); @@ -1022,39 +1028,8 @@ void OpEmitter::genVerifier() { body << " }\n"; } - // Emits verification code for an operand or result. - auto verifyValue = [&](const tblgen::NamedTypeConstraint &value, int index, - bool isOperand) -> void { - // TODO: Handle variadic operand/result verification. - if (value.isVariadic()) - return; - - // TODO: Commonality between matchers could be extracted to have a more - // concise code. - if (value.hasPredicate()) { - auto description = value.constraint.getDescription(); - body << " if (!(" - << tgfmt( - value.constraint.getConditionTemplate(), - &verifyCtx.withSelf("this->getOperation()->get" + - Twine(isOperand ? "Operand" : "Result") + - "(" + Twine(index) + ")->getType()")) - << ")) {\n"; - body << " return emitOpError(\"" << (isOperand ? "operand" : "result") - << " #" << index - << (description.empty() ? " type precondition failed" - : " must be " + Twine(description)) - << "\");\n }\n"; - } - }; - - for (int i = 0, e = op.getNumOperands(); i < e; ++i) { - verifyValue(op.getOperand(i), i, /*isOperand=*/true); - } - - for (int i = 0, e = op.getNumResults(); i < e; ++i) { - verifyValue(op.getResult(i), i, /*isOperand=*/false); - } + genOperandResultVerifier(body, op.getOperands(), "operand"); + genOperandResultVerifier(body, op.getResults(), "result"); for (auto &trait : op.getTraits()) { if (auto t = dyn_cast(&trait)) { @@ -1073,6 +1048,37 @@ void OpEmitter::genVerifier() { body << " return mlir::success();\n"; } +void OpEmitter::genOperandResultVerifier(OpMethodBody &body, + Operator::value_range values, + StringRef valueKind) { + FmtContext fctx; + unsigned i = 0; + for (auto &staticValue : values) { + if (!staticValue.hasPredicate()) + continue; + + // Emit a loop to check all the dynamic values in the pack. + body << formatv(" for (Value *v : getODS{0}{1}s({2})) {{\n", + // Capitalize the first letter to match the function name + valueKind.substr(0, 1).upper(), valueKind.substr(1), i); + + auto description = staticValue.constraint.getDescription(); + body << " (void)v;\n"; + body << " if (!(" + << tgfmt(staticValue.constraint.getConditionTemplate(), + &fctx.withSelf("v->getType()")) + << "))\n"; + body << " return emitOpError(\"" + // TODO(b/129706806): Use the name of the operand/result here + << valueKind << " #" << i + << (description.empty() ? " type precondition failed" + : " must be " + Twine(description)) + << "\");\n"; + body << " }\n"; + ++i; + } +} + void OpEmitter::genRegionVerifier(OpMethodBody &body) { unsigned numRegions = op.getNumRegions();