fix: ensure thread-safety in zeDeviceSynchronize

get task count and flush stamp within single critical section

Related-To: NEO-14560
Signed-off-by: Mateusz Jablonski <mateusz.jablonski@intel.com>
This commit is contained in:
Mateusz Jablonski
2025-07-30 14:09:01 +00:00
committed by Compute-Runtime-Automation
parent b6200738f3
commit abb00a5ce3
3 changed files with 107 additions and 16 deletions

View File

@@ -2249,28 +2249,35 @@ uint32_t DeviceImp::getEventMaxKernelCount() const {
}
ze_result_t DeviceImp::synchronize() {
for (auto &engine : neoDevice->getAllEngines()) {
if (engine.commandStreamReceiver->isInitialized()) {
auto waitStatus = engine.commandStreamReceiver->waitForTaskCountWithKmdNotifyFallback(
engine.commandStreamReceiver->peekTaskCount(),
engine.commandStreamReceiver->obtainCurrentFlushStamp(),
auto waitForCsr = [](NEO::CommandStreamReceiver *csr) -> ze_result_t {
if (csr->isInitialized()) {
auto lock = csr->obtainUniqueOwnership();
auto taskCountToWait = csr->peekTaskCount();
auto flushStampToWait = csr->obtainCurrentFlushStamp();
lock.unlock();
auto waitStatus = csr->waitForTaskCountWithKmdNotifyFallback(
taskCountToWait,
flushStampToWait,
false,
NEO::QueueThrottle::MEDIUM);
if (waitStatus == NEO::WaitStatus::gpuHang) {
return ZE_RESULT_ERROR_DEVICE_LOST;
}
}
return ZE_RESULT_SUCCESS;
};
for (auto &engine : neoDevice->getAllEngines()) {
auto ret = waitForCsr(engine.commandStreamReceiver);
if (ret != ZE_RESULT_SUCCESS) {
return ret;
}
}
for (auto &secondaryCsr : neoDevice->getSecondaryCsrs()) {
if (secondaryCsr->isInitialized()) {
auto waitStatus = secondaryCsr->waitForTaskCountWithKmdNotifyFallback(
secondaryCsr->peekTaskCount(),
secondaryCsr->obtainCurrentFlushStamp(),
false,
NEO::QueueThrottle::MEDIUM);
if (waitStatus == NEO::WaitStatus::gpuHang) {
return ZE_RESULT_ERROR_DEVICE_LOST;
}
auto ret = waitForCsr(secondaryCsr.get());
if (ret != ZE_RESULT_SUCCESS) {
return ret;
}
}

View File

@@ -5,10 +5,14 @@
*
*/
#include "shared/source/command_stream/command_stream_receiver.h"
#include "shared/test/common/test_macros/test.h"
#include "shared/test/common/libult/ult_command_stream_receiver.h"
#include "shared/test/common/test_macros/hw_test.h"
#include "level_zero/core/test/unit_tests/fixtures/device_fixture.h"
#include "level_zero/core/test/unit_tests/mocks/mock_cmdlist.h"
#include "level_zero/core/test/unit_tests/mocks/mock_cmdqueue.h"
#include "level_zero/core/test/unit_tests/mocks/mock_kernel.h"
#include "level_zero/core/test/unit_tests/mocks/mock_module.h"
namespace L0 {
namespace ult {
@@ -68,5 +72,80 @@ TEST_F(MultiDeviceMtTest, givenTwoDevicesWhenCanAccessPeerIsCalledManyTimesFromM
EXPECT_GE(2u, std::max(taskCount0, taskCount1));
EXPECT_EQ(0u, std::min(taskCount0, taskCount1));
}
using DeviceMtTest = Test<DeviceFixture>;
HWTEST_F(DeviceMtTest, givenMultiThreadsExecutingCmdListAndSynchronizingDeviceWhenSynchronizeIsCalledThenTaskCountAndFlushStampAreTakenWithinSingleCriticalSection) {
L0::Device *device = driverHandle->devices[0];
auto csr = static_cast<UltCommandStreamReceiver<FamilyType> *>(device->getNEODevice()->getDefaultEngine().commandStreamReceiver);
csr->latestSentTaskCount = 0u;
csr->latestFlushedTaskCount = 0u;
csr->taskCount = 0;
csr->flushStamp->setStamp(0);
csr->captureWaitForTaskCountWithKmdNotifyInputParams = true;
csr->waitForTaskCountWithKmdNotifyFallbackReturnValue = WaitStatus::ready;
csr->resourcesInitialized = true;
csr->incrementFlushStampOnFlush = true;
const ze_command_queue_desc_t desc = {};
ze_result_t returnValue;
auto commandQueue = whiteboxCast(CommandQueue::create(defaultHwInfo->platform.eProductFamily,
device,
csr,
&desc,
false,
false,
false,
returnValue));
ASSERT_NE(nullptr, commandQueue);
Mock<Module> module(device, nullptr, ModuleType::user);
Mock<KernelImp> kernel;
kernel.module = &module;
kernel.immutableData.device = device;
auto commandList = std::unique_ptr<CommandList>(CommandList::whiteboxCast(CommandList::create(defaultHwInfo->platform.eProductFamily, device, NEO::EngineGroupType::renderCompute, 0u, returnValue, false)));
ASSERT_NE(nullptr, commandList);
ze_group_count_t dispatchKernelArguments{1, 1, 1};
CmdListKernelLaunchParams launchParams = {};
commandList->appendLaunchKernel(kernel.toHandle(), dispatchKernelArguments, nullptr, 0, nullptr, launchParams);
commandList->close();
std::atomic_bool started = false;
constexpr int numThreads = 8;
constexpr int iterationCount = 20;
std::vector<std::thread> threads;
auto threadBody = [&]() {
ze_command_list_handle_t cmdList = commandList->toHandle();
for (auto i = 0; i < iterationCount; i++) {
commandQueue->executeCommandLists(1, &cmdList, nullptr, false, nullptr, nullptr);
device->synchronize();
}
};
for (int i = 0; i < numThreads; ++i) {
threads.push_back(std::thread(threadBody));
}
started = true;
for (auto &thread : threads) {
thread.join();
}
auto expectedWaitCalls = numThreads * iterationCount;
EXPECT_EQ(static_cast<size_t>(expectedWaitCalls), csr->waitForTaskCountWithKmdNotifyInputParams.size());
for (auto i = 0; i < static_cast<int>(csr->waitForTaskCountWithKmdNotifyInputParams.size()); i++) {
auto &inputParams = csr->waitForTaskCountWithKmdNotifyInputParams[i];
EXPECT_NE(0u, inputParams.taskCountToWait);
EXPECT_NE(0u, inputParams.flushStampToWait);
EXPECT_EQ(inputParams.taskCountToWait, inputParams.flushStampToWait);
}
commandQueue->destroy();
}
} // namespace ult
} // namespace L0

View File

@@ -12,6 +12,7 @@
#include "shared/source/command_stream/wait_status.h"
#include "shared/source/direct_submission/direct_submission_hw.h"
#include "shared/source/helpers/blit_properties.h"
#include "shared/source/helpers/flush_stamp.h"
#include "shared/source/memory_manager/graphics_allocation.h"
#include "shared/source/memory_manager/surface.h"
#include "shared/source/os_interface/os_context.h"
@@ -215,6 +216,9 @@ class UltCommandStreamReceiver : public CommandStreamReceiverHw<GfxFamily> {
}
NEO::SubmissionStatus flush(BatchBuffer &batchBuffer, ResidencyContainer &allocationsForResidency) override {
if (incrementFlushStampOnFlush) {
this->flushStamp->setStamp(this->obtainCurrentFlushStamp() + 1);
}
if (flushReturnValue) {
return *flushReturnValue;
}
@@ -676,6 +680,7 @@ class UltCommandStreamReceiver : public CommandStreamReceiverHw<GfxFamily> {
bool isAnyDirectSubmissionEnabledCallBase = true;
bool isAnyDirectSubmissionEnabledResult = true;
std::atomic_bool captureWaitForTaskCountWithKmdNotifyInputParams = false;
bool incrementFlushStampOnFlush = false;
};
} // namespace NEO