mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 12:26:52 +08:00
[mlir][vector] Split TransposeOpLowering into 2 patterns (#91935)
Splits `TransposeOpLowering` into two patterns:
1. `Transpose2DWithUnitDimToShapeCast` - rewrites 2D `vector.transpose`
as `vector.shape_cast` (there has to be at least one unit dim),
2. `TransposeOpLowering` - the original pattern without the part
extracted into `Transpose2DWithUnitDimToShapeCast`.
The rationale behind the split:
* the output generated by `Transpose2DWithUnitDimToShapeCast` doesn't
really match the intended output from `TransposeOpLowering` as
documented in the source file - it doesn't make much sense to keep
it embedded inside `TransposeOpLowering`,
* `Transpose2DWithUnitDimToShapeCast` _does_ work for scalable vectors,
`TransposeOpLowering` _does_ not.
This commit is contained in:
committed by
GitHub
parent
363258a3cc
commit
cbd72cb0de
@@ -326,6 +326,10 @@ public:
|
||||
VectorType inputType = op.getSourceVectorType();
|
||||
VectorType resType = op.getResultVectorType();
|
||||
|
||||
if (inputType.isScalable())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "This lowering does not support scalable vectors");
|
||||
|
||||
// Set up convenience transposition table.
|
||||
ArrayRef<int64_t> transp = op.getPermutation();
|
||||
|
||||
@@ -334,28 +338,6 @@ public:
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Options specifies lowering to shuffle");
|
||||
|
||||
// Replace:
|
||||
// vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
|
||||
// vector<1xnxelty>
|
||||
// with:
|
||||
// vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
|
||||
//
|
||||
// Source with leading unit dim (inverse) is also replaced. Unit dim must
|
||||
// be fixed. Non-unit can be scalable.
|
||||
if (resType.getRank() == 2 &&
|
||||
((resType.getShape().front() == 1 &&
|
||||
!resType.getScalableDims().front()) ||
|
||||
(resType.getShape().back() == 1 &&
|
||||
!resType.getScalableDims().back())) &&
|
||||
transp == ArrayRef<int64_t>({1, 0})) {
|
||||
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
|
||||
return success();
|
||||
}
|
||||
|
||||
// TODO: Add support for scalable vectors
|
||||
if (inputType.isScalable())
|
||||
return failure();
|
||||
|
||||
// Handle a true 2-D matrix transpose differently when requested.
|
||||
if (vectorTransformOptions.vectorTransposeLowering ==
|
||||
vector::VectorTransposeLowering::Flat &&
|
||||
@@ -411,6 +393,64 @@ private:
|
||||
vector::VectorTransformsOptions vectorTransformOptions;
|
||||
};
|
||||
|
||||
/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
|
||||
/// to 2D vectors with at least one unit dim. For example:
|
||||
///
|
||||
/// Replace:
|
||||
/// vector.transpose %0, [1, 0] : vector<4x1xi32>> to
|
||||
/// vector<1x4xi32>
|
||||
/// with:
|
||||
/// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
|
||||
///
|
||||
/// Source with leading unit dim (inverse) is also replaced. Unit dim must
|
||||
/// be fixed. Non-unit dim can be scalable.
|
||||
///
|
||||
/// TODO: This pattern was introduced specifically to help lower scalable
|
||||
/// vectors. In hindsight, a more specialised canonicalization (for shape_cast's
|
||||
/// to cancel out) would be preferable:
|
||||
///
|
||||
/// BEFORE:
|
||||
/// %0 = some_op
|
||||
/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
|
||||
/// %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
|
||||
/// AFTER:
|
||||
/// %0 = some_op
|
||||
/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
|
||||
///
|
||||
/// Given the context above, we may want to consider (re-)moving this pattern
|
||||
/// at some later time. I am leaving it for now in case there are other users
|
||||
/// that I am not aware of.
|
||||
class Transpose2DWithUnitDimToShapeCast
|
||||
: public OpRewritePattern<vector::TransposeOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
|
||||
PatternBenefit benefit = 1)
|
||||
: OpRewritePattern<vector::TransposeOp>(context, benefit) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::TransposeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value input = op.getVector();
|
||||
VectorType resType = op.getResultVectorType();
|
||||
|
||||
// Set up convenience transposition table.
|
||||
ArrayRef<int64_t> transp = op.getPermutation();
|
||||
|
||||
if (resType.getRank() == 2 &&
|
||||
((resType.getShape().front() == 1 &&
|
||||
!resType.getScalableDims().front()) ||
|
||||
(resType.getShape().back() == 1 &&
|
||||
!resType.getScalableDims().back())) &&
|
||||
transp == ArrayRef<int64_t>({1, 0})) {
|
||||
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
|
||||
/// If the strategy is Shuffle1D, it will be lowered to:
|
||||
/// vector.shape_cast 2D -> 1D
|
||||
@@ -483,6 +523,8 @@ private:
|
||||
void mlir::vector::populateVectorTransposeLoweringPatterns(
|
||||
RewritePatternSet &patterns, VectorTransformsOptions options,
|
||||
PatternBenefit benefit) {
|
||||
patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
|
||||
benefit);
|
||||
patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
|
||||
options, patterns.getContext(), benefit);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user