feature: add L0 API to wait for completion of all submissions to given device

Related-To: NEO-14560
Signed-off-by: Mateusz Jablonski <mateusz.jablonski@intel.com>
This commit is contained in:
Mateusz Jablonski 2025-05-07 13:20:39 +00:00 committed by Compute-Runtime-Automation
parent ba85f7417d
commit 1b42ebf7fb
6 changed files with 136 additions and 0 deletions

View File

@ -5,6 +5,9 @@
*
*/
#include "shared/source/command_stream/command_stream_receiver.h"
#include "shared/source/device/device.h"
#include "level_zero/core/source/device/device.h"
#include "level_zero/core/source/driver/driver.h"
#include "level_zero/core/source/driver/driver_handle_imp.h"
@ -27,3 +30,29 @@ ze_device_handle_t ZE_APICALL zerIdentifierTranslateToDeviceHandle(uint32_t iden
}
return driverHandle->devicesToExpose[identifier];
}
ze_result_t ZE_APICALL zeDeviceSynchronize(ze_device_handle_t hDevice) {
auto device = L0::Device::fromHandle(hDevice);
for (auto &engine : device->getNEODevice()->getAllEngines()) {
auto waitStatus = engine.commandStreamReceiver->waitForTaskCountWithKmdNotifyFallback(
engine.commandStreamReceiver->peekTaskCount(),
engine.commandStreamReceiver->obtainCurrentFlushStamp(),
false,
NEO::QueueThrottle::MEDIUM);
if (waitStatus == NEO::WaitStatus::gpuHang) {
return ZE_RESULT_ERROR_DEVICE_LOST;
}
}
for (auto &secondaryCsr : device->getNEODevice()->getSecondaryCsrs()) {
auto waitStatus = secondaryCsr->waitForTaskCountWithKmdNotifyFallback(
secondaryCsr->peekTaskCount(),
secondaryCsr->obtainCurrentFlushStamp(),
false,
NEO::QueueThrottle::MEDIUM);
if (waitStatus == NEO::WaitStatus::gpuHang) {
return ZE_RESULT_ERROR_DEVICE_LOST;
}
}
return ZE_RESULT_SUCCESS;
}

View File

@ -34,6 +34,7 @@ void *ExtensionFunctionAddressHelper::getExtensionFunctionAddress(const std::str
RETURN_FUNC_PTR_IF_EXIST(zerDeviceTranslateToIdentifier);
RETURN_FUNC_PTR_IF_EXIST(zerIdentifierTranslateToDeviceHandle);
RETURN_FUNC_PTR_IF_EXIST(zeDeviceSynchronize);
RETURN_FUNC_PTR_IF_EXIST(zexKernelGetBaseAddress);
RETURN_FUNC_PTR_IF_EXIST(zexKernelGetArgumentSize);

View File

@ -6785,5 +6785,81 @@ TEST_F(DeviceSimpleTests, whenWorkgroupSizeCheckedThenSizeLimitIs1kOrLess) {
EXPECT_LE(properties.maxTotalGroupSize, CommonConstants::maxWorkgroupSize);
}
HWTEST_F(DeviceSimpleTests, givenGpuHangWhenSynchronizingDeviceThenErrorIsPropagated) {
auto &csr = neoDevice->getUltCommandStreamReceiver<FamilyType>();
csr.waitForTaskCountWithKmdNotifyFallbackReturnValue = WaitStatus::gpuHang;
auto result = zeDeviceSynchronize(device);
EXPECT_EQ(ZE_RESULT_ERROR_DEVICE_LOST, result);
}
HWTEST_F(DeviceSimpleTests, givenNoGpuHangWhenSynchronizingDeviceThenCallWaitForTaskCountWithKmdNotifyFallbackOnEachCsr) {
auto &engines = neoDevice->getAllEngines();
TaskCountType taskCountToWait = 1u;
FlushStamp flushStampToWait = 4u;
for (auto &engine : engines) {
auto csr = static_cast<UltCommandStreamReceiver<FamilyType> *>(engine.commandStreamReceiver);
csr->latestSentTaskCount = 0u;
csr->latestFlushedTaskCount = 0u;
csr->taskCount = taskCountToWait++;
csr->flushStamp->setStamp(flushStampToWait++);
csr->waitForTaskCountWithKmdNotifyFallbackReturnValue = WaitStatus::ready;
}
auto &secondaryCsrs = neoDevice->getSecondaryCsrs();
for (auto &secondaryCsr : secondaryCsrs) {
auto csr = static_cast<UltCommandStreamReceiver<FamilyType> *>(secondaryCsr.get());
csr->latestSentTaskCount = 0u;
csr->latestFlushedTaskCount = 0u;
csr->taskCount = taskCountToWait++;
csr->flushStamp->setStamp(flushStampToWait++);
csr->waitForTaskCountWithKmdNotifyFallbackReturnValue = WaitStatus::ready;
}
auto result = zeDeviceSynchronize(device);
EXPECT_EQ(ZE_RESULT_SUCCESS, result);
for (auto &engine : engines) {
auto csr = static_cast<UltCommandStreamReceiver<FamilyType> *>(engine.commandStreamReceiver);
EXPECT_EQ(1u, csr->waitForTaskCountWithKmdNotifyInputParams.size());
EXPECT_EQ(csr->taskCount, csr->waitForTaskCountWithKmdNotifyInputParams[0].taskCountToWait);
EXPECT_EQ(csr->flushStamp->peekStamp(), csr->waitForTaskCountWithKmdNotifyInputParams[0].flushStampToWait);
EXPECT_FALSE(csr->waitForTaskCountWithKmdNotifyInputParams[0].useQuickKmdSleep);
EXPECT_EQ(NEO::QueueThrottle::MEDIUM, csr->waitForTaskCountWithKmdNotifyInputParams[0].throttle);
}
for (auto &secondaryCsr : secondaryCsrs) {
auto csr = static_cast<UltCommandStreamReceiver<FamilyType> *>(secondaryCsr.get());
EXPECT_EQ(1u, csr->waitForTaskCountWithKmdNotifyInputParams.size());
EXPECT_EQ(csr->taskCount, csr->waitForTaskCountWithKmdNotifyInputParams[0].taskCountToWait);
EXPECT_EQ(csr->flushStamp->peekStamp(), csr->waitForTaskCountWithKmdNotifyInputParams[0].flushStampToWait);
EXPECT_FALSE(csr->waitForTaskCountWithKmdNotifyInputParams[0].useQuickKmdSleep);
EXPECT_EQ(NEO::QueueThrottle::MEDIUM, csr->waitForTaskCountWithKmdNotifyInputParams[0].throttle);
}
}
HWTEST_F(DeviceSimpleTests, givenGpuHangOnSecondaryCsrWhenSynchronizingDeviceThenErrorIsPropagated) {
if (neoDevice->getSecondaryCsrs().empty()) {
GTEST_SKIP();
}
auto &engines = neoDevice->getAllEngines();
for (auto &engine : engines) {
auto csr = static_cast<UltCommandStreamReceiver<FamilyType> *>(engine.commandStreamReceiver);
csr->waitForTaskCountWithKmdNotifyFallbackReturnValue = WaitStatus::ready;
}
auto &secondaryCsrs = neoDevice->getSecondaryCsrs();
for (auto &secondaryCsr : secondaryCsrs) {
auto csr = static_cast<UltCommandStreamReceiver<FamilyType> *>(secondaryCsr.get());
csr->waitForTaskCountWithKmdNotifyFallbackReturnValue = WaitStatus::gpuHang;
}
auto result = zeDeviceSynchronize(device);
EXPECT_EQ(ZE_RESULT_ERROR_DEVICE_LOST, result);
}
} // namespace ult
} // namespace L0

View File

@ -1210,6 +1210,7 @@ TEST_F(DriverExperimentalApiTest, whenRetrievingApiFunctionThenExpectProperPoint
decltype(&zerDeviceTranslateToIdentifier) expectedZerDeviceTranslateToIdentifier = zerDeviceTranslateToIdentifier;
decltype(&zerIdentifierTranslateToDeviceHandle) expectedZerIdentifierTranslateToDeviceHandle = zerIdentifierTranslateToDeviceHandle;
decltype(&zeDeviceSynchronize) expectedZeDeviceSynchronize = zeDeviceSynchronize;
decltype(&zexKernelGetBaseAddress) expectedKernelGetBaseAddress = L0::zexKernelGetBaseAddress;
decltype(&zeIntelGetDriverVersionString) expectedIntelGetDriverVersionString = zeIntelGetDriverVersionString;
@ -1246,6 +1247,9 @@ TEST_F(DriverExperimentalApiTest, whenRetrievingApiFunctionThenExpectProperPoint
EXPECT_EQ(ZE_RESULT_SUCCESS, zeDriverGetExtensionFunctionAddress(driverHandle, "zerIdentifierTranslateToDeviceHandle", &funPtr));
EXPECT_EQ(expectedZerIdentifierTranslateToDeviceHandle, reinterpret_cast<decltype(&zerIdentifierTranslateToDeviceHandle)>(funPtr));
EXPECT_EQ(ZE_RESULT_SUCCESS, zeDriverGetExtensionFunctionAddress(driverHandle, "zeDeviceSynchronize", &funPtr));
EXPECT_EQ(expectedZeDeviceSynchronize, reinterpret_cast<decltype(&zeDeviceSynchronize)>(funPtr));
EXPECT_EQ(ZE_RESULT_SUCCESS, zeDriverGetExtensionFunctionAddress(driverHandle, "zexKernelGetBaseAddress", &funPtr));
EXPECT_EQ(expectedKernelGetBaseAddress, reinterpret_cast<decltype(&zexKernelGetBaseAddress)>(funPtr));

View File

@ -316,6 +316,22 @@ uint32_t ZE_APICALL zerDeviceTranslateToIdentifier(ze_device_handle_t hDevice);
/// - device handle associated with the identifier
ze_device_handle_t ZE_APICALL zerIdentifierTranslateToDeviceHandle(uint32_t identifier); ///< [in] integer identifier of the device
/// @brief Global device synchronization
///
/// @details
/// - The application may call this function from simultaneous threads.
/// - The implementation of this function should be lock-free.
/// - Ensures that everything that was submitted to the device is completed.
/// - Ensures that all submissions in all queues on device are completed.
/// - It is not allowed to call this function while some command list are in graph capture mode.
/// - Returns error if error is detected during execution on device.
/// - Hangs indefinitely if GPU execution is blocked on non signaled event.
///
/// @returns
/// - ::ZE_RESULT_SUCCESS
/// - ::ZE_RESULT_ERROR_DEVICE_LOST
ze_result_t ZE_APICALL zeDeviceSynchronize(ze_device_handle_t hDevice); ///> [in] handle of the device
#if defined(__cplusplus)
} // extern "C"
#endif

View File

@ -330,6 +330,7 @@ class UltCommandStreamReceiver : public CommandStreamReceiverHw<GfxFamily> {
}
WaitStatus waitForTaskCountWithKmdNotifyFallback(TaskCountType taskCountToWait, FlushStamp flushStampToWait, bool useQuickKmdSleep, QueueThrottle throttle) override {
waitForTaskCountWithKmdNotifyInputParams.push_back({taskCountToWait, flushStampToWait, useQuickKmdSleep, throttle});
if (waitForTaskCountWithKmdNotifyFallbackReturnValue.has_value()) {
return *waitForTaskCountWithKmdNotifyFallbackReturnValue;
}
@ -591,6 +592,15 @@ class UltCommandStreamReceiver : public CommandStreamReceiverHw<GfxFamily> {
const IndirectHeap *recordedSsh = nullptr;
struct WaitForTaskCountWithKmdNotifyParams {
TaskCountType taskCountToWait;
FlushStamp flushStampToWait;
bool useQuickKmdSleep;
QueueThrottle throttle;
};
std::vector<WaitForTaskCountWithKmdNotifyParams> waitForTaskCountWithKmdNotifyInputParams;
std::mutex mutex;
std::atomic<uint32_t> recursiveLockCounter;
std::atomic<uint32_t> waitForCompletionWithTimeoutTaskCountCalled{0};