mirror of
https://github.com/intel/llvm.git
synced 2026-01-16 13:35:38 +08:00
[mlir] Do not bubble up extract slice when it is rank-reducing.
The bubble up logic was written by assuming the slice operation is always a normal slice that outputs a tensor with the same rank. Differential Revision: https://reviews.llvm.org/D124283
This commit is contained in:
@@ -76,6 +76,10 @@ struct BubbleUpExtractSliceOpPattern
|
||||
if (!sliceOp.hasUnitStride())
|
||||
return rewriter.notifyMatchFailure(sliceOp, "expected unit stride");
|
||||
|
||||
if (sliceOp.getType().getRank() != sliceOp.getSourceType().getRank()) {
|
||||
return rewriter.notifyMatchFailure(sliceOp, "expected no rank reduction");
|
||||
}
|
||||
|
||||
OpOperand *outOperand = linalgOp.getOutputOperand(0);
|
||||
AffineMap indexingMap = linalgOp.getTiedIndexingMap(outOperand);
|
||||
if (!indexingMap.isProjectedPermutation()) {
|
||||
|
||||
@@ -156,3 +156,22 @@ func.func @conv_slice(%input: tensor<1x225x225x3xf32>, %filter: tensor<3x3x3x32x
|
||||
// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST:.+]] : f32) outs(%[[SLICE2]] : tensor<1x32x32x16xf32>) -> tensor<1x32x32x16xf32>
|
||||
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[SLICE0]], %[[SLICE1]] : tensor<1x65x65x3xf32>, tensor<3x3x3x16xf32>) outs(%[[FILL]] : tensor<1x32x32x16xf32>) -> tensor<1x32x32x16xf32>
|
||||
// CHECK: return %[[CONV]] : tensor<1x32x32x16xf32>
|
||||
|
||||
//-----
|
||||
|
||||
// The slice is not supposed to be bubbled up when it is rank-reducing.
|
||||
func @rank_reducing_slice(%width : index) -> tensor<1x1x1x?xf32> {
|
||||
%cst = arith.constant 1.000000e+00 : f32
|
||||
%init = linalg.init_tensor [1, %width] : tensor<1x?xf32>
|
||||
%fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x?xf32>) -> tensor<1x?xf32>
|
||||
%slice = tensor.extract_slice %fill[0, 0] [1, %width] [1, 1] : tensor<1x?xf32> to tensor<?xf32>
|
||||
%expand = tensor.expand_shape %slice [[0, 1, 2, 3]] : tensor<?xf32> into tensor<1x1x1x?xf32>
|
||||
return %expand : tensor<1x1x1x?xf32>
|
||||
}
|
||||
|
||||
// CHECK: func @rank_reducing_slice
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor
|
||||
// CHECK: %[[FILL:.+]] = linalg.fill ins
|
||||
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[FILL]]
|
||||
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]]
|
||||
// CHECK: return %[[EXPAND]]
|
||||
|
||||
Reference in New Issue
Block a user