[mlir][python] fix constructor generation for optional operands in presence of segment attribute

The ODS-based Python op bindings generator has been generating incorrect
specification of the operand segment in presence if both optional and variadic
operand groups: optional groups were treated as variadic whereas they require
separate treatement. Make sure it is the case. Also harden the tests around
generated op constructors as they could hitherto accept the code for both
optional and variadic arguments.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D113259
This commit is contained in:
Alex Zinenko
2021-11-05 12:05:02 +01:00
parent 5e9ac7c0a5
commit 6981e5ec91
4 changed files with 96 additions and 24 deletions

View File

@@ -1153,7 +1153,7 @@ PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
throw py::value_error((llvm::Twine("Operation \"") + name +
"\" requires " +
llvm::Twine(resultSegmentSpec.size()) +
"result segments but was provided " +
" result segments but was provided " +
llvm::Twine(resultTypeList.size()))
.str());
}
@@ -1164,7 +1164,7 @@ PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
if (segmentSpec == 1 || segmentSpec == 0) {
// Unpack unary element.
try {
auto resultType = py::cast<PyType *>(std::get<0>(it.value()));
auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
if (resultType) {
resultTypes.push_back(resultType);
resultSegmentLengths.push_back(1);

View File

@@ -18,7 +18,7 @@ class TestOp<string mnemonic, list<OpTrait> traits = []> :
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttrSizedOperandsOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.attr_sized_operands"
// CHECK: _ODS_OPERAND_SEGMENTS = [-1,1,-1,]
// CHECK: _ODS_OPERAND_SEGMENTS = [-1,1,0,]
def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
[AttrSizedOperandSegments]> {
// CHECK: def __init__(self, variadic1, non_variadic, variadic2, *, loc=None, ip=None):
@@ -28,7 +28,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
// CHECK: regions = None
// CHECK: operands.append(_get_op_results_or_values(variadic1))
// CHECK: operands.append(_get_op_result_or_value(non_variadic))
// CHECK: if variadic2 is not None: operands.append(_get_op_result_or_value(variadic2))
// CHECK: operands.append(_get_op_result_or_value(variadic2) if variadic2 is not None else None)
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
@@ -40,6 +40,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
// CHECK: self.operation.operands,
// CHECK: self.operation.attributes["operand_segment_sizes"], 0)
// CHECK: return operand_range
// CHECK-NOT: if len(operand_range)
//
// CHECK: @builtins.property
// CHECK: def non_variadic(self):
@@ -61,7 +62,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.attr_sized_results"
// CHECK: _ODS_RESULT_SEGMENTS = [-1,1,-1,]
// CHECK: _ODS_RESULT_SEGMENTS = [0,1,-1,]
def AttrSizedResultsOp : TestOp<"attr_sized_results",
[AttrSizedResultSegments]> {
// CHECK: def __init__(self, variadic1, non_variadic, variadic2, *, loc=None, ip=None):
@@ -71,7 +72,7 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
// CHECK: regions = None
// CHECK: if variadic1 is not None: results.append(variadic1)
// CHECK: results.append(non_variadic)
// CHECK: if variadic2 is not None: results.append(variadic2)
// CHECK: results.append(variadic2)
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
@@ -97,8 +98,9 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
// CHECK: self.operation.results,
// CHECK: self.operation.attributes["result_segment_sizes"], 2)
// CHECK: return result_range
// CHECK-NOT: if len(result_range)
let results = (outs Optional<AnyType>:$variadic1, AnyType:$non_variadic,
Optional<AnyType>:$variadic2);
Variadic<AnyType>:$variadic2);
}
@@ -277,6 +279,35 @@ def MissingNamesOp : TestOp<"missing_names"> {
let results = (outs I32:$i32, AnyFloat, I64:$i64);
}
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneOptionalOperandOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.one_optional_operand"
// CHECK-NOT: _ODS_OPERAND_SEGMENTS
// CHECK-NOT: _ODS_RESULT_SEGMENTS
def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
let arguments = (ins AnyType:$non_optional, Optional<AnyType>:$optional);
// CHECK: def __init__(self, non_optional, optional, *, loc=None, ip=None):
// CHECK: operands = []
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
// CHECK: operands.append(_get_op_result_or_value(non_optional))
// CHECK: if optional is not None: operands.append(_get_op_result_or_value(optional))
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip))
// CHECK: @builtins.property
// CHECK: def non_optional(self):
// CHECK: return self.operation.operands[0]
// CHECK: @builtins.property
// CHECK: def optional(self):
// CHECK: return self.operation.operands[1] if len(self.operation.operands) > 2 else None
}
// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.one_variadic_operand"

View File

@@ -2,25 +2,58 @@
from mlir.ir import *
import mlir.dialects.builtin as builtin
import mlir.dialects.std as std
import mlir.dialects.vector as vector
def run(f):
print("\nTEST:", f.__name__)
f()
with Context(), Location.unknown():
f()
return f
# CHECK-LABEL: TEST: testPrintOp
@run
def testPrintOp():
with Context() as ctx, Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
@builtin.FuncOp.from_py_func(VectorType.get((12, 5), F32Type.get()))
def print_vector(arg):
return vector.PrintOp(arg)
module = Module.create()
with InsertionPoint(module.body):
# CHECK-LABEL: func @print_vector(
# CHECK-SAME: %[[ARG:.*]]: vector<12x5xf32>) {
# CHECK: vector.print %[[ARG]] : vector<12x5xf32>
# CHECK: return
# CHECK: }
print(module)
@builtin.FuncOp.from_py_func(VectorType.get((12, 5), F32Type.get()))
def print_vector(arg):
return vector.PrintOp(arg)
# CHECK-LABEL: func @print_vector(
# CHECK-SAME: %[[ARG:.*]]: vector<12x5xf32>) {
# CHECK: vector.print %[[ARG]] : vector<12x5xf32>
# CHECK: return
# CHECK: }
print(module)
# CHECK-LABEL: TEST: testTransferReadOp
@run
def testTransferReadOp():
module = Module.create()
with InsertionPoint(module.body):
vector_type = VectorType.get([2, 3], F32Type.get())
memref_type = MemRefType.get([-1, -1], F32Type.get())
index_type = IndexType.get()
mask_type = VectorType.get(vector_type.shape, IntegerType.get_signless(1))
identity_map = AffineMap.get_identity(vector_type.rank)
identity_map_attr = AffineMapAttr.get(identity_map)
func = builtin.FuncOp("transfer_read",
([memref_type, index_type,
F32Type.get(), mask_type], []))
with InsertionPoint(func.add_entry_block()):
A, zero, padding, mask = func.arguments
vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr,
padding, mask, None)
vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr,
padding, None, None)
std.ReturnOp([])
# CHECK: @transfer_read(%[[MEM:.*]]: memref<?x?xf32>, %[[IDX:.*]]: index,
# CHECK: %[[PAD:.*]]: f32, %[[MASK:.*]]: vector<2x3xi1>)
# CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]], %[[MASK]]
# CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]]
# CHECK-NOT: %[[MASK]]
print(module)

View File

@@ -67,6 +67,7 @@ class {0}(_ods_ir.OpView):
/// Each segment spec is either None (default) or an array of integers
/// where:
/// 1 = single element (expect non sequence operand/result)
/// 0 = optional element (expect a value or None)
/// -1 = operand/result is a sequence corresponding to a variadic
constexpr const char *opClassSizedSegmentsTemplate = R"Py(
_ODS_{0}_SEGMENTS = {1}
@@ -505,6 +506,9 @@ constexpr const char *singleResultAppendTemplate = "results.append({0})";
/// {0} is the field name.
constexpr const char *optionalAppendOperandTemplate =
"if {0} is not None: operands.append(_get_op_result_or_value({0}))";
constexpr const char *optionalAppendAttrSizedOperandsTemplate =
"operands.append(_get_op_result_or_value({0}) if {0} is not None else "
"None)";
constexpr const char *optionalAppendResultTemplate =
"if {0} is not None: results.append({0})";
@@ -693,7 +697,11 @@ populateBuilderLinesOperand(const Operator &op,
if (!element.isVariableLength()) {
formatString = singleOperandAppendTemplate;
} else if (element.isOptional()) {
formatString = optionalAppendOperandTemplate;
if (sizedSegments) {
formatString = optionalAppendAttrSizedOperandsTemplate;
} else {
formatString = optionalAppendOperandTemplate;
}
} else {
assert(element.isVariadic() && "unhandled element group type");
// If emitting with sizedSegments, then we add the actual list-typed
@@ -882,10 +890,10 @@ static void emitSegmentSpec(
std::string segmentSpec("[");
for (int i = 0, e = getNumElements(op); i < e; ++i) {
const NamedTypeConstraint &element = getElement(op, i);
if (element.isVariableLength()) {
segmentSpec.append("-1,");
} else if (element.isOptional()) {
if (element.isOptional()) {
segmentSpec.append("0,");
} else if (element.isVariadic()) {
segmentSpec.append("-1,");
} else {
segmentSpec.append("1,");
}