diff --git a/level_zero/api/core/ze_core_loader.cpp b/level_zero/api/core/ze_core_loader.cpp index 349df9dc0b..823745255a 100644 --- a/level_zero/api/core/ze_core_loader.cpp +++ b/level_zero/api/core/ze_core_loader.cpp @@ -716,3 +716,60 @@ zeGetFabricEdgeExpProcAddrTable( driverDdiTable.coreDdiTable.FabricEdgeExp = *pDdiTable; return result; } + +ZE_APIEXPORT ze_result_t ZE_APICALL +zeGetDriverExpProcAddrTable( + ze_api_version_t version, + ze_driver_exp_dditable_t *pDdiTable) { + + if (nullptr == pDdiTable) + return ZE_RESULT_ERROR_INVALID_ARGUMENT; + if (ZE_MAJOR_VERSION(driverDdiTable.version) != ZE_MAJOR_VERSION(version) || + ZE_MINOR_VERSION(driverDdiTable.version) > ZE_MINOR_VERSION(version)) + return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + + ze_result_t result = ZE_RESULT_SUCCESS; + pDdiTable->pfnRTASFormatCompatibilityCheckExp = L0::zeDriverRTASFormatCompatibilityCheckExp; + driverDdiTable.coreDdiTable.DriverExp = *pDdiTable; + return result; +} + +ZE_APIEXPORT ze_result_t ZE_APICALL +zeGetRTASParallelOperationExpProcAddrTable( + ze_api_version_t version, + ze_rtas_parallel_operation_exp_dditable_t *pDdiTable) { + + if (nullptr == pDdiTable) + return ZE_RESULT_ERROR_INVALID_ARGUMENT; + if (ZE_MAJOR_VERSION(driverDdiTable.version) != ZE_MAJOR_VERSION(version) || + ZE_MINOR_VERSION(driverDdiTable.version) > ZE_MINOR_VERSION(version)) + return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + + ze_result_t result = ZE_RESULT_SUCCESS; + pDdiTable->pfnCreateExp = L0::zeRTASParallelOperationCreateExp; + pDdiTable->pfnGetPropertiesExp = L0::zeRTASParallelOperationGetPropertiesExp; + pDdiTable->pfnJoinExp = L0::zeRTASParallelOperationJoinExp; + pDdiTable->pfnDestroyExp = L0::zeRTASParallelOperationDestroyExp; + driverDdiTable.coreDdiTable.RTASParallelOperationExp = *pDdiTable; + return result; +} + +ZE_APIEXPORT ze_result_t ZE_APICALL +zeGetRTASBuilderExpProcAddrTable( + ze_api_version_t version, + ze_rtas_builder_exp_dditable_t *pDdiTable) { + + if (nullptr == pDdiTable) + return ZE_RESULT_ERROR_INVALID_ARGUMENT; + if (ZE_MAJOR_VERSION(driverDdiTable.version) != ZE_MAJOR_VERSION(version) || + ZE_MINOR_VERSION(driverDdiTable.version) > ZE_MINOR_VERSION(version)) + return ZE_RESULT_ERROR_UNSUPPORTED_VERSION; + + ze_result_t result = ZE_RESULT_SUCCESS; + pDdiTable->pfnCreateExp = L0::zeRTASBuilderCreateExp; + pDdiTable->pfnGetBuildPropertiesExp = L0::zeRTASBuilderGetBuildPropertiesExp; + pDdiTable->pfnBuildExp = L0::zeRTASBuilderBuildExp; + pDdiTable->pfnDestroyExp = L0::zeRTASBuilderDestroyExp; + driverDdiTable.coreDdiTable.RTASBuilderExp = *pDdiTable; + return result; +} diff --git a/level_zero/api/extensions/public/ze_exp_ext.cpp b/level_zero/api/extensions/public/ze_exp_ext.cpp index a05d824249..c2fc9443a4 100644 --- a/level_zero/api/extensions/public/ze_exp_ext.cpp +++ b/level_zero/api/extensions/public/ze_exp_ext.cpp @@ -14,6 +14,7 @@ #include "level_zero/core/source/fabric/fabric.h" #include "level_zero/core/source/image/image.h" #include "level_zero/core/source/kernel/kernel.h" +#include "level_zero/core/source/rtas/rtas.h" namespace L0 { ze_result_t zeKernelSetGlobalOffsetExp( @@ -114,6 +115,67 @@ ze_result_t zeFabricEdgeGetPropertiesExp(ze_fabric_edge_handle_t hEdge, return L0::FabricEdge::fromHandle(hEdge)->getProperties(pEdgeProperties); } +ze_result_t zeRTASBuilderCreateExp(ze_driver_handle_t hDriver, + const ze_rtas_builder_exp_desc_t *pDescriptor, + ze_rtas_builder_exp_handle_t *phBuilder) { + return L0::DriverHandle::fromHandle(hDriver)->createRTASBuilder(pDescriptor, phBuilder); +} + +ze_result_t zeRTASBuilderGetBuildPropertiesExp(ze_rtas_builder_exp_handle_t hBuilder, + const ze_rtas_builder_build_op_exp_desc_t *pBuildOpDescriptor, + ze_rtas_builder_exp_properties_t *pProperties) { + return L0::RTASBuilder::fromHandle(hBuilder)->getProperties(pBuildOpDescriptor, pProperties); +} + +ze_result_t zeDriverRTASFormatCompatibilityCheckExp(ze_driver_handle_t hDriver, + ze_rtas_format_exp_t rtasFormatA, + ze_rtas_format_exp_t rtasFormatB) { + return L0::DriverHandle::fromHandle(hDriver)->formatRTASCompatibilityCheck(rtasFormatA, rtasFormatB); +} + +ze_result_t zeRTASBuilderBuildExp(ze_rtas_builder_exp_handle_t hBuilder, + const ze_rtas_builder_build_op_exp_desc_t *pBuildOpDescriptor, + void *pScratchBuffer, + size_t scratchBufferSizeBytes, + void *pRtasBuffer, + size_t rtasBufferSizeBytes, + ze_rtas_parallel_operation_exp_handle_t hParallelOperation, + void *pBuildUserPtr, + ze_rtas_aabb_exp_t *pBounds, + size_t *pRtasBufferSizeBytes) { + return L0::RTASBuilder::fromHandle(hBuilder)->build(pBuildOpDescriptor, + pScratchBuffer, + scratchBufferSizeBytes, + pRtasBuffer, + rtasBufferSizeBytes, + hParallelOperation, + pBuildUserPtr, + pBounds, + pRtasBufferSizeBytes); +} + +ze_result_t zeRTASBuilderDestroyExp(ze_rtas_builder_exp_handle_t hBuilder) { + return L0::RTASBuilder::fromHandle(hBuilder)->destroy(); +} + +ze_result_t zeRTASParallelOperationCreateExp(ze_driver_handle_t hDriver, + ze_rtas_parallel_operation_exp_handle_t *phParallelOperation) { + return L0::DriverHandle::fromHandle(hDriver)->createRTASParallelOperation(phParallelOperation); +} + +ze_result_t zeRTASParallelOperationGetPropertiesExp(ze_rtas_parallel_operation_exp_handle_t hParallelOperation, + ze_rtas_parallel_operation_exp_properties_t *pProperties) { + return L0::RTASParallelOperation::fromHandle(hParallelOperation)->getProperties(pProperties); +} + +ze_result_t zeRTASParallelOperationJoinExp(ze_rtas_parallel_operation_exp_handle_t hParallelOperation) { + return L0::RTASParallelOperation::fromHandle(hParallelOperation)->join(); +} + +ze_result_t zeRTASParallelOperationDestroyExp(ze_rtas_parallel_operation_exp_handle_t hParallelOperation) { + return L0::RTASParallelOperation::fromHandle(hParallelOperation)->destroy(); +} + } // namespace L0 extern "C" { @@ -232,4 +294,77 @@ zeFabricEdgeGetPropertiesExp( return L0::zeFabricEdgeGetPropertiesExp(hEdge, pEdgeProperties); } +ZE_APIEXPORT ze_result_t ZE_APICALL +zeRTASBuilderCreateExp( + ze_driver_handle_t hDriver, + const ze_rtas_builder_exp_desc_t *pDescriptor, + ze_rtas_builder_exp_handle_t *phBuilder) { + return L0::zeRTASBuilderCreateExp(hDriver, pDescriptor, phBuilder); +} + +ZE_APIEXPORT ze_result_t ZE_APICALL +zeRTASBuilderGetBuildPropertiesExp( + ze_rtas_builder_exp_handle_t hBuilder, + const ze_rtas_builder_build_op_exp_desc_t *pBuildOpDescriptor, + ze_rtas_builder_exp_properties_t *pProperties) { + return L0::zeRTASBuilderGetBuildPropertiesExp(hBuilder, pBuildOpDescriptor, pProperties); +} + +ZE_APIEXPORT ze_result_t ZE_APICALL +zeDriverRTASFormatCompatibilityCheckExp( + ze_driver_handle_t hDriver, + ze_rtas_format_exp_t rtasFormatA, + ze_rtas_format_exp_t rtasFormatB) { + return L0::zeDriverRTASFormatCompatibilityCheckExp(hDriver, rtasFormatA, rtasFormatB); +} + +ZE_APIEXPORT ze_result_t ZE_APICALL +zeRTASBuilderBuildExp( + ze_rtas_builder_exp_handle_t hBuilder, + const ze_rtas_builder_build_op_exp_desc_t *pBuildOpDescriptor, + void *pScratchBuffer, + size_t scratchBufferSizeBytes, + void *pRtasBuffer, + size_t rtasBufferSizeBytes, + ze_rtas_parallel_operation_exp_handle_t hParallelOperation, + void *pBuildUserPtr, + ze_rtas_aabb_exp_t *pBounds, + size_t *pRtasBufferSizeBytes) { + return L0::zeRTASBuilderBuildExp(hBuilder, pBuildOpDescriptor, pScratchBuffer, + scratchBufferSizeBytes, pRtasBuffer, rtasBufferSizeBytes, + hParallelOperation, pBuildUserPtr, pBounds, pRtasBufferSizeBytes); +} + +ZE_APIEXPORT ze_result_t ZE_APICALL +zeRTASBuilderDestroyExp( + ze_rtas_builder_exp_handle_t hBuilder) { + return L0::zeRTASBuilderDestroyExp(hBuilder); +} + +ZE_APIEXPORT ze_result_t ZE_APICALL +zeRTASParallelOperationCreateExp( + ze_driver_handle_t hDriver, + ze_rtas_parallel_operation_exp_handle_t *phParallelOperation) { + return L0::zeRTASParallelOperationCreateExp(hDriver, phParallelOperation); +} + +ZE_APIEXPORT ze_result_t ZE_APICALL +zeRTASParallelOperationGetPropertiesExp( + ze_rtas_parallel_operation_exp_handle_t hParallelOperation, + ze_rtas_parallel_operation_exp_properties_t *pProperties) { + return L0::zeRTASParallelOperationGetPropertiesExp(hParallelOperation, pProperties); +} + +ZE_APIEXPORT ze_result_t ZE_APICALL +zeRTASParallelOperationJoinExp( + ze_rtas_parallel_operation_exp_handle_t hParallelOperation) { + return L0::zeRTASParallelOperationJoinExp(hParallelOperation); +} + +ZE_APIEXPORT ze_result_t ZE_APICALL +zeRTASParallelOperationDestroyExp( + ze_rtas_parallel_operation_exp_handle_t hParallelOperation) { + return L0::zeRTASParallelOperationDestroyExp(hParallelOperation); +} + } // extern "C" diff --git a/level_zero/api/extensions/public/ze_exp_ext.h b/level_zero/api/extensions/public/ze_exp_ext.h index e209d6653f..013b1a80b7 100644 --- a/level_zero/api/extensions/public/ze_exp_ext.h +++ b/level_zero/api/extensions/public/ze_exp_ext.h @@ -99,4 +99,48 @@ ze_result_t zeFabricEdgeGetPropertiesExp( ze_fabric_edge_handle_t hEdge, ze_fabric_edge_exp_properties_t *pEdgeProperties); +ze_result_t zeRTASBuilderCreateExp( + ze_driver_handle_t hDriver, + const ze_rtas_builder_exp_desc_t *pDescriptor, + ze_rtas_builder_exp_handle_t *phBuilder); + +ze_result_t zeRTASBuilderGetBuildPropertiesExp( + ze_rtas_builder_exp_handle_t hBuilder, + const ze_rtas_builder_build_op_exp_desc_t *pBuildOpDescriptor, + ze_rtas_builder_exp_properties_t *pProperties); + +ze_result_t zeDriverRTASFormatCompatibilityCheckExp( + ze_driver_handle_t hDriver, + ze_rtas_format_exp_t rtasFormatA, + ze_rtas_format_exp_t rtasFormatB); + +ze_result_t zeRTASBuilderBuildExp( + ze_rtas_builder_exp_handle_t hBuilder, + const ze_rtas_builder_build_op_exp_desc_t *pBuildOpDescriptor, + void *pScratchBuffer, + size_t scratchBufferSizeBytes, + void *pRtasBuffer, + size_t rtasBufferSizeBytes, + ze_rtas_parallel_operation_exp_handle_t hParallelOperation, + void *pBuildUserPtr, + ze_rtas_aabb_exp_t *pBounds, + size_t *pRtasBufferSizeBytes); + +ze_result_t zeRTASBuilderDestroyExp( + ze_rtas_builder_exp_handle_t hBuilder); + +ze_result_t zeRTASParallelOperationCreateExp( + ze_driver_handle_t hDriver, + ze_rtas_parallel_operation_exp_handle_t *phParallelOperation); + +ze_result_t zeRTASParallelOperationGetPropertiesExp( + ze_rtas_parallel_operation_exp_handle_t hParallelOperation, + ze_rtas_parallel_operation_exp_properties_t *pProperties); + +ze_result_t zeRTASParallelOperationJoinExp( + ze_rtas_parallel_operation_exp_handle_t hParallelOperation); + +ze_result_t zeRTASParallelOperationDestroyExp( + ze_rtas_parallel_operation_exp_handle_t hParallelOperation); + } // namespace L0 diff --git a/level_zero/core/source/device/device_imp.cpp b/level_zero/core/source/device/device_imp.cpp index fcd6dcd54c..464b9140cd 100644 --- a/level_zero/core/source/device/device_imp.cpp +++ b/level_zero/core/source/device/device_imp.cpp @@ -49,6 +49,7 @@ #include "level_zero/core/source/module/module.h" #include "level_zero/core/source/module/module_build_log.h" #include "level_zero/core/source/printf_handler/printf_handler.h" +#include "level_zero/core/source/rtas/rtas.h" #include "level_zero/core/source/sampler/sampler.h" #include "level_zero/tools/source/debug/debug_session.h" #include "level_zero/tools/source/debug/debug_session_imp.h" @@ -784,6 +785,7 @@ ze_result_t DeviceImp::getProperties(ze_device_properties_t *pDeviceProperties) const auto &deviceInfo = this->neoDevice->getDeviceInfo(); const auto &hardwareInfo = this->neoDevice->getHardwareInfo(); auto &gfxCoreHelper = this->neoDevice->getGfxCoreHelper(); + const auto &l0GfxCoreHelper = this->getL0GfxCoreHelper(); pDeviceProperties->type = ZE_DEVICE_TYPE_GPU; @@ -897,6 +899,21 @@ ze_result_t DeviceImp::getProperties(ze_device_properties_t *pDeviceProperties) } else if (extendedProperties->stype == ZE_STRUCTURE_TYPE_EVENT_QUERY_KERNEL_TIMESTAMPS_EXT_PROPERTIES) { ze_event_query_kernel_timestamps_ext_properties_t *kernelTimestampExtProperties = reinterpret_cast(extendedProperties); kernelTimestampExtProperties->flags = ZE_EVENT_QUERY_KERNEL_TIMESTAMPS_EXT_FLAG_KERNEL | ZE_EVENT_QUERY_KERNEL_TIMESTAMPS_EXT_FLAG_SYNCHRONIZED; + } else if (extendedProperties->stype == ZE_STRUCTURE_TYPE_RTAS_DEVICE_EXP_PROPERTIES) { + ze_rtas_device_exp_properties_t *rtasProperties = reinterpret_cast(extendedProperties); + rtasProperties->flags = 0; + rtasProperties->rtasFormat = l0GfxCoreHelper.getSupportedRTASFormat(); + rtasProperties->rtasBufferAlignment = 128; + + if (l0GfxCoreHelper.platformSupportsRayTracing()) { + auto driverHandle = this->getDriverHandle(); + DriverHandleImp *driverHandleImp = static_cast(driverHandle); + + ze_result_t result = driverHandleImp->loadRTASLibrary(); + if (result != ZE_RESULT_SUCCESS) { + rtasProperties->rtasFormat = ZE_RTAS_FORMAT_EXP_INVALID; + } + } } extendedProperties = static_cast(extendedProperties->pNext); } diff --git a/level_zero/core/source/driver/driver_handle.h b/level_zero/core/source/driver/driver_handle.h index 8447edfe9b..17e647e5ed 100644 --- a/level_zero/core/source/driver/driver_handle.h +++ b/level_zero/core/source/driver/driver_handle.h @@ -73,6 +73,11 @@ struct DriverHandle : BaseDriver { virtual uint32_t getEventMaxPacketCount(uint32_t numDevices, ze_device_handle_t *deviceHandles) const = 0; virtual uint32_t getEventMaxKernelCount(uint32_t numDevices, ze_device_handle_t *deviceHandles) const = 0; + virtual ze_result_t loadRTASLibrary() = 0; + virtual ze_result_t createRTASBuilder(const ze_rtas_builder_exp_desc_t *desc, ze_rtas_builder_exp_handle_t *phBuilder) = 0; + virtual ze_result_t createRTASParallelOperation(ze_rtas_parallel_operation_exp_handle_t *phParallelOperation) = 0; + virtual ze_result_t formatRTASCompatibilityCheck(ze_rtas_format_exp_t rtasFormatA, ze_rtas_format_exp_t rtasFormatB) = 0; + virtual int setErrorDescription(const char *fmt, ...) = 0; virtual ze_result_t getErrorDescription(const char **ppString) = 0; virtual ze_result_t clearErrorDescription() = 0; diff --git a/level_zero/core/source/driver/driver_handle_imp.h b/level_zero/core/source/driver/driver_handle_imp.h index b300233844..78519f753a 100644 --- a/level_zero/core/source/driver/driver_handle_imp.h +++ b/level_zero/core/source/driver/driver_handle_imp.h @@ -9,6 +9,7 @@ #include "shared/source/debugger/debugger.h" #include "shared/source/memory_manager/graphics_allocation.h" +#include "shared/source/os_interface/os_library.h" #include "level_zero/api/extensions/public/ze_exp_ext.h" #include "level_zero/core/source/driver/driver_handle.h" @@ -98,6 +99,11 @@ struct DriverHandleImp : public DriverHandle { uint32_t getEventMaxPacketCount(uint32_t numDevices, ze_device_handle_t *deviceHandles) const override; uint32_t getEventMaxKernelCount(uint32_t numDevices, ze_device_handle_t *deviceHandles) const override; + ze_result_t loadRTASLibrary() override; + ze_result_t createRTASBuilder(const ze_rtas_builder_exp_desc_t *desc, ze_rtas_builder_exp_handle_t *phBuilder) override; + ze_result_t createRTASParallelOperation(ze_rtas_parallel_operation_exp_handle_t *phParallelOperation) override; + ze_result_t formatRTASCompatibilityCheck(ze_rtas_format_exp_t rtasFormatA, ze_rtas_format_exp_t rtasFormatB) override; + ze_result_t parseAffinityMaskCombined(uint32_t *pCount, ze_device_handle_t *phDevices); std::unique_ptr hostPointerManager; @@ -110,6 +116,9 @@ struct DriverHandleImp : public DriverHandle { std::vector devices; std::vector fabricVertices; std::vector fabricEdges; + + std::mutex rtasLock; + // Spec extensions const std::vector> extensionsSupported = { {ZE_FLOAT_ATOMICS_EXT_NAME, ZE_FLOAT_ATOMICS_EXT_VERSION_CURRENT}, @@ -128,13 +137,17 @@ struct DriverHandleImp : public DriverHandle { {ZE_CACHE_RESERVATION_EXT_NAME, ZE_CACHE_RESERVATION_EXT_VERSION_CURRENT}, {ZE_IMAGE_VIEW_EXT_NAME, ZE_IMAGE_VIEW_EXP_VERSION_CURRENT}, {ZE_IMAGE_VIEW_PLANAR_EXT_NAME, ZE_IMAGE_VIEW_PLANAR_EXP_VERSION_CURRENT}, - {ZE_EVENT_QUERY_KERNEL_TIMESTAMPS_EXT_NAME, ZE_EVENT_QUERY_KERNEL_TIMESTAMPS_EXT_VERSION_CURRENT}}; + {ZE_EVENT_QUERY_KERNEL_TIMESTAMPS_EXT_NAME, ZE_EVENT_QUERY_KERNEL_TIMESTAMPS_EXT_VERSION_CURRENT}, + {ZE_RTAS_BUILDER_EXP_NAME, ZE_RTAS_BUILDER_EXP_VERSION_CURRENT}}; uint64_t uuidTimestamp = 0u; NEO::MemoryManager *memoryManager = nullptr; NEO::SVMAllocsManager *svmAllocsManager = nullptr; + std::unique_ptr rtasLibraryHandle; + bool rtasLibraryUnavailable = false; + uint32_t numDevices = 0; RootDeviceIndicesContainer rootDeviceIndices; diff --git a/level_zero/core/source/gfx_core_helpers/l0_gfx_core_helper.h b/level_zero/core/source/gfx_core_helpers/l0_gfx_core_helper.h index 3911cf971a..163554f44d 100644 --- a/level_zero/core/source/gfx_core_helpers/l0_gfx_core_helper.h +++ b/level_zero/core/source/gfx_core_helpers/l0_gfx_core_helper.h @@ -29,6 +29,13 @@ class Debugger; namespace L0 { +typedef enum _ze_rtas_device_format_internal_t { + ZE_RTAS_DEVICE_FORMAT_EXP_INVALID = 0, // invalid acceleration structure format + ZE_RTAS_DEVICE_FORMAT_EXP_VERSION_1 = 1, // acceleration structure format version 1 + ZE_RTAS_DEVICE_FORMAT_EXP_VERSION_2 = 2, // acceleration structure format version 2 + ZE_RTAS_DEVICE_FORMAT_EXP_VERSION_MAX = 2 +} ze_rtas_device_format_internal_t; + struct Event; struct Device; struct EventPool; @@ -78,6 +85,7 @@ class L0GfxCoreHelper : public NEO::ApiGfxCoreHelper { virtual uint32_t getEventBaseMaxPacketCount(const NEO::RootDeviceEnvironment &rootDeviceEnvironment) const = 0; virtual NEO::HeapAddressModel getPlatformHeapAddressModel() const = 0; virtual std::vector getSupportedNumGrfs() const = 0; + virtual ze_rtas_format_exp_t getSupportedRTASFormat() const = 0; virtual bool platformSupportsImmediateComputeFlushTask() const = 0; virtual zet_debug_regset_type_intel_gpu_t getRegsetTypeForLargeGrfDetection() const = 0; @@ -116,6 +124,7 @@ class L0GfxCoreHelperHw : public L0GfxCoreHelper { uint32_t getEventBaseMaxPacketCount(const NEO::RootDeviceEnvironment &rootDeviceEnvironment) const override; NEO::HeapAddressModel getPlatformHeapAddressModel() const override; std::vector getSupportedNumGrfs() const override; + ze_rtas_format_exp_t getSupportedRTASFormat() const override; bool platformSupportsImmediateComputeFlushTask() const override; zet_debug_regset_type_intel_gpu_t getRegsetTypeForLargeGrfDetection() const override; diff --git a/level_zero/core/source/gfx_core_helpers/l0_gfx_core_helper_skl_to_tgllp.inl b/level_zero/core/source/gfx_core_helpers/l0_gfx_core_helper_skl_to_tgllp.inl index 92a52384d9..7b5d729e20 100644 --- a/level_zero/core/source/gfx_core_helpers/l0_gfx_core_helper_skl_to_tgllp.inl +++ b/level_zero/core/source/gfx_core_helpers/l0_gfx_core_helper_skl_to_tgllp.inl @@ -64,6 +64,11 @@ std::vector L0GfxCoreHelperHw::getSupportedNumGrfs() const { return {128u}; } +template +ze_rtas_format_exp_t L0GfxCoreHelperHw::getSupportedRTASFormat() const { + return ZE_RTAS_FORMAT_EXP_INVALID; +} + template bool L0GfxCoreHelperHw::platformSupportsPrimaryBatchBufferCmdList() const { return true; diff --git a/level_zero/core/source/gfx_core_helpers/l0_gfx_core_helper_xehp_and_later.inl b/level_zero/core/source/gfx_core_helpers/l0_gfx_core_helper_xehp_and_later.inl index 736197becf..66efe5ff5a 100644 --- a/level_zero/core/source/gfx_core_helpers/l0_gfx_core_helper_xehp_and_later.inl +++ b/level_zero/core/source/gfx_core_helpers/l0_gfx_core_helper_xehp_and_later.inl @@ -77,6 +77,11 @@ std::vector L0GfxCoreHelperHw::getSupportedNumGrfs() const { return {128u, 256u}; } +template +ze_rtas_format_exp_t L0GfxCoreHelperHw::getSupportedRTASFormat() const { + return static_cast(ZE_RTAS_DEVICE_FORMAT_EXP_VERSION_1); +} + template bool L0GfxCoreHelperHw::platformSupportsPrimaryBatchBufferCmdList() const { return true; diff --git a/level_zero/core/source/rtas/CMakeLists.txt b/level_zero/core/source/rtas/CMakeLists.txt new file mode 100644 index 0000000000..993a9d534a --- /dev/null +++ b/level_zero/core/source/rtas/CMakeLists.txt @@ -0,0 +1,13 @@ +# +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT +# + +target_sources(${L0_STATIC_LIB_NAME} + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/CMakeLists.txt + ${CMAKE_CURRENT_SOURCE_DIR}/rtas.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/rtas.h +) +add_subdirectories() diff --git a/level_zero/core/source/rtas/linux/CMakeLists.txt b/level_zero/core/source/rtas/linux/CMakeLists.txt new file mode 100644 index 0000000000..cf3d47aa78 --- /dev/null +++ b/level_zero/core/source/rtas/linux/CMakeLists.txt @@ -0,0 +1,13 @@ +# +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT +# + +if(UNIX) + target_sources(${L0_STATIC_LIB_NAME} + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/CMakeLists.txt + ${CMAKE_CURRENT_SOURCE_DIR}/os_rtas_enumeration.cpp + ) +endif() diff --git a/level_zero/core/source/rtas/linux/os_rtas_enumeration.cpp b/level_zero/core/source/rtas/linux/os_rtas_enumeration.cpp new file mode 100644 index 0000000000..c3d6457090 --- /dev/null +++ b/level_zero/core/source/rtas/linux/os_rtas_enumeration.cpp @@ -0,0 +1,12 @@ +/* + * Copyright (C) 2023 Intel Corporation + * + * SPDX-License-Identifier: MIT + * + */ + +#include "level_zero/core/source/rtas/rtas.h" + +namespace L0 { +std::string RTASBuilder::rtasLibraryName = "libze_intel_gpu_raytracing.so"; +} // namespace L0 diff --git a/level_zero/core/source/rtas/rtas.cpp b/level_zero/core/source/rtas/rtas.cpp new file mode 100644 index 0000000000..453c4ca083 --- /dev/null +++ b/level_zero/core/source/rtas/rtas.cpp @@ -0,0 +1,168 @@ +/* + * Copyright (C) 2023 Intel Corporation + * + * SPDX-License-Identifier: MIT + * + */ + +#include "level_zero/core/source/rtas/rtas.h" + +#include "shared/source/debug_settings/debug_settings_manager.h" +#include "shared/source/helpers/string.h" + +#include "level_zero/core/source/driver/driver_handle_imp.h" + +namespace L0 { + +const std::string zeRTASBuilderCreateExpImpl = "zeRTASBuilderCreateExpImpl"; +const std::string zeRTASBuilderDestroyExpImpl = "zeRTASBuilderDestroyExpImpl"; +const std::string zeRTASBuilderGetBuildPropertiesExpImpl = "zeRTASBuilderGetBuildPropertiesExpImpl"; +const std::string zeRTASBuilderBuildExpImpl = "zeRTASBuilderBuildExpImpl"; +const std::string zeDriverRTASFormatCompatibilityCheckExpImpl = "zeDriverRTASFormatCompatibilityCheckExpImpl"; +const std::string zeRTASParallelOperationCreateExpImpl = "zeRTASParallelOperationCreateExpImpl"; +const std::string zeRTASParallelOperationDestroyExpImpl = "zeRTASParallelOperationDestroyExpImpl"; +const std::string zeRTASParallelOperationGetPropertiesExpImpl = "zeRTASParallelOperationGetPropertiesExpImpl"; +const std::string zeRTASParallelOperationJoinExpImpl = "zeRTASParallelOperationJoinExpImpl"; + +pRTASBuilderCreateExpImpl builderCreateExpImpl; +pRTASBuilderDestroyExpImpl builderDestroyExpImpl; +pRTASBuilderGetBuildPropertiesExpImpl builderGetBuildPropertiesExpImpl; +pRTASBuilderBuildExpImpl builderBuildExpImpl; +pDriverRTASFormatCompatibilityCheckExpImpl formatCompatibilityCheckExpImpl; +pRTASParallelOperationCreateExpImpl parallelOperationCreateExpImpl; +pRTASParallelOperationDestroyExpImpl parallelOperationDestroyExpImpl; +pRTASParallelOperationGetPropertiesExpImpl parallelOperationGetPropertiesExpImpl; +pRTASParallelOperationJoinExpImpl parallelOperationJoinExpImpl; + +RTASBuilder::OsLibraryLoadPtr RTASBuilder::osLibraryLoadFunction(NEO::OsLibrary::load); + +bool RTASBuilder::loadEntryPoints(NEO::OsLibrary *libraryHandle) { + bool ok = getSymbolAddr(libraryHandle, zeRTASBuilderCreateExpImpl, builderCreateExpImpl); + ok = ok && getSymbolAddr(libraryHandle, zeRTASBuilderDestroyExpImpl, builderDestroyExpImpl); + ok = ok && getSymbolAddr(libraryHandle, zeRTASBuilderGetBuildPropertiesExpImpl, builderGetBuildPropertiesExpImpl); + ok = ok && getSymbolAddr(libraryHandle, zeRTASBuilderBuildExpImpl, builderBuildExpImpl); + ok = ok && getSymbolAddr(libraryHandle, zeDriverRTASFormatCompatibilityCheckExpImpl, formatCompatibilityCheckExpImpl); + ok = ok && getSymbolAddr(libraryHandle, zeRTASParallelOperationCreateExpImpl, parallelOperationCreateExpImpl); + ok = ok && getSymbolAddr(libraryHandle, zeRTASParallelOperationDestroyExpImpl, parallelOperationDestroyExpImpl); + ok = ok && getSymbolAddr(libraryHandle, zeRTASParallelOperationGetPropertiesExpImpl, parallelOperationGetPropertiesExpImpl); + ok = ok && getSymbolAddr(libraryHandle, zeRTASParallelOperationJoinExpImpl, parallelOperationJoinExpImpl); + + return ok; +} + +ze_result_t RTASBuilder::getProperties(const ze_rtas_builder_build_op_exp_desc_t *args, + ze_rtas_builder_exp_properties_t *pProp) { + return builderGetBuildPropertiesExpImpl(this->handleImpl, args, pProp); +} + +ze_result_t RTASBuilder::build(const ze_rtas_builder_build_op_exp_desc_t *args, + void *pScratchBuffer, size_t scratchBufferSizeBytes, + void *pRtasBuffer, size_t rtasBufferSizeBytes, + ze_rtas_parallel_operation_exp_handle_t hParallelOperation, + void *pBuildUserPtr, ze_rtas_aabb_exp_t *pBounds, + size_t *pRtasBufferSizeBytes) { + RTASParallelOperation *parallelOperation = RTASParallelOperation::fromHandle(hParallelOperation); + + return builderBuildExpImpl(this->handleImpl, + args, + pScratchBuffer, scratchBufferSizeBytes, + pRtasBuffer, rtasBufferSizeBytes, + parallelOperation->handleImpl, + pBuildUserPtr, pBounds, + pRtasBufferSizeBytes); +} + +ze_result_t DriverHandleImp::formatRTASCompatibilityCheck(ze_rtas_format_exp_t rtasFormatA, + ze_rtas_format_exp_t rtasFormatB) { + ze_result_t result = this->loadRTASLibrary(); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + + return formatCompatibilityCheckExpImpl(this->toHandle(), rtasFormatA, rtasFormatB); +} + +ze_result_t DriverHandleImp::createRTASBuilder(const ze_rtas_builder_exp_desc_t *desc, + ze_rtas_builder_exp_handle_t *phBuilder) { + ze_result_t result = this->loadRTASLibrary(); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + + auto pRTASBuilder = std::make_unique(); + + result = builderCreateExpImpl(this->toHandle(), desc, &pRTASBuilder->handleImpl); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + + *phBuilder = pRTASBuilder.release(); + return ZE_RESULT_SUCCESS; +} + +ze_result_t RTASBuilder::destroy() { + ze_result_t result = builderDestroyExpImpl(this->handleImpl); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + + delete this; + return ZE_RESULT_SUCCESS; +} + +ze_result_t DriverHandleImp::loadRTASLibrary() { + std::lock_guard lock(this->rtasLock); + + if (this->rtasLibraryUnavailable == true) { + return ZE_RESULT_ERROR_DEPENDENCY_UNAVAILABLE; + } + + if (this->rtasLibraryHandle == nullptr) { + this->rtasLibraryHandle = std::unique_ptr(RTASBuilder::osLibraryLoadFunction(RTASBuilder::rtasLibraryName)); + if (this->rtasLibraryHandle == nullptr || RTASBuilder::loadEntryPoints(this->rtasLibraryHandle.get()) == false) { + this->rtasLibraryUnavailable = true; + + PRINT_DEBUG_STRING(NEO::DebugManager.flags.PrintDebugMessages.get(), stderr, "Failed to load Ray Tracing Support Library %s\n", RTASBuilder::rtasLibraryName.c_str()); + return ZE_RESULT_ERROR_DEPENDENCY_UNAVAILABLE; + } + } + + return ZE_RESULT_SUCCESS; +} + +ze_result_t DriverHandleImp::createRTASParallelOperation(ze_rtas_parallel_operation_exp_handle_t *phParallelOperation) { + ze_result_t result = this->loadRTASLibrary(); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + + auto pRTASParallelOperation = std::make_unique(); + + result = parallelOperationCreateExpImpl(this->toHandle(), &pRTASParallelOperation->handleImpl); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + + *phParallelOperation = pRTASParallelOperation.release(); + return ZE_RESULT_SUCCESS; +} + +ze_result_t RTASParallelOperation::destroy() { + ze_result_t result = parallelOperationDestroyExpImpl(this->handleImpl); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + + delete this; + return ZE_RESULT_SUCCESS; +} + +ze_result_t RTASParallelOperation::getProperties(ze_rtas_parallel_operation_exp_properties_t *pProperties) { + return parallelOperationGetPropertiesExpImpl(this->handleImpl, pProperties); +} + +ze_result_t RTASParallelOperation::join() { + return parallelOperationJoinExpImpl(this->handleImpl); +} + +} // namespace L0 diff --git a/level_zero/core/source/rtas/rtas.h b/level_zero/core/source/rtas/rtas.h new file mode 100644 index 0000000000..aae6d3ca2c --- /dev/null +++ b/level_zero/core/source/rtas/rtas.h @@ -0,0 +1,112 @@ +/* + * Copyright (C) 2023 Intel Corporation + * + * SPDX-License-Identifier: MIT + * + */ + +#pragma once + +#include "shared/source/os_interface/os_library.h" + +#include + +#include + +struct _ze_rtas_builder_exp_handle_t {}; +struct _ze_rtas_parallel_operation_exp_handle_t {}; + +namespace L0 { +/* + * Note: RTAS Library using same headers as Level Zero, but using function + * pointers to access these symbols in the external Library. + */ +typedef ze_result_t (*pRTASBuilderCreateExpImpl)(ze_driver_handle_t hDriver, + const ze_rtas_builder_exp_desc_t *pDescriptor, + ze_rtas_builder_exp_handle_t *phBuilder); + +typedef ze_result_t (*pRTASBuilderDestroyExpImpl)(ze_rtas_builder_exp_handle_t hBuilder); + +typedef ze_result_t (*pRTASBuilderGetBuildPropertiesExpImpl)(ze_rtas_builder_exp_handle_t hBuilder, + const ze_rtas_builder_build_op_exp_desc_t *args, + ze_rtas_builder_exp_properties_t *pProp); + +typedef ze_result_t (*pRTASBuilderBuildExpImpl)(ze_rtas_builder_exp_handle_t hBuilder, + const ze_rtas_builder_build_op_exp_desc_t *args, + void *pScratchBuffer, size_t scratchBufferSizeBytes, + void *pRtasBuffer, size_t rtasBufferSizeBytes, + ze_rtas_parallel_operation_exp_handle_t hParallelOperation, + void *pBuildUserPtr, ze_rtas_aabb_exp_t *pBounds, + size_t *pRtasBufferSizeBytes); + +typedef ze_result_t (*pDriverRTASFormatCompatibilityCheckExpImpl)(ze_driver_handle_t hDriver, + const ze_rtas_format_exp_t accelFormat, + const ze_rtas_format_exp_t otherAccelFormat); + +typedef ze_result_t (*pRTASParallelOperationCreateExpImpl)(ze_driver_handle_t hDriver, + ze_rtas_parallel_operation_exp_handle_t *phParallelOperation); + +typedef ze_result_t (*pRTASParallelOperationDestroyExpImpl)(ze_rtas_parallel_operation_exp_handle_t hParallelOperation); + +typedef ze_result_t (*pRTASParallelOperationGetPropertiesExpImpl)(ze_rtas_parallel_operation_exp_handle_t hParallelOperation, + ze_rtas_parallel_operation_exp_properties_t *pProperties); + +typedef ze_result_t (*pRTASParallelOperationJoinExpImpl)(ze_rtas_parallel_operation_exp_handle_t hParallelOperation); + +extern pRTASBuilderCreateExpImpl builderCreateExpImpl; +extern pRTASBuilderDestroyExpImpl builderDestroyExpImpl; +extern pRTASBuilderGetBuildPropertiesExpImpl builderGetBuildPropertiesExpImpl; +extern pRTASBuilderBuildExpImpl builderBuildExpImpl; +extern pDriverRTASFormatCompatibilityCheckExpImpl formatCompatibilityCheckExpImpl; +extern pRTASParallelOperationCreateExpImpl parallelOperationCreateExpImpl; +extern pRTASParallelOperationDestroyExpImpl parallelOperationDestroyExpImpl; +extern pRTASParallelOperationGetPropertiesExpImpl parallelOperationGetPropertiesExpImpl; +extern pRTASParallelOperationJoinExpImpl parallelOperationJoinExpImpl; + +struct RTASBuilder : _ze_rtas_builder_exp_handle_t { + public: + virtual ~RTASBuilder() = default; + + ze_result_t destroy(); + ze_result_t getProperties(const ze_rtas_builder_build_op_exp_desc_t *args, + ze_rtas_builder_exp_properties_t *pProp); + ze_result_t build(const ze_rtas_builder_build_op_exp_desc_t *args, + void *pScratchBuffer, size_t scratchBufferSizeBytes, + void *pRtasBuffer, size_t rtasBufferSizeBytes, + ze_rtas_parallel_operation_exp_handle_t hParallelOperation, + void *pBuildUserPtr, ze_rtas_aabb_exp_t *pBounds, + size_t *pRtasBufferSizeBytes); + + static RTASBuilder *fromHandle(ze_rtas_builder_exp_handle_t handle) { return static_cast(handle); } + inline ze_rtas_builder_exp_handle_t toHandle() { return this; } + + using OsLibraryLoadPtr = std::add_pointer::type; + static OsLibraryLoadPtr osLibraryLoadFunction; + static std::string rtasLibraryName; + static bool loadEntryPoints(NEO::OsLibrary *libraryHandle); + + template + static bool getSymbolAddr(NEO::OsLibrary *libraryHandle, const std::string name, T &proc) { + void *addr = libraryHandle->getProcAddress(name); + proc = reinterpret_cast(addr); + return nullptr != proc; + } + + ze_rtas_builder_exp_handle_t handleImpl; +}; + +struct RTASParallelOperation : _ze_rtas_parallel_operation_exp_handle_t { + public: + virtual ~RTASParallelOperation() = default; + + ze_result_t destroy(); + ze_result_t getProperties(ze_rtas_parallel_operation_exp_properties_t *pProperties); + ze_result_t join(); + + static RTASParallelOperation *fromHandle(ze_rtas_parallel_operation_exp_handle_t handle) { return static_cast(handle); } + inline ze_rtas_parallel_operation_exp_handle_t toHandle() { return this; } + + ze_rtas_parallel_operation_exp_handle_t handleImpl; +}; + +} // namespace L0 diff --git a/level_zero/core/source/rtas/windows/CMakeLists.txt b/level_zero/core/source/rtas/windows/CMakeLists.txt new file mode 100644 index 0000000000..d43260f2e3 --- /dev/null +++ b/level_zero/core/source/rtas/windows/CMakeLists.txt @@ -0,0 +1,13 @@ +# +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT +# + +if(WIN32) + target_sources(${L0_STATIC_LIB_NAME} + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/CMakeLists.txt + ${CMAKE_CURRENT_SOURCE_DIR}/os_rtas_enumeration.cpp + ) +endif() diff --git a/level_zero/core/source/rtas/windows/os_rtas_enumeration.cpp b/level_zero/core/source/rtas/windows/os_rtas_enumeration.cpp new file mode 100644 index 0000000000..ce4c4d4b82 --- /dev/null +++ b/level_zero/core/source/rtas/windows/os_rtas_enumeration.cpp @@ -0,0 +1,12 @@ +/* + * Copyright (C) 2023 Intel Corporation + * + * SPDX-License-Identifier: MIT + * + */ + +#include "level_zero/core/source/rtas/rtas.h" + +namespace L0 { +std::string RTASBuilder::rtasLibraryName = "ze_intel_gpu_raytracing.dll"; +} // namespace L0 diff --git a/level_zero/core/test/unit_tests/gen11/test_l0_gfx_core_helper_gen11.cpp b/level_zero/core/test/unit_tests/gen11/test_l0_gfx_core_helper_gen11.cpp index 402cc020a2..3563e4dc6e 100644 --- a/level_zero/core/test/unit_tests/gen11/test_l0_gfx_core_helper_gen11.cpp +++ b/level_zero/core/test/unit_tests/gen11/test_l0_gfx_core_helper_gen11.cpp @@ -56,5 +56,10 @@ GEN11TEST_F(L0GfxCoreHelperTestGen11, GivenGen11WhenCheckingL0HelperForCmdlistPr EXPECT_TRUE(l0GfxCoreHelper.platformSupportsPrimaryBatchBufferCmdList()); } +GEN11TEST_F(L0GfxCoreHelperTestGen11, GivenGen11WhenGettingSupportedRTASFormatThenExpectedFormatIsReturned) { + const auto &l0GfxCoreHelper = getHelper(); + EXPECT_EQ(ZE_RTAS_FORMAT_EXP_INVALID, l0GfxCoreHelper.getSupportedRTASFormat()); +} + } // namespace ult } // namespace L0 diff --git a/level_zero/core/test/unit_tests/gen12lp/test_l0_gfx_core_helper_gen12lp.cpp b/level_zero/core/test/unit_tests/gen12lp/test_l0_gfx_core_helper_gen12lp.cpp index 60f17db918..523afe812f 100644 --- a/level_zero/core/test/unit_tests/gen12lp/test_l0_gfx_core_helper_gen12lp.cpp +++ b/level_zero/core/test/unit_tests/gen12lp/test_l0_gfx_core_helper_gen12lp.cpp @@ -73,5 +73,10 @@ GEN12LPTEST_F(L0GfxCoreHelperTestGen12Lp, GivenGen12LpWhenCheckingL0HelperForCmd EXPECT_TRUE(l0GfxCoreHelper.platformSupportsPrimaryBatchBufferCmdList()); } +GEN12LPTEST_F(L0GfxCoreHelperTestGen12Lp, GivenGen12LpWhenGettingSupportedRTASFormatThenExpectedFormatIsReturned) { + const auto &l0GfxCoreHelper = getHelper(); + EXPECT_EQ(ZE_RTAS_FORMAT_EXP_INVALID, l0GfxCoreHelper.getSupportedRTASFormat()); +} + } // namespace ult } // namespace L0 diff --git a/level_zero/core/test/unit_tests/gen9/test_l0_gfx_core_helper_gen9.cpp b/level_zero/core/test/unit_tests/gen9/test_l0_gfx_core_helper_gen9.cpp index dd9dc50b47..d06f049a3b 100644 --- a/level_zero/core/test/unit_tests/gen9/test_l0_gfx_core_helper_gen9.cpp +++ b/level_zero/core/test/unit_tests/gen9/test_l0_gfx_core_helper_gen9.cpp @@ -56,5 +56,10 @@ GEN9TEST_F(L0GfxCoreHelperTestGen9, GivenGen9WhenCheckingL0HelperForCmdlistPrima EXPECT_TRUE(l0GfxCoreHelper.platformSupportsPrimaryBatchBufferCmdList()); } +GEN9TEST_F(L0GfxCoreHelperTestGen9, GivenGen9WhenGettingSupportedRTASFormatThenExpectedFormatIsReturned) { + const auto &l0GfxCoreHelper = getHelper(); + EXPECT_EQ(ZE_RTAS_FORMAT_EXP_INVALID, l0GfxCoreHelper.getSupportedRTASFormat()); +} + } // namespace ult } // namespace L0 diff --git a/level_zero/core/test/unit_tests/sources/device/test_l0_device.cpp b/level_zero/core/test/unit_tests/sources/device/test_l0_device.cpp index 585108e5e6..6781e58bf7 100644 --- a/level_zero/core/test/unit_tests/sources/device/test_l0_device.cpp +++ b/level_zero/core/test/unit_tests/sources/device/test_l0_device.cpp @@ -41,6 +41,7 @@ #include "level_zero/core/source/fabric/fabric.h" #include "level_zero/core/source/gfx_core_helpers/l0_gfx_core_helper.h" #include "level_zero/core/source/image/image.h" +#include "level_zero/core/source/rtas/rtas.h" #include "level_zero/core/test/unit_tests/fixtures/device_fixture.h" #include "level_zero/core/test/unit_tests/mocks/mock_built_ins.h" #include "level_zero/core/test/unit_tests/mocks/mock_cmdlist.h" @@ -4934,5 +4935,160 @@ TEST_F(DeviceTest, GivenValidDeviceWhenQueryingKernelTimestampsProptertiesThenCo EXPECT_NE(0u, tsProps.flags & ZE_EVENT_QUERY_KERNEL_TIMESTAMPS_EXT_FLAG_SYNCHRONIZED); } +struct RTASDeviceTest : public ::testing::Test { + void SetUp() override { + DebugManager.flags.CreateMultipleRootDevices.set(numRootDevices); + neoDevice = NEO::MockDevice::createWithNewExecutionEnvironment(NEO::defaultHwInfo.get(), rootDeviceIndex); + execEnv = neoDevice->getExecutionEnvironment(); + execEnv->incRefInternal(); + NEO::DeviceVector devices; + devices.push_back(std::unique_ptr(neoDevice)); + driverHandle = std::make_unique>(); + driverHandle->initialize(std::move(devices)); + device = driverHandle->devices[0]; + } + + void TearDown() override { + driverHandle.reset(nullptr); + execEnv->decRefInternal(); + } + + struct MockOsLibrary : public OsLibrary { + public: + MockOsLibrary(const std::string &name, std::string *errorValue) { + } + MockOsLibrary() {} + ~MockOsLibrary() override = default; + void *getProcAddress(const std::string &procName) override { + if (failGetProcAddress) { + return nullptr; + } + return reinterpret_cast(0x1234); + } + bool isLoaded() override { + return libraryLoaded; + } + std::string getFullPath() override { + return std::string(); + } + static OsLibrary *load(const std::string &name) { + if (failLibraryLoad) { + return nullptr; + } + auto ptr = new (std::nothrow) MockOsLibrary(); + return ptr; + } + + static bool libraryLoaded; + static bool failLibraryLoad; + static bool failGetProcAddress; + }; + + DebugManagerStateRestore restorer; + std::unique_ptr> driverHandle; + NEO::ExecutionEnvironment *execEnv; + NEO::Device *neoDevice = nullptr; + L0::Device *device = nullptr; + const uint32_t rootDeviceIndex = 1u; + const uint32_t numRootDevices = 2u; +}; + +bool RTASDeviceTest::MockOsLibrary::libraryLoaded = false; +bool RTASDeviceTest::MockOsLibrary::failLibraryLoad = false; +bool RTASDeviceTest::MockOsLibrary::failGetProcAddress = false; + +HWTEST2_F(RTASDeviceTest, GivenValidRTASLibraryWhenQueryingRTASProptertiesThenCorrectPropertiesIsReturned, MatchAny) { + MockOsLibrary::libraryLoaded = false; + MockOsLibrary::failLibraryLoad = false; + MockOsLibrary::failGetProcAddress = false; + + ze_device_properties_t devProps = {}; + ze_rtas_device_exp_properties_t rtasProperties = {}; + L0::RTASBuilder::osLibraryLoadFunction = MockOsLibrary::load; + driverHandle->rtasLibraryHandle.reset(); + + devProps.stype = ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES; + devProps.pNext = &rtasProperties; + + rtasProperties.stype = ZE_STRUCTURE_TYPE_RTAS_DEVICE_EXP_PROPERTIES; + rtasProperties.pNext = nullptr; + + EXPECT_EQ(ZE_RESULT_SUCCESS, zeDeviceGetProperties(device, &devProps)); + EXPECT_EQ(128u, rtasProperties.rtasBufferAlignment); + + auto &l0GfxCoreHelper = this->neoDevice->getRootDeviceEnvironment().getHelper(); + + if (l0GfxCoreHelper.platformSupportsRayTracing()) { + EXPECT_NE(ZE_RTAS_FORMAT_EXP_INVALID, rtasProperties.rtasFormat); + } +} + +HWTEST2_F(RTASDeviceTest, GivenRTASLibraryPreLoadedWhenQueryingRTASProptertiesThenCorrectPropertiesIsReturned, MatchAny) { + MockOsLibrary::libraryLoaded = false; + MockOsLibrary::failLibraryLoad = false; + MockOsLibrary::failGetProcAddress = false; + + ze_device_properties_t devProps = {}; + ze_rtas_device_exp_properties_t rtasProperties = {}; + driverHandle->rtasLibraryHandle = std::make_unique(); + + devProps.stype = ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES; + devProps.pNext = &rtasProperties; + + rtasProperties.stype = ZE_STRUCTURE_TYPE_RTAS_DEVICE_EXP_PROPERTIES; + rtasProperties.pNext = nullptr; + + EXPECT_EQ(ZE_RESULT_SUCCESS, zeDeviceGetProperties(device, &devProps)); + EXPECT_EQ(128u, rtasProperties.rtasBufferAlignment); + + auto &l0GfxCoreHelper = this->neoDevice->getRootDeviceEnvironment().getHelper(); + + if (l0GfxCoreHelper.platformSupportsRayTracing()) { + EXPECT_NE(ZE_RTAS_FORMAT_EXP_INVALID, rtasProperties.rtasFormat); + } +} + +HWTEST2_F(RTASDeviceTest, GivenInvalidRTASLibraryWhenQueryingRTASProptertiesThenCorrectPropertiesIsReturned, MatchAny) { + MockOsLibrary::libraryLoaded = false; + MockOsLibrary::failLibraryLoad = true; + MockOsLibrary::failGetProcAddress = true; + + ze_device_properties_t devProps = {}; + ze_rtas_device_exp_properties_t rtasProperties = {}; + L0::RTASBuilder::osLibraryLoadFunction = MockOsLibrary::load; + driverHandle->rtasLibraryHandle.reset(); + + devProps.stype = ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES; + devProps.pNext = &rtasProperties; + + rtasProperties.stype = ZE_STRUCTURE_TYPE_RTAS_DEVICE_EXP_PROPERTIES; + rtasProperties.pNext = nullptr; + + EXPECT_EQ(ZE_RESULT_SUCCESS, zeDeviceGetProperties(device, &devProps)); + EXPECT_EQ(128u, rtasProperties.rtasBufferAlignment); + EXPECT_EQ(ZE_RTAS_FORMAT_EXP_INVALID, rtasProperties.rtasFormat); +} + +HWTEST2_F(RTASDeviceTest, GivenMissingSymbolsInRTASLibraryWhenQueryingRTASProptertiesThenCorrectPropertiesIsReturned, MatchAny) { + MockOsLibrary::libraryLoaded = false; + MockOsLibrary::failLibraryLoad = false; + MockOsLibrary::failGetProcAddress = true; + + ze_device_properties_t devProps = {}; + ze_rtas_device_exp_properties_t rtasProperties = {}; + L0::RTASBuilder::osLibraryLoadFunction = MockOsLibrary::load; + driverHandle->rtasLibraryHandle.reset(); + + devProps.stype = ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES; + devProps.pNext = &rtasProperties; + + rtasProperties.stype = ZE_STRUCTURE_TYPE_RTAS_DEVICE_EXP_PROPERTIES; + rtasProperties.pNext = nullptr; + + EXPECT_EQ(ZE_RESULT_SUCCESS, zeDeviceGetProperties(device, &devProps)); + EXPECT_EQ(128u, rtasProperties.rtasBufferAlignment); + EXPECT_EQ(ZE_RTAS_FORMAT_EXP_INVALID, rtasProperties.rtasFormat); +} + } // namespace ult } // namespace L0 diff --git a/level_zero/core/test/unit_tests/sources/rtas/CMakeLists.txt b/level_zero/core/test/unit_tests/sources/rtas/CMakeLists.txt new file mode 100644 index 0000000000..0f93d5bfd8 --- /dev/null +++ b/level_zero/core/test/unit_tests/sources/rtas/CMakeLists.txt @@ -0,0 +1,11 @@ +# +# Copyright (C) 2022-2023 Intel Corporation +# +# SPDX-License-Identifier: MIT +# + +target_sources(${TARGET_NAME} PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/CMakeLists.txt + ${CMAKE_CURRENT_SOURCE_DIR}/test_rtas.cpp +) +add_subdirectories() \ No newline at end of file diff --git a/level_zero/core/test/unit_tests/sources/rtas/test_rtas.cpp b/level_zero/core/test/unit_tests/sources/rtas/test_rtas.cpp new file mode 100644 index 0000000000..86f8814c1a --- /dev/null +++ b/level_zero/core/test/unit_tests/sources/rtas/test_rtas.cpp @@ -0,0 +1,585 @@ +/* + * Copyright (C) 2023 Intel Corporation + * + * SPDX-License-Identifier: MIT + * + */ + +#include "shared/test/common/test_macros/test.h" + +#include "level_zero/core/source/rtas/rtas.h" +#include "level_zero/core/test/unit_tests/fixtures/device_fixture.h" + +namespace L0 { +namespace ult { + +struct RTASFixture : public DeviceFixture { + void setUp() { + DeviceFixture::setUp(); + builderCreateCalled = 0; + builderCreateCalled = 0; + builderCreateFailCalled = 0; + builderDestroyCalled = 0; + builderDestroyFailCalled = 0; + builderGetBuildPropertiesCalled = 0; + builderGetBuildPropertiesFailCalled = 0; + builderBuildCalled = 0; + builderBuildFailCalled = 0; + formatCompatibilityCheckCalled = 0; + formatCompatibilityCheckFailCalled = 0; + parallelOperationDestroyCalled = 0; + parallelOperationDestroyFailCalled = 0; + parallelOperationCreateCalled = 0; + parallelOperationCreateFailCalled = 0; + parallelOperationGetPropertiesCalled = 0; + parallelOperationGetPropertiesFailCalled = 0; + parallelOperationJoinCalled = 0; + parallelOperationJoinFailCalled = 0; + } + + void tearDown() { + DeviceFixture::tearDown(); + } + + static ze_result_t builderCreate(ze_driver_handle_t hDriver, + const ze_rtas_builder_exp_desc_t *pDescriptor, + ze_rtas_builder_exp_handle_t *phBuilder) { + builderCreateCalled++; + return ZE_RESULT_SUCCESS; + } + + static ze_result_t builderCreateFail(ze_driver_handle_t hDriver, + const ze_rtas_builder_exp_desc_t *pDescriptor, + ze_rtas_builder_exp_handle_t *phBuilder) { + builderCreateFailCalled++; + return ZE_RESULT_ERROR_UNKNOWN; + } + + static ze_result_t builderDestroy(ze_rtas_builder_exp_handle_t hBuilder) { + builderDestroyCalled++; + return ZE_RESULT_SUCCESS; + } + + static ze_result_t builderDestroyFail(ze_rtas_builder_exp_handle_t hBuilder) { + builderDestroyFailCalled++; + return ZE_RESULT_ERROR_UNKNOWN; + } + + static ze_result_t builderGetBuildProperties(ze_rtas_builder_exp_handle_t hBuilder, + const ze_rtas_builder_build_op_exp_desc_t *args, + ze_rtas_builder_exp_properties_t *pProp) { + builderGetBuildPropertiesCalled++; + return ZE_RESULT_SUCCESS; + } + + static ze_result_t builderGetBuildPropertiesFail(ze_rtas_builder_exp_handle_t hBuilder, + const ze_rtas_builder_build_op_exp_desc_t *args, + ze_rtas_builder_exp_properties_t *pProp) { + builderGetBuildPropertiesFailCalled++; + return ZE_RESULT_ERROR_UNKNOWN; + } + + static ze_result_t builderBuild(ze_rtas_builder_exp_handle_t hBuilder, + const ze_rtas_builder_build_op_exp_desc_t *args, + void *pScratchBuffer, size_t scratchBufferSizeBytes, + void *pRtasBuffer, size_t rtasBufferSizeBytes, + ze_rtas_parallel_operation_exp_handle_t hParallelOperation, + void *pBuildUserPtr, ze_rtas_aabb_exp_t *pBounds, + size_t *pRtasBufferSizeBytes) { + builderBuildCalled++; + return ZE_RESULT_SUCCESS; + } + + static ze_result_t builderBuildFail(ze_rtas_builder_exp_handle_t hBuilder, + const ze_rtas_builder_build_op_exp_desc_t *args, + void *pScratchBuffer, size_t scratchBufferSizeBytes, + void *pRtasBuffer, size_t rtasBufferSizeBytes, + ze_rtas_parallel_operation_exp_handle_t hParallelOperation, + void *pBuildUserPtr, ze_rtas_aabb_exp_t *pBounds, + size_t *pRtasBufferSizeBytes) { + builderBuildFailCalled++; + return ZE_RESULT_ERROR_UNKNOWN; + } + + static ze_result_t formatCompatibilityCheck(ze_driver_handle_t hDriver, + const ze_rtas_format_exp_t accelFormat, + const ze_rtas_format_exp_t otherAccelFormat) { + formatCompatibilityCheckCalled++; + return ZE_RESULT_SUCCESS; + } + + static ze_result_t formatCompatibilityCheckFail(ze_driver_handle_t hDriver, + const ze_rtas_format_exp_t accelFormat, + const ze_rtas_format_exp_t otherAccelFormat) { + formatCompatibilityCheckFailCalled++; + return ZE_RESULT_ERROR_UNKNOWN; + } + + static ze_result_t parallelOperationDestroy(ze_rtas_parallel_operation_exp_handle_t hParallelOperation) { + parallelOperationDestroyCalled++; + return ZE_RESULT_SUCCESS; + } + + static ze_result_t parallelOperationDestroyFail(ze_rtas_parallel_operation_exp_handle_t hParallelOperation) { + parallelOperationDestroyFailCalled++; + return ZE_RESULT_ERROR_UNKNOWN; + } + + static ze_result_t parallelOperationCreate(ze_driver_handle_t hDriver, + ze_rtas_parallel_operation_exp_handle_t *phParallelOperation) { + parallelOperationCreateCalled++; + return ZE_RESULT_SUCCESS; + } + + static ze_result_t parallelOperationCreateFail(ze_driver_handle_t hDriver, + ze_rtas_parallel_operation_exp_handle_t *phParallelOperation) { + parallelOperationCreateFailCalled++; + return ZE_RESULT_ERROR_UNKNOWN; + } + + static ze_result_t parallelOperationGetProperties(ze_rtas_parallel_operation_exp_handle_t hParallelOperation, + ze_rtas_parallel_operation_exp_properties_t *pProperties) { + parallelOperationGetPropertiesCalled++; + return ZE_RESULT_SUCCESS; + } + + static ze_result_t parallelOperationGetPropertiesFail(ze_rtas_parallel_operation_exp_handle_t hParallelOperation, + ze_rtas_parallel_operation_exp_properties_t *pProperties) { + parallelOperationGetPropertiesFailCalled++; + return ZE_RESULT_ERROR_UNKNOWN; + } + + static ze_result_t parallelOperationJoin(ze_rtas_parallel_operation_exp_handle_t hParallelOperation) { + parallelOperationJoinCalled++; + return ZE_RESULT_SUCCESS; + } + + static ze_result_t parallelOperationJoinFail(ze_rtas_parallel_operation_exp_handle_t hParallelOperation) { + parallelOperationJoinFailCalled++; + return ZE_RESULT_ERROR_UNKNOWN; + } + + static uint32_t builderCreateCalled; + static uint32_t builderCreateFailCalled; + static uint32_t builderDestroyCalled; + static uint32_t builderDestroyFailCalled; + static uint32_t builderGetBuildPropertiesCalled; + static uint32_t builderGetBuildPropertiesFailCalled; + static uint32_t builderBuildCalled; + static uint32_t builderBuildFailCalled; + static uint32_t formatCompatibilityCheckCalled; + static uint32_t formatCompatibilityCheckFailCalled; + static uint32_t parallelOperationDestroyCalled; + static uint32_t parallelOperationDestroyFailCalled; + static uint32_t parallelOperationCreateCalled; + static uint32_t parallelOperationCreateFailCalled; + static uint32_t parallelOperationGetPropertiesCalled; + static uint32_t parallelOperationGetPropertiesFailCalled; + static uint32_t parallelOperationJoinCalled; + static uint32_t parallelOperationJoinFailCalled; +}; + +uint32_t RTASFixture::builderCreateCalled = 0; +uint32_t RTASFixture::builderCreateFailCalled = 0; +uint32_t RTASFixture::builderDestroyCalled = 0; +uint32_t RTASFixture::builderDestroyFailCalled = 0; +uint32_t RTASFixture::builderGetBuildPropertiesCalled = 0; +uint32_t RTASFixture::builderGetBuildPropertiesFailCalled = 0; +uint32_t RTASFixture::builderBuildCalled = 0; +uint32_t RTASFixture::builderBuildFailCalled = 0; +uint32_t RTASFixture::formatCompatibilityCheckCalled = 0; +uint32_t RTASFixture::formatCompatibilityCheckFailCalled = 0; +uint32_t RTASFixture::parallelOperationDestroyCalled = 0; +uint32_t RTASFixture::parallelOperationDestroyFailCalled = 0; +uint32_t RTASFixture::parallelOperationCreateCalled = 0; +uint32_t RTASFixture::parallelOperationCreateFailCalled = 0; +uint32_t RTASFixture::parallelOperationGetPropertiesCalled = 0; +uint32_t RTASFixture::parallelOperationGetPropertiesFailCalled = 0; +uint32_t RTASFixture::parallelOperationJoinCalled = 0; +uint32_t RTASFixture::parallelOperationJoinFailCalled = 0; + +using RTASTest = Test; + +struct MockRTASOsLibrary : public OsLibrary { + public: + static bool mockLoad; + MockRTASOsLibrary(const std::string &name, std::string *errorValue) { + } + MockRTASOsLibrary() {} + ~MockRTASOsLibrary() override = default; + void *getProcAddress(const std::string &procName) override { + auto it = funcMap.find(procName); + if (funcMap.end() == it) { + return nullptr; + } else { + return it->second; + } + } + bool isLoaded() override { + return false; + } + std::string getFullPath() override { + return std::string(); + } + static OsLibrary *load(const std::string &name) { + if (mockLoad == true) { + auto ptr = new (std::nothrow) MockRTASOsLibrary(); + return ptr; + } else { + return nullptr; + } + } + std::map funcMap; +}; + +bool MockRTASOsLibrary::mockLoad = true; + +TEST_F(RTASTest, GivenLibraryLoadsSymbolsAndUnderlyingFunctionsSucceedThenSuccessIsReturned) { + struct MockSymbolsLoadedOsLibrary : public OsLibrary { + public: + MockSymbolsLoadedOsLibrary(const std::string &name, std::string *errorValue) { + } + MockSymbolsLoadedOsLibrary() {} + ~MockSymbolsLoadedOsLibrary() override = default; + void *getProcAddress(const std::string &procName) override { + funcMap["zeRTASBuilderCreateExpImpl"] = reinterpret_cast(&builderCreate); + funcMap["zeRTASBuilderDestroyExpImpl"] = reinterpret_cast(&builderDestroy); + funcMap["zeRTASBuilderGetBuildPropertiesExpImpl"] = reinterpret_cast(&builderGetBuildProperties); + funcMap["zeRTASBuilderBuildExpImpl"] = reinterpret_cast(&builderBuild); + funcMap["zeDriverRTASFormatCompatibilityCheckExpImpl"] = reinterpret_cast(&formatCompatibilityCheck); + funcMap["zeRTASParallelOperationCreateExpImpl"] = reinterpret_cast(¶llelOperationCreate); + funcMap["zeRTASParallelOperationDestroyExpImpl"] = reinterpret_cast(¶llelOperationDestroy); + funcMap["zeRTASParallelOperationGetPropertiesExpImpl"] = reinterpret_cast(¶llelOperationGetProperties); + funcMap["zeRTASParallelOperationJoinExpImpl"] = reinterpret_cast(¶llelOperationJoin); + auto it = funcMap.find(procName); + if (funcMap.end() == it) { + return nullptr; + } else { + return it->second; + } + } + bool isLoaded() override { + return true; + } + std::string getFullPath() override { + return std::string(); + } + static OsLibrary *load(const std::string &name) { + auto ptr = new (std::nothrow) MockSymbolsLoadedOsLibrary(); + return ptr; + } + std::map funcMap; + }; + + ze_rtas_builder_exp_handle_t hBuilder; + ze_rtas_parallel_operation_exp_handle_t hParallelOperation; + const ze_rtas_format_exp_t accelFormatA = {}; + const ze_rtas_format_exp_t accelFormatB = {}; + L0::RTASBuilder::osLibraryLoadFunction = MockSymbolsLoadedOsLibrary::load; + driverHandle->rtasLibraryHandle.reset(); + + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeRTASBuilderCreateExp(driverHandle->toHandle(), nullptr, &hBuilder)); + EXPECT_EQ(1u, builderCreateCalled); + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeRTASBuilderDestroyExp(hBuilder)); + EXPECT_EQ(1u, builderDestroyCalled); + driverHandle->rtasLibraryHandle.reset(); + + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeRTASParallelOperationCreateExp(driverHandle->toHandle(), &hParallelOperation)); + EXPECT_EQ(1u, parallelOperationCreateCalled); + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeRTASParallelOperationDestroyExp(hParallelOperation)); + EXPECT_EQ(1u, parallelOperationDestroyCalled); + driverHandle->rtasLibraryHandle.reset(); + + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeDriverRTASFormatCompatibilityCheckExp(driverHandle->toHandle(), accelFormatA, accelFormatB)); + EXPECT_EQ(1u, formatCompatibilityCheckCalled); + + driverHandle->rtasLibraryHandle.reset(); +} + +TEST_F(RTASTest, GivenLibraryFailedToLoadSymbolsThenErrorIsReturned) { + ze_rtas_builder_exp_handle_t hBuilder; + ze_rtas_parallel_operation_exp_handle_t hParallelOperation; + const ze_rtas_format_exp_t accelFormatA = {}; + const ze_rtas_format_exp_t accelFormatB = {}; + L0::RTASBuilder::osLibraryLoadFunction = MockRTASOsLibrary::load; + driverHandle->rtasLibraryHandle.reset(); + + EXPECT_EQ(ZE_RESULT_ERROR_DEPENDENCY_UNAVAILABLE, L0::zeRTASBuilderCreateExp(driverHandle->toHandle(), nullptr, &hBuilder)); + EXPECT_EQ(ZE_RESULT_ERROR_DEPENDENCY_UNAVAILABLE, L0::zeRTASParallelOperationCreateExp(driverHandle->toHandle(), &hParallelOperation)); + EXPECT_EQ(ZE_RESULT_ERROR_DEPENDENCY_UNAVAILABLE, L0::zeDriverRTASFormatCompatibilityCheckExp(driverHandle->toHandle(), accelFormatA, accelFormatB)); +} + +TEST_F(RTASTest, GivenLibraryPreLoadedAndUnderlyingBuilderCreateSucceedsThenSuccessIsReturned) { + ze_rtas_builder_exp_handle_t hBuilder; + builderCreateExpImpl = &builderCreate; + builderDestroyExpImpl = &builderDestroy; + driverHandle->rtasLibraryHandle = std::make_unique(); + + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeRTASBuilderCreateExp(driverHandle->toHandle(), nullptr, &hBuilder)); + EXPECT_EQ(1u, builderCreateCalled); + + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeRTASBuilderDestroyExp(hBuilder)); + EXPECT_EQ(1u, builderDestroyCalled); +} + +TEST_F(RTASTest, GivenLibraryPreLoadedAndUnderlyingBuilderCreateFailsThenErrorIsReturned) { + ze_rtas_builder_exp_handle_t hBuilder; + builderCreateExpImpl = &builderCreateFail; + driverHandle->rtasLibraryHandle = std::make_unique(); + + EXPECT_EQ(ZE_RESULT_ERROR_UNKNOWN, L0::zeRTASBuilderCreateExp(driverHandle->toHandle(), nullptr, &hBuilder)); + EXPECT_EQ(1u, builderCreateFailCalled); +} + +TEST_F(RTASTest, GivenLibraryFailsToLoadThenBuilderCreateReturnsError) { + ze_rtas_builder_exp_handle_t hBuilder; + L0::RTASBuilder::osLibraryLoadFunction = MockRTASOsLibrary::load; + MockRTASOsLibrary::mockLoad = false; + driverHandle->rtasLibraryHandle.reset(); + + EXPECT_EQ(ZE_RESULT_ERROR_DEPENDENCY_UNAVAILABLE, L0::zeRTASBuilderCreateExp(driverHandle->toHandle(), nullptr, &hBuilder)); +} + +TEST_F(RTASTest, GivenUnderlyingBuilderDestroySucceedsThenSuccessIsReturned) { + auto pRTASBuilder = std::make_unique(); + builderDestroyExpImpl = &builderDestroy; + + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeRTASBuilderDestroyExp(pRTASBuilder.release())); + EXPECT_EQ(1u, builderDestroyCalled); +} + +TEST_F(RTASTest, GivenUnderlyingBuilderDestroyFailsThenErrorIsReturned) { + RTASBuilder pRTASBuilder; + builderDestroyExpImpl = &builderDestroyFail; + driverHandle->rtasLibraryHandle = std::make_unique(); + + EXPECT_EQ(ZE_RESULT_ERROR_UNKNOWN, L0::zeRTASBuilderDestroyExp(pRTASBuilder.toHandle())); + EXPECT_EQ(1u, builderDestroyFailCalled); +} + +TEST_F(RTASTest, GivenUnderlyingBuilderGetBuildPropertiesSucceedsThenSuccessIsReturned) { + RTASBuilder pRTASBuilder; + builderGetBuildPropertiesExpImpl = &builderGetBuildProperties; + + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeRTASBuilderGetBuildPropertiesExp(pRTASBuilder.toHandle(), nullptr, nullptr)); + EXPECT_EQ(1u, builderGetBuildPropertiesCalled); +} + +TEST_F(RTASTest, GivenUnderlyingBuilderGetBuildPropertiesFailsThenErrorIsReturned) { + RTASBuilder pRTASBuilder; + builderGetBuildPropertiesExpImpl = &builderGetBuildPropertiesFail; + + EXPECT_EQ(ZE_RESULT_ERROR_UNKNOWN, L0::zeRTASBuilderGetBuildPropertiesExp(pRTASBuilder.toHandle(), nullptr, nullptr)); + EXPECT_EQ(1u, builderGetBuildPropertiesFailCalled); +} + +TEST_F(RTASTest, GivenUnderlyingBuilderBuildSucceedsThenSuccessIsReturned) { + RTASBuilder pRTASBuilder; + RTASParallelOperation pParallelOperation; + builderBuildExpImpl = &builderBuild; + + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeRTASBuilderBuildExp(pRTASBuilder.toHandle(), + nullptr, + nullptr, 0, + nullptr, 0, + pParallelOperation.toHandle(), + nullptr, nullptr, + nullptr)); + EXPECT_EQ(1u, builderBuildCalled); +} + +TEST_F(RTASTest, GivenUnderlyingBuilderBuildFailsThenErrorIsReturned) { + RTASBuilder pRTASBuilder; + RTASParallelOperation pParallelOperation; + builderBuildExpImpl = &builderBuildFail; + + EXPECT_EQ(ZE_RESULT_ERROR_UNKNOWN, L0::zeRTASBuilderBuildExp(pRTASBuilder.toHandle(), + nullptr, + nullptr, 0, + nullptr, 0, + pParallelOperation.toHandle(), + nullptr, nullptr, + nullptr)); + EXPECT_EQ(1u, builderBuildFailCalled); +} + +TEST_F(RTASTest, GivenLibraryPreLoadedAndUnderlyingFormatCompatibilitySucceedsThenSuccessIsReturned) { + formatCompatibilityCheckExpImpl = &formatCompatibilityCheck; + const ze_rtas_format_exp_t accelFormatA = {}; + const ze_rtas_format_exp_t accelFormatB = {}; + driverHandle->rtasLibraryHandle = std::make_unique(); + + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeDriverRTASFormatCompatibilityCheckExp(driverHandle->toHandle(), accelFormatA, accelFormatB)); + EXPECT_EQ(1u, formatCompatibilityCheckCalled); +} + +TEST_F(RTASTest, GivenUnderlyingFormatCompatibilityFailsThenErrorIsReturned) { + formatCompatibilityCheckExpImpl = &formatCompatibilityCheckFail; + const ze_rtas_format_exp_t accelFormatA = {}; + const ze_rtas_format_exp_t accelFormatB = {}; + driverHandle->rtasLibraryHandle = std::make_unique(); + + EXPECT_EQ(ZE_RESULT_ERROR_UNKNOWN, L0::zeDriverRTASFormatCompatibilityCheckExp(driverHandle->toHandle(), accelFormatA, accelFormatB)); + EXPECT_EQ(1u, formatCompatibilityCheckFailCalled); +} + +TEST_F(RTASTest, GivenLibraryPreLoadedAndUnderlyingParallelOperationCreateSucceedsThenSuccessIsReturned) { + ze_rtas_parallel_operation_exp_handle_t hParallelOperation; + parallelOperationCreateExpImpl = ¶llelOperationCreate; + parallelOperationDestroyExpImpl = ¶llelOperationDestroy; + driverHandle->rtasLibraryHandle = std::make_unique(); + + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeRTASParallelOperationCreateExp(driverHandle->toHandle(), &hParallelOperation)); + EXPECT_EQ(1u, parallelOperationCreateCalled); + + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeRTASParallelOperationDestroyExp(hParallelOperation)); + EXPECT_EQ(1u, parallelOperationDestroyCalled); +} + +TEST_F(RTASTest, GivenUnderlyingParallelOperationCreateFailsThenErrorIsReturned) { + ze_rtas_parallel_operation_exp_handle_t hParallelOperation; + parallelOperationCreateExpImpl = ¶llelOperationCreateFail; + driverHandle->rtasLibraryHandle = std::make_unique(); + + EXPECT_EQ(ZE_RESULT_ERROR_UNKNOWN, L0::zeRTASParallelOperationCreateExp(driverHandle->toHandle(), &hParallelOperation)); + EXPECT_EQ(1u, parallelOperationCreateFailCalled); +} + +TEST_F(RTASTest, GivenUnderlyingParallelOperationDestroySucceedsThenSuccessIsReturned) { + auto pParallelOperation = std::make_unique(); + parallelOperationDestroyExpImpl = ¶llelOperationDestroy; + + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeRTASParallelOperationDestroyExp(pParallelOperation.release())); + EXPECT_EQ(1u, parallelOperationDestroyCalled); +} + +TEST_F(RTASTest, GivenUnderlyingParallelOperationDestroyFailsThenErrorIsReturned) { + RTASParallelOperation pParallelOperation; + parallelOperationDestroyExpImpl = ¶llelOperationDestroyFail; + + EXPECT_EQ(ZE_RESULT_ERROR_UNKNOWN, L0::zeRTASParallelOperationDestroyExp(pParallelOperation.toHandle())); + EXPECT_EQ(1u, parallelOperationDestroyFailCalled); +} + +TEST_F(RTASTest, GivenUnderlyingParallelOperationGetPropertiesSucceedsThenSuccessIsReturned) { + RTASParallelOperation pParallelOperation; + parallelOperationGetPropertiesExpImpl = ¶llelOperationGetProperties; + + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeRTASParallelOperationGetPropertiesExp(pParallelOperation.toHandle(), nullptr)); + EXPECT_EQ(1u, parallelOperationGetPropertiesCalled); +} + +TEST_F(RTASTest, GivenUnderlyingParallelOperationGetPropertiesFailsThenErrorIsReturned) { + RTASParallelOperation pParallelOperation; + parallelOperationGetPropertiesExpImpl = ¶llelOperationGetPropertiesFail; + + EXPECT_EQ(ZE_RESULT_ERROR_UNKNOWN, L0::zeRTASParallelOperationGetPropertiesExp(pParallelOperation.toHandle(), nullptr)); + EXPECT_EQ(1u, parallelOperationGetPropertiesFailCalled); +} + +TEST_F(RTASTest, GivenUnderlyingParallelOperationJoinSucceedsThenSuccessIsReturned) { + RTASParallelOperation pParallelOperation; + parallelOperationJoinExpImpl = ¶llelOperationJoin; + + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeRTASParallelOperationJoinExp(pParallelOperation.toHandle())); + EXPECT_EQ(1u, parallelOperationJoinCalled); +} + +TEST_F(RTASTest, GivenUnderlyingParallelOperationJoinFailsThenErrorIsReturned) { + RTASParallelOperation pParallelOperation; + parallelOperationJoinExpImpl = ¶llelOperationJoinFail; + + EXPECT_EQ(ZE_RESULT_ERROR_UNKNOWN, L0::zeRTASParallelOperationJoinExp(pParallelOperation.toHandle())); + EXPECT_EQ(1u, parallelOperationJoinFailCalled); +} + +TEST_F(RTASTest, GivenNoSymbolAvailableInLibraryThenLoadEntryPointsReturnsFalse) { + driverHandle->rtasLibraryHandle = std::make_unique(); + + EXPECT_EQ(false, L0::RTASBuilder::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); +} + +TEST_F(RTASTest, GivenRTASLibraryHandleUnavailableThenDependencyUnavailableErrorIsReturned) { + L0::RTASBuilder::osLibraryLoadFunction = MockRTASOsLibrary::load; + MockRTASOsLibrary::mockLoad = false; + driverHandle->rtasLibraryHandle.reset(); + + EXPECT_EQ(false, driverHandle->rtasLibraryUnavailable); + EXPECT_EQ(ZE_RESULT_ERROR_DEPENDENCY_UNAVAILABLE, driverHandle->loadRTASLibrary()); + EXPECT_EQ(true, driverHandle->rtasLibraryUnavailable); + EXPECT_EQ(ZE_RESULT_ERROR_DEPENDENCY_UNAVAILABLE, driverHandle->loadRTASLibrary()); +} + +TEST_F(RTASTest, GivenOnlySingleSymbolAvailableThenLoadEntryPointsReturnsFalse) { + driverHandle->rtasLibraryHandle = std::make_unique(); + MockRTASOsLibrary *osLibHandle = static_cast(driverHandle->rtasLibraryHandle.get()); + + osLibHandle->funcMap["zeRTASBuilderCreateExpImpl"] = reinterpret_cast(&builderCreate); + EXPECT_EQ(false, L0::RTASBuilder::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + + osLibHandle->funcMap.clear(); + osLibHandle->funcMap["zeRTASBuilderDestroyExpImpl"] = reinterpret_cast(&builderDestroy); + EXPECT_EQ(false, L0::RTASBuilder::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + + osLibHandle->funcMap.clear(); + osLibHandle->funcMap["zeRTASBuilderGetBuildPropertiesExpImpl"] = reinterpret_cast(&builderGetBuildProperties); + EXPECT_EQ(false, L0::RTASBuilder::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + + osLibHandle->funcMap.clear(); + osLibHandle->funcMap["zeRTASBuilderBuildExpImpl"] = reinterpret_cast(&builderBuild); + EXPECT_EQ(false, L0::RTASBuilder::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + + osLibHandle->funcMap.clear(); + osLibHandle->funcMap["zeDriverRTASFormatCompatibilityCheckExpImpl"] = reinterpret_cast(&formatCompatibilityCheck); + EXPECT_EQ(false, L0::RTASBuilder::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + + osLibHandle->funcMap.clear(); + osLibHandle->funcMap["zeRTASParallelOperationCreateExpImpl"] = reinterpret_cast(¶llelOperationCreate); + EXPECT_EQ(false, L0::RTASBuilder::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + + osLibHandle->funcMap.clear(); + osLibHandle->funcMap["zeRTASParallelOperationDestroyExpImpl"] = reinterpret_cast(¶llelOperationDestroy); + EXPECT_EQ(false, L0::RTASBuilder::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + + osLibHandle->funcMap.clear(); + osLibHandle->funcMap["zeRTASParallelOperationGetPropertiesExpImpl"] = reinterpret_cast(¶llelOperationGetProperties); + EXPECT_EQ(false, L0::RTASBuilder::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + + osLibHandle->funcMap.clear(); + osLibHandle->funcMap["zeRTASParallelOperationJoinExpImpl"] = reinterpret_cast(¶llelOperationJoin); + EXPECT_EQ(false, L0::RTASBuilder::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); +} + +TEST_F(RTASTest, GivenMissingSymbolsThenLoadEntryPointsReturnsFalse) { + driverHandle->rtasLibraryHandle = std::make_unique(); + MockRTASOsLibrary *osLibHandle = static_cast(driverHandle->rtasLibraryHandle.get()); + + EXPECT_EQ(false, L0::RTASBuilder::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + osLibHandle->funcMap["zeRTASBuilderCreateExpImpl"] = reinterpret_cast(&builderCreate); + + EXPECT_EQ(false, L0::RTASBuilder::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + osLibHandle->funcMap["zeRTASBuilderDestroyExpImpl"] = reinterpret_cast(&builderDestroy); + + EXPECT_EQ(false, L0::RTASBuilder::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + osLibHandle->funcMap["zeRTASBuilderGetBuildPropertiesExpImpl"] = reinterpret_cast(&builderGetBuildProperties); + + EXPECT_EQ(false, L0::RTASBuilder::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + osLibHandle->funcMap["zeRTASBuilderBuildExpImpl"] = reinterpret_cast(&builderBuild); + + EXPECT_EQ(false, L0::RTASBuilder::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + osLibHandle->funcMap["zeDriverRTASFormatCompatibilityCheckExpImpl"] = reinterpret_cast(&formatCompatibilityCheck); + + EXPECT_EQ(false, L0::RTASBuilder::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + osLibHandle->funcMap["zeRTASParallelOperationCreateExpImpl"] = reinterpret_cast(¶llelOperationCreate); + + EXPECT_EQ(false, L0::RTASBuilder::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + osLibHandle->funcMap["zeRTASParallelOperationDestroyExpImpl"] = reinterpret_cast(¶llelOperationDestroy); + + EXPECT_EQ(false, L0::RTASBuilder::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); + osLibHandle->funcMap["zeRTASParallelOperationGetPropertiesExpImpl"] = reinterpret_cast(¶llelOperationGetProperties); + + EXPECT_EQ(false, L0::RTASBuilder::loadEntryPoints(driverHandle->rtasLibraryHandle.get())); +} + +} // namespace ult +} // namespace L0 \ No newline at end of file diff --git a/level_zero/core/test/unit_tests/xe_hpc_core/test_l0_gfx_core_helper_xe_hpc_core.cpp b/level_zero/core/test/unit_tests/xe_hpc_core/test_l0_gfx_core_helper_xe_hpc_core.cpp index cff693b39e..fa3093ed92 100644 --- a/level_zero/core/test/unit_tests/xe_hpc_core/test_l0_gfx_core_helper_xe_hpc_core.cpp +++ b/level_zero/core/test/unit_tests/xe_hpc_core/test_l0_gfx_core_helper_xe_hpc_core.cpp @@ -79,5 +79,10 @@ XE_HPC_CORETEST_F(L0GfxCoreHelperTestXeHpc, GivenXeHpcWhenGetRegsetTypeForLargeG EXPECT_EQ(ZET_DEBUG_REGSET_TYPE_CR_INTEL_GPU, l0GfxCoreHelper.getRegsetTypeForLargeGrfDetection()); } +XE_HPC_CORETEST_F(L0GfxCoreHelperTestXeHpc, GivenXeHpcWhenGettingSupportedRTASFormatThenExpectedFormatIsReturned) { + const auto &l0GfxCoreHelper = getHelper(); + EXPECT_EQ(ZE_RTAS_DEVICE_FORMAT_EXP_VERSION_1, static_cast(l0GfxCoreHelper.getSupportedRTASFormat())); +} + } // namespace ult } // namespace L0 diff --git a/level_zero/core/test/unit_tests/xe_hpg_core/test_l0_gfx_core_helper_xe_hpg_core.cpp b/level_zero/core/test/unit_tests/xe_hpg_core/test_l0_gfx_core_helper_xe_hpg_core.cpp index cff6b39766..796ab3a31b 100644 --- a/level_zero/core/test/unit_tests/xe_hpg_core/test_l0_gfx_core_helper_xe_hpg_core.cpp +++ b/level_zero/core/test/unit_tests/xe_hpg_core/test_l0_gfx_core_helper_xe_hpg_core.cpp @@ -81,5 +81,10 @@ XE_HPG_CORETEST_F(L0GfxCoreHelperTestXeHpg, GivenXeHpgWhenCheckingL0HelperForPla EXPECT_TRUE(l0GfxCoreHelper.platformSupportsImmediateComputeFlushTask()); } +XE_HPG_CORETEST_F(L0GfxCoreHelperTestXeHpg, GivenXeHpgWhenGettingSupportedRTASFormatThenExpectedFormatIsReturned) { + const auto &l0GfxCoreHelper = getHelper(); + EXPECT_EQ(ZE_RTAS_DEVICE_FORMAT_EXP_VERSION_1, static_cast(l0GfxCoreHelper.getSupportedRTASFormat())); +} + } // namespace ult } // namespace L0