From 6981e5ec91c98a23753d2dae590156107d857fda Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Fri, 5 Nov 2021 12:05:02 +0100 Subject: [PATCH] [mlir][python] fix constructor generation for optional operands in presence of segment attribute The ODS-based Python op bindings generator has been generating incorrect specification of the operand segment in presence if both optional and variadic operand groups: optional groups were treated as variadic whereas they require separate treatement. Make sure it is the case. Also harden the tests around generated op constructors as they could hitherto accept the code for both optional and variadic arguments. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D113259 --- mlir/lib/Bindings/Python/IRCore.cpp | 4 +- mlir/test/mlir-tblgen/op-python-bindings.td | 41 +++++++++++-- mlir/test/python/dialects/vector.py | 59 +++++++++++++++---- mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 16 +++-- 4 files changed, 96 insertions(+), 24 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index d465c1382459..cf59a67f9c8f 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1153,7 +1153,7 @@ PyOpView::buildGeneric(py::object cls, py::list resultTypeList, throw py::value_error((llvm::Twine("Operation \"") + name + "\" requires " + llvm::Twine(resultSegmentSpec.size()) + - "result segments but was provided " + + " result segments but was provided " + llvm::Twine(resultTypeList.size())) .str()); } @@ -1164,7 +1164,7 @@ PyOpView::buildGeneric(py::object cls, py::list resultTypeList, if (segmentSpec == 1 || segmentSpec == 0) { // Unpack unary element. try { - auto resultType = py::cast(std::get<0>(it.value())); + auto *resultType = py::cast(std::get<0>(it.value())); if (resultType) { resultTypes.push_back(resultType); resultSegmentLengths.push_back(1); diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td index d6dc56428eb5..becce13050a1 100644 --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -18,7 +18,7 @@ class TestOp traits = []> : // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class AttrSizedOperandsOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.attr_sized_operands" -// CHECK: _ODS_OPERAND_SEGMENTS = [-1,1,-1,] +// CHECK: _ODS_OPERAND_SEGMENTS = [-1,1,0,] def AttrSizedOperandsOp : TestOp<"attr_sized_operands", [AttrSizedOperandSegments]> { // CHECK: def __init__(self, variadic1, non_variadic, variadic2, *, loc=None, ip=None): @@ -28,7 +28,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands", // CHECK: regions = None // CHECK: operands.append(_get_op_results_or_values(variadic1)) // CHECK: operands.append(_get_op_result_or_value(non_variadic)) - // CHECK: if variadic2 is not None: operands.append(_get_op_result_or_value(variadic2)) + // CHECK: operands.append(_get_op_result_or_value(variadic2) if variadic2 is not None else None) // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, @@ -40,6 +40,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands", // CHECK: self.operation.operands, // CHECK: self.operation.attributes["operand_segment_sizes"], 0) // CHECK: return operand_range + // CHECK-NOT: if len(operand_range) // // CHECK: @builtins.property // CHECK: def non_variadic(self): @@ -61,7 +62,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands", // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class AttrSizedResultsOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.attr_sized_results" -// CHECK: _ODS_RESULT_SEGMENTS = [-1,1,-1,] +// CHECK: _ODS_RESULT_SEGMENTS = [0,1,-1,] def AttrSizedResultsOp : TestOp<"attr_sized_results", [AttrSizedResultSegments]> { // CHECK: def __init__(self, variadic1, non_variadic, variadic2, *, loc=None, ip=None): @@ -71,7 +72,7 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results", // CHECK: regions = None // CHECK: if variadic1 is not None: results.append(variadic1) // CHECK: results.append(non_variadic) - // CHECK: if variadic2 is not None: results.append(variadic2) + // CHECK: results.append(variadic2) // CHECK: _ods_successors = None // CHECK: super().__init__(self.build_generic( // CHECK: attributes=attributes, results=results, operands=operands, @@ -97,8 +98,9 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results", // CHECK: self.operation.results, // CHECK: self.operation.attributes["result_segment_sizes"], 2) // CHECK: return result_range + // CHECK-NOT: if len(result_range) let results = (outs Optional:$variadic1, AnyType:$non_variadic, - Optional:$variadic2); + Variadic:$variadic2); } @@ -277,6 +279,35 @@ def MissingNamesOp : TestOp<"missing_names"> { let results = (outs I32:$i32, AnyFloat, I64:$i64); } +// CHECK: @_ods_cext.register_operation(_Dialect) +// CHECK: class OneOptionalOperandOp(_ods_ir.OpView): +// CHECK-LABEL: OPERATION_NAME = "test.one_optional_operand" +// CHECK-NOT: _ODS_OPERAND_SEGMENTS +// CHECK-NOT: _ODS_RESULT_SEGMENTS +def OneOptionalOperandOp : TestOp<"one_optional_operand"> { + let arguments = (ins AnyType:$non_optional, Optional:$optional); + // CHECK: def __init__(self, non_optional, optional, *, loc=None, ip=None): + // CHECK: operands = [] + // CHECK: results = [] + // CHECK: attributes = {} + // CHECK: regions = None + // CHECK: operands.append(_get_op_result_or_value(non_optional)) + // CHECK: if optional is not None: operands.append(_get_op_result_or_value(optional)) + // CHECK: _ods_successors = None + // CHECK: super().__init__(self.build_generic( + // CHECK: attributes=attributes, results=results, operands=operands, + // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)) + + // CHECK: @builtins.property + // CHECK: def non_optional(self): + // CHECK: return self.operation.operands[0] + + // CHECK: @builtins.property + // CHECK: def optional(self): + // CHECK: return self.operation.operands[1] if len(self.operation.operands) > 2 else None + +} + // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class OneVariadicOperandOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.one_variadic_operand" diff --git a/mlir/test/python/dialects/vector.py b/mlir/test/python/dialects/vector.py index 4d7052859e7d..b8db94070d6a 100644 --- a/mlir/test/python/dialects/vector.py +++ b/mlir/test/python/dialects/vector.py @@ -2,25 +2,58 @@ from mlir.ir import * import mlir.dialects.builtin as builtin +import mlir.dialects.std as std import mlir.dialects.vector as vector def run(f): print("\nTEST:", f.__name__) - f() + with Context(), Location.unknown(): + f() + return f # CHECK-LABEL: TEST: testPrintOp @run def testPrintOp(): - with Context() as ctx, Location.unknown(): - module = Module.create() - with InsertionPoint(module.body): - @builtin.FuncOp.from_py_func(VectorType.get((12, 5), F32Type.get())) - def print_vector(arg): - return vector.PrintOp(arg) + module = Module.create() + with InsertionPoint(module.body): - # CHECK-LABEL: func @print_vector( - # CHECK-SAME: %[[ARG:.*]]: vector<12x5xf32>) { - # CHECK: vector.print %[[ARG]] : vector<12x5xf32> - # CHECK: return - # CHECK: } - print(module) + @builtin.FuncOp.from_py_func(VectorType.get((12, 5), F32Type.get())) + def print_vector(arg): + return vector.PrintOp(arg) + + # CHECK-LABEL: func @print_vector( + # CHECK-SAME: %[[ARG:.*]]: vector<12x5xf32>) { + # CHECK: vector.print %[[ARG]] : vector<12x5xf32> + # CHECK: return + # CHECK: } + print(module) + + +# CHECK-LABEL: TEST: testTransferReadOp +@run +def testTransferReadOp(): + module = Module.create() + with InsertionPoint(module.body): + vector_type = VectorType.get([2, 3], F32Type.get()) + memref_type = MemRefType.get([-1, -1], F32Type.get()) + index_type = IndexType.get() + mask_type = VectorType.get(vector_type.shape, IntegerType.get_signless(1)) + identity_map = AffineMap.get_identity(vector_type.rank) + identity_map_attr = AffineMapAttr.get(identity_map) + func = builtin.FuncOp("transfer_read", + ([memref_type, index_type, + F32Type.get(), mask_type], [])) + with InsertionPoint(func.add_entry_block()): + A, zero, padding, mask = func.arguments + vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr, + padding, mask, None) + vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr, + padding, None, None) + std.ReturnOp([]) + + # CHECK: @transfer_read(%[[MEM:.*]]: memref, %[[IDX:.*]]: index, + # CHECK: %[[PAD:.*]]: f32, %[[MASK:.*]]: vector<2x3xi1>) + # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]], %[[MASK]] + # CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]] + # CHECK-NOT: %[[MASK]] + print(module) diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp index d9ce2963a8f3..8babff25db07 100644 --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -67,6 +67,7 @@ class {0}(_ods_ir.OpView): /// Each segment spec is either None (default) or an array of integers /// where: /// 1 = single element (expect non sequence operand/result) +/// 0 = optional element (expect a value or None) /// -1 = operand/result is a sequence corresponding to a variadic constexpr const char *opClassSizedSegmentsTemplate = R"Py( _ODS_{0}_SEGMENTS = {1} @@ -505,6 +506,9 @@ constexpr const char *singleResultAppendTemplate = "results.append({0})"; /// {0} is the field name. constexpr const char *optionalAppendOperandTemplate = "if {0} is not None: operands.append(_get_op_result_or_value({0}))"; +constexpr const char *optionalAppendAttrSizedOperandsTemplate = + "operands.append(_get_op_result_or_value({0}) if {0} is not None else " + "None)"; constexpr const char *optionalAppendResultTemplate = "if {0} is not None: results.append({0})"; @@ -693,7 +697,11 @@ populateBuilderLinesOperand(const Operator &op, if (!element.isVariableLength()) { formatString = singleOperandAppendTemplate; } else if (element.isOptional()) { - formatString = optionalAppendOperandTemplate; + if (sizedSegments) { + formatString = optionalAppendAttrSizedOperandsTemplate; + } else { + formatString = optionalAppendOperandTemplate; + } } else { assert(element.isVariadic() && "unhandled element group type"); // If emitting with sizedSegments, then we add the actual list-typed @@ -882,10 +890,10 @@ static void emitSegmentSpec( std::string segmentSpec("["); for (int i = 0, e = getNumElements(op); i < e; ++i) { const NamedTypeConstraint &element = getElement(op, i); - if (element.isVariableLength()) { - segmentSpec.append("-1,"); - } else if (element.isOptional()) { + if (element.isOptional()) { segmentSpec.append("0,"); + } else if (element.isVariadic()) { + segmentSpec.append("-1,"); } else { segmentSpec.append("1,"); }