[mlir][vector] Split TransposeOpLowering into 2 patterns (#91935)

Splits `TransposeOpLowering` into two patterns:
1. `Transpose2DWithUnitDimToShapeCast` - rewrites 2D `vector.transpose`
    as `vector.shape_cast` (there has to be at least one unit dim),
  2. `TransposeOpLowering` - the original pattern without the part
    extracted into `Transpose2DWithUnitDimToShapeCast`.

The rationale behind the split:
  * the output generated by `Transpose2DWithUnitDimToShapeCast` doesn't
    really match the intended output from `TransposeOpLowering` as
    documented in the source file - it doesn't make much sense to keep
    it embedded inside `TransposeOpLowering`,
* `Transpose2DWithUnitDimToShapeCast` _does_ work for scalable vectors,
    `TransposeOpLowering` _does_ not.
This commit is contained in:
Andrzej Warzyński
2024-05-14 12:08:57 +01:00
committed by GitHub
parent 363258a3cc
commit cbd72cb0de

View File

@@ -326,6 +326,10 @@ public:
VectorType inputType = op.getSourceVectorType();
VectorType resType = op.getResultVectorType();
if (inputType.isScalable())
return rewriter.notifyMatchFailure(
op, "This lowering does not support scalable vectors");
// Set up convenience transposition table.
ArrayRef<int64_t> transp = op.getPermutation();
@@ -334,28 +338,6 @@ public:
return rewriter.notifyMatchFailure(
op, "Options specifies lowering to shuffle");
// Replace:
// vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
// vector<1xnxelty>
// with:
// vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
//
// Source with leading unit dim (inverse) is also replaced. Unit dim must
// be fixed. Non-unit can be scalable.
if (resType.getRank() == 2 &&
((resType.getShape().front() == 1 &&
!resType.getScalableDims().front()) ||
(resType.getShape().back() == 1 &&
!resType.getScalableDims().back())) &&
transp == ArrayRef<int64_t>({1, 0})) {
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
return success();
}
// TODO: Add support for scalable vectors
if (inputType.isScalable())
return failure();
// Handle a true 2-D matrix transpose differently when requested.
if (vectorTransformOptions.vectorTransposeLowering ==
vector::VectorTransposeLowering::Flat &&
@@ -411,6 +393,64 @@ private:
vector::VectorTransformsOptions vectorTransformOptions;
};
/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
/// to 2D vectors with at least one unit dim. For example:
///
/// Replace:
/// vector.transpose %0, [1, 0] : vector<4x1xi32>> to
/// vector<1x4xi32>
/// with:
/// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
///
/// Source with leading unit dim (inverse) is also replaced. Unit dim must
/// be fixed. Non-unit dim can be scalable.
///
/// TODO: This pattern was introduced specifically to help lower scalable
/// vectors. In hindsight, a more specialised canonicalization (for shape_cast's
/// to cancel out) would be preferable:
///
/// BEFORE:
/// %0 = some_op
/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
/// %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
/// AFTER:
/// %0 = some_op
/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
///
/// Given the context above, we may want to consider (re-)moving this pattern
/// at some later time. I am leaving it for now in case there are other users
/// that I am not aware of.
class Transpose2DWithUnitDimToShapeCast
: public OpRewritePattern<vector::TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;
Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::TransposeOp>(context, benefit) {}
LogicalResult matchAndRewrite(vector::TransposeOp op,
PatternRewriter &rewriter) const override {
Value input = op.getVector();
VectorType resType = op.getResultVectorType();
// Set up convenience transposition table.
ArrayRef<int64_t> transp = op.getPermutation();
if (resType.getRank() == 2 &&
((resType.getShape().front() == 1 &&
!resType.getScalableDims().front()) ||
(resType.getShape().back() == 1 &&
!resType.getScalableDims().back())) &&
transp == ArrayRef<int64_t>({1, 0})) {
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
return success();
}
return failure();
}
};
/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
/// If the strategy is Shuffle1D, it will be lowered to:
/// vector.shape_cast 2D -> 1D
@@ -483,6 +523,8 @@ private:
void mlir::vector::populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns, VectorTransformsOptions options,
PatternBenefit benefit) {
patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
benefit);
patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
options, patterns.getContext(), benefit);
}