[Offload] Use amd_signal_async_handler for host function calls (#154131)

This commit is contained in:
Ross Brunton
2025-10-21 13:08:30 +01:00
committed by GitHub
parent 531d45d767
commit 186182bb64

View File

@@ -923,6 +923,10 @@ private:
/// devices. This class relies on signals to implement streams and define the
/// dependencies between asynchronous operations.
struct AMDGPUStreamTy {
public:
/// Function pointer type for `pushHostCallback`
using HostFnType = void (*)(void *);
private:
/// Utility struct holding arguments for async H2H memory copies.
struct MemcpyArgsTy {
@@ -1084,18 +1088,19 @@ private:
/// Indicate to spread data transfers across all available SDMAs
bool UseMultipleSdmaEngines;
struct CallbackDataType {
HostFnType UserFn;
void *UserData;
AMDGPUSignalTy *OutputSignal;
};
/// Wrapper function for implementing host callbacks
static void CallbackWrapper(AMDGPUSignalTy *InputSignal,
AMDGPUSignalTy *OutputSignal,
void (*Callback)(void *), void *UserData) {
// The wait call will not error in this context.
if (InputSignal)
if (auto Err = InputSignal->wait())
reportFatalInternalError(std::move(Err));
Callback(UserData);
OutputSignal->signal();
static bool callbackWrapper([[maybe_unused]] hsa_signal_value_t Signal,
void *UserData) {
auto CallbackData = reinterpret_cast<CallbackDataType *>(UserData);
CallbackData->UserFn(CallbackData->UserData);
CallbackData->OutputSignal->signal();
delete CallbackData;
return false;
}
/// Return the current number of asynchronous operations on the stream.
@@ -1540,7 +1545,7 @@ public:
OutputSignal->get());
}
Error pushHostCallback(void (*Callback)(void *), void *UserData) {
Error pushHostCallback(HostFnType Callback, void *UserData) {
// Retrieve an available signal for the operation's output.
AMDGPUSignalTy *OutputSignal = nullptr;
if (auto Err = SignalManager.getResource(OutputSignal))
@@ -1556,12 +1561,21 @@ public:
InputSignal = consume(OutputSignal).second;
}
// "Leaking" the thread here is consistent with other work added to the
// queue. The input and output signals will remain valid until the output is
// signaled.
std::thread(CallbackWrapper, InputSignal, OutputSignal, Callback, UserData)
.detach();
auto *CallbackData = new CallbackDataType{Callback, UserData, OutputSignal};
if (InputSignal && InputSignal->load()) {
hsa_status_t Status = hsa_amd_signal_async_handler(
InputSignal->get(), HSA_SIGNAL_CONDITION_EQ, 0, callbackWrapper,
CallbackData);
return Plugin::check(Status, "error in hsa_amd_signal_async_handler: %s");
}
// No dependencies - schedule it now.
// Using a seperate thread because this function should run asynchronously
// and not block the main thread.
std::thread([](void *CallbackData) { callbackWrapper(0, CallbackData); },
CallbackData)
.detach();
return Plugin::success();
}
@@ -2733,7 +2747,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
return Plugin::success();
}
Error enqueueHostCallImpl(void (*Callback)(void *), void *UserData,
Error enqueueHostCallImpl(AMDGPUStreamTy::HostFnType Callback, void *UserData,
AsyncInfoWrapperTy &AsyncInfo) override {
AMDGPUStreamTy *Stream = nullptr;
if (auto Err = getStream(AsyncInfo, Stream))