mirror of
https://github.com/intel/llvm.git
synced 2026-01-14 03:50:17 +08:00
[MLIR][Transform][Python] transform.foreach wrapper and .owner OpViews (#171544)
Friendlier wrapper for `transform.foreach`. To facilitate that friendliness, makes it so that `OpResult.owner` returns the relevant `OpView` instead of `Operation`. For good measure, also changes `Value.owner` to return `OpView` instead of `Operation`, thereby ensuring consistency. That is, makes it is so that all op-returning `.owner` accessors return `OpView` (and thereby give access to all goodies available on registered `OpView`s.)
This commit is contained in:
@@ -14,8 +14,7 @@ from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType, Operation
|
||||
def _is_constant_int_like(i):
|
||||
return (
|
||||
isinstance(i, Value)
|
||||
and isinstance(i.owner, Operation)
|
||||
and isinstance(i.owner.opview, ConstantOp)
|
||||
and isinstance(i.owner, ConstantOp)
|
||||
and _is_integer_like_type(i.type)
|
||||
)
|
||||
|
||||
|
||||
@@ -310,6 +310,8 @@ class NamedSequenceOp(NamedSequenceOp):
|
||||
sym_visibility=sym_visibility,
|
||||
arg_attrs=arg_attrs,
|
||||
res_attrs=res_attrs,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
self.regions[0].blocks.append(*input_types)
|
||||
|
||||
@@ -468,6 +470,54 @@ def apply_registered_pass(
|
||||
).result
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class ForeachOp(ForeachOp):
|
||||
def __init__(
|
||||
self,
|
||||
results: Sequence[Type],
|
||||
targets: Sequence[Union[Operation, Value, OpView]],
|
||||
*,
|
||||
with_zip_shortest: Optional[bool] = False,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
targets = [_get_op_result_or_value(target) for target in targets]
|
||||
super().__init__(
|
||||
results_=results,
|
||||
targets=targets,
|
||||
with_zip_shortest=with_zip_shortest,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
self.regions[0].blocks.append(*[target.type for target in targets])
|
||||
|
||||
@property
|
||||
def body(self) -> Block:
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
@property
|
||||
def bodyTargets(self) -> BlockArgumentList:
|
||||
return self.regions[0].blocks[0].arguments
|
||||
|
||||
|
||||
def foreach(
|
||||
results: Sequence[Type],
|
||||
targets: Sequence[Union[Operation, Value, OpView]],
|
||||
*,
|
||||
with_zip_shortest: Optional[bool] = False,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Union[OpResult, OpResultList, ForeachOp]:
|
||||
results = ForeachOp(
|
||||
results=results,
|
||||
targets=targets,
|
||||
with_zip_shortest=with_zip_shortest,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
).results
|
||||
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
|
||||
|
||||
|
||||
AnyOpTypeT = NewType("AnyOpType", AnyOpType)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user