[mlir][NFC] GreedyPatternRewriteDriver: Consistent return values

All `apply...` functions now return a LogicalResult indicating whether the iterative process converged or not.

Differential Revision: https://reviews.llvm.org/D141845
This commit is contained in:
Matthias Springer
2023-01-16 16:23:58 +01:00
parent 6e5021b8dc
commit fefe655baa
3 changed files with 40 additions and 18 deletions

View File

@@ -574,7 +574,8 @@ public:
: GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()),
strictMode(strict) {}
bool simplifyLocally(ArrayRef<Operation *> op);
LogicalResult simplifyLocally(ArrayRef<Operation *> op,
bool *changed = nullptr);
void addToWorklist(Operation *op) override {
if (!strictMode || strictModeFilteredOps.contains(op))
@@ -625,13 +626,16 @@ private:
// there is no strong rationale to re-add all operations into the worklist and
// rerun until an iteration changes nothing. If more widereaching simplification
// is desired, GreedyPatternRewriteDriver should be used.
bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
LogicalResult
MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops,
bool *changed) {
if (strictMode) {
strictModeFilteredOps.clear();
strictModeFilteredOps.insert(ops.begin(), ops.end());
}
bool changed = false;
if (changed)
*changed = false;
worklist.clear();
worklistMap.clear();
for (Operation *op : ops)
@@ -657,7 +661,8 @@ bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
if (isOpTriviallyDead(op)) {
notifyOperationRemoved(op);
op->erase();
changed = true;
if (changed)
*changed = true;
continue;
}
@@ -687,7 +692,8 @@ bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
bool inPlaceUpdate;
if (succeeded(folder.tryToFold(op, processGeneratedConstants,
preReplaceAction, &inPlaceUpdate))) {
changed = true;
if (changed)
*changed = true;
if (!inPlaceUpdate) {
// Op has been erased.
continue;
@@ -698,12 +704,13 @@ bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
// notified of any necessary changes, so there is nothing else to do
// here.
if (succeeded(matcher.matchAndRewrite(op, *this))) {
changed = true;
if (changed)
*changed = true;
++numRewrites;
}
}
return changed;
return success(worklist.empty());
}
/// Rewrites only `op` using the supplied canonicalization patterns and
@@ -726,14 +733,18 @@ LogicalResult mlir::applyOpPatternsAndFold(
return converged;
}
bool mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
const FrozenRewritePatternSet &patterns,
bool strict) {
if (ops.empty())
return false;
LogicalResult
mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
const FrozenRewritePatternSet &patterns,
bool strict, bool *changed) {
if (ops.empty()) {
if (changed)
*changed = false;
return success();
}
// Start the pattern driver.
MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
strict);
return driver.simplifyLocally(ops);
return driver.simplifyLocally(ops, changed);
}