From 1abaf407998c58375bcfba672ac8a1701a62816a Mon Sep 17 00:00:00 2001 From: "Neil R. Spruit" Date: Mon, 30 Jun 2025 22:05:51 +0000 Subject: [PATCH] feature: RTAS EXT support - Added Support for the RTAs Extension support replacing the Exp support which will remain for backwards compatability. Related-To: NEO-15257 Signed-off-by: Neil R. Spruit --- level_zero/api/core/ze_core_loader.cpp | 39 ++ .../api/extensions/public/ze_exp_ext.cpp | 164 ++++- level_zero/api/extensions/public/ze_exp_ext.h | 59 +- level_zero/core/source/driver/driver_handle.h | 3 + .../core/source/driver/driver_handle_imp.h | 3 + .../driver/driver_handle_imp_helper.cpp | 1 + .../source/rtas/linux/os_rtas_enumeration.cpp | 4 +- level_zero/core/source/rtas/rtas.cpp | 141 ++++- level_zero/core/source/rtas/rtas.h | 94 ++- .../rtas/windows/os_rtas_enumeration.cpp | 4 +- .../unit_tests/sources/driver/test_driver.cpp | 1 + .../unit_tests/sources/rtas/CMakeLists.txt | 3 +- .../unit_tests/sources/rtas/test_rtas_ext.cpp | 594 ++++++++++++++++++ level_zero/ddi/ze_ddi_tables.cpp | 9 + level_zero/ddi/ze_ddi_tables.h | 2 +- 15 files changed, 1107 insertions(+), 14 deletions(-) create mode 100644 level_zero/core/test/unit_tests/sources/rtas/test_rtas_ext.cpp diff --git a/level_zero/api/core/ze_core_loader.cpp b/level_zero/api/core/ze_core_loader.cpp index f6aae45274..902d74b6e9 100644 --- a/level_zero/api/core/ze_core_loader.cpp +++ b/level_zero/api/core/ze_core_loader.cpp @@ -793,3 +793,42 @@ zeGetRTASBuilderExpProcAddrTable( driverDdiTable.coreDdiTable.RTASBuilderExp = *pDdiTable; return result; } + +ZE_APIEXPORT ze_result_t ZE_APICALL +zeGetRTASBuilderProcAddrTable( + ze_api_version_t version, + ze_rtas_builder_dditable_t *pDdiTable) { + + if (nullptr == pDdiTable) + return ZE_RESULT_ERROR_INVALID_ARGUMENT; + if (ZE_MAJOR_VERSION(L0::globalDriverDispatch.core.version) != ZE_MAJOR_VERSION(version)) + return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + + ze_result_t result = ZE_RESULT_SUCCESS; + fillDdiEntry(pDdiTable->pfnCreateExt, L0::globalDriverDispatch.coreRTASBuilder.pfnCreateExt, version, ZE_API_VERSION_1_13); + fillDdiEntry(pDdiTable->pfnGetBuildPropertiesExt, L0::globalDriverDispatch.coreRTASBuilder.pfnGetBuildPropertiesExt, version, ZE_API_VERSION_1_13); + fillDdiEntry(pDdiTable->pfnBuildExt, L0::globalDriverDispatch.coreRTASBuilder.pfnBuildExt, version, ZE_API_VERSION_1_13); + fillDdiEntry(pDdiTable->pfnCommandListAppendCopyExt, L0::globalDriverDispatch.coreRTASBuilder.pfnCommandListAppendCopyExt, version, ZE_API_VERSION_1_13); + fillDdiEntry(pDdiTable->pfnDestroyExt, L0::globalDriverDispatch.coreRTASBuilder.pfnDestroyExt, version, ZE_API_VERSION_1_13); + driverDdiTable.coreDdiTable.RTASBuilder = *pDdiTable; + return result; +} + +ZE_APIEXPORT ze_result_t ZE_APICALL +zeGetRTASParallelOperationProcAddrTable( + ze_api_version_t version, + ze_rtas_parallel_operation_dditable_t *pDdiTable) { + + if (nullptr == pDdiTable) + return ZE_RESULT_ERROR_INVALID_ARGUMENT; + if (ZE_MAJOR_VERSION(L0::globalDriverDispatch.core.version) != ZE_MAJOR_VERSION(version)) + return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + + ze_result_t result = ZE_RESULT_SUCCESS; + fillDdiEntry(pDdiTable->pfnCreateExt, L0::globalDriverDispatch.coreRTASParallelOperation.pfnCreateExt, version, ZE_API_VERSION_1_13); + fillDdiEntry(pDdiTable->pfnGetPropertiesExt, L0::globalDriverDispatch.coreRTASParallelOperation.pfnGetPropertiesExt, version, ZE_API_VERSION_1_13); + fillDdiEntry(pDdiTable->pfnJoinExt, L0::globalDriverDispatch.coreRTASParallelOperation.pfnJoinExt, version, ZE_API_VERSION_1_13); + fillDdiEntry(pDdiTable->pfnDestroyExt, L0::globalDriverDispatch.coreRTASParallelOperation.pfnDestroyExt, version, ZE_API_VERSION_1_13); + driverDdiTable.coreDdiTable.RTASParallelOperation = *pDdiTable; + return result; +} diff --git a/level_zero/api/extensions/public/ze_exp_ext.cpp b/level_zero/api/extensions/public/ze_exp_ext.cpp index 4d8cb2d203..f723014ce9 100644 --- a/level_zero/api/extensions/public/ze_exp_ext.cpp +++ b/level_zero/api/extensions/public/ze_exp_ext.cpp @@ -1,5 +1,5 @@ /* - * Copyright (C) 2020-2024 Intel Corporation + * Copyright (C) 2020-2025 Intel Corporation * * SPDX-License-Identifier: MIT * @@ -176,6 +176,80 @@ ze_result_t zeRTASParallelOperationDestroyExp(ze_rtas_parallel_operation_exp_han return L0::RTASParallelOperation::fromHandle(hParallelOperation)->destroy(); } +// RTAs Ext + +ze_result_t zeRTASBuilderCreateExt(ze_driver_handle_t hDriver, + const ze_rtas_builder_ext_desc_t *pDescriptor, + ze_rtas_builder_ext_handle_t *phBuilder) { + return L0::DriverHandle::fromHandle(hDriver)->createRTASBuilderExt(pDescriptor, phBuilder); +} + +ze_result_t zeRTASBuilderGetBuildPropertiesExt(ze_rtas_builder_ext_handle_t hBuilder, + const ze_rtas_builder_build_op_ext_desc_t *pBuildOpDescriptor, + ze_rtas_builder_ext_properties_t *pProperties) { + return L0::RTASBuilderExt::fromHandle(hBuilder)->getProperties(pBuildOpDescriptor, pProperties); +} + +ze_result_t zeDriverRTASFormatCompatibilityCheckExt(ze_driver_handle_t hDriver, + ze_rtas_format_ext_t rtasFormatA, + ze_rtas_format_ext_t rtasFormatB) { + return L0::DriverHandle::fromHandle(hDriver)->formatRTASCompatibilityCheckExt(rtasFormatA, rtasFormatB); +} + +ze_result_t zeRTASBuilderBuildExt(ze_rtas_builder_ext_handle_t hBuilder, + const ze_rtas_builder_build_op_ext_desc_t *pBuildOpDescriptor, + void *pScratchBuffer, + size_t scratchBufferSizeBytes, + void *pRtasBuffer, + size_t rtasBufferSizeBytes, + ze_rtas_parallel_operation_ext_handle_t hParallelOperation, + void *pBuildUserPtr, + ze_rtas_aabb_ext_t *pBounds, + size_t *pRtasBufferSizeBytes) { + return L0::RTASBuilderExt::fromHandle(hBuilder)->build(pBuildOpDescriptor, + pScratchBuffer, + scratchBufferSizeBytes, + pRtasBuffer, + rtasBufferSizeBytes, + hParallelOperation, + pBuildUserPtr, + pBounds, + pRtasBufferSizeBytes); +} + +ze_result_t zeRTASBuilderDestroyExt(ze_rtas_builder_ext_handle_t hBuilder) { + return L0::RTASBuilderExt::fromHandle(hBuilder)->destroy(); +} + +ze_result_t zeRTASBuilderCommandListAppendCopyExt( + ze_command_list_handle_t hCommandList, + void *dstptr, + const void *srcptr, + size_t size, + ze_event_handle_t hSignalEvent, + uint32_t numWaitEvents, + ze_event_handle_t *phWaitEvents) { + return zeCommandListAppendMemoryCopy(hCommandList, dstptr, srcptr, size, hSignalEvent, numWaitEvents, phWaitEvents); +} + +ze_result_t zeRTASParallelOperationCreateExt(ze_driver_handle_t hDriver, + ze_rtas_parallel_operation_ext_handle_t *phParallelOperation) { + return L0::DriverHandle::fromHandle(hDriver)->createRTASParallelOperationExt(phParallelOperation); +} + +ze_result_t zeRTASParallelOperationGetPropertiesExt(ze_rtas_parallel_operation_ext_handle_t hParallelOperation, + ze_rtas_parallel_operation_ext_properties_t *pProperties) { + return L0::RTASParallelOperationExt::fromHandle(hParallelOperation)->getProperties(pProperties); +} + +ze_result_t zeRTASParallelOperationJoinExt(ze_rtas_parallel_operation_ext_handle_t hParallelOperation) { + return L0::RTASParallelOperationExt::fromHandle(hParallelOperation)->join(); +} + +ze_result_t zeRTASParallelOperationDestroyExt(ze_rtas_parallel_operation_ext_handle_t hParallelOperation) { + return L0::RTASParallelOperationExt::fromHandle(hParallelOperation)->destroy(); +} +// end RTAs Ext ze_result_t zeMemSetAtomicAccessAttributeExp(ze_context_handle_t hContext, ze_device_handle_t hDevice, const void *ptr, size_t size, ze_memory_atomic_attr_exp_flags_t attr) { return L0::Context::fromHandle(hContext)->setAtomicAccessAttribute(L0::Device::fromHandle(hDevice), ptr, size, attr); } @@ -391,6 +465,94 @@ zeRTASParallelOperationDestroyExp( return L0::zeRTASParallelOperationDestroyExp(hParallelOperation); } +// RTAS Ext + +ZE_APIEXPORT ze_result_t ZE_APICALL +zeRTASBuilderCreateExt( + ze_driver_handle_t hDriver, + const ze_rtas_builder_ext_desc_t *pDescriptor, + ze_rtas_builder_ext_handle_t *phBuilder) { + return L0::zeRTASBuilderCreateExt(hDriver, pDescriptor, phBuilder); +} + +ZE_APIEXPORT ze_result_t ZE_APICALL +zeRTASBuilderGetBuildPropertiesExt( + ze_rtas_builder_ext_handle_t hBuilder, + const ze_rtas_builder_build_op_ext_desc_t *pBuildOpDescriptor, + ze_rtas_builder_ext_properties_t *pProperties) { + return L0::zeRTASBuilderGetBuildPropertiesExt(hBuilder, pBuildOpDescriptor, pProperties); +} + +ZE_APIEXPORT ze_result_t ZE_APICALL +zeDriverRTASFormatCompatibilityCheckExt( + ze_driver_handle_t hDriver, + ze_rtas_format_ext_t rtasFormatA, + ze_rtas_format_ext_t rtasFormatB) { + return L0::zeDriverRTASFormatCompatibilityCheckExt(hDriver, rtasFormatA, rtasFormatB); +} + +ZE_APIEXPORT ze_result_t ZE_APICALL +zeRTASBuilderCommandListAppendCopyExt( + ze_command_list_handle_t hCommandList, + void *dstptr, + const void *srcptr, + size_t size, + ze_event_handle_t hSignalEvent, + uint32_t numWaitEvents, + ze_event_handle_t *phWaitEvents) { + return L0::zeRTASBuilderCommandListAppendCopyExt(hCommandList, dstptr, srcptr, size, hSignalEvent, numWaitEvents, phWaitEvents); +} + +ZE_APIEXPORT ze_result_t ZE_APICALL +zeRTASBuilderBuildExt( + ze_rtas_builder_ext_handle_t hBuilder, + const ze_rtas_builder_build_op_ext_desc_t *pBuildOpDescriptor, + void *pScratchBuffer, + size_t scratchBufferSizeBytes, + void *pRtasBuffer, + size_t rtasBufferSizeBytes, + ze_rtas_parallel_operation_ext_handle_t hParallelOperation, + void *pBuildUserPtr, + ze_rtas_aabb_ext_t *pBounds, + size_t *pRtasBufferSizeBytes) { + return L0::zeRTASBuilderBuildExt(hBuilder, pBuildOpDescriptor, pScratchBuffer, + scratchBufferSizeBytes, pRtasBuffer, rtasBufferSizeBytes, + hParallelOperation, pBuildUserPtr, pBounds, pRtasBufferSizeBytes); +} + +ZE_APIEXPORT ze_result_t ZE_APICALL +zeRTASBuilderDestroyExt( + ze_rtas_builder_ext_handle_t hBuilder) { + return L0::zeRTASBuilderDestroyExt(hBuilder); +} + +ZE_APIEXPORT ze_result_t ZE_APICALL +zeRTASParallelOperationCreateExt( + ze_driver_handle_t hDriver, + ze_rtas_parallel_operation_ext_handle_t *phParallelOperation) { + return L0::zeRTASParallelOperationCreateExt(hDriver, phParallelOperation); +} + +ZE_APIEXPORT ze_result_t ZE_APICALL +zeRTASParallelOperationGetPropertiesExt( + ze_rtas_parallel_operation_ext_handle_t hParallelOperation, + ze_rtas_parallel_operation_ext_properties_t *pProperties) { + return L0::zeRTASParallelOperationGetPropertiesExt(hParallelOperation, pProperties); +} + +ZE_APIEXPORT ze_result_t ZE_APICALL +zeRTASParallelOperationJoinExt( + ze_rtas_parallel_operation_ext_handle_t hParallelOperation) { + return L0::zeRTASParallelOperationJoinExt(hParallelOperation); +} + +ZE_APIEXPORT ze_result_t ZE_APICALL +zeRTASParallelOperationDestroyExt( + ze_rtas_parallel_operation_ext_handle_t hParallelOperation) { + return L0::zeRTASParallelOperationDestroyExt(hParallelOperation); +} +// End RTAS Ext + ZE_APIEXPORT ze_result_t ZE_APICALL zeMemSetAtomicAccessAttributeExp( ze_context_handle_t hContext, diff --git a/level_zero/api/extensions/public/ze_exp_ext.h b/level_zero/api/extensions/public/ze_exp_ext.h index d26635675f..50779547bb 100644 --- a/level_zero/api/extensions/public/ze_exp_ext.h +++ b/level_zero/api/extensions/public/ze_exp_ext.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2020-2024 Intel Corporation + * Copyright (C) 2020-2025 Intel Corporation * * SPDX-License-Identifier: MIT * @@ -8,6 +8,7 @@ #pragma once #include "level_zero/driver_experimental/ze_bindless_image_exp.h" +#include #include namespace L0 { @@ -144,6 +145,62 @@ ze_result_t zeRTASParallelOperationJoinExp( ze_result_t zeRTASParallelOperationDestroyExp( ze_rtas_parallel_operation_exp_handle_t hParallelOperation); +// RTAs Ext + +ze_result_t zeRTASBuilderCreateExt(ze_driver_handle_t hDriver, + const ze_rtas_builder_ext_desc_t *pDescriptor, + ze_rtas_builder_ext_handle_t *phBuilder); + +ze_result_t zeRTASBuilderGetBuildPropertiesExt( + ze_rtas_builder_ext_handle_t hBuilder, + const ze_rtas_builder_build_op_ext_desc_t *pBuildOpDescriptor, + ze_rtas_builder_ext_properties_t *pProperties); + +ze_result_t zeDriverRTASFormatCompatibilityCheckExt( + ze_driver_handle_t hDriver, + ze_rtas_format_ext_t rtasFormatA, + ze_rtas_format_ext_t rtasFormatB); + +ze_result_t zeRTASBuilderBuildExt( + ze_rtas_builder_ext_handle_t hBuilder, + const ze_rtas_builder_build_op_ext_desc_t *pBuildOpDescriptor, + void *pScratchBuffer, + size_t scratchBufferSizeBytes, + void *pRtasBuffer, + size_t rtasBufferSizeBytes, + ze_rtas_parallel_operation_ext_handle_t hParallelOperation, + void *pBuildUserPtr, + ze_rtas_aabb_ext_t *pBounds, + size_t *pRtasBufferSizeBytes); + +ze_result_t zeRTASBuilderDestroyExt( + ze_rtas_builder_ext_handle_t hBuilder); + +ze_result_t zeRTASBuilderCommandListAppendCopyExt( + ze_command_list_handle_t hCommandList, + void *dstptr, + const void *srcptr, + size_t size, + ze_event_handle_t hSignalEvent, + uint32_t numWaitEvents, + ze_event_handle_t *phWaitEvents); + +ze_result_t zeRTASParallelOperationCreateExt( + ze_driver_handle_t hDriver, + ze_rtas_parallel_operation_ext_handle_t *phParallelOperation); + +ze_result_t zeRTASParallelOperationGetPropertiesExt( + ze_rtas_parallel_operation_ext_handle_t hParallelOperation, + ze_rtas_parallel_operation_ext_properties_t *pProperties); + +ze_result_t zeRTASParallelOperationJoinExt( + ze_rtas_parallel_operation_ext_handle_t hParallelOperation); + +ze_result_t zeRTASParallelOperationDestroyExt( + ze_rtas_parallel_operation_ext_handle_t hParallelOperation); + +// End RTAS Ext + ze_result_t zeMemSetAtomicAccessAttributeExp( ze_context_handle_t hContext, ze_device_handle_t hDevice, diff --git a/level_zero/core/source/driver/driver_handle.h b/level_zero/core/source/driver/driver_handle.h index f4c9a0562a..b7d7a90dfd 100644 --- a/level_zero/core/source/driver/driver_handle.h +++ b/level_zero/core/source/driver/driver_handle.h @@ -78,8 +78,11 @@ struct DriverHandle : BaseDriver, NEO::NonCopyableAndNonMovableClass { virtual ze_result_t loadRTASLibrary() = 0; virtual ze_result_t createRTASBuilder(const ze_rtas_builder_exp_desc_t *desc, ze_rtas_builder_exp_handle_t *phBuilder) = 0; + virtual ze_result_t createRTASBuilderExt(const ze_rtas_builder_ext_desc_t *desc, ze_rtas_builder_ext_handle_t *phBuilder) = 0; virtual ze_result_t createRTASParallelOperation(ze_rtas_parallel_operation_exp_handle_t *phParallelOperation) = 0; + virtual ze_result_t createRTASParallelOperationExt(ze_rtas_parallel_operation_ext_handle_t *phParallelOperation) = 0; virtual ze_result_t formatRTASCompatibilityCheck(ze_rtas_format_exp_t rtasFormatA, ze_rtas_format_exp_t rtasFormatB) = 0; + virtual ze_result_t formatRTASCompatibilityCheckExt(ze_rtas_format_ext_t rtasFormatA, ze_rtas_format_ext_t rtasFormatB) = 0; virtual int setErrorDescription(const std::string &str) = 0; virtual ze_result_t getErrorDescription(const char **ppString) = 0; diff --git a/level_zero/core/source/driver/driver_handle_imp.h b/level_zero/core/source/driver/driver_handle_imp.h index 06eb46d479..fb65dee277 100644 --- a/level_zero/core/source/driver/driver_handle_imp.h +++ b/level_zero/core/source/driver/driver_handle_imp.h @@ -136,8 +136,11 @@ struct DriverHandleImp : public DriverHandle { ze_result_t loadRTASLibrary() override; ze_result_t createRTASBuilder(const ze_rtas_builder_exp_desc_t *desc, ze_rtas_builder_exp_handle_t *phBuilder) override; + ze_result_t createRTASBuilderExt(const ze_rtas_builder_ext_desc_t *desc, ze_rtas_builder_ext_handle_t *phBuilder) override; ze_result_t createRTASParallelOperation(ze_rtas_parallel_operation_exp_handle_t *phParallelOperation) override; + ze_result_t createRTASParallelOperationExt(ze_rtas_parallel_operation_ext_handle_t *phParallelOperation) override; ze_result_t formatRTASCompatibilityCheck(ze_rtas_format_exp_t rtasFormatA, ze_rtas_format_exp_t rtasFormatB) override; + ze_result_t formatRTASCompatibilityCheckExt(ze_rtas_format_ext_t rtasFormatA, ze_rtas_format_ext_t rtasFormatB) override; std::map &getIPCHandleMap() { return this->ipcHandles; }; [[nodiscard]] std::unique_lock lockIPCHandleMap() { return std::unique_lock(this->ipcHandleMapMutex); }; diff --git a/level_zero/core/source/driver/driver_handle_imp_helper.cpp b/level_zero/core/source/driver/driver_handle_imp_helper.cpp index cca41f064e..026ddd18e8 100644 --- a/level_zero/core/source/driver/driver_handle_imp_helper.cpp +++ b/level_zero/core/source/driver/driver_handle_imp_helper.cpp @@ -41,6 +41,7 @@ const std::vector> DriverHandleImp::extensionsS {ZE_CACHELINE_SIZE_EXT_NAME, ZE_DEVICE_CACHE_LINE_SIZE_EXT_VERSION_1_0}, {ZE_DEVICE_VECTOR_SIZES_EXT_NAME, ZE_DEVICE_VECTOR_SIZES_EXT_VERSION_1_0}, {ZE_MUTABLE_COMMAND_LIST_EXP_NAME, ZE_MUTABLE_COMMAND_LIST_EXP_VERSION_1_1}, + {ZE_RTAS_EXT_NAME, ZE_RTAS_BUILDER_EXT_VERSION_1_0}, // Driver experimental extensions {ZE_INTEL_DEVICE_MODULE_DP_PROPERTIES_EXP_NAME, ZE_INTEL_DEVICE_MODULE_DP_PROPERTIES_EXP_VERSION_CURRENT}, diff --git a/level_zero/core/source/rtas/linux/os_rtas_enumeration.cpp b/level_zero/core/source/rtas/linux/os_rtas_enumeration.cpp index c3d6457090..cfa0c29e60 100644 --- a/level_zero/core/source/rtas/linux/os_rtas_enumeration.cpp +++ b/level_zero/core/source/rtas/linux/os_rtas_enumeration.cpp @@ -1,5 +1,5 @@ /* - * Copyright (C) 2023 Intel Corporation + * Copyright (C) 2023-2025 Intel Corporation * * SPDX-License-Identifier: MIT * @@ -8,5 +8,5 @@ #include "level_zero/core/source/rtas/rtas.h" namespace L0 { -std::string RTASBuilder::rtasLibraryName = "libze_intel_gpu_raytracing.so"; +std::string rtasLibraryName = "libze_intel_gpu_raytracing.so"; } // namespace L0 diff --git a/level_zero/core/source/rtas/rtas.cpp b/level_zero/core/source/rtas/rtas.cpp index 3179811cc8..9a07e5b36e 100644 --- a/level_zero/core/source/rtas/rtas.cpp +++ b/level_zero/core/source/rtas/rtas.cpp @@ -1,5 +1,5 @@ /* - * Copyright (C) 2023-2024 Intel Corporation + * Copyright (C) 2023-2025 Intel Corporation * * SPDX-License-Identifier: MIT * @@ -33,6 +33,26 @@ pRTASParallelOperationCreateExpImpl parallelOperationCreateExpImpl; pRTASParallelOperationDestroyExpImpl parallelOperationDestroyExpImpl; pRTASParallelOperationGetPropertiesExpImpl parallelOperationGetPropertiesExpImpl; pRTASParallelOperationJoinExpImpl parallelOperationJoinExpImpl; +// RTAS Extension function pointers +const std::string zeRTASBuilderCreateExtImpl = "zeRTASBuilderCreateExtImpl"; +const std::string zeRTASBuilderDestroyExtImpl = "zeRTASBuilderDestroyExtImpl"; +const std::string zeRTASBuilderGetBuildPropertiesExtImpl = "zeRTASBuilderGetBuildPropertiesExtImpl"; +const std::string zeRTASBuilderBuildExtImpl = "zeRTASBuilderBuildExtImpl"; +const std::string zeDriverRTASFormatCompatibilityCheckExtImpl = "zeDriverRTASFormatCompatibilityCheckExtImpl"; +const std::string zeRTASParallelOperationCreateExtImpl = "zeRTASParallelOperationCreateExtImpl"; +const std::string zeRTASParallelOperationDestroyExtImpl = "zeRTASParallelOperationDestroyExtImpl"; +const std::string zeRTASParallelOperationGetPropertiesExtImpl = "zeRTASParallelOperationGetPropertiesExtImpl"; +const std::string zeRTASParallelOperationJoinExtImpl = "zeRTASParallelOperationJoinExtImpl"; + +pRTASBuilderCreateExtImpl builderCreateExtImpl; +pRTASBuilderDestroyExtImpl builderDestroyExtImpl; +pDriverRTASFormatCompatibilityCheckExtImpl formatCompatibilityCheckExtImpl; +pRTASBuilderGetBuildPropertiesExtImpl builderGetBuildPropertiesExtImpl; +pRTASBuilderBuildExtImpl builderBuildExtImpl; +pRTASParallelOperationCreateExtImpl parallelOperationCreateExtImpl; +pRTASParallelOperationDestroyExtImpl parallelOperationDestroyExtImpl; +pRTASParallelOperationGetPropertiesExtImpl parallelOperationGetPropertiesExtImpl; +pRTASParallelOperationJoinExtImpl parallelOperationJoinExtImpl; bool RTASBuilder::loadEntryPoints(NEO::OsLibrary *libraryHandle) { bool ok = getSymbolAddr(libraryHandle, zeRTASBuilderCreateExpImpl, builderCreateExpImpl); @@ -44,7 +64,19 @@ bool RTASBuilder::loadEntryPoints(NEO::OsLibrary *libraryHandle) { ok = ok && getSymbolAddr(libraryHandle, zeRTASParallelOperationDestroyExpImpl, parallelOperationDestroyExpImpl); ok = ok && getSymbolAddr(libraryHandle, zeRTASParallelOperationGetPropertiesExpImpl, parallelOperationGetPropertiesExpImpl); ok = ok && getSymbolAddr(libraryHandle, zeRTASParallelOperationJoinExpImpl, parallelOperationJoinExpImpl); + return ok; +} +bool RTASBuilderExt::loadEntryPoints(NEO::OsLibrary *libraryHandle) { + bool ok = getSymbolAddr(libraryHandle, zeRTASBuilderCreateExtImpl, builderCreateExtImpl); + ok = ok && getSymbolAddr(libraryHandle, zeRTASBuilderDestroyExtImpl, builderDestroyExtImpl); + ok = ok && getSymbolAddr(libraryHandle, zeRTASBuilderGetBuildPropertiesExtImpl, builderGetBuildPropertiesExtImpl); + ok = ok && getSymbolAddr(libraryHandle, zeRTASBuilderBuildExtImpl, builderBuildExtImpl); + ok = ok && getSymbolAddr(libraryHandle, zeDriverRTASFormatCompatibilityCheckExtImpl, formatCompatibilityCheckExtImpl); + ok = ok && getSymbolAddr(libraryHandle, zeRTASParallelOperationCreateExtImpl, parallelOperationCreateExtImpl); + ok = ok && getSymbolAddr(libraryHandle, zeRTASParallelOperationDestroyExtImpl, parallelOperationDestroyExtImpl); + ok = ok && getSymbolAddr(libraryHandle, zeRTASParallelOperationGetPropertiesExtImpl, parallelOperationGetPropertiesExtImpl); + ok = ok && getSymbolAddr(libraryHandle, zeRTASParallelOperationJoinExtImpl, parallelOperationJoinExtImpl); return ok; } @@ -53,6 +85,11 @@ ze_result_t RTASBuilder::getProperties(const ze_rtas_builder_build_op_exp_desc_t return builderGetBuildPropertiesExpImpl(this->handleImpl, args, pProp); } +ze_result_t RTASBuilderExt::getProperties(const ze_rtas_builder_build_op_ext_desc_t *args, + ze_rtas_builder_ext_properties_t *pProp) { + return builderGetBuildPropertiesExtImpl(this->handleImpl, args, pProp); +} + ze_result_t RTASBuilder::build(const ze_rtas_builder_build_op_exp_desc_t *args, void *pScratchBuffer, size_t scratchBufferSizeBytes, void *pRtasBuffer, size_t rtasBufferSizeBytes, @@ -70,6 +107,23 @@ ze_result_t RTASBuilder::build(const ze_rtas_builder_build_op_exp_desc_t *args, pRtasBufferSizeBytes); } +ze_result_t RTASBuilderExt::build(const ze_rtas_builder_build_op_ext_desc_t *args, + void *pScratchBuffer, size_t scratchBufferSizeBytes, + void *pRtasBuffer, size_t rtasBufferSizeBytes, + ze_rtas_parallel_operation_ext_handle_t hParallelOperation, + void *pBuildUserPtr, ze_rtas_aabb_ext_t *pBounds, + size_t *pRtasBufferSizeBytes) { + RTASParallelOperationExt *parallelOperation = RTASParallelOperationExt::fromHandle(hParallelOperation); + + return builderBuildExtImpl(this->handleImpl, + args, + pScratchBuffer, scratchBufferSizeBytes, + pRtasBuffer, rtasBufferSizeBytes, + parallelOperation->handleImpl, + pBuildUserPtr, pBounds, + pRtasBufferSizeBytes); +} + ze_result_t DriverHandleImp::formatRTASCompatibilityCheck(ze_rtas_format_exp_t rtasFormatA, ze_rtas_format_exp_t rtasFormatB) { ze_result_t result = this->loadRTASLibrary(); @@ -80,6 +134,16 @@ ze_result_t DriverHandleImp::formatRTASCompatibilityCheck(ze_rtas_format_exp_t r return formatCompatibilityCheckExpImpl(this->toHandle(), rtasFormatA, rtasFormatB); } +ze_result_t DriverHandleImp::formatRTASCompatibilityCheckExt(ze_rtas_format_ext_t rtasFormatA, + ze_rtas_format_ext_t rtasFormatB) { + ze_result_t result = this->loadRTASLibrary(); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + + return formatCompatibilityCheckExtImpl(this->toHandle(), rtasFormatA, rtasFormatB); +} + ze_result_t DriverHandleImp::createRTASBuilder(const ze_rtas_builder_exp_desc_t *desc, ze_rtas_builder_exp_handle_t *phBuilder) { ze_result_t result = this->loadRTASLibrary(); @@ -98,6 +162,24 @@ ze_result_t DriverHandleImp::createRTASBuilder(const ze_rtas_builder_exp_desc_t return ZE_RESULT_SUCCESS; } +ze_result_t DriverHandleImp::createRTASBuilderExt(const ze_rtas_builder_ext_desc_t *desc, + ze_rtas_builder_ext_handle_t *phBuilder) { + ze_result_t result = this->loadRTASLibrary(); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + + auto pRTASBuilder = std::make_unique(); + + result = builderCreateExtImpl(this->toHandle(), desc, &pRTASBuilder->handleImpl); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + + *phBuilder = pRTASBuilder.release(); + return ZE_RESULT_SUCCESS; +} + ze_result_t RTASBuilder::destroy() { ze_result_t result = builderDestroyExpImpl(this->handleImpl); if (result != ZE_RESULT_SUCCESS) { @@ -108,6 +190,16 @@ ze_result_t RTASBuilder::destroy() { return ZE_RESULT_SUCCESS; } +ze_result_t RTASBuilderExt::destroy() { + ze_result_t result = builderDestroyExtImpl(this->handleImpl); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + + delete this; + return ZE_RESULT_SUCCESS; +} + ze_result_t DriverHandleImp::loadRTASLibrary() { std::lock_guard lock(this->rtasLock); @@ -116,11 +208,17 @@ ze_result_t DriverHandleImp::loadRTASLibrary() { } if (this->rtasLibraryHandle == nullptr) { - this->rtasLibraryHandle = std::unique_ptr(NEO::OsLibrary::loadFunc(RTASBuilder::rtasLibraryName)); - if (this->rtasLibraryHandle == nullptr || RTASBuilder::loadEntryPoints(this->rtasLibraryHandle.get()) == false) { + this->rtasLibraryHandle = std::unique_ptr(NEO::OsLibrary::loadFunc(L0::rtasLibraryName)); + bool rtasExpAvailable = false; + bool rtasExtAvailable = false; + if (this->rtasLibraryHandle != nullptr) { + rtasExpAvailable = RTASBuilder::loadEntryPoints(this->rtasLibraryHandle.get()); + rtasExtAvailable = RTASBuilderExt::loadEntryPoints(this->rtasLibraryHandle.get()); + } + if (this->rtasLibraryHandle == nullptr || (!rtasExpAvailable && !rtasExtAvailable)) { this->rtasLibraryUnavailable = true; - PRINT_DEBUG_STRING(NEO::debugManager.flags.PrintDebugMessages.get(), stderr, "Failed to load Ray Tracing Support Library %s\n", RTASBuilder::rtasLibraryName.c_str()); + PRINT_DEBUG_STRING(NEO::debugManager.flags.PrintDebugMessages.get(), stderr, "Failed to load Ray Tracing Support Library %s\n", L0::rtasLibraryName.c_str()); return ZE_RESULT_ERROR_DEPENDENCY_UNAVAILABLE; } } @@ -145,6 +243,23 @@ ze_result_t DriverHandleImp::createRTASParallelOperation(ze_rtas_parallel_operat return ZE_RESULT_SUCCESS; } +ze_result_t DriverHandleImp::createRTASParallelOperationExt(ze_rtas_parallel_operation_ext_handle_t *phParallelOperation) { + ze_result_t result = this->loadRTASLibrary(); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + + auto pRTASParallelOperation = std::make_unique(); + + result = parallelOperationCreateExtImpl(this->toHandle(), &pRTASParallelOperation->handleImpl); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + + *phParallelOperation = pRTASParallelOperation.release(); + return ZE_RESULT_SUCCESS; +} + ze_result_t RTASParallelOperation::destroy() { ze_result_t result = parallelOperationDestroyExpImpl(this->handleImpl); if (result != ZE_RESULT_SUCCESS) { @@ -155,12 +270,30 @@ ze_result_t RTASParallelOperation::destroy() { return ZE_RESULT_SUCCESS; } +ze_result_t RTASParallelOperationExt::destroy() { + ze_result_t result = parallelOperationDestroyExtImpl(this->handleImpl); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + + delete this; + return ZE_RESULT_SUCCESS; +} + ze_result_t RTASParallelOperation::getProperties(ze_rtas_parallel_operation_exp_properties_t *pProperties) { return parallelOperationGetPropertiesExpImpl(this->handleImpl, pProperties); } +ze_result_t RTASParallelOperationExt::getProperties(ze_rtas_parallel_operation_ext_properties_t *pProperties) { + return parallelOperationGetPropertiesExtImpl(this->handleImpl, pProperties); +} + ze_result_t RTASParallelOperation::join() { return parallelOperationJoinExpImpl(this->handleImpl); } +ze_result_t RTASParallelOperationExt::join() { + return parallelOperationJoinExtImpl(this->handleImpl); +} + } // namespace L0 diff --git a/level_zero/core/source/rtas/rtas.h b/level_zero/core/source/rtas/rtas.h index 90c930fb7b..7b1522cee6 100644 --- a/level_zero/core/source/rtas/rtas.h +++ b/level_zero/core/source/rtas/rtas.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2023-2024 Intel Corporation + * Copyright (C) 2023-2025 Intel Corporation * * SPDX-License-Identifier: MIT * @@ -15,12 +15,16 @@ struct _ze_rtas_builder_exp_handle_t {}; struct _ze_rtas_parallel_operation_exp_handle_t {}; +struct _ze_rtas_builder_ext_handle_t {}; +struct _ze_rtas_parallel_operation_ext_handle_t {}; namespace L0 { /* * Note: RTAS Library using same headers as Level Zero, but using function * pointers to access these symbols in the external Library. */ + +// Exp Functions typedef ze_result_t (*pRTASBuilderCreateExpImpl)(ze_driver_handle_t hDriver, const ze_rtas_builder_exp_desc_t *pDescriptor, ze_rtas_builder_exp_handle_t *phBuilder); @@ -53,6 +57,39 @@ typedef ze_result_t (*pRTASParallelOperationGetPropertiesExpImpl)(ze_rtas_parall typedef ze_result_t (*pRTASParallelOperationJoinExpImpl)(ze_rtas_parallel_operation_exp_handle_t hParallelOperation); +// Ext Functions +typedef ze_result_t (*pRTASBuilderCreateExtImpl)(ze_driver_handle_t hDriver, + const ze_rtas_builder_ext_desc_t *pDescriptor, + ze_rtas_builder_ext_handle_t *phBuilder); + +typedef ze_result_t (*pRTASBuilderDestroyExtImpl)(ze_rtas_builder_ext_handle_t hBuilder); + +typedef ze_result_t (*pRTASBuilderGetBuildPropertiesExtImpl)(ze_rtas_builder_ext_handle_t hBuilder, + const ze_rtas_builder_build_op_ext_desc_t *args, + ze_rtas_builder_ext_properties_t *pProp); + +typedef ze_result_t (*pRTASBuilderBuildExtImpl)(ze_rtas_builder_ext_handle_t hBuilder, + const ze_rtas_builder_build_op_ext_desc_t *args, + void *pScratchBuffer, size_t scratchBufferSizeBytes, + void *pRtasBuffer, size_t rtasBufferSizeBytes, + ze_rtas_parallel_operation_ext_handle_t hParallelOperation, + void *pBuildUserPtr, ze_rtas_aabb_ext_t *pBounds, + size_t *pRtasBufferSizeBytes); + +typedef ze_result_t (*pDriverRTASFormatCompatibilityCheckExtImpl)(ze_driver_handle_t hDriver, + const ze_rtas_format_ext_t accelFormat, + const ze_rtas_format_ext_t otherAccelFormat); + +typedef ze_result_t (*pRTASParallelOperationCreateExtImpl)(ze_driver_handle_t hDriver, + ze_rtas_parallel_operation_ext_handle_t *phParallelOperation); + +typedef ze_result_t (*pRTASParallelOperationDestroyExtImpl)(ze_rtas_parallel_operation_ext_handle_t hParallelOperation); + +typedef ze_result_t (*pRTASParallelOperationGetPropertiesExtImpl)(ze_rtas_parallel_operation_ext_handle_t hParallelOperation, + ze_rtas_parallel_operation_ext_properties_t *pProperties); + +typedef ze_result_t (*pRTASParallelOperationJoinExtImpl)(ze_rtas_parallel_operation_ext_handle_t hParallelOperation); + extern pRTASBuilderCreateExpImpl builderCreateExpImpl; extern pRTASBuilderDestroyExpImpl builderDestroyExpImpl; extern pRTASBuilderGetBuildPropertiesExpImpl builderGetBuildPropertiesExpImpl; @@ -62,6 +99,17 @@ extern pRTASParallelOperationCreateExpImpl parallelOperationCreateExpImpl; extern pRTASParallelOperationDestroyExpImpl parallelOperationDestroyExpImpl; extern pRTASParallelOperationGetPropertiesExpImpl parallelOperationGetPropertiesExpImpl; extern pRTASParallelOperationJoinExpImpl parallelOperationJoinExpImpl; +// RTAS Extension function pointers +extern pRTASBuilderCreateExtImpl builderCreateExtImpl; +extern pRTASBuilderDestroyExtImpl builderDestroyExtImpl; +extern pRTASBuilderGetBuildPropertiesExtImpl builderGetBuildPropertiesExtImpl; +extern pRTASBuilderBuildExtImpl builderBuildExtImpl; +extern pDriverRTASFormatCompatibilityCheckExtImpl formatCompatibilityCheckExtImpl; +extern pRTASParallelOperationCreateExtImpl parallelOperationCreateExtImpl; +extern pRTASParallelOperationDestroyExtImpl parallelOperationDestroyExtImpl; +extern pRTASParallelOperationGetPropertiesExtImpl parallelOperationGetPropertiesExtImpl; +extern pRTASParallelOperationJoinExtImpl parallelOperationJoinExtImpl; +extern std::string rtasLibraryName; struct RTASBuilder : _ze_rtas_builder_exp_handle_t { public: @@ -80,7 +128,6 @@ struct RTASBuilder : _ze_rtas_builder_exp_handle_t { static RTASBuilder *fromHandle(ze_rtas_builder_exp_handle_t handle) { return static_cast(handle); } inline ze_rtas_builder_exp_handle_t toHandle() { return this; } - static std::string rtasLibraryName; static bool loadEntryPoints(NEO::OsLibrary *libraryHandle); template @@ -107,4 +154,47 @@ struct RTASParallelOperation : _ze_rtas_parallel_operation_exp_handle_t { ze_rtas_parallel_operation_exp_handle_t handleImpl; }; +struct RTASBuilderExt : _ze_rtas_builder_ext_handle_t { + public: + virtual ~RTASBuilderExt() = default; + + ze_result_t destroy(); + ze_result_t getProperties(const ze_rtas_builder_build_op_ext_desc_t *args, + ze_rtas_builder_ext_properties_t *pProp); + ze_result_t build(const ze_rtas_builder_build_op_ext_desc_t *args, + void *pScratchBuffer, size_t scratchBufferSizeBytes, + void *pRtasBuffer, size_t rtasBufferSizeBytes, + ze_rtas_parallel_operation_ext_handle_t hParallelOperation, + void *pBuildUserPtr, ze_rtas_aabb_ext_t *pBounds, + size_t *pRtasBufferSizeBytes); + + static RTASBuilderExt *fromHandle(ze_rtas_builder_ext_handle_t handle) { return static_cast(handle); } + inline ze_rtas_builder_ext_handle_t toHandle() { return this; } + + static bool loadEntryPoints(NEO::OsLibrary *libraryHandle); + + template + static bool getSymbolAddr(NEO::OsLibrary *libraryHandle, const std::string name, T &proc) { + void *addr = libraryHandle->getProcAddress(name); + proc = reinterpret_cast(addr); + return nullptr != proc; + } + + ze_rtas_builder_ext_handle_t handleImpl; +}; + +struct RTASParallelOperationExt : _ze_rtas_parallel_operation_ext_handle_t { + public: + virtual ~RTASParallelOperationExt() = default; + + ze_result_t destroy(); + ze_result_t getProperties(ze_rtas_parallel_operation_ext_properties_t *pProperties); + ze_result_t join(); + + static RTASParallelOperationExt *fromHandle(ze_rtas_parallel_operation_ext_handle_t handle) { return static_cast(handle); } + inline ze_rtas_parallel_operation_ext_handle_t toHandle() { return this; } + + ze_rtas_parallel_operation_ext_handle_t handleImpl; +}; + } // namespace L0 diff --git a/level_zero/core/source/rtas/windows/os_rtas_enumeration.cpp b/level_zero/core/source/rtas/windows/os_rtas_enumeration.cpp index ce4c4d4b82..2430195cf3 100644 --- a/level_zero/core/source/rtas/windows/os_rtas_enumeration.cpp +++ b/level_zero/core/source/rtas/windows/os_rtas_enumeration.cpp @@ -1,5 +1,5 @@ /* - * Copyright (C) 2023 Intel Corporation + * Copyright (C) 2023-2025 Intel Corporation * * SPDX-License-Identifier: MIT * @@ -8,5 +8,5 @@ #include "level_zero/core/source/rtas/rtas.h" namespace L0 { -std::string RTASBuilder::rtasLibraryName = "ze_intel_gpu_raytracing.dll"; +std::string rtasLibraryName = "ze_intel_gpu_raytracing.dll"; } // namespace L0 diff --git a/level_zero/core/test/unit_tests/sources/driver/test_driver.cpp b/level_zero/core/test/unit_tests/sources/driver/test_driver.cpp index f10a7cd98b..fca9d043f9 100644 --- a/level_zero/core/test/unit_tests/sources/driver/test_driver.cpp +++ b/level_zero/core/test/unit_tests/sources/driver/test_driver.cpp @@ -1705,6 +1705,7 @@ TEST_F(DriverExtensionsTest, givenDriverHandleWhenAskingForExtensionsThenReturnC verifyExtensionDefinition(ZE_CACHELINE_SIZE_EXT_NAME, ZE_DEVICE_CACHE_LINE_SIZE_EXT_VERSION_1_0); verifyExtensionDefinition(ZE_DEVICE_VECTOR_SIZES_EXT_NAME, ZE_DEVICE_VECTOR_SIZES_EXT_VERSION_1_0); verifyExtensionDefinition(ZE_MUTABLE_COMMAND_LIST_EXP_NAME, ZE_MUTABLE_COMMAND_LIST_EXP_VERSION_1_1); + verifyExtensionDefinition(ZE_RTAS_EXT_NAME, ZE_RTAS_BUILDER_EXT_VERSION_1_0); // Driver experimental extensions verifyExtensionDefinition(ZE_INTEL_DEVICE_MODULE_DP_PROPERTIES_EXP_NAME, ZE_INTEL_DEVICE_MODULE_DP_PROPERTIES_EXP_VERSION_CURRENT); diff --git a/level_zero/core/test/unit_tests/sources/rtas/CMakeLists.txt b/level_zero/core/test/unit_tests/sources/rtas/CMakeLists.txt index 0f93d5bfd8..bf39f900fe 100644 --- a/level_zero/core/test/unit_tests/sources/rtas/CMakeLists.txt +++ b/level_zero/core/test/unit_tests/sources/rtas/CMakeLists.txt @@ -1,5 +1,5 @@ # -# Copyright (C) 2022-2023 Intel Corporation +# Copyright (C) 2022-2025 Intel Corporation # # SPDX-License-Identifier: MIT # @@ -7,5 +7,6 @@ target_sources(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/CMakeLists.txt ${CMAKE_CURRENT_SOURCE_DIR}/test_rtas.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test_rtas_ext.cpp ) add_subdirectories() \ No newline at end of file diff --git a/level_zero/core/test/unit_tests/sources/rtas/test_rtas_ext.cpp b/level_zero/core/test/unit_tests/sources/rtas/test_rtas_ext.cpp new file mode 100644 index 0000000000..75bba0b8b0 --- /dev/null +++ b/level_zero/core/test/unit_tests/sources/rtas/test_rtas_ext.cpp @@ -0,0 +1,594 @@ +/* + * Copyright (C) 2023-2025 Intel Corporation + * + * SPDX-License-Identifier: MIT + * + */ + +#include "shared/test/common/test_macros/test.h" + +#include "level_zero/core/source/rtas/rtas.h" +#include "level_zero/core/test/unit_tests/fixtures/device_fixture.h" + +namespace L0 { +namespace ult { + +struct RTASFixtureExt : public DeviceFixture { + void setUp() { + DeviceFixture::setUp(); + builderCreateCalled = 0; + builderCreateCalled = 0; + builderCreateFailCalled = 0; + builderDestroyCalled = 0; + builderDestroyFailCalled = 0; + builderGetBuildPropertiesCalled = 0; + builderGetBuildPropertiesFailCalled = 0; + builderBuildCalled = 0; + builderBuildFailCalled = 0; + formatCompatibilityCheckCalled = 0; + formatCompatibilityCheckFailCalled = 0; + parallelOperationDestroyCalled = 0; + parallelOperationDestroyFailCalled = 0; + parallelOperationCreateCalled = 0; + parallelOperationCreateFailCalled = 0; + parallelOperationGetPropertiesCalled = 0; + parallelOperationGetPropertiesFailCalled = 0; + parallelOperationJoinCalled = 0; + parallelOperationJoinFailCalled = 0; + } + + void tearDown() { + DeviceFixture::tearDown(); + } + + static ze_result_t builderCreate(ze_driver_handle_t hDriver, + const ze_rtas_builder_ext_desc_t *pDescriptor, + ze_rtas_builder_ext_handle_t *phBuilder) { + builderCreateCalled++; + return ZE_RESULT_SUCCESS; + } + + static ze_result_t builderCreateFail(ze_driver_handle_t hDriver, + const ze_rtas_builder_ext_desc_t *pDescriptor, + ze_rtas_builder_ext_handle_t *phBuilder) { + builderCreateFailCalled++; + return ZE_RESULT_ERROR_UNKNOWN; + } + + static ze_result_t builderDestroy(ze_rtas_builder_ext_handle_t hBuilder) { + builderDestroyCalled++; + return ZE_RESULT_SUCCESS; + } + + static ze_result_t builderDestroyFail(ze_rtas_builder_ext_handle_t hBuilder) { + builderDestroyFailCalled++; + return ZE_RESULT_ERROR_UNKNOWN; + } + + static ze_result_t builderGetBuildProperties(ze_rtas_builder_ext_handle_t hBuilder, + const ze_rtas_builder_build_op_ext_desc_t *args, + ze_rtas_builder_ext_properties_t *pProp) { + builderGetBuildPropertiesCalled++; + return ZE_RESULT_SUCCESS; + } + + static ze_result_t builderGetBuildPropertiesFail(ze_rtas_builder_ext_handle_t hBuilder, + const ze_rtas_builder_build_op_ext_desc_t *args, + ze_rtas_builder_ext_properties_t *pProp) { + builderGetBuildPropertiesFailCalled++; + return ZE_RESULT_ERROR_UNKNOWN; + } + + static ze_result_t builderBuild(ze_rtas_builder_ext_handle_t hBuilder, + const ze_rtas_builder_build_op_ext_desc_t *args, + void *pScratchBuffer, size_t scratchBufferSizeBytes, + void *pRtasBuffer, size_t rtasBufferSizeBytes, + ze_rtas_parallel_operation_ext_handle_t hParallelOperation, + void *pBuildUserPtr, ze_rtas_aabb_ext_t *pBounds, + size_t *pRtasBufferSizeBytes) { + builderBuildCalled++; + return ZE_RESULT_SUCCESS; + } + + static ze_result_t builderBuildFail(ze_rtas_builder_ext_handle_t hBuilder, + const ze_rtas_builder_build_op_ext_desc_t *args, + void *pScratchBuffer, size_t scratchBufferSizeBytes, + void *pRtasBuffer, size_t rtasBufferSizeBytes, + ze_rtas_parallel_operation_ext_handle_t hParallelOperation, + void *pBuildUserPtr, ze_rtas_aabb_ext_t *pBounds, + size_t *pRtasBufferSizeBytes) { + builderBuildFailCalled++; + return ZE_RESULT_ERROR_UNKNOWN; + } + + static ze_result_t formatCompatibilityCheck(ze_driver_handle_t hDriver, + const ze_rtas_format_ext_t accelFormat, + const ze_rtas_format_ext_t otherAccelFormat) { + formatCompatibilityCheckCalled++; + return ZE_RESULT_SUCCESS; + } + + static ze_result_t formatCompatibilityCheckFail(ze_driver_handle_t hDriver, + const ze_rtas_format_ext_t accelFormat, + const ze_rtas_format_ext_t otherAccelFormat) { + formatCompatibilityCheckFailCalled++; + return ZE_RESULT_ERROR_UNKNOWN; + } + + static ze_result_t parallelOperationDestroy(ze_rtas_parallel_operation_ext_handle_t hParallelOperation) { + parallelOperationDestroyCalled++; + return ZE_RESULT_SUCCESS; + } + + static ze_result_t parallelOperationDestroyFail(ze_rtas_parallel_operation_ext_handle_t hParallelOperation) { + parallelOperationDestroyFailCalled++; + return ZE_RESULT_ERROR_UNKNOWN; + } + + static ze_result_t parallelOperationCreate(ze_driver_handle_t hDriver, + ze_rtas_parallel_operation_ext_handle_t *phParallelOperation) { + parallelOperationCreateCalled++; + return ZE_RESULT_SUCCESS; + } + + static ze_result_t parallelOperationCreateFail(ze_driver_handle_t hDriver, + ze_rtas_parallel_operation_ext_handle_t *phParallelOperation) { + parallelOperationCreateFailCalled++; + return ZE_RESULT_ERROR_UNKNOWN; + } + + static ze_result_t parallelOperationGetProperties(ze_rtas_parallel_operation_ext_handle_t hParallelOperation, + ze_rtas_parallel_operation_ext_properties_t *pProperties) { + parallelOperationGetPropertiesCalled++; + return ZE_RESULT_SUCCESS; + } + + static ze_result_t parallelOperationGetPropertiesFail(ze_rtas_parallel_operation_ext_handle_t hParallelOperation, + ze_rtas_parallel_operation_ext_properties_t *pProperties) { + parallelOperationGetPropertiesFailCalled++; + return ZE_RESULT_ERROR_UNKNOWN; + } + + static ze_result_t parallelOperationJoin(ze_rtas_parallel_operation_ext_handle_t hParallelOperation) { + parallelOperationJoinCalled++; + return ZE_RESULT_SUCCESS; + } + + static ze_result_t parallelOperationJoinFail(ze_rtas_parallel_operation_ext_handle_t hParallelOperation) { + parallelOperationJoinFailCalled++; + return ZE_RESULT_ERROR_UNKNOWN; + } + + static uint32_t builderCreateCalled; + static uint32_t builderCreateFailCalled; + static uint32_t builderDestroyCalled; + static uint32_t builderDestroyFailCalled; + static uint32_t builderGetBuildPropertiesCalled; + static uint32_t builderGetBuildPropertiesFailCalled; + static uint32_t builderBuildCalled; + static uint32_t builderBuildFailCalled; + static uint32_t formatCompatibilityCheckCalled; + static uint32_t formatCompatibilityCheckFailCalled; + static uint32_t parallelOperationDestroyCalled; + static uint32_t parallelOperationDestroyFailCalled; + static uint32_t parallelOperationCreateCalled; + static uint32_t parallelOperationCreateFailCalled; + static uint32_t parallelOperationGetPropertiesCalled; + static uint32_t parallelOperationGetPropertiesFailCalled; + static uint32_t parallelOperationJoinCalled; + static uint32_t parallelOperationJoinFailCalled; +}; + +uint32_t RTASFixtureExt::builderCreateCalled = 0; +uint32_t RTASFixtureExt::builderCreateFailCalled = 0; +uint32_t RTASFixtureExt::builderDestroyCalled = 0; +uint32_t RTASFixtureExt::builderDestroyFailCalled = 0; +uint32_t RTASFixtureExt::builderGetBuildPropertiesCalled = 0; +uint32_t RTASFixtureExt::builderGetBuildPropertiesFailCalled = 0; +uint32_t RTASFixtureExt::builderBuildCalled = 0; +uint32_t RTASFixtureExt::builderBuildFailCalled = 0; +uint32_t RTASFixtureExt::formatCompatibilityCheckCalled = 0; +uint32_t RTASFixtureExt::formatCompatibilityCheckFailCalled = 0; +uint32_t RTASFixtureExt::parallelOperationDestroyCalled = 0; +uint32_t RTASFixtureExt::parallelOperationDestroyFailCalled = 0; +uint32_t RTASFixtureExt::parallelOperationCreateCalled = 0; +uint32_t RTASFixtureExt::parallelOperationCreateFailCalled = 0; +uint32_t RTASFixtureExt::parallelOperationGetPropertiesCalled = 0; +uint32_t RTASFixtureExt::parallelOperationGetPropertiesFailCalled = 0; +uint32_t RTASFixtureExt::parallelOperationJoinCalled = 0; +uint32_t RTASFixtureExt::parallelOperationJoinFailCalled = 0; + +using RTASTestExt = Test; + +struct MockRTASExtOsLibrary : public OsLibrary { + public: + static bool mockLoad; + MockRTASExtOsLibrary(const std::string &name, std::string *errorValue) { + } + MockRTASExtOsLibrary() {} + ~MockRTASExtOsLibrary() override = default; + void *getProcAddress(const std::string &procName) override { + auto it = funcMap.find(procName); + if (funcMap.end() == it) { + return nullptr; + } else { + return it->second; + } + } + bool isLoaded() override { + return false; + } + std::string getFullPath() override { + return std::string(); + } + static OsLibrary *load(const OsLibraryCreateProperties &properties) { + if (mockLoad == true) { + auto ptr = new (std::nothrow) MockRTASExtOsLibrary(); + return ptr; + } else { + return nullptr; + } + } + std::map funcMap; +}; + +bool MockRTASExtOsLibrary::mockLoad = true; + +TEST_F(RTASTestExt, GivenLibraryLoadsSymbolsAndUnderlyingFunctionsSucceedThenSuccessIsReturned_Ext) { + struct MockSymbolsLoadedOsLibrary : public OsLibrary { + public: + MockSymbolsLoadedOsLibrary(const std::string &name, std::string *errorValue) {} + MockSymbolsLoadedOsLibrary() {} + ~MockSymbolsLoadedOsLibrary() override = default; + void *getProcAddress(const std::string &procName) override { + funcMap["zeRTASBuilderCreateExtImpl"] = reinterpret_cast(&builderCreate); + funcMap["zeRTASBuilderDestroyExtImpl"] = reinterpret_cast(&builderDestroy); + funcMap["zeRTASBuilderGetBuildPropertiesExtImpl"] = reinterpret_cast(&builderGetBuildProperties); + funcMap["zeRTASBuilderBuildExtImpl"] = reinterpret_cast(&builderBuild); + funcMap["zeDriverRTASFormatCompatibilityCheckExtImpl"] = reinterpret_cast(&formatCompatibilityCheck); + funcMap["zeRTASParallelOperationCreateExtImpl"] = reinterpret_cast(¶llelOperationCreate); + funcMap["zeRTASParallelOperationDestroyExtImpl"] = reinterpret_cast(¶llelOperationDestroy); + funcMap["zeRTASParallelOperationGetPropertiesExtImpl"] = reinterpret_cast(¶llelOperationGetProperties); + funcMap["zeRTASParallelOperationJoinExtImpl"] = reinterpret_cast(¶llelOperationJoin); + auto it = funcMap.find(procName); + if (funcMap.end() == it) { + return nullptr; + } else { + return it->second; + } + } + bool isLoaded() override { return true; } + std::string getFullPath() override { return std::string(); } + static OsLibrary *load(const OsLibraryCreateProperties &properties) { + auto ptr = new (std::nothrow) MockSymbolsLoadedOsLibrary(); + return ptr; + } + std::map funcMap; + }; + ze_rtas_builder_ext_handle_t hBuilder; + ze_rtas_parallel_operation_ext_handle_t hParallelOperation; + const ze_rtas_format_ext_t accelFormatA = {}; + const ze_rtas_format_ext_t accelFormatB = {}; + VariableBackup funcBackup{&NEO::OsLibrary::loadFunc, MockSymbolsLoadedOsLibrary::load}; + driverHandle->rtasLibraryHandle.reset(); + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeRTASBuilderCreateExt(driverHandle->toHandle(), nullptr, &hBuilder)); + EXPECT_EQ(1u, builderCreateCalled); + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeRTASBuilderDestroyExt(hBuilder)); + EXPECT_EQ(1u, builderDestroyCalled); + driverHandle->rtasLibraryHandle.reset(); + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeRTASParallelOperationCreateExt(driverHandle->toHandle(), &hParallelOperation)); + EXPECT_EQ(1u, parallelOperationCreateCalled); + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeRTASParallelOperationDestroyExt(hParallelOperation)); + EXPECT_EQ(1u, parallelOperationDestroyCalled); + driverHandle->rtasLibraryHandle.reset(); + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeDriverRTASFormatCompatibilityCheckExt(driverHandle->toHandle(), accelFormatA, accelFormatB)); + EXPECT_EQ(1u, formatCompatibilityCheckCalled); + driverHandle->rtasLibraryHandle.reset(); +} + +TEST_F(RTASTestExt, GivenLibraryFailedToLoadSymbolsThenErrorIsReturned_Ext) { + ze_rtas_builder_ext_handle_t hBuilder; + ze_rtas_parallel_operation_ext_handle_t hParallelOperation; + const ze_rtas_format_ext_t accelFormatA = {}; + const ze_rtas_format_ext_t accelFormatB = {}; + VariableBackup funcBackup{&NEO::OsLibrary::loadFunc, MockRTASExtOsLibrary::load}; + driverHandle->rtasLibraryHandle.reset(); + EXPECT_EQ(ZE_RESULT_ERROR_DEPENDENCY_UNAVAILABLE, L0::zeRTASBuilderCreateExt(driverHandle->toHandle(), nullptr, &hBuilder)); + EXPECT_EQ(ZE_RESULT_ERROR_DEPENDENCY_UNAVAILABLE, L0::zeRTASParallelOperationCreateExt(driverHandle->toHandle(), &hParallelOperation)); + EXPECT_EQ(ZE_RESULT_ERROR_DEPENDENCY_UNAVAILABLE, L0::zeDriverRTASFormatCompatibilityCheckExt(driverHandle->toHandle(), accelFormatA, accelFormatB)); +} + +TEST_F(RTASTestExt, GivenUnderlyingBuilderCommandListAppendCopySucceedsThenSuccessIsReturned_Ext) { + // Mock implementation for zeRTASBuilderCommandListAppendCopyExt + static uint32_t builderCommandListAppendCopyCalled = 0; + auto builderCommandListAppendCopy = [](ze_command_list_handle_t, ze_rtas_builder_ext_handle_t, const void *, size_t, void *, size_t, ze_rtas_parallel_operation_ext_handle_t, void *, ze_rtas_aabb_ext_t *, size_t *) -> ze_result_t { + builderCommandListAppendCopyCalled++; + return ZE_RESULT_SUCCESS; + }; + + // Simulate symbol loading + driverHandle->rtasLibraryHandle = std::make_unique(); + MockRTASExtOsLibrary *osLibHandle = static_cast(driverHandle->rtasLibraryHandle.get()); + osLibHandle->funcMap["zeRTASBuilderCommandListAppendCopyExtImpl"] = reinterpret_cast(+builderCommandListAppendCopy); + + // Patch the entry point loader to pick up our symbol + L0::RTASBuilderExt::loadEntryPoints(driverHandle->rtasLibraryHandle.get()); + + // Call the function under test + ze_command_list_handle_t hCommandList = reinterpret_cast(0x1234); + ze_rtas_builder_ext_handle_t hBuilder = reinterpret_cast(0x5678); + void *pScratchBuffer = nullptr; + size_t scratchBufferSizeBytes = 0; + void *pRtasBuffer = nullptr; + size_t rtasBufferSizeBytes = 0; + ze_rtas_parallel_operation_ext_handle_t hParallelOperation = reinterpret_cast(0x9abc); + void *pBuildUserPtr = nullptr; + ze_rtas_aabb_ext_t *pBounds = nullptr; + size_t *pRtasBufferSizeBytes = nullptr; + + // If the function pointer is available, call it + auto func = reinterpret_cast( + osLibHandle->funcMap["zeRTASBuilderCommandListAppendCopyExtImpl"]); + ASSERT_NE(nullptr, func); + EXPECT_EQ(ZE_RESULT_SUCCESS, func(hCommandList, hBuilder, pScratchBuffer, scratchBufferSizeBytes, pRtasBuffer, rtasBufferSizeBytes, hParallelOperation, pBuildUserPtr, pBounds, pRtasBufferSizeBytes)); + EXPECT_NE(0u, builderCommandListAppendCopyCalled); +} + +TEST_F(RTASTestExt, GivenUnderlyingBuilderCommandListAppendCopyFailsThenErrorIsReturned_Ext) { + // Mock implementation for zeRTASBuilderCommandListAppendCopyExt that fails + static uint32_t builderCommandListAppendCopyFailCalled = 0; + auto builderCommandListAppendCopyFail = [](ze_command_list_handle_t, ze_rtas_builder_ext_handle_t, const void *, size_t, void *, size_t, ze_rtas_parallel_operation_ext_handle_t, void *, ze_rtas_aabb_ext_t *, size_t *) -> ze_result_t { + builderCommandListAppendCopyFailCalled++; + return ZE_RESULT_ERROR_UNKNOWN; + }; + + // Simulate symbol loading + driverHandle->rtasLibraryHandle = std::make_unique(); + MockRTASExtOsLibrary *osLibHandle = static_cast(driverHandle->rtasLibraryHandle.get()); + osLibHandle->funcMap["zeRTASBuilderCommandListAppendCopyExtImpl"] = reinterpret_cast(+builderCommandListAppendCopyFail); + + // Patch the entry point loader to pick up our symbol + L0::RTASBuilderExt::loadEntryPoints(driverHandle->rtasLibraryHandle.get()); + + // Call the function under test + ze_command_list_handle_t hCommandList = reinterpret_cast(0x1234); + ze_rtas_builder_ext_handle_t hBuilder = reinterpret_cast(0x5678); + void *pScratchBuffer = nullptr; + size_t scratchBufferSizeBytes = 0; + void *pRtasBuffer = nullptr; + size_t rtasBufferSizeBytes = 0; + ze_rtas_parallel_operation_ext_handle_t hParallelOperation = reinterpret_cast(0x9abc); + void *pBuildUserPtr = nullptr; + ze_rtas_aabb_ext_t *pBounds = nullptr; + size_t *pRtasBufferSizeBytes = nullptr; + + // If the function pointer is available, call it + auto func = reinterpret_cast( + osLibHandle->funcMap["zeRTASBuilderCommandListAppendCopyExtImpl"]); + ASSERT_NE(nullptr, func); + EXPECT_EQ(ZE_RESULT_ERROR_UNKNOWN, func(hCommandList, hBuilder, pScratchBuffer, scratchBufferSizeBytes, pRtasBuffer, rtasBufferSizeBytes, hParallelOperation, pBuildUserPtr, pBounds, pRtasBufferSizeBytes)); + EXPECT_NE(0u, builderCommandListAppendCopyFailCalled); +} + +TEST_F(RTASTestExt, GivenLibraryPreLoadedAndUnderlyingBuilderCreateSucceedsThenSuccessIsReturned_Ext) { + ze_rtas_builder_ext_handle_t hBuilder; + builderCreateExtImpl = &builderCreate; + builderDestroyExtImpl = &builderDestroy; + driverHandle->rtasLibraryHandle = std::make_unique(); + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeRTASBuilderCreateExt(driverHandle->toHandle(), nullptr, &hBuilder)); + EXPECT_EQ(1u, builderCreateCalled); + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeRTASBuilderDestroyExt(hBuilder)); + EXPECT_EQ(1u, builderDestroyCalled); +} + +TEST_F(RTASTestExt, GivenLibraryPreLoadedAndUnderlyingBuilderCreateFailsThenErrorIsReturned_Ext) { + ze_rtas_builder_ext_handle_t hBuilder; + builderCreateExtImpl = &builderCreateFail; + driverHandle->rtasLibraryHandle = std::make_unique(); + EXPECT_EQ(ZE_RESULT_ERROR_UNKNOWN, L0::zeRTASBuilderCreateExt(driverHandle->toHandle(), nullptr, &hBuilder)); + EXPECT_EQ(1u, builderCreateFailCalled); +} + +TEST_F(RTASTestExt, GivenLibraryFailsToLoadThenBuilderCreateReturnsError_Ext) { + ze_rtas_builder_ext_handle_t hBuilder; + VariableBackup funcBackup{&NEO::OsLibrary::loadFunc, MockRTASExtOsLibrary::load}; + MockRTASExtOsLibrary::mockLoad = false; + driverHandle->rtasLibraryHandle.reset(); + EXPECT_EQ(ZE_RESULT_ERROR_DEPENDENCY_UNAVAILABLE, L0::zeRTASBuilderCreateExt(driverHandle->toHandle(), nullptr, &hBuilder)); +} + +TEST_F(RTASTestExt, GivenUnderlyingBuilderDestroySucceedsThenSuccessIsReturned_Ext) { + auto pRTASBuilderExt = std::make_unique(); + builderDestroyExtImpl = &builderDestroy; + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeRTASBuilderDestroyExt(pRTASBuilderExt.release()->toHandle())); + EXPECT_EQ(1u, builderDestroyCalled); +} + +TEST_F(RTASTestExt, GivenUnderlyingBuilderDestroyFailsThenErrorIsReturned_Ext) { + auto pRTASBuilderExt = std::make_unique(); + builderDestroyExtImpl = &builderDestroyFail; + EXPECT_EQ(ZE_RESULT_ERROR_UNKNOWN, L0::zeRTASBuilderDestroyExt(pRTASBuilderExt.get()->toHandle())); + EXPECT_EQ(1u, builderDestroyFailCalled); +} + +TEST_F(RTASTestExt, GivenUnderlyingBuilderGetBuildPropertiesSucceedsThenSuccessIsReturned_Ext) { + RTASBuilderExt pRTASBuilderExt; + builderGetBuildPropertiesExtImpl = &builderGetBuildProperties; + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeRTASBuilderGetBuildPropertiesExt(pRTASBuilderExt.toHandle(), nullptr, nullptr)); + EXPECT_EQ(1u, builderGetBuildPropertiesCalled); +} + +TEST_F(RTASTestExt, GivenUnderlyingBuilderGetBuildPropertiesFailsThenErrorIsReturned_Ext) { + RTASBuilderExt pRTASBuilderExt; + builderGetBuildPropertiesExtImpl = &builderGetBuildPropertiesFail; + EXPECT_EQ(ZE_RESULT_ERROR_UNKNOWN, L0::zeRTASBuilderGetBuildPropertiesExt(pRTASBuilderExt.toHandle(), nullptr, nullptr)); + EXPECT_EQ(1u, builderGetBuildPropertiesFailCalled); +} + +TEST_F(RTASTestExt, GivenUnderlyingBuilderBuildSucceedsThenSuccessIsReturned_Ext) { + RTASBuilderExt pRTASBuilderExt; + RTASParallelOperationExt pParallelOperation; + builderBuildExtImpl = &builderBuild; + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeRTASBuilderBuildExt(pRTASBuilderExt.toHandle(), + nullptr, + nullptr, 0, + nullptr, 0, + pParallelOperation.toHandle(), + nullptr, nullptr, + nullptr)); + EXPECT_EQ(1u, builderBuildCalled); +} + +TEST_F(RTASTestExt, GivenUnderlyingBuilderBuildFailsThenErrorIsReturned_Ext) { + RTASBuilderExt pRTASBuilderExt; + RTASParallelOperationExt pParallelOperation; + builderBuildExtImpl = &builderBuildFail; + EXPECT_EQ(ZE_RESULT_ERROR_UNKNOWN, L0::zeRTASBuilderBuildExt(pRTASBuilderExt.toHandle(), + nullptr, + nullptr, 0, + nullptr, 0, + pParallelOperation.toHandle(), + nullptr, nullptr, + nullptr)); + EXPECT_EQ(1u, builderBuildFailCalled); +} + +TEST_F(RTASTestExt, GivenLibraryPreLoadedAndUnderlyingFormatCompatibilitySucceedsThenSuccessIsReturned_Ext) { + formatCompatibilityCheckExtImpl = &formatCompatibilityCheck; + const ze_rtas_format_ext_t accelFormatA = {}; + const ze_rtas_format_ext_t accelFormatB = {}; + driverHandle->rtasLibraryHandle = std::make_unique(); + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeDriverRTASFormatCompatibilityCheckExt(driverHandle->toHandle(), accelFormatA, accelFormatB)); + EXPECT_EQ(1u, formatCompatibilityCheckCalled); +} + +TEST_F(RTASTestExt, GivenUnderlyingFormatCompatibilityFailsThenErrorIsReturned_Ext) { + formatCompatibilityCheckExtImpl = &formatCompatibilityCheckFail; + const ze_rtas_format_ext_t accelFormatA = {}; + const ze_rtas_format_ext_t accelFormatB = {}; + driverHandle->rtasLibraryHandle = std::make_unique(); + EXPECT_EQ(ZE_RESULT_ERROR_UNKNOWN, L0::zeDriverRTASFormatCompatibilityCheckExt(driverHandle->toHandle(), accelFormatA, accelFormatB)); + EXPECT_EQ(1u, formatCompatibilityCheckFailCalled); +} + +TEST_F(RTASTestExt, GivenLibraryPreLoadedAndUnderlyingParallelOperationCreateSucceedsThenSuccessIsReturned_Ext) { + ze_rtas_parallel_operation_ext_handle_t hParallelOperation; + parallelOperationCreateExtImpl = ¶llelOperationCreate; + parallelOperationDestroyExtImpl = ¶llelOperationDestroy; + driverHandle->rtasLibraryHandle = std::make_unique(); + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeRTASParallelOperationCreateExt(driverHandle->toHandle(), &hParallelOperation)); + EXPECT_EQ(1u, parallelOperationCreateCalled); + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeRTASParallelOperationDestroyExt(hParallelOperation)); + EXPECT_EQ(1u, parallelOperationDestroyCalled); +} + +TEST_F(RTASTestExt, GivenUnderlyingParallelOperationCreateFailsThenErrorIsReturned_Ext) { + ze_rtas_parallel_operation_ext_handle_t hParallelOperation; + parallelOperationCreateExtImpl = ¶llelOperationCreateFail; + driverHandle->rtasLibraryHandle = std::make_unique(); + EXPECT_EQ(ZE_RESULT_ERROR_UNKNOWN, L0::zeRTASParallelOperationCreateExt(driverHandle->toHandle(), &hParallelOperation)); + EXPECT_EQ(1u, parallelOperationCreateFailCalled); +} + +TEST_F(RTASTestExt, GivenUnderlyingParallelOperationDestroySucceedsThenSuccessIsReturned_Ext) { + auto pParallelOperation = std::make_unique(); + parallelOperationDestroyExtImpl = ¶llelOperationDestroy; + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeRTASParallelOperationDestroyExt(pParallelOperation.release()->toHandle())); + EXPECT_EQ(1u, parallelOperationDestroyCalled); +} + +TEST_F(RTASTestExt, GivenUnderlyingParallelOperationDestroyFailsThenErrorIsReturned_Ext) { + auto pParallelOperation = std::make_unique(); + parallelOperationDestroyExtImpl = ¶llelOperationDestroyFail; + EXPECT_EQ(ZE_RESULT_ERROR_UNKNOWN, L0::zeRTASParallelOperationDestroyExt(pParallelOperation.get()->toHandle())); + EXPECT_EQ(1u, parallelOperationDestroyFailCalled); +} + +TEST_F(RTASTestExt, GivenUnderlyingParallelOperationGetPropertiesSucceedsThenSuccessIsReturned_Ext) { + RTASParallelOperationExt pParallelOperation; + parallelOperationGetPropertiesExtImpl = ¶llelOperationGetProperties; + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeRTASParallelOperationGetPropertiesExt(pParallelOperation.toHandle(), nullptr)); + EXPECT_EQ(1u, parallelOperationGetPropertiesCalled); +} + +TEST_F(RTASTestExt, GivenUnderlyingParallelOperationGetPropertiesFailsThenErrorIsReturned_Ext) { + RTASParallelOperationExt pParallelOperation; + parallelOperationGetPropertiesExtImpl = ¶llelOperationGetPropertiesFail; + EXPECT_EQ(ZE_RESULT_ERROR_UNKNOWN, L0::zeRTASParallelOperationGetPropertiesExt(pParallelOperation.toHandle(), nullptr)); + EXPECT_EQ(1u, parallelOperationGetPropertiesFailCalled); +} + +TEST_F(RTASTestExt, GivenUnderlyingParallelOperationJoinSucceedsThenSuccessIsReturned_Ext) { + RTASParallelOperationExt pParallelOperation; + parallelOperationJoinExtImpl = ¶llelOperationJoin; + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeRTASParallelOperationJoinExt(pParallelOperation.toHandle())); + EXPECT_EQ(1u, parallelOperationJoinCalled); +} + +TEST_F(RTASTestExt, GivenUnderlyingParallelOperationJoinFailsThenErrorIsReturned_Ext) { + RTASParallelOperationExt pParallelOperation; + parallelOperationJoinExtImpl = ¶llelOperationJoinFail; + EXPECT_EQ(ZE_RESULT_ERROR_UNKNOWN, L0::zeRTASParallelOperationJoinExt(pParallelOperation.toHandle())); + EXPECT_EQ(1u, parallelOperationJoinFailCalled); +} + +TEST_F(RTASTestExt, GivenNoSymbolAvailableInLibraryThenLoadEntryPointsReturnsFalse_Ext) { + driverHandle->rtasLibraryHandle = std::make_unique(); + EXPECT_EQ(false, L0::RTASBuilderExt::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); +} + +TEST_F(RTASTestExt, GivenOnlySingleSymbolAvailableThenLoadEntryPointsReturnsFalse_Ext) { + driverHandle->rtasLibraryHandle = std::make_unique(); + MockRTASExtOsLibrary *osLibHandle = static_cast(driverHandle->rtasLibraryHandle.get()); + osLibHandle->funcMap["zeRTASBuilderCreateExtImpl"] = reinterpret_cast(&builderCreate); + EXPECT_EQ(false, L0::RTASBuilderExt::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + osLibHandle->funcMap.clear(); + osLibHandle->funcMap["zeRTASBuilderDestroyExtImpl"] = reinterpret_cast(&builderDestroy); + EXPECT_EQ(false, L0::RTASBuilderExt::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + osLibHandle->funcMap.clear(); + osLibHandle->funcMap["zeRTASBuilderGetBuildPropertiesExtImpl"] = reinterpret_cast(&builderGetBuildProperties); + EXPECT_EQ(false, L0::RTASBuilderExt::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + osLibHandle->funcMap.clear(); + osLibHandle->funcMap["zeRTASBuilderBuildExtImpl"] = reinterpret_cast(&builderBuild); + EXPECT_EQ(false, L0::RTASBuilderExt::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + osLibHandle->funcMap.clear(); + osLibHandle->funcMap["zeDriverRTASFormatCompatibilityCheckExtImpl"] = reinterpret_cast(&formatCompatibilityCheck); + EXPECT_EQ(false, L0::RTASBuilderExt::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + osLibHandle->funcMap.clear(); + osLibHandle->funcMap["zeRTASParallelOperationCreateExtImpl"] = reinterpret_cast(¶llelOperationCreate); + EXPECT_EQ(false, L0::RTASBuilderExt::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + osLibHandle->funcMap.clear(); + osLibHandle->funcMap["zeRTASParallelOperationDestroyExtImpl"] = reinterpret_cast(¶llelOperationDestroy); + EXPECT_EQ(false, L0::RTASBuilderExt::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + osLibHandle->funcMap.clear(); + osLibHandle->funcMap["zeRTASParallelOperationGetPropertiesExtImpl"] = reinterpret_cast(¶llelOperationGetProperties); + EXPECT_EQ(false, L0::RTASBuilderExt::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + osLibHandle->funcMap.clear(); + osLibHandle->funcMap["zeRTASParallelOperationJoinExtImpl"] = reinterpret_cast(¶llelOperationJoin); + EXPECT_EQ(false, L0::RTASBuilderExt::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); +} + +TEST_F(RTASTestExt, GivenMissingSymbolsThenLoadEntryPointsReturnsFalse_Ext) { + driverHandle->rtasLibraryHandle = std::make_unique(); + MockRTASExtOsLibrary *osLibHandle = static_cast(driverHandle->rtasLibraryHandle.get()); + EXPECT_EQ(false, L0::RTASBuilderExt::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + osLibHandle->funcMap["zeRTASBuilderCreateExtImpl"] = reinterpret_cast(&builderCreate); + EXPECT_EQ(false, L0::RTASBuilderExt::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + osLibHandle->funcMap["zeRTASBuilderDestroyExtImpl"] = reinterpret_cast(&builderDestroy); + EXPECT_EQ(false, L0::RTASBuilderExt::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + osLibHandle->funcMap["zeRTASBuilderGetBuildPropertiesExtImpl"] = reinterpret_cast(&builderGetBuildProperties); + EXPECT_EQ(false, L0::RTASBuilderExt::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + osLibHandle->funcMap["zeRTASBuilderBuildExtImpl"] = reinterpret_cast(&builderBuild); + EXPECT_EQ(false, L0::RTASBuilderExt::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + osLibHandle->funcMap["zeDriverRTASFormatCompatibilityCheckExtImpl"] = reinterpret_cast(&formatCompatibilityCheck); + EXPECT_EQ(false, L0::RTASBuilderExt::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + osLibHandle->funcMap["zeRTASParallelOperationCreateExtImpl"] = reinterpret_cast(¶llelOperationCreate); + EXPECT_EQ(false, L0::RTASBuilderExt::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + osLibHandle->funcMap["zeRTASParallelOperationDestroyExtImpl"] = reinterpret_cast(¶llelOperationDestroy); + EXPECT_EQ(false, L0::RTASBuilderExt::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + osLibHandle->funcMap["zeRTASParallelOperationGetPropertiesExtImpl"] = reinterpret_cast(¶llelOperationGetProperties); + EXPECT_EQ(false, L0::RTASBuilderExt::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); +} + +} // namespace ult +} // namespace L0 diff --git a/level_zero/ddi/ze_ddi_tables.cpp b/level_zero/ddi/ze_ddi_tables.cpp index d0cc5e0362..4adb7a2919 100644 --- a/level_zero/ddi/ze_ddi_tables.cpp +++ b/level_zero/ddi/ze_ddi_tables.cpp @@ -96,6 +96,11 @@ DriverDispatch::DriverDispatch() { this->sysman.Diagnostics = &this->sysmanDiagnostics; this->sysman.VFManagementExp = &this->sysmanVFManagementExp; + this->coreRTASBuilder.pfnCreateExt = L0::zeRTASBuilderCreateExt; + this->coreRTASBuilder.pfnGetBuildPropertiesExt = L0::zeRTASBuilderGetBuildPropertiesExt; + this->coreRTASBuilder.pfnBuildExt = L0::zeRTASBuilderBuildExt; + this->coreRTASBuilder.pfnDestroyExt = L0::zeRTASBuilderDestroyExt; + this->coreRTASBuilder.pfnCommandListAppendCopyExt = L0::zeRTASBuilderCommandListAppendCopyExt; this->coreRTASBuilderExp.pfnCreateExp = L0::zeRTASBuilderCreateExp; this->coreRTASBuilderExp.pfnGetBuildPropertiesExp = L0::zeRTASBuilderGetBuildPropertiesExp; this->coreRTASBuilderExp.pfnBuildExp = L0::zeRTASBuilderBuildExp; @@ -104,6 +109,10 @@ DriverDispatch::DriverDispatch() { this->coreRTASParallelOperationExp.pfnGetPropertiesExp = L0::zeRTASParallelOperationGetPropertiesExp; this->coreRTASParallelOperationExp.pfnJoinExp = L0::zeRTASParallelOperationJoinExp; this->coreRTASParallelOperationExp.pfnDestroyExp = L0::zeRTASParallelOperationDestroyExp; + this->coreRTASParallelOperation.pfnCreateExt = L0::zeRTASParallelOperationCreateExt; + this->coreRTASParallelOperation.pfnGetPropertiesExt = L0::zeRTASParallelOperationGetPropertiesExt; + this->coreRTASParallelOperation.pfnJoinExt = L0::zeRTASParallelOperationJoinExt; + this->coreRTASParallelOperation.pfnDestroyExt = L0::zeRTASParallelOperationDestroyExt; this->coreGlobal.pfnInit = L0::zeInit; this->coreGlobal.pfnInitDrivers = L0::zeInitDrivers; this->coreDriver.pfnGet = L0::zeDriverGet; diff --git a/level_zero/ddi/ze_ddi_tables.h b/level_zero/ddi/ze_ddi_tables.h index df18cd80b5..2b8f946df4 100644 --- a/level_zero/ddi/ze_ddi_tables.h +++ b/level_zero/ddi/ze_ddi_tables.h @@ -39,8 +39,8 @@ struct DriverDispatch { zet_dditable_driver_t tools{}; zes_dditable_driver_t sysman{}; - ze_rtas_builder_exp_dditable_t coreRTASBuilderExp{}; ze_rtas_builder_dditable_t coreRTASBuilder{}; + ze_rtas_builder_exp_dditable_t coreRTASBuilderExp{}; ze_rtas_parallel_operation_exp_dditable_t coreRTASParallelOperationExp{}; ze_rtas_parallel_operation_dditable_t coreRTASParallelOperation{}; ze_global_dditable_t coreGlobal{};