mirror of
https://github.com/intel/llvm.git
synced 2026-01-19 01:15:50 +08:00
This change adds a new helper function `mlir::reifyResultShapes` that calls the corresponding interface method and also checks the result produced by the implementation when running in debug mode. Bugs due to incorrect interface implementations can be difficult to debug. This helper function also reduces the amount of code needed at call sites: the cast to `ReifyRankedShapedTypeOpInterface` is done in the helper function. Differential Revision: https://reviews.llvm.org/D145777
76 lines
2.9 KiB
C++
76 lines
2.9 KiB
C++
//===- EmptyOpPatterns.cpp - Patterns related to tensor.empty folding ----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "llvm/Support/Debug.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::tensor;
|
|
|
|
namespace {
|
|
|
|
template <typename ReshapeOp>
|
|
struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern<ReshapeOp> {
|
|
using OpRewritePattern<ReshapeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(ReshapeOp reshapeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (!reshapeOp.getSrc().template getDefiningOp<EmptyOp>())
|
|
return failure();
|
|
Location loc = reshapeOp.getLoc();
|
|
ReifiedRankedShapedTypeDims resultShapes;
|
|
if (failed(reifyResultShapes(rewriter, reshapeOp, resultShapes)) ||
|
|
!llvm::hasSingleElement(resultShapes))
|
|
return failure();
|
|
// TODO: Do not drop tensor type encoding.
|
|
Value emptyTensor = rewriter.create<EmptyOp>(
|
|
loc, resultShapes[0], reshapeOp.getResultType().getElementType());
|
|
if (emptyTensor.getType() != reshapeOp.getResultType()) {
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
|
reshapeOp, reshapeOp.getResultType(), emptyTensor);
|
|
} else {
|
|
rewriter.replaceOp(reshapeOp, emptyTensor);
|
|
}
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// `tensor.empty` does not define any tensor contents, so a slice of a
|
|
/// `tensor.empty` can be canonicalized to a smaller `tensor.empty`.
|
|
struct FoldEmptyTensorWithExtractSliceOp
|
|
: public OpRewritePattern<ExtractSliceOp> {
|
|
using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (!sliceOp.getSource().getDefiningOp<EmptyOp>())
|
|
return failure();
|
|
|
|
// ExtractSliceOp may be rank-reducing; its dynamic sizes must be
|
|
// preserved as well as its result type.
|
|
auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(),
|
|
sliceOp.getType().getElementType(),
|
|
sliceOp.getType().getEncoding());
|
|
rewriter.replaceOpWithNewOp<EmptyOp>(sliceOp, tensorType,
|
|
sliceOp.getSizes());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::tensor::populateFoldTensorEmptyPatterns(
|
|
RewritePatternSet &patterns) {
|
|
patterns.add<FoldEmptyTensorWithExtractSliceOp,
|
|
FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>,
|
|
FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>(
|
|
patterns.getContext());
|
|
}
|