[MLIR][Python] restore liveModuleMap (#158506)

There are cases where the same module can have multiple references (via
`PyModule::forModule` via `PyModule::createFromCapsule`) and thus when
`PyModule`s get gc'd `mlirModuleDestroy` can get called multiple times
for the same actual underlying `mlir::Module` (i.e., double free). So we
do actually need a "liveness map" for modules.

Note, if `type_caster<MlirModule>::from_cpp` weren't a thing we could guarantree
this never happened except explicitly when users called `PyModule::createFromCapsule`.
This commit is contained in:
Maksim Levental
2025-09-15 00:45:30 -04:00
committed by GitHub
parent 65ad21d730
commit 6a4f66476f
3 changed files with 47 additions and 15 deletions

View File

@@ -1079,23 +1079,38 @@ PyLocation &DefaultingPyLocation::resolve() {
PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
: BaseContextObject(std::move(contextRef)), module(module) {}
PyModule::~PyModule() { mlirModuleDestroy(module); }
PyModule::~PyModule() {
nb::gil_scoped_acquire acquire;
auto &liveModules = getContext()->liveModules;
assert(liveModules.count(module.ptr) == 1 &&
"destroying module not in live map");
liveModules.erase(module.ptr);
mlirModuleDestroy(module);
}
PyModuleRef PyModule::forModule(MlirModule module) {
MlirContext context = mlirModuleGetContext(module);
PyMlirContextRef contextRef = PyMlirContext::forContext(context);
// Create.
PyModule *unownedModule = new PyModule(std::move(contextRef), module);
// Note that the default return value policy on cast is `automatic_reference`,
// which means "does not take ownership, does not call delete/dtor".
// We use `take_ownership`, which means "Python will call the C++ destructor
// and delete operator when the Python wrapper is garbage collected", because
// MlirModule actually wraps OwningOpRef<ModuleOp> (see mlirModuleCreateParse
// etc).
nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
unownedModule->handle = pyRef;
return PyModuleRef(unownedModule, std::move(pyRef));
nb::gil_scoped_acquire acquire;
auto &liveModules = contextRef->liveModules;
auto it = liveModules.find(module.ptr);
if (it == liveModules.end()) {
// Create.
PyModule *unownedModule = new PyModule(std::move(contextRef), module);
// Note that the default return value policy on cast is automatic_reference,
// which does not take ownership (delete will not be called).
// Just be explicit.
nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
unownedModule->handle = pyRef;
liveModules[module.ptr] =
std::make_pair(unownedModule->handle, unownedModule);
return PyModuleRef(unownedModule, std::move(pyRef));
}
// Use existing.
PyModule *existing = it->second.second;
nb::object pyRef = nb::borrow<nb::object>(it->second.first);
return PyModuleRef(existing, std::move(pyRef));
}
nb::object PyModule::createFromCapsule(nb::object capsule) {
@@ -2084,6 +2099,8 @@ PyInsertionPoint PyInsertionPoint::after(PyOperationBase &op) {
return PyInsertionPoint{block, std::move(nextOpRef)};
}
size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
return PyThreadContextEntry::pushInsertionPoint(insertPoint);
}
@@ -2923,6 +2940,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
PyMlirContextRef ref = PyMlirContext::forContext(self.get());
return ref.releaseObject();
})
.def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
.def("__enter__", &PyMlirContext::contextEnter)

View File

@@ -218,6 +218,10 @@ public:
/// Gets the count of live context objects. Used for testing.
static size_t getLiveCount();
/// Gets the count of live modules associated with this context.
/// Used for testing.
size_t getLiveModuleCount();
/// Enter and exit the context manager.
static nanobind::object contextEnter(nanobind::object context);
void contextExit(const nanobind::object &excType,
@@ -244,6 +248,14 @@ private:
static nanobind::ft_mutex live_contexts_mutex;
static LiveContextMap &getLiveContexts();
// Interns all live modules associated with this context. Modules tracked
// in this map are valid. When a module is invalidated, it is removed
// from this map, and while it still exists as an instance, any
// attempt to access it will raise an error.
using LiveModuleMap =
llvm::DenseMap<const void *, std::pair<nanobind::handle, PyModule *>>;
LiveModuleMap liveModules;
bool emitErrorDiagnostics = false;
MlirContext context;

View File

@@ -121,6 +121,7 @@ def testRoundtripBinary():
def testModuleOperation():
ctx = Context()
module = Module.parse(r"""module @successfulParse {}""", ctx)
assert ctx._get_live_module_count() == 1
op1 = module.operation
# CHECK: module @successfulParse
print(op1)
@@ -145,6 +146,7 @@ def testModuleOperation():
op1 = None
op2 = None
gc.collect()
assert ctx._get_live_module_count() == 0
# CHECK-LABEL: TEST: testModuleCapsule
@@ -152,17 +154,17 @@ def testModuleOperation():
def testModuleCapsule():
ctx = Context()
module = Module.parse(r"""module @successfulParse {}""", ctx)
assert ctx._get_live_module_count() == 1
# CHECK: "mlir.ir.Module._CAPIPtr"
module_capsule = module._CAPIPtr
print(module_capsule)
module_dup = Module._CAPICreate(module_capsule)
assert module is not module_dup
assert module is module_dup
assert module == module_dup
module._clear_mlir_module()
assert module != module_dup
assert module_dup.context is ctx
# Gc and verify destructed.
module = None
module_capsule = None
module_dup = None
gc.collect()
assert ctx._get_live_module_count() == 0