Files
compute-runtime/shared/source/command_stream/host_function_worker_interface.cpp
Kamil Kopryk f84a5fbee9 feature: add host functions workers
* add common host function worker interface
* add worker as a single thread per csr with 3 modes
* add logic for waiting on internal tag, check gpu hang
* if tag is in pending state, read callback data, run callback
and signal completion
* threads will exit the work loop once stop request
is called in finish
* add multi thread unit tests

Related-To: NEO-14577
Signed-off-by: Kamil Kopryk <kamil.kopryk@intel.com>
2025-11-03 12:11:17 +01:00

69 lines
2.5 KiB
C++

/*
* Copyright (C) 2025 Intel Corporation
*
* SPDX-License-Identifier: MIT
*
*/
#include "shared/source/command_stream/host_function_worker_interface.h"
#include "shared/source/command_stream/host_function.h"
#include "shared/source/utilities/wait_util.h"
#include <chrono>
#include <type_traits>
namespace NEO {
IHostFunctionWorker::IHostFunctionWorker(bool skipHostFunctionExecution,
const std::function<void(GraphicsAllocation &)> &downloadAllocationImpl,
GraphicsAllocation *allocation,
HostFunctionData *data)
: downloadAllocationImpl(downloadAllocationImpl),
allocation(allocation),
data(data),
skipHostFunctionExecution(skipHostFunctionExecution) {
}
IHostFunctionWorker::~IHostFunctionWorker() = default;
bool IHostFunctionWorker::runHostFunction(std::stop_token st) noexcept {
using tagStatusT = std::underlying_type_t<HostFunctionTagStatus>;
const auto start = std::chrono::steady_clock::now();
std::chrono::microseconds waitTime{0};
if (!this->skipHostFunctionExecution) {
while (true) {
if (this->downloadAllocationImpl) [[unlikely]] {
this->downloadAllocationImpl(*this->allocation);
}
const volatile uint32_t *hostFuntionTagAddress = this->data->internalTag;
waitTime = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::steady_clock::now() - start);
bool pendingJobFound = WaitUtils::waitFunctionWithPredicate<const tagStatusT>(hostFuntionTagAddress,
static_cast<tagStatusT>(HostFunctionTagStatus::pending),
std::equal_to<tagStatusT>(),
waitTime.count());
if (pendingJobFound) {
break;
}
if (st.stop_requested()) {
return false;
}
}
using CallbackT = void (*)(void *);
CallbackT callback = reinterpret_cast<CallbackT>(*this->data->entry);
void *callbackData = reinterpret_cast<void *>(*this->data->userData);
callback(callbackData);
}
*this->data->internalTag = static_cast<tagStatusT>(HostFunctionTagStatus::completed);
return true;
}
} // namespace NEO