mirror of
https://github.com/intel/llvm.git
synced 2026-01-22 07:01:03 +08:00
[mlir][GreedyPatternRewriter] Add out param to detect changes in IR in applyPatternsAndFoldGreedily
This allows users of `applyPatternsAndFoldGreedily` to detect if any MLIR changes have occurred. An example use-case is where we expect the `applyPatternsAndFoldGreedily` to change the IR and want to validate that it indeed does change it. Differential Revision: https://reviews.llvm.org/D153986
This commit is contained in:
committed by
Matthias Springer
parent
d34bc66ef9
commit
8498c9e948
@@ -94,25 +94,36 @@ public:
|
||||
/// in absence of convergence.
|
||||
///
|
||||
/// Return success if the iterative process converged and no more patterns can
|
||||
/// be matched in the result operation regions.
|
||||
/// be matched in the result operation regions. `changed` is set to true if the
|
||||
/// IR was modified at all.
|
||||
///
|
||||
/// Note: This does not apply patterns to the top-level operation itself.
|
||||
/// These methods also perform folding and simple dead-code elimination
|
||||
/// before attempting to match any of the provided patterns.
|
||||
///
|
||||
/// You may configure several aspects of this with GreedyRewriteConfig.
|
||||
LogicalResult applyPatternsAndFoldGreedily(
|
||||
Region ®ion, const FrozenRewritePatternSet &patterns,
|
||||
GreedyRewriteConfig config = GreedyRewriteConfig());
|
||||
LogicalResult
|
||||
applyPatternsAndFoldGreedily(Region ®ion,
|
||||
const FrozenRewritePatternSet &patterns,
|
||||
GreedyRewriteConfig config = GreedyRewriteConfig(),
|
||||
bool *changed = nullptr);
|
||||
|
||||
/// Rewrite ops in all regions of the given op, which must be isolated from
|
||||
/// above.
|
||||
inline LogicalResult applyPatternsAndFoldGreedily(
|
||||
Operation *op, const FrozenRewritePatternSet &patterns,
|
||||
GreedyRewriteConfig config = GreedyRewriteConfig()) {
|
||||
inline LogicalResult
|
||||
applyPatternsAndFoldGreedily(Operation *op,
|
||||
const FrozenRewritePatternSet &patterns,
|
||||
GreedyRewriteConfig config = GreedyRewriteConfig(),
|
||||
bool *changed = nullptr) {
|
||||
bool failed = false;
|
||||
for (Region ®ion : op->getRegions())
|
||||
failed |= applyPatternsAndFoldGreedily(region, patterns, config).failed();
|
||||
for (Region ®ion : op->getRegions()) {
|
||||
bool regionChanged;
|
||||
failed |=
|
||||
applyPatternsAndFoldGreedily(region, patterns, config, ®ionChanged)
|
||||
.failed();
|
||||
if (changed)
|
||||
*changed |= regionChanged;
|
||||
}
|
||||
return failure(failed);
|
||||
}
|
||||
|
||||
|
||||
@@ -616,7 +616,7 @@ public:
|
||||
|
||||
/// Simplify ops inside `region` and simplify the region itself. Return
|
||||
/// success if the transformation converged.
|
||||
LogicalResult simplify() &&;
|
||||
LogicalResult simplify(bool *changed) &&;
|
||||
|
||||
private:
|
||||
/// The region that is simplified.
|
||||
@@ -652,7 +652,7 @@ private:
|
||||
};
|
||||
} // namespace
|
||||
|
||||
LogicalResult RegionPatternRewriteDriver::simplify() && {
|
||||
LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
|
||||
auto insertKnownConstant = [&](Operation *op) {
|
||||
// Check for existing constants when populating the worklist. This avoids
|
||||
// accidentally reversing the constant order during processing.
|
||||
@@ -663,12 +663,12 @@ LogicalResult RegionPatternRewriteDriver::simplify() && {
|
||||
return false;
|
||||
};
|
||||
|
||||
bool changed = false;
|
||||
bool continueRewrites = false;
|
||||
int64_t iteration = 0;
|
||||
MLIRContext *ctx = getContext();
|
||||
do {
|
||||
// Check if the iteration limit was reached.
|
||||
if (iteration++ >= config.maxIterations &&
|
||||
if (++iteration > config.maxIterations &&
|
||||
config.maxIterations != GreedyRewriteConfig::kNoLimit)
|
||||
break;
|
||||
|
||||
@@ -696,24 +696,27 @@ LogicalResult RegionPatternRewriteDriver::simplify() && {
|
||||
|
||||
ctx->executeAction<GreedyPatternRewriteIteration>(
|
||||
[&] {
|
||||
changed = processWorklist();
|
||||
continueRewrites = processWorklist();
|
||||
|
||||
// After applying patterns, make sure that the CFG of each of the
|
||||
// regions is kept up to date.
|
||||
if (config.enableRegionSimplification)
|
||||
changed |= succeeded(simplifyRegions(*this, region));
|
||||
continueRewrites |= succeeded(simplifyRegions(*this, region));
|
||||
},
|
||||
{®ion}, iteration);
|
||||
} while (changed);
|
||||
} while (continueRewrites);
|
||||
|
||||
if (changed)
|
||||
*changed = iteration > 1;
|
||||
|
||||
// Whether the rewrite converges, i.e. wasn't changed in the last iteration.
|
||||
return success(!changed);
|
||||
return success(!continueRewrites);
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
mlir::applyPatternsAndFoldGreedily(Region ®ion,
|
||||
const FrozenRewritePatternSet &patterns,
|
||||
GreedyRewriteConfig config) {
|
||||
GreedyRewriteConfig config, bool *changed) {
|
||||
// The top-level operation must be known to be isolated from above to
|
||||
// prevent performing canonicalizations on operations defined at or above
|
||||
// the region containing 'op'.
|
||||
@@ -727,7 +730,7 @@ mlir::applyPatternsAndFoldGreedily(Region ®ion,
|
||||
// Start the pattern driver.
|
||||
RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
|
||||
region);
|
||||
LogicalResult converged = std::move(driver).simplify();
|
||||
LogicalResult converged = std::move(driver).simplify(changed);
|
||||
LLVM_DEBUG(if (failed(converged)) {
|
||||
llvm::dbgs() << "The pattern rewrite did not converge after scanning "
|
||||
<< config.maxIterations << " times\n";
|
||||
|
||||
Reference in New Issue
Block a user