diff --git a/runtime/api/api.cpp b/runtime/api/api.cpp index 4d133856f8..5e90aa0d77 100644 --- a/runtime/api/api.cpp +++ b/runtime/api/api.cpp @@ -3412,6 +3412,11 @@ void *clHostMemAllocINTEL( return nullptr; } + if (size > neoContext->getDevice(0u)->getDeviceInfo().maxMemAllocSize) { + err.set(CL_INVALID_BUFFER_SIZE); + return nullptr; + } + SVMAllocsManager::UnifiedMemoryProperties unifiedMemoryProperties(InternalMemoryType::HOST_UNIFIED_MEMORY); if (!MemObjHelper::parseUnifiedMemoryProperties(properties, unifiedMemoryProperties)) { err.set(CL_INVALID_VALUE); @@ -3429,23 +3434,29 @@ void *clDeviceMemAllocINTEL( cl_uint alignment, cl_int *errcodeRet) { Context *neoContext = nullptr; + Device *neoDevice = nullptr; ErrorCodeHelper err(errcodeRet, CL_SUCCESS); - auto retVal = validateObjects(WithCastToInternal(context, &neoContext)); + auto retVal = validateObjects(WithCastToInternal(context, &neoContext), WithCastToInternal(device, &neoDevice)); if (retVal != CL_SUCCESS) { err.set(retVal); return nullptr; } + if (size > neoDevice->getDeviceInfo().maxMemAllocSize) { + err.set(CL_INVALID_BUFFER_SIZE); + return nullptr; + } + SVMAllocsManager::UnifiedMemoryProperties unifiedMemoryProperties(InternalMemoryType::DEVICE_UNIFIED_MEMORY); if (!MemObjHelper::parseUnifiedMemoryProperties(properties, unifiedMemoryProperties)) { err.set(CL_INVALID_VALUE); return nullptr; } - return neoContext->getSVMAllocsManager()->createUnifiedMemoryAllocation(neoContext->getDevice(0)->getRootDeviceIndex(), size, unifiedMemoryProperties); + return neoContext->getSVMAllocsManager()->createUnifiedMemoryAllocation(neoDevice->getRootDeviceIndex(), size, unifiedMemoryProperties); } void *clSharedMemAllocINTEL( @@ -3456,16 +3467,22 @@ void *clSharedMemAllocINTEL( cl_uint alignment, cl_int *errcodeRet) { Context *neoContext = nullptr; + Device *neoDevice = nullptr; ErrorCodeHelper err(errcodeRet, CL_SUCCESS); - auto retVal = validateObjects(WithCastToInternal(context, &neoContext)); + auto retVal = validateObjects(WithCastToInternal(context, &neoContext), WithCastToInternal(device, &neoDevice)); if (retVal != CL_SUCCESS) { err.set(retVal); return nullptr; } + if (size > neoDevice->getDeviceInfo().maxMemAllocSize) { + err.set(CL_INVALID_BUFFER_SIZE); + return nullptr; + } + SVMAllocsManager::UnifiedMemoryProperties unifiedMemoryProperties(InternalMemoryType::SHARED_UNIFIED_MEMORY); if (!MemObjHelper::parseUnifiedMemoryProperties(properties, unifiedMemoryProperties)) { err.set(CL_INVALID_VALUE); diff --git a/unit_tests/api/cl_unified_shared_memory_tests.inl b/unit_tests/api/cl_unified_shared_memory_tests.inl index af278e81e9..e697fc67a4 100644 --- a/unit_tests/api/cl_unified_shared_memory_tests.inl +++ b/unit_tests/api/cl_unified_shared_memory_tests.inl @@ -80,6 +80,23 @@ TEST(clUnifiedSharedMemoryTests, whenClDeviceMemAllocIntelIsCalledThenItAllocate EXPECT_EQ(CL_SUCCESS, retVal); } +TEST(clUnifiedSharedMemoryTests, whenUnifiedSharedMemoryAllocationCallsAreCalledWithSizeGreaterThenMaxMemAllocSizeThenErrorIsReturned) { + MockContext mockContext; + cl_int retVal = CL_SUCCESS; + auto maxMemAllocSize = mockContext.getDevice(0u)->getDeviceInfo().maxMemAllocSize; + size_t requestedSize = static_cast(maxMemAllocSize) + 1u; + + auto unfiedMemoryDeviceAllocation = clDeviceMemAllocINTEL(&mockContext, mockContext.getDevice(0u), nullptr, requestedSize, 0, &retVal); + EXPECT_EQ(CL_INVALID_BUFFER_SIZE, retVal); + EXPECT_EQ(nullptr, unfiedMemoryDeviceAllocation); + unfiedMemoryDeviceAllocation = clSharedMemAllocINTEL(&mockContext, mockContext.getDevice(0u), nullptr, requestedSize, 0, &retVal); + EXPECT_EQ(CL_INVALID_BUFFER_SIZE, retVal); + EXPECT_EQ(nullptr, unfiedMemoryDeviceAllocation); + unfiedMemoryDeviceAllocation = clHostMemAllocINTEL(&mockContext, nullptr, requestedSize, 0, &retVal); + EXPECT_EQ(CL_INVALID_BUFFER_SIZE, retVal); + EXPECT_EQ(nullptr, unfiedMemoryDeviceAllocation); +} + TEST(clUnifiedSharedMemoryTests, whenClSharedMemAllocINTELisCalledWithWrongContextThenInvalidContextErrorIsReturned) { cl_int retVal = CL_SUCCESS; auto ptr = clSharedMemAllocINTEL(0, 0, nullptr, 0, 0, &retVal);