[mlir][vector] Enable distribution over multiple dimensions

This commit starts enabling vector distruction over multiple
dimensions. It requires delinearize the lane ID to match the
expected rank. shape_cast and transfer_read now can properly
handle multiple dimensions.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D157931
This commit is contained in:
Lei Zhang
2023-08-16 12:05:38 -07:00
parent 42dad521e3
commit 73ddc4474b
5 changed files with 204 additions and 22 deletions

View File

@@ -5760,22 +5760,26 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed,
expandedVecType.getElementType() != distributedVecType.getElementType())
return op->emitOpError(
"expected distributed vectors to have same rank and element type.");
bool foundDistributedDim = false;
SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
if (expandedVecType.getDimSize(i) == distributedVecType.getDimSize(i))
int64_t eDim = expandedVecType.getDimSize(i);
int64_t dDim = distributedVecType.getDimSize(i);
if (eDim == dDim)
continue;
if (expandedVecType.getDimSize(i) ==
distributedVecType.getDimSize(i) * warpSize) {
if (foundDistributedDim)
return op->emitOpError()
<< "expected only one dimension to be distributed from "
<< expandedVecType << " to " << distributedVecType;
foundDistributedDim = true;
continue;
}
return op->emitOpError() << "incompatible distribution dimensions from "
<< expandedVecType << " to " << distributedVecType;
if (eDim % dDim != 0)
return op->emitOpError()
<< "expected expanded vector dimension #" << i << " (" << eDim
<< ") to be a multipler of the distributed vector dimension ("
<< dDim << ")";
scales[i] = eDim / dDim;
}
if (std::accumulate(scales.begin(), scales.end(), 1,
std::multiplies<int64_t>()) != warpSize)
return op->emitOpError()
<< "incompatible distribution dimensions from " << expandedVecType
<< " to " << distributedVecType << " with warp size = " << warpSize;
return success();
}

View File

@@ -16,6 +16,7 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/SetVector.h"
#include <numeric>
#include <utility>
using namespace mlir;
@@ -45,8 +46,6 @@ static AffineMap calculateImplicitMap(VectorType sequentialType,
}
auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
distributedType.getContext());
assert(map.getNumResults() <= 1 &&
"only support distribution along one dimension for now.");
return map;
}
@@ -702,6 +701,49 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
};
/// Delinearize the given `laneId` into multiple dimensions, where each
/// dimension's size is determined by `originalShape` and `distributedShape`
/// together. This function expects the total numbers of threads needed for
/// distribution is equal to `warpSize`. Returns true and updates
/// `delinearizedIds` if so.
bool delinearizeLaneId(OpBuilder &builder, Location loc,
ArrayRef<int64_t> originalShape,
ArrayRef<int64_t> distributedShape, int64_t warpSize,
Value laneId, SmallVectorImpl<Value> &delinearizedIds) {
SmallVector<int64_t> sizes;
for (auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) {
if (large % small != 0)
return false;
sizes.push_back(large / small);
}
if (std::accumulate(sizes.begin(), sizes.end(), 1,
std::multiplies<int64_t>()) != warpSize)
return false;
AffineExpr s0, s1;
bindSymbols(builder.getContext(), s0, s1);
int64_t usedThreads = 1;
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
delinearizedIds.assign(sizes.size(), zero);
for (int i = sizes.size() - 1; i >= 0; --i) {
usedThreads *= sizes[i];
if (usedThreads == warpSize) {
// We've used up all available threads. Don't need to perform modulo
// anymore. And we can stop the calculation for further dimensions.
delinearizedIds[i] = laneId;
break;
}
delinearizedIds[i] =
affine::makeComposedAffineApply(builder, loc, s0 % sizes[i], {laneId});
laneId = affine::makeComposedAffineApply(
builder, loc, s0.floorDiv(usedThreads), {laneId});
}
return true;
}
/// Sink out transfer_read op feeding into a warp op yield.
/// ```
/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
@@ -743,6 +785,16 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
AffineMap indexMap = map.compose(read.getPermutationMap());
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(warpOp);
// Try to delinearize the lane ID to match the rank expected for
// distribution.
SmallVector<Value> delinearizedIds;
if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
distributedType.getShape(), warpOp.getWarpSize(),
warpOp.getLaneid(), delinearizedIds))
return rewriter.notifyMatchFailure(
read, "cannot delinearize lane ID for distribution");
for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
AffineExpr d0, d1;
bindDims(read.getContext(), d0, d1);
@@ -751,11 +803,10 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
continue;
unsigned indexPos = indexExpr.getPosition();
unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
int64_t scale =
cast<VectorType>(distributedVal.getType()).getDimSize(vectorPos);
int64_t scale = distributedType.getDimSize(vectorPos);
indices[indexPos] = affine::makeComposedAffineApply(
rewriter, read.getLoc(), d0 + scale * d1,
{indices[indexPos], warpOp.getLaneid()});
{indices[indexPos], delinearizedIds[vectorPos]});
}
auto newRead = rewriter.create<vector::TransferReadOp>(
read.getLoc(), distributedVal.getType(), read.getSource(), indices,
@@ -918,6 +969,48 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
};
/// Pattern to move shape cast out of the warp op. shape cast is basically a
/// no-op for warp distribution; we need to handle the shape though.
struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand = getWarpResult(
warpOp, [](Operation *op) { return isa<vector::ShapeCastOp>(op); });
if (!operand)
return failure();
auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
unsigned int operandNumber = operand->getOperandNumber();
auto castDistributedType =
cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
VectorType castOriginalType = oldCastOp.getSourceVectorType();
VectorType castResultType = castDistributedType;
// We expect the distributed type to have a smaller rank than the original
// type. Prepend with size-one dimensions to make them the same.
unsigned castDistributedRank = castDistributedType.getRank();
unsigned castOriginalRank = castOriginalType.getRank();
if (castDistributedRank < castOriginalRank) {
SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
llvm::append_range(shape, castDistributedType.getShape());
castDistributedType =
VectorType::get(shape, castDistributedType.getElementType());
}
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value newCast = rewriter.create<vector::ShapeCastOp>(
oldCastOp.getLoc(), castResultType,
newWarpOp->getResult(newRetIndices[0]));
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
return success();
}
};
/// Pattern to move out vector.extract of single element vector. Those don't
/// need to be distributed and can just be propagated outside of the region.
struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
@@ -1557,9 +1650,9 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit) {
patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand,
WarpOpConstant, WarpOpInsertElement, WarpOpInsert>(
patterns.getContext(), benefit);
WarpOpBroadcast, WarpOpShapeCast, WarpOpExtract,
WarpOpForwardOperand, WarpOpConstant, WarpOpInsertElement,
WarpOpInsert>(patterns.getContext(), benefit);
patterns.add<WarpOpExtractElement>(patterns.getContext(),
warpShuffleFromIdxFn, benefit);
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,

View File

@@ -1593,7 +1593,7 @@ func.func @warp_wrong_arg_distribution(%laneid: index, %v0 : vector<4xi32>) {
// -----
func.func @warp_2_distributed_dims(%laneid: index) {
// expected-error@+1 {{'vector.warp_execute_on_lane_0' op expected only one dimension to be distributed from 'vector<128x128xi32>' to 'vector<4x4xi32>'}}
// expected-error@+1 {{incompatible distribution dimensions from 'vector<128x128xi32>' to 'vector<4x4xi32>' with warp size = 32}}
%2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x4xi32>) {
%0 = arith.constant dense<2>: vector<128x128xi32>
vector.yield %0 : vector<128x128xi32>
@@ -1603,6 +1603,17 @@ func.func @warp_2_distributed_dims(%laneid: index) {
// -----
func.func @warp_2_distributed_dims(%laneid: index) {
// expected-error@+1 {{expected expanded vector dimension #1 (8) to be a multipler of the distributed vector dimension (3)}}
%2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x3xi32>) {
%0 = arith.constant dense<2>: vector<4x8xi32>
vector.yield %0 : vector<4x8xi32>
}
return
}
// -----
func.func @warp_mismatch_rank(%laneid: index) {
// expected-error@+1 {{'vector.warp_execute_on_lane_0' op expected distributed vectors to have same rank and element type.}}
%2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x4xi32>) {

View File

@@ -849,6 +849,17 @@ func.func @warp_execute_on_lane_0(%laneid: index) {
return
}
// CHECK-LABEL: func.func @warp_execute_on_lane_0_2d
func.func @warp_execute_on_lane_0_2d(%laneid: index) {
// CHECK: vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1x4xi32>)
%2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x4xi32>) {
%0 = arith.constant dense<2>: vector<4x32xi32>
// CHECK: vector.yield %{{.+}} : vector<4x32xi32>
vector.yield %0 : vector<4x32xi32>
}
return
}
// CHECK-LABEL: func @warp_operand_result(
func.func @warp_operand_result(%laneid: index, %v0 : vector<4xi32>) -> (vector<4xi32>) {
// CHECK-NEXT: %{{.*}} = vector.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<4xi32>) -> (vector<4xi32>) {

View File

@@ -827,6 +827,50 @@ func.func @lane_dependent_warp_propagate_read(
// -----
func.func @warp_propagate_read_3d(%laneid: index, %src: memref<32x4x32xf32>) -> vector<1x1x4xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%r = vector.warp_execute_on_lane_0(%laneid)[1024] -> (vector<1x1x4xf32>) {
%2 = vector.transfer_read %src[%c0, %c0, %c0], %cst : memref<32x4x32xf32>, vector<32x4x32xf32>
vector.yield %2 : vector<32x4x32xf32>
}
return %r : vector<1x1x4xf32>
}
// CHECK-PROP-DAG: #[[$ID0MAP:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)>
// CHECK-PROP-DAG: #[[$ID1MAP:.+]] = affine_map<()[s0] -> ((s0 floordiv 8) mod 4)>
// CHECK-PROP-DAG: #[[$ID2MAP:.+]] = affine_map<()[s0] -> ((s0 floordiv 8) floordiv 32)>
// CHECK-PROP-LABEL: func.func @warp_propagate_read_3d
// CHECK-PROP-SAME: (%[[LANE:.+]]: index, %[[SRC:.+]]: memref<32x4x32xf32>)
// CHECK-PROP-DAG: %[[ID0:.+]] = affine.apply #[[$ID0MAP]]()[%[[LANE]]]
// CHECK-PROP-DAG: %[[ID1:.+]] = affine.apply #[[$ID1MAP]]()[%[[LANE]]]
// CHECK-PROP-DAG: %[[ID2:.+]] = affine.apply #[[$ID2MAP]]()[%[[LANE]]]
// CHECK-PROP: %[[READ:.+]] = vector.transfer_read %[[SRC]][%[[ID2]], %[[ID1]], %[[ID0]]], %{{.+}} : memref<32x4x32xf32>, vector<1x1x4xf32>
// CHECK-PROP: return %[[READ]] : vector<1x1x4xf32>
// -----
func.func @warp_propagate_read_broadcast(%laneid: index, %src: memref<32x1xf32>) -> vector<1x4xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%r = vector.warp_execute_on_lane_0(%laneid)[512] -> (vector<1x4xf32>) {
%2 = vector.transfer_read %src[%c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d0, 0)>} : memref<32x1xf32>, vector<32x64xf32>
vector.yield %2 : vector<32x64xf32>
}
return %r : vector<1x4xf32>
}
// CHECK-PROP-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 16)>
// CHECK-PROP-DAG: #[[$READMAP:.+]] = affine_map<(d0, d1) -> (d0, 0)>
// CHECK-PROP-LABEL: func.func @warp_propagate_read_broadcast
// CHECK-PROP-SAME: (%[[LANE:.+]]: index, %[[SRC:.+]]: memref<32x1xf32>)
// CHECK-PROP: %[[C0:.+]] = arith.constant 0 : index
// CHECK-PROP: %[[ID:.+]] = affine.apply #[[$MAP]]()[%[[LANE]]]
// CHECK-PROP: %[[READ:.+]] = vector.transfer_read %[[SRC]][%[[ID]], %[[C0]]], %{{.+}} {in_bounds = [true, true], permutation_map = #[[$READMAP]]} : memref<32x1xf32>, vector<1x4xf32>
// CHECK-PROP: return %[[READ]] : vector<1x4xf32>
// -----
// CHECK-PROP: func @dont_duplicate_read
func.func @dont_duplicate_read(
%laneid: index, %src: memref<1024xf32>) -> vector<1xf32> {
@@ -1173,3 +1217,22 @@ func.func @dont_fold_vector_broadcast(%laneid: index) {
vector.print %r : vector<1x2xf32>
return
}
// -----
func.func @warp_propagate_shape_cast(%laneid: index, %src: memref<32x4x32xf32>) -> vector<4xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%r = vector.warp_execute_on_lane_0(%laneid)[1024] -> (vector<4xf32>) {
%2 = vector.transfer_read %src[%c0, %c0, %c0], %cst : memref<32x4x32xf32>, vector<32x4x32xf32>
%3 = vector.shape_cast %2 : vector<32x4x32xf32> to vector<4096xf32>
vector.yield %3 : vector<4096xf32>
}
return %r : vector<4xf32>
}
// CHECK-PROP-LABEL: func.func @warp_propagate_shape_cast
// CHECK-PROP: %[[READ:.+]] = vector.transfer_read {{.+}} : memref<32x4x32xf32>, vector<1x1x4xf32>
// CHECK-PROP: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<1x1x4xf32> to vector<4xf32>
// CHECK-PROP: return %[[CAST]] : vector<4xf32>