mirror of
https://github.com/intel/llvm.git
synced 2026-01-23 16:06:39 +08:00
[mlir][vector] Extend pattern to trim lead unit dimension to Splat Op
Differential Revision: https://reviews.llvm.org/D102091
This commit is contained in:
@@ -3175,27 +3175,31 @@ struct CastAwayTransferWriteLeadingOneDim
|
||||
}
|
||||
};
|
||||
|
||||
struct CastAwayBroadcastLeadingOneDim
|
||||
: public OpRewritePattern<vector::BroadcastOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
template <typename BroadCastType>
|
||||
struct CastAwayBroadcastLeadingOneDim : public OpRewritePattern<BroadCastType> {
|
||||
using OpRewritePattern<BroadCastType>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
|
||||
LogicalResult matchAndRewrite(BroadCastType broadcastOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
VectorType newDstType = trimLeadingOneDims(broadcastOp.getVectorType());
|
||||
if (newDstType == broadcastOp.getVectorType())
|
||||
VectorType dstType =
|
||||
broadcastOp.getResult().getType().template dyn_cast<VectorType>();
|
||||
if (!dstType)
|
||||
return failure();
|
||||
VectorType newDstType = trimLeadingOneDims(dstType);
|
||||
if (newDstType == dstType)
|
||||
return failure();
|
||||
Location loc = broadcastOp.getLoc();
|
||||
VectorType srcVecType = broadcastOp.getSourceType().dyn_cast<VectorType>();
|
||||
Value source = broadcastOp->getOperand(0);
|
||||
VectorType srcVecType = source.getType().template dyn_cast<VectorType>();
|
||||
if (srcVecType)
|
||||
srcVecType = trimLeadingOneDims(srcVecType);
|
||||
Value source = broadcastOp.source();
|
||||
if (srcVecType && srcVecType != broadcastOp.getSourceType()) {
|
||||
if (srcVecType && srcVecType != source.getType()) {
|
||||
source = rewriter.create<vector::ShapeCastOp>(loc, srcVecType, source);
|
||||
}
|
||||
Value newBroadcastOp =
|
||||
rewriter.create<vector::BroadcastOp>(loc, newDstType, source);
|
||||
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
|
||||
broadcastOp, broadcastOp.getVectorType(), newBroadcastOp);
|
||||
rewriter.create<BroadCastType>(loc, newDstType, source);
|
||||
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcastOp, dstType,
|
||||
newBroadcastOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -3833,13 +3837,13 @@ void mlir::vector::populateSplitVectorTransferPatterns(
|
||||
|
||||
void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns
|
||||
.add<CastAwayExtractStridedSliceLeadingOneDim,
|
||||
CastAwayInsertStridedSliceLeadingOneDim,
|
||||
CastAwayTransferReadLeadingOneDim,
|
||||
CastAwayTransferWriteLeadingOneDim, CastAwayBroadcastLeadingOneDim,
|
||||
CastAwayElementwiseLeadingOneDim, ShapeCastOpFolder>(
|
||||
patterns.getContext());
|
||||
patterns.add<
|
||||
CastAwayExtractStridedSliceLeadingOneDim,
|
||||
CastAwayInsertStridedSliceLeadingOneDim,
|
||||
CastAwayTransferReadLeadingOneDim, CastAwayTransferWriteLeadingOneDim,
|
||||
CastAwayBroadcastLeadingOneDim<vector::BroadcastOp>,
|
||||
CastAwayBroadcastLeadingOneDim<SplatOp>, CastAwayElementwiseLeadingOneDim,
|
||||
ShapeCastOpFolder>(patterns.getContext());
|
||||
}
|
||||
|
||||
void mlir::vector::populateBubbleVectorBitCastOpPatterns(
|
||||
|
||||
@@ -675,7 +675,7 @@ func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x
|
||||
// CHECK-LABEL: func @cast_away_broadcast_leading_one_dims
|
||||
func @cast_away_broadcast_leading_one_dims(
|
||||
%arg0: vector<8xf32>, %arg1: f32, %arg2: vector<1x4xf32>) ->
|
||||
(vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>) {
|
||||
(vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>, vector<1x1x4xf32>) {
|
||||
// CHECK: vector.broadcast %{{.*}} : vector<8xf32> to vector<8xf32>
|
||||
// CHECK: vector.shape_cast %{{.*}} : vector<8xf32> to vector<1x1x8xf32>
|
||||
%0 = vector.broadcast %arg0 : vector<8xf32> to vector<1x1x8xf32>
|
||||
@@ -686,7 +686,10 @@ func @cast_away_broadcast_leading_one_dims(
|
||||
// CHECK: vector.broadcast %{{.*}} : vector<4xf32> to vector<3x4xf32>
|
||||
// CHECK: vector.shape_cast %{{.*}} : vector<3x4xf32> to vector<1x3x4xf32>
|
||||
%2 = vector.broadcast %arg2 : vector<1x4xf32> to vector<1x3x4xf32>
|
||||
return %0, %1, %2: vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>
|
||||
// CHECK: splat %{{.*}} : vector<4xf32>
|
||||
// CHECK: vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x1x4xf32>
|
||||
%3 = splat %arg1 : vector<1x1x4xf32>
|
||||
return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>, vector<1x1x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @cast_away_elementwise_leading_one_dims
|
||||
|
||||
Reference in New Issue
Block a user