mirror of
https://github.com/intel/llvm.git
synced 2026-02-08 08:57:43 +08:00
[mlir][vector] Allow vector distribution with multiple written elements (#75122)
Add a configuration option to allow vector distribution with multiple elements written by a single lane. This is so that we can perform vector multi-reduction with multiple results per workgroup.
This commit is contained in:
@@ -43,7 +43,9 @@ void populateWarpExecuteOnLane0OpToScfForPattern(
|
||||
using DistributionMapFn = std::function<AffineMap(Value)>;
|
||||
|
||||
/// Distribute transfer_write ops based on the affine map returned by
|
||||
/// `distributionMapFn`.
|
||||
/// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
|
||||
/// will not be distributed (it should be less than the warp size).
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// %0 = vector.warp_execute_on_lane_0(%id){
|
||||
@@ -67,7 +69,7 @@ using DistributionMapFn = std::function<AffineMap(Value)>;
|
||||
/// distribute, meaning writes should propagate first.
|
||||
void populateDistributeTransferWriteOpPatterns(
|
||||
RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
|
||||
PatternBenefit benefit = 2);
|
||||
unsigned maxNumElementsToExtract, PatternBenefit benefit = 2);
|
||||
|
||||
/// Move scalar operations with no dependency on the warp op outside of the
|
||||
/// region.
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Transforms/RegionUtils.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include <numeric>
|
||||
#include <utility>
|
||||
|
||||
@@ -458,7 +459,9 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
|
||||
}
|
||||
|
||||
/// Distribute transfer_write ops based on the affine map returned by
|
||||
/// `distributionMapFn`.
|
||||
/// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
|
||||
/// will not be distributed (it should be less than the warp size).
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// %0 = vector.warp_execute_on_lane_0(%id){
|
||||
@@ -476,9 +479,10 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
|
||||
/// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
|
||||
struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
|
||||
WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
|
||||
PatternBenefit b = 1)
|
||||
unsigned maxNumElementsToExtract, PatternBenefit b = 1)
|
||||
: OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
|
||||
distributionMapFn(std::move(fn)) {}
|
||||
distributionMapFn(std::move(fn)),
|
||||
maxNumElementsToExtract(maxNumElementsToExtract) {}
|
||||
|
||||
/// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
|
||||
/// are multiples of the distribution ratio are supported at the moment.
|
||||
@@ -553,10 +557,13 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
|
||||
Location loc = writeOp.getLoc();
|
||||
VectorType vecType = writeOp.getVectorType();
|
||||
|
||||
// Only sink out vector of 1 element for now to not serialize large vector
|
||||
// store. This can later be controlled by user.
|
||||
if (vecType.getNumElements() != 1)
|
||||
return failure();
|
||||
if (vecType.getNumElements() > maxNumElementsToExtract) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
warpOp,
|
||||
llvm::formatv(
|
||||
"writes more elements ({0}) than allowed to extract ({1})",
|
||||
vecType.getNumElements(), maxNumElementsToExtract));
|
||||
}
|
||||
|
||||
// Do not process warp ops that contain only TransferWriteOps.
|
||||
if (llvm::all_of(warpOp.getOps(), [](Operation &op) {
|
||||
@@ -616,6 +623,7 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
|
||||
|
||||
private:
|
||||
DistributionMapFn distributionMapFn;
|
||||
unsigned maxNumElementsToExtract = 1;
|
||||
};
|
||||
|
||||
/// Sink out elementwise op feeding into a warp op yield.
|
||||
@@ -1833,9 +1841,9 @@ void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
|
||||
|
||||
void mlir::vector::populateDistributeTransferWriteOpPatterns(
|
||||
RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
|
||||
PatternBenefit benefit) {
|
||||
unsigned maxNumElementsToExtract, PatternBenefit benefit) {
|
||||
patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
|
||||
benefit);
|
||||
maxNumElementsToExtract, benefit);
|
||||
}
|
||||
|
||||
void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
|
||||
|
||||
@@ -1,8 +1,20 @@
|
||||
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if | FileCheck %s --check-prefix=CHECK-SCF-IF
|
||||
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform" | FileCheck --check-prefixes=CHECK-HOIST %s
|
||||
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform distribute-transfer-write" | FileCheck --check-prefixes=CHECK-D %s
|
||||
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute=propagate-distribution -canonicalize | FileCheck --check-prefixes=CHECK-PROP %s
|
||||
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform distribute-transfer-write propagate-distribution" -canonicalize | FileCheck --check-prefixes=CHECK-DIST-AND-PROP %s
|
||||
// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
|
||||
// RUN: --test-vector-warp-distribute=rewrite-warp-ops-to-scf-if | FileCheck %s --check-prefix=CHECK-SCF-IF
|
||||
|
||||
// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
|
||||
// RUN: --test-vector-warp-distribute="hoist-uniform" | FileCheck --check-prefixes=CHECK-HOIST %s
|
||||
|
||||
// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
|
||||
// RUN: --test-vector-warp-distribute="hoist-uniform distribute-transfer-write max-transfer-write-elements=4" \
|
||||
// RUN: | FileCheck --check-prefixes=CHECK-D %s
|
||||
|
||||
// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
|
||||
// RUN: --test-vector-warp-distribute=propagate-distribution --canonicalize \
|
||||
// RUN: | FileCheck --check-prefixes=CHECK-PROP %s
|
||||
|
||||
// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
|
||||
// RUN: --test-vector-warp-distribute="hoist-uniform distribute-transfer-write propagate-distribution" \
|
||||
// RUN: --canonicalize | FileCheck --check-prefixes=CHECK-DIST-AND-PROP %s
|
||||
|
||||
// CHECK-SCF-IF-DAG: #[[$TIMES2:.*]] = affine_map<()[s0] -> (s0 * 2)>
|
||||
// CHECK-SCF-IF-DAG: #[[$TIMES4:.*]] = affine_map<()[s0] -> (s0 * 4)>
|
||||
@@ -134,6 +146,84 @@ func.func @warp_extract(%laneid: index, %arg1: memref<1024x1024xf32>, %gid : ind
|
||||
|
||||
// -----
|
||||
|
||||
// Check that we can distribute writes of the maximum allowed number of elements.
|
||||
|
||||
// CHECK-D-LABEL: func @warp_extract_4_elems(
|
||||
// CHECK-D: %[[WARPOP:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4x1xf32>)
|
||||
// CHECK-D: "test.dummy_op"
|
||||
// CHECK-D: "test.dummy_op"
|
||||
// CHECK-D: vector.yield %{{.*}}, %{{.*}} : vector<4xf32>, vector<4x1xf32>
|
||||
// CHECK-D: }
|
||||
// CHECK-D: vector.warp_execute_on_lane_0(%{{.*}})[32] {
|
||||
// CHECK-D: vector.transfer_write %[[WARPOP]]#1, %{{.*}}[%{{.*}}] {{.*}} : vector<4x1xf32>
|
||||
// CHECK-D: }
|
||||
// CHECK-D: vector.warp_execute_on_lane_0(%{{.*}})[32] {
|
||||
// CHECK-D: vector.transfer_write %[[WARPOP]]#0, %{{.*}}[%{{.*}}] {{.*}} : vector<4xf32>
|
||||
// CHECK-D: }
|
||||
|
||||
func.func @warp_extract_4_elems(%laneid: index, %arg1: memref<1024x1024xf32>, %gid : index) {
|
||||
vector.warp_execute_on_lane_0(%laneid)[32] {
|
||||
%c0 = arith.constant 0 : index
|
||||
%v = "test.dummy_op"() : () -> (vector<4xf32>)
|
||||
%v1 = "test.dummy_op"() : () -> (vector<4x1xf32>)
|
||||
vector.transfer_write %v1, %arg1[%c0, %c0] : vector<4x1xf32>, memref<1024x1024xf32>
|
||||
vector.transfer_write %v, %arg1[%c0, %c0] : vector<4xf32>, memref<1024x1024xf32>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that we do not distribute writes larger than the maximum allowed
|
||||
// number of elements.
|
||||
|
||||
// CHECK-D-LABEL: func @warp_extract_5_elems(
|
||||
// CHECK-D: arith.constant 0 : index
|
||||
// CHECK-D: vector.warp_execute_on_lane_0(%{{.*}})[32] {
|
||||
// CHECK-D: %[[V:.+]] = "test.dummy_op"
|
||||
// CHECK-D: %[[V1:.+]] = "test.dummy_op"
|
||||
// CHECK-D: vector.transfer_write %[[V1]], %{{.*}}[%{{.*}}] {{.*}} : vector<5x1xf32>
|
||||
// CHECK-D: vector.transfer_write %[[V]], %{{.*}}[%{{.*}}] {{.*}} : vector<5xf32>
|
||||
// CHECK-D: }
|
||||
|
||||
func.func @warp_extract_5_elems(%laneid: index, %arg1: memref<1024x1024xf32>, %gid : index) {
|
||||
vector.warp_execute_on_lane_0(%laneid)[32] {
|
||||
%c0 = arith.constant 0 : index
|
||||
%v = "test.dummy_op"() : () -> (vector<5xf32>)
|
||||
%v1 = "test.dummy_op"() : () -> (vector<5x1xf32>)
|
||||
vector.transfer_write %v1, %arg1[%c0, %c0] : vector<5x1xf32>, memref<1024x1024xf32>
|
||||
vector.transfer_write %v, %arg1[%c0, %c0] : vector<5xf32>, memref<1024x1024xf32>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that we do not distribute writes larger than the maximum allowed
|
||||
// number of elements, or multiples of the maximum number of elements.
|
||||
|
||||
// CHECK-D-LABEL: func @warp_extract_8_elems(
|
||||
// CHECK-D: arith.constant 0 : index
|
||||
// CHECK-D: vector.warp_execute_on_lane_0(%{{.*}})[32] {
|
||||
// CHECK-D: %[[V:.+]] = "test.dummy_op"
|
||||
// CHECK-D: %[[V1:.+]] = "test.dummy_op"
|
||||
// CHECK-D: vector.transfer_write %[[V1]], %{{.*}}[%{{.*}}] {{.*}} : vector<8x1xf32>
|
||||
// CHECK-D: vector.transfer_write %[[V]], %{{.*}}[%{{.*}}] {{.*}} : vector<8xf32>
|
||||
// CHECK-D: }
|
||||
|
||||
func.func @warp_extract_8_elems(%laneid: index, %arg1: memref<1024x1024xf32>, %gid : index) {
|
||||
vector.warp_execute_on_lane_0(%laneid)[32] {
|
||||
%c0 = arith.constant 0 : index
|
||||
%v = "test.dummy_op"() : () -> (vector<8xf32>)
|
||||
%v1 = "test.dummy_op"() : () -> (vector<8x1xf32>)
|
||||
vector.transfer_write %v1, %arg1[%c0, %c0] : vector<8x1xf32>, memref<1024x1024xf32>
|
||||
vector.transfer_write %v, %arg1[%c0, %c0] : vector<8xf32>, memref<1024x1024xf32>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-PROP-LABEL: func @warp_dead_result(
|
||||
func.func @warp_dead_result(%laneid: index) -> (vector<1xf32>) {
|
||||
// CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>)
|
||||
|
||||
@@ -568,6 +568,11 @@ struct TestVectorDistribution
|
||||
llvm::cl::desc("Test distribution of transfer write"),
|
||||
llvm::cl::init(false)};
|
||||
|
||||
Option<unsigned> maxTransferWriteElements{
|
||||
*this, "max-transfer-write-elements",
|
||||
llvm::cl::desc("Maximum number of transfer write elements to distribute"),
|
||||
llvm::cl::init(1)};
|
||||
|
||||
Option<bool> hoistUniform{*this, "hoist-uniform",
|
||||
llvm::cl::desc("Test hoist uniform"),
|
||||
llvm::cl::init(false)};
|
||||
@@ -624,7 +629,8 @@ struct TestVectorDistribution
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
} else if (distributeTransferWriteOps) {
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
|
||||
populateDistributeTransferWriteOpPatterns(patterns, distributionFn,
|
||||
maxTransferWriteElements);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
} else if (propagateDistribution) {
|
||||
RewritePatternSet patterns(ctx);
|
||||
|
||||
Reference in New Issue
Block a user