[mlir] int-range-optmizations: Fix referencing of deleted ops (#91807)

The pass runs a `DataFlowSolver` and collects state information on the
input IR. Then, the rewrite driver and folding is applied. During
pattern application and folding it can happen that an Op from the input
IR is deleted and a new Op is created at the same address. When the
newly created Ops is looked up in the `DataFlowSolver` state memory, the
state of the original Op is returned.

This patch adds a method to `DataFlowSolver` which removes all state
related to a `ProgramPoint`. It also adds a listener to the Pass which
clears the state information of deleted Ops from the `DataFlowSolver`.

Fix https://github.com/llvm/llvm-project/issues/81228
This commit is contained in:
Felix Schneider
2024-05-12 18:11:42 +02:00
committed by GitHub
parent 502e77df1f
commit 78b3a00418
2 changed files with 35 additions and 1 deletions

View File

@@ -242,6 +242,17 @@ public:
return static_cast<const StateT *>(it->second.get());
}
/// Erase any analysis state associated with the given program point.
template <typename PointT>
void eraseState(PointT point) {
ProgramPoint pp(point);
for (auto it = analysisStates.begin(); it != analysisStates.end(); ++it) {
if (it->first.first == pp)
analysisStates.erase(it);
}
}
/// Get a uniqued program point instance. If one is not present, it is
/// created with the provided arguments.
template <typename PointT, typename... Args>

View File

@@ -102,6 +102,24 @@ static FailureOr<bool> handleUge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
}
namespace {
/// This class listens on IR transformations performed during a pass relying on
/// information from a `DataflowSolver`. It erases state associated with the
/// erased operation and its results from the `DataFlowSolver` so that Patterns
/// do not accidentally query old state information for newly created Ops.
class DataFlowListener : public RewriterBase::Listener {
public:
DataFlowListener(DataFlowSolver &s) : s(s) {}
protected:
void notifyOperationErased(Operation *op) override {
s.eraseState(op);
for (Value res : op->getResults())
s.eraseState(res);
}
DataFlowSolver &s;
};
struct ConvertCmpOp : public OpRewritePattern<arith::CmpIOp> {
ConvertCmpOp(MLIRContext *context, DataFlowSolver &s)
@@ -167,10 +185,15 @@ struct IntRangeOptimizationsPass
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();
DataFlowListener listener(solver);
RewritePatternSet patterns(ctx);
populateIntRangeOptimizationsPatterns(patterns, solver);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
GreedyRewriteConfig config;
config.listener = &listener;
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
signalPassFailure();
}
};