mirror of
https://github.com/intel/llvm.git
synced 2026-02-01 08:56:15 +08:00
[mlir][Tensor] Fold destination-style ops into tensor.unpack operation. (#71468)
The destination operand of the `tensor.unpack` operation is only needed to carry shape information. So if the producer of the destination operand implements the `DestinationStyleOpInterface`, then fold it into the `tensor.unpack` operation by replacing the destination operand with the destination for the source.
This commit is contained in:
committed by
GitHub
parent
11c182740a
commit
14e7846d6e
@@ -3922,18 +3922,29 @@ UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
|
||||
metadata.outerDimsPerm);
|
||||
}
|
||||
|
||||
/// pack(unpack(x)) -> x
|
||||
LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
|
||||
PatternRewriter &rewriter) {
|
||||
PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>();
|
||||
if (!packOp || packOp.getDestType() != unPackOp.getSourceType())
|
||||
return failure();
|
||||
if (packOp.getPaddingValue() ||
|
||||
!hasSameInnerOuterAttribute(packOp, unPackOp) ||
|
||||
!haveSameTiles(packOp, unPackOp))
|
||||
return failure();
|
||||
rewriter.replaceOp(unPackOp, packOp.getSource());
|
||||
return success();
|
||||
/// pack(unpack(x)) -> x
|
||||
if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
|
||||
if (packOp.getDestType() != unPackOp.getSourceType())
|
||||
return failure();
|
||||
if (packOp.getPaddingValue() ||
|
||||
!hasSameInnerOuterAttribute(packOp, unPackOp) ||
|
||||
!haveSameTiles(packOp, unPackOp))
|
||||
return failure();
|
||||
rewriter.replaceOp(unPackOp, packOp.getSource());
|
||||
return success();
|
||||
}
|
||||
/// unpack(destinationStyleOp(x)) -> unpack(x)
|
||||
if (auto dstStyleOp =
|
||||
unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
|
||||
auto destValue = unPackOp.getDest().cast<OpResult>();
|
||||
Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
|
||||
rewriter.updateRootInPlace(
|
||||
unPackOp, [&]() { unPackOp.setDpsInitOperand(0, newDest); });
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
bool UnPackOp::isLikeUnPad() {
|
||||
|
||||
@@ -1861,3 +1861,19 @@ func.func @invalid_empty_negative_size() -> (tensor<4x5x?xf32>) {
|
||||
%1 = tensor.empty(%0) : tensor<4x5x?xf32>
|
||||
return %1 : tensor<4x5x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Fold DstStyleOp -> tensor.unpack operations.
|
||||
func.func @fold_dst_style_ops_into_unpack(%arg0 : tensor<?x?x16x64xf32>, %init : tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%cst = arith.constant 0.0 : f32
|
||||
%fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%unpack = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %fill : tensor<?x?x16x64xf32> -> tensor<?x?xf32>
|
||||
return %unpack : tensor<?x?xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @fold_dst_style_ops_into_unpack
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x16x64xf32>
|
||||
// CHECK-SAME: %[[INIT:.+]]: tensor<?x?xf32>
|
||||
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
|
||||
// CHECK-SAME: into %[[INIT]]
|
||||
// CHECK: return %[[UNPACK]]
|
||||
|
||||
Reference in New Issue
Block a user