fix: add fallback for invalid handles in extension functions

handle context, commandlist, driver, device, event, image and kernel handles

Signed-off-by: Mateusz Jablonski <mateusz.jablonski@intel.com>
This commit is contained in:
Mateusz Jablonski 2024-08-28 12:23:22 +00:00 committed by Compute-Runtime-Automation
parent 066282e15b
commit d45c16dfc2
17 changed files with 75 additions and 28 deletions

View File

@ -1,5 +1,5 @@
/* /*
* Copyright (C) 2022-2023 Intel Corporation * Copyright (C) 2022-2024 Intel Corporation
* *
* SPDX-License-Identifier: MIT * SPDX-License-Identifier: MIT
* *
@ -19,6 +19,7 @@ zexCommandListAppendWaitOnMemory(
zex_event_handle_t hSignalEvent) { zex_event_handle_t hSignalEvent) {
try { try {
{ {
hCommandList = toInternalType(hCommandList);
if (nullptr == hCommandList) if (nullptr == hCommandList)
return ZE_RESULT_ERROR_INVALID_ARGUMENT; return ZE_RESULT_ERROR_INVALID_ARGUMENT;
} }
@ -40,6 +41,7 @@ zexCommandListAppendWaitOnMemory64(
uint64_t data, uint64_t data,
zex_event_handle_t hSignalEvent) { zex_event_handle_t hSignalEvent) {
hCommandList = toInternalType(hCommandList);
if (!hCommandList) { if (!hCommandList) {
return ZE_RESULT_ERROR_INVALID_ARGUMENT; return ZE_RESULT_ERROR_INVALID_ARGUMENT;
} }
@ -55,6 +57,7 @@ zexCommandListAppendWriteToMemory(
uint64_t data) { uint64_t data) {
try { try {
{ {
hCommandList = toInternalType(hCommandList);
if (nullptr == hCommandList) if (nullptr == hCommandList)
return ZE_RESULT_ERROR_INVALID_ARGUMENT; return ZE_RESULT_ERROR_INVALID_ARGUMENT;
} }

View File

@ -13,7 +13,7 @@
namespace L0 { namespace L0 {
ZE_APIEXPORT ze_result_t ZE_APICALL zeIntelMediaCommunicationCreate(ze_context_handle_t hContext, ze_device_handle_t hDevice, ze_intel_media_communication_desc_t *desc, ze_intel_media_doorbell_handle_desc_t *phDoorbell) { ZE_APIEXPORT ze_result_t ZE_APICALL zeIntelMediaCommunicationCreate(ze_context_handle_t hContext, ze_device_handle_t hDevice, ze_intel_media_communication_desc_t *desc, ze_intel_media_doorbell_handle_desc_t *phDoorbell) {
auto device = Device::fromHandle(hDevice); auto device = Device::fromHandle(toInternalType(hDevice));
if (!device || !desc || !phDoorbell) { if (!device || !desc || !phDoorbell) {
return ZE_RESULT_ERROR_INVALID_ARGUMENT; return ZE_RESULT_ERROR_INVALID_ARGUMENT;
@ -28,7 +28,7 @@ ZE_APIEXPORT ze_result_t ZE_APICALL zeIntelMediaCommunicationCreate(ze_context_h
} }
ZE_APIEXPORT ze_result_t ZE_APICALL zeIntelMediaCommunicationDestroy(ze_context_handle_t hContext, ze_device_handle_t hDevice, ze_intel_media_doorbell_handle_desc_t *phDoorbell) { ZE_APIEXPORT ze_result_t ZE_APICALL zeIntelMediaCommunicationDestroy(ze_context_handle_t hContext, ze_device_handle_t hDevice, ze_intel_media_doorbell_handle_desc_t *phDoorbell) {
auto device = Device::fromHandle(hDevice); auto device = Device::fromHandle(toInternalType(hDevice));
if (!device || !phDoorbell) { if (!device || !phDoorbell) {
return ZE_RESULT_ERROR_INVALID_ARGUMENT; return ZE_RESULT_ERROR_INVALID_ARGUMENT;

View File

@ -23,14 +23,14 @@ zexDriverImportExternalPointer(
ze_driver_handle_t hDriver, ze_driver_handle_t hDriver,
void *ptr, void *ptr,
size_t size) { size_t size) {
return L0::DriverHandle::fromHandle(hDriver)->importExternalPointer(ptr, size); return L0::DriverHandle::fromHandle(toInternalType(hDriver))->importExternalPointer(ptr, size);
} }
ze_result_t ZE_APICALL ze_result_t ZE_APICALL
zexDriverReleaseImportedPointer( zexDriverReleaseImportedPointer(
ze_driver_handle_t hDriver, ze_driver_handle_t hDriver,
void *ptr) { void *ptr) {
return L0::DriverHandle::fromHandle(hDriver)->releaseImportedPointer(ptr); return L0::DriverHandle::fromHandle(toInternalType(hDriver))->releaseImportedPointer(ptr);
} }
ze_result_t ZE_APICALL ze_result_t ZE_APICALL
@ -38,7 +38,7 @@ zexDriverGetHostPointerBaseAddress(
ze_driver_handle_t hDriver, ze_driver_handle_t hDriver,
void *ptr, void *ptr,
void **baseAddress) { void **baseAddress) {
return L0::DriverHandle::fromHandle(hDriver)->getHostPointerBaseAddress(ptr, baseAddress); return L0::DriverHandle::fromHandle(toInternalType(hDriver))->getHostPointerBaseAddress(ptr, baseAddress);
} }
} // namespace L0 } // namespace L0
@ -49,7 +49,7 @@ zeIntelGetDriverVersionString(
char *pDriverVersion, char *pDriverVersion,
size_t *pVersionSize) { size_t *pVersionSize) {
ze_api_version_t apiVersion; ze_api_version_t apiVersion;
L0::DriverHandle::fromHandle(hDriver)->getApiVersion(&apiVersion); L0::DriverHandle::fromHandle(toInternalType(hDriver))->getApiVersion(&apiVersion);
std::string driverVersionString = std::to_string(ZE_MAJOR_VERSION(apiVersion)) + "." + std::to_string(ZE_MINOR_VERSION(apiVersion)) + "." + std::to_string(NEO_VERSION_BUILD); std::string driverVersionString = std::to_string(ZE_MAJOR_VERSION(apiVersion)) + "." + std::to_string(ZE_MINOR_VERSION(apiVersion)) + "." + std::to_string(NEO_VERSION_BUILD);
if (NEO_VERSION_HOTFIX > 0) { if (NEO_VERSION_HOTFIX > 0) {
driverVersionString += "+" + std::to_string(NEO_VERSION_HOTFIX); driverVersionString += "+" + std::to_string(NEO_VERSION_HOTFIX);

View File

@ -20,7 +20,7 @@ namespace L0 {
ZE_APIEXPORT ze_result_t ZE_APICALL ZE_APIEXPORT ze_result_t ZE_APICALL
zexEventGetDeviceAddress(ze_event_handle_t event, uint64_t *completionValue, uint64_t *address) { zexEventGetDeviceAddress(ze_event_handle_t event, uint64_t *completionValue, uint64_t *address) {
auto eventObj = Event::fromHandle(event); auto eventObj = Event::fromHandle(toInternalType(event));
if (!eventObj || !completionValue || !address) { if (!eventObj || !completionValue || !address) {
return ZE_RESULT_ERROR_INVALID_ARGUMENT; return ZE_RESULT_ERROR_INVALID_ARGUMENT;
@ -58,7 +58,7 @@ zexCounterBasedEventCreate(ze_context_handle_t hContext, ze_device_handle_t hDev
false, // ipcPool false, // ipcPool
}; };
auto device = Device::fromHandle(hDevice); auto device = Device::fromHandle(toInternalType(hDevice));
if (!hDevice || !desc || !phEvent) { if (!hDevice || !desc || !phEvent) {
return ZE_RESULT_ERROR_INVALID_ARGUMENT; return ZE_RESULT_ERROR_INVALID_ARGUMENT;
@ -85,7 +85,7 @@ zexCounterBasedEventCreate(ze_context_handle_t hContext, ze_device_handle_t hDev
} }
ZE_APIEXPORT ze_result_t ZE_APICALL zexIntelAllocateNetworkInterrupt(ze_context_handle_t hContext, uint32_t &networkInterruptId) { ZE_APIEXPORT ze_result_t ZE_APICALL zexIntelAllocateNetworkInterrupt(ze_context_handle_t hContext, uint32_t &networkInterruptId) {
auto context = static_cast<ContextImp *>(L0::Context::fromHandle(hContext)); auto context = static_cast<ContextImp *>(L0::Context::fromHandle(toInternalType(hContext)));
if (!context) { if (!context) {
return ZE_RESULT_ERROR_INVALID_ARGUMENT; return ZE_RESULT_ERROR_INVALID_ARGUMENT;
@ -99,7 +99,7 @@ ZE_APIEXPORT ze_result_t ZE_APICALL zexIntelAllocateNetworkInterrupt(ze_context_
} }
ZE_APIEXPORT ze_result_t ZE_APICALL zexIntelReleaseNetworkInterrupt(ze_context_handle_t hContext, uint32_t networkInterruptId) { ZE_APIEXPORT ze_result_t ZE_APICALL zexIntelReleaseNetworkInterrupt(ze_context_handle_t hContext, uint32_t networkInterruptId) {
auto context = static_cast<ContextImp *>(L0::Context::fromHandle(hContext)); auto context = static_cast<ContextImp *>(L0::Context::fromHandle(toInternalType(hContext)));
if (!context || !context->getDriverHandle()->getMemoryManager()->releaseInterrupt(networkInterruptId, context->rootDeviceIndices[0])) { if (!context || !context->getDriverHandle()->getMemoryManager()->releaseInterrupt(networkInterruptId, context->rootDeviceIndices[0])) {
return ZE_RESULT_ERROR_INVALID_ARGUMENT; return ZE_RESULT_ERROR_INVALID_ARGUMENT;

View File

@ -1,5 +1,5 @@
/* /*
* Copyright (C) 2020-2022 Intel Corporation * Copyright (C) 2020-2024 Intel Corporation
* *
* SPDX-License-Identifier: MIT * SPDX-License-Identifier: MIT
* *
@ -7,6 +7,7 @@
#include "level_zero/api/driver_experimental/public/zex_api.h" #include "level_zero/api/driver_experimental/public/zex_api.h"
#include "level_zero/core/source/context/context.h" #include "level_zero/core/source/context/context.h"
#include "level_zero/core/source/device/device.h"
namespace L0 { namespace L0 {
@ -16,7 +17,7 @@ zexMemGetIpcHandles(
const void *ptr, const void *ptr,
uint32_t *numIpcHandles, uint32_t *numIpcHandles,
ze_ipc_mem_handle_t *pIpcHandles) { ze_ipc_mem_handle_t *pIpcHandles) {
return L0::Context::fromHandle(hContext)->getIpcMemHandles(ptr, numIpcHandles, pIpcHandles); return L0::Context::fromHandle(toInternalType(hContext))->getIpcMemHandles(ptr, numIpcHandles, pIpcHandles);
} }
ze_result_t ZE_APICALL ze_result_t ZE_APICALL
@ -27,7 +28,7 @@ zexMemOpenIpcHandles(
ze_ipc_mem_handle_t *pIpcHandles, ze_ipc_mem_handle_t *pIpcHandles,
ze_ipc_memory_flags_t flags, ze_ipc_memory_flags_t flags,
void **pptr) { void **pptr) {
return L0::Context::fromHandle(hContext)->openIpcMemHandles(hDevice, numIpcHandles, pIpcHandles, flags, pptr); return L0::Context::fromHandle(toInternalType(hContext))->openIpcMemHandles(toInternalType(hDevice), numIpcHandles, pIpcHandles, flags, pptr);
} }
} // namespace L0 } // namespace L0

View File

@ -20,7 +20,7 @@ zexDeviceGetConcurrentMetricGroups(
uint32_t *pConcurrentGroupCount, uint32_t *pConcurrentGroupCount,
uint32_t *pCountPerConcurrentGroup) { uint32_t *pCountPerConcurrentGroup) {
auto device = Device::fromHandle(hDevice); auto device = Device::fromHandle(toInternalType(hDevice));
return static_cast<MetricDeviceContext &>(device->getMetricDeviceContext()).getConcurrentMetricGroups(metricGroupCount, phMetricGroups, pConcurrentGroupCount, pCountPerConcurrentGroup); return static_cast<MetricDeviceContext &>(device->getMetricDeviceContext()).getConcurrentMetricGroups(metricGroupCount, phMetricGroups, pConcurrentGroupCount, pCountPerConcurrentGroup);
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright (C) 2020-2022 Intel Corporation * Copyright (C) 2020-2024 Intel Corporation
* *
* SPDX-License-Identifier: MIT * SPDX-License-Identifier: MIT
* *
@ -15,7 +15,7 @@ ze_result_t ZE_APICALL
zexKernelGetBaseAddress( zexKernelGetBaseAddress(
ze_kernel_handle_t hKernel, ze_kernel_handle_t hKernel,
uint64_t *baseAddress) { uint64_t *baseAddress) {
return L0::Kernel::fromHandle(hKernel)->getBaseAddress(baseAddress); return L0::Kernel::fromHandle(toInternalType(hKernel))->getBaseAddress(baseAddress);
} }
} // namespace L0 } // namespace L0

View File

@ -191,13 +191,13 @@ ze_result_t ZE_APICALL zeMemGetPitchFor2dImage(
size_t imageHeight, size_t imageHeight,
unsigned int elementSizeInBytes, unsigned int elementSizeInBytes,
size_t *rowPitch) { size_t *rowPitch) {
return L0::Context::fromHandle(hContext)->getPitchFor2dImage(hDevice, imageWidth, imageHeight, elementSizeInBytes, rowPitch); return L0::Context::fromHandle(toInternalType(hContext))->getPitchFor2dImage(toInternalType(hDevice), imageWidth, imageHeight, elementSizeInBytes, rowPitch);
} }
ze_result_t ZE_APICALL zeImageGetDeviceOffsetExp( ze_result_t ZE_APICALL zeImageGetDeviceOffsetExp(
ze_image_handle_t hImage, ze_image_handle_t hImage,
uint64_t *pDeviceOffset) { uint64_t *pDeviceOffset) {
return L0::Image::fromHandle(hImage)->getDeviceOffset(pDeviceOffset); return L0::Image::fromHandle(toInternalType(hImage))->getDeviceOffset(pDeviceOffset);
} }
} // namespace L0 } // namespace L0

View File

@ -20,6 +20,7 @@
#include "shared/source/utilities/stackvec.h" #include "shared/source/utilities/stackvec.h"
#include "level_zero/core/source/cmdlist/cmdlist_launch_params.h" #include "level_zero/core/source/cmdlist/cmdlist_launch_params.h"
#include "level_zero/core/source/helpers/api_handle_helper.h"
#include <level_zero/ze_api.h> #include <level_zero/ze_api.h>
#include <level_zero/zet_api.h> #include <level_zero/zet_api.h>
@ -28,7 +29,9 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
struct _ze_command_list_handle_t {}; struct _ze_command_list_handle_t {
const uint64_t objMagic = objMagicValue;
};
namespace NEO { namespace NEO {
class ScratchSpaceController; class ScratchSpaceController;

View File

@ -10,10 +10,12 @@
#include "shared/source/memory_manager/allocation_type.h" #include "shared/source/memory_manager/allocation_type.h"
#include "shared/source/unified_memory/unified_memory.h" #include "shared/source/unified_memory/unified_memory.h"
#include "level_zero/core/source/helpers/api_handle_helper.h"
#include <level_zero/ze_api.h> #include <level_zero/ze_api.h>
#include <level_zero/zet_api.h> #include <level_zero/zet_api.h>
struct _ze_context_handle_t { struct _ze_context_handle_t {
const uint64_t objMagic = objMagicValue;
virtual ~_ze_context_handle_t() = default; virtual ~_ze_context_handle_t() = default;
}; };

View File

@ -11,6 +11,7 @@
#include "shared/source/os_interface/product_helper.h" #include "shared/source/os_interface/product_helper.h"
#include "shared/source/utilities/tag_allocator.h" #include "shared/source/utilities/tag_allocator.h"
#include "level_zero/core/source/helpers/api_handle_helper.h"
#include <level_zero/ze_api.h> #include <level_zero/ze_api.h>
#include <level_zero/zet_api.h> #include <level_zero/zet_api.h>
@ -19,7 +20,9 @@
static_assert(NEO::ProductHelper::uuidSize == ZE_MAX_DEVICE_UUID_SIZE); static_assert(NEO::ProductHelper::uuidSize == ZE_MAX_DEVICE_UUID_SIZE);
struct _ze_device_handle_t {}; struct _ze_device_handle_t {
const uint64_t objMagic = objMagicValue;
};
namespace NEO { namespace NEO {
class CommandStreamReceiver; class CommandStreamReceiver;
class DebuggerL0; class DebuggerL0;

View File

@ -7,6 +7,7 @@
#pragma once #pragma once
#include "level_zero/core/source/helpers/api_handle_helper.h"
#include <level_zero/ze_api.h> #include <level_zero/ze_api.h>
#include <level_zero/zes_api.h> #include <level_zero/zes_api.h>
@ -15,6 +16,7 @@
#include <vector> #include <vector>
struct _ze_driver_handle_t { struct _ze_driver_handle_t {
const uint64_t objMagic = objMagicValue;
virtual ~_ze_driver_handle_t() = default; virtual ~_ze_driver_handle_t() = default;
}; };

View File

@ -12,6 +12,7 @@
#include "shared/source/memory_manager/multi_graphics_allocation.h" #include "shared/source/memory_manager/multi_graphics_allocation.h"
#include "shared/source/os_interface/os_time.h" #include "shared/source/os_interface/os_time.h"
#include "level_zero/core/source/helpers/api_handle_helper.h"
#include <level_zero/ze_api.h> #include <level_zero/ze_api.h>
#include <atomic> #include <atomic>
@ -22,7 +23,9 @@
#include <mutex> #include <mutex>
#include <vector> #include <vector>
struct _ze_event_handle_t {}; struct _ze_event_handle_t {
const uint64_t objMagic = objMagicValue;
};
struct _ze_event_pool_handle_t {}; struct _ze_event_pool_handle_t {};

View File

@ -1,5 +1,5 @@
# #
# Copyright (C) 2022-2023 Intel Corporation # Copyright (C) 2022-2024 Intel Corporation
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
# #
@ -7,6 +7,7 @@
target_sources(${L0_STATIC_LIB_NAME} target_sources(${L0_STATIC_LIB_NAME}
PRIVATE PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/CMakeLists.txt ${CMAKE_CURRENT_SOURCE_DIR}/CMakeLists.txt
${CMAKE_CURRENT_SOURCE_DIR}/api_handle_helper.h
${CMAKE_CURRENT_SOURCE_DIR}/api_specific_config_l0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/api_specific_config_l0.cpp
${CMAKE_CURRENT_SOURCE_DIR}/error_code_helper_l0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/error_code_helper_l0.cpp
${CMAKE_CURRENT_SOURCE_DIR}/error_code_helper_l0.h ${CMAKE_CURRENT_SOURCE_DIR}/error_code_helper_l0.h

View File

@ -0,0 +1,23 @@
/*
* Copyright (C) 2024 Intel Corporation
*
* SPDX-License-Identifier: MIT
*
*/
#pragma once
#include <cstdint>
constexpr uint64_t objMagicValue = 0x8D7E6A5D4B3E2E1FULL;
template <typename T>
inline T toInternalType(T input) {
if (!input || input->objMagic == objMagicValue) {
return input;
}
input = *reinterpret_cast<T *>(input);
if (input->objMagic == objMagicValue) {
return input;
}
return nullptr;
}

View File

@ -7,9 +7,12 @@
#pragma once #pragma once
#include "level_zero/core/source/helpers/api_handle_helper.h"
#include <level_zero/ze_api.h> #include <level_zero/ze_api.h>
struct _ze_image_handle_t {}; struct _ze_image_handle_t {
const uint64_t objMagic = objMagicValue;
};
namespace NEO { namespace NEO {
struct ImageInfo; struct ImageInfo;

View File

@ -13,13 +13,16 @@
#include "shared/source/memory_manager/unified_memory_manager.h" #include "shared/source/memory_manager/unified_memory_manager.h"
#include "shared/source/unified_memory/unified_memory.h" #include "shared/source/unified_memory/unified_memory.h"
#include "level_zero/core/source/helpers/api_handle_helper.h"
#include <level_zero/ze_api.h> #include <level_zero/ze_api.h>
#include <level_zero/zet_api.h> #include <level_zero/zet_api.h>
#include <memory> #include <memory>
#include <vector> #include <vector>
struct _ze_kernel_handle_t {}; struct _ze_kernel_handle_t {
const uint64_t objMagic = objMagicValue;
};
namespace NEO { namespace NEO {
class Device; class Device;