[mlir][GreedyPatternRewriter] Add out param to detect changes in IR in applyPatternsAndFoldGreedily

This allows users of `applyPatternsAndFoldGreedily` to detect if any MLIR changes have occurred. An example use-case is where we expect the `applyPatternsAndFoldGreedily` to change the IR and want to validate that it indeed does change it.

Differential Revision: https://reviews.llvm.org/D153986
This commit is contained in:
Joel Wee
2023-06-29 12:46:54 +02:00
committed by Matthias Springer
parent d34bc66ef9
commit 8498c9e948
2 changed files with 33 additions and 19 deletions

View File

@@ -94,25 +94,36 @@ public:
/// in absence of convergence.
///
/// Return success if the iterative process converged and no more patterns can
/// be matched in the result operation regions.
/// be matched in the result operation regions. `changed` is set to true if the
/// IR was modified at all.
///
/// Note: This does not apply patterns to the top-level operation itself.
/// These methods also perform folding and simple dead-code elimination
/// before attempting to match any of the provided patterns.
///
/// You may configure several aspects of this with GreedyRewriteConfig.
LogicalResult applyPatternsAndFoldGreedily(
Region &region, const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig());
LogicalResult
applyPatternsAndFoldGreedily(Region &region,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr);
/// Rewrite ops in all regions of the given op, which must be isolated from
/// above.
inline LogicalResult applyPatternsAndFoldGreedily(
Operation *op, const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig()) {
inline LogicalResult
applyPatternsAndFoldGreedily(Operation *op,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr) {
bool failed = false;
for (Region &region : op->getRegions())
failed |= applyPatternsAndFoldGreedily(region, patterns, config).failed();
for (Region &region : op->getRegions()) {
bool regionChanged;
failed |=
applyPatternsAndFoldGreedily(region, patterns, config, &regionChanged)
.failed();
if (changed)
*changed |= regionChanged;
}
return failure(failed);
}

View File

@@ -616,7 +616,7 @@ public:
/// Simplify ops inside `region` and simplify the region itself. Return
/// success if the transformation converged.
LogicalResult simplify() &&;
LogicalResult simplify(bool *changed) &&;
private:
/// The region that is simplified.
@@ -652,7 +652,7 @@ private:
};
} // namespace
LogicalResult RegionPatternRewriteDriver::simplify() && {
LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
auto insertKnownConstant = [&](Operation *op) {
// Check for existing constants when populating the worklist. This avoids
// accidentally reversing the constant order during processing.
@@ -663,12 +663,12 @@ LogicalResult RegionPatternRewriteDriver::simplify() && {
return false;
};
bool changed = false;
bool continueRewrites = false;
int64_t iteration = 0;
MLIRContext *ctx = getContext();
do {
// Check if the iteration limit was reached.
if (iteration++ >= config.maxIterations &&
if (++iteration > config.maxIterations &&
config.maxIterations != GreedyRewriteConfig::kNoLimit)
break;
@@ -696,24 +696,27 @@ LogicalResult RegionPatternRewriteDriver::simplify() && {
ctx->executeAction<GreedyPatternRewriteIteration>(
[&] {
changed = processWorklist();
continueRewrites = processWorklist();
// After applying patterns, make sure that the CFG of each of the
// regions is kept up to date.
if (config.enableRegionSimplification)
changed |= succeeded(simplifyRegions(*this, region));
continueRewrites |= succeeded(simplifyRegions(*this, region));
},
{&region}, iteration);
} while (changed);
} while (continueRewrites);
if (changed)
*changed = iteration > 1;
// Whether the rewrite converges, i.e. wasn't changed in the last iteration.
return success(!changed);
return success(!continueRewrites);
}
LogicalResult
mlir::applyPatternsAndFoldGreedily(Region &region,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config) {
GreedyRewriteConfig config, bool *changed) {
// The top-level operation must be known to be isolated from above to
// prevent performing canonicalizations on operations defined at or above
// the region containing 'op'.
@@ -727,7 +730,7 @@ mlir::applyPatternsAndFoldGreedily(Region &region,
// Start the pattern driver.
RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
region);
LogicalResult converged = std::move(driver).simplify();
LogicalResult converged = std::move(driver).simplify(changed);
LLVM_DEBUG(if (failed(converged)) {
llvm::dbgs() << "The pattern rewrite did not converge after scanning "
<< config.maxIterations << " times\n";