[mlir][vector] Group re-order patterns together (#102856)

Group all patterns that re-order vector.transpose and vector.broadcast
Ops (*) under `populateSinkVectorOpsPatterns`. These patterns are
normally used to "sink" redundant Vector Ops, hence grouping together.
Example:

```mlir
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%r = arith.addf %at, %bt : vector<2x4xf32>
```
would get converted to:
```mlir
%0 = arith.addf %a, %b : vector<4x2xf32>
%r = vector.transpose %0, [1, 0] : vector<2x4xf32>
```

This patch also moves all tests for these patterns so that all of them
are:
  * run under one test-flag: `test-vector-sink-patterns`,
  * located in one file: "vector-sink.mlir".

To facilitate this change:
  * `-test-sink-vector-broadcast` is renamed as
    `test-vector-sink-patterns`,
  * "sink-vector-broadcast.mlir" is renamed as "vector-sink.mlir",
  * tests for `ReorderCastOpsOnBroadcast` and
    `ReorderElementwiseOpsOnTranspose` patterns are moved from
    "vector-reduce-to-contract.mlir" to "vector-sink.mlir",
  * `ReorderElementwiseOpsOnTranspose` patterns are removed from
    `populateVectorReductionToContractPatterns` and added to (newly
    created) `populateSinkVectorOpsPatterns`,
  * `ReorderCastOpsOnBroadcast` patterns are removed from
    `populateVectorReductionToContractPatterns` - these are already
    present in `populateSinkVectorOpsPatterns`.

This should allow us better layering and more straightforward testing.
For the latter, the goal is to be able to easily identify which pattern
a particular test is exercising (especially when it's a specific
pattern).

NOTES FOR DOWNSTREAM USERS

In order to preserve the current functionality, please make sure to add
  * `populateSinkVectorOpsPatterns`,

wherever you are using `populateVectorReductionToContractPatterns`.
Also, rename `populateSinkVectorBroadcastPatterns` as
`populateSinkVectorOpsPatterns`.

(*) I didn't notice any other re-order patterns.
This commit is contained in:
Andrzej Warzyński
2024-08-16 16:53:53 +01:00
committed by GitHub
parent a434cac523
commit 42944da5ba
6 changed files with 145 additions and 142 deletions

View File

@@ -144,9 +144,22 @@ void populateVectorTransferFullPartialPatterns(
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
/// Patterns that remove redundant vector broadcasts.
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
/// Patterns that remove redundant Vector Ops by re-ordering them with
/// e.g. elementwise Ops:
/// ```
/// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
/// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
/// %r = arith.addf %at, %bt : vector<2x4xf32>
/// ```
/// gets converted to:
/// ```
/// %0 = arith.addf %a, %b : vector<4x2xf32>
/// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
/// ```
/// At the moment, these patterns are limited to vector.broadcast and
/// vector.transpose.
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
/// Patterns that fold chained vector reductions. These patterns assume that
/// elementwise operations (e.g., `arith.addf` with vector operands) are

View File

@@ -3452,7 +3452,7 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
if (!getDisableMultiReductionToContractPatterns())
vector::populateVectorReductionToContractPatterns(patterns);
vector::populateSinkVectorBroadcastPatterns(patterns);
vector::populateSinkVectorOpsPatterns(patterns);
patterns.add<linalg::LinalgCopyVTRForwardingPattern,
linalg::LinalgCopyVTWForwardingPattern>(ctx,

View File

@@ -2030,8 +2030,7 @@ void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
void mlir::vector::populateVectorReductionToContractPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<MultiReduceToContract, CombineContractBroadcast,
CombineContractABTranspose, CombineContractResultTranspose,
ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnTranspose>(
CombineContractABTranspose, CombineContractResultTranspose>(
patterns.getContext(), benefit);
}
@@ -2043,10 +2042,11 @@ void mlir::vector::
benefit);
}
void mlir::vector::populateSinkVectorBroadcastPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnBroadcast>(
patterns.getContext(), benefit);
void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
ReorderElementwiseOpsOnBroadcast>(patterns.getContext(),
benefit);
}
void mlir::vector::populateChainedVectorReductionFoldingPatterns(

View File

@@ -245,128 +245,6 @@ func.func @contract_broadcast_would_have_no_reduction_dim_pair(%arg0 : vector<1x
}
//===----------------------------------------------------------------------===//
// [Pattern: ReorderCastOpsOnBroadcast]
//
// Reorder casting ops and vector ops. The casting ops have almost identical
// pattern, so only arith.extsi op is tested.
//
// TODO: Potential duplication with sink-vector-broadcast.mlir
//===----------------------------------------------------------------------===//
// -----
func.func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> {
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4xi8> to vector<4xi32>
// CHECK: vector.broadcast %[[EXT:.+]] : vector<4xi32> to vector<2x4xi32>
%b = vector.broadcast %a : vector<4xi8> to vector<2x4xi8>
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
return %r : vector<2x4xi32>
}
// -----
func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
// CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32>
%b = vector.broadcast %a : i8 to vector<2x4xi8>
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
return %r : vector<2x4xi32>
}
// -----
//===----------------------------------------------------------------------===//
// [Pattern: ReorderElementwiseOpsOnTranspose]
//
// TODO: Potential duplication with sink-vector-broadcast.mlir
//===----------------------------------------------------------------------===//
func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4x2xi8> to vector<4x2xi32>
// CHECK: vector.transpose %[[EXT]], [1, 0] : vector<4x2xi32> to vector<2x4xi32>
%b = vector.transpose %a, [1, 0]: vector<4x2xi8> to vector<2x4xi8>
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
return %r : vector<2x4xi32>
}
//===----------------------------------------------------------------------===//
// Reorder elementwise ops and vector ops.
// TODO: Potential duplication with sink-vector-broadcast.mlir
//===----------------------------------------------------------------------===//
// -----
// CHECK-LABEL: func @transpose_elementwise_same_type
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x2xf32>
// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0]
// CHECK: return %[[T]]
func.func @transpose_elementwise_same_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%r = arith.addf %at, %bt : vector<2x4xf32>
return %r : vector<2x4xf32>
}
// -----
// CHECK-LABEL: func @transpose_elementwise_diff_operand_types
// CHECK-SAME: (%[[COND:.+]]: vector<4x2xi1>, %[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
// CHECK: %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<4x2xi1>, vector<4x2xf32>
// CHECK: %[[T:.+]] = vector.transpose %[[S]], [1, 0] : vector<4x2xf32> to vector<2x4xf32>
// CHECK: return %[[T]]
func.func @transpose_elementwise_diff_operand_types(%cond: vector<4x2xi1>, %a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
%condt = vector.transpose %cond, [1, 0]: vector<4x2xi1> to vector<2x4xi1>
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%r = arith.select %condt, %at, %bt : vector<2x4xi1>, vector<2x4xf32>
return %r : vector<2x4xf32>
}
// -----
// CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
// CHECK: %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<4x2xf32>
// CHECK: %[[T:.+]] = vector.transpose %[[CMP]], [1, 0] : vector<4x2xi1> to vector<2x4xi1>
// CHECK: return %[[T]]
func.func @transpose_elementwise_diff_operand_result_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xi1> {
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%r = arith.cmpf olt, %at, %bt : vector<2x4xf32>
return %r : vector<2x4xi1>
}
// -----
// CHECK-LABEL: func @transpose_elementwise_splat_constant
// CHECK-SAME: (%[[A:.+]]: vector<4x6x3x2xf32>)
// CHECK: %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<4x6x3x2xf32>
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x6x3x2xf32>
// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0, 3, 2] : vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
// CHECK: return %[[T:.+]] : vector<6x4x2x3xf32>
func.func @transpose_elementwise_splat_constant(%a : vector<4x6x3x2xf32>) -> vector<6x4x2x3xf32> {
%b = arith.constant dense<5.0> : vector<6x4x2x3xf32>
%at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
%r = arith.addf %at, %b : vector<6x4x2x3xf32>
return %r : vector<6x4x2x3xf32>
}
// -----
// CHECK-LABEL: func @transpose_elementwise_diff_map
// CHECK: vector.transpose
// CHECK: vector.transpose
// CHECK: arith.addf
func.func @transpose_elementwise_diff_map(%a : vector<4x6x3x2xf32>, %b: vector<6x2x4x3xf32>) -> vector<6x4x2x3xf32> {
%at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
%bt = vector.transpose %b, [0, 2, 1, 3]: vector<6x2x4x3xf32> to vector<6x4x2x3xf32>
%r = arith.addf %at, %bt : vector<6x4x2x3xf32>
return %r : vector<6x4x2x3xf32>
}
// -----
// CHECK-DAG: #[[$LHS_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>

View File

@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-sink-vector-broadcast -split-input-file | FileCheck %s
// RUN: mlir-opt %s -test-vector-sink-patterns -split-input-file | FileCheck %s
//-----------------------------------------------------------------------------
// [Pattern: ReorderElementwiseOpsOnBroadcast]
@@ -208,3 +208,115 @@ func.func @negative_op_only_supports_vectors(%arg0 : f32) -> vector<1xf32> {
%1 = vector.fma %0, %0, %0 : vector<1xf32>
return %1 : vector<1xf32>
}
//===----------------------------------------------------------------------===//
// [Pattern: ReorderCastOpsOnBroadcast]
//
// Reorder casting ops and vector ops. The casting ops have almost identical
// pattern, so only arith.extsi op is tested.
//===----------------------------------------------------------------------===//
// -----
func.func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> {
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4xi8> to vector<4xi32>
// CHECK: vector.broadcast %[[EXT:.+]] : vector<4xi32> to vector<2x4xi32>
%b = vector.broadcast %a : vector<4xi8> to vector<2x4xi8>
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
return %r : vector<2x4xi32>
}
// -----
func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
// CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32>
%b = vector.broadcast %a : i8 to vector<2x4xi8>
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
return %r : vector<2x4xi32>
}
//===----------------------------------------------------------------------===//
// [Pattern: ReorderElementwiseOpsOnTranspose]
//===----------------------------------------------------------------------===//
func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4x2xi8> to vector<4x2xi32>
// CHECK: vector.transpose %[[EXT]], [1, 0] : vector<4x2xi32> to vector<2x4xi32>
%b = vector.transpose %a, [1, 0]: vector<4x2xi8> to vector<2x4xi8>
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
return %r : vector<2x4xi32>
}
// -----
// CHECK-LABEL: func @transpose_elementwise_same_type
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x2xf32>
// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0]
// CHECK: return %[[T]]
func.func @transpose_elementwise_same_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%r = arith.addf %at, %bt : vector<2x4xf32>
return %r : vector<2x4xf32>
}
// -----
// CHECK-LABEL: func @transpose_elementwise_diff_operand_types
// CHECK-SAME: (%[[COND:.+]]: vector<4x2xi1>, %[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
// CHECK: %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<4x2xi1>, vector<4x2xf32>
// CHECK: %[[T:.+]] = vector.transpose %[[S]], [1, 0] : vector<4x2xf32> to vector<2x4xf32>
// CHECK: return %[[T]]
func.func @transpose_elementwise_diff_operand_types(%cond: vector<4x2xi1>, %a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
%condt = vector.transpose %cond, [1, 0]: vector<4x2xi1> to vector<2x4xi1>
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%r = arith.select %condt, %at, %bt : vector<2x4xi1>, vector<2x4xf32>
return %r : vector<2x4xf32>
}
// -----
// CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
// CHECK: %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<4x2xf32>
// CHECK: %[[T:.+]] = vector.transpose %[[CMP]], [1, 0] : vector<4x2xi1> to vector<2x4xi1>
// CHECK: return %[[T]]
func.func @transpose_elementwise_diff_operand_result_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xi1> {
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%r = arith.cmpf olt, %at, %bt : vector<2x4xf32>
return %r : vector<2x4xi1>
}
// -----
// CHECK-LABEL: func @transpose_elementwise_splat_constant
// CHECK-SAME: (%[[A:.+]]: vector<4x6x3x2xf32>)
// CHECK: %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<4x6x3x2xf32>
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x6x3x2xf32>
// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0, 3, 2] : vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
// CHECK: return %[[T:.+]] : vector<6x4x2x3xf32>
func.func @transpose_elementwise_splat_constant(%a : vector<4x6x3x2xf32>) -> vector<6x4x2x3xf32> {
%b = arith.constant dense<5.0> : vector<6x4x2x3xf32>
%at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
%r = arith.addf %at, %b : vector<6x4x2x3xf32>
return %r : vector<6x4x2x3xf32>
}
// -----
// CHECK-LABEL: func @transpose_elementwise_diff_map
// CHECK: vector.transpose
// CHECK: vector.transpose
// CHECK: arith.addf
func.func @transpose_elementwise_diff_map(%a : vector<4x6x3x2xf32>, %b: vector<6x2x4x3xf32>) -> vector<6x4x2x3xf32> {
%at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
%bt = vector.transpose %b, [0, 2, 1, 3]: vector<6x2x4x3xf32> to vector<6x4x2x3xf32>
%r = arith.addf %at, %bt : vector<6x4x2x3xf32>
return %r : vector<6x4x2x3xf32>
}

View File

@@ -374,27 +374,27 @@ struct TestVectorTransferCollapseInnerMostContiguousDims
}
};
struct TestSinkVectorBroadcast
: public PassWrapper<TestSinkVectorBroadcast, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSinkVectorBroadcast)
struct TestVectorSinkPatterns
: public PassWrapper<TestVectorSinkPatterns, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorSinkPatterns)
TestSinkVectorBroadcast() = default;
TestSinkVectorBroadcast(const TestSinkVectorBroadcast &pass) = default;
TestVectorSinkPatterns() = default;
TestVectorSinkPatterns(const TestVectorSinkPatterns &pass) = default;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect, affine::AffineDialect>();
}
StringRef getArgument() const final { return "test-sink-vector-broadcast"; }
StringRef getArgument() const final { return "test-vector-sink-patterns"; }
StringRef getDescription() const final {
return "Test lowering patterns that eliminate redundant brodacast "
"operations.";
"and transpose operations.";
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateSinkVectorBroadcastPatterns(patterns);
populateSinkVectorOpsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
@@ -919,7 +919,7 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
PassRegistration<TestSinkVectorBroadcast>();
PassRegistration<TestVectorSinkPatterns>();
PassRegistration<TestVectorReduceToContractPatternsPatterns>();