mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 01:07:04 +08:00
[mlir] [VectorOps] Progressive lowering of vector.broadcast
Summary: Rather than having a full, recursive, lowering of vector.broadcast to LLVM IR, it is much more elegant to have a progressive lowering of each vector.broadcast into a lower dimensional vector.broadcast, until only elementary vector operations remain. This results in more elegant, step-wise code, that is easier to understand. Also makes some optimizations in the generated code. Reviewers: nicolasvasilache, mehdi_amini, andydavis1, grosul1 Reviewed By: nicolasvasilache Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, grosul1, frgossen, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D78071
This commit is contained in:
@@ -979,7 +979,114 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Progressive lowering of OuterProductOp.
|
||||
/// Progressive lowering of BroadcastOp.
|
||||
class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
|
||||
public:
|
||||
using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::BroadcastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
VectorType dstType = op.getVectorType();
|
||||
VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
|
||||
Type eltType = dstType.getElementType();
|
||||
|
||||
// Determine rank of source and destination.
|
||||
int64_t srcRank = srcType ? srcType.getRank() : 0;
|
||||
int64_t dstRank = dstType.getRank();
|
||||
|
||||
// Duplicate this rank.
|
||||
// For example:
|
||||
// %x = broadcast %y : k-D to n-D, k < n
|
||||
// becomes:
|
||||
// %b = broadcast %y : k-D to (n-1)-D
|
||||
// %x = [%b,%b,%b,%b] : n-D
|
||||
// becomes:
|
||||
// %b = [%y,%y] : (n-1)-D
|
||||
// %x = [%b,%b,%b,%b] : n-D
|
||||
if (srcRank < dstRank) {
|
||||
// Scalar to any vector can use splat.
|
||||
if (srcRank == 0) {
|
||||
rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, op.source());
|
||||
return success();
|
||||
}
|
||||
// Duplication.
|
||||
VectorType resType =
|
||||
VectorType::get(dstType.getShape().drop_front(), eltType);
|
||||
Value bcst =
|
||||
rewriter.create<vector::BroadcastOp>(loc, resType, op.source());
|
||||
Value zero = rewriter.create<ConstantOp>(loc, eltType,
|
||||
rewriter.getZeroAttr(eltType));
|
||||
Value result = rewriter.create<SplatOp>(loc, dstType, zero);
|
||||
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
|
||||
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Find non-matching dimension, if any.
|
||||
assert(srcRank == dstRank);
|
||||
int64_t m = -1;
|
||||
for (int64_t r = 0; r < dstRank; r++)
|
||||
if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
|
||||
m = r;
|
||||
break;
|
||||
}
|
||||
|
||||
// All trailing dimensions are the same. Simply pass through.
|
||||
if (m == -1) {
|
||||
rewriter.replaceOp(op, op.source());
|
||||
return success();
|
||||
}
|
||||
|
||||
// Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
|
||||
if (srcRank == 1) {
|
||||
assert(m == 0);
|
||||
Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
|
||||
rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, ext);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Any non-matching dimension forces a stretch along this rank.
|
||||
// For example:
|
||||
// %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
|
||||
// becomes:
|
||||
// %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
|
||||
// %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
|
||||
// %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
|
||||
// %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
|
||||
// %x = [%a,%b,%c,%d]
|
||||
// becomes:
|
||||
// %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
|
||||
// %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
|
||||
// %a = [%u, %v]
|
||||
// ..
|
||||
// %x = [%a,%b,%c,%d]
|
||||
VectorType resType =
|
||||
VectorType::get(dstType.getShape().drop_front(), eltType);
|
||||
Value zero = rewriter.create<ConstantOp>(loc, eltType,
|
||||
rewriter.getZeroAttr(eltType));
|
||||
Value result = rewriter.create<SplatOp>(loc, dstType, zero);
|
||||
if (m == 0) {
|
||||
// Stetch at start.
|
||||
Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
|
||||
Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
|
||||
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
|
||||
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
|
||||
} else {
|
||||
// Stetch not at start.
|
||||
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
|
||||
Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), d);
|
||||
Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
|
||||
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
|
||||
}
|
||||
}
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Progressive lowering of TransposeOp.
|
||||
/// One:
|
||||
/// %x = vector.transpose %y, [1, 0]
|
||||
/// is replaced by:
|
||||
@@ -1518,7 +1625,7 @@ void mlir::vector::populateVectorContractLoweringPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *context,
|
||||
VectorTransformsOptions parameters) {
|
||||
patterns.insert<ShapeCastOp2DDownCastRewritePattern,
|
||||
ShapeCastOp2DUpCastRewritePattern, TransposeOpLowering,
|
||||
OuterProductOpLowering>(context);
|
||||
ShapeCastOp2DUpCastRewritePattern, BroadcastOpLowering,
|
||||
TransposeOpLowering, OuterProductOpLowering>(context);
|
||||
patterns.insert<ContractionOpLowering>(parameters, context);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user