mirror of
https://github.com/intel/llvm.git
synced 2026-01-13 11:02:04 +08:00
[MLIR][Python] restore liveModuleMap (#158506)
There are cases where the same module can have multiple references (via `PyModule::forModule` via `PyModule::createFromCapsule`) and thus when `PyModule`s get gc'd `mlirModuleDestroy` can get called multiple times for the same actual underlying `mlir::Module` (i.e., double free). So we do actually need a "liveness map" for modules. Note, if `type_caster<MlirModule>::from_cpp` weren't a thing we could guarantree this never happened except explicitly when users called `PyModule::createFromCapsule`.
This commit is contained in:
@@ -1079,23 +1079,38 @@ PyLocation &DefaultingPyLocation::resolve() {
|
||||
PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
|
||||
: BaseContextObject(std::move(contextRef)), module(module) {}
|
||||
|
||||
PyModule::~PyModule() { mlirModuleDestroy(module); }
|
||||
PyModule::~PyModule() {
|
||||
nb::gil_scoped_acquire acquire;
|
||||
auto &liveModules = getContext()->liveModules;
|
||||
assert(liveModules.count(module.ptr) == 1 &&
|
||||
"destroying module not in live map");
|
||||
liveModules.erase(module.ptr);
|
||||
mlirModuleDestroy(module);
|
||||
}
|
||||
|
||||
PyModuleRef PyModule::forModule(MlirModule module) {
|
||||
MlirContext context = mlirModuleGetContext(module);
|
||||
PyMlirContextRef contextRef = PyMlirContext::forContext(context);
|
||||
|
||||
// Create.
|
||||
PyModule *unownedModule = new PyModule(std::move(contextRef), module);
|
||||
// Note that the default return value policy on cast is `automatic_reference`,
|
||||
// which means "does not take ownership, does not call delete/dtor".
|
||||
// We use `take_ownership`, which means "Python will call the C++ destructor
|
||||
// and delete operator when the Python wrapper is garbage collected", because
|
||||
// MlirModule actually wraps OwningOpRef<ModuleOp> (see mlirModuleCreateParse
|
||||
// etc).
|
||||
nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
|
||||
unownedModule->handle = pyRef;
|
||||
return PyModuleRef(unownedModule, std::move(pyRef));
|
||||
nb::gil_scoped_acquire acquire;
|
||||
auto &liveModules = contextRef->liveModules;
|
||||
auto it = liveModules.find(module.ptr);
|
||||
if (it == liveModules.end()) {
|
||||
// Create.
|
||||
PyModule *unownedModule = new PyModule(std::move(contextRef), module);
|
||||
// Note that the default return value policy on cast is automatic_reference,
|
||||
// which does not take ownership (delete will not be called).
|
||||
// Just be explicit.
|
||||
nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
|
||||
unownedModule->handle = pyRef;
|
||||
liveModules[module.ptr] =
|
||||
std::make_pair(unownedModule->handle, unownedModule);
|
||||
return PyModuleRef(unownedModule, std::move(pyRef));
|
||||
}
|
||||
// Use existing.
|
||||
PyModule *existing = it->second.second;
|
||||
nb::object pyRef = nb::borrow<nb::object>(it->second.first);
|
||||
return PyModuleRef(existing, std::move(pyRef));
|
||||
}
|
||||
|
||||
nb::object PyModule::createFromCapsule(nb::object capsule) {
|
||||
@@ -2084,6 +2099,8 @@ PyInsertionPoint PyInsertionPoint::after(PyOperationBase &op) {
|
||||
return PyInsertionPoint{block, std::move(nextOpRef)};
|
||||
}
|
||||
|
||||
size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
|
||||
|
||||
nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
|
||||
return PyThreadContextEntry::pushInsertionPoint(insertPoint);
|
||||
}
|
||||
@@ -2923,6 +2940,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
|
||||
PyMlirContextRef ref = PyMlirContext::forContext(self.get());
|
||||
return ref.releaseObject();
|
||||
})
|
||||
.def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
|
||||
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
|
||||
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
|
||||
.def("__enter__", &PyMlirContext::contextEnter)
|
||||
|
||||
@@ -218,6 +218,10 @@ public:
|
||||
/// Gets the count of live context objects. Used for testing.
|
||||
static size_t getLiveCount();
|
||||
|
||||
/// Gets the count of live modules associated with this context.
|
||||
/// Used for testing.
|
||||
size_t getLiveModuleCount();
|
||||
|
||||
/// Enter and exit the context manager.
|
||||
static nanobind::object contextEnter(nanobind::object context);
|
||||
void contextExit(const nanobind::object &excType,
|
||||
@@ -244,6 +248,14 @@ private:
|
||||
static nanobind::ft_mutex live_contexts_mutex;
|
||||
static LiveContextMap &getLiveContexts();
|
||||
|
||||
// Interns all live modules associated with this context. Modules tracked
|
||||
// in this map are valid. When a module is invalidated, it is removed
|
||||
// from this map, and while it still exists as an instance, any
|
||||
// attempt to access it will raise an error.
|
||||
using LiveModuleMap =
|
||||
llvm::DenseMap<const void *, std::pair<nanobind::handle, PyModule *>>;
|
||||
LiveModuleMap liveModules;
|
||||
|
||||
bool emitErrorDiagnostics = false;
|
||||
|
||||
MlirContext context;
|
||||
|
||||
@@ -121,6 +121,7 @@ def testRoundtripBinary():
|
||||
def testModuleOperation():
|
||||
ctx = Context()
|
||||
module = Module.parse(r"""module @successfulParse {}""", ctx)
|
||||
assert ctx._get_live_module_count() == 1
|
||||
op1 = module.operation
|
||||
# CHECK: module @successfulParse
|
||||
print(op1)
|
||||
@@ -145,6 +146,7 @@ def testModuleOperation():
|
||||
op1 = None
|
||||
op2 = None
|
||||
gc.collect()
|
||||
assert ctx._get_live_module_count() == 0
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testModuleCapsule
|
||||
@@ -152,17 +154,17 @@ def testModuleOperation():
|
||||
def testModuleCapsule():
|
||||
ctx = Context()
|
||||
module = Module.parse(r"""module @successfulParse {}""", ctx)
|
||||
assert ctx._get_live_module_count() == 1
|
||||
# CHECK: "mlir.ir.Module._CAPIPtr"
|
||||
module_capsule = module._CAPIPtr
|
||||
print(module_capsule)
|
||||
module_dup = Module._CAPICreate(module_capsule)
|
||||
assert module is not module_dup
|
||||
assert module is module_dup
|
||||
assert module == module_dup
|
||||
module._clear_mlir_module()
|
||||
assert module != module_dup
|
||||
assert module_dup.context is ctx
|
||||
# Gc and verify destructed.
|
||||
module = None
|
||||
module_capsule = None
|
||||
module_dup = None
|
||||
gc.collect()
|
||||
assert ctx._get_live_module_count() == 0
|
||||
|
||||
Reference in New Issue
Block a user