[mlir][tensor] add tensor insert/extract op folders (#142458)

Adds a few canonicalizers, folders, and rewrite patterns to tensor ops:

* tensor.insert folder: insert into a constant is replaced with a new
constant
* tensor.extract folder: extract from a parent tensor that was inserted
at the same indices is folded into the inserted value
* rewrite pattern added that replaces an extract of a collapse shape
with an extract of the source tensor (requires static source dimensions)

Signed-off-by: Asra Ali <asraa@google.com>
This commit is contained in:
asraa
2025-06-03 11:16:03 -05:00
committed by GitHub
parent b9dec5aa79
commit 34d8275e4f
6 changed files with 240 additions and 3 deletions

View File

@@ -176,6 +176,10 @@ void populateFoldConstantExtractSlicePatterns(
return false;
});
/// Patterns to fold extracts of a collapse_shaped tensor to an extract of the
/// source tensor.
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns);
} // namespace tensor
} // namespace mlir

View File

@@ -827,6 +827,7 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
let hasFolder = 1;
let hasVerifier = 1;
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//

View File

@@ -22,6 +22,7 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
@@ -33,10 +34,12 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MathExtras.h"
#include <algorithm>
#include <optional>
#include <vector>
using namespace mlir;
using namespace mlir::tensor;
@@ -1288,6 +1291,68 @@ struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
}
};
/// Canonicalizes the pattern of the form
///
/// %val = tensor.collapse_shape %src[[0, 1]] : tensor<3x4xf64> into
/// tensor<12xf64>
/// %extracted_element = tensor.extract %val[%c10] :
/// tensor<12xf64>
///
/// to
///
/// %extracted_element = tensor.extract %src[%c2, %c2] : tensor<3x4xf64>
struct ExtractFromCollapseShape : public OpRewritePattern<tensor::ExtractOp> {
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
PatternRewriter &rewriter) const final {
auto collapseOp =
extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
if (!collapseOp)
return failure();
if (!collapseOp.getSrcType().hasStaticShape())
return failure();
auto sourceSizes = collapseOp.getSrcType().getShape();
SmallVector<Value> indices(extractOp.getIndices().begin(),
extractOp.getIndices().end());
SmallVector<Value> sourceIndices;
for (auto [index, group] :
llvm::zip(indices, collapseOp.getReassociationIndices())) {
assert(!group.empty() && "association indices groups cannot be empty");
auto groupSize = group.size();
if (groupSize == 1) {
sourceIndices.push_back(index);
continue;
}
SmallVector<int64_t> basis =
llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
extractOp.getLoc(), index, basis, /*hasOuterBound=*/true);
llvm::append_range(sourceIndices, delinearize.getResults());
}
if (collapseOp.getReassociationIndices().empty()) {
auto zeroAffineMap = rewriter.getConstantAffineMap(0);
int64_t srcRank =
cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
rewriter, extractOp.getLoc(), zeroAffineMap,
ArrayRef<OpFoldResult>{});
for (int64_t i = 0; i < srcRank; i++) {
sourceIndices.push_back(
getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), ofr));
}
}
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
extractOp, collapseOp.getSrc(), sourceIndices);
return success();
}
};
} // namespace
void ExtractOp::getAsmResultNames(
@@ -1303,6 +1368,23 @@ LogicalResult ExtractOp::verify() {
return success();
}
/// If we have an ExtractOp consuming an InsertOp with the same
/// indices, we can return the InsertOp's scalar directly.
// TODO: This only checks the immediate producer; extend to go up the
// insert/extract chain if the slices are disjoint.
static Value foldExtractAfterInsert(ExtractOp extractOp) {
auto insertOp = extractOp.getTensor().getDefiningOp<InsertOp>();
auto isSame = [](Value a, Value b) {
return getAsOpFoldResult(a) == getAsOpFoldResult(b);
};
if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
return insertOp.getScalar();
return {};
}
OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
if (Attribute tensor = adaptor.getTensor()) {
// If this is a splat elements attribute, simply return the value.
@@ -1350,6 +1432,9 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
return elementsAttr.getValues<Attribute>()[indices];
}
if (Value result = foldExtractAfterInsert(*this))
return result;
return {};
}
@@ -1358,6 +1443,11 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ExtractFromTensorCast>(context);
}
void mlir::tensor::populateFoldCollapseExtractPatterns(
RewritePatternSet &patterns) {
patterns.add<ExtractFromCollapseShape>(patterns.getContext());
}
//===----------------------------------------------------------------------===//
// FromElementsOp
//===----------------------------------------------------------------------===//
@@ -1534,6 +1624,76 @@ OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
// InsertOp
//===----------------------------------------------------------------------===//
namespace {
/// Pattern to fold an insert op of a constant destination and scalar to a new
/// constant.
///
/// Example:
/// ```
/// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
/// %c0 = arith.constant 0 : index
/// %c4_f32 = arith.constant 4.0 : f32
/// %1 = tensor.insert %c4_f32 into %0[%c0] : tensor<4xf32>
/// ```
/// is rewritten into:
/// ```
/// %1 = arith.constant dense<[4.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
/// ```
class InsertOpConstantFold final : public OpRewritePattern<InsertOp> {
public:
using OpRewritePattern<InsertOp>::OpRewritePattern;
LogicalResult matchAndRewrite(InsertOp insertOp,
PatternRewriter &rewriter) const override {
// Requires a ranked tensor type.
auto destType =
llvm::dyn_cast<RankedTensorType>(insertOp.getDest().getType());
if (!destType)
return failure();
// Pattern requires constant indices
SmallVector<uint64_t, 8> indices;
for (OpFoldResult indice : getAsOpFoldResult(insertOp.getIndices())) {
auto indiceAttr = dyn_cast<Attribute>(indice);
if (!indiceAttr)
return failure();
indices.push_back(llvm::cast<IntegerAttr>(indiceAttr).getInt());
}
// Requires a constant scalar to insert
OpFoldResult scalar = getAsOpFoldResult(insertOp.getScalar());
Attribute scalarAttr = dyn_cast<Attribute>(scalar);
if (!scalarAttr)
return failure();
if (auto constantOp = dyn_cast_or_null<arith::ConstantOp>(
insertOp.getDest().getDefiningOp())) {
if (auto sourceAttr =
llvm::dyn_cast<ElementsAttr>(constantOp.getValue())) {
// Update the attribute at the inserted index.
auto sourceValues = sourceAttr.getValues<Attribute>();
auto flattenedIndex = sourceAttr.getFlattenedIndex(indices);
std::vector<Attribute> updatedValues;
updatedValues.reserve(sourceAttr.getNumElements());
for (auto i = 0; i < sourceAttr.getNumElements(); ++i) {
updatedValues.push_back(i == flattenedIndex ? scalarAttr
: sourceValues[i]);
}
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
insertOp, sourceAttr.getType(),
DenseElementsAttr::get(cast<ShapedType>(sourceAttr.getType()),
updatedValues));
return success();
}
}
return failure();
}
};
} // namespace
void InsertOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "inserted");
@@ -1557,6 +1717,11 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
return {};
}
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<InsertOpConstantFold>(context);
}
//===----------------------------------------------------------------------===//
// GenerateOp
//===----------------------------------------------------------------------===//

View File

@@ -163,7 +163,7 @@ func.func @infer_concat_return_type(%arg0: tensor<5x12xi32>, %arg1: tensor<?x12x
// -----
// CHECK-LABEL: func @fold_extract
func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>, i32) {
%const_0 = arith.constant 0 : index
%const_1 = arith.constant 1 : index
%const_3 = arith.constant 3 : index
@@ -193,8 +193,15 @@ func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
%4 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
%ext_5 = tensor.extract %4[] : tensor<complex<f32>>
// CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]]
return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5 : f32, f16, f16, i32, complex<f32>
// Fold an extract after an insert.
// CHECK-DAG: [[C6:%.+]] = arith.constant 4 : i32
%c4_i32 = arith.constant 4 : i32
%5 = arith.constant dense<[[1, 3], [0, 2]]> : tensor<2x2xi32>
%inserted = tensor.insert %c4_i32 into %5[%const_1, %const_0] : tensor<2x2xi32>
%ext_6 = tensor.extract %inserted[%const_1, %const_0] : tensor<2x2xi32>
// CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]], [[C6]]
return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5, %ext_6 : f32, f16, f16, i32, complex<f32>, i32
}
// -----
@@ -224,6 +231,22 @@ func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
return %ins_1 : tensor<4xf32>
}
// -----
func.func @canonicalize_insert_after_constant() -> (tensor<2x2xi32>) {
// Fold an insert into a splat.
// CHECK: %[[C4:.+]] = arith.constant dense<{{\[\[}}1, 2], [4, 4]]> : tensor<2x2xi32>
// CHECK-LITERAL:
// CHECK-NEXT: return %[[C4]]
%cst = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4_i32 = arith.constant 4 : i32
%inserted = tensor.insert %c4_i32 into %cst[%c1, %c0] : tensor<2x2xi32>
return %inserted : tensor<2x2xi32>
}
// -----
// CHECK-LABEL: func @extract_from_tensor.cast

View File

@@ -0,0 +1,31 @@
// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-extract-from-collapse-shape %s | FileCheck %s
// CHECK-LABEL: @extract_from_collapse_shape
// CHECK-SAME: (%[[ARG0:.*]]: tensor<1x1x8xi8>)
func.func @extract_from_collapse_shape(%arg0: tensor<1x1x8xi8>) -> (i8, i8) {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%collapsed = tensor.collapse_shape %arg0 [[0, 1, 2]] : tensor<1x1x8xi8> into tensor<8xi8>
%extracted = tensor.extract %collapsed[%c0] : tensor<8xi8>
%extracted_0 = tensor.extract %collapsed[%c1] : tensor<8xi8>
func.return %extracted, %extracted_0 : i8, i8
}
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[RESULT0:.*]] = tensor.extract %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] : tensor<1x1x8xi8>
// CHECK-DAG: %[[RESULT1:.*]] = tensor.extract %[[ARG0]][%[[C0]], %[[C0]], %[[C1]]] : tensor<1x1x8xi8>
// CHECK-NEXT: return %[[RESULT0]], %[[RESULT1]] : i8, i8
// -----
// CHECK-LABEL: @extract_from_static_shape
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
func.func @extract_from_static_shape(%arg0 : tensor<2x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<2x6x32xf32> into tensor<12x32xf32>
%1 = tensor.extract %0[%arg1, %arg2] : tensor<12x32xf32>
return %1 : f32
}
// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (2, 6)
// CHECK-NEXT: %[[RESULT:.*]] = tensor.extract %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : tensor<2x6x32xf32>
// CHECK-NEXT: return %[[RESULT]] : f32

View File

@@ -77,6 +77,11 @@ struct TestTensorTransforms
llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
llvm::cl::init(false)};
Option<bool> testFoldExtractFromCollapseShape{
*this, "test-fold-extract-from-collapse-shape",
llvm::cl::desc("Test folding of extract from collapse_shape"),
llvm::cl::init(false)};
Option<bool> useForeach{
*this, "use-foreach",
llvm::cl::desc(
@@ -132,6 +137,12 @@ applyDropRedundantInsertSliceRankExpansionPatterns(Operation *rootOp) {
(void)applyPatternsGreedily(rootOp, std::move(patterns));
}
static void applyFoldExtractFromCollapseShapePatterns(Operation *rootOp) {
RewritePatternSet patterns(rootOp->getContext());
tensor::populateFoldCollapseExtractPatterns(patterns);
(void)applyPatternsGreedily(rootOp, std::move(patterns));
}
namespace {
/// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`.
/// The `tensor.extract_slice` is replaced by a loop or gather operation that
@@ -380,6 +391,8 @@ void TestTensorTransforms::runOnOperation() {
applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))
return signalPassFailure();
}
if (testFoldExtractFromCollapseShape)
applyFoldExtractFromCollapseShapePatterns(rootOp);
if (testTrackingListener)
if (failed(testTrackingListenerReplacements(rootOp)))
return signalPassFailure();