[mlir][bufferization] Enable moving dependent values in eliminate-empty-tensors (#169718)

Currently empty tensor elimination by constructing a SubsetExtractionOp
to match a SubsetInsertionOp at the end of a DPS chain will fail if any
operands required by the insertion op don't dominate the insertion point
for the extraction op.

This change improves the transformation by attempting to move all pure
producers of required operands to the insertion point of the extraction
op. In the process this improves a number of tests for empty tensor
elimination.
This commit is contained in:
Quinn Dawkins
2025-12-05 14:40:08 -05:00
committed by GitHub
parent 29fa151a07
commit bb17dfa7d1
6 changed files with 194 additions and 115 deletions

View File

@@ -84,7 +84,8 @@ LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
/// Move definitions of `values` before an insertion point. Current support is
/// only for movement of definitions within the same basic block. Note that this
/// is an all-or-nothing approach. Either definitions of all values are moved
/// before insertion point, or none of them are.
/// before insertion point, or none of them are. Any side-effecting operations
/// in the producer chain pessimistically blocks movement.
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values,
Operation *insertionPoint,
DominanceInfo &dominance);

View File

@@ -16,6 +16,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
#include "mlir/Transforms/RegionUtils.h"
namespace mlir {
namespace bufferization {
@@ -105,8 +106,13 @@ Value mlir::bufferization::buildSubsetExtraction(RewriterBase &rewriter,
// this replacement.
Operation *insertionPoint =
findValidInsertionPoint(emptyTensorOp, user, neededValues);
if (!insertionPoint)
return {};
if (!insertionPoint) {
// If no already suitable insertion point was found, attempt to move all
// needed values before the user.
if (failed(moveValueDefinitions(rewriter, neededValues, user)))
return {};
insertionPoint = user;
}
rewriter.setInsertionPoint(insertionPoint);
Value replacement =

View File

@@ -1149,9 +1149,8 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
// Remove the values that already dominate the insertion point.
SmallVector<Value> prunedValues;
for (auto value : values) {
if (dominance.properlyDominates(value, insertionPoint)) {
if (dominance.properlyDominates(value, insertionPoint))
continue;
}
// Block arguments are not supported.
if (isa<BlockArgument>(value)) {
return rewriter.notifyMatchFailure(
@@ -1178,8 +1177,13 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
// Since current support is to only move within a same basic block,
// the slices dont need to look past block arguments.
options.omitBlockArguments = true;
bool dependsOnSideEffectingOp = false;
options.filter = [&](Operation *sliceBoundaryOp) {
return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
bool mustMove =
!dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
if (mustMove && !isPure(sliceBoundaryOp))
dependsOnSideEffectingOp = true;
return mustMove;
};
llvm::SetVector<Operation *> slice;
for (auto value : prunedValues) {
@@ -1188,6 +1192,10 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
(void)result;
}
// Check if any operation in the slice is side-effecting.
if (dependsOnSideEffectingOp)
return failure();
// If the slice contains `insertionPoint` cannot move the dependencies.
if (slice.contains(insertionPoint)) {
return rewriter.notifyMatchFailure(
@@ -1198,9 +1206,8 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
// Sort operations topologically before moving.
mlir::topologicalSort(slice);
for (Operation *op : slice) {
for (Operation *op : slice)
rewriter.moveOpBefore(op, insertionPoint);
}
return success();
}

View File

@@ -368,21 +368,18 @@ func.func @multiple_materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32
// -----
// `EmptyTensorElimination` fails to find a valid insertion
// point for the new injected `SubsetExtraction`.
// CHECK-LABEL: func.func @fail_to_eliminate_any_empty_tensors
func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> {
// CHECK-LABEL: func.func @eliminate_all_empty_tensors
func.func @eliminate_all_empty_tensors() -> tensor<5x6x128xf32> {
%cst_1 = arith.constant 1.0 : f32
%cst_2 = arith.constant 2.0 : f32
// CHECK: memref.alloc
// CHECK: memref.alloc
// CHECK: memref.alloc
// CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
// CHECK-NOT: memref.alloc
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
%empty_2 = tensor.empty() : tensor<5x6x64xf32>
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
%cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
// CHECK: memref.copy
// CHECK-NOT: memref.copy
%inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
@@ -392,20 +389,19 @@ func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> {
// -----
// CHECK-LABEL: func.func @succeed_to_eliminate_one_empty_tensor
func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
// CHECK-LABEL: func.func @eliminate_concatenated_empty_tensors
func.func @eliminate_concatenated_empty_tensors() -> tensor<5x6x128xf32> {
%cst_1 = arith.constant 1.0 : f32
%cst_2 = arith.constant 2.0 : f32
// CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
// CHECK: memref.alloc
// CHECK-NOT: memref.alloc
%cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
%concatenated_empty = tensor.empty() : tensor<5x6x128xf32>
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
%empty_2 = tensor.empty() : tensor<5x6x64xf32>
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
// CHECK: memref.copy
%inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
// CHECK-NOT: memref.copy
%inserted_slice_1 = tensor.insert_slice %res_1 into %concatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
@@ -420,20 +416,22 @@ func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
// CHECK-ELIM-LABEL: func.func @multi_use_of_the_same_tensor_empty
// CHECK-LABEL: func.func @multi_use_of_the_same_tensor_empty
// CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
// CHECK-NOT: memref.alloc
// CHECK-NOT: memref.copy
// CHECK-ELIM: tensor.extract_slice {{.*}}[0, 0, 0]
// CHECK-ELIM: linalg.fill
// CHECK-ELIM: tensor.extract_slice {{.*}}[0, 0, 64]
// CHECK-ELIM: linalg.fill
func.func @multi_use_of_the_same_tensor_empty() -> tensor<5x6x128xf32> {
%cst_1 = arith.constant 1.0 : f32
%cst_2 = arith.constant 2.0 : f32
%cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
// CHECK-ELIM: %[[VAL_3:.*]] = tensor.extract_slice
// CHECK-ELIM: linalg.fill ins(%[[VAL_0:.*]] : f32) outs(%[[VAL_3]]
// CHECK-ELIM-NOT: linalg.fill ins(%[[VAL_1:.*]] : f32) outs(%[[VAL_3]]
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
// CHECK: memref.copy
%inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
// CHECK-NOT: memref.copy
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
return %inserted_slice_2 : tensor<5x6x128xf32>
@@ -476,3 +474,66 @@ func.func @direct_use_of_tensor_empty(%arg0: tensor<5x6x128xf32>) -> tensor<5x6x
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
return %inserted_slice_1 : tensor<5x6x128xf32>
}
// -----
// Test that dependent pure operations are moved before the
// insertion point to enable empty tensor elimination.
// CHECK-LABEL: func.func @move_dependent_arith_op(
// CHECK-SAME: %[[ARG0:.*]]: memref<10xf32>
// CHECK-SAME: %[[ARG1:.*]]: index
// CHECK-NOT: memref.alloc
// CHECK: %[[C5:.*]] = arith.constant 5 : index
// CHECK: %[[OFFSET:.*]] = arith.addi %[[ARG1]], %[[C5]]
// CHECK: %[[SV:.*]] = memref.subview %[[ARG0]][%[[OFFSET]]] [5] [1]
// CHECK: linalg.fill {{.*}} outs(%[[SV]]
// CHECK: return %[[ARG0]]
// CHECK-ELIM-LABEL: func.func @move_dependent_arith_op(
// CHECK-ELIM-SAME: %[[ARG0:.*]]: tensor<10xf32>
// CHECK-ELIM-SAME: %[[ARG1:.*]]: index
// CHECK-ELIM: %[[C5:.*]] = arith.constant 5 : index
// CHECK-ELIM: %[[OFFSET:.*]] = arith.addi %[[ARG1]], %[[C5]]
// CHECK-ELIM: %[[SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[OFFSET]]] [5] [1]
// CHECK-ELIM: %[[FILL:.*]] = linalg.fill {{.*}} outs(%[[SLICE]]
// CHECK-ELIM: tensor.insert_slice %[[FILL]] into %[[ARG0]][%[[OFFSET]]]
func.func @move_dependent_arith_op(
%arg0: tensor<10xf32> {bufferization.buffer_layout = affine_map<(d0) -> (d0)>, bufferization.writable = true},
%arg1: index, %f: f32) -> tensor<10xf32>
{
%0 = tensor.empty() : tensor<5xf32>
%1 = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
%c5 = arith.constant 5 : index
%offset = arith.addi %arg1, %c5 : index
%2 = tensor.insert_slice %1 into %arg0[%offset][5][1]
: tensor<5xf32> into tensor<10xf32>
return %2 : tensor<10xf32>
}
// -----
// Test that side-effecting operations are not moved, preventing empty
// tensor elimination.
// CHECK-LABEL: func.func @side_effecting_op_blocks_movement(
// CHECK: memref.alloc
// CHECK: linalg.fill
// CHECK: memref.load
// CHECK: memref.subview
// CHECK: memref.copy
// CHECK-ELIM-LABEL: func.func @side_effecting_op_blocks_movement(
// CHECK-ELIM: tensor.empty
// CHECK-ELIM: linalg.fill
// CHECK-ELIM: memref.load
// CHECK-ELIM: tensor.insert_slice
func.func @side_effecting_op_blocks_movement(
%arg0: tensor<10xf32> {bufferization.buffer_layout = affine_map<(d0) -> (d0)>, bufferization.writable = true},
%mem: memref<index>, %f: f32) -> tensor<10xf32>
{
%0 = tensor.empty() : tensor<5xf32>
%1 = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
%offset = memref.load %mem[] : memref<index>
%2 = tensor.insert_slice %1 into %arg0[%offset][5][1]
: tensor<5xf32> into tensor<10xf32>
return %2 : tensor<10xf32>
}

View File

@@ -238,25 +238,26 @@ module attributes {transform.with_named_sequence} {
// -----
// Check simple move value definitions before insertion operation.
func.func @simple_move_values() -> f32 {
%0 = "before"() : () -> (f32)
%1 = "moved_op_1"() : () -> (f32)
%2 = "moved_op_2"() : () -> (f32)
%3 = "foo"(%1, %2) : (f32, f32) -> (f32)
return %3 : f32
func.func @simple_move_values(%arg0 : index) -> index {
%c0 = arith.constant 0 : index
%0 = "before"() : () -> (index)
%1 = arith.addi %arg0, %c0 {"moved_op_1"} : index
%2 = arith.subi %arg0, %c0 {"moved_op_2"} : index
%3 = "foo"(%1, %2) : (index, index) -> (index)
return %3 : index
}
// CHECK-LABEL: func @simple_move_values()
// CHECK: %[[MOVED1:.+]] = "moved_op_1"
// CHECK: %[[MOVED2:.+]] = "moved_op_2"
// CHECK-LABEL: func @simple_move_values(
// CHECK: %[[MOVED1:.+]] = arith.addi {{.*}} {moved_op_1}
// CHECK: %[[MOVED2:.+]] = arith.subi {{.*}} {moved_op_2}
// CHECK: %[[BEFORE:.+]] = "before"
// CHECK: %[[FOO:.+]] = "foo"(%[[MOVED1]], %[[MOVED2]])
// CHECK: return %[[FOO]]
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["moved_op_1"]} in %arg0
%op1 = transform.structured.match ops{["arith.addi"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["moved_op_2"]} in %arg0
%op2 = transform.structured.match ops{["arith.subi"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op3 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
@@ -271,23 +272,26 @@ module attributes {transform.with_named_sequence} {
// -----
// Compute slice including the implicitly captured values.
func.func @move_region_dependencies_values() -> f32 {
%0 = "before"() : () -> (f32)
%1 = "moved_op_1"() : () -> (f32)
%2 = "moved_op_2"() ({
%3 = "inner_op"(%1) : (f32) -> (f32)
"yield"(%3) : (f32) -> ()
}) : () -> (f32)
return %2 : f32
func.func @move_region_dependencies_values(%arg0 : index, %cond : i1) -> index {
%0 = "before"() : () -> (index)
%1 = arith.addi %arg0, %arg0 {moved_op_1} : index
%2 = scf.if %cond -> index {
%3 = arith.muli %1, %1 {inner_op} : index
scf.yield %3 : index
} else {
scf.yield %1 : index
}
return %2 : index
}
// CHECK-LABEL: func @move_region_dependencies_values()
// CHECK: %[[MOVED1:.+]] = "moved_op_1"
// CHECK: %[[MOVED2:.+]] = "moved_op_2"
// CHECK-LABEL: func @move_region_dependencies_values(
// CHECK: %[[MOVED1:.+]] = arith.addi {{.*}} {moved_op_1}
// CHECK: scf.if
// CHECK: arith.muli %[[MOVED1]], %[[MOVED1]] {inner_op}
// CHECK: %[[BEFORE:.+]] = "before"
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["moved_op_2"]} in %arg0
%op1 = transform.structured.match ops{["scf.if"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
@@ -301,31 +305,31 @@ module attributes {transform.with_named_sequence} {
// -----
// Move operations in toplogical sort order
func.func @move_values_in_topological_sort_order() -> f32 {
%0 = "before"() : () -> (f32)
%1 = "moved_op_1"() : () -> (f32)
%2 = "moved_op_2"() : () -> (f32)
%3 = "moved_op_3"(%1) : (f32) -> (f32)
%4 = "moved_op_4"(%1, %3) : (f32, f32) -> (f32)
%5 = "moved_op_5"(%2) : (f32) -> (f32)
%6 = "foo"(%4, %5) : (f32, f32) -> (f32)
return %6 : f32
func.func @move_values_in_topological_sort_order(%arg0 : index, %arg1 : index) -> index {
%0 = "before"() : () -> (index)
%1 = arith.addi %arg0, %arg0 {moved_op_1} : index
%2 = arith.addi %arg1, %arg1 {moved_op_2} : index
%3 = arith.muli %1, %1 {moved_op_3} : index
%4 = arith.andi %1, %3 {moved_op_4} : index
%5 = arith.subi %2, %2 {moved_op_5} : index
%6 = "foo"(%4, %5) : (index, index) -> (index)
return %6 : index
}
// CHECK-LABEL: func @move_values_in_topological_sort_order()
// CHECK: %[[MOVED_1:.+]] = "moved_op_1"
// CHECK-DAG: %[[MOVED_2:.+]] = "moved_op_3"(%[[MOVED_1]])
// CHECK-DAG: %[[MOVED_3:.+]] = "moved_op_4"(%[[MOVED_1]], %[[MOVED_2]])
// CHECK-DAG: %[[MOVED_4:.+]] = "moved_op_2"
// CHECK-DAG: %[[MOVED_5:.+]] = "moved_op_5"(%[[MOVED_4]])
// CHECK-LABEL: func @move_values_in_topological_sort_order(
// CHECK: %[[MOVED_1:.+]] = arith.addi {{.*}} {moved_op_1}
// CHECK-DAG: %[[MOVED_2:.+]] = arith.muli %[[MOVED_1]], %[[MOVED_1]] {moved_op_3}
// CHECK-DAG: %[[MOVED_3:.+]] = arith.andi %[[MOVED_1]], %[[MOVED_2]] {moved_op_4}
// CHECK-DAG: %[[MOVED_4:.+]] = arith.addi {{.*}} {moved_op_2}
// CHECK-DAG: %[[MOVED_5:.+]] = arith.subi %[[MOVED_4]], %[[MOVED_4]] {moved_op_5}
// CHECK: %[[BEFORE:.+]] = "before"
// CHECK: %[[FOO:.+]] = "foo"(%[[MOVED_3]], %[[MOVED_5]])
// CHECK: return %[[FOO]]
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["moved_op_4"]} in %arg0
%op1 = transform.structured.match ops{["arith.andi"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["moved_op_5"]} in %arg0
%op2 = transform.structured.match ops{["arith.subi"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op3 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
@@ -341,17 +345,17 @@ module attributes {transform.with_named_sequence} {
// Move only those value definitions that are not dominated by insertion point
func.func @move_only_required_defns() -> (f32, f32, f32, f32) {
%0 = "unmoved_op"() : () -> (f32)
%1 = "dummy_op"() : () -> (f32)
%2 = "before"() : () -> (f32)
%3 = "moved_op"() : () -> (f32)
return %0, %1, %2, %3 : f32, f32, f32, f32
func.func @move_only_required_defns(%arg0 : index) -> (index, index, index, index) {
%0 = "unmoved_op"() : () -> (index)
%1 = "dummy_op"() : () -> (index)
%2 = "before"() : () -> (index)
%3 = arith.addi %arg0, %arg0 {moved_op} : index
return %0, %1, %2, %3 : index, index, index, index
}
// CHECK-LABEL: func @move_only_required_defns()
// CHECK-LABEL: func @move_only_required_defns(
// CHECK: %[[UNMOVED:.+]] = "unmoved_op"
// CHECK: %[[DUMMY:.+]] = "dummy_op"
// CHECK: %[[MOVED:.+]] = "moved_op"
// CHECK: %[[MOVED:.+]] = arith.addi {{.*}} {moved_op}
// CHECK: %[[BEFORE:.+]] = "before"
module attributes {transform.with_named_sequence} {
@@ -362,7 +366,7 @@ module attributes {transform.with_named_sequence} {
: (!transform.any_op) -> !transform.any_op
%op3 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op4 = transform.structured.match ops{["moved_op"]} in %arg0
%op4 = transform.structured.match ops{["arith.addi"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
%v2 = transform.get_result %op4[0] : (!transform.any_op) -> !transform.any_value
@@ -374,19 +378,19 @@ module attributes {transform.with_named_sequence} {
// -----
// Move only those value definitions that are not dominated by insertion point
// Move only those value definitions that are not dominated by insertion point (duplicate test)
func.func @move_only_required_defns() -> (f32, f32, f32, f32) {
%0 = "unmoved_op"() : () -> (f32)
%1 = "dummy_op"() : () -> (f32)
%2 = "before"() : () -> (f32)
%3 = "moved_op"() : () -> (f32)
return %0, %1, %2, %3 : f32, f32, f32, f32
func.func @move_only_required_defns_2(%arg0 : index) -> (index, index, index, index) {
%0 = "unmoved_op"() : () -> (index)
%1 = "dummy_op"() : () -> (index)
%2 = "before"() : () -> (index)
%3 = arith.subi %arg0, %arg0 {moved_op} : index
return %0, %1, %2, %3 : index, index, index, index
}
// CHECK-LABEL: func @move_only_required_defns()
// CHECK-LABEL: func @move_only_required_defns_2(
// CHECK: %[[UNMOVED:.+]] = "unmoved_op"
// CHECK: %[[DUMMY:.+]] = "dummy_op"
// CHECK: %[[MOVED:.+]] = "moved_op"
// CHECK: %[[MOVED:.+]] = arith.subi {{.*}} {moved_op}
// CHECK: %[[BEFORE:.+]] = "before"
module attributes {transform.with_named_sequence} {
@@ -397,7 +401,7 @@ module attributes {transform.with_named_sequence} {
: (!transform.any_op) -> !transform.any_op
%op3 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op4 = transform.structured.match ops{["moved_op"]} in %arg0
%op4 = transform.structured.match ops{["arith.subi"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value
%v2 = transform.get_result %op4[0] : (!transform.any_op) -> !transform.any_value
@@ -410,23 +414,23 @@ module attributes {transform.with_named_sequence} {
// -----
// Check handling of block arguments
func.func @move_only_required_defns() -> (f32, f32) {
%0 = "unmoved_op"() : () -> (f32)
cf.br ^bb0(%0 : f32)
^bb0(%arg0 : f32) :
%1 = "before"() : () -> (f32)
%2 = "moved_op"(%arg0) : (f32) -> (f32)
return %1, %2 : f32, f32
func.func @move_with_block_arguments() -> (index, index) {
%0 = "unmoved_op"() : () -> (index)
cf.br ^bb0(%0 : index)
^bb0(%arg0 : index) :
%1 = "before"() : () -> (index)
%2 = arith.addi %arg0, %arg0 {moved_op} : index
return %1, %2 : index, index
}
// CHECK-LABEL: func @move_only_required_defns()
// CHECK: %[[MOVED:.+]] = "moved_op"
// CHECK-LABEL: func @move_with_block_arguments()
// CHECK: %[[MOVED:.+]] = arith.addi {{.*}} {moved_op}
// CHECK: %[[BEFORE:.+]] = "before"
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["moved_op"]} in %arg0
%op2 = transform.structured.match ops{["arith.addi"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
transform.test.move_value_defns %v1 before %op1
@@ -438,20 +442,20 @@ module attributes {transform.with_named_sequence} {
// -----
// Do not move across basic blocks
func.func @no_move_across_basic_blocks() -> (f32, f32) {
%0 = "unmoved_op"() : () -> (f32)
%1 = "before"() : () -> (f32)
cf.br ^bb0(%0 : f32)
^bb0(%arg0 : f32) :
%2 = "moved_op"(%arg0) : (f32) -> (f32)
return %1, %2 : f32, f32
func.func @no_move_across_basic_blocks() -> (index, index) {
%0 = "unmoved_op"() : () -> (index)
%1 = "before"() : () -> (index)
cf.br ^bb0(%0 : index)
^bb0(%arg0 : index) :
%2 = arith.addi %arg0, %arg0 {moved_op} : index
return %1, %2 : index, index
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["moved_op"]} in %arg0
%op2 = transform.structured.match ops{["arith.addi"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
// expected-remark@+1{{unsupported case of moving definition of value before an insertion point in a different basic block}}
@@ -463,24 +467,22 @@ module attributes {transform.with_named_sequence} {
// -----
func.func @move_isolated_from_above() -> () {
%1 = "before"() : () -> (f32)
%2 = "moved0"() : () -> (f32)
%3 = test.isolated_one_region_op %2 {} : f32 -> f32
%4 = "moved1"(%3) : (f32) -> (f32)
func.func @move_isolated_from_above(%arg0 : index) -> () {
%1 = "before"() : () -> (index)
%2 = arith.addi %arg0, %arg0 {moved0} : index
%3 = arith.muli %2, %2 {moved1} : index
return
}
// CHECK-LABEL: func @move_isolated_from_above()
// CHECK: %[[MOVED0:.+]] = "moved0"
// CHECK: %[[ISOLATED:.+]] = test.isolated_one_region_op %[[MOVED0]]
// CHECK: %[[MOVED1:.+]] = "moved1"(%[[ISOLATED]])
// CHECK-LABEL: func @move_isolated_from_above(
// CHECK: %[[MOVED0:.+]] = arith.addi {{.*}} {moved0}
// CHECK: %[[MOVED1:.+]] = arith.muli %[[MOVED0]], %[[MOVED0]] {moved1}
// CHECK: %[[BEFORE:.+]] = "before"
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["moved1"]} in %arg0
%op2 = transform.structured.match ops{["arith.muli"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
transform.test.move_value_defns %v1 before %op1

View File

@@ -44,7 +44,9 @@ def TestMoveValueDefns :
DeclareOpInterfaceMethods<TransformOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Moves all dependencies of on operation before another operation.
Moves all dependencies of a list of values before another operation.
Only pure operations are moved. If there is a side effecting op in the
dependency chain no operations are moved.
}];
let arguments =