[mlir] GreedyPatternRewriter: fix counting of iterations

The GreedyPatternRewriteDriver did previously not count the first iteration. I.e., when setting `config.maxIterations = 1`, two iterations were performed. In pratice, this number is not really important; we usually just need a limit in some reasonable order of magnitude. However, this fix allows us to write better convergence/worklist tests with carefully crafted test patterns to purposely trigger edge cases in the driver.

Similarly, the first rewrite was previously not counted towards `config.maxNumRewrites`.

For consistency, `OpPatternRewriteDriver` now uses `config.maxNumRewrites` instead of `config.maxIterations`; this driver does not have "iterations", it consists of a single loop (corresponding to the inner loop in the GreedyPatternRewriteDriver).

Differential Revision: https://reviews.llvm.org/D141365
This commit is contained in:
Matthias Springer
2023-01-10 12:02:33 +01:00
parent 094ccee2c8
commit 0ff3cf0c0c

View File

@@ -96,10 +96,11 @@ protected:
/// Non-pattern based folder for operations.
OperationFolder folder;
private:
protected:
/// Configuration information for how to simplify.
GreedyRewriteConfig config;
private:
#ifndef NDEBUG
/// A logger used to emit information during the application process.
llvm::ScopedPrinter logger{llvm::dbgs()};
@@ -147,8 +148,13 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
};
bool changed = false;
unsigned iteration = 0;
int64_t iteration = 0;
do {
// Check if the iteration limit was reached.
if (iteration++ >= config.maxIterations &&
config.maxIterations != GreedyRewriteConfig::kNoLimit)
break;
worklist.clear();
worklistMap.clear();
@@ -184,7 +190,9 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
changed = false;
int64_t numRewrites = 0;
while (!worklist.empty()) {
while (!worklist.empty() &&
(numRewrites < config.maxNumRewrites ||
config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
auto *op = popFromWorklist();
// Nulls get added to the worklist when operations are removed, ignore
@@ -280,11 +288,10 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
#else
LogicalResult matchResult = matcher.matchAndRewrite(op, *this);
#endif
if (succeeded(matchResult)) {
changed = true;
if (numRewrites++ >= config.maxNumRewrites &&
config.maxNumRewrites != GreedyRewriteConfig::kNoLimit)
break;
++numRewrites;
}
}
@@ -292,8 +299,7 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
// is kept up to date.
if (config.enableRegionSimplification)
changed |= succeeded(simplifyRegions(*this, regions));
} while (changed && (iteration++ < config.maxIterations ||
config.maxIterations == GreedyRewriteConfig::kNoLimit));
} while (changed);
// Whether the rewrite converges, i.e. wasn't changed in the last iteration.
return !changed;
@@ -421,7 +427,7 @@ mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config);
bool converged = driver.simplify(regions);
LLVM_DEBUG(if (!converged) {
llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
llvm::dbgs() << "The pattern rewrite did not converge after scanning "
<< config.maxIterations << " times\n";
});
return success(converged);
@@ -443,7 +449,8 @@ public:
matcher.applyDefaultCostModel();
}
LogicalResult simplifyLocally(Operation *op, int maxIterations, bool &erased);
LogicalResult simplifyLocally(Operation *op, int64_t maxNumRewrites,
bool &erased);
// These are hooks implemented for PatternRewriter.
protected:
@@ -473,18 +480,22 @@ private:
/// Performs the rewrites and folding only on `op`. The simplification
/// converges if the op is erased as a result of being folded, replaced, or
/// becoming dead, or no more changes happen in an iteration. Returns success if
/// the rewrite converges in `maxIterations`. `erased` is set to true if `op`
/// the rewrite converges in `maxNumRewrites`. `erased` is set to true if `op`
/// gets erased.
LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
int maxIterations,
int64_t maxNumRewrites,
bool &erased) {
bool changed = false;
erased = false;
opErasedViaPatternRewrites = false;
int iterations = 0;
// Iterate until convergence or until maxIterations. Deletion of the op as
int64_t numRewrites = 0;
// Iterate until convergence or until maxNumRewrites. Deletion of the op as
// a result of being dead or folded is convergence.
do {
if (numRewrites >= maxNumRewrites &&
maxNumRewrites != GreedyRewriteConfig::kNoLimit)
break;
changed = false;
// If the operation is trivially dead - remove it.
@@ -508,11 +519,13 @@ LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
// Try to match one of the patterns. The rewriter is automatically
// notified of any necessary changes, so there is nothing else to do here.
changed |= succeeded(matcher.matchAndRewrite(op, *this));
if (succeeded(matcher.matchAndRewrite(op, *this))) {
changed = true;
++numRewrites;
}
if ((erased = opErasedViaPatternRewrites))
return success();
} while (changed && (++iterations < maxIterations ||
maxIterations == GreedyRewriteConfig::kNoLimit));
} while (changed);
// Whether the rewrite converges, i.e. wasn't changed in the last iteration.
return failure(changed);
@@ -601,7 +614,10 @@ bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
// These are scratch vectors used in the folding loop below.
SmallVector<Value, 8> originalOperands, resultValues;
while (!worklist.empty()) {
int64_t numRewrites = 0;
while (!worklist.empty() &&
(numRewrites < config.maxNumRewrites ||
config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
Operation *op = popFromWorklist();
// Nulls get added to the worklist when operations are removed, ignore
@@ -656,7 +672,10 @@ bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
// Try to match one of the patterns. The rewriter is automatically
// notified of any necessary changes, so there is nothing else to do
// here.
changed |= succeeded(matcher.matchAndRewrite(op, *this));
if (succeeded(matcher.matchAndRewrite(op, *this))) {
changed = true;
++numRewrites;
}
}
return changed;
@@ -672,12 +691,12 @@ LogicalResult mlir::applyOpPatternsAndFold(
OpPatternRewriteDriver driver(op->getContext(), patterns);
bool opErased;
LogicalResult converged =
driver.simplifyLocally(op, config.maxIterations, opErased);
driver.simplifyLocally(op, config.maxNumRewrites, opErased);
if (erased)
*erased = opErased;
LLVM_DEBUG(if (failed(converged)) {
llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
<< config.maxIterations << " times";
llvm::dbgs() << "The pattern rewrite did not converge after "
<< config.maxNumRewrites << " rewrites";
});
return converged;
}