mirror of
https://github.com/intel/llvm.git
synced 2026-02-01 08:56:15 +08:00
[mlir][tensor] add tensor insert/extract op folders (#142458)
Adds a few canonicalizers, folders, and rewrite patterns to tensor ops: * tensor.insert folder: insert into a constant is replaced with a new constant * tensor.extract folder: extract from a parent tensor that was inserted at the same indices is folded into the inserted value * rewrite pattern added that replaces an extract of a collapse shape with an extract of the source tensor (requires static source dimensions) Signed-off-by: Asra Ali <asraa@google.com>
This commit is contained in:
@@ -176,6 +176,10 @@ void populateFoldConstantExtractSlicePatterns(
|
||||
return false;
|
||||
});
|
||||
|
||||
/// Patterns to fold extracts of a collapse_shaped tensor to an extract of the
|
||||
/// source tensor.
|
||||
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns);
|
||||
|
||||
} // namespace tensor
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -827,6 +827,7 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||
#include "mlir/Interfaces/InferIntRangeInterface.h"
|
||||
@@ -33,10 +34,12 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
#include "llvm/Support/MathExtras.h"
|
||||
#include <algorithm>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tensor;
|
||||
@@ -1288,6 +1291,68 @@ struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
|
||||
}
|
||||
};
|
||||
|
||||
/// Canonicalizes the pattern of the form
|
||||
///
|
||||
/// %val = tensor.collapse_shape %src[[0, 1]] : tensor<3x4xf64> into
|
||||
/// tensor<12xf64>
|
||||
/// %extracted_element = tensor.extract %val[%c10] :
|
||||
/// tensor<12xf64>
|
||||
///
|
||||
/// to
|
||||
///
|
||||
/// %extracted_element = tensor.extract %src[%c2, %c2] : tensor<3x4xf64>
|
||||
struct ExtractFromCollapseShape : public OpRewritePattern<tensor::ExtractOp> {
|
||||
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
|
||||
PatternRewriter &rewriter) const final {
|
||||
auto collapseOp =
|
||||
extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
|
||||
if (!collapseOp)
|
||||
return failure();
|
||||
if (!collapseOp.getSrcType().hasStaticShape())
|
||||
return failure();
|
||||
|
||||
auto sourceSizes = collapseOp.getSrcType().getShape();
|
||||
|
||||
SmallVector<Value> indices(extractOp.getIndices().begin(),
|
||||
extractOp.getIndices().end());
|
||||
SmallVector<Value> sourceIndices;
|
||||
for (auto [index, group] :
|
||||
llvm::zip(indices, collapseOp.getReassociationIndices())) {
|
||||
assert(!group.empty() && "association indices groups cannot be empty");
|
||||
auto groupSize = group.size();
|
||||
|
||||
if (groupSize == 1) {
|
||||
sourceIndices.push_back(index);
|
||||
continue;
|
||||
}
|
||||
|
||||
SmallVector<int64_t> basis =
|
||||
llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
|
||||
auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
|
||||
extractOp.getLoc(), index, basis, /*hasOuterBound=*/true);
|
||||
llvm::append_range(sourceIndices, delinearize.getResults());
|
||||
}
|
||||
if (collapseOp.getReassociationIndices().empty()) {
|
||||
auto zeroAffineMap = rewriter.getConstantAffineMap(0);
|
||||
int64_t srcRank =
|
||||
cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
|
||||
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
|
||||
rewriter, extractOp.getLoc(), zeroAffineMap,
|
||||
ArrayRef<OpFoldResult>{});
|
||||
for (int64_t i = 0; i < srcRank; i++) {
|
||||
sourceIndices.push_back(
|
||||
getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), ofr));
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
|
||||
extractOp, collapseOp.getSrc(), sourceIndices);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void ExtractOp::getAsmResultNames(
|
||||
@@ -1303,6 +1368,23 @@ LogicalResult ExtractOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
/// If we have an ExtractOp consuming an InsertOp with the same
|
||||
/// indices, we can return the InsertOp's scalar directly.
|
||||
// TODO: This only checks the immediate producer; extend to go up the
|
||||
// insert/extract chain if the slices are disjoint.
|
||||
static Value foldExtractAfterInsert(ExtractOp extractOp) {
|
||||
auto insertOp = extractOp.getTensor().getDefiningOp<InsertOp>();
|
||||
|
||||
auto isSame = [](Value a, Value b) {
|
||||
return getAsOpFoldResult(a) == getAsOpFoldResult(b);
|
||||
};
|
||||
if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
|
||||
llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
|
||||
return insertOp.getScalar();
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
|
||||
if (Attribute tensor = adaptor.getTensor()) {
|
||||
// If this is a splat elements attribute, simply return the value.
|
||||
@@ -1350,6 +1432,9 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
|
||||
return elementsAttr.getValues<Attribute>()[indices];
|
||||
}
|
||||
|
||||
if (Value result = foldExtractAfterInsert(*this))
|
||||
return result;
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
@@ -1358,6 +1443,11 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
results.add<ExtractFromTensorCast>(context);
|
||||
}
|
||||
|
||||
void mlir::tensor::populateFoldCollapseExtractPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<ExtractFromCollapseShape>(patterns.getContext());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FromElementsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -1534,6 +1624,76 @@ OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
|
||||
// InsertOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
|
||||
/// Pattern to fold an insert op of a constant destination and scalar to a new
|
||||
/// constant.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
|
||||
/// %c0 = arith.constant 0 : index
|
||||
/// %c4_f32 = arith.constant 4.0 : f32
|
||||
/// %1 = tensor.insert %c4_f32 into %0[%c0] : tensor<4xf32>
|
||||
/// ```
|
||||
/// is rewritten into:
|
||||
/// ```
|
||||
/// %1 = arith.constant dense<[4.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
|
||||
/// ```
|
||||
class InsertOpConstantFold final : public OpRewritePattern<InsertOp> {
|
||||
public:
|
||||
using OpRewritePattern<InsertOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(InsertOp insertOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Requires a ranked tensor type.
|
||||
auto destType =
|
||||
llvm::dyn_cast<RankedTensorType>(insertOp.getDest().getType());
|
||||
if (!destType)
|
||||
return failure();
|
||||
|
||||
// Pattern requires constant indices
|
||||
SmallVector<uint64_t, 8> indices;
|
||||
for (OpFoldResult indice : getAsOpFoldResult(insertOp.getIndices())) {
|
||||
auto indiceAttr = dyn_cast<Attribute>(indice);
|
||||
if (!indiceAttr)
|
||||
return failure();
|
||||
indices.push_back(llvm::cast<IntegerAttr>(indiceAttr).getInt());
|
||||
}
|
||||
|
||||
// Requires a constant scalar to insert
|
||||
OpFoldResult scalar = getAsOpFoldResult(insertOp.getScalar());
|
||||
Attribute scalarAttr = dyn_cast<Attribute>(scalar);
|
||||
if (!scalarAttr)
|
||||
return failure();
|
||||
|
||||
if (auto constantOp = dyn_cast_or_null<arith::ConstantOp>(
|
||||
insertOp.getDest().getDefiningOp())) {
|
||||
if (auto sourceAttr =
|
||||
llvm::dyn_cast<ElementsAttr>(constantOp.getValue())) {
|
||||
// Update the attribute at the inserted index.
|
||||
auto sourceValues = sourceAttr.getValues<Attribute>();
|
||||
auto flattenedIndex = sourceAttr.getFlattenedIndex(indices);
|
||||
std::vector<Attribute> updatedValues;
|
||||
updatedValues.reserve(sourceAttr.getNumElements());
|
||||
for (auto i = 0; i < sourceAttr.getNumElements(); ++i) {
|
||||
updatedValues.push_back(i == flattenedIndex ? scalarAttr
|
||||
: sourceValues[i]);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
||||
insertOp, sourceAttr.getType(),
|
||||
DenseElementsAttr::get(cast<ShapedType>(sourceAttr.getType()),
|
||||
updatedValues));
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void InsertOp::getAsmResultNames(
|
||||
function_ref<void(Value, StringRef)> setNameFn) {
|
||||
setNameFn(getResult(), "inserted");
|
||||
@@ -1557,6 +1717,11 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
|
||||
return {};
|
||||
}
|
||||
|
||||
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<InsertOpConstantFold>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenerateOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -163,7 +163,7 @@ func.func @infer_concat_return_type(%arg0: tensor<5x12xi32>, %arg1: tensor<?x12x
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @fold_extract
|
||||
func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
|
||||
func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>, i32) {
|
||||
%const_0 = arith.constant 0 : index
|
||||
%const_1 = arith.constant 1 : index
|
||||
%const_3 = arith.constant 3 : index
|
||||
@@ -193,8 +193,15 @@ func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
|
||||
%4 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
|
||||
%ext_5 = tensor.extract %4[] : tensor<complex<f32>>
|
||||
|
||||
// CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]]
|
||||
return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5 : f32, f16, f16, i32, complex<f32>
|
||||
// Fold an extract after an insert.
|
||||
// CHECK-DAG: [[C6:%.+]] = arith.constant 4 : i32
|
||||
%c4_i32 = arith.constant 4 : i32
|
||||
%5 = arith.constant dense<[[1, 3], [0, 2]]> : tensor<2x2xi32>
|
||||
%inserted = tensor.insert %c4_i32 into %5[%const_1, %const_0] : tensor<2x2xi32>
|
||||
%ext_6 = tensor.extract %inserted[%const_1, %const_0] : tensor<2x2xi32>
|
||||
|
||||
// CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]], [[C6]]
|
||||
return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5, %ext_6 : f32, f16, f16, i32, complex<f32>, i32
|
||||
}
|
||||
|
||||
// -----
|
||||
@@ -224,6 +231,22 @@ func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
|
||||
return %ins_1 : tensor<4xf32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func.func @canonicalize_insert_after_constant() -> (tensor<2x2xi32>) {
|
||||
// Fold an insert into a splat.
|
||||
// CHECK: %[[C4:.+]] = arith.constant dense<{{\[\[}}1, 2], [4, 4]]> : tensor<2x2xi32>
|
||||
// CHECK-LITERAL:
|
||||
// CHECK-NEXT: return %[[C4]]
|
||||
%cst = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c4_i32 = arith.constant 4 : i32
|
||||
%inserted = tensor.insert %c4_i32 into %cst[%c1, %c0] : tensor<2x2xi32>
|
||||
return %inserted : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @extract_from_tensor.cast
|
||||
|
||||
31
mlir/test/Dialect/Tensor/extract-from-collapse-shape.mlir
Normal file
31
mlir/test/Dialect/Tensor/extract-from-collapse-shape.mlir
Normal file
@@ -0,0 +1,31 @@
|
||||
// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-extract-from-collapse-shape %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @extract_from_collapse_shape
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<1x1x8xi8>)
|
||||
func.func @extract_from_collapse_shape(%arg0: tensor<1x1x8xi8>) -> (i8, i8) {
|
||||
%c1 = arith.constant 1 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%collapsed = tensor.collapse_shape %arg0 [[0, 1, 2]] : tensor<1x1x8xi8> into tensor<8xi8>
|
||||
%extracted = tensor.extract %collapsed[%c0] : tensor<8xi8>
|
||||
%extracted_0 = tensor.extract %collapsed[%c1] : tensor<8xi8>
|
||||
func.return %extracted, %extracted_0 : i8, i8
|
||||
}
|
||||
|
||||
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[RESULT0:.*]] = tensor.extract %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] : tensor<1x1x8xi8>
|
||||
// CHECK-DAG: %[[RESULT1:.*]] = tensor.extract %[[ARG0]][%[[C0]], %[[C0]], %[[C1]]] : tensor<1x1x8xi8>
|
||||
// CHECK-NEXT: return %[[RESULT0]], %[[RESULT1]] : i8, i8
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @extract_from_static_shape
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
|
||||
func.func @extract_from_static_shape(%arg0 : tensor<2x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
|
||||
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<2x6x32xf32> into tensor<12x32xf32>
|
||||
%1 = tensor.extract %0[%arg1, %arg2] : tensor<12x32xf32>
|
||||
return %1 : f32
|
||||
}
|
||||
// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (2, 6)
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = tensor.extract %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : tensor<2x6x32xf32>
|
||||
// CHECK-NEXT: return %[[RESULT]] : f32
|
||||
@@ -77,6 +77,11 @@ struct TestTensorTransforms
|
||||
llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
|
||||
llvm::cl::init(false)};
|
||||
|
||||
Option<bool> testFoldExtractFromCollapseShape{
|
||||
*this, "test-fold-extract-from-collapse-shape",
|
||||
llvm::cl::desc("Test folding of extract from collapse_shape"),
|
||||
llvm::cl::init(false)};
|
||||
|
||||
Option<bool> useForeach{
|
||||
*this, "use-foreach",
|
||||
llvm::cl::desc(
|
||||
@@ -132,6 +137,12 @@ applyDropRedundantInsertSliceRankExpansionPatterns(Operation *rootOp) {
|
||||
(void)applyPatternsGreedily(rootOp, std::move(patterns));
|
||||
}
|
||||
|
||||
static void applyFoldExtractFromCollapseShapePatterns(Operation *rootOp) {
|
||||
RewritePatternSet patterns(rootOp->getContext());
|
||||
tensor::populateFoldCollapseExtractPatterns(patterns);
|
||||
(void)applyPatternsGreedily(rootOp, std::move(patterns));
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`.
|
||||
/// The `tensor.extract_slice` is replaced by a loop or gather operation that
|
||||
@@ -380,6 +391,8 @@ void TestTensorTransforms::runOnOperation() {
|
||||
applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))
|
||||
return signalPassFailure();
|
||||
}
|
||||
if (testFoldExtractFromCollapseShape)
|
||||
applyFoldExtractFromCollapseShapePatterns(rootOp);
|
||||
if (testTrackingListener)
|
||||
if (failed(testTrackingListenerReplacements(rootOp)))
|
||||
return signalPassFailure();
|
||||
|
||||
Reference in New Issue
Block a user