mirror of
https://github.com/intel/llvm.git
synced 2026-01-12 18:27:07 +08:00
Revert "[MLIR][Transform][Python] transform.foreach wrapper and .owner OpViews" (#172225)
Reverts llvm/llvm-project#171544 ; bots are broken.
This commit is contained in:
@@ -1519,12 +1519,12 @@ public:
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_prop_ro(
|
||||
"owner",
|
||||
[](PyOpResult &self) -> nb::typed<nb::object, PyOpView> {
|
||||
[](PyOpResult &self) -> nb::typed<nb::object, PyOperation> {
|
||||
assert(mlirOperationEqual(self.getParentOperation()->get(),
|
||||
mlirOpResultGetOwner(self.get())) &&
|
||||
"expected the owner of the value in Python to match that in "
|
||||
"the IR");
|
||||
return self.getParentOperation()->createOpView();
|
||||
return self.getParentOperation().getObject();
|
||||
},
|
||||
"Returns the operation that produces this result.");
|
||||
c.def_prop_ro(
|
||||
@@ -4646,7 +4646,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
|
||||
kDumpDocstring)
|
||||
.def_prop_ro(
|
||||
"owner",
|
||||
[](PyValue &self) -> nb::typed<nb::object, PyOpView> {
|
||||
[](PyValue &self) -> nb::object {
|
||||
MlirValue v = self.get();
|
||||
if (mlirValueIsAOpResult(v)) {
|
||||
assert(mlirOperationEqual(self.getParentOperation()->get(),
|
||||
@@ -4654,7 +4654,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
|
||||
"expected the owner of the value in Python to match "
|
||||
"that in "
|
||||
"the IR");
|
||||
return self.getParentOperation()->createOpView();
|
||||
return self.getParentOperation().getObject();
|
||||
}
|
||||
|
||||
if (mlirValueIsABlockArgument(v)) {
|
||||
|
||||
@@ -14,7 +14,8 @@ from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType, Operation
|
||||
def _is_constant_int_like(i):
|
||||
return (
|
||||
isinstance(i, Value)
|
||||
and isinstance(i.owner, ConstantOp)
|
||||
and isinstance(i.owner, Operation)
|
||||
and isinstance(i.owner.opview, ConstantOp)
|
||||
and _is_integer_like_type(i.type)
|
||||
)
|
||||
|
||||
|
||||
@@ -310,8 +310,6 @@ class NamedSequenceOp(NamedSequenceOp):
|
||||
sym_visibility=sym_visibility,
|
||||
arg_attrs=arg_attrs,
|
||||
res_attrs=res_attrs,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
self.regions[0].blocks.append(*input_types)
|
||||
|
||||
@@ -470,54 +468,6 @@ def apply_registered_pass(
|
||||
).result
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class ForeachOp(ForeachOp):
|
||||
def __init__(
|
||||
self,
|
||||
results: Sequence[Type],
|
||||
targets: Sequence[Union[Operation, Value, OpView]],
|
||||
*,
|
||||
with_zip_shortest: Optional[bool] = False,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
targets = [_get_op_result_or_value(target) for target in targets]
|
||||
super().__init__(
|
||||
results_=results,
|
||||
targets=targets,
|
||||
with_zip_shortest=with_zip_shortest,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
self.regions[0].blocks.append(*[target.type for target in targets])
|
||||
|
||||
@property
|
||||
def body(self) -> Block:
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
@property
|
||||
def bodyTargets(self) -> BlockArgumentList:
|
||||
return self.regions[0].blocks[0].arguments
|
||||
|
||||
|
||||
def foreach(
|
||||
results: Sequence[Type],
|
||||
targets: Sequence[Union[Operation, Value, OpView]],
|
||||
*,
|
||||
with_zip_shortest: Optional[bool] = False,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Union[OpResult, OpResultList, ForeachOp]:
|
||||
results = ForeachOp(
|
||||
results=results,
|
||||
targets=targets,
|
||||
with_zip_shortest=with_zip_shortest,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
).results
|
||||
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
|
||||
|
||||
|
||||
AnyOpTypeT = NewType("AnyOpType", AnyOpType)
|
||||
|
||||
|
||||
|
||||
@@ -401,55 +401,3 @@ def testApplyRegisteredPassOp(module: Module):
|
||||
options={"exclude": (symbol_a, symbol_b)},
|
||||
)
|
||||
transform.YieldOp()
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testForeachOp
|
||||
@run
|
||||
def testForeachOp(module: Module):
|
||||
# CHECK: transform.sequence
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[transform.AnyOpType.get()],
|
||||
transform.AnyOpType.get(),
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
# CHECK: {{.*}} = foreach %{{.*}} : !transform.any_op -> !transform.any_op
|
||||
foreach1 = transform.ForeachOp(
|
||||
(transform.AnyOpType.get(),), (sequence.bodyTarget,)
|
||||
)
|
||||
with InsertionPoint(foreach1.body):
|
||||
# CHECK: transform.yield {{.*}} : !transform.any_op
|
||||
transform.yield_(foreach1.bodyTargets)
|
||||
|
||||
a_val = transform.get_operand(
|
||||
transform.AnyValueType.get(), foreach1.result, [0]
|
||||
)
|
||||
a_param = transform.param_constant(
|
||||
transform.AnyParamType.get(), StringAttr.get("a_param")
|
||||
)
|
||||
|
||||
# CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_value, !transform.any_param
|
||||
foreach2 = transform.foreach(
|
||||
(transform.AnyValueType.get(), transform.AnyParamType.get()),
|
||||
(sequence.bodyTarget, a_val, a_param),
|
||||
)
|
||||
with InsertionPoint(foreach2.owner.body):
|
||||
# CHECK: transform.yield {{.*}} : !transform.any_value, !transform.any_param
|
||||
transform.yield_(foreach2.owner.bodyTargets[1:3])
|
||||
|
||||
another_param = transform.param_constant(
|
||||
transform.AnyParamType.get(), StringAttr.get("another_param")
|
||||
)
|
||||
params = transform.merge_handles([a_param, another_param])
|
||||
|
||||
# CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} with_zip_shortest : !transform.any_op, !transform.any_param, !transform.any_param -> !transform.any_op
|
||||
foreach3 = transform.foreach(
|
||||
(transform.AnyOpType.get(),),
|
||||
(foreach1.result, foreach2[1], params),
|
||||
with_zip_shortest=True,
|
||||
)
|
||||
with InsertionPoint(foreach3.owner.body):
|
||||
# CHECK: transform.yield {{.*}} : !transform.any_op
|
||||
transform.yield_((foreach3.owner.bodyTargets[0],))
|
||||
|
||||
transform.yield_((foreach3,))
|
||||
|
||||
Reference in New Issue
Block a user