Revert "[MLIR][Transform][Python] transform.foreach wrapper and .owner OpViews" (#172225)

Reverts llvm/llvm-project#171544 ; bots are broken.
This commit is contained in:
Mehdi Amini
2025-12-14 22:27:02 +01:00
committed by GitHub
parent bcbbe2c2bc
commit b9fe6532a7
4 changed files with 6 additions and 107 deletions

View File

@@ -1519,12 +1519,12 @@ public:
static void bindDerived(ClassTy &c) {
c.def_prop_ro(
"owner",
[](PyOpResult &self) -> nb::typed<nb::object, PyOpView> {
[](PyOpResult &self) -> nb::typed<nb::object, PyOperation> {
assert(mlirOperationEqual(self.getParentOperation()->get(),
mlirOpResultGetOwner(self.get())) &&
"expected the owner of the value in Python to match that in "
"the IR");
return self.getParentOperation()->createOpView();
return self.getParentOperation().getObject();
},
"Returns the operation that produces this result.");
c.def_prop_ro(
@@ -4646,7 +4646,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
kDumpDocstring)
.def_prop_ro(
"owner",
[](PyValue &self) -> nb::typed<nb::object, PyOpView> {
[](PyValue &self) -> nb::object {
MlirValue v = self.get();
if (mlirValueIsAOpResult(v)) {
assert(mlirOperationEqual(self.getParentOperation()->get(),
@@ -4654,7 +4654,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"expected the owner of the value in Python to match "
"that in "
"the IR");
return self.getParentOperation()->createOpView();
return self.getParentOperation().getObject();
}
if (mlirValueIsABlockArgument(v)) {

View File

@@ -14,7 +14,8 @@ from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType, Operation
def _is_constant_int_like(i):
return (
isinstance(i, Value)
and isinstance(i.owner, ConstantOp)
and isinstance(i.owner, Operation)
and isinstance(i.owner.opview, ConstantOp)
and _is_integer_like_type(i.type)
)

View File

@@ -310,8 +310,6 @@ class NamedSequenceOp(NamedSequenceOp):
sym_visibility=sym_visibility,
arg_attrs=arg_attrs,
res_attrs=res_attrs,
loc=loc,
ip=ip,
)
self.regions[0].blocks.append(*input_types)
@@ -470,54 +468,6 @@ def apply_registered_pass(
).result
@_ods_cext.register_operation(_Dialect, replace=True)
class ForeachOp(ForeachOp):
def __init__(
self,
results: Sequence[Type],
targets: Sequence[Union[Operation, Value, OpView]],
*,
with_zip_shortest: Optional[bool] = False,
loc=None,
ip=None,
):
targets = [_get_op_result_or_value(target) for target in targets]
super().__init__(
results_=results,
targets=targets,
with_zip_shortest=with_zip_shortest,
loc=loc,
ip=ip,
)
self.regions[0].blocks.append(*[target.type for target in targets])
@property
def body(self) -> Block:
return self.regions[0].blocks[0]
@property
def bodyTargets(self) -> BlockArgumentList:
return self.regions[0].blocks[0].arguments
def foreach(
results: Sequence[Type],
targets: Sequence[Union[Operation, Value, OpView]],
*,
with_zip_shortest: Optional[bool] = False,
loc=None,
ip=None,
) -> Union[OpResult, OpResultList, ForeachOp]:
results = ForeachOp(
results=results,
targets=targets,
with_zip_shortest=with_zip_shortest,
loc=loc,
ip=ip,
).results
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
AnyOpTypeT = NewType("AnyOpType", AnyOpType)

View File

@@ -401,55 +401,3 @@ def testApplyRegisteredPassOp(module: Module):
options={"exclude": (symbol_a, symbol_b)},
)
transform.YieldOp()
# CHECK-LABEL: TEST: testForeachOp
@run
def testForeachOp(module: Module):
# CHECK: transform.sequence
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[transform.AnyOpType.get()],
transform.AnyOpType.get(),
)
with InsertionPoint(sequence.body):
# CHECK: {{.*}} = foreach %{{.*}} : !transform.any_op -> !transform.any_op
foreach1 = transform.ForeachOp(
(transform.AnyOpType.get(),), (sequence.bodyTarget,)
)
with InsertionPoint(foreach1.body):
# CHECK: transform.yield {{.*}} : !transform.any_op
transform.yield_(foreach1.bodyTargets)
a_val = transform.get_operand(
transform.AnyValueType.get(), foreach1.result, [0]
)
a_param = transform.param_constant(
transform.AnyParamType.get(), StringAttr.get("a_param")
)
# CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_value, !transform.any_param
foreach2 = transform.foreach(
(transform.AnyValueType.get(), transform.AnyParamType.get()),
(sequence.bodyTarget, a_val, a_param),
)
with InsertionPoint(foreach2.owner.body):
# CHECK: transform.yield {{.*}} : !transform.any_value, !transform.any_param
transform.yield_(foreach2.owner.bodyTargets[1:3])
another_param = transform.param_constant(
transform.AnyParamType.get(), StringAttr.get("another_param")
)
params = transform.merge_handles([a_param, another_param])
# CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} with_zip_shortest : !transform.any_op, !transform.any_param, !transform.any_param -> !transform.any_op
foreach3 = transform.foreach(
(transform.AnyOpType.get(),),
(foreach1.result, foreach2[1], params),
with_zip_shortest=True,
)
with InsertionPoint(foreach3.owner.body):
# CHECK: transform.yield {{.*}} : !transform.any_op
transform.yield_((foreach3.owner.bodyTargets[0],))
transform.yield_((foreach3,))