[mlir][linalg] Enhance padding LinalgOps to handle tensor.empty cases.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D143043
This commit is contained in:
Hanhan Wang
2023-01-31 20:36:18 -08:00
parent 0b5c51b040
commit 061201ec3d
2 changed files with 54 additions and 7 deletions

View File

@@ -115,19 +115,28 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
currOpOperand = linalgOp.getDpsInitOperand(result.getResultNumber());
}
// Fail if `currOpOperand` is not defined by an ExtractSliceOp.
// Fail if `currOpOperand` is not defined by an ExtractSliceOp or EmptyOp.
auto sliceOp = currOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
if (!sliceOp)
auto emptyOp = currOpOperand->get().getDefiningOp<tensor::EmptyOp>();
if (!sliceOp && !emptyOp)
return failure();
// Compute the dropped dimensions if `sliceOp` is ranke-reducing.
llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
OffsetSizeAndStrideOpInterface shapedOp = sliceOp;
llvm::SmallBitVector droppedDims;
SmallVector<OpFoldResult> mixedSizes;
if (sliceOp) {
// Compute the dropped dimensions if `sliceOp` is ranke-reducing.
droppedDims = sliceOp.getDroppedDims();
mixedSizes = sliceOp.getMixedSizes();
}
if (emptyOp) {
mixedSizes = emptyOp.getMixedSizes();
droppedDims.resize(mixedSizes.size());
}
// Upper bound the `sliceOp` sizes to obtain a static bounding box.
// Upper bound the sizes to obtain a static bounding box.
SmallVector<int64_t> paddedShape(shape.begin(), shape.end());
int64_t shapeIdx = 0;
for (const auto &en : enumerate(shapedOp.getMixedSizes())) {
for (const auto &en : enumerate(mixedSizes)) {
// Skip dropped dimensions.
if (droppedDims.test(en.index()))
continue;

View File

@@ -39,6 +39,44 @@ transform.sequence failures(propagate) {
// -----
#map = affine_map<()[s0] -> (-s0 + 12, 7)>
// CHECK-LABEL: @static_sizes_output_divisible_on_empty_op
func.func @static_sizes_output_divisible_on_empty_op(%arg0: tensor<24x12xf32>,
%arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>, %iv0: index,
%iv1: index, %iv2: index) -> tensor<24x25xf32> {
%0 = affine.min #map()[%iv2]
// CHECK: %[[T0:.*]] = tensor.empty
// CHECK: %[[T1:.*]] = tensor.empty
// CHECK: %[[T2:.*]] = tensor.empty
%1 = tensor.empty(%0) : tensor<4x?xf32>
%2 = tensor.empty(%0) : tensor<?x5xf32>
%3 = tensor.empty() : tensor<4x5xf32>
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[T3:.*]] = tensor.pad %[[T0]] nofold
// CHECK: tensor.yield %[[CST]]
// CHECK: %[[T4:.*]] = tensor.pad %[[T1]] nofold
// CHECK: %[[T5:.*]] = linalg.matmul
// CHECK-SAME: ins(%[[T3]], %[[T4]] : tensor<4x7xf32>, tensor<7x5xf32>)
// CHECK-SAME: outs(%[[T2]] : tensor<4x5xf32>)
%4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor<?x5xf32>) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32>
%5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32>
func.return %5 : tensor<24x25xf32>
}
transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation
%1 = transform.structured.pad %0 {padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]}
}
// -----
func.func @pad(%arg0: tensor<24x12xf32>,
%arg1: tensor<12x25xf32>,
%arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {