mirror of
https://github.com/intel/llvm.git
synced 2026-01-14 11:57:39 +08:00
[mlir] GreedyPatternRewriter: fix counting of iterations
The GreedyPatternRewriteDriver did previously not count the first iteration. I.e., when setting `config.maxIterations = 1`, two iterations were performed. In pratice, this number is not really important; we usually just need a limit in some reasonable order of magnitude. However, this fix allows us to write better convergence/worklist tests with carefully crafted test patterns to purposely trigger edge cases in the driver. Similarly, the first rewrite was previously not counted towards `config.maxNumRewrites`. For consistency, `OpPatternRewriteDriver` now uses `config.maxNumRewrites` instead of `config.maxIterations`; this driver does not have "iterations", it consists of a single loop (corresponding to the inner loop in the GreedyPatternRewriteDriver). Differential Revision: https://reviews.llvm.org/D141365
This commit is contained in:
@@ -96,10 +96,11 @@ protected:
|
||||
/// Non-pattern based folder for operations.
|
||||
OperationFolder folder;
|
||||
|
||||
private:
|
||||
protected:
|
||||
/// Configuration information for how to simplify.
|
||||
GreedyRewriteConfig config;
|
||||
|
||||
private:
|
||||
#ifndef NDEBUG
|
||||
/// A logger used to emit information during the application process.
|
||||
llvm::ScopedPrinter logger{llvm::dbgs()};
|
||||
@@ -147,8 +148,13 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
|
||||
};
|
||||
|
||||
bool changed = false;
|
||||
unsigned iteration = 0;
|
||||
int64_t iteration = 0;
|
||||
do {
|
||||
// Check if the iteration limit was reached.
|
||||
if (iteration++ >= config.maxIterations &&
|
||||
config.maxIterations != GreedyRewriteConfig::kNoLimit)
|
||||
break;
|
||||
|
||||
worklist.clear();
|
||||
worklistMap.clear();
|
||||
|
||||
@@ -184,7 +190,9 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
|
||||
|
||||
changed = false;
|
||||
int64_t numRewrites = 0;
|
||||
while (!worklist.empty()) {
|
||||
while (!worklist.empty() &&
|
||||
(numRewrites < config.maxNumRewrites ||
|
||||
config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
|
||||
auto *op = popFromWorklist();
|
||||
|
||||
// Nulls get added to the worklist when operations are removed, ignore
|
||||
@@ -280,11 +288,10 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
|
||||
#else
|
||||
LogicalResult matchResult = matcher.matchAndRewrite(op, *this);
|
||||
#endif
|
||||
|
||||
if (succeeded(matchResult)) {
|
||||
changed = true;
|
||||
if (numRewrites++ >= config.maxNumRewrites &&
|
||||
config.maxNumRewrites != GreedyRewriteConfig::kNoLimit)
|
||||
break;
|
||||
++numRewrites;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -292,8 +299,7 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
|
||||
// is kept up to date.
|
||||
if (config.enableRegionSimplification)
|
||||
changed |= succeeded(simplifyRegions(*this, regions));
|
||||
} while (changed && (iteration++ < config.maxIterations ||
|
||||
config.maxIterations == GreedyRewriteConfig::kNoLimit));
|
||||
} while (changed);
|
||||
|
||||
// Whether the rewrite converges, i.e. wasn't changed in the last iteration.
|
||||
return !changed;
|
||||
@@ -421,7 +427,7 @@ mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
|
||||
GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config);
|
||||
bool converged = driver.simplify(regions);
|
||||
LLVM_DEBUG(if (!converged) {
|
||||
llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
|
||||
llvm::dbgs() << "The pattern rewrite did not converge after scanning "
|
||||
<< config.maxIterations << " times\n";
|
||||
});
|
||||
return success(converged);
|
||||
@@ -443,7 +449,8 @@ public:
|
||||
matcher.applyDefaultCostModel();
|
||||
}
|
||||
|
||||
LogicalResult simplifyLocally(Operation *op, int maxIterations, bool &erased);
|
||||
LogicalResult simplifyLocally(Operation *op, int64_t maxNumRewrites,
|
||||
bool &erased);
|
||||
|
||||
// These are hooks implemented for PatternRewriter.
|
||||
protected:
|
||||
@@ -473,18 +480,22 @@ private:
|
||||
/// Performs the rewrites and folding only on `op`. The simplification
|
||||
/// converges if the op is erased as a result of being folded, replaced, or
|
||||
/// becoming dead, or no more changes happen in an iteration. Returns success if
|
||||
/// the rewrite converges in `maxIterations`. `erased` is set to true if `op`
|
||||
/// the rewrite converges in `maxNumRewrites`. `erased` is set to true if `op`
|
||||
/// gets erased.
|
||||
LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
|
||||
int maxIterations,
|
||||
int64_t maxNumRewrites,
|
||||
bool &erased) {
|
||||
bool changed = false;
|
||||
erased = false;
|
||||
opErasedViaPatternRewrites = false;
|
||||
int iterations = 0;
|
||||
// Iterate until convergence or until maxIterations. Deletion of the op as
|
||||
int64_t numRewrites = 0;
|
||||
// Iterate until convergence or until maxNumRewrites. Deletion of the op as
|
||||
// a result of being dead or folded is convergence.
|
||||
do {
|
||||
if (numRewrites >= maxNumRewrites &&
|
||||
maxNumRewrites != GreedyRewriteConfig::kNoLimit)
|
||||
break;
|
||||
|
||||
changed = false;
|
||||
|
||||
// If the operation is trivially dead - remove it.
|
||||
@@ -508,11 +519,13 @@ LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
|
||||
|
||||
// Try to match one of the patterns. The rewriter is automatically
|
||||
// notified of any necessary changes, so there is nothing else to do here.
|
||||
changed |= succeeded(matcher.matchAndRewrite(op, *this));
|
||||
if (succeeded(matcher.matchAndRewrite(op, *this))) {
|
||||
changed = true;
|
||||
++numRewrites;
|
||||
}
|
||||
if ((erased = opErasedViaPatternRewrites))
|
||||
return success();
|
||||
} while (changed && (++iterations < maxIterations ||
|
||||
maxIterations == GreedyRewriteConfig::kNoLimit));
|
||||
} while (changed);
|
||||
|
||||
// Whether the rewrite converges, i.e. wasn't changed in the last iteration.
|
||||
return failure(changed);
|
||||
@@ -601,7 +614,10 @@ bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
|
||||
|
||||
// These are scratch vectors used in the folding loop below.
|
||||
SmallVector<Value, 8> originalOperands, resultValues;
|
||||
while (!worklist.empty()) {
|
||||
int64_t numRewrites = 0;
|
||||
while (!worklist.empty() &&
|
||||
(numRewrites < config.maxNumRewrites ||
|
||||
config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
|
||||
Operation *op = popFromWorklist();
|
||||
|
||||
// Nulls get added to the worklist when operations are removed, ignore
|
||||
@@ -656,7 +672,10 @@ bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
|
||||
// Try to match one of the patterns. The rewriter is automatically
|
||||
// notified of any necessary changes, so there is nothing else to do
|
||||
// here.
|
||||
changed |= succeeded(matcher.matchAndRewrite(op, *this));
|
||||
if (succeeded(matcher.matchAndRewrite(op, *this))) {
|
||||
changed = true;
|
||||
++numRewrites;
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
@@ -672,12 +691,12 @@ LogicalResult mlir::applyOpPatternsAndFold(
|
||||
OpPatternRewriteDriver driver(op->getContext(), patterns);
|
||||
bool opErased;
|
||||
LogicalResult converged =
|
||||
driver.simplifyLocally(op, config.maxIterations, opErased);
|
||||
driver.simplifyLocally(op, config.maxNumRewrites, opErased);
|
||||
if (erased)
|
||||
*erased = opErased;
|
||||
LLVM_DEBUG(if (failed(converged)) {
|
||||
llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
|
||||
<< config.maxIterations << " times";
|
||||
llvm::dbgs() << "The pattern rewrite did not converge after "
|
||||
<< config.maxNumRewrites << " rewrites";
|
||||
});
|
||||
return converged;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user