[mlir][vector] Allow vector distribution with multiple written elements (#75122)

Add a configuration option to allow vector distribution with multiple
elements written by a single lane.

This is so that we can perform vector multi-reduction with multiple
results per workgroup.
This commit is contained in:
Jakub Kuderski
2023-12-12 13:15:17 -05:00
committed by GitHub
parent 42e4967140
commit 8063622721
4 changed files with 123 additions and 17 deletions

View File

@@ -43,7 +43,9 @@ void populateWarpExecuteOnLane0OpToScfForPattern(
using DistributionMapFn = std::function<AffineMap(Value)>;
/// Distribute transfer_write ops based on the affine map returned by
/// `distributionMapFn`.
/// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
/// will not be distributed (it should be less than the warp size).
///
/// Example:
/// ```
/// %0 = vector.warp_execute_on_lane_0(%id){
@@ -67,7 +69,7 @@ using DistributionMapFn = std::function<AffineMap(Value)>;
/// distribute, meaning writes should propagate first.
void populateDistributeTransferWriteOpPatterns(
RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
PatternBenefit benefit = 2);
unsigned maxNumElementsToExtract, PatternBenefit benefit = 2);
/// Move scalar operations with no dependency on the warp op outside of the
/// region.

View File

@@ -16,6 +16,7 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/FormatVariadic.h"
#include <numeric>
#include <utility>
@@ -458,7 +459,9 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
}
/// Distribute transfer_write ops based on the affine map returned by
/// `distributionMapFn`.
/// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
/// will not be distributed (it should be less than the warp size).
///
/// Example:
/// ```
/// %0 = vector.warp_execute_on_lane_0(%id){
@@ -476,9 +479,10 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
/// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
PatternBenefit b = 1)
unsigned maxNumElementsToExtract, PatternBenefit b = 1)
: OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
distributionMapFn(std::move(fn)) {}
distributionMapFn(std::move(fn)),
maxNumElementsToExtract(maxNumElementsToExtract) {}
/// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
/// are multiples of the distribution ratio are supported at the moment.
@@ -553,10 +557,13 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
Location loc = writeOp.getLoc();
VectorType vecType = writeOp.getVectorType();
// Only sink out vector of 1 element for now to not serialize large vector
// store. This can later be controlled by user.
if (vecType.getNumElements() != 1)
return failure();
if (vecType.getNumElements() > maxNumElementsToExtract) {
return rewriter.notifyMatchFailure(
warpOp,
llvm::formatv(
"writes more elements ({0}) than allowed to extract ({1})",
vecType.getNumElements(), maxNumElementsToExtract));
}
// Do not process warp ops that contain only TransferWriteOps.
if (llvm::all_of(warpOp.getOps(), [](Operation &op) {
@@ -616,6 +623,7 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
private:
DistributionMapFn distributionMapFn;
unsigned maxNumElementsToExtract = 1;
};
/// Sink out elementwise op feeding into a warp op yield.
@@ -1833,9 +1841,9 @@ void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
void mlir::vector::populateDistributeTransferWriteOpPatterns(
RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
PatternBenefit benefit) {
unsigned maxNumElementsToExtract, PatternBenefit benefit) {
patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
benefit);
maxNumElementsToExtract, benefit);
}
void mlir::vector::populatePropagateWarpVectorDistributionPatterns(

View File

@@ -1,8 +1,20 @@
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if | FileCheck %s --check-prefix=CHECK-SCF-IF
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform" | FileCheck --check-prefixes=CHECK-HOIST %s
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform distribute-transfer-write" | FileCheck --check-prefixes=CHECK-D %s
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute=propagate-distribution -canonicalize | FileCheck --check-prefixes=CHECK-PROP %s
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform distribute-transfer-write propagate-distribution" -canonicalize | FileCheck --check-prefixes=CHECK-DIST-AND-PROP %s
// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
// RUN: --test-vector-warp-distribute=rewrite-warp-ops-to-scf-if | FileCheck %s --check-prefix=CHECK-SCF-IF
// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
// RUN: --test-vector-warp-distribute="hoist-uniform" | FileCheck --check-prefixes=CHECK-HOIST %s
// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
// RUN: --test-vector-warp-distribute="hoist-uniform distribute-transfer-write max-transfer-write-elements=4" \
// RUN: | FileCheck --check-prefixes=CHECK-D %s
// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
// RUN: --test-vector-warp-distribute=propagate-distribution --canonicalize \
// RUN: | FileCheck --check-prefixes=CHECK-PROP %s
// RUN: mlir-opt %s --allow-unregistered-dialect --split-input-file \
// RUN: --test-vector-warp-distribute="hoist-uniform distribute-transfer-write propagate-distribution" \
// RUN: --canonicalize | FileCheck --check-prefixes=CHECK-DIST-AND-PROP %s
// CHECK-SCF-IF-DAG: #[[$TIMES2:.*]] = affine_map<()[s0] -> (s0 * 2)>
// CHECK-SCF-IF-DAG: #[[$TIMES4:.*]] = affine_map<()[s0] -> (s0 * 4)>
@@ -134,6 +146,84 @@ func.func @warp_extract(%laneid: index, %arg1: memref<1024x1024xf32>, %gid : ind
// -----
// Check that we can distribute writes of the maximum allowed number of elements.
// CHECK-D-LABEL: func @warp_extract_4_elems(
// CHECK-D: %[[WARPOP:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4x1xf32>)
// CHECK-D: "test.dummy_op"
// CHECK-D: "test.dummy_op"
// CHECK-D: vector.yield %{{.*}}, %{{.*}} : vector<4xf32>, vector<4x1xf32>
// CHECK-D: }
// CHECK-D: vector.warp_execute_on_lane_0(%{{.*}})[32] {
// CHECK-D: vector.transfer_write %[[WARPOP]]#1, %{{.*}}[%{{.*}}] {{.*}} : vector<4x1xf32>
// CHECK-D: }
// CHECK-D: vector.warp_execute_on_lane_0(%{{.*}})[32] {
// CHECK-D: vector.transfer_write %[[WARPOP]]#0, %{{.*}}[%{{.*}}] {{.*}} : vector<4xf32>
// CHECK-D: }
func.func @warp_extract_4_elems(%laneid: index, %arg1: memref<1024x1024xf32>, %gid : index) {
vector.warp_execute_on_lane_0(%laneid)[32] {
%c0 = arith.constant 0 : index
%v = "test.dummy_op"() : () -> (vector<4xf32>)
%v1 = "test.dummy_op"() : () -> (vector<4x1xf32>)
vector.transfer_write %v1, %arg1[%c0, %c0] : vector<4x1xf32>, memref<1024x1024xf32>
vector.transfer_write %v, %arg1[%c0, %c0] : vector<4xf32>, memref<1024x1024xf32>
}
return
}
// -----
// Check that we do not distribute writes larger than the maximum allowed
// number of elements.
// CHECK-D-LABEL: func @warp_extract_5_elems(
// CHECK-D: arith.constant 0 : index
// CHECK-D: vector.warp_execute_on_lane_0(%{{.*}})[32] {
// CHECK-D: %[[V:.+]] = "test.dummy_op"
// CHECK-D: %[[V1:.+]] = "test.dummy_op"
// CHECK-D: vector.transfer_write %[[V1]], %{{.*}}[%{{.*}}] {{.*}} : vector<5x1xf32>
// CHECK-D: vector.transfer_write %[[V]], %{{.*}}[%{{.*}}] {{.*}} : vector<5xf32>
// CHECK-D: }
func.func @warp_extract_5_elems(%laneid: index, %arg1: memref<1024x1024xf32>, %gid : index) {
vector.warp_execute_on_lane_0(%laneid)[32] {
%c0 = arith.constant 0 : index
%v = "test.dummy_op"() : () -> (vector<5xf32>)
%v1 = "test.dummy_op"() : () -> (vector<5x1xf32>)
vector.transfer_write %v1, %arg1[%c0, %c0] : vector<5x1xf32>, memref<1024x1024xf32>
vector.transfer_write %v, %arg1[%c0, %c0] : vector<5xf32>, memref<1024x1024xf32>
}
return
}
// -----
// Check that we do not distribute writes larger than the maximum allowed
// number of elements, or multiples of the maximum number of elements.
// CHECK-D-LABEL: func @warp_extract_8_elems(
// CHECK-D: arith.constant 0 : index
// CHECK-D: vector.warp_execute_on_lane_0(%{{.*}})[32] {
// CHECK-D: %[[V:.+]] = "test.dummy_op"
// CHECK-D: %[[V1:.+]] = "test.dummy_op"
// CHECK-D: vector.transfer_write %[[V1]], %{{.*}}[%{{.*}}] {{.*}} : vector<8x1xf32>
// CHECK-D: vector.transfer_write %[[V]], %{{.*}}[%{{.*}}] {{.*}} : vector<8xf32>
// CHECK-D: }
func.func @warp_extract_8_elems(%laneid: index, %arg1: memref<1024x1024xf32>, %gid : index) {
vector.warp_execute_on_lane_0(%laneid)[32] {
%c0 = arith.constant 0 : index
%v = "test.dummy_op"() : () -> (vector<8xf32>)
%v1 = "test.dummy_op"() : () -> (vector<8x1xf32>)
vector.transfer_write %v1, %arg1[%c0, %c0] : vector<8x1xf32>, memref<1024x1024xf32>
vector.transfer_write %v, %arg1[%c0, %c0] : vector<8xf32>, memref<1024x1024xf32>
}
return
}
// -----
// CHECK-PROP-LABEL: func @warp_dead_result(
func.func @warp_dead_result(%laneid: index) -> (vector<1xf32>) {
// CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>)

View File

@@ -568,6 +568,11 @@ struct TestVectorDistribution
llvm::cl::desc("Test distribution of transfer write"),
llvm::cl::init(false)};
Option<unsigned> maxTransferWriteElements{
*this, "max-transfer-write-elements",
llvm::cl::desc("Maximum number of transfer write elements to distribute"),
llvm::cl::init(1)};
Option<bool> hoistUniform{*this, "hoist-uniform",
llvm::cl::desc("Test hoist uniform"),
llvm::cl::init(false)};
@@ -624,7 +629,8 @@ struct TestVectorDistribution
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
} else if (distributeTransferWriteOps) {
RewritePatternSet patterns(ctx);
populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
populateDistributeTransferWriteOpPatterns(patterns, distributionFn,
maxTransferWriteElements);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
} else if (propagateDistribution) {
RewritePatternSet patterns(ctx);