[mlir][PatternRewriter] Add a new hook to selectively replace uses of an operation

This revision adds a new `replaceOpWithIf` hook that replaces uses of an operation that satisfy a given functor. If all uses are replaced, the operation gets erased in a similar manner to `replaceOp`. DialectConversion support will be added in a followup as this requires adjusting how replacements are tracked there.

Differential Revision: https://reviews.llvm.org/D94632
This commit is contained in:
River Riddle
2021-01-14 11:57:17 -08:00
parent 387d3c2479
commit c8fb6ee341
6 changed files with 143 additions and 0 deletions

View File

@@ -11,6 +11,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "llvm/ADT/FunctionExtras.h"
namespace mlir {
@@ -447,6 +448,30 @@ public:
Region::iterator before);
void cloneRegionBefore(Region &region, Block *before);
/// This method replaces the uses of the results of `op` with the values in
/// `newValues` when the provided `functor` returns true for a specific use.
/// The number of values in `newValues` is required to match the number of
/// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of
/// the uses of `op` were replaced. Note that in some pattern rewriters, the
/// given 'functor' may be stored beyond the lifetime of the pattern being
/// applied. As such, the function should not capture by reference and instead
/// use value capture as necessary.
virtual void
replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced,
llvm::unique_function<bool(OpOperand &) const> functor);
void replaceOpWithIf(Operation *op, ValueRange newValues,
llvm::unique_function<bool(OpOperand &) const> functor) {
replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr,
std::move(functor));
}
/// This method replaces the uses of the results of `op` with the values in
/// `newValues` when a use is nested within the given `block`. The number of
/// values in `newValues` is required to match the number of results of `op`.
/// If all uses of this operation are replaced, the operation is erased.
void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block,
bool *allUsesReplaced = nullptr);
/// This method performs the final replacement for a pattern, where the
/// results of the operation are updated to use the specified list of SSA
/// values.

View File

@@ -470,6 +470,12 @@ public:
// PatternRewriter Hooks
//===--------------------------------------------------------------------===//
/// PatternRewriter hook for replacing the results of an operation when the
/// given functor returns true.
void replaceOpWithIf(
Operation *op, ValueRange newValues, bool *allUsesReplaced,
llvm::unique_function<bool(OpOperand &) const> functor) override;
/// PatternRewriter hook for replacing the results of an operation.
void replaceOp(Operation *op, ValueRange newValues) override;
using PatternRewriter::replaceOp;

View File

@@ -155,6 +155,41 @@ PatternRewriter::~PatternRewriter() {
// Out of line to provide a vtable anchor for the class.
}
/// This method replaces the uses of the results of `op` with the values in
/// `newValues` when the provided `functor` returns true for a specific use.
/// The number of values in `newValues` is required to match the number of
/// results of `op`.
void PatternRewriter::replaceOpWithIf(
Operation *op, ValueRange newValues, bool *allUsesReplaced,
llvm::unique_function<bool(OpOperand &) const> functor) {
assert(op->getNumResults() == newValues.size() &&
"incorrect number of values to replace operation");
// Notify the rewriter subclass that we're about to replace this root.
notifyRootReplaced(op);
// Replace each use of the results when the functor is true.
bool replacedAllUses = true;
for (auto it : llvm::zip(op->getResults(), newValues)) {
std::get<0>(it).replaceUsesWithIf(std::get<1>(it), functor);
replacedAllUses &= std::get<0>(it).use_empty();
}
if (allUsesReplaced)
*allUsesReplaced = replacedAllUses;
}
/// This method replaces the uses of the results of `op` with the values in
/// `newValues` when a use is nested within the given `block`. The number of
/// values in `newValues` is required to match the number of results of `op`.
/// If all uses of this operation are replaced, the operation is erased.
void PatternRewriter::replaceOpWithinBlock(Operation *op, ValueRange newValues,
Block *block,
bool *allUsesReplaced) {
replaceOpWithIf(op, newValues, allUsesReplaced, [block](OpOperand &use) {
return block->getParentOp()->isProperAncestor(use.getOwner());
});
}
/// This method performs the final replacement for a pattern, where the
/// results of the operation are updated to use the specified list of SSA
/// values.

View File

@@ -1250,6 +1250,21 @@ ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
impl(new detail::ConversionPatternRewriterImpl(*this)) {}
ConversionPatternRewriter::~ConversionPatternRewriter() {}
/// PatternRewriter hook for replacing the results of an operation when the
/// given functor returns true.
void ConversionPatternRewriter::replaceOpWithIf(
Operation *op, ValueRange newValues, bool *allUsesReplaced,
llvm::unique_function<bool(OpOperand &) const> functor) {
// TODO: To support this we will need to rework a bit of how replacements are
// tracked, given that this isn't guranteed to replace all of the uses of an
// operation. The main change is that now an operation can be replaced
// multiple times, in parts. The current "set" based tracking is mainly useful
// for tracking if a replaced operation should be ignored, i.e. if all of the
// uses will be replaced.
llvm_unreachable(
"replaceOpWithIf is currently not supported by DialectConversion");
}
/// PatternRewriter hook for replacing the results of an operation.
void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
LLVM_DEBUG({

View File

@@ -0,0 +1,15 @@
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-pattern-selective-replacement -verify-diagnostics %s | FileCheck %s
// Test that operations can be selectively replaced.
// CHECK-LABEL: @test1
// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
func @test1(%arg0: i32, %arg1 : i32) -> () {
// CHECK: addi %[[ARG1]], %[[ARG1]]
// CHECK-NEXT: "test.return"(%[[ARG0]]
%cast = "test.cast"(%arg0, %arg1) : (i32, i32) -> (i32)
%non_terminator = addi %cast, %cast : i32
"test.return"(%cast, %non_terminator) : (i32, i32) -> ()
}
// -----

View File

@@ -847,6 +847,10 @@ struct TestTypeConversionDriver
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// Test Block Merging
//===----------------------------------------------------------------------===//
namespace {
/// A rewriter pattern that tests that blocks can be merged.
struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
@@ -955,6 +959,46 @@ struct TestMergeBlocksPatternDriver
};
} // namespace
//===----------------------------------------------------------------------===//
// Test Selective Replacement
//===----------------------------------------------------------------------===//
namespace {
/// A rewrite mechanism to inline the body of the op into its parent, when both
/// ops can have a single block.
struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
using OpRewritePattern<TestCastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TestCastOp op,
PatternRewriter &rewriter) const final {
if (op.getNumOperands() != 2)
return failure();
OperandRange operands = op.getOperands();
// Replace non-terminator uses with the first operand.
rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) {
return operand.getOwner()->isKnownTerminator();
});
// Replace everything else with the second operand if the operation isn't
// dead.
rewriter.replaceOp(op, op.getOperand(1));
return success();
}
};
struct TestSelectiveReplacementPatternDriver
: public PassWrapper<TestSelectiveReplacementPatternDriver,
OperationPass<>> {
void runOnOperation() override {
mlir::OwningRewritePatternList patterns;
MLIRContext *context = &getContext();
patterns.insert<TestSelectiveOpReplacementPattern>(context);
applyPatternsAndFoldGreedily(getOperation()->getRegions(),
std::move(patterns));
}
};
} // namespace
//===----------------------------------------------------------------------===//
// PassRegistration
//===----------------------------------------------------------------------===//
@@ -992,6 +1036,9 @@ void registerPatternsTestPass() {
PassRegistration<TestMergeBlocksPatternDriver>{
"test-merge-blocks",
"Test Merging operation in ConversionPatternRewriter"};
PassRegistration<TestSelectiveReplacementPatternDriver>{
"test-pattern-selective-replacement",
"Test selective replacement in the PatternRewriter"};
}
} // namespace test
} // namespace mlir