mirror of
https://github.com/intel/llvm.git
synced 2026-01-13 19:08:21 +08:00
[mlir][python] fix constructor generation for optional operands in presence of segment attribute
The ODS-based Python op bindings generator has been generating incorrect specification of the operand segment in presence if both optional and variadic operand groups: optional groups were treated as variadic whereas they require separate treatement. Make sure it is the case. Also harden the tests around generated op constructors as they could hitherto accept the code for both optional and variadic arguments. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D113259
This commit is contained in:
@@ -1153,7 +1153,7 @@ PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
|
||||
throw py::value_error((llvm::Twine("Operation \"") + name +
|
||||
"\" requires " +
|
||||
llvm::Twine(resultSegmentSpec.size()) +
|
||||
"result segments but was provided " +
|
||||
" result segments but was provided " +
|
||||
llvm::Twine(resultTypeList.size()))
|
||||
.str());
|
||||
}
|
||||
@@ -1164,7 +1164,7 @@ PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
|
||||
if (segmentSpec == 1 || segmentSpec == 0) {
|
||||
// Unpack unary element.
|
||||
try {
|
||||
auto resultType = py::cast<PyType *>(std::get<0>(it.value()));
|
||||
auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
|
||||
if (resultType) {
|
||||
resultTypes.push_back(resultType);
|
||||
resultSegmentLengths.push_back(1);
|
||||
|
||||
@@ -18,7 +18,7 @@ class TestOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
// CHECK: @_ods_cext.register_operation(_Dialect)
|
||||
// CHECK: class AttrSizedOperandsOp(_ods_ir.OpView):
|
||||
// CHECK-LABEL: OPERATION_NAME = "test.attr_sized_operands"
|
||||
// CHECK: _ODS_OPERAND_SEGMENTS = [-1,1,-1,]
|
||||
// CHECK: _ODS_OPERAND_SEGMENTS = [-1,1,0,]
|
||||
def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
|
||||
[AttrSizedOperandSegments]> {
|
||||
// CHECK: def __init__(self, variadic1, non_variadic, variadic2, *, loc=None, ip=None):
|
||||
@@ -28,7 +28,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
|
||||
// CHECK: regions = None
|
||||
// CHECK: operands.append(_get_op_results_or_values(variadic1))
|
||||
// CHECK: operands.append(_get_op_result_or_value(non_variadic))
|
||||
// CHECK: if variadic2 is not None: operands.append(_get_op_result_or_value(variadic2))
|
||||
// CHECK: operands.append(_get_op_result_or_value(variadic2) if variadic2 is not None else None)
|
||||
// CHECK: _ods_successors = None
|
||||
// CHECK: super().__init__(self.build_generic(
|
||||
// CHECK: attributes=attributes, results=results, operands=operands,
|
||||
@@ -40,6 +40,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
|
||||
// CHECK: self.operation.operands,
|
||||
// CHECK: self.operation.attributes["operand_segment_sizes"], 0)
|
||||
// CHECK: return operand_range
|
||||
// CHECK-NOT: if len(operand_range)
|
||||
//
|
||||
// CHECK: @builtins.property
|
||||
// CHECK: def non_variadic(self):
|
||||
@@ -61,7 +62,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
|
||||
// CHECK: @_ods_cext.register_operation(_Dialect)
|
||||
// CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
|
||||
// CHECK-LABEL: OPERATION_NAME = "test.attr_sized_results"
|
||||
// CHECK: _ODS_RESULT_SEGMENTS = [-1,1,-1,]
|
||||
// CHECK: _ODS_RESULT_SEGMENTS = [0,1,-1,]
|
||||
def AttrSizedResultsOp : TestOp<"attr_sized_results",
|
||||
[AttrSizedResultSegments]> {
|
||||
// CHECK: def __init__(self, variadic1, non_variadic, variadic2, *, loc=None, ip=None):
|
||||
@@ -71,7 +72,7 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
|
||||
// CHECK: regions = None
|
||||
// CHECK: if variadic1 is not None: results.append(variadic1)
|
||||
// CHECK: results.append(non_variadic)
|
||||
// CHECK: if variadic2 is not None: results.append(variadic2)
|
||||
// CHECK: results.append(variadic2)
|
||||
// CHECK: _ods_successors = None
|
||||
// CHECK: super().__init__(self.build_generic(
|
||||
// CHECK: attributes=attributes, results=results, operands=operands,
|
||||
@@ -97,8 +98,9 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
|
||||
// CHECK: self.operation.results,
|
||||
// CHECK: self.operation.attributes["result_segment_sizes"], 2)
|
||||
// CHECK: return result_range
|
||||
// CHECK-NOT: if len(result_range)
|
||||
let results = (outs Optional<AnyType>:$variadic1, AnyType:$non_variadic,
|
||||
Optional<AnyType>:$variadic2);
|
||||
Variadic<AnyType>:$variadic2);
|
||||
}
|
||||
|
||||
|
||||
@@ -277,6 +279,35 @@ def MissingNamesOp : TestOp<"missing_names"> {
|
||||
let results = (outs I32:$i32, AnyFloat, I64:$i64);
|
||||
}
|
||||
|
||||
// CHECK: @_ods_cext.register_operation(_Dialect)
|
||||
// CHECK: class OneOptionalOperandOp(_ods_ir.OpView):
|
||||
// CHECK-LABEL: OPERATION_NAME = "test.one_optional_operand"
|
||||
// CHECK-NOT: _ODS_OPERAND_SEGMENTS
|
||||
// CHECK-NOT: _ODS_RESULT_SEGMENTS
|
||||
def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
|
||||
let arguments = (ins AnyType:$non_optional, Optional<AnyType>:$optional);
|
||||
// CHECK: def __init__(self, non_optional, optional, *, loc=None, ip=None):
|
||||
// CHECK: operands = []
|
||||
// CHECK: results = []
|
||||
// CHECK: attributes = {}
|
||||
// CHECK: regions = None
|
||||
// CHECK: operands.append(_get_op_result_or_value(non_optional))
|
||||
// CHECK: if optional is not None: operands.append(_get_op_result_or_value(optional))
|
||||
// CHECK: _ods_successors = None
|
||||
// CHECK: super().__init__(self.build_generic(
|
||||
// CHECK: attributes=attributes, results=results, operands=operands,
|
||||
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip))
|
||||
|
||||
// CHECK: @builtins.property
|
||||
// CHECK: def non_optional(self):
|
||||
// CHECK: return self.operation.operands[0]
|
||||
|
||||
// CHECK: @builtins.property
|
||||
// CHECK: def optional(self):
|
||||
// CHECK: return self.operation.operands[1] if len(self.operation.operands) > 2 else None
|
||||
|
||||
}
|
||||
|
||||
// CHECK: @_ods_cext.register_operation(_Dialect)
|
||||
// CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
|
||||
// CHECK-LABEL: OPERATION_NAME = "test.one_variadic_operand"
|
||||
|
||||
@@ -2,25 +2,58 @@
|
||||
|
||||
from mlir.ir import *
|
||||
import mlir.dialects.builtin as builtin
|
||||
import mlir.dialects.std as std
|
||||
import mlir.dialects.vector as vector
|
||||
|
||||
def run(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
f()
|
||||
with Context(), Location.unknown():
|
||||
f()
|
||||
return f
|
||||
|
||||
# CHECK-LABEL: TEST: testPrintOp
|
||||
@run
|
||||
def testPrintOp():
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
@builtin.FuncOp.from_py_func(VectorType.get((12, 5), F32Type.get()))
|
||||
def print_vector(arg):
|
||||
return vector.PrintOp(arg)
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
|
||||
# CHECK-LABEL: func @print_vector(
|
||||
# CHECK-SAME: %[[ARG:.*]]: vector<12x5xf32>) {
|
||||
# CHECK: vector.print %[[ARG]] : vector<12x5xf32>
|
||||
# CHECK: return
|
||||
# CHECK: }
|
||||
print(module)
|
||||
@builtin.FuncOp.from_py_func(VectorType.get((12, 5), F32Type.get()))
|
||||
def print_vector(arg):
|
||||
return vector.PrintOp(arg)
|
||||
|
||||
# CHECK-LABEL: func @print_vector(
|
||||
# CHECK-SAME: %[[ARG:.*]]: vector<12x5xf32>) {
|
||||
# CHECK: vector.print %[[ARG]] : vector<12x5xf32>
|
||||
# CHECK: return
|
||||
# CHECK: }
|
||||
print(module)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testTransferReadOp
|
||||
@run
|
||||
def testTransferReadOp():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
vector_type = VectorType.get([2, 3], F32Type.get())
|
||||
memref_type = MemRefType.get([-1, -1], F32Type.get())
|
||||
index_type = IndexType.get()
|
||||
mask_type = VectorType.get(vector_type.shape, IntegerType.get_signless(1))
|
||||
identity_map = AffineMap.get_identity(vector_type.rank)
|
||||
identity_map_attr = AffineMapAttr.get(identity_map)
|
||||
func = builtin.FuncOp("transfer_read",
|
||||
([memref_type, index_type,
|
||||
F32Type.get(), mask_type], []))
|
||||
with InsertionPoint(func.add_entry_block()):
|
||||
A, zero, padding, mask = func.arguments
|
||||
vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr,
|
||||
padding, mask, None)
|
||||
vector.TransferReadOp(vector_type, A, [zero, zero], identity_map_attr,
|
||||
padding, None, None)
|
||||
std.ReturnOp([])
|
||||
|
||||
# CHECK: @transfer_read(%[[MEM:.*]]: memref<?x?xf32>, %[[IDX:.*]]: index,
|
||||
# CHECK: %[[PAD:.*]]: f32, %[[MASK:.*]]: vector<2x3xi1>)
|
||||
# CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]], %[[MASK]]
|
||||
# CHECK: vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[PAD]]
|
||||
# CHECK-NOT: %[[MASK]]
|
||||
print(module)
|
||||
|
||||
@@ -67,6 +67,7 @@ class {0}(_ods_ir.OpView):
|
||||
/// Each segment spec is either None (default) or an array of integers
|
||||
/// where:
|
||||
/// 1 = single element (expect non sequence operand/result)
|
||||
/// 0 = optional element (expect a value or None)
|
||||
/// -1 = operand/result is a sequence corresponding to a variadic
|
||||
constexpr const char *opClassSizedSegmentsTemplate = R"Py(
|
||||
_ODS_{0}_SEGMENTS = {1}
|
||||
@@ -505,6 +506,9 @@ constexpr const char *singleResultAppendTemplate = "results.append({0})";
|
||||
/// {0} is the field name.
|
||||
constexpr const char *optionalAppendOperandTemplate =
|
||||
"if {0} is not None: operands.append(_get_op_result_or_value({0}))";
|
||||
constexpr const char *optionalAppendAttrSizedOperandsTemplate =
|
||||
"operands.append(_get_op_result_or_value({0}) if {0} is not None else "
|
||||
"None)";
|
||||
constexpr const char *optionalAppendResultTemplate =
|
||||
"if {0} is not None: results.append({0})";
|
||||
|
||||
@@ -693,7 +697,11 @@ populateBuilderLinesOperand(const Operator &op,
|
||||
if (!element.isVariableLength()) {
|
||||
formatString = singleOperandAppendTemplate;
|
||||
} else if (element.isOptional()) {
|
||||
formatString = optionalAppendOperandTemplate;
|
||||
if (sizedSegments) {
|
||||
formatString = optionalAppendAttrSizedOperandsTemplate;
|
||||
} else {
|
||||
formatString = optionalAppendOperandTemplate;
|
||||
}
|
||||
} else {
|
||||
assert(element.isVariadic() && "unhandled element group type");
|
||||
// If emitting with sizedSegments, then we add the actual list-typed
|
||||
@@ -882,10 +890,10 @@ static void emitSegmentSpec(
|
||||
std::string segmentSpec("[");
|
||||
for (int i = 0, e = getNumElements(op); i < e; ++i) {
|
||||
const NamedTypeConstraint &element = getElement(op, i);
|
||||
if (element.isVariableLength()) {
|
||||
segmentSpec.append("-1,");
|
||||
} else if (element.isOptional()) {
|
||||
if (element.isOptional()) {
|
||||
segmentSpec.append("0,");
|
||||
} else if (element.isVariadic()) {
|
||||
segmentSpec.append("-1,");
|
||||
} else {
|
||||
segmentSpec.append("1,");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user