mirror of
https://github.com/intel/llvm.git
synced 2026-02-08 08:57:43 +08:00
[mlir][vector] Enable distribution over multiple dimensions
This commit starts enabling vector distruction over multiple dimensions. It requires delinearize the lane ID to match the expected rank. shape_cast and transfer_read now can properly handle multiple dimensions. Reviewed By: hanchung Differential Revision: https://reviews.llvm.org/D157931
This commit is contained in:
@@ -5760,22 +5760,26 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed,
|
||||
expandedVecType.getElementType() != distributedVecType.getElementType())
|
||||
return op->emitOpError(
|
||||
"expected distributed vectors to have same rank and element type.");
|
||||
bool foundDistributedDim = false;
|
||||
|
||||
SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
|
||||
for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
|
||||
if (expandedVecType.getDimSize(i) == distributedVecType.getDimSize(i))
|
||||
int64_t eDim = expandedVecType.getDimSize(i);
|
||||
int64_t dDim = distributedVecType.getDimSize(i);
|
||||
if (eDim == dDim)
|
||||
continue;
|
||||
if (expandedVecType.getDimSize(i) ==
|
||||
distributedVecType.getDimSize(i) * warpSize) {
|
||||
if (foundDistributedDim)
|
||||
return op->emitOpError()
|
||||
<< "expected only one dimension to be distributed from "
|
||||
<< expandedVecType << " to " << distributedVecType;
|
||||
foundDistributedDim = true;
|
||||
continue;
|
||||
}
|
||||
return op->emitOpError() << "incompatible distribution dimensions from "
|
||||
<< expandedVecType << " to " << distributedVecType;
|
||||
if (eDim % dDim != 0)
|
||||
return op->emitOpError()
|
||||
<< "expected expanded vector dimension #" << i << " (" << eDim
|
||||
<< ") to be a multipler of the distributed vector dimension ("
|
||||
<< dDim << ")";
|
||||
scales[i] = eDim / dDim;
|
||||
}
|
||||
if (std::accumulate(scales.begin(), scales.end(), 1,
|
||||
std::multiplies<int64_t>()) != warpSize)
|
||||
return op->emitOpError()
|
||||
<< "incompatible distribution dimensions from " << expandedVecType
|
||||
<< " to " << distributedVecType << " with warp size = " << warpSize;
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Transforms/RegionUtils.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include <numeric>
|
||||
#include <utility>
|
||||
|
||||
using namespace mlir;
|
||||
@@ -45,8 +46,6 @@ static AffineMap calculateImplicitMap(VectorType sequentialType,
|
||||
}
|
||||
auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
|
||||
distributedType.getContext());
|
||||
assert(map.getNumResults() <= 1 &&
|
||||
"only support distribution along one dimension for now.");
|
||||
return map;
|
||||
}
|
||||
|
||||
@@ -702,6 +701,49 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
|
||||
}
|
||||
};
|
||||
|
||||
/// Delinearize the given `laneId` into multiple dimensions, where each
|
||||
/// dimension's size is determined by `originalShape` and `distributedShape`
|
||||
/// together. This function expects the total numbers of threads needed for
|
||||
/// distribution is equal to `warpSize`. Returns true and updates
|
||||
/// `delinearizedIds` if so.
|
||||
bool delinearizeLaneId(OpBuilder &builder, Location loc,
|
||||
ArrayRef<int64_t> originalShape,
|
||||
ArrayRef<int64_t> distributedShape, int64_t warpSize,
|
||||
Value laneId, SmallVectorImpl<Value> &delinearizedIds) {
|
||||
SmallVector<int64_t> sizes;
|
||||
for (auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) {
|
||||
if (large % small != 0)
|
||||
return false;
|
||||
sizes.push_back(large / small);
|
||||
}
|
||||
if (std::accumulate(sizes.begin(), sizes.end(), 1,
|
||||
std::multiplies<int64_t>()) != warpSize)
|
||||
return false;
|
||||
|
||||
AffineExpr s0, s1;
|
||||
bindSymbols(builder.getContext(), s0, s1);
|
||||
|
||||
int64_t usedThreads = 1;
|
||||
|
||||
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
|
||||
delinearizedIds.assign(sizes.size(), zero);
|
||||
|
||||
for (int i = sizes.size() - 1; i >= 0; --i) {
|
||||
usedThreads *= sizes[i];
|
||||
if (usedThreads == warpSize) {
|
||||
// We've used up all available threads. Don't need to perform modulo
|
||||
// anymore. And we can stop the calculation for further dimensions.
|
||||
delinearizedIds[i] = laneId;
|
||||
break;
|
||||
}
|
||||
delinearizedIds[i] =
|
||||
affine::makeComposedAffineApply(builder, loc, s0 % sizes[i], {laneId});
|
||||
laneId = affine::makeComposedAffineApply(
|
||||
builder, loc, s0.floorDiv(usedThreads), {laneId});
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Sink out transfer_read op feeding into a warp op yield.
|
||||
/// ```
|
||||
/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
|
||||
@@ -743,6 +785,16 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
|
||||
AffineMap indexMap = map.compose(read.getPermutationMap());
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
rewriter.setInsertionPointAfter(warpOp);
|
||||
|
||||
// Try to delinearize the lane ID to match the rank expected for
|
||||
// distribution.
|
||||
SmallVector<Value> delinearizedIds;
|
||||
if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
|
||||
distributedType.getShape(), warpOp.getWarpSize(),
|
||||
warpOp.getLaneid(), delinearizedIds))
|
||||
return rewriter.notifyMatchFailure(
|
||||
read, "cannot delinearize lane ID for distribution");
|
||||
|
||||
for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
|
||||
AffineExpr d0, d1;
|
||||
bindDims(read.getContext(), d0, d1);
|
||||
@@ -751,11 +803,10 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
|
||||
continue;
|
||||
unsigned indexPos = indexExpr.getPosition();
|
||||
unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
|
||||
int64_t scale =
|
||||
cast<VectorType>(distributedVal.getType()).getDimSize(vectorPos);
|
||||
int64_t scale = distributedType.getDimSize(vectorPos);
|
||||
indices[indexPos] = affine::makeComposedAffineApply(
|
||||
rewriter, read.getLoc(), d0 + scale * d1,
|
||||
{indices[indexPos], warpOp.getLaneid()});
|
||||
{indices[indexPos], delinearizedIds[vectorPos]});
|
||||
}
|
||||
auto newRead = rewriter.create<vector::TransferReadOp>(
|
||||
read.getLoc(), distributedVal.getType(), read.getSource(), indices,
|
||||
@@ -918,6 +969,48 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
|
||||
}
|
||||
};
|
||||
|
||||
/// Pattern to move shape cast out of the warp op. shape cast is basically a
|
||||
/// no-op for warp distribution; we need to handle the shape though.
|
||||
struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
|
||||
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
OpOperand *operand = getWarpResult(
|
||||
warpOp, [](Operation *op) { return isa<vector::ShapeCastOp>(op); });
|
||||
if (!operand)
|
||||
return failure();
|
||||
auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
|
||||
|
||||
unsigned int operandNumber = operand->getOperandNumber();
|
||||
auto castDistributedType =
|
||||
cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
|
||||
VectorType castOriginalType = oldCastOp.getSourceVectorType();
|
||||
VectorType castResultType = castDistributedType;
|
||||
|
||||
// We expect the distributed type to have a smaller rank than the original
|
||||
// type. Prepend with size-one dimensions to make them the same.
|
||||
unsigned castDistributedRank = castDistributedType.getRank();
|
||||
unsigned castOriginalRank = castOriginalType.getRank();
|
||||
if (castDistributedRank < castOriginalRank) {
|
||||
SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
|
||||
llvm::append_range(shape, castDistributedType.getShape());
|
||||
castDistributedType =
|
||||
VectorType::get(shape, castDistributedType.getElementType());
|
||||
}
|
||||
|
||||
SmallVector<size_t> newRetIndices;
|
||||
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
|
||||
rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
|
||||
newRetIndices);
|
||||
rewriter.setInsertionPointAfter(newWarpOp);
|
||||
Value newCast = rewriter.create<vector::ShapeCastOp>(
|
||||
oldCastOp.getLoc(), castResultType,
|
||||
newWarpOp->getResult(newRetIndices[0]));
|
||||
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Pattern to move out vector.extract of single element vector. Those don't
|
||||
/// need to be distributed and can just be propagated outside of the region.
|
||||
struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
|
||||
@@ -1557,9 +1650,9 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
|
||||
RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
|
||||
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit) {
|
||||
patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
|
||||
WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand,
|
||||
WarpOpConstant, WarpOpInsertElement, WarpOpInsert>(
|
||||
patterns.getContext(), benefit);
|
||||
WarpOpBroadcast, WarpOpShapeCast, WarpOpExtract,
|
||||
WarpOpForwardOperand, WarpOpConstant, WarpOpInsertElement,
|
||||
WarpOpInsert>(patterns.getContext(), benefit);
|
||||
patterns.add<WarpOpExtractElement>(patterns.getContext(),
|
||||
warpShuffleFromIdxFn, benefit);
|
||||
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
|
||||
|
||||
@@ -1593,7 +1593,7 @@ func.func @warp_wrong_arg_distribution(%laneid: index, %v0 : vector<4xi32>) {
|
||||
// -----
|
||||
|
||||
func.func @warp_2_distributed_dims(%laneid: index) {
|
||||
// expected-error@+1 {{'vector.warp_execute_on_lane_0' op expected only one dimension to be distributed from 'vector<128x128xi32>' to 'vector<4x4xi32>'}}
|
||||
// expected-error@+1 {{incompatible distribution dimensions from 'vector<128x128xi32>' to 'vector<4x4xi32>' with warp size = 32}}
|
||||
%2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x4xi32>) {
|
||||
%0 = arith.constant dense<2>: vector<128x128xi32>
|
||||
vector.yield %0 : vector<128x128xi32>
|
||||
@@ -1603,6 +1603,17 @@ func.func @warp_2_distributed_dims(%laneid: index) {
|
||||
|
||||
// -----
|
||||
|
||||
func.func @warp_2_distributed_dims(%laneid: index) {
|
||||
// expected-error@+1 {{expected expanded vector dimension #1 (8) to be a multipler of the distributed vector dimension (3)}}
|
||||
%2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x3xi32>) {
|
||||
%0 = arith.constant dense<2>: vector<4x8xi32>
|
||||
vector.yield %0 : vector<4x8xi32>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @warp_mismatch_rank(%laneid: index) {
|
||||
// expected-error@+1 {{'vector.warp_execute_on_lane_0' op expected distributed vectors to have same rank and element type.}}
|
||||
%2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x4xi32>) {
|
||||
|
||||
@@ -849,6 +849,17 @@ func.func @warp_execute_on_lane_0(%laneid: index) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @warp_execute_on_lane_0_2d
|
||||
func.func @warp_execute_on_lane_0_2d(%laneid: index) {
|
||||
// CHECK: vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1x4xi32>)
|
||||
%2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x4xi32>) {
|
||||
%0 = arith.constant dense<2>: vector<4x32xi32>
|
||||
// CHECK: vector.yield %{{.+}} : vector<4x32xi32>
|
||||
vector.yield %0 : vector<4x32xi32>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @warp_operand_result(
|
||||
func.func @warp_operand_result(%laneid: index, %v0 : vector<4xi32>) -> (vector<4xi32>) {
|
||||
// CHECK-NEXT: %{{.*}} = vector.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<4xi32>) -> (vector<4xi32>) {
|
||||
|
||||
@@ -827,6 +827,50 @@ func.func @lane_dependent_warp_propagate_read(
|
||||
|
||||
// -----
|
||||
|
||||
func.func @warp_propagate_read_3d(%laneid: index, %src: memref<32x4x32xf32>) -> vector<1x1x4xf32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%cst = arith.constant 0.000000e+00 : f32
|
||||
%r = vector.warp_execute_on_lane_0(%laneid)[1024] -> (vector<1x1x4xf32>) {
|
||||
%2 = vector.transfer_read %src[%c0, %c0, %c0], %cst : memref<32x4x32xf32>, vector<32x4x32xf32>
|
||||
vector.yield %2 : vector<32x4x32xf32>
|
||||
}
|
||||
return %r : vector<1x1x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-PROP-DAG: #[[$ID0MAP:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)>
|
||||
// CHECK-PROP-DAG: #[[$ID1MAP:.+]] = affine_map<()[s0] -> ((s0 floordiv 8) mod 4)>
|
||||
// CHECK-PROP-DAG: #[[$ID2MAP:.+]] = affine_map<()[s0] -> ((s0 floordiv 8) floordiv 32)>
|
||||
// CHECK-PROP-LABEL: func.func @warp_propagate_read_3d
|
||||
// CHECK-PROP-SAME: (%[[LANE:.+]]: index, %[[SRC:.+]]: memref<32x4x32xf32>)
|
||||
// CHECK-PROP-DAG: %[[ID0:.+]] = affine.apply #[[$ID0MAP]]()[%[[LANE]]]
|
||||
// CHECK-PROP-DAG: %[[ID1:.+]] = affine.apply #[[$ID1MAP]]()[%[[LANE]]]
|
||||
// CHECK-PROP-DAG: %[[ID2:.+]] = affine.apply #[[$ID2MAP]]()[%[[LANE]]]
|
||||
// CHECK-PROP: %[[READ:.+]] = vector.transfer_read %[[SRC]][%[[ID2]], %[[ID1]], %[[ID0]]], %{{.+}} : memref<32x4x32xf32>, vector<1x1x4xf32>
|
||||
// CHECK-PROP: return %[[READ]] : vector<1x1x4xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func.func @warp_propagate_read_broadcast(%laneid: index, %src: memref<32x1xf32>) -> vector<1x4xf32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%cst = arith.constant 0.000000e+00 : f32
|
||||
%r = vector.warp_execute_on_lane_0(%laneid)[512] -> (vector<1x4xf32>) {
|
||||
%2 = vector.transfer_read %src[%c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d0, 0)>} : memref<32x1xf32>, vector<32x64xf32>
|
||||
vector.yield %2 : vector<32x64xf32>
|
||||
}
|
||||
return %r : vector<1x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-PROP-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 16)>
|
||||
// CHECK-PROP-DAG: #[[$READMAP:.+]] = affine_map<(d0, d1) -> (d0, 0)>
|
||||
// CHECK-PROP-LABEL: func.func @warp_propagate_read_broadcast
|
||||
// CHECK-PROP-SAME: (%[[LANE:.+]]: index, %[[SRC:.+]]: memref<32x1xf32>)
|
||||
// CHECK-PROP: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK-PROP: %[[ID:.+]] = affine.apply #[[$MAP]]()[%[[LANE]]]
|
||||
// CHECK-PROP: %[[READ:.+]] = vector.transfer_read %[[SRC]][%[[ID]], %[[C0]]], %{{.+}} {in_bounds = [true, true], permutation_map = #[[$READMAP]]} : memref<32x1xf32>, vector<1x4xf32>
|
||||
// CHECK-PROP: return %[[READ]] : vector<1x4xf32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-PROP: func @dont_duplicate_read
|
||||
func.func @dont_duplicate_read(
|
||||
%laneid: index, %src: memref<1024xf32>) -> vector<1xf32> {
|
||||
@@ -1173,3 +1217,22 @@ func.func @dont_fold_vector_broadcast(%laneid: index) {
|
||||
vector.print %r : vector<1x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @warp_propagate_shape_cast(%laneid: index, %src: memref<32x4x32xf32>) -> vector<4xf32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%cst = arith.constant 0.000000e+00 : f32
|
||||
%r = vector.warp_execute_on_lane_0(%laneid)[1024] -> (vector<4xf32>) {
|
||||
%2 = vector.transfer_read %src[%c0, %c0, %c0], %cst : memref<32x4x32xf32>, vector<32x4x32xf32>
|
||||
%3 = vector.shape_cast %2 : vector<32x4x32xf32> to vector<4096xf32>
|
||||
vector.yield %3 : vector<4096xf32>
|
||||
}
|
||||
return %r : vector<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-PROP-LABEL: func.func @warp_propagate_shape_cast
|
||||
// CHECK-PROP: %[[READ:.+]] = vector.transfer_read {{.+}} : memref<32x4x32xf32>, vector<1x1x4xf32>
|
||||
// CHECK-PROP: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<1x1x4xf32> to vector<4xf32>
|
||||
// CHECK-PROP: return %[[CAST]] : vector<4xf32>
|
||||
|
||||
|
||||
Reference in New Issue
Block a user