mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 09:57:08 +08:00
[mlir] [VectorOps] A "reference" lowering of vector.transpose to LLVM IR
Summary: Makes the vector.tranpose runnable on CPU. Reviewers: nicolasvasilache, andydavis1, rriddle Reviewed By: nicolasvasilache Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D76644
This commit is contained in:
@@ -864,6 +864,67 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Progressive lowering of OuterProductOp.
|
||||
/// One:
|
||||
/// %x = vector.transpose %y, [1, 0]
|
||||
/// is replaced by:
|
||||
/// %z = constant dense<0.000000e+00>
|
||||
/// %0 = vector.extract %y[0, 0]
|
||||
/// %1 = vector.insert %0, %z [0, 0]
|
||||
/// ..
|
||||
/// %x = vector.insert .., .. [.., ..]
|
||||
class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
|
||||
public:
|
||||
using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::TransposeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
VectorType resType = op.getResultType();
|
||||
Type eltType = resType.getElementType();
|
||||
|
||||
// Set up convenience transposition table.
|
||||
SmallVector<int64_t, 4> transp;
|
||||
for (auto attr : op.transp())
|
||||
transp.push_back(attr.cast<IntegerAttr>().getInt());
|
||||
|
||||
// Generate fully unrolled extract/insert ops.
|
||||
Value zero = rewriter.create<ConstantOp>(loc, eltType,
|
||||
rewriter.getZeroAttr(eltType));
|
||||
Value result = rewriter.create<SplatOp>(loc, resType, zero);
|
||||
SmallVector<int64_t, 4> lhs(transp.size(), 0);
|
||||
SmallVector<int64_t, 4> rhs(transp.size(), 0);
|
||||
rewriter.replaceOp(op, expandIndices(loc, resType, 0, transp, lhs, rhs,
|
||||
op.vector(), result, rewriter));
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
// Builds the indices arrays for the lhs and rhs. Generates the extract/insert
|
||||
// operation when al ranks are exhausted.
|
||||
Value expandIndices(Location loc, VectorType resType, int64_t pos,
|
||||
SmallVector<int64_t, 4> &transp,
|
||||
SmallVector<int64_t, 4> &lhs,
|
||||
SmallVector<int64_t, 4> &rhs, Value input, Value result,
|
||||
PatternRewriter &rewriter) const {
|
||||
if (pos >= resType.getRank()) {
|
||||
auto ridx = rewriter.getI64ArrayAttr(rhs);
|
||||
auto lidx = rewriter.getI64ArrayAttr(lhs);
|
||||
Type eltType = resType.getElementType();
|
||||
Value e = rewriter.create<vector::ExtractOp>(loc, eltType, input, ridx);
|
||||
return rewriter.create<vector::InsertOp>(loc, resType, e, result, lidx);
|
||||
}
|
||||
for (int64_t d = 0, e = resType.getDimSize(pos); d < e; ++d) {
|
||||
lhs[pos] = d;
|
||||
rhs[transp[pos]] = d;
|
||||
result = expandIndices(loc, resType, pos + 1, transp, lhs, rhs, input,
|
||||
result, rewriter);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
/// Progressive lowering of OuterProductOp.
|
||||
/// One:
|
||||
/// %x = vector.outerproduct %lhs, %rhs, %acc
|
||||
@@ -1353,7 +1414,7 @@ void mlir::vector::populateVectorContractLoweringPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *context,
|
||||
VectorTransformsOptions parameters) {
|
||||
patterns.insert<ShapeCastOp2DDownCastRewritePattern,
|
||||
ShapeCastOp2DUpCastRewritePattern, OuterProductOpLowering>(
|
||||
context);
|
||||
ShapeCastOp2DUpCastRewritePattern, TransposeOpLowering,
|
||||
OuterProductOpLowering>(context);
|
||||
patterns.insert<ContractionOpLowering>(parameters, context);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user