diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp index 7011c478fefb..ca8a6f6d82a6 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp @@ -326,6 +326,10 @@ public: VectorType inputType = op.getSourceVectorType(); VectorType resType = op.getResultVectorType(); + if (inputType.isScalable()) + return rewriter.notifyMatchFailure( + op, "This lowering does not support scalable vectors"); + // Set up convenience transposition table. ArrayRef transp = op.getPermutation(); @@ -334,28 +338,6 @@ public: return rewriter.notifyMatchFailure( op, "Options specifies lowering to shuffle"); - // Replace: - // vector.transpose %0, [1, 0] : vector> to - // vector<1xnxelty> - // with: - // vector.shape_cast %0 : vector> to vector<1xnxelty> - // - // Source with leading unit dim (inverse) is also replaced. Unit dim must - // be fixed. Non-unit can be scalable. - if (resType.getRank() == 2 && - ((resType.getShape().front() == 1 && - !resType.getScalableDims().front()) || - (resType.getShape().back() == 1 && - !resType.getScalableDims().back())) && - transp == ArrayRef({1, 0})) { - rewriter.replaceOpWithNewOp(op, resType, input); - return success(); - } - - // TODO: Add support for scalable vectors - if (inputType.isScalable()) - return failure(); - // Handle a true 2-D matrix transpose differently when requested. if (vectorTransformOptions.vectorTransposeLowering == vector::VectorTransposeLowering::Flat && @@ -411,6 +393,64 @@ private: vector::VectorTransformsOptions vectorTransformOptions; }; +/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied +/// to 2D vectors with at least one unit dim. For example: +/// +/// Replace: +/// vector.transpose %0, [1, 0] : vector<4x1xi32>> to +/// vector<1x4xi32> +/// with: +/// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32> +/// +/// Source with leading unit dim (inverse) is also replaced. Unit dim must +/// be fixed. Non-unit dim can be scalable. +/// +/// TODO: This pattern was introduced specifically to help lower scalable +/// vectors. In hindsight, a more specialised canonicalization (for shape_cast's +/// to cancel out) would be preferable: +/// +/// BEFORE: +/// %0 = some_op +/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32> +/// %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> +/// AFTER: +/// %0 = some_op +/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32> +/// +/// Given the context above, we may want to consider (re-)moving this pattern +/// at some later time. I am leaving it for now in case there are other users +/// that I am not aware of. +class Transpose2DWithUnitDimToShapeCast + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + Transpose2DWithUnitDimToShapeCast(MLIRContext *context, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit) {} + + LogicalResult matchAndRewrite(vector::TransposeOp op, + PatternRewriter &rewriter) const override { + Value input = op.getVector(); + VectorType resType = op.getResultVectorType(); + + // Set up convenience transposition table. + ArrayRef transp = op.getPermutation(); + + if (resType.getRank() == 2 && + ((resType.getShape().front() == 1 && + !resType.getScalableDims().front()) || + (resType.getShape().back() == 1 && + !resType.getScalableDims().back())) && + transp == ArrayRef({1, 0})) { + rewriter.replaceOpWithNewOp(op, resType, input); + return success(); + } + + return failure(); + } +}; + /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops. /// If the strategy is Shuffle1D, it will be lowered to: /// vector.shape_cast 2D -> 1D @@ -483,6 +523,8 @@ private: void mlir::vector::populateVectorTransposeLoweringPatterns( RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit) { + patterns.add(patterns.getContext(), + benefit); patterns.add( options, patterns.getContext(), benefit); }