[mlir][py] invalidate nested operations when parent is deleted (#93339)

When an operation is erased in Python, its children may still be in the
"live" list inside Python bindings. After this, if some of the newly
allocated operations happen to reuse the same pointer address, this will
trigger an assertion in the bindings. This assertion would be incorrect
because the operations aren't actually live. Make sure we remove the
children operations from the "live" list when erasing the parent.

This also concentrates responsibility over the removal from the "live"
list and invalidation in a single place.

Note that this requires the IR to be sufficiently structurally valid so
a walk through it can succeed. If this invariant was broken by, e.g, C++
pass called from Python, there isn't much we can do.
This commit is contained in:
Oleksandr "Alex" Zinenko
2024-05-30 10:06:02 +02:00
committed by GitHub
parent 540a36ad7e
commit 67897d77ed
3 changed files with 75 additions and 13 deletions

View File

@@ -697,6 +697,17 @@ void PyMlirContext::clearOperationsInside(MlirOperation op) {
clearOperationsInside(opRef->getOperation());
}
void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
void *userData) {
PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
contextRef->clearOperation(op);
return MlirWalkResult::MlirWalkResultAdvance;
};
mlirOperationWalk(op.getOperation(), invalidatingCallback,
&op.getOperation().getContext(), MlirWalkPreOrder);
}
size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
pybind11::object PyMlirContext::contextEnter() {
@@ -1125,12 +1136,16 @@ PyOperation::~PyOperation() {
// If the operation has already been invalidated there is nothing to do.
if (!valid)
return;
auto &liveOperations = getContext()->liveOperations;
assert(liveOperations.count(operation.ptr) == 1 &&
"destroying operation not in live map");
liveOperations.erase(operation.ptr);
if (!isAttached()) {
mlirOperationDestroy(operation);
// Otherwise, invalidate the operation and remove it from live map when it is
// attached.
if (isAttached()) {
getContext()->clearOperation(*this);
} else {
// And destroy it when it is detached, i.e. owned by Python, in which case
// all nested operations must be invalidated at removed from the live map as
// well.
erase();
}
}
@@ -1540,14 +1555,8 @@ py::object PyOperation::createOpView() {
void PyOperation::erase() {
checkValid();
// TODO: Fix memory hazards when erasing a tree of operations for which a deep
// Python reference to a child operation is live. All children should also
// have their `valid` bit set to false.
auto &liveOperations = getContext()->liveOperations;
if (liveOperations.count(operation.ptr))
liveOperations.erase(operation.ptr);
getContext()->clearOperationAndInside(*this);
mlirOperationDestroy(operation);
valid = false;
}
//------------------------------------------------------------------------------

View File

@@ -218,6 +218,8 @@ public:
/// This is useful for when some non-bindings code destroys the operation and
/// the bindings need to made aware. For example, in the case when pass
/// manager is run.
///
/// Note that this does *NOT* clear the nested operations.
void clearOperation(MlirOperation op);
/// Clears all operations nested inside the given op using
@@ -225,6 +227,10 @@ public:
void clearOperationsInside(PyOperationBase &op);
void clearOperationsInside(MlirOperation op);
/// Clears the operaiton _and_ all operations inside using
/// `clearOperation(MlirOperation)`.
void clearOperationAndInside(PyOperationBase &op);
/// Gets the count of live modules associated with this context.
/// Used for testing.
size_t getLiveModuleCount();
@@ -246,6 +252,7 @@ public:
private:
PyMlirContext(MlirContext context);
// Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
// preserving the relationship that an MlirContext maps to a single
// PyMlirContext wrapper. This could be replaced in the future with an

View File

@@ -0,0 +1,46 @@
# RUN: %PYTHON %s
# It is sufficient that this doesn't assert.
from mlir.ir import *
def createDetachedModule():
module = Module.create()
with InsertionPoint(module.body):
# TODO: Python bindings are currently unaware that modules are also
# operations, so having a module erased won't trigger the cascading
# removal of live operations (#93337). Use a non-module operation
# instead.
nested = Operation.create("test.some_operation", regions=1)
# When the operation is detached from parent, it is considered to be
# owned by Python. It will therefore be erased when the Python object
# is destroyed.
nested.detach_from_parent()
# However, we create and maintain references to operations within
# `nested`. These references keep the corresponding operations in the
# "live" list even if they have been erased in C++, making them
# "zombie". If the C++ allocator reuses one of the address previously
# used for a now-"zombie" operation, this used to result in an
# assertion "cannot create detached operation that already exists" from
# the bindings code. Erasing the detached operation should result in
# removing all nested operations from the live list.
#
# Note that the assertion is not guaranteed since it depends on the
# behavior of the allocator on the C++ side, so this test mail fail
# intermittently.
with InsertionPoint(nested.regions[0].blocks.append()):
a = [Operation.create("test.some_other_operation") for i in range(100)]
return a
def createManyDetachedModules():
with Context() as ctx, Location.unknown():
ctx.allow_unregistered_dialects = True
for j in range(100):
a = createDetachedModule()
if __name__ == "__main__":
createManyDetachedModules()