Files
llvm/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
Hanhan Wang 65388086e6 [mlir][tensor] Add patterns that fold ops into pack and unpack ops.
The tensor.pack ops have pad semantic, so we can fold pad + pack into
pack when

1. They have the same padding values or the pack op does not have
   padding values.
2. The pad op does not have low paddings.

The tensor.unpack ops have extract_slice semantic, so we can fold unpack
+ extract_slice into unpack when

1. All the offsets are 0s.
2. All the strides are 1s.

Reviewed By: tyb0807

Differential Revision: https://reviews.llvm.org/D141099
2023-01-11 13:51:49 -08:00

88 lines
3.2 KiB
C++

//===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
namespace mlir {
namespace tensor {
namespace {
static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
return llvm::all_of(
ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
}
/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
/// the pad op has zero low paddings, or if `pack` has no padding values.
struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
using OpRewritePattern<PackOp>::OpRewritePattern;
LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
auto padOp = packOp.getSource().getDefiningOp<PadOp>();
if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
return failure();
Value constantPaddingValue = padOp.getConstantPaddingValue();
if (!constantPaddingValue)
return failure();
if (auto paddingValue = packOp.getPaddingValue())
if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
return failure();
rewriter.replaceOpWithNewOp<PackOp>(
packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
packOp.getMixedTiles(), constantPaddingValue,
packOp.getOuterDimsPerm());
return success();
}
};
/// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
/// has extract_slice semantics.
struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const override {
auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
if (!unpackOp)
return failure();
// Check all offsets are zeros, and all strides are ones.
if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
!areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
return rewriter.notifyMatchFailure(
sliceOp, "expects offsets to be 0s and strides to be 1s");
}
// Create a new empty output tensor.
Type elementType = unpackOp.getDestType().getElementType();
Value output = rewriter.create<EmptyOp>(
sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
rewriter.replaceOpWithNewOp<UnPackOp>(
sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(),
unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
return success();
}
};
} // namespace
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp>(
patterns.getContext());
}
} // namespace tensor
} // namespace mlir