[mlir][python] Fix PyOperationBase::walk not catching exception in python callback (#89225)

If the python callback throws an error, the c++ code will throw a
py::error_already_set that needs to be caught and handled in the c++
code .

This change is inspired by the similar solution in
PySymbolTable::walkSymbolTables.
This commit is contained in:
tomnatan30
2024-04-18 15:09:31 +01:00
committed by GitHub
parent 8f07a67f97
commit bc5536469d
2 changed files with 23 additions and 6 deletions

View File

@@ -1255,14 +1255,31 @@ void PyOperationBase::walk(
MlirWalkOrder walkOrder) {
PyOperation &operation = getOperation();
operation.checkValid();
struct UserData {
std::function<MlirWalkResult(MlirOperation)> callback;
bool gotException;
std::string exceptionWhat;
py::object exceptionType;
};
UserData userData{callback, false, {}, {}};
MlirOperationWalkCallback walkCallback = [](MlirOperation op,
void *userData) {
auto *fn =
static_cast<std::function<MlirWalkResult(MlirOperation)> *>(userData);
return (*fn)(op);
UserData *calleeUserData = static_cast<UserData *>(userData);
try {
return (calleeUserData->callback)(op);
} catch (py::error_already_set &e) {
calleeUserData->gotException = true;
calleeUserData->exceptionWhat = e.what();
calleeUserData->exceptionType = e.type();
return MlirWalkResult::MlirWalkResultInterrupt;
}
};
mlirOperationWalk(operation, walkCallback, &callback, walkOrder);
mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
if (userData.gotException) {
std::string message("Exception raised in callback: ");
message.append(userData.exceptionWhat);
throw std::runtime_error(message);
}
}
py::object PyOperationBase::getAsm(bool binary,

View File

@@ -1088,5 +1088,5 @@ def testOpWalk():
try:
module.operation.walk(callback)
except ValueError:
except RuntimeError:
print("Exception raised")