[MLIR][Transform] FuseOp: accept transform params, add use_forall argument (#161883)

Changes to linalg `structured.fuse` transform op:

* Adds an optional `use_forall` boolean argument which generates a tiled
  `scf.forall` loop instead of `scf.for` loops.
* `tile_sizes` can now be any parameter or handle.
* `tile_interchange` can now be any parameter or handle.
* IR formatting changes from `transform.structured.fuse %0 [4, 8] ...`
  to `transform.structured.fuse %0 tile_sizes [4, 8] ...`
- boolean arguments are now `UnitAttrs` and should be set via the op
  attr-dict: `{apply_cleanup, use_forall}`
This commit is contained in:
Tuomas Kärnä
2025-10-13 13:41:29 +03:00
committed by GitHub
parent 34c7cf0750
commit 032df4b6f7
7 changed files with 359 additions and 74 deletions

View File

@@ -109,11 +109,27 @@ def testFuseOpCompact(target):
)
# CHECK-LABEL: TEST: testFuseOpCompact
# CHECK: transform.sequence
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
# CHECK-SAME: interchange [0, 1] apply_cleanup = true
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}} tile_sizes [4, 8]
# CHECK-SAME: interchange [0, 1] {apply_cleanup}
# CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
@run
@create_sequence
def testFuseOpCompactForall(target):
structured.FuseOp(
target,
tile_sizes=[4, 8],
apply_cleanup=True,
use_forall=True,
)
# CHECK-LABEL: TEST: testFuseOpCompact
# CHECK: transform.sequence
# CHECK: %{{.+}}, %{{.+}} = transform.structured.fuse %{{.*}} tile_sizes [4, 8]
# CHECK-SAME: {apply_cleanup, use_forall}
# CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
@run
@create_sequence
def testFuseOpNoArg(target):
@@ -124,6 +140,44 @@ def testFuseOpNoArg(target):
# CHECK-SAME: (!transform.any_op) -> !transform.any_op
@run
@create_sequence
def testFuseOpParams(target):
structured.FuseOp(
target,
tile_sizes=[constant_param(4), Attribute.parse("8")],
tile_interchange=[constant_param(0), Attribute.parse("1")],
)
# CHECK-LABEL: TEST: testFuseOpParams
# CHECK: transform.sequence
# CHECK-DAG: %[[P:.*]] = transform.param.constant 4
# CHECK-DAG: %[[I:.*]] = transform.param.constant 0
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse
# CHECK-SAME: tile_sizes [%[[P]], 8]
# CHECK-SAME: interchange [%[[I]], 1]
# CHECK-SAME: (!transform.any_op, !transform.param<i64>, !transform.param<i64>) -> (!transform.any_op, !transform.any_op, !transform.any_op)
@run
@create_sequence
def testFuseOpHandles(target):
size1 = structured.MatchOp.match_op_names(target, ["arith.constant"])
ichange1 = structured.MatchOp.match_op_names(target, ["arith.constant"])
structured.FuseOp(
target,
tile_sizes=[size1, 8],
tile_interchange=[ichange1, 1],
)
# CHECK-LABEL: TEST: testFuseOpHandles
# CHECK: transform.sequence
# CHECK: %[[H:.*]] = transform.structured.match
# CHECK: %[[I:.*]] = transform.structured.match
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse
# CHECK-SAME: tile_sizes [%[[H]], 8]
# CHECK-SAME: interchange [%[[I]], 1]
# CHECK-SAME: (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
@run
@create_sequence
def testFuseOpAttributes(target):
@@ -132,7 +186,7 @@ def testFuseOpAttributes(target):
structured.FuseOp(target, tile_sizes=attr, tile_interchange=ichange)
# CHECK-LABEL: TEST: testFuseOpAttributes
# CHECK: transform.sequence
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}} tile_sizes [4, 8]
# CHECK-SAME: interchange [0, 1]
# CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)