[mlir] Do not bubble up extract slice when it is rank-reducing.

The bubble up logic was written by assuming the slice operation is
always a normal slice that outputs a tensor with the same rank.

Differential Revision: https://reviews.llvm.org/D124283
This commit is contained in:
Okwan Kwon
2022-04-22 10:49:22 -07:00
parent c94a02e0e2
commit ee285faed2
2 changed files with 23 additions and 0 deletions

View File

@@ -76,6 +76,10 @@ struct BubbleUpExtractSliceOpPattern
if (!sliceOp.hasUnitStride())
return rewriter.notifyMatchFailure(sliceOp, "expected unit stride");
if (sliceOp.getType().getRank() != sliceOp.getSourceType().getRank()) {
return rewriter.notifyMatchFailure(sliceOp, "expected no rank reduction");
}
OpOperand *outOperand = linalgOp.getOutputOperand(0);
AffineMap indexingMap = linalgOp.getTiedIndexingMap(outOperand);
if (!indexingMap.isProjectedPermutation()) {

View File

@@ -156,3 +156,22 @@ func.func @conv_slice(%input: tensor<1x225x225x3xf32>, %filter: tensor<3x3x3x32x
// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST:.+]] : f32) outs(%[[SLICE2]] : tensor<1x32x32x16xf32>) -> tensor<1x32x32x16xf32>
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[SLICE0]], %[[SLICE1]] : tensor<1x65x65x3xf32>, tensor<3x3x3x16xf32>) outs(%[[FILL]] : tensor<1x32x32x16xf32>) -> tensor<1x32x32x16xf32>
// CHECK: return %[[CONV]] : tensor<1x32x32x16xf32>
//-----
// The slice is not supposed to be bubbled up when it is rank-reducing.
func @rank_reducing_slice(%width : index) -> tensor<1x1x1x?xf32> {
%cst = arith.constant 1.000000e+00 : f32
%init = linalg.init_tensor [1, %width] : tensor<1x?xf32>
%fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x?xf32>) -> tensor<1x?xf32>
%slice = tensor.extract_slice %fill[0, 0] [1, %width] [1, 1] : tensor<1x?xf32> to tensor<?xf32>
%expand = tensor.expand_shape %slice [[0, 1, 2, 3]] : tensor<?xf32> into tensor<1x1x1x?xf32>
return %expand : tensor<1x1x1x?xf32>
}
// CHECK: func @rank_reducing_slice
// CHECK: %[[INIT:.+]] = linalg.init_tensor
// CHECK: %[[FILL:.+]] = linalg.fill ins
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[FILL]]
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]]
// CHECK: return %[[EXPAND]]