[MLIR][Transform][Python] transform.foreach wrapper and .owner OpViews (#171544)

Friendlier wrapper for `transform.foreach`.

To facilitate that friendliness, makes it so that `OpResult.owner`
returns the relevant `OpView` instead of `Operation`. For good measure,
also changes `Value.owner` to return `OpView` instead of `Operation`,
thereby ensuring consistency. That is, makes it is so that all
op-returning `.owner` accessors return `OpView` (and thereby give access
to all goodies available on registered `OpView`s.)
This commit is contained in:
Rolf Morel
2025-12-14 20:44:15 +00:00
committed by GitHub
parent 53cf22f3a1
commit 4cdec92827
4 changed files with 107 additions and 6 deletions

View File

@@ -14,8 +14,7 @@ from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType, Operation
def _is_constant_int_like(i):
return (
isinstance(i, Value)
and isinstance(i.owner, Operation)
and isinstance(i.owner.opview, ConstantOp)
and isinstance(i.owner, ConstantOp)
and _is_integer_like_type(i.type)
)

View File

@@ -310,6 +310,8 @@ class NamedSequenceOp(NamedSequenceOp):
sym_visibility=sym_visibility,
arg_attrs=arg_attrs,
res_attrs=res_attrs,
loc=loc,
ip=ip,
)
self.regions[0].blocks.append(*input_types)
@@ -468,6 +470,54 @@ def apply_registered_pass(
).result
@_ods_cext.register_operation(_Dialect, replace=True)
class ForeachOp(ForeachOp):
def __init__(
self,
results: Sequence[Type],
targets: Sequence[Union[Operation, Value, OpView]],
*,
with_zip_shortest: Optional[bool] = False,
loc=None,
ip=None,
):
targets = [_get_op_result_or_value(target) for target in targets]
super().__init__(
results_=results,
targets=targets,
with_zip_shortest=with_zip_shortest,
loc=loc,
ip=ip,
)
self.regions[0].blocks.append(*[target.type for target in targets])
@property
def body(self) -> Block:
return self.regions[0].blocks[0]
@property
def bodyTargets(self) -> BlockArgumentList:
return self.regions[0].blocks[0].arguments
def foreach(
results: Sequence[Type],
targets: Sequence[Union[Operation, Value, OpView]],
*,
with_zip_shortest: Optional[bool] = False,
loc=None,
ip=None,
) -> Union[OpResult, OpResultList, ForeachOp]:
results = ForeachOp(
results=results,
targets=targets,
with_zip_shortest=with_zip_shortest,
loc=loc,
ip=ip,
).results
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
AnyOpTypeT = NewType("AnyOpType", AnyOpType)