diff --git a/level_zero/core/source/cmdlist/cmdlist_hw.inl b/level_zero/core/source/cmdlist/cmdlist_hw.inl index e43cd6181b..b469345bd2 100644 --- a/level_zero/core/source/cmdlist/cmdlist_hw.inl +++ b/level_zero/core/source/cmdlist/cmdlist_hw.inl @@ -1864,6 +1864,7 @@ void CommandListCoreFamily::dispatchHostFunction( auto csr = getCsr(false); csr->ensureHostFunctionDataInitialization(); this->commandContainer.addToResidencyContainer(csr->getHostFunctionDataAllocation()); + csr->signalHostFunctionWorker(); NEO::HostFunctionHelper::programHostFunction(*this->commandContainer.getCommandStream(), csr->getHostFunctionData(), userHostFunctionAddress, userDataAddress); } else { addHostFunctionToPatchCommands(userHostFunctionAddress, userDataAddress); diff --git a/level_zero/core/source/cmdqueue/cmdqueue_hw_gen12lp.inl b/level_zero/core/source/cmdqueue/cmdqueue_hw_gen12lp.inl index 5b773615ce..c219cd4385 100644 --- a/level_zero/core/source/cmdqueue/cmdqueue_hw_gen12lp.inl +++ b/level_zero/core/source/cmdqueue/cmdqueue_hw_gen12lp.inl @@ -199,6 +199,7 @@ void CommandQueueHw::patchCommands(CommandList &commandList, uint case CommandToPatch::HostFunctionEntry: csr->ensureHostFunctionDataInitialization(); csr->makeResidentHostFunctionAllocation(); + csr->signalHostFunctionWorker(); NEO::HostFunctionHelper::programHostFunctionAddress(nullptr, commandToPatch.pCommand, csr->getHostFunctionData(), commandToPatch.baseAddress); break; diff --git a/level_zero/core/source/cmdqueue/cmdqueue_xe_hp_core_and_later.inl b/level_zero/core/source/cmdqueue/cmdqueue_xe_hp_core_and_later.inl index 63bd507a1e..d313b4ad62 100644 --- a/level_zero/core/source/cmdqueue/cmdqueue_xe_hp_core_and_later.inl +++ b/level_zero/core/source/cmdqueue/cmdqueue_xe_hp_core_and_later.inl @@ -287,6 +287,7 @@ void CommandQueueHw::patchCommands(CommandList &commandList, uint case CommandToPatch::HostFunctionEntry: csr->ensureHostFunctionDataInitialization(); csr->makeResidentHostFunctionAllocation(); + csr->signalHostFunctionWorker(); NEO::HostFunctionHelper::programHostFunctionAddress(nullptr, commandToPatch.pCommand, csr->getHostFunctionData(), commandToPatch.baseAddress); break; diff --git a/level_zero/core/test/unit_tests/sources/cmdqueue/test_cmdqueue_2.cpp b/level_zero/core/test/unit_tests/sources/cmdqueue/test_cmdqueue_2.cpp index bfe4940d4a..9d285a8c30 100644 --- a/level_zero/core/test/unit_tests/sources/cmdqueue/test_cmdqueue_2.cpp +++ b/level_zero/core/test/unit_tests/sources/cmdqueue/test_cmdqueue_2.cpp @@ -1156,6 +1156,10 @@ HWTEST_F(HostFunctionsCmdPatchTests, givenHostFunctionPatchCommandsWhenPatchComm NEO::CommandStreamReceiver *csr = nullptr; device->getCsrForOrdinalAndIndex(&csr, 0u, 0u, ZE_COMMAND_QUEUE_PRIORITY_NORMAL, 0, false); auto commandQueue = std::make_unique>(device, csr, &desc); + MockCommandStreamReceiver mockCsr(*neoDevice->executionEnvironment, neoDevice->getRootDeviceIndex(), neoDevice->getDeviceBitfield()); + const auto oldCsr = commandQueue->csr; + commandQueue->csr = &mockCsr; + auto commandList = std::make_unique>>(); commandList->commandsToPatch.clear(); @@ -1197,6 +1201,8 @@ HWTEST_F(HostFunctionsCmdPatchTests, givenHostFunctionPatchCommandsWhenPatchComm commandQueue->patchCommands(*commandList, 0, false, nullptr); EXPECT_NE(nullptr, commandQueue->csr->getHostFunctionDataAllocation()); + EXPECT_EQ(1u, mockCsr.createHostFunctionWorkerCounter); + EXPECT_EQ(1u, mockCsr.signalHostFunctionWorkerCounter); auto &hostFunctionDataFromCsr = commandQueue->csr->getHostFunctionData(); @@ -1220,6 +1226,8 @@ HWTEST_F(HostFunctionsCmdPatchTests, givenHostFunctionPatchCommandsWhenPatchComm // internal tag wait - semaphore wait EXPECT_EQ(static_cast(HostFunctionTagStatus::completed), internalTagMiWait.getSemaphoreDataDword()); EXPECT_EQ(reinterpret_cast(hostFunctionDataFromCsr.internalTag), internalTagMiWait.getSemaphoreGraphicsAddress()); + + commandQueue->csr = oldCsr; } } // namespace ult diff --git a/shared/source/command_stream/CMakeLists.txt b/shared/source/command_stream/CMakeLists.txt index ab55819178..e82b16c3da 100644 --- a/shared/source/command_stream/CMakeLists.txt +++ b/shared/source/command_stream/CMakeLists.txt @@ -30,8 +30,17 @@ set(NEO_CORE_COMMAND_STREAM ${CMAKE_CURRENT_SOURCE_DIR}/definitions${BRANCH_DIR_SUFFIX}stream_properties.inl ${CMAKE_CURRENT_SOURCE_DIR}/device_command_stream.h ${CMAKE_CURRENT_SOURCE_DIR}/host_function.h + ${CMAKE_CURRENT_SOURCE_DIR}/host_function.cpp ${CMAKE_CURRENT_SOURCE_DIR}/host_function.inl ${CMAKE_CURRENT_SOURCE_DIR}/host_function_enablers.inl + ${CMAKE_CURRENT_SOURCE_DIR}/host_function_worker_cv.h + ${CMAKE_CURRENT_SOURCE_DIR}/host_function_worker_cv.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/host_function_worker_interface.h + ${CMAKE_CURRENT_SOURCE_DIR}/host_function_worker_interface.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/host_function_worker_counting_semaphore.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/host_function_worker_counting_semaphore.h + ${CMAKE_CURRENT_SOURCE_DIR}/host_function_worker_atomic.h + ${CMAKE_CURRENT_SOURCE_DIR}/host_function_worker_atomic.cpp ${CMAKE_CURRENT_SOURCE_DIR}/linear_stream.cpp ${CMAKE_CURRENT_SOURCE_DIR}/linear_stream.h ${CMAKE_CURRENT_SOURCE_DIR}/preemption.cpp diff --git a/shared/source/command_stream/command_stream_receiver.cpp b/shared/source/command_stream/command_stream_receiver.cpp index d1d3713cdb..bfcaad10a4 100644 --- a/shared/source/command_stream/command_stream_receiver.cpp +++ b/shared/source/command_stream/command_stream_receiver.cpp @@ -9,6 +9,7 @@ #include "shared/source/command_container/implicit_scaling.h" #include "shared/source/command_stream/aub_subcapture_status.h" +#include "shared/source/command_stream/host_function_worker_interface.h" #include "shared/source/command_stream/scratch_space_controller.h" #include "shared/source/command_stream/submission_status.h" #include "shared/source/command_stream/submissions_aggregator.h" @@ -97,6 +98,7 @@ CommandStreamReceiver::CommandStreamReceiver(ExecutionEnvironment &executionEnvi auto &compilerProductHelper = rootDeviceEnvironment.getHelper(); this->heaplessModeEnabled = compilerProductHelper.isHeaplessModeEnabled(hwInfo); this->heaplessStateInitEnabled = compilerProductHelper.isHeaplessStateInitEnabled(heaplessModeEnabled); + this->hostFunctionWorkerMode = debugManager.flags.HostFunctionWorkMode.get(); } CommandStreamReceiver::~CommandStreamReceiver() { @@ -235,6 +237,25 @@ WaitStatus CommandStreamReceiver::waitForTaskCountAndCleanTemporaryAllocationLis return waitForTaskCountAndCleanAllocationList(requiredTaskCount, TEMPORARY_ALLOCATION); } +void CommandStreamReceiver::createHostFunctionWorker() { + + if (this->hostFunctionWorker != nullptr) { + return; + } + + this->hostFunctionWorker = HostFunctionFactory::createHostFunctionWorker(this->hostFunctionWorkerMode, + this->isAubMode(), + this->downloadAllocationImpl, + this->getHostFunctionDataAllocation(), + &this->getHostFunctionData()); + + this->hostFunctionWorker->start(); +} + +IHostFunctionWorker *CommandStreamReceiver::getHostFunctionWorker() { + return this->hostFunctionWorker; +} + void CommandStreamReceiver::ensureCommandBufferAllocation(LinearStream &commandStream, size_t minimumRequiredSize, size_t additionalAllocationSize) { if (commandStream.getAvailableSpace() >= minimumRequiredSize) { return; @@ -419,6 +440,10 @@ void CommandStreamReceiver::cleanupResources() { tagsMultiAllocation = nullptr; } + if (hostFunctionWorker) { + cleanupHostFunctionWorker(); + } + if (hostFunctionDataMultiAllocation) { hostFunctionDataAllocation = nullptr; @@ -464,6 +489,12 @@ void CommandStreamReceiver::cleanupResources() { ownedPrivateAllocations.clear(); } +void CommandStreamReceiver ::cleanupHostFunctionWorker() { + hostFunctionWorker->finish(); + delete hostFunctionWorker; + hostFunctionWorker = nullptr; +} + WaitStatus CommandStreamReceiver::waitForCompletionWithTimeout(const WaitParams ¶ms, TaskCountType taskCountToWait) { bool printWaitForCompletion = debugManager.flags.LogWaitingForCompletion.get(); if (printWaitForCompletion) { @@ -697,6 +728,10 @@ void *CommandStreamReceiver::getIndirectHeapCurrentPtr(IndirectHeapType heapType return nullptr; } +void CommandStreamReceiver::signalHostFunctionWorker() { + hostFunctionWorker->submit(); +} + void CommandStreamReceiver::ensureHostFunctionDataInitialization() { if (!this->hostFunctionInitialized.load(std::memory_order_acquire)) { initializeHostFunctionData(); @@ -717,6 +752,9 @@ void CommandStreamReceiver::initializeHostFunctionData() { this->hostFunctionData.entry = reinterpret_cast(ptrOffset(hostFunctionBuffer, HostFunctionHelper::entryOffset)); this->hostFunctionData.userData = reinterpret_cast(ptrOffset(hostFunctionBuffer, HostFunctionHelper::userDataOffset)); this->hostFunctionData.internalTag = reinterpret_cast(ptrOffset(hostFunctionBuffer, HostFunctionHelper::internalTagOffset)); + + createHostFunctionWorker(); + this->hostFunctionInitialized.store(true, std::memory_order_release); } @@ -968,12 +1006,14 @@ bool CommandStreamReceiver::createPreemptionAllocation() { std::unique_lock CommandStreamReceiver::obtainUniqueOwnership() { return std::unique_lock(this->ownershipMutex); } + std::unique_lock CommandStreamReceiver::tryObtainUniqueOwnership() { return std::unique_lock(this->ownershipMutex, std::try_to_lock); } std::unique_lock CommandStreamReceiver::obtainHostPtrSurfaceCreationLock() { return std::unique_lock(this->hostPtrSurfaceCreationMutex); } + AllocationsList &CommandStreamReceiver::getTemporaryAllocations() { return internalAllocationStorage->getTemporaryAllocations(); } AllocationsList &CommandStreamReceiver::getAllocationsForReuse() { return internalAllocationStorage->getAllocationsForReuse(); } AllocationsList &CommandStreamReceiver::getDeferredAllocations() { return internalAllocationStorage->getDeferredAllocations(); } diff --git a/shared/source/command_stream/command_stream_receiver.h b/shared/source/command_stream/command_stream_receiver.h index fa6732ce47..a99f575d2a 100644 --- a/shared/source/command_stream/command_stream_receiver.h +++ b/shared/source/command_stream/command_stream_receiver.h @@ -67,6 +67,7 @@ class KmdNotifyHelper; class GfxCoreHelper; class ProductHelper; class ReleaseHelper; +class IHostFunctionWorker; enum class WaitStatus; struct AubSubCaptureStatus; class SharedPoolAllocation; @@ -148,6 +149,8 @@ class CommandStreamReceiver : NEO::NonCopyableAndNonMovableClass { MOCKABLE_VIRTUAL WaitStatus waitForTaskCount(TaskCountType requiredTaskCount); WaitStatus waitForTaskCountAndCleanAllocationList(TaskCountType requiredTaskCount, uint32_t allocationUsage); MOCKABLE_VIRTUAL WaitStatus waitForTaskCountAndCleanTemporaryAllocationList(TaskCountType requiredTaskCount); + MOCKABLE_VIRTUAL void createHostFunctionWorker(); + IHostFunctionWorker *getHostFunctionWorker(); LinearStream &getCS(size_t minRequiredSize = 1024u); OSInterface *getOSInterface() const; @@ -270,7 +273,6 @@ class CommandStreamReceiver : NEO::NonCopyableAndNonMovableClass { MOCKABLE_VIRTUAL bool createPerDssBackedBuffer(Device &device); [[nodiscard]] MOCKABLE_VIRTUAL std::unique_lock obtainUniqueOwnership(); [[nodiscard]] MOCKABLE_VIRTUAL std::unique_lock tryObtainUniqueOwnership(); - bool peekTimestampPacketWriteEnabled() const { return timestampPacketWriteEnabled; } bool isLatestTaskCountFlushed() { @@ -567,6 +569,7 @@ class CommandStreamReceiver : NEO::NonCopyableAndNonMovableClass { bool isLatestFlushIsTaskCountUpdateOnly() const { return latestFlushIsTaskCountUpdateOnly; } MOCKABLE_VIRTUAL uint32_t getContextGroupId() const; + MOCKABLE_VIRTUAL void signalHostFunctionWorker(); void ensureHostFunctionDataInitialization(); HostFunctionData &getHostFunctionData(); @@ -584,6 +587,7 @@ class CommandStreamReceiver : NEO::NonCopyableAndNonMovableClass { TaskCountType taskLevel, DispatchFlags &dispatchFlags, Device &device) = 0; void cleanupResources(); + void cleanupHostFunctionWorker(); void printDeviceIndex(); void checkForNewResources(TaskCountType submittedTaskCount, TaskCountType allocationTaskCount, GraphicsAllocation &gfxAllocation); bool checkImplicitFlushForGpuIdle(); @@ -610,6 +614,7 @@ class CommandStreamReceiver : NEO::NonCopyableAndNonMovableClass { std::unique_ptr timestampPacketAllocator; std::unique_ptr userPauseConfirmation; std::unique_ptr globalStatelessHeap; + IHostFunctionWorker *hostFunctionWorker = nullptr; ResidencyContainer residencyAllocations; PrivateAllocsToReuseContainer ownedPrivateAllocations; @@ -664,7 +669,6 @@ class CommandStreamReceiver : NEO::NonCopyableAndNonMovableClass { std::atomic taskCount{0}; std::atomic numClients = 0u; - DispatchMode dispatchMode = DispatchMode::immediateDispatch; SamplerCacheFlushState samplerCacheFlushRequired = SamplerCacheFlushState::samplerCacheFlushNotRequired; PreemptionMode lastPreemptionMode = PreemptionMode::Initial; @@ -673,7 +677,7 @@ class CommandStreamReceiver : NEO::NonCopyableAndNonMovableClass { uint32_t lastSentL3Config = 0; uint32_t latestSentStatelessMocsConfig; uint64_t lastSentSliceCount; - + int32_t hostFunctionWorkerMode = -1; uint32_t requiredScratchSlot0Size = 0; uint32_t requiredScratchSlot1Size = 0; uint32_t lastAdditionalKernelExecInfo; diff --git a/shared/source/command_stream/host_function.cpp b/shared/source/command_stream/host_function.cpp new file mode 100644 index 0000000000..9d9fa57054 --- /dev/null +++ b/shared/source/command_stream/host_function.cpp @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2025 Intel Corporation + * + * SPDX-License-Identifier: MIT + * + */ + +#include "shared/source/command_stream/host_function.h" + +#include "shared/source/command_stream/command_stream_receiver.h" +#include "shared/source/command_stream/host_function_worker_atomic.h" +#include "shared/source/command_stream/host_function_worker_counting_semaphore.h" +#include "shared/source/command_stream/host_function_worker_cv.h" +#include "shared/source/command_stream/host_function_worker_interface.h" + +namespace NEO::HostFunctionFactory { + +IHostFunctionWorker *createHostFunctionWorker(int32_t hostFunctionWorkerMode, + bool isAubMode, + const std::function &downloadAllocationImpl, + GraphicsAllocation *allocation, + HostFunctionData *data) { + + bool skipHostFunctionExecution = isAubMode; + + switch (hostFunctionWorkerMode) { + default: + case 0: + return new HostFunctionWorkerCountingSemaphore(skipHostFunctionExecution, downloadAllocationImpl, allocation, data); + case 1: + return new HostFunctionWorkerCV(skipHostFunctionExecution, downloadAllocationImpl, allocation, data); + case 2: + return new HostFunctionWorkerAtomic(skipHostFunctionExecution, downloadAllocationImpl, allocation, data); + } +} + +} // namespace NEO::HostFunctionFactory diff --git a/shared/source/command_stream/host_function.h b/shared/source/command_stream/host_function.h index 14f927cd6c..5d03fbf665 100644 --- a/shared/source/command_stream/host_function.h +++ b/shared/source/command_stream/host_function.h @@ -9,10 +9,14 @@ #include #include +#include namespace NEO { class LinearStream; +class CommandStreamReceiver; +class IHostFunctionWorker; +class GraphicsAllocation; struct HostFunctionData { volatile uint64_t *entry = nullptr; @@ -47,4 +51,13 @@ struct HostFunctionHelper { static void programWaitForHostFunctionCompletion(LinearStream *commandStream, void *cmdBuffer, const HostFunctionData &hostFunctionData); }; +namespace HostFunctionFactory { +IHostFunctionWorker *createHostFunctionWorker(int32_t hostFunctionWorkerMode, + bool isAubMode, + const std::function &downloadAllocationImpl, + GraphicsAllocation *allocation, + HostFunctionData *data); + +} + } // namespace NEO diff --git a/shared/source/command_stream/host_function_worker_atomic.cpp b/shared/source/command_stream/host_function_worker_atomic.cpp new file mode 100644 index 0000000000..1b801ded27 --- /dev/null +++ b/shared/source/command_stream/host_function_worker_atomic.cpp @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2025 Intel Corporation + * + * SPDX-License-Identifier: MIT + * + */ + +#include "shared/source/command_stream/host_function_worker_atomic.h" + +#include "shared/source/command_stream/host_function.h" + +namespace NEO { +HostFunctionWorkerAtomic::HostFunctionWorkerAtomic(bool skipHostFunctionExecution, + const std::function &downloadAllocationImpl, + GraphicsAllocation *allocation, + HostFunctionData *data) + : IHostFunctionWorker(skipHostFunctionExecution, downloadAllocationImpl, allocation, data) { +} + +HostFunctionWorkerAtomic::~HostFunctionWorkerAtomic() = default; + +void HostFunctionWorkerAtomic::start() { + + std::lock_guard lg{workerMutex}; + if (!worker) { + worker = std::make_unique([this](std::stop_token st) { + this->workerLoop(st); + }); + } +} + +void HostFunctionWorkerAtomic::finish() { + std::lock_guard lg{workerMutex}; + if (worker) { + worker->request_stop(); + pending.fetch_add(1u); + pending.notify_one(); + worker.reset(nullptr); + } +} + +void HostFunctionWorkerAtomic::submit() noexcept { + pending.fetch_add(1, std::memory_order_release); + pending.notify_one(); +} + +void HostFunctionWorkerAtomic::workerLoop(std::stop_token st) noexcept { + + while (true) { + + while (pending.load(std::memory_order_acquire) == 0) { + pending.wait(0, std::memory_order_acquire); + } + + if (st.stop_requested()) { + return; + } + + pending.fetch_sub(1, std::memory_order_acq_rel); + + bool sucess = this->runHostFunction(st); + if (!sucess) [[unlikely]] { + return; + } + } +} + +} // namespace NEO diff --git a/shared/source/command_stream/host_function_worker_atomic.h b/shared/source/command_stream/host_function_worker_atomic.h new file mode 100644 index 0000000000..1c22729915 --- /dev/null +++ b/shared/source/command_stream/host_function_worker_atomic.h @@ -0,0 +1,38 @@ +/* + * Copyright (C) 2025 Intel Corporation + * + * SPDX-License-Identifier: MIT + * + */ + +#pragma once + +#include "shared/source/command_stream/host_function_worker_interface.h" + +#include +#include +#include + +namespace NEO { + +class HostFunctionWorkerAtomic : public IHostFunctionWorker { + public: + HostFunctionWorkerAtomic(bool skipHostFunctionExecution, + const std::function &downloadAllocationImpl, + GraphicsAllocation *allocation, + HostFunctionData *data); + ~HostFunctionWorkerAtomic() override; + + void start() override; + void finish() override; + void submit() noexcept override; + + private: + void workerLoop(std::stop_token st) noexcept; + + std::atomic pending{0}; +}; + +static_assert(NonCopyableAndNonMovable); + +} // namespace NEO diff --git a/shared/source/command_stream/host_function_worker_counting_semaphore.cpp b/shared/source/command_stream/host_function_worker_counting_semaphore.cpp new file mode 100644 index 0000000000..138c0fb31c --- /dev/null +++ b/shared/source/command_stream/host_function_worker_counting_semaphore.cpp @@ -0,0 +1,59 @@ +/* + * Copyright (C) 2025 Intel Corporation + * + * SPDX-License-Identifier: MIT + * + */ + +#include "shared/source/command_stream/host_function_worker_counting_semaphore.h" + +namespace NEO { + +HostFunctionWorkerCountingSemaphore::HostFunctionWorkerCountingSemaphore(bool skipHostFunctionExecution, const std::function &downloadAllocationImpl, GraphicsAllocation *allocation, HostFunctionData *data) + : IHostFunctionWorker(skipHostFunctionExecution, downloadAllocationImpl, allocation, data) { +} + +HostFunctionWorkerCountingSemaphore::~HostFunctionWorkerCountingSemaphore() = default; + +void HostFunctionWorkerCountingSemaphore::start() { + std::lock_guard lg{workerMutex}; + + if (!worker) { + worker = std::make_unique([this](std::stop_token st) { + this->workerLoop(st); + }); + } +} + +void HostFunctionWorkerCountingSemaphore::finish() { + std::lock_guard lg{workerMutex}; + + if (worker) { + worker->request_stop(); + semaphore.release(); + worker.reset(nullptr); + } +} + +void HostFunctionWorkerCountingSemaphore::submit() noexcept { + semaphore.release(); +} + +void HostFunctionWorkerCountingSemaphore::workerLoop(std::stop_token st) noexcept { + + while (true) { + + semaphore.acquire(); + + if (st.stop_requested()) [[unlikely]] { + return; + } + + bool success = runHostFunction(st); + if (!success) [[unlikely]] { + return; + } + } +} + +} // namespace NEO diff --git a/shared/source/command_stream/host_function_worker_counting_semaphore.h b/shared/source/command_stream/host_function_worker_counting_semaphore.h new file mode 100644 index 0000000000..46eb3e4aed --- /dev/null +++ b/shared/source/command_stream/host_function_worker_counting_semaphore.h @@ -0,0 +1,38 @@ +/* + * Copyright (C) 2025 Intel Corporation + * + * SPDX-License-Identifier: MIT + * + */ + +#pragma once + +#include "shared/source/command_stream/host_function_worker_interface.h" + +#include +#include +#include + +namespace NEO { + +class HostFunctionWorkerCountingSemaphore : public IHostFunctionWorker { + public: + HostFunctionWorkerCountingSemaphore(bool skipHostFunctionExecution, + const std::function &downloadAllocationImpl, + GraphicsAllocation *allocation, + HostFunctionData *data); + ~HostFunctionWorkerCountingSemaphore() override; + + void start() override; + void finish() override; + void submit() noexcept override; + + private: + void workerLoop(std::stop_token st) noexcept; + + std::counting_semaphore<> semaphore{0}; +}; + +static_assert(NonCopyableAndNonMovable); + +} // namespace NEO diff --git a/shared/source/command_stream/host_function_worker_cv.cpp b/shared/source/command_stream/host_function_worker_cv.cpp new file mode 100644 index 0000000000..b0e17b2663 --- /dev/null +++ b/shared/source/command_stream/host_function_worker_cv.cpp @@ -0,0 +1,75 @@ +/* + * Copyright (C) 2025 Intel Corporation + * + * SPDX-License-Identifier: MIT + * + */ + +#include "shared/source/command_stream/host_function_worker_cv.h" + +#include "shared/source/command_stream/command_stream_receiver.h" +#include "shared/source/command_stream/host_function.h" +#include "shared/source/utilities/wait_util.h" + +namespace NEO { +HostFunctionWorkerCV::HostFunctionWorkerCV(bool skipHostFunctionExecution, + const std::function &downloadAllocationImpl, + GraphicsAllocation *allocation, + HostFunctionData *data) + : IHostFunctionWorker(skipHostFunctionExecution, downloadAllocationImpl, allocation, data) { +} + +HostFunctionWorkerCV::~HostFunctionWorkerCV() = default; + +void HostFunctionWorkerCV::start() { + std::lock_guard lg{workerMutex}; + if (!worker) { + worker = std::make_unique([this](std::stop_token st) { + this->workerLoop(st); + }); + } +} + +void HostFunctionWorkerCV::finish() { + std::lock_guard lg{workerMutex}; + if (worker) { + worker->request_stop(); + cv.notify_one(); + worker.reset(nullptr); + } +} + +void HostFunctionWorkerCV::submit() noexcept { + { + std::lock_guard lock{pendingAccessMutex}; + ++pending; + } + cv.notify_one(); +} + +void HostFunctionWorkerCV::workerLoop(std::stop_token st) noexcept { + + std::unique_lock lock{pendingAccessMutex, std::defer_lock}; + + while (true) { + lock.lock(); + cv.wait(lock, [&]() { + return pending > 0 || st.stop_requested(); + }); + + if (st.stop_requested()) [[unlikely]] { + return; + } + + --pending; + + lock.unlock(); + + bool sucess = this->runHostFunction(st); + if (!sucess) [[unlikely]] { + return; + } + } +} + +} // namespace NEO diff --git a/shared/source/command_stream/host_function_worker_cv.h b/shared/source/command_stream/host_function_worker_cv.h new file mode 100644 index 0000000000..9dd6780797 --- /dev/null +++ b/shared/source/command_stream/host_function_worker_cv.h @@ -0,0 +1,39 @@ +/* + * Copyright (C) 2025 Intel Corporation + * + * SPDX-License-Identifier: MIT + * + */ + +#pragma once +#include "shared/source/command_stream/host_function_worker_interface.h" + +#include +#include +#include + +namespace NEO { + +class HostFunctionWorkerCV : public IHostFunctionWorker { + public: + HostFunctionWorkerCV(bool skipHostFunctionExecution, + const std::function &downloadAllocationImpl, + GraphicsAllocation *allocation, + HostFunctionData *data); + ~HostFunctionWorkerCV() override; + + void start() override; + void finish() override; + void submit() noexcept override; + + private: + void workerLoop(std::stop_token st) noexcept; + + std::mutex pendingAccessMutex; + std::condition_variable cv; + uint32_t pending{0}; +}; + +static_assert(NonCopyableAndNonMovable); + +} // namespace NEO diff --git a/shared/source/command_stream/host_function_worker_interface.cpp b/shared/source/command_stream/host_function_worker_interface.cpp new file mode 100644 index 0000000000..9236883774 --- /dev/null +++ b/shared/source/command_stream/host_function_worker_interface.cpp @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2025 Intel Corporation + * + * SPDX-License-Identifier: MIT + * + */ + +#include "shared/source/command_stream/host_function_worker_interface.h" + +#include "shared/source/command_stream/host_function.h" +#include "shared/source/utilities/wait_util.h" + +#include +#include + +namespace NEO { +IHostFunctionWorker::IHostFunctionWorker(bool skipHostFunctionExecution, + const std::function &downloadAllocationImpl, + GraphicsAllocation *allocation, + HostFunctionData *data) + : downloadAllocationImpl(downloadAllocationImpl), + allocation(allocation), + data(data), + skipHostFunctionExecution(skipHostFunctionExecution) { +} + +IHostFunctionWorker::~IHostFunctionWorker() = default; + +bool IHostFunctionWorker::runHostFunction(std::stop_token st) noexcept { + + using tagStatusT = std::underlying_type_t; + const auto start = std::chrono::steady_clock::now(); + std::chrono::microseconds waitTime{0}; + + if (!this->skipHostFunctionExecution) { + + while (true) { + if (this->downloadAllocationImpl) [[unlikely]] { + this->downloadAllocationImpl(*this->allocation); + } + const volatile uint32_t *hostFuntionTagAddress = this->data->internalTag; + waitTime = std::chrono::duration_cast(std::chrono::steady_clock::now() - start); + bool pendingJobFound = WaitUtils::waitFunctionWithPredicate(hostFuntionTagAddress, + static_cast(HostFunctionTagStatus::pending), + std::equal_to(), + waitTime.count()); + if (pendingJobFound) { + break; + } + + if (st.stop_requested()) { + return false; + } + } + + using CallbackT = void (*)(void *); + CallbackT callback = reinterpret_cast(*this->data->entry); + void *callbackData = reinterpret_cast(*this->data->userData); + + callback(callbackData); + } + + *this->data->internalTag = static_cast(HostFunctionTagStatus::completed); + + return true; +} + +} // namespace NEO diff --git a/shared/source/command_stream/host_function_worker_interface.h b/shared/source/command_stream/host_function_worker_interface.h new file mode 100644 index 0000000000..f7f52d363e --- /dev/null +++ b/shared/source/command_stream/host_function_worker_interface.h @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2025 Intel Corporation + * + * SPDX-License-Identifier: MIT + * + */ + +#pragma once + +#include "shared/source/helpers/non_copyable_or_moveable.h" + +#include +#include +#include +#include +#include + +namespace NEO { + +class GraphicsAllocation; +struct HostFunctionData; + +class IHostFunctionWorker : public NonCopyableAndNonMovableClass { + public: + IHostFunctionWorker(bool skipHostFunctionExecution, + const std::function &downloadAllocationImpl, + GraphicsAllocation *allocation, + HostFunctionData *data); + virtual ~IHostFunctionWorker() = 0; + + virtual void start() = 0; + virtual void finish() = 0; + virtual void submit() noexcept = 0; + + protected: + MOCKABLE_VIRTUAL bool runHostFunction(std::stop_token st) noexcept; + std::unique_ptr worker; + std::mutex workerMutex; + + private: + std::function downloadAllocationImpl; + GraphicsAllocation *allocation = nullptr; + HostFunctionData *data = nullptr; + bool skipHostFunctionExecution = false; +}; + +static_assert(NonCopyableAndNonMovable); + +} // namespace NEO diff --git a/shared/source/debug_settings/debug_variables_base.inl b/shared/source/debug_settings/debug_variables_base.inl index 3533e6e1ba..eef44d5baa 100644 --- a/shared/source/debug_settings/debug_variables_base.inl +++ b/shared/source/debug_settings/debug_variables_base.inl @@ -320,6 +320,7 @@ DECLARE_DEBUG_VARIABLE(int32_t, OverrideCopyOffloadMode, -1, "-1: default, 0: di DECLARE_DEBUG_VARIABLE(int32_t, UseSingleListForTemporaryAllocations, -1, "-1: default, 0: disabled, 0: enabled. If enabled, use single list, instead of per CSR for tracking temporary allocations") DECLARE_DEBUG_VARIABLE(int32_t, OverrideMaxMemAllocSizeMb, -1, "-1: default, >=0 override reported max mem alloc size in MB") DECLARE_DEBUG_VARIABLE(int32_t, DetectIncorrectPointersOnSetArgCalls, -1, "-1: default do not detect, 0: do not detect, 1: detect incorrect pointers and return error") +DECLARE_DEBUG_VARIABLE(int32_t, HostFunctionWorkMode, -1, "-1: default - counting semaphore based, 0: counting semaphore based, 1: condition variable base, 2: atomics based") DECLARE_DEBUG_VARIABLE(int32_t, ForceDisableGraphPatchPreamble, -1, "-1: default, 0: enable patch preamble, 1: disable graph patch preamble. If disabled, do not patch preamble graph internal command lists") /*LOGGING FLAGS*/ diff --git a/shared/test/common/libult/ult_command_stream_receiver.h b/shared/test/common/libult/ult_command_stream_receiver.h index 30934a5c6b..92d10ab47c 100644 --- a/shared/test/common/libult/ult_command_stream_receiver.h +++ b/shared/test/common/libult/ult_command_stream_receiver.h @@ -563,6 +563,14 @@ class UltCommandStreamReceiver : public CommandStreamReceiverHw { BaseClass::setupContext(osContext); } + void signalHostFunctionWorker() override { + signalHostFunctionWorkerCounter++; + } + + void createHostFunctionWorker() override { + createHostFunctionWorkerCounter++; + } + bool waitUserFence(TaskCountType waitValue, uint64_t hostAddress, int64_t timeout, bool userInterrupt, uint32_t externalInterruptId, GraphicsAllocation *allocForInterruptWait) override { waitUserFenceParams.callCount++; waitUserFenceParams.latestWaitedAddress = hostAddress; @@ -651,6 +659,8 @@ class UltCommandStreamReceiver : public CommandStreamReceiverHw { uint32_t flushHandlerCalled = 0; uint32_t obtainUniqueOwnershipCalledTimes = 0; uint32_t walkerWithProfilingEnqueuedTimes = 0; + uint32_t createHostFunctionWorkerCounter = 0; + uint32_t signalHostFunctionWorkerCounter = 0; mutable uint32_t checkGpuHangDetectedCalled = 0; int ensureCommandBufferAllocationCalled = 0; DispatchFlags recordedDispatchFlags; diff --git a/shared/test/common/mocks/mock_command_stream_receiver.h b/shared/test/common/mocks/mock_command_stream_receiver.h index d9c9b317d2..1ef796e92e 100644 --- a/shared/test/common/mocks/mock_command_stream_receiver.h +++ b/shared/test/common/mocks/mock_command_stream_receiver.h @@ -285,6 +285,14 @@ class MockCommandStreamReceiver : public CommandStreamReceiver { BaseClass::initializeHostFunctionData(); } + void signalHostFunctionWorker() override { + signalHostFunctionWorkerCounter++; + } + + void createHostFunctionWorker() override { + createHostFunctionWorkerCounter++; + } + static constexpr size_t tagSize = 256; static volatile TagAddressType mockTagAddress[tagSize]; std::vector instructionHeapReserveredData; @@ -298,6 +306,8 @@ class MockCommandStreamReceiver : public CommandStreamReceiver { uint32_t submitDependencyUpdateCalledTimes = 0; uint32_t stopDirectSubmissionCalledTimes = 0; uint32_t initializeHostFunctionDataCalledTimes = 0; + uint32_t createHostFunctionWorkerCounter = 0; + uint32_t signalHostFunctionWorkerCounter = 0; int hostPtrSurfaceCreationMutexLockCount = 0; bool multiOsContextCapable = false; bool memoryCompressionEnabled = false; diff --git a/shared/test/common/test_files/igdrcl.config b/shared/test/common/test_files/igdrcl.config index 725611cadc..c769d63349 100644 --- a/shared/test/common/test_files/igdrcl.config +++ b/shared/test/common/test_files/igdrcl.config @@ -674,5 +674,6 @@ EnableUsmAllocationPoolManager = -1 ForceTotalWMTPDataSize = -1 CopyLockedMemoryBeforeWrite = 0 SplitBcsPerEngineMaxSize = -1 +HostFunctionWorkMode = -1 EnableUsmPoolResidencyTracking = -1 # Please don't edit below this line diff --git a/shared/test/unit_test/command_stream/command_stream_receiver_tests.cpp b/shared/test/unit_test/command_stream/command_stream_receiver_tests.cpp index 957add3bce..aa68c98edd 100644 --- a/shared/test/unit_test/command_stream/command_stream_receiver_tests.cpp +++ b/shared/test/unit_test/command_stream/command_stream_receiver_tests.cpp @@ -6493,19 +6493,22 @@ TEST(CommandStreamReceiverHostFunctionsTest, givenDestructedCommandStreamReceive csr->ensureHostFunctionDataInitialization(); EXPECT_NE(nullptr, csr->hostFunctionDataAllocation); EXPECT_NE(nullptr, csr->hostFunctionDataMultiAllocation); + EXPECT_EQ(1u, csr->createHostFunctionWorkerCounter); + csr->cleanupResources(); EXPECT_EQ(nullptr, csr->hostFunctionDataAllocation); EXPECT_EQ(nullptr, csr->hostFunctionDataMultiAllocation); } -TEST(CommandStreamReceiverHostFunctionsTest, givenCommandStreamReceiverWithHostFunctionDataWhenMakeResidentHostFunctionAllocationIsCalledThenHostAllocationIsResident) { - std::unique_ptr device(MockDevice::createWithNewExecutionEnvironment(defaultHwInfo.get(), 0u)); - auto &csr = *device->commandStreamReceivers[0]; +HWTEST_F(CommandStreamReceiverHwTest, givenHostFunctionDataWhenMakeResidentHostFunctionAllocationIsCalledThenHostAllocationIsResident) { + auto &csr = pDevice->getUltCommandStreamReceiver(); + ASSERT_EQ(nullptr, csr.getHostFunctionDataAllocation()); csr.ensureHostFunctionDataInitialization(); auto *hostDataAllocation = csr.getHostFunctionDataAllocation(); ASSERT_NE(nullptr, hostDataAllocation); + EXPECT_EQ(1u, csr.createHostFunctionWorkerCounter); auto csrContextId = csr.getOsContext().getContextId(); EXPECT_FALSE(hostDataAllocation->isResident(csrContextId)); @@ -6516,3 +6519,16 @@ TEST(CommandStreamReceiverHostFunctionsTest, givenCommandStreamReceiverWithHostF csr.makeNonResident(*hostDataAllocation); EXPECT_FALSE(hostDataAllocation->isResident(csrContextId)); } + +HWTEST_F(CommandStreamReceiverHwTest, givenHostFunctionDataWhenSignalHostFunctionWorkerIsCalledThenCounterIsUpdated) { + auto &csr = pDevice->getUltCommandStreamReceiver(); + + ASSERT_EQ(nullptr, csr.getHostFunctionDataAllocation()); + csr.ensureHostFunctionDataInitialization(); + auto *hostDataAllocation = csr.getHostFunctionDataAllocation(); + ASSERT_NE(nullptr, hostDataAllocation); + ASSERT_EQ(1u, csr.createHostFunctionWorkerCounter); + + csr.signalHostFunctionWorker(); + EXPECT_EQ(1u, csr.signalHostFunctionWorkerCounter); +} diff --git a/shared/test/unit_test/mt_tests/host_function/CMakeLists.txt b/shared/test/unit_test/mt_tests/host_function/CMakeLists.txt new file mode 100644 index 0000000000..4c690b4643 --- /dev/null +++ b/shared/test/unit_test/mt_tests/host_function/CMakeLists.txt @@ -0,0 +1,11 @@ +# +# Copyright (C) 2025 Intel Corporation +# +# SPDX-License-Identifier: MIT +# + +set(NEO_SHARED_SRCS_mt_tests_host_function + ${CMAKE_CURRENT_SOURCE_DIR}/CMakeLists.txt + ${CMAKE_CURRENT_SOURCE_DIR}/host_function_tests_mt.cpp +) +target_sources(neo_shared_mt_tests PRIVATE ${NEO_SHARED_SRCS_mt_tests_host_function}) diff --git a/shared/test/unit_test/mt_tests/host_function/host_function_tests_mt.cpp b/shared/test/unit_test/mt_tests/host_function/host_function_tests_mt.cpp new file mode 100644 index 0000000000..4bd28d73b6 --- /dev/null +++ b/shared/test/unit_test/mt_tests/host_function/host_function_tests_mt.cpp @@ -0,0 +1,298 @@ +/* + * Copyright (C) 2025 Intel Corporation + * + * SPDX-License-Identifier: MIT + * + */ + +#include "shared/source/command_stream/host_function_worker_counting_semaphore.h" +#include "shared/source/command_stream/host_function_worker_cv.h" +#include "shared/test/common/helpers/debug_manager_state_restore.h" +#include "shared/test/common/mocks/mock_command_stream_receiver.h" +#include "shared/test/common/mocks/mock_execution_environment.h" +#include "shared/test/common/test_macros/test.h" + +#define TSAN_ANNOTATE_IGNORE_BEGIN() \ + do { \ + } while (0) +#define TSAN_ANNOTATE_IGNORE_END() \ + do { \ + } while (0) + +#if defined(__clang__) +#if defined(__has_feature) +#if __has_feature(thread_sanitizer) + +extern "C" void __tsan_ignore_thread_begin(); +extern "C" void __tsan_ignore_thread_end(); + +#undef TSAN_ANNOTATE_IGNORE_BEGIN +#undef TSAN_ANNOTATE_IGNORE_END +#define TSAN_ANNOTATE_IGNORE_BEGIN() \ + do { \ + __tsan_ignore_thread_begin(); \ + } while (0) +#define TSAN_ANNOTATE_IGNORE_END() \ + do { \ + __tsan_ignore_thread_end(); \ + } while (0) + +#endif +#endif +#endif + +namespace { +class MockCommandStreamReceiverHostFunction : public MockCommandStreamReceiver { + public: + using MockCommandStreamReceiver::hostFunctionData; + using MockCommandStreamReceiver::hostFunctionWorker; + using MockCommandStreamReceiver::MockCommandStreamReceiver; + + void createHostFunctionWorker() override { + CommandStreamReceiver::createHostFunctionWorker(); + } + + void signalHostFunctionWorker() override { + CommandStreamReceiver::signalHostFunctionWorker(); + } +}; + +struct Arg { + uint32_t expected = 0; + uint32_t result = 0; + uint32_t counter = 0; +}; + +extern "C" void hostFunctionExample(void *data) { + Arg *arg = static_cast(data); + arg->result = arg->expected; + ++arg->counter; +} + +void createArgs(std::vector &hostFunctionArgs, uint32_t n) { + hostFunctionArgs.reserve(n); + + for (auto i = 0u; i < n; i++) { + hostFunctionArgs.push_back(Arg{.expected = i + 1, .result = 0, .counter = 0}); + } +} + +class HostFunctionMtFixture { + public: + void configureCSRs(uint32_t numberOfCSRs, uint32_t callbacksPerCsr, uint32_t testingMode, uint32_t primaryCSRs) { + this->callbacksPerCsr = callbacksPerCsr; + + executionEnvironment.prepareRootDeviceEnvironments(1); + executionEnvironment.initializeMemoryManager(); + DeviceBitfield deviceBitfield(1); + + createArgs(this->hostFunctionArgs, numberOfCSRs); + for (auto i = 0u; i < numberOfCSRs; i++) { + csrs.push_back(std::make_unique(executionEnvironment, 0, deviceBitfield)); + } + + for (auto &csr : csrs) { + csr->initializeHostFunctionData(); + } + + for (auto i = 0u; i < csrs.size(); i++) { + *csrs[i]->hostFunctionData.entry = reinterpret_cast(hostFunctionExample); + *csrs[i]->hostFunctionData.userData = reinterpret_cast(&hostFunctionArgs[i]); + *csrs[i]->hostFunctionData.internalTag = static_cast(HostFunctionTagStatus::completed); + } + } + + void simulateGpuContexts() { + auto expectedAllCallbacks = csrs.size() * callbacksPerCsr; + auto callbacksCounter = 0u; + std::vector callbacksPerCsrCounter(csrs.size(), 0); + + TSAN_ANNOTATE_IGNORE_BEGIN(); + + while (true) { + for (auto i = 0u; i < csrs.size(); i++) { + if (*csrs[i]->hostFunctionData.internalTag == static_cast(HostFunctionTagStatus::completed)) { + + if (callbacksPerCsrCounter[i] < callbacksPerCsr) { + *csrs[i]->hostFunctionData.internalTag = static_cast(HostFunctionTagStatus::pending); + ++callbacksPerCsrCounter[i]; + ++callbacksCounter; + } + } + } + + if (callbacksCounter == expectedAllCallbacks) { + break; + } + + std::this_thread::sleep_for(std::chrono::microseconds(100)); + } + + TSAN_ANNOTATE_IGNORE_END(); + } + + void waitForCallbacksCompletion() { + TSAN_ANNOTATE_IGNORE_BEGIN(); + while (true) { + uint32_t csrsCompleted = 0u; + for (auto i = 0u; i < csrs.size(); i++) { + if (*csrs[i]->hostFunctionData.internalTag == static_cast(HostFunctionTagStatus::completed)) { + ++csrsCompleted; + } + } + + if (csrsCompleted == csrs.size()) { + break; + } + + std::this_thread::sleep_for(std::chrono::microseconds(100)); + } + + for (auto i = 0u; i < csrs.size(); i++) { + csrs[i]->hostFunctionWorker->finish(); + } + + TSAN_ANNOTATE_IGNORE_END(); + } + + void checkResults() { + uint32_t expectedCounter = callbacksPerCsr; + TSAN_ANNOTATE_IGNORE_BEGIN(); + + for (auto i = 0u; i < csrs.size(); i++) { + Arg *arg = reinterpret_cast(*csrs[i]->hostFunctionData.userData); + EXPECT_EQ(arg->expected, arg->result); + EXPECT_EQ(uint32_t{i + 1u}, arg->result); + EXPECT_EQ(expectedCounter, arg->counter); + EXPECT_EQ(static_cast(HostFunctionTagStatus::completed), *csrs[i]->hostFunctionData.internalTag); + } + TSAN_ANNOTATE_IGNORE_END(); + } + + void clearResources() { + csrs.clear(); + hostFunctionArgs.clear(); + } + + std::vector hostFunctionArgs; + std::vector> csrs; + DebugManagerStateRestore restorer{}; + uint32_t callbacksPerCsr = 0; + MockExecutionEnvironment executionEnvironment; +}; + +class HostFunctionMtTestP : public ::testing::TestWithParam, public HostFunctionMtFixture { + public: + void SetUp() override { + + auto param = GetParam(); + this->testingMode = static_cast(param); + debugManager.flags.HostFunctionWorkMode.set(this->testingMode); + } + + void TearDown() override { + } + + int primaryCSRs = 0; + DebugManagerStateRestore restorer{}; + int testingMode = 0; +}; + +TEST_P(HostFunctionMtTestP, givenHostFunctionWorkersWhenSequentialCsrJobIsSubmittedThenHostFunctionsWorkIsDoneCorrectly) { + + uint32_t numberOfCSRs = 4; + uint32_t callbacksPerCsr = 6; + + configureCSRs(numberOfCSRs, callbacksPerCsr, testingMode, primaryCSRs); + + // each csr will enqueue multiple host function callbacks + + for (auto iCallback = 0u; iCallback < callbacksPerCsr; iCallback++) { + for (auto &csr : csrs) { + csr->signalHostFunctionWorker(); + } + } + + simulateGpuContexts(); + waitForCallbacksCompletion(); + checkResults(); + clearResources(); +} + +TEST_P(HostFunctionMtTestP, givenHostFunctionWorkersWhenEachCsrSubmitAllCalbacksPerThreadThenHostFunctionsWorkIsDoneCorrectly) { + uint32_t numberOfCSRs = 4; + uint32_t callbacksPerCsr = 6; + + configureCSRs(numberOfCSRs, callbacksPerCsr, testingMode, primaryCSRs); + + // each csr gets its own thread to submit all host functions + auto nSubmitters = csrs.size(); + std::vector submitters; + submitters.reserve(nSubmitters); + + auto submitAllCallbacksPerCsr = [&](uint32_t idxCsr) { + auto csr = csrs[idxCsr].get(); + for (auto callbackIdx = 0u; callbackIdx < callbacksPerCsr; callbackIdx++) { + csr->signalHostFunctionWorker(); + } + }; + + for (auto i = 0u; i < nSubmitters; i++) { + submitters.emplace_back([&, idx = i]() { + submitAllCallbacksPerCsr(idx); + }); + } + + for (auto i = 0u; i < nSubmitters; i++) { + submitters[i].join(); + } + + simulateGpuContexts(); + waitForCallbacksCompletion(); + checkResults(); + clearResources(); +} + +TEST_P(HostFunctionMtTestP, givenHostFunctionWorkersWhenCsrJobsAreSubmittedConcurrentlyThenHostFunctionsWorkIsDoneCorrectly) { + + uint32_t numberOfCSRs = 4; + uint32_t callbacksPerCsr = 6; + + configureCSRs(numberOfCSRs, callbacksPerCsr, testingMode, primaryCSRs); + + auto nSubmitters = callbacksPerCsr; + std::vector submitters; + submitters.reserve(nSubmitters); + + // multiple threads can submit host function in parrarel using the same csr + auto submitOnceCallbackForAllCSRs = [&]() { + for (auto &csr : csrs) { + csr->signalHostFunctionWorker(); + } + }; + + for (auto i = 0u; i < callbacksPerCsr; i++) { + submitters.emplace_back([&]() { + submitOnceCallbackForAllCSRs(); + }); + } + + for (auto i = 0u; i < nSubmitters; i++) { + submitters[i].join(); + } + + simulateGpuContexts(); + waitForCallbacksCompletion(); + checkResults(); + clearResources(); +} + +INSTANTIATE_TEST_SUITE_P(AllModes, + HostFunctionMtTestP, + ::testing::Values( + 0, // Counting Semaphore implementation + 1, // Condition Variable implementation + 2 // Atomics implementation + )); + +} // namespace