[MLIR][Transform][Python] transform.foreach wrapper and .owner OpViews (#172228)

Friendlier wrapper for transform.foreach.

To facilitate that friendliness, makes it so that OpResult.owner returns
the relevant OpView instead of Operation. For good measure, also changes
Value.owner to return OpView instead of Operation, thereby ensuring
consistency. That is, makes it is so that all op-returning .owner
accessors return OpView (and thereby give access to all goodies
available on registered OpViews.)

Reland of #171544 due to fixup for integration test.
This commit is contained in:
Rolf Morel
2025-12-14 22:10:31 +00:00
committed by GitHub
parent 423919d31f
commit f12fcf030c
5 changed files with 108 additions and 7 deletions

View File

@@ -1519,12 +1519,12 @@ public:
static void bindDerived(ClassTy &c) {
c.def_prop_ro(
"owner",
[](PyOpResult &self) -> nb::typed<nb::object, PyOperation> {
[](PyOpResult &self) -> nb::typed<nb::object, PyOpView> {
assert(mlirOperationEqual(self.getParentOperation()->get(),
mlirOpResultGetOwner(self.get())) &&
"expected the owner of the value in Python to match that in "
"the IR");
return self.getParentOperation().getObject();
return self.getParentOperation()->createOpView();
},
"Returns the operation that produces this result.");
c.def_prop_ro(
@@ -4646,7 +4646,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
kDumpDocstring)
.def_prop_ro(
"owner",
[](PyValue &self) -> nb::object {
[](PyValue &self) -> nb::typed<nb::object, PyOpView> {
MlirValue v = self.get();
if (mlirValueIsAOpResult(v)) {
assert(mlirOperationEqual(self.getParentOperation()->get(),
@@ -4654,7 +4654,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"expected the owner of the value in Python to match "
"that in "
"the IR");
return self.getParentOperation().getObject();
return self.getParentOperation()->createOpView();
}
if (mlirValueIsABlockArgument(v)) {

View File

@@ -14,8 +14,7 @@ from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType, Operation
def _is_constant_int_like(i):
return (
isinstance(i, Value)
and isinstance(i.owner, Operation)
and isinstance(i.owner.opview, ConstantOp)
and isinstance(i.owner, ConstantOp)
and _is_integer_like_type(i.type)
)

View File

@@ -310,6 +310,8 @@ class NamedSequenceOp(NamedSequenceOp):
sym_visibility=sym_visibility,
arg_attrs=arg_attrs,
res_attrs=res_attrs,
loc=loc,
ip=ip,
)
self.regions[0].blocks.append(*input_types)
@@ -468,6 +470,54 @@ def apply_registered_pass(
).result
@_ods_cext.register_operation(_Dialect, replace=True)
class ForeachOp(ForeachOp):
def __init__(
self,
results: Sequence[Type],
targets: Sequence[Union[Operation, Value, OpView]],
*,
with_zip_shortest: Optional[bool] = False,
loc=None,
ip=None,
):
targets = [_get_op_result_or_value(target) for target in targets]
super().__init__(
results_=results,
targets=targets,
with_zip_shortest=with_zip_shortest,
loc=loc,
ip=ip,
)
self.regions[0].blocks.append(*[target.type for target in targets])
@property
def body(self) -> Block:
return self.regions[0].blocks[0]
@property
def bodyTargets(self) -> BlockArgumentList:
return self.regions[0].blocks[0].arguments
def foreach(
results: Sequence[Type],
targets: Sequence[Union[Operation, Value, OpView]],
*,
with_zip_shortest: Optional[bool] = False,
loc=None,
ip=None,
) -> Union[OpResult, OpResultList, ForeachOp]:
results = ForeachOp(
results=results,
targets=targets,
with_zip_shortest=with_zip_shortest,
loc=loc,
ip=ip,
).results
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
AnyOpTypeT = NewType("AnyOpType", AnyOpType)

View File

@@ -401,3 +401,55 @@ def testApplyRegisteredPassOp(module: Module):
options={"exclude": (symbol_a, symbol_b)},
)
transform.YieldOp()
# CHECK-LABEL: TEST: testForeachOp
@run
def testForeachOp(module: Module):
# CHECK: transform.sequence
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[transform.AnyOpType.get()],
transform.AnyOpType.get(),
)
with InsertionPoint(sequence.body):
# CHECK: {{.*}} = foreach %{{.*}} : !transform.any_op -> !transform.any_op
foreach1 = transform.ForeachOp(
(transform.AnyOpType.get(),), (sequence.bodyTarget,)
)
with InsertionPoint(foreach1.body):
# CHECK: transform.yield {{.*}} : !transform.any_op
transform.yield_(foreach1.bodyTargets)
a_val = transform.get_operand(
transform.AnyValueType.get(), foreach1.result, [0]
)
a_param = transform.param_constant(
transform.AnyParamType.get(), StringAttr.get("a_param")
)
# CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_value, !transform.any_param
foreach2 = transform.foreach(
(transform.AnyValueType.get(), transform.AnyParamType.get()),
(sequence.bodyTarget, a_val, a_param),
)
with InsertionPoint(foreach2.owner.body):
# CHECK: transform.yield {{.*}} : !transform.any_value, !transform.any_param
transform.yield_(foreach2.owner.bodyTargets[1:3])
another_param = transform.param_constant(
transform.AnyParamType.get(), StringAttr.get("another_param")
)
params = transform.merge_handles([a_param, another_param])
# CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} with_zip_shortest : !transform.any_op, !transform.any_param, !transform.any_param -> !transform.any_op
foreach3 = transform.foreach(
(transform.AnyOpType.get(),),
(foreach1.result, foreach2[1], params),
with_zip_shortest=True,
)
with InsertionPoint(foreach3.owner.body):
# CHECK: transform.yield {{.*}} : !transform.any_op
transform.yield_((foreach3.owner.bodyTargets[0],))
transform.yield_((foreach3,))

View File

@@ -174,7 +174,7 @@ def get_pdl_pattern_fold():
def is_zero(value):
op = value.owner
if isinstance(op, Operation):
if isinstance(op, OpView):
return op.name == "myint.constant" and op.attributes["value"].value == 0
return False