[mlir][Tensor] Fold destination-style ops into tensor.unpack operation. (#71468)

The destination operand of the `tensor.unpack` operation is only needed
to carry shape information. So if the producer of the destination
operand implements the `DestinationStyleOpInterface`, then fold it into
the `tensor.unpack` operation by replacing the destination operand with
the destination for the source.
This commit is contained in:
MaheshRavishankar
2023-11-07 21:42:32 -08:00
committed by GitHub
parent 11c182740a
commit 14e7846d6e
2 changed files with 37 additions and 10 deletions

View File

@@ -3922,18 +3922,29 @@ UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
metadata.outerDimsPerm);
}
/// pack(unpack(x)) -> x
LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
PatternRewriter &rewriter) {
PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>();
if (!packOp || packOp.getDestType() != unPackOp.getSourceType())
return failure();
if (packOp.getPaddingValue() ||
!hasSameInnerOuterAttribute(packOp, unPackOp) ||
!haveSameTiles(packOp, unPackOp))
return failure();
rewriter.replaceOp(unPackOp, packOp.getSource());
return success();
/// pack(unpack(x)) -> x
if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
if (packOp.getDestType() != unPackOp.getSourceType())
return failure();
if (packOp.getPaddingValue() ||
!hasSameInnerOuterAttribute(packOp, unPackOp) ||
!haveSameTiles(packOp, unPackOp))
return failure();
rewriter.replaceOp(unPackOp, packOp.getSource());
return success();
}
/// unpack(destinationStyleOp(x)) -> unpack(x)
if (auto dstStyleOp =
unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
auto destValue = unPackOp.getDest().cast<OpResult>();
Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
rewriter.updateRootInPlace(
unPackOp, [&]() { unPackOp.setDpsInitOperand(0, newDest); });
return success();
}
return failure();
}
bool UnPackOp::isLikeUnPad() {

View File

@@ -1861,3 +1861,19 @@ func.func @invalid_empty_negative_size() -> (tensor<4x5x?xf32>) {
%1 = tensor.empty(%0) : tensor<4x5x?xf32>
return %1 : tensor<4x5x?xf32>
}
// -----
// Fold DstStyleOp -> tensor.unpack operations.
func.func @fold_dst_style_ops_into_unpack(%arg0 : tensor<?x?x16x64xf32>, %init : tensor<?x?xf32>) -> tensor<?x?xf32> {
%cst = arith.constant 0.0 : f32
%fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
%unpack = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %fill : tensor<?x?x16x64xf32> -> tensor<?x?xf32>
return %unpack : tensor<?x?xf32>
}
// CHECK-LABEL: func @fold_dst_style_ops_into_unpack
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x16x64xf32>
// CHECK-SAME: %[[INIT:.+]]: tensor<?x?xf32>
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
// CHECK-SAME: into %[[INIT]]
// CHECK: return %[[UNPACK]]