From 34d8275e4fcd619226e2872ea0ee07f8a1634ff7 Mon Sep 17 00:00:00 2001 From: asraa Date: Tue, 3 Jun 2025 11:16:03 -0500 Subject: [PATCH] [mlir][tensor] add tensor insert/extract op folders (#142458) Adds a few canonicalizers, folders, and rewrite patterns to tensor ops: * tensor.insert folder: insert into a constant is replaced with a new constant * tensor.extract folder: extract from a parent tensor that was inserted at the same indices is folded into the inserted value * rewrite pattern added that replaces an extract of a collapse shape with an extract of the source tensor (requires static source dimensions) Signed-off-by: Asra Ali --- mlir/include/mlir/Dialect/Tensor/IR/Tensor.h | 4 + .../mlir/Dialect/Tensor/IR/TensorOps.td | 1 + mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 165 ++++++++++++++++++ mlir/test/Dialect/Tensor/canonicalize.mlir | 29 ++- .../Tensor/extract-from-collapse-shape.mlir | 31 ++++ .../Dialect/Tensor/TestTensorTransforms.cpp | 13 ++ 6 files changed, 240 insertions(+), 3 deletions(-) create mode 100644 mlir/test/Dialect/Tensor/extract-from-collapse-shape.mlir diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h index eb550bb469b9..e8e1342ef36f 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -176,6 +176,10 @@ void populateFoldConstantExtractSlicePatterns( return false; }); +/// Patterns to fold extracts of a collapse_shaped tensor to an extract of the +/// source tensor. +void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns); + } // namespace tensor } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 35d0b1662841..c0885a376382 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -827,6 +827,7 @@ def Tensor_InsertOp : Tensor_Op<"insert", [ let hasFolder = 1; let hasVerifier = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 30ca20fc0d88..f2a7220b4bed 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/InferIntRangeInterface.h" @@ -33,10 +34,12 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/LogicalResult.h" #include "llvm/Support/MathExtras.h" #include #include +#include using namespace mlir; using namespace mlir::tensor; @@ -1288,6 +1291,68 @@ struct ExtractFromTensorCast : public OpRewritePattern { } }; +/// Canonicalizes the pattern of the form +/// +/// %val = tensor.collapse_shape %src[[0, 1]] : tensor<3x4xf64> into +/// tensor<12xf64> +/// %extracted_element = tensor.extract %val[%c10] : +/// tensor<12xf64> +/// +/// to +/// +/// %extracted_element = tensor.extract %src[%c2, %c2] : tensor<3x4xf64> +struct ExtractFromCollapseShape : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractOp extractOp, + PatternRewriter &rewriter) const final { + auto collapseOp = + extractOp.getTensor().getDefiningOp(); + if (!collapseOp) + return failure(); + if (!collapseOp.getSrcType().hasStaticShape()) + return failure(); + + auto sourceSizes = collapseOp.getSrcType().getShape(); + + SmallVector indices(extractOp.getIndices().begin(), + extractOp.getIndices().end()); + SmallVector sourceIndices; + for (auto [index, group] : + llvm::zip(indices, collapseOp.getReassociationIndices())) { + assert(!group.empty() && "association indices groups cannot be empty"); + auto groupSize = group.size(); + + if (groupSize == 1) { + sourceIndices.push_back(index); + continue; + } + + SmallVector basis = + llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; }); + auto delinearize = rewriter.create( + extractOp.getLoc(), index, basis, /*hasOuterBound=*/true); + llvm::append_range(sourceIndices, delinearize.getResults()); + } + if (collapseOp.getReassociationIndices().empty()) { + auto zeroAffineMap = rewriter.getConstantAffineMap(0); + int64_t srcRank = + cast(collapseOp.getSrcType()).getRank(); + OpFoldResult ofr = affine::makeComposedFoldedAffineApply( + rewriter, extractOp.getLoc(), zeroAffineMap, + ArrayRef{}); + for (int64_t i = 0; i < srcRank; i++) { + sourceIndices.push_back( + getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), ofr)); + } + } + + rewriter.replaceOpWithNewOp( + extractOp, collapseOp.getSrc(), sourceIndices); + return success(); + } +}; + } // namespace void ExtractOp::getAsmResultNames( @@ -1303,6 +1368,23 @@ LogicalResult ExtractOp::verify() { return success(); } +/// If we have an ExtractOp consuming an InsertOp with the same +/// indices, we can return the InsertOp's scalar directly. +// TODO: This only checks the immediate producer; extend to go up the +// insert/extract chain if the slices are disjoint. +static Value foldExtractAfterInsert(ExtractOp extractOp) { + auto insertOp = extractOp.getTensor().getDefiningOp(); + + auto isSame = [](Value a, Value b) { + return getAsOpFoldResult(a) == getAsOpFoldResult(b); + }; + if (insertOp && insertOp.getScalar().getType() == extractOp.getType() && + llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame)) + return insertOp.getScalar(); + + return {}; +} + OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) { if (Attribute tensor = adaptor.getTensor()) { // If this is a splat elements attribute, simply return the value. @@ -1350,6 +1432,9 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) { return elementsAttr.getValues()[indices]; } + if (Value result = foldExtractAfterInsert(*this)) + return result; + return {}; } @@ -1358,6 +1443,11 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); } +void mlir::tensor::populateFoldCollapseExtractPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + //===----------------------------------------------------------------------===// // FromElementsOp //===----------------------------------------------------------------------===// @@ -1534,6 +1624,76 @@ OpFoldResult GatherOp::fold(FoldAdaptor adaptor) { // InsertOp //===----------------------------------------------------------------------===// +namespace { + +/// Pattern to fold an insert op of a constant destination and scalar to a new +/// constant. +/// +/// Example: +/// ``` +/// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32> +/// %c0 = arith.constant 0 : index +/// %c4_f32 = arith.constant 4.0 : f32 +/// %1 = tensor.insert %c4_f32 into %0[%c0] : tensor<4xf32> +/// ``` +/// is rewritten into: +/// ``` +/// %1 = arith.constant dense<[4.0, 2.0, 3.0, 4.0]> : tensor<4xf32> +/// ``` +class InsertOpConstantFold final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InsertOp insertOp, + PatternRewriter &rewriter) const override { + // Requires a ranked tensor type. + auto destType = + llvm::dyn_cast(insertOp.getDest().getType()); + if (!destType) + return failure(); + + // Pattern requires constant indices + SmallVector indices; + for (OpFoldResult indice : getAsOpFoldResult(insertOp.getIndices())) { + auto indiceAttr = dyn_cast(indice); + if (!indiceAttr) + return failure(); + indices.push_back(llvm::cast(indiceAttr).getInt()); + } + + // Requires a constant scalar to insert + OpFoldResult scalar = getAsOpFoldResult(insertOp.getScalar()); + Attribute scalarAttr = dyn_cast(scalar); + if (!scalarAttr) + return failure(); + + if (auto constantOp = dyn_cast_or_null( + insertOp.getDest().getDefiningOp())) { + if (auto sourceAttr = + llvm::dyn_cast(constantOp.getValue())) { + // Update the attribute at the inserted index. + auto sourceValues = sourceAttr.getValues(); + auto flattenedIndex = sourceAttr.getFlattenedIndex(indices); + std::vector updatedValues; + updatedValues.reserve(sourceAttr.getNumElements()); + for (auto i = 0; i < sourceAttr.getNumElements(); ++i) { + updatedValues.push_back(i == flattenedIndex ? scalarAttr + : sourceValues[i]); + } + rewriter.replaceOpWithNewOp( + insertOp, sourceAttr.getType(), + DenseElementsAttr::get(cast(sourceAttr.getType()), + updatedValues)); + return success(); + } + } + + return failure(); + } +}; + +} // namespace + void InsertOp::getAsmResultNames( function_ref setNameFn) { setNameFn(getResult(), "inserted"); @@ -1557,6 +1717,11 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) { return {}; } +void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // GenerateOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 3eaf824b9911..646b2197d9aa 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -163,7 +163,7 @@ func.func @infer_concat_return_type(%arg0: tensor<5x12xi32>, %arg1: tensor (f32, f16, f16, i32, complex) { +func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex, i32) { %const_0 = arith.constant 0 : index %const_1 = arith.constant 1 : index %const_3 = arith.constant 3 : index @@ -193,8 +193,15 @@ func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex) { %4 = arith.constant dense<(1.2, 2.3)> : tensor> %ext_5 = tensor.extract %4[] : tensor> - // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]] - return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5 : f32, f16, f16, i32, complex + // Fold an extract after an insert. + // CHECK-DAG: [[C6:%.+]] = arith.constant 4 : i32 + %c4_i32 = arith.constant 4 : i32 + %5 = arith.constant dense<[[1, 3], [0, 2]]> : tensor<2x2xi32> + %inserted = tensor.insert %c4_i32 into %5[%const_1, %const_0] : tensor<2x2xi32> + %ext_6 = tensor.extract %inserted[%const_1, %const_0] : tensor<2x2xi32> + + // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]], [[C6]] + return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5, %ext_6 : f32, f16, f16, i32, complex, i32 } // ----- @@ -224,6 +231,22 @@ func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) { return %ins_1 : tensor<4xf32> } + +// ----- + +func.func @canonicalize_insert_after_constant() -> (tensor<2x2xi32>) { + // Fold an insert into a splat. + // CHECK: %[[C4:.+]] = arith.constant dense<{{\[\[}}1, 2], [4, 4]]> : tensor<2x2xi32> + // CHECK-LITERAL: + // CHECK-NEXT: return %[[C4]] + %cst = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4_i32 = arith.constant 4 : i32 + %inserted = tensor.insert %c4_i32 into %cst[%c1, %c0] : tensor<2x2xi32> + return %inserted : tensor<2x2xi32> +} + // ----- // CHECK-LABEL: func @extract_from_tensor.cast diff --git a/mlir/test/Dialect/Tensor/extract-from-collapse-shape.mlir b/mlir/test/Dialect/Tensor/extract-from-collapse-shape.mlir new file mode 100644 index 000000000000..c301f494a7c8 --- /dev/null +++ b/mlir/test/Dialect/Tensor/extract-from-collapse-shape.mlir @@ -0,0 +1,31 @@ +// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-extract-from-collapse-shape %s | FileCheck %s + +// CHECK-LABEL: @extract_from_collapse_shape +// CHECK-SAME: (%[[ARG0:.*]]: tensor<1x1x8xi8>) +func.func @extract_from_collapse_shape(%arg0: tensor<1x1x8xi8>) -> (i8, i8) { + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2]] : tensor<1x1x8xi8> into tensor<8xi8> + %extracted = tensor.extract %collapsed[%c0] : tensor<8xi8> + %extracted_0 = tensor.extract %collapsed[%c1] : tensor<8xi8> + func.return %extracted, %extracted_0 : i8, i8 +} + +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[RESULT0:.*]] = tensor.extract %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] : tensor<1x1x8xi8> +// CHECK-DAG: %[[RESULT1:.*]] = tensor.extract %[[ARG0]][%[[C0]], %[[C0]], %[[C1]]] : tensor<1x1x8xi8> +// CHECK-NEXT: return %[[RESULT0]], %[[RESULT1]] : i8, i8 + +// ----- + +// CHECK-LABEL: @extract_from_static_shape +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) +func.func @extract_from_static_shape(%arg0 : tensor<2x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 { + %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<2x6x32xf32> into tensor<12x32xf32> + %1 = tensor.extract %0[%arg1, %arg2] : tensor<12x32xf32> + return %1 : f32 +} +// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (2, 6) +// CHECK-NEXT: %[[RESULT:.*]] = tensor.extract %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : tensor<2x6x32xf32> +// CHECK-NEXT: return %[[RESULT]] : f32 diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp index e435130c2a41..0e191c32f009 100644 --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -77,6 +77,11 @@ struct TestTensorTransforms llvm::cl::desc("Test folding of expand_shape/collapse_shape"), llvm::cl::init(false)}; + Option testFoldExtractFromCollapseShape{ + *this, "test-fold-extract-from-collapse-shape", + llvm::cl::desc("Test folding of extract from collapse_shape"), + llvm::cl::init(false)}; + Option useForeach{ *this, "use-foreach", llvm::cl::desc( @@ -132,6 +137,12 @@ applyDropRedundantInsertSliceRankExpansionPatterns(Operation *rootOp) { (void)applyPatternsGreedily(rootOp, std::move(patterns)); } +static void applyFoldExtractFromCollapseShapePatterns(Operation *rootOp) { + RewritePatternSet patterns(rootOp->getContext()); + tensor::populateFoldCollapseExtractPatterns(patterns); + (void)applyPatternsGreedily(rootOp, std::move(patterns)); +} + namespace { /// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`. /// The `tensor.extract_slice` is replaced by a loop or gather operation that @@ -380,6 +391,8 @@ void TestTensorTransforms::runOnOperation() { applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach))) return signalPassFailure(); } + if (testFoldExtractFromCollapseShape) + applyFoldExtractFromCollapseShapePatterns(rootOp); if (testTrackingListener) if (failed(testTrackingListenerReplacements(rootOp))) return signalPassFailure();