mirror of
https://github.com/intel/llvm.git
synced 2026-01-24 08:30:34 +08:00
[mlir][vector] Group re-order patterns together (#102856)
Group all patterns that re-order vector.transpose and vector.broadcast
Ops (*) under `populateSinkVectorOpsPatterns`. These patterns are
normally used to "sink" redundant Vector Ops, hence grouping together.
Example:
```mlir
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
%r = arith.addf %at, %bt : vector<2x4xf32>
```
would get converted to:
```mlir
%0 = arith.addf %a, %b : vector<4x2xf32>
%r = vector.transpose %0, [1, 0] : vector<2x4xf32>
```
This patch also moves all tests for these patterns so that all of them
are:
* run under one test-flag: `test-vector-sink-patterns`,
* located in one file: "vector-sink.mlir".
To facilitate this change:
* `-test-sink-vector-broadcast` is renamed as
`test-vector-sink-patterns`,
* "sink-vector-broadcast.mlir" is renamed as "vector-sink.mlir",
* tests for `ReorderCastOpsOnBroadcast` and
`ReorderElementwiseOpsOnTranspose` patterns are moved from
"vector-reduce-to-contract.mlir" to "vector-sink.mlir",
* `ReorderElementwiseOpsOnTranspose` patterns are removed from
`populateVectorReductionToContractPatterns` and added to (newly
created) `populateSinkVectorOpsPatterns`,
* `ReorderCastOpsOnBroadcast` patterns are removed from
`populateVectorReductionToContractPatterns` - these are already
present in `populateSinkVectorOpsPatterns`.
This should allow us better layering and more straightforward testing.
For the latter, the goal is to be able to easily identify which pattern
a particular test is exercising (especially when it's a specific
pattern).
NOTES FOR DOWNSTREAM USERS
In order to preserve the current functionality, please make sure to add
* `populateSinkVectorOpsPatterns`,
wherever you are using `populateVectorReductionToContractPatterns`.
Also, rename `populateSinkVectorBroadcastPatterns` as
`populateSinkVectorOpsPatterns`.
(*) I didn't notice any other re-order patterns.
This commit is contained in:
committed by
GitHub
parent
a434cac523
commit
42944da5ba
@@ -144,9 +144,22 @@ void populateVectorTransferFullPartialPatterns(
|
||||
void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
|
||||
RewritePatternSet &patterns, PatternBenefit benefit = 1);
|
||||
|
||||
/// Patterns that remove redundant vector broadcasts.
|
||||
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
|
||||
PatternBenefit benefit = 1);
|
||||
/// Patterns that remove redundant Vector Ops by re-ordering them with
|
||||
/// e.g. elementwise Ops:
|
||||
/// ```
|
||||
/// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
|
||||
/// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
|
||||
/// %r = arith.addf %at, %bt : vector<2x4xf32>
|
||||
/// ```
|
||||
/// gets converted to:
|
||||
/// ```
|
||||
/// %0 = arith.addf %a, %b : vector<4x2xf32>
|
||||
/// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
|
||||
/// ```
|
||||
/// At the moment, these patterns are limited to vector.broadcast and
|
||||
/// vector.transpose.
|
||||
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Patterns that fold chained vector reductions. These patterns assume that
|
||||
/// elementwise operations (e.g., `arith.addf` with vector operands) are
|
||||
|
||||
@@ -3452,7 +3452,7 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
|
||||
if (!getDisableMultiReductionToContractPatterns())
|
||||
vector::populateVectorReductionToContractPatterns(patterns);
|
||||
|
||||
vector::populateSinkVectorBroadcastPatterns(patterns);
|
||||
vector::populateSinkVectorOpsPatterns(patterns);
|
||||
|
||||
patterns.add<linalg::LinalgCopyVTRForwardingPattern,
|
||||
linalg::LinalgCopyVTWForwardingPattern>(ctx,
|
||||
|
||||
@@ -2030,8 +2030,7 @@ void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
|
||||
void mlir::vector::populateVectorReductionToContractPatterns(
|
||||
RewritePatternSet &patterns, PatternBenefit benefit) {
|
||||
patterns.add<MultiReduceToContract, CombineContractBroadcast,
|
||||
CombineContractABTranspose, CombineContractResultTranspose,
|
||||
ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnTranspose>(
|
||||
CombineContractABTranspose, CombineContractResultTranspose>(
|
||||
patterns.getContext(), benefit);
|
||||
}
|
||||
|
||||
@@ -2043,10 +2042,11 @@ void mlir::vector::
|
||||
benefit);
|
||||
}
|
||||
|
||||
void mlir::vector::populateSinkVectorBroadcastPatterns(
|
||||
RewritePatternSet &patterns, PatternBenefit benefit) {
|
||||
patterns.add<ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnBroadcast>(
|
||||
patterns.getContext(), benefit);
|
||||
void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
|
||||
PatternBenefit benefit) {
|
||||
patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
|
||||
ReorderElementwiseOpsOnBroadcast>(patterns.getContext(),
|
||||
benefit);
|
||||
}
|
||||
|
||||
void mlir::vector::populateChainedVectorReductionFoldingPatterns(
|
||||
|
||||
@@ -245,128 +245,6 @@ func.func @contract_broadcast_would_have_no_reduction_dim_pair(%arg0 : vector<1x
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// [Pattern: ReorderCastOpsOnBroadcast]
|
||||
//
|
||||
// Reorder casting ops and vector ops. The casting ops have almost identical
|
||||
// pattern, so only arith.extsi op is tested.
|
||||
//
|
||||
// TODO: Potential duplication with sink-vector-broadcast.mlir
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// -----
|
||||
|
||||
func.func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> {
|
||||
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4xi8> to vector<4xi32>
|
||||
// CHECK: vector.broadcast %[[EXT:.+]] : vector<4xi32> to vector<2x4xi32>
|
||||
%b = vector.broadcast %a : vector<4xi8> to vector<2x4xi8>
|
||||
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
|
||||
return %r : vector<2x4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
|
||||
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
|
||||
// CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32>
|
||||
%b = vector.broadcast %a : i8 to vector<2x4xi8>
|
||||
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
|
||||
return %r : vector<2x4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// [Pattern: ReorderElementwiseOpsOnTranspose]
|
||||
//
|
||||
// TODO: Potential duplication with sink-vector-broadcast.mlir
|
||||
//===----------------------------------------------------------------------===//
|
||||
func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
|
||||
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4x2xi8> to vector<4x2xi32>
|
||||
// CHECK: vector.transpose %[[EXT]], [1, 0] : vector<4x2xi32> to vector<2x4xi32>
|
||||
%b = vector.transpose %a, [1, 0]: vector<4x2xi8> to vector<2x4xi8>
|
||||
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
|
||||
return %r : vector<2x4xi32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Reorder elementwise ops and vector ops.
|
||||
// TODO: Potential duplication with sink-vector-broadcast.mlir
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transpose_elementwise_same_type
|
||||
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
|
||||
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x2xf32>
|
||||
// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0]
|
||||
// CHECK: return %[[T]]
|
||||
|
||||
func.func @transpose_elementwise_same_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
|
||||
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
|
||||
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
|
||||
%r = arith.addf %at, %bt : vector<2x4xf32>
|
||||
return %r : vector<2x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transpose_elementwise_diff_operand_types
|
||||
// CHECK-SAME: (%[[COND:.+]]: vector<4x2xi1>, %[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
|
||||
// CHECK: %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<4x2xi1>, vector<4x2xf32>
|
||||
// CHECK: %[[T:.+]] = vector.transpose %[[S]], [1, 0] : vector<4x2xf32> to vector<2x4xf32>
|
||||
// CHECK: return %[[T]]
|
||||
func.func @transpose_elementwise_diff_operand_types(%cond: vector<4x2xi1>, %a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
|
||||
%condt = vector.transpose %cond, [1, 0]: vector<4x2xi1> to vector<2x4xi1>
|
||||
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
|
||||
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
|
||||
%r = arith.select %condt, %at, %bt : vector<2x4xi1>, vector<2x4xf32>
|
||||
return %r : vector<2x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type
|
||||
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
|
||||
// CHECK: %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<4x2xf32>
|
||||
// CHECK: %[[T:.+]] = vector.transpose %[[CMP]], [1, 0] : vector<4x2xi1> to vector<2x4xi1>
|
||||
// CHECK: return %[[T]]
|
||||
func.func @transpose_elementwise_diff_operand_result_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xi1> {
|
||||
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
|
||||
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
|
||||
%r = arith.cmpf olt, %at, %bt : vector<2x4xf32>
|
||||
return %r : vector<2x4xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transpose_elementwise_splat_constant
|
||||
// CHECK-SAME: (%[[A:.+]]: vector<4x6x3x2xf32>)
|
||||
// CHECK: %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<4x6x3x2xf32>
|
||||
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x6x3x2xf32>
|
||||
// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0, 3, 2] : vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
|
||||
// CHECK: return %[[T:.+]] : vector<6x4x2x3xf32>
|
||||
|
||||
func.func @transpose_elementwise_splat_constant(%a : vector<4x6x3x2xf32>) -> vector<6x4x2x3xf32> {
|
||||
%b = arith.constant dense<5.0> : vector<6x4x2x3xf32>
|
||||
%at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
|
||||
%r = arith.addf %at, %b : vector<6x4x2x3xf32>
|
||||
return %r : vector<6x4x2x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transpose_elementwise_diff_map
|
||||
// CHECK: vector.transpose
|
||||
// CHECK: vector.transpose
|
||||
// CHECK: arith.addf
|
||||
func.func @transpose_elementwise_diff_map(%a : vector<4x6x3x2xf32>, %b: vector<6x2x4x3xf32>) -> vector<6x4x2x3xf32> {
|
||||
%at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
|
||||
%bt = vector.transpose %b, [0, 2, 1, 3]: vector<6x2x4x3xf32> to vector<6x4x2x3xf32>
|
||||
%r = arith.addf %at, %bt : vector<6x4x2x3xf32>
|
||||
return %r : vector<6x4x2x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[$LHS_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s -test-sink-vector-broadcast -split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-vector-sink-patterns -split-input-file | FileCheck %s
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// [Pattern: ReorderElementwiseOpsOnBroadcast]
|
||||
@@ -208,3 +208,115 @@ func.func @negative_op_only_supports_vectors(%arg0 : f32) -> vector<1xf32> {
|
||||
%1 = vector.fma %0, %0, %0 : vector<1xf32>
|
||||
return %1 : vector<1xf32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// [Pattern: ReorderCastOpsOnBroadcast]
|
||||
//
|
||||
// Reorder casting ops and vector ops. The casting ops have almost identical
|
||||
// pattern, so only arith.extsi op is tested.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// -----
|
||||
|
||||
func.func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> {
|
||||
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4xi8> to vector<4xi32>
|
||||
// CHECK: vector.broadcast %[[EXT:.+]] : vector<4xi32> to vector<2x4xi32>
|
||||
%b = vector.broadcast %a : vector<4xi8> to vector<2x4xi8>
|
||||
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
|
||||
return %r : vector<2x4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
|
||||
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
|
||||
// CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32>
|
||||
%b = vector.broadcast %a : i8 to vector<2x4xi8>
|
||||
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
|
||||
return %r : vector<2x4xi32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// [Pattern: ReorderElementwiseOpsOnTranspose]
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
|
||||
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4x2xi8> to vector<4x2xi32>
|
||||
// CHECK: vector.transpose %[[EXT]], [1, 0] : vector<4x2xi32> to vector<2x4xi32>
|
||||
%b = vector.transpose %a, [1, 0]: vector<4x2xi8> to vector<2x4xi8>
|
||||
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
|
||||
return %r : vector<2x4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transpose_elementwise_same_type
|
||||
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
|
||||
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x2xf32>
|
||||
// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0]
|
||||
// CHECK: return %[[T]]
|
||||
|
||||
func.func @transpose_elementwise_same_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
|
||||
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
|
||||
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
|
||||
%r = arith.addf %at, %bt : vector<2x4xf32>
|
||||
return %r : vector<2x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transpose_elementwise_diff_operand_types
|
||||
// CHECK-SAME: (%[[COND:.+]]: vector<4x2xi1>, %[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
|
||||
// CHECK: %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<4x2xi1>, vector<4x2xf32>
|
||||
// CHECK: %[[T:.+]] = vector.transpose %[[S]], [1, 0] : vector<4x2xf32> to vector<2x4xf32>
|
||||
// CHECK: return %[[T]]
|
||||
func.func @transpose_elementwise_diff_operand_types(%cond: vector<4x2xi1>, %a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
|
||||
%condt = vector.transpose %cond, [1, 0]: vector<4x2xi1> to vector<2x4xi1>
|
||||
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
|
||||
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
|
||||
%r = arith.select %condt, %at, %bt : vector<2x4xi1>, vector<2x4xf32>
|
||||
return %r : vector<2x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type
|
||||
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
|
||||
// CHECK: %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<4x2xf32>
|
||||
// CHECK: %[[T:.+]] = vector.transpose %[[CMP]], [1, 0] : vector<4x2xi1> to vector<2x4xi1>
|
||||
// CHECK: return %[[T]]
|
||||
func.func @transpose_elementwise_diff_operand_result_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xi1> {
|
||||
%at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
|
||||
%bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
|
||||
%r = arith.cmpf olt, %at, %bt : vector<2x4xf32>
|
||||
return %r : vector<2x4xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transpose_elementwise_splat_constant
|
||||
// CHECK-SAME: (%[[A:.+]]: vector<4x6x3x2xf32>)
|
||||
// CHECK: %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<4x6x3x2xf32>
|
||||
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x6x3x2xf32>
|
||||
// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0, 3, 2] : vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
|
||||
// CHECK: return %[[T:.+]] : vector<6x4x2x3xf32>
|
||||
|
||||
func.func @transpose_elementwise_splat_constant(%a : vector<4x6x3x2xf32>) -> vector<6x4x2x3xf32> {
|
||||
%b = arith.constant dense<5.0> : vector<6x4x2x3xf32>
|
||||
%at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
|
||||
%r = arith.addf %at, %b : vector<6x4x2x3xf32>
|
||||
return %r : vector<6x4x2x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transpose_elementwise_diff_map
|
||||
// CHECK: vector.transpose
|
||||
// CHECK: vector.transpose
|
||||
// CHECK: arith.addf
|
||||
func.func @transpose_elementwise_diff_map(%a : vector<4x6x3x2xf32>, %b: vector<6x2x4x3xf32>) -> vector<6x4x2x3xf32> {
|
||||
%at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
|
||||
%bt = vector.transpose %b, [0, 2, 1, 3]: vector<6x2x4x3xf32> to vector<6x4x2x3xf32>
|
||||
%r = arith.addf %at, %bt : vector<6x4x2x3xf32>
|
||||
return %r : vector<6x4x2x3xf32>
|
||||
}
|
||||
@@ -374,27 +374,27 @@ struct TestVectorTransferCollapseInnerMostContiguousDims
|
||||
}
|
||||
};
|
||||
|
||||
struct TestSinkVectorBroadcast
|
||||
: public PassWrapper<TestSinkVectorBroadcast, OperationPass<func::FuncOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSinkVectorBroadcast)
|
||||
struct TestVectorSinkPatterns
|
||||
: public PassWrapper<TestVectorSinkPatterns, OperationPass<func::FuncOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorSinkPatterns)
|
||||
|
||||
TestSinkVectorBroadcast() = default;
|
||||
TestSinkVectorBroadcast(const TestSinkVectorBroadcast &pass) = default;
|
||||
TestVectorSinkPatterns() = default;
|
||||
TestVectorSinkPatterns(const TestVectorSinkPatterns &pass) = default;
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<memref::MemRefDialect, affine::AffineDialect>();
|
||||
}
|
||||
|
||||
StringRef getArgument() const final { return "test-sink-vector-broadcast"; }
|
||||
StringRef getArgument() const final { return "test-vector-sink-patterns"; }
|
||||
|
||||
StringRef getDescription() const final {
|
||||
return "Test lowering patterns that eliminate redundant brodacast "
|
||||
"operations.";
|
||||
"and transpose operations.";
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateSinkVectorBroadcastPatterns(patterns);
|
||||
populateSinkVectorOpsPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
@@ -919,7 +919,7 @@ void registerTestVectorLowerings() {
|
||||
|
||||
PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
|
||||
|
||||
PassRegistration<TestSinkVectorBroadcast>();
|
||||
PassRegistration<TestVectorSinkPatterns>();
|
||||
|
||||
PassRegistration<TestVectorReduceToContractPatternsPatterns>();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user