From 42944da5ba7617bbc02f341e9ef401c325310a73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrzej=20Warzy=C5=84ski?= Date: Fri, 16 Aug 2024 16:53:53 +0100 Subject: [PATCH] [mlir][vector] Group re-order patterns together (#102856) Group all patterns that re-order vector.transpose and vector.broadcast Ops (*) under `populateSinkVectorOpsPatterns`. These patterns are normally used to "sink" redundant Vector Ops, hence grouping together. Example: ```mlir %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32> %r = arith.addf %at, %bt : vector<2x4xf32> ``` would get converted to: ```mlir %0 = arith.addf %a, %b : vector<4x2xf32> %r = vector.transpose %0, [1, 0] : vector<2x4xf32> ``` This patch also moves all tests for these patterns so that all of them are: * run under one test-flag: `test-vector-sink-patterns`, * located in one file: "vector-sink.mlir". To facilitate this change: * `-test-sink-vector-broadcast` is renamed as `test-vector-sink-patterns`, * "sink-vector-broadcast.mlir" is renamed as "vector-sink.mlir", * tests for `ReorderCastOpsOnBroadcast` and `ReorderElementwiseOpsOnTranspose` patterns are moved from "vector-reduce-to-contract.mlir" to "vector-sink.mlir", * `ReorderElementwiseOpsOnTranspose` patterns are removed from `populateVectorReductionToContractPatterns` and added to (newly created) `populateSinkVectorOpsPatterns`, * `ReorderCastOpsOnBroadcast` patterns are removed from `populateVectorReductionToContractPatterns` - these are already present in `populateSinkVectorOpsPatterns`. This should allow us better layering and more straightforward testing. For the latter, the goal is to be able to easily identify which pattern a particular test is exercising (especially when it's a specific pattern). NOTES FOR DOWNSTREAM USERS In order to preserve the current functionality, please make sure to add * `populateSinkVectorOpsPatterns`, wherever you are using `populateVectorReductionToContractPatterns`. Also, rename `populateSinkVectorBroadcastPatterns` as `populateSinkVectorOpsPatterns`. (*) I didn't notice any other re-order patterns. --- .../Vector/Transforms/VectorRewritePatterns.h | 19 ++- .../TransformOps/LinalgTransformOps.cpp | 2 +- .../Vector/Transforms/VectorTransforms.cpp | 12 +- .../Vector/vector-reduce-to-contract.mlir | 122 ------------------ ...vector-broadcast.mlir => vector-sink.mlir} | 114 +++++++++++++++- .../Dialect/Vector/TestVectorTransforms.cpp | 18 +-- 6 files changed, 145 insertions(+), 142 deletions(-) rename mlir/test/Dialect/Vector/{sink-vector-broadcast.mlir => vector-sink.mlir} (65%) diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index 10970fd03e6e..456d021bc793 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -144,9 +144,22 @@ void populateVectorTransferFullPartialPatterns( void populateVectorTransferCollapseInnerMostContiguousDimsPatterns( RewritePatternSet &patterns, PatternBenefit benefit = 1); -/// Patterns that remove redundant vector broadcasts. -void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); +/// Patterns that remove redundant Vector Ops by re-ordering them with +/// e.g. elementwise Ops: +/// ``` +/// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> +/// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32> +/// %r = arith.addf %at, %bt : vector<2x4xf32> +/// ``` +/// gets converted to: +/// ``` +/// %0 = arith.addf %a, %b : vector<4x2xf32> +/// %r = vector.transpose %0, [1, 0] : vector<2x4xf32> +/// ``` +/// At the moment, these patterns are limited to vector.broadcast and +/// vector.transpose. +void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); /// Patterns that fold chained vector reductions. These patterns assume that /// elementwise operations (e.g., `arith.addf` with vector operands) are diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index fbf4e29024f7..29b5631f61b4 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3452,7 +3452,7 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne( if (!getDisableMultiReductionToContractPatterns()) vector::populateVectorReductionToContractPatterns(patterns); - vector::populateSinkVectorBroadcastPatterns(patterns); + vector::populateSinkVectorOpsPatterns(patterns); patterns.add(ctx, diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index ccbaa3e97599..ad4e42b31962 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -2030,8 +2030,7 @@ void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT( void mlir::vector::populateVectorReductionToContractPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add( + CombineContractABTranspose, CombineContractResultTranspose>( patterns.getContext(), benefit); } @@ -2043,10 +2042,11 @@ void mlir::vector:: benefit); } -void mlir::vector::populateSinkVectorBroadcastPatterns( - RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add( - patterns.getContext(), benefit); +void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(patterns.getContext(), + benefit); } void mlir::vector::populateChainedVectorReductionFoldingPatterns( diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir index c0dbea81df89..24070dbf017a 100644 --- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir +++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir @@ -245,128 +245,6 @@ func.func @contract_broadcast_would_have_no_reduction_dim_pair(%arg0 : vector<1x } -//===----------------------------------------------------------------------===// -// [Pattern: ReorderCastOpsOnBroadcast] -// -// Reorder casting ops and vector ops. The casting ops have almost identical -// pattern, so only arith.extsi op is tested. -// -// TODO: Potential duplication with sink-vector-broadcast.mlir -//===----------------------------------------------------------------------===// - -// ----- - -func.func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> { - // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4xi8> to vector<4xi32> - // CHECK: vector.broadcast %[[EXT:.+]] : vector<4xi32> to vector<2x4xi32> - %b = vector.broadcast %a : vector<4xi8> to vector<2x4xi8> - %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32> - return %r : vector<2x4xi32> -} - -// ----- - -func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> { - // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32 - // CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32> - %b = vector.broadcast %a : i8 to vector<2x4xi8> - %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32> - return %r : vector<2x4xi32> -} - -// ----- - -//===----------------------------------------------------------------------===// -// [Pattern: ReorderElementwiseOpsOnTranspose] -// -// TODO: Potential duplication with sink-vector-broadcast.mlir -//===----------------------------------------------------------------------===// -func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> { - // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4x2xi8> to vector<4x2xi32> - // CHECK: vector.transpose %[[EXT]], [1, 0] : vector<4x2xi32> to vector<2x4xi32> - %b = vector.transpose %a, [1, 0]: vector<4x2xi8> to vector<2x4xi8> - %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32> - return %r : vector<2x4xi32> -} - -//===----------------------------------------------------------------------===// -// Reorder elementwise ops and vector ops. -// TODO: Potential duplication with sink-vector-broadcast.mlir -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @transpose_elementwise_same_type -// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>) -// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x2xf32> -// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0] -// CHECK: return %[[T]] - -func.func @transpose_elementwise_same_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> { - %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> - %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32> - %r = arith.addf %at, %bt : vector<2x4xf32> - return %r : vector<2x4xf32> -} - -// ----- - -// CHECK-LABEL: func @transpose_elementwise_diff_operand_types -// CHECK-SAME: (%[[COND:.+]]: vector<4x2xi1>, %[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>) -// CHECK: %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<4x2xi1>, vector<4x2xf32> -// CHECK: %[[T:.+]] = vector.transpose %[[S]], [1, 0] : vector<4x2xf32> to vector<2x4xf32> -// CHECK: return %[[T]] -func.func @transpose_elementwise_diff_operand_types(%cond: vector<4x2xi1>, %a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> { - %condt = vector.transpose %cond, [1, 0]: vector<4x2xi1> to vector<2x4xi1> - %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> - %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32> - %r = arith.select %condt, %at, %bt : vector<2x4xi1>, vector<2x4xf32> - return %r : vector<2x4xf32> -} - -// ----- - -// CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type -// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>) -// CHECK: %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<4x2xf32> -// CHECK: %[[T:.+]] = vector.transpose %[[CMP]], [1, 0] : vector<4x2xi1> to vector<2x4xi1> -// CHECK: return %[[T]] -func.func @transpose_elementwise_diff_operand_result_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xi1> { - %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> - %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32> - %r = arith.cmpf olt, %at, %bt : vector<2x4xf32> - return %r : vector<2x4xi1> -} - -// ----- - -// CHECK-LABEL: func @transpose_elementwise_splat_constant -// CHECK-SAME: (%[[A:.+]]: vector<4x6x3x2xf32>) -// CHECK: %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<4x6x3x2xf32> -// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x6x3x2xf32> -// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0, 3, 2] : vector<4x6x3x2xf32> to vector<6x4x2x3xf32> -// CHECK: return %[[T:.+]] : vector<6x4x2x3xf32> - -func.func @transpose_elementwise_splat_constant(%a : vector<4x6x3x2xf32>) -> vector<6x4x2x3xf32> { - %b = arith.constant dense<5.0> : vector<6x4x2x3xf32> - %at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32> - %r = arith.addf %at, %b : vector<6x4x2x3xf32> - return %r : vector<6x4x2x3xf32> -} - -// ----- - -// CHECK-LABEL: func @transpose_elementwise_diff_map -// CHECK: vector.transpose -// CHECK: vector.transpose -// CHECK: arith.addf -func.func @transpose_elementwise_diff_map(%a : vector<4x6x3x2xf32>, %b: vector<6x2x4x3xf32>) -> vector<6x4x2x3xf32> { - %at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32> - %bt = vector.transpose %b, [0, 2, 1, 3]: vector<6x2x4x3xf32> to vector<6x4x2x3xf32> - %r = arith.addf %at, %bt : vector<6x4x2x3xf32> - return %r : vector<6x4x2x3xf32> -} - // ----- // CHECK-DAG: #[[$LHS_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)> diff --git a/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir similarity index 65% rename from mlir/test/Dialect/Vector/sink-vector-broadcast.mlir rename to mlir/test/Dialect/Vector/vector-sink.mlir index dd2e98831a70..5a3699333265 100644 --- a/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir +++ b/mlir/test/Dialect/Vector/vector-sink.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-sink-vector-broadcast -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-vector-sink-patterns -split-input-file | FileCheck %s //----------------------------------------------------------------------------- // [Pattern: ReorderElementwiseOpsOnBroadcast] @@ -208,3 +208,115 @@ func.func @negative_op_only_supports_vectors(%arg0 : f32) -> vector<1xf32> { %1 = vector.fma %0, %0, %0 : vector<1xf32> return %1 : vector<1xf32> } + +//===----------------------------------------------------------------------===// +// [Pattern: ReorderCastOpsOnBroadcast] +// +// Reorder casting ops and vector ops. The casting ops have almost identical +// pattern, so only arith.extsi op is tested. +//===----------------------------------------------------------------------===// + +// ----- + +func.func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> { + // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4xi8> to vector<4xi32> + // CHECK: vector.broadcast %[[EXT:.+]] : vector<4xi32> to vector<2x4xi32> + %b = vector.broadcast %a : vector<4xi8> to vector<2x4xi8> + %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32> + return %r : vector<2x4xi32> +} + +// ----- + +func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> { + // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32 + // CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32> + %b = vector.broadcast %a : i8 to vector<2x4xi8> + %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32> + return %r : vector<2x4xi32> +} + +//===----------------------------------------------------------------------===// +// [Pattern: ReorderElementwiseOpsOnTranspose] +//===----------------------------------------------------------------------===// + +func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> { + // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4x2xi8> to vector<4x2xi32> + // CHECK: vector.transpose %[[EXT]], [1, 0] : vector<4x2xi32> to vector<2x4xi32> + %b = vector.transpose %a, [1, 0]: vector<4x2xi8> to vector<2x4xi8> + %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32> + return %r : vector<2x4xi32> +} + +// ----- + +// CHECK-LABEL: func @transpose_elementwise_same_type +// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>) +// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x2xf32> +// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0] +// CHECK: return %[[T]] + +func.func @transpose_elementwise_same_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> { + %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> + %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32> + %r = arith.addf %at, %bt : vector<2x4xf32> + return %r : vector<2x4xf32> +} + +// ----- + +// CHECK-LABEL: func @transpose_elementwise_diff_operand_types +// CHECK-SAME: (%[[COND:.+]]: vector<4x2xi1>, %[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>) +// CHECK: %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<4x2xi1>, vector<4x2xf32> +// CHECK: %[[T:.+]] = vector.transpose %[[S]], [1, 0] : vector<4x2xf32> to vector<2x4xf32> +// CHECK: return %[[T]] +func.func @transpose_elementwise_diff_operand_types(%cond: vector<4x2xi1>, %a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> { + %condt = vector.transpose %cond, [1, 0]: vector<4x2xi1> to vector<2x4xi1> + %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> + %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32> + %r = arith.select %condt, %at, %bt : vector<2x4xi1>, vector<2x4xf32> + return %r : vector<2x4xf32> +} + +// ----- + +// CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type +// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>) +// CHECK: %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<4x2xf32> +// CHECK: %[[T:.+]] = vector.transpose %[[CMP]], [1, 0] : vector<4x2xi1> to vector<2x4xi1> +// CHECK: return %[[T]] +func.func @transpose_elementwise_diff_operand_result_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xi1> { + %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> + %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32> + %r = arith.cmpf olt, %at, %bt : vector<2x4xf32> + return %r : vector<2x4xi1> +} + +// ----- + +// CHECK-LABEL: func @transpose_elementwise_splat_constant +// CHECK-SAME: (%[[A:.+]]: vector<4x6x3x2xf32>) +// CHECK: %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<4x6x3x2xf32> +// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x6x3x2xf32> +// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0, 3, 2] : vector<4x6x3x2xf32> to vector<6x4x2x3xf32> +// CHECK: return %[[T:.+]] : vector<6x4x2x3xf32> + +func.func @transpose_elementwise_splat_constant(%a : vector<4x6x3x2xf32>) -> vector<6x4x2x3xf32> { + %b = arith.constant dense<5.0> : vector<6x4x2x3xf32> + %at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32> + %r = arith.addf %at, %b : vector<6x4x2x3xf32> + return %r : vector<6x4x2x3xf32> +} + +// ----- + +// CHECK-LABEL: func @transpose_elementwise_diff_map +// CHECK: vector.transpose +// CHECK: vector.transpose +// CHECK: arith.addf +func.func @transpose_elementwise_diff_map(%a : vector<4x6x3x2xf32>, %b: vector<6x2x4x3xf32>) -> vector<6x4x2x3xf32> { + %at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32> + %bt = vector.transpose %b, [0, 2, 1, 3]: vector<6x2x4x3xf32> to vector<6x4x2x3xf32> + %r = arith.addf %at, %bt : vector<6x4x2x3xf32> + return %r : vector<6x4x2x3xf32> +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 29c763b622e8..72aaa7dc4f89 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -374,27 +374,27 @@ struct TestVectorTransferCollapseInnerMostContiguousDims } }; -struct TestSinkVectorBroadcast - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSinkVectorBroadcast) +struct TestVectorSinkPatterns + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorSinkPatterns) - TestSinkVectorBroadcast() = default; - TestSinkVectorBroadcast(const TestSinkVectorBroadcast &pass) = default; + TestVectorSinkPatterns() = default; + TestVectorSinkPatterns(const TestVectorSinkPatterns &pass) = default; void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } - StringRef getArgument() const final { return "test-sink-vector-broadcast"; } + StringRef getArgument() const final { return "test-vector-sink-patterns"; } StringRef getDescription() const final { return "Test lowering patterns that eliminate redundant brodacast " - "operations."; + "and transpose operations."; } void runOnOperation() override { RewritePatternSet patterns(&getContext()); - populateSinkVectorBroadcastPatterns(patterns); + populateSinkVectorOpsPatterns(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; @@ -919,7 +919,7 @@ void registerTestVectorLowerings() { PassRegistration(); - PassRegistration(); + PassRegistration(); PassRegistration();