[mlir][transform] ApplyPatternsOp: Add check to prevent modifying the transform IR

Add an extra check to make sure that transform IR is not getting modified by this op while it is being interpreted. This generally dangerous and we may want to enforce this for all transform ops that modify the payload in the future.

Users should generally try to apply patterns only to the piece of IR where it is needed (e.g., a matched function) and not the entire module (which may contain the transform IR).

This revision is in response to a crash in a downstream compiler that was caused by a dead `transform.structured.match` op that was removed by the GreedyPatternRewriteDriver's DCE while the enclosing sequence was being interpreted.

Differential Revision: https://reviews.llvm.org/D153113
This commit is contained in:
Matthias Springer
2023-06-19 08:56:49 +02:00
parent 16b46dde0b
commit 726d076784
14 changed files with 108 additions and 92 deletions

View File

@@ -239,6 +239,22 @@ DiagnosedSilenceableFailure
transform::ApplyPatternsOp::applyToOne(Operation *target,
ApplyToEachResultList &results,
transform::TransformState &state) {
// Make sure that this transform is not applied to itself. Modifying the
// transform IR while it is being interpreted is generally dangerous. Even
// more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver
// performs many additional simplifications such as dead code elimination.
Operation *transformAncestor = getOperation();
while (transformAncestor) {
if (transformAncestor == target) {
DiagnosedDefiniteFailure diag =
emitDefiniteFailure()
<< "cannot apply transform to itself (or one of its ancestors)";
diag.attachNote(target->getLoc()) << "target payload op";
return diag;
}
transformAncestor = transformAncestor->getParentOp();
}
// Gather all specified patterns.
MLIRContext *ctx = target->getContext();
RewritePatternSet patterns(ctx);

View File

@@ -2,7 +2,9 @@
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
%func_op = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_op {
transform.apply_patterns.tensor.fold_tensor_empty
} : !transform.any_op
}
@@ -66,7 +68,9 @@ func.func @rank_reducing_empty_tensor_extract(%sz : index, %idx : index) -> tens
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
%func_op = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_op {
transform.apply_patterns.tensor.fold_tensor_empty
{fold_single_use_only = true}
} : !transform.any_op

View File

@@ -2,7 +2,9 @@
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
%func_op = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_op {
transform.apply_patterns.tensor.rewrite_as_constant
} : !transform.any_op
}

View File

@@ -155,3 +155,22 @@ transform.sequence failures(propagate) {
} : !transform.any_op
transform.test_print_remark_at_operand %0, "op was replaced" : !transform.any_op
}
// -----
// expected-note @below{{target payload op}}
module {
func.func @invalid_pattern_application_to_transform_ir() {
return
}
module {
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
// expected-error @below {{cannot apply transform to itself (or one of its ancestors)}}
transform.apply_patterns to %arg1 {
transform.apply_patterns.canonicalization
} : !transform.any_op
}
}
}

View File

@@ -210,7 +210,9 @@ func.func @redpar_vecmattrans2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<v
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
%func_op = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
} : !transform.any_op
}

View File

@@ -19,8 +19,6 @@ func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -
// CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32>
// CHECK: return %[[RESULT_VEC]]
// -----
func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -> f32 {
%0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
return %0 : f32
@@ -33,8 +31,6 @@ func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -
// CHECK: %[[RES:.*]] = vector.extract %[[INSERTED]][0] : vector<1xf32>
// CHECK: return %[[RES]]
// -----
func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi32>) -> vector<2x3xi32> {
%0 = vector.multi_reduction <add>, %arg0, %acc [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
return %0 : vector<2x3xi32>
@@ -76,8 +72,6 @@ func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi
// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[FLAT_RESULT_VEC]] : vector<6xi32> to vector<2x3xi32>
// CHECK: return %[[RESULT]]
// -----
func.func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>, %acc: vector<2x5xf32>) -> vector<2x5xf32> {
%0 = vector.multi_reduction <add>, %arg0, %acc [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
return %0 : vector<2x5xf32>
@@ -90,8 +84,6 @@ func.func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>, %acc: v
// CHECK: %[[RESULT:.+]] = vector.shape_cast %{{.*}} : vector<10xf32> to vector<2x5xf32>
// CHECK: return %[[RESULT]]
// -----
func.func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>, %acc: vector<2x4xf32>) -> vector<2x4xf32> {
%0 = vector.multi_reduction <mul>, %arg0, %acc [0] : vector<3x2x4xf32> to vector<2x4xf32>
return %0 : vector<2x4xf32>
@@ -143,8 +135,6 @@ func.func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>, %acc: vecto
// CHECK: %[[RESHAPED_VEC:.+]] = vector.shape_cast %[[RESULT_VEC]] : vector<8xf32> to vector<2x4xf32>
// CHECK: return %[[RESHAPED_VEC]]
// -----
func.func @vectorize_dynamic_reduction(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
%c0 = arith.constant 0 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
@@ -187,8 +177,6 @@ func.func @vectorize_dynamic_reduction(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf
// CHECK: %[[VAL_32:.*]] = vector.mask %[[VAL_31]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
// CHECK: %[[VAL_33:.*]] = vector.insertelement
// -----
func.func @vectorize_1d_dynamic_reduction(%arg0: tensor<?xf32>) -> f32 {
%c0 = arith.constant 0 : index
%dim = tensor.dim %arg0, %c0 : tensor<?xf32>
@@ -207,8 +195,6 @@ func.func @vectorize_1d_dynamic_reduction(%arg0: tensor<?xf32>) -> f32 {
// CHECK: %[[VAL_5:.*]] = vector.create_mask {{.*}} : vector<8xi1>
// CHECK: %[[VAL_7:.*]] = vector.mask %[[VAL_5]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
// -----
func.func @vectorize_dynamic_transpose_reduction(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
@@ -254,8 +240,6 @@ func.func @vectorize_dynamic_transpose_reduction(%arg0: tensor<?x?x?xf32>, %arg1
// CHECK: %[[VAL_159:.*]] = vector.mask %[[VAL_158]] { vector.reduction <add>
// CHECK: %[[VAL_160:.*]] = vector.insertelement %[[VAL_159]]
// -----
func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
%0 = vector.multi_reduction <add>, %arg0, %acc [0, 2] : vector<3x4x5xf32> to vector<4xf32>
return %0 : vector<4xf32>
@@ -267,7 +251,9 @@ func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
%func_op = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerreduction"
} : !transform.any_op
}

View File

@@ -190,7 +190,9 @@ func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x3xf32>, %acc: f32) -
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
%func_op = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel"
} : !transform.any_op
}

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s
// RUN: mlir-opt %s --test-transform-dialect-interpreter | FileCheck %s
func.func @transfer_read_rank_reducing(
%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>) -> vector<3x2xi8> {
@@ -8,44 +8,24 @@ func.func @transfer_read_rank_reducing(
memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8>
return %v : vector<3x2xi8>
}
// CHECK-LABEL: func @transfer_read_rank_reducing
// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
// CHECK: vector.transfer_read %[[SUBVIEW]]
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
transform.apply_patterns.vector.rank_reducing_subview_patterns
} : !transform.any_op
}
// -----
func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, %vec : vector<3x2xi8>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>
return
}
// CHECK-LABEL: func @transfer_write_rank_reducing
// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
transform.apply_patterns.vector.rank_reducing_subview_patterns
} : !transform.any_op
}
// -----
func.func @transfer_read_and_vector_rank_reducing(
%arg : memref<1x1x3x2x1xf32>) -> vector<3x2x1xf32> {
%c0 = arith.constant 0 : index
@@ -54,22 +34,12 @@ func.func @transfer_read_and_vector_rank_reducing(
memref<1x1x3x2x1xf32>, vector<3x2x1xf32>
return %v : vector<3x2x1xf32>
}
// CHECK-LABEL: func @transfer_read_and_vector_rank_reducing
// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2x1xf32>
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0] [1, 1, 3, 2, 1] [1, 1, 1, 1, 1]
// CHECK-SAME: memref<1x1x3x2x1xf32> to memref<3x2xf32>
// CHECK: vector.transfer_read %[[SUBVIEW]]{{.*}} {in_bounds = [true, true]} : memref<3x2xf32>, vector<3x2xf32>
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
transform.apply_patterns.vector.rank_reducing_subview_patterns
} : !transform.any_op
}
// -----
func.func @transfer_write_and_vector_rank_reducing(
%arg : memref<1x1x3x2x1xf32>,
%vec : vector<3x2x1xf32>) {
@@ -78,22 +48,12 @@ func.func @transfer_write_and_vector_rank_reducing(
vector<3x2x1xf32>, memref<1x1x3x2x1xf32>
return
}
// CHECK-LABEL: func @transfer_write_and_vector_rank_reducing
// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2x1xf32>
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0] [1, 1, 3, 2, 1] [1, 1, 1, 1, 1]
// CHECK-SAME: memref<1x1x3x2x1xf32> to memref<3x2xf32>
// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}} {in_bounds = [true, true]} : vector<3x2xf32>, memref<3x2xf32>
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
transform.apply_patterns.vector.rank_reducing_subview_patterns
} : !transform.any_op
}
// -----
func.func @transfer_read_and_vector_rank_reducing_to_0d(
%arg : memref<1x1x1x1x1xf32>) -> vector<1x1x1xf32> {
%c0 = arith.constant 0 : index
@@ -102,22 +62,12 @@ func.func @transfer_read_and_vector_rank_reducing_to_0d(
memref<1x1x1x1x1xf32>, vector<1x1x1xf32>
return %v : vector<1x1x1xf32>
}
// CHECK-LABEL: func @transfer_read_and_vector_rank_reducing_to_0d
// CHECK-SAME: %[[MEMREF:.+]]: memref<1x1x1x1x1xf32>
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[MEMREF]][0, 0, 0, 0, 0] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : memref<1x1x1x1x1xf32> to memref<f32>
// CHECK: %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]]{{.*}} : memref<f32>, vector<f32>
// CHECK: vector.shape_cast %[[READ]] : vector<f32> to vector<1x1x1xf32>
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
transform.apply_patterns.vector.rank_reducing_subview_patterns
} : !transform.any_op
}
// -----
func.func @transfer_write_and_vector_rank_reducing_to_0d(
%arg : memref<1x1x1x1x1xf32>,
%vec : vector<1x1x1xf32>) {
@@ -126,7 +76,6 @@ func.func @transfer_write_and_vector_rank_reducing_to_0d(
vector<1x1x1xf32>, memref<1x1x1x1x1xf32>
return
}
// CHECK-LABEL: func @transfer_write_and_vector_rank_reducing_to_0d
// CHECK-SAME: %[[MEMREF:.+]]: memref<1x1x1x1x1xf32>, %[[VECTOR:.+]]: vector<1x1x1xf32>
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[MEMREF]][0, 0, 0, 0, 0] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : memref<1x1x1x1x1xf32> to memref<f32>
@@ -135,7 +84,9 @@ func.func @transfer_write_and_vector_rank_reducing_to_0d(
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
%func_op = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.rank_reducing_subview_patterns
} : !transform.any_op
}

View File

@@ -108,7 +108,9 @@ func.func @split_vector_transfer_read_strided_2d(
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
%func_op = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
} : !transform.any_op
}
@@ -169,7 +171,9 @@ func.func @split_vector_transfer_write_2d(%V: vector<4x8xf32>, %A: memref<?x8xf3
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
%func_op = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
} : !transform.any_op
}
@@ -237,7 +241,9 @@ func.func @split_vector_transfer_write_strided_2d(
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
%func_op = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
} : !transform.any_op
}

View File

@@ -103,7 +103,9 @@ func.func @split_vector_transfer_read_strided_2d(
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
%func_op = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer"
} : !transform.any_op
}
@@ -161,7 +163,9 @@ func.func @split_vector_transfer_write_2d(%V: vector<4x8xf32>, %A: memref<?x8xf3
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
%func_op = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer"
} : !transform.any_op
}
@@ -223,7 +227,9 @@ func.func @split_vector_transfer_write_strided_2d(
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
%func_op = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer"
} : !transform.any_op
}
@@ -265,7 +271,9 @@ func.func @transfer_read_within_scf_for(%A : memref<?x?xf32>, %lb : index, %ub :
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
%func_op = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer"
} : !transform.any_op
}

View File

@@ -2,7 +2,9 @@
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
%func_op = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.fold_tensor_slice_into_transfer
} : !transform.any_op
}

View File

@@ -240,7 +240,9 @@ func.func @transfer_broadcasting_complex(%mem : memref<10x20x30x8x8xf32>, %i : i
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
%func_op = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.lower_transfer max_transfer_rank = 99
transform.apply_patterns.vector.transfer_permutation_patterns
} : !transform.any_op
@@ -362,7 +364,9 @@ func.func @transfer_write_broadcast_unit_dim(
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
%func_op = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.lower_transfer max_transfer_rank = 99
transform.apply_patterns.vector.transfer_permutation_patterns
} : !transform.any_op

View File

@@ -76,7 +76,9 @@ func.func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
%func_op = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.lower_transpose lowering_strategy = "eltwise"
} : !transform.any_op
}
@@ -99,7 +101,9 @@ func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
%func_op = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d"
} : !transform.any_op
}
@@ -118,7 +122,9 @@ func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
%func_op = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.lower_transpose lowering_strategy = "flat_transpose"
} : !transform.any_op
}
@@ -605,7 +611,9 @@ func.func @transpose210_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x8x1xf32>
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
%func_op = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.lower_transpose avx2_lowering_strategy = true
} : !transform.any_op
}
@@ -683,7 +691,9 @@ func.func @transpose_shuffle16x16xf32(%arg0: vector<16x16xf32>) -> vector<16x16x
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
%func_op = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_16x16"
} : !transform.any_op
}
@@ -762,7 +772,9 @@ func.func @transpose021_shuffle16x16xf32(%arg0: vector<1x16x16xf32>) -> vector<1
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
%func_op = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_16x16"
} : !transform.any_op
}

View File

@@ -31,7 +31,9 @@ func.func @entry() {
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.apply_patterns to %module_op {
%func_op = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_16x16"
} : !transform.any_op
}