diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index d875f4eba2b1..01678a9719f9 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1255,14 +1255,31 @@ void PyOperationBase::walk( MlirWalkOrder walkOrder) { PyOperation &operation = getOperation(); operation.checkValid(); + struct UserData { + std::function callback; + bool gotException; + std::string exceptionWhat; + py::object exceptionType; + }; + UserData userData{callback, false, {}, {}}; MlirOperationWalkCallback walkCallback = [](MlirOperation op, void *userData) { - auto *fn = - static_cast *>(userData); - return (*fn)(op); + UserData *calleeUserData = static_cast(userData); + try { + return (calleeUserData->callback)(op); + } catch (py::error_already_set &e) { + calleeUserData->gotException = true; + calleeUserData->exceptionWhat = e.what(); + calleeUserData->exceptionType = e.type(); + return MlirWalkResult::MlirWalkResultInterrupt; + } }; - - mlirOperationWalk(operation, walkCallback, &callback, walkOrder); + mlirOperationWalk(operation, walkCallback, &userData, walkOrder); + if (userData.gotException) { + std::string message("Exception raised in callback: "); + message.append(userData.exceptionWhat); + throw std::runtime_error(message); + } } py::object PyOperationBase::getAsm(bool binary, diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index 9666e63bda1e..3a5d850b86e3 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -1088,5 +1088,5 @@ def testOpWalk(): try: module.operation.walk(callback) - except ValueError: + except RuntimeError: print("Exception raised")