mirror of
https://github.com/intel/llvm.git
synced 2026-01-13 11:02:04 +08:00
[Offload] Use amd_signal_async_handler for host function calls (#154131)
This commit is contained in:
@@ -923,6 +923,10 @@ private:
|
||||
/// devices. This class relies on signals to implement streams and define the
|
||||
/// dependencies between asynchronous operations.
|
||||
struct AMDGPUStreamTy {
|
||||
public:
|
||||
/// Function pointer type for `pushHostCallback`
|
||||
using HostFnType = void (*)(void *);
|
||||
|
||||
private:
|
||||
/// Utility struct holding arguments for async H2H memory copies.
|
||||
struct MemcpyArgsTy {
|
||||
@@ -1084,18 +1088,19 @@ private:
|
||||
/// Indicate to spread data transfers across all available SDMAs
|
||||
bool UseMultipleSdmaEngines;
|
||||
|
||||
struct CallbackDataType {
|
||||
HostFnType UserFn;
|
||||
void *UserData;
|
||||
AMDGPUSignalTy *OutputSignal;
|
||||
};
|
||||
/// Wrapper function for implementing host callbacks
|
||||
static void CallbackWrapper(AMDGPUSignalTy *InputSignal,
|
||||
AMDGPUSignalTy *OutputSignal,
|
||||
void (*Callback)(void *), void *UserData) {
|
||||
// The wait call will not error in this context.
|
||||
if (InputSignal)
|
||||
if (auto Err = InputSignal->wait())
|
||||
reportFatalInternalError(std::move(Err));
|
||||
|
||||
Callback(UserData);
|
||||
|
||||
OutputSignal->signal();
|
||||
static bool callbackWrapper([[maybe_unused]] hsa_signal_value_t Signal,
|
||||
void *UserData) {
|
||||
auto CallbackData = reinterpret_cast<CallbackDataType *>(UserData);
|
||||
CallbackData->UserFn(CallbackData->UserData);
|
||||
CallbackData->OutputSignal->signal();
|
||||
delete CallbackData;
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Return the current number of asynchronous operations on the stream.
|
||||
@@ -1540,7 +1545,7 @@ public:
|
||||
OutputSignal->get());
|
||||
}
|
||||
|
||||
Error pushHostCallback(void (*Callback)(void *), void *UserData) {
|
||||
Error pushHostCallback(HostFnType Callback, void *UserData) {
|
||||
// Retrieve an available signal for the operation's output.
|
||||
AMDGPUSignalTy *OutputSignal = nullptr;
|
||||
if (auto Err = SignalManager.getResource(OutputSignal))
|
||||
@@ -1556,12 +1561,21 @@ public:
|
||||
InputSignal = consume(OutputSignal).second;
|
||||
}
|
||||
|
||||
// "Leaking" the thread here is consistent with other work added to the
|
||||
// queue. The input and output signals will remain valid until the output is
|
||||
// signaled.
|
||||
std::thread(CallbackWrapper, InputSignal, OutputSignal, Callback, UserData)
|
||||
.detach();
|
||||
auto *CallbackData = new CallbackDataType{Callback, UserData, OutputSignal};
|
||||
if (InputSignal && InputSignal->load()) {
|
||||
hsa_status_t Status = hsa_amd_signal_async_handler(
|
||||
InputSignal->get(), HSA_SIGNAL_CONDITION_EQ, 0, callbackWrapper,
|
||||
CallbackData);
|
||||
|
||||
return Plugin::check(Status, "error in hsa_amd_signal_async_handler: %s");
|
||||
}
|
||||
|
||||
// No dependencies - schedule it now.
|
||||
// Using a seperate thread because this function should run asynchronously
|
||||
// and not block the main thread.
|
||||
std::thread([](void *CallbackData) { callbackWrapper(0, CallbackData); },
|
||||
CallbackData)
|
||||
.detach();
|
||||
return Plugin::success();
|
||||
}
|
||||
|
||||
@@ -2733,7 +2747,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
|
||||
return Plugin::success();
|
||||
}
|
||||
|
||||
Error enqueueHostCallImpl(void (*Callback)(void *), void *UserData,
|
||||
Error enqueueHostCallImpl(AMDGPUStreamTy::HostFnType Callback, void *UserData,
|
||||
AsyncInfoWrapperTy &AsyncInfo) override {
|
||||
AMDGPUStreamTy *Stream = nullptr;
|
||||
if (auto Err = getStream(AsyncInfo, Stream))
|
||||
|
||||
Reference in New Issue
Block a user