mirror of
https://github.com/intel/llvm.git
synced 2026-01-12 10:17:28 +08:00
[MLIR][Transform][Python] transform.foreach wrapper and .owner OpViews (#172228)
Friendlier wrapper for transform.foreach. To facilitate that friendliness, makes it so that OpResult.owner returns the relevant OpView instead of Operation. For good measure, also changes Value.owner to return OpView instead of Operation, thereby ensuring consistency. That is, makes it is so that all op-returning .owner accessors return OpView (and thereby give access to all goodies available on registered OpViews.) Reland of #171544 due to fixup for integration test.
This commit is contained in:
@@ -1519,12 +1519,12 @@ public:
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_prop_ro(
|
||||
"owner",
|
||||
[](PyOpResult &self) -> nb::typed<nb::object, PyOperation> {
|
||||
[](PyOpResult &self) -> nb::typed<nb::object, PyOpView> {
|
||||
assert(mlirOperationEqual(self.getParentOperation()->get(),
|
||||
mlirOpResultGetOwner(self.get())) &&
|
||||
"expected the owner of the value in Python to match that in "
|
||||
"the IR");
|
||||
return self.getParentOperation().getObject();
|
||||
return self.getParentOperation()->createOpView();
|
||||
},
|
||||
"Returns the operation that produces this result.");
|
||||
c.def_prop_ro(
|
||||
@@ -4646,7 +4646,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
|
||||
kDumpDocstring)
|
||||
.def_prop_ro(
|
||||
"owner",
|
||||
[](PyValue &self) -> nb::object {
|
||||
[](PyValue &self) -> nb::typed<nb::object, PyOpView> {
|
||||
MlirValue v = self.get();
|
||||
if (mlirValueIsAOpResult(v)) {
|
||||
assert(mlirOperationEqual(self.getParentOperation()->get(),
|
||||
@@ -4654,7 +4654,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
|
||||
"expected the owner of the value in Python to match "
|
||||
"that in "
|
||||
"the IR");
|
||||
return self.getParentOperation().getObject();
|
||||
return self.getParentOperation()->createOpView();
|
||||
}
|
||||
|
||||
if (mlirValueIsABlockArgument(v)) {
|
||||
|
||||
@@ -14,8 +14,7 @@ from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType, Operation
|
||||
def _is_constant_int_like(i):
|
||||
return (
|
||||
isinstance(i, Value)
|
||||
and isinstance(i.owner, Operation)
|
||||
and isinstance(i.owner.opview, ConstantOp)
|
||||
and isinstance(i.owner, ConstantOp)
|
||||
and _is_integer_like_type(i.type)
|
||||
)
|
||||
|
||||
|
||||
@@ -310,6 +310,8 @@ class NamedSequenceOp(NamedSequenceOp):
|
||||
sym_visibility=sym_visibility,
|
||||
arg_attrs=arg_attrs,
|
||||
res_attrs=res_attrs,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
self.regions[0].blocks.append(*input_types)
|
||||
|
||||
@@ -468,6 +470,54 @@ def apply_registered_pass(
|
||||
).result
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class ForeachOp(ForeachOp):
|
||||
def __init__(
|
||||
self,
|
||||
results: Sequence[Type],
|
||||
targets: Sequence[Union[Operation, Value, OpView]],
|
||||
*,
|
||||
with_zip_shortest: Optional[bool] = False,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
targets = [_get_op_result_or_value(target) for target in targets]
|
||||
super().__init__(
|
||||
results_=results,
|
||||
targets=targets,
|
||||
with_zip_shortest=with_zip_shortest,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
self.regions[0].blocks.append(*[target.type for target in targets])
|
||||
|
||||
@property
|
||||
def body(self) -> Block:
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
@property
|
||||
def bodyTargets(self) -> BlockArgumentList:
|
||||
return self.regions[0].blocks[0].arguments
|
||||
|
||||
|
||||
def foreach(
|
||||
results: Sequence[Type],
|
||||
targets: Sequence[Union[Operation, Value, OpView]],
|
||||
*,
|
||||
with_zip_shortest: Optional[bool] = False,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Union[OpResult, OpResultList, ForeachOp]:
|
||||
results = ForeachOp(
|
||||
results=results,
|
||||
targets=targets,
|
||||
with_zip_shortest=with_zip_shortest,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
).results
|
||||
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
|
||||
|
||||
|
||||
AnyOpTypeT = NewType("AnyOpType", AnyOpType)
|
||||
|
||||
|
||||
|
||||
@@ -401,3 +401,55 @@ def testApplyRegisteredPassOp(module: Module):
|
||||
options={"exclude": (symbol_a, symbol_b)},
|
||||
)
|
||||
transform.YieldOp()
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testForeachOp
|
||||
@run
|
||||
def testForeachOp(module: Module):
|
||||
# CHECK: transform.sequence
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[transform.AnyOpType.get()],
|
||||
transform.AnyOpType.get(),
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
# CHECK: {{.*}} = foreach %{{.*}} : !transform.any_op -> !transform.any_op
|
||||
foreach1 = transform.ForeachOp(
|
||||
(transform.AnyOpType.get(),), (sequence.bodyTarget,)
|
||||
)
|
||||
with InsertionPoint(foreach1.body):
|
||||
# CHECK: transform.yield {{.*}} : !transform.any_op
|
||||
transform.yield_(foreach1.bodyTargets)
|
||||
|
||||
a_val = transform.get_operand(
|
||||
transform.AnyValueType.get(), foreach1.result, [0]
|
||||
)
|
||||
a_param = transform.param_constant(
|
||||
transform.AnyParamType.get(), StringAttr.get("a_param")
|
||||
)
|
||||
|
||||
# CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_value, !transform.any_param
|
||||
foreach2 = transform.foreach(
|
||||
(transform.AnyValueType.get(), transform.AnyParamType.get()),
|
||||
(sequence.bodyTarget, a_val, a_param),
|
||||
)
|
||||
with InsertionPoint(foreach2.owner.body):
|
||||
# CHECK: transform.yield {{.*}} : !transform.any_value, !transform.any_param
|
||||
transform.yield_(foreach2.owner.bodyTargets[1:3])
|
||||
|
||||
another_param = transform.param_constant(
|
||||
transform.AnyParamType.get(), StringAttr.get("another_param")
|
||||
)
|
||||
params = transform.merge_handles([a_param, another_param])
|
||||
|
||||
# CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} with_zip_shortest : !transform.any_op, !transform.any_param, !transform.any_param -> !transform.any_op
|
||||
foreach3 = transform.foreach(
|
||||
(transform.AnyOpType.get(),),
|
||||
(foreach1.result, foreach2[1], params),
|
||||
with_zip_shortest=True,
|
||||
)
|
||||
with InsertionPoint(foreach3.owner.body):
|
||||
# CHECK: transform.yield {{.*}} : !transform.any_op
|
||||
transform.yield_((foreach3.owner.bodyTargets[0],))
|
||||
|
||||
transform.yield_((foreach3,))
|
||||
|
||||
@@ -174,7 +174,7 @@ def get_pdl_pattern_fold():
|
||||
|
||||
def is_zero(value):
|
||||
op = value.owner
|
||||
if isinstance(op, Operation):
|
||||
if isinstance(op, OpView):
|
||||
return op.name == "myint.constant" and op.attributes["value"].value == 0
|
||||
return False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user