mirror of
https://github.com/intel/llvm.git
synced 2026-01-20 10:58:11 +08:00
[mlir][PatternRewriter] Add a new hook to selectively replace uses of an operation
This revision adds a new `replaceOpWithIf` hook that replaces uses of an operation that satisfy a given functor. If all uses are replaced, the operation gets erased in a similar manner to `replaceOp`. DialectConversion support will be added in a followup as this requires adjusting how replacements are tracked there. Differential Revision: https://reviews.llvm.org/D94632
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "llvm/ADT/FunctionExtras.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
@@ -447,6 +448,30 @@ public:
|
||||
Region::iterator before);
|
||||
void cloneRegionBefore(Region ®ion, Block *before);
|
||||
|
||||
/// This method replaces the uses of the results of `op` with the values in
|
||||
/// `newValues` when the provided `functor` returns true for a specific use.
|
||||
/// The number of values in `newValues` is required to match the number of
|
||||
/// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of
|
||||
/// the uses of `op` were replaced. Note that in some pattern rewriters, the
|
||||
/// given 'functor' may be stored beyond the lifetime of the pattern being
|
||||
/// applied. As such, the function should not capture by reference and instead
|
||||
/// use value capture as necessary.
|
||||
virtual void
|
||||
replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced,
|
||||
llvm::unique_function<bool(OpOperand &) const> functor);
|
||||
void replaceOpWithIf(Operation *op, ValueRange newValues,
|
||||
llvm::unique_function<bool(OpOperand &) const> functor) {
|
||||
replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr,
|
||||
std::move(functor));
|
||||
}
|
||||
|
||||
/// This method replaces the uses of the results of `op` with the values in
|
||||
/// `newValues` when a use is nested within the given `block`. The number of
|
||||
/// values in `newValues` is required to match the number of results of `op`.
|
||||
/// If all uses of this operation are replaced, the operation is erased.
|
||||
void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block,
|
||||
bool *allUsesReplaced = nullptr);
|
||||
|
||||
/// This method performs the final replacement for a pattern, where the
|
||||
/// results of the operation are updated to use the specified list of SSA
|
||||
/// values.
|
||||
|
||||
@@ -470,6 +470,12 @@ public:
|
||||
// PatternRewriter Hooks
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// PatternRewriter hook for replacing the results of an operation when the
|
||||
/// given functor returns true.
|
||||
void replaceOpWithIf(
|
||||
Operation *op, ValueRange newValues, bool *allUsesReplaced,
|
||||
llvm::unique_function<bool(OpOperand &) const> functor) override;
|
||||
|
||||
/// PatternRewriter hook for replacing the results of an operation.
|
||||
void replaceOp(Operation *op, ValueRange newValues) override;
|
||||
using PatternRewriter::replaceOp;
|
||||
|
||||
@@ -155,6 +155,41 @@ PatternRewriter::~PatternRewriter() {
|
||||
// Out of line to provide a vtable anchor for the class.
|
||||
}
|
||||
|
||||
/// This method replaces the uses of the results of `op` with the values in
|
||||
/// `newValues` when the provided `functor` returns true for a specific use.
|
||||
/// The number of values in `newValues` is required to match the number of
|
||||
/// results of `op`.
|
||||
void PatternRewriter::replaceOpWithIf(
|
||||
Operation *op, ValueRange newValues, bool *allUsesReplaced,
|
||||
llvm::unique_function<bool(OpOperand &) const> functor) {
|
||||
assert(op->getNumResults() == newValues.size() &&
|
||||
"incorrect number of values to replace operation");
|
||||
|
||||
// Notify the rewriter subclass that we're about to replace this root.
|
||||
notifyRootReplaced(op);
|
||||
|
||||
// Replace each use of the results when the functor is true.
|
||||
bool replacedAllUses = true;
|
||||
for (auto it : llvm::zip(op->getResults(), newValues)) {
|
||||
std::get<0>(it).replaceUsesWithIf(std::get<1>(it), functor);
|
||||
replacedAllUses &= std::get<0>(it).use_empty();
|
||||
}
|
||||
if (allUsesReplaced)
|
||||
*allUsesReplaced = replacedAllUses;
|
||||
}
|
||||
|
||||
/// This method replaces the uses of the results of `op` with the values in
|
||||
/// `newValues` when a use is nested within the given `block`. The number of
|
||||
/// values in `newValues` is required to match the number of results of `op`.
|
||||
/// If all uses of this operation are replaced, the operation is erased.
|
||||
void PatternRewriter::replaceOpWithinBlock(Operation *op, ValueRange newValues,
|
||||
Block *block,
|
||||
bool *allUsesReplaced) {
|
||||
replaceOpWithIf(op, newValues, allUsesReplaced, [block](OpOperand &use) {
|
||||
return block->getParentOp()->isProperAncestor(use.getOwner());
|
||||
});
|
||||
}
|
||||
|
||||
/// This method performs the final replacement for a pattern, where the
|
||||
/// results of the operation are updated to use the specified list of SSA
|
||||
/// values.
|
||||
|
||||
@@ -1250,6 +1250,21 @@ ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
|
||||
impl(new detail::ConversionPatternRewriterImpl(*this)) {}
|
||||
ConversionPatternRewriter::~ConversionPatternRewriter() {}
|
||||
|
||||
/// PatternRewriter hook for replacing the results of an operation when the
|
||||
/// given functor returns true.
|
||||
void ConversionPatternRewriter::replaceOpWithIf(
|
||||
Operation *op, ValueRange newValues, bool *allUsesReplaced,
|
||||
llvm::unique_function<bool(OpOperand &) const> functor) {
|
||||
// TODO: To support this we will need to rework a bit of how replacements are
|
||||
// tracked, given that this isn't guranteed to replace all of the uses of an
|
||||
// operation. The main change is that now an operation can be replaced
|
||||
// multiple times, in parts. The current "set" based tracking is mainly useful
|
||||
// for tracking if a replaced operation should be ignored, i.e. if all of the
|
||||
// uses will be replaced.
|
||||
llvm_unreachable(
|
||||
"replaceOpWithIf is currently not supported by DialectConversion");
|
||||
}
|
||||
|
||||
/// PatternRewriter hook for replacing the results of an operation.
|
||||
void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
|
||||
LLVM_DEBUG({
|
||||
|
||||
15
mlir/test/Transforms/test-pattern-selective-replacement.mlir
Normal file
15
mlir/test/Transforms/test-pattern-selective-replacement.mlir
Normal file
@@ -0,0 +1,15 @@
|
||||
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-pattern-selective-replacement -verify-diagnostics %s | FileCheck %s
|
||||
|
||||
// Test that operations can be selectively replaced.
|
||||
|
||||
// CHECK-LABEL: @test1
|
||||
// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
|
||||
func @test1(%arg0: i32, %arg1 : i32) -> () {
|
||||
// CHECK: addi %[[ARG1]], %[[ARG1]]
|
||||
// CHECK-NEXT: "test.return"(%[[ARG0]]
|
||||
%cast = "test.cast"(%arg0, %arg1) : (i32, i32) -> (i32)
|
||||
%non_terminator = addi %cast, %cast : i32
|
||||
"test.return"(%cast, %non_terminator) : (i32, i32) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
@@ -847,6 +847,10 @@ struct TestTypeConversionDriver
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Block Merging
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// A rewriter pattern that tests that blocks can be merged.
|
||||
struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
|
||||
@@ -955,6 +959,46 @@ struct TestMergeBlocksPatternDriver
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Selective Replacement
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// A rewrite mechanism to inline the body of the op into its parent, when both
|
||||
/// ops can have a single block.
|
||||
struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
|
||||
using OpRewritePattern<TestCastOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(TestCastOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
if (op.getNumOperands() != 2)
|
||||
return failure();
|
||||
OperandRange operands = op.getOperands();
|
||||
|
||||
// Replace non-terminator uses with the first operand.
|
||||
rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) {
|
||||
return operand.getOwner()->isKnownTerminator();
|
||||
});
|
||||
// Replace everything else with the second operand if the operation isn't
|
||||
// dead.
|
||||
rewriter.replaceOp(op, op.getOperand(1));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TestSelectiveReplacementPatternDriver
|
||||
: public PassWrapper<TestSelectiveReplacementPatternDriver,
|
||||
OperationPass<>> {
|
||||
void runOnOperation() override {
|
||||
mlir::OwningRewritePatternList patterns;
|
||||
MLIRContext *context = &getContext();
|
||||
patterns.insert<TestSelectiveOpReplacementPattern>(context);
|
||||
applyPatternsAndFoldGreedily(getOperation()->getRegions(),
|
||||
std::move(patterns));
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PassRegistration
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -992,6 +1036,9 @@ void registerPatternsTestPass() {
|
||||
PassRegistration<TestMergeBlocksPatternDriver>{
|
||||
"test-merge-blocks",
|
||||
"Test Merging operation in ConversionPatternRewriter"};
|
||||
PassRegistration<TestSelectiveReplacementPatternDriver>{
|
||||
"test-pattern-selective-replacement",
|
||||
"Test selective replacement in the PatternRewriter"};
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
||||
Reference in New Issue
Block a user