diff --git a/level_zero/core/source/context/context_imp.cpp b/level_zero/core/source/context/context_imp.cpp index a20a11eddc..3ced891201 100644 --- a/level_zero/core/source/context/context_imp.cpp +++ b/level_zero/core/source/context/context_imp.cpp @@ -136,8 +136,6 @@ ze_result_t ContextImp::allocDeviceMem(ze_device_handle_t hDevice, return ZE_RESULT_SUCCESS; } - neoDevice = this->driverHandle->devices[0]->getNEODevice(); - if (lookupTable.relaxedSizeAllowed == false && (size > neoDevice->getDeviceInfo().maxMemAllocSize)) { *ptr = nullptr; @@ -155,8 +153,6 @@ ze_result_t ContextImp::allocDeviceMem(ze_device_handle_t hDevice, return ZE_RESULT_ERROR_UNSUPPORTED_SIZE; } - neoDevice = Device::fromHandle(hDevice)->getNEODevice(); - deviceBitfields[rootDeviceIndex] = neoDevice->getDeviceBitfield(); NEO::SVMAllocsManager::UnifiedMemoryProperties unifiedMemoryProperties(InternalMemoryType::DEVICE_UNIFIED_MEMORY, this->driverHandle->rootDeviceIndices, deviceBitfields); @@ -183,6 +179,12 @@ ze_result_t ContextImp::allocSharedMem(ze_device_handle_t hDevice, size_t size, size_t alignment, void **ptr) { + + auto neoDevice = this->devices.begin()->second->getNEODevice(); + if (hDevice != nullptr) { + neoDevice = Device::fromHandle(hDevice)->getNEODevice(); + } + bool relaxedSizeAllowed = NEO::DebugManager.flags.AllowUnrestrictedSize.get(); if (deviceDesc->pNext) { const ze_base_desc_t *extendedDesc = reinterpret_cast(deviceDesc->pNext); @@ -196,14 +198,12 @@ ze_result_t ContextImp::allocSharedMem(ze_device_handle_t hDevice, } } - auto neoDevice = this->devices.begin()->second->getNEODevice(); if (relaxedSizeAllowed == false && (size > neoDevice->getDeviceInfo().maxMemAllocSize)) { *ptr = nullptr; return ZE_RESULT_ERROR_UNSUPPORTED_SIZE; } - neoDevice = this->driverHandle->devices[0]->getNEODevice(); uint64_t globalMemSize = neoDevice->getDeviceInfo().globalMemSize; uint32_t numSubDevices = neoDevice->getNumGenericSubDevices(); @@ -216,7 +216,6 @@ ze_result_t ContextImp::allocSharedMem(ze_device_handle_t hDevice, return ZE_RESULT_ERROR_UNSUPPORTED_SIZE; } - neoDevice = this->devices.begin()->second->getNEODevice(); auto deviceBitfields = this->deviceBitfields; NEO::Device *unifiedMemoryPropertiesDevice = nullptr; if (hDevice) { diff --git a/level_zero/core/test/unit_tests/sources/memory/test_memory.cpp b/level_zero/core/test/unit_tests/sources/memory/test_memory.cpp index ef2cad7141..87d7414d03 100644 --- a/level_zero/core/test/unit_tests/sources/memory/test_memory.cpp +++ b/level_zero/core/test/unit_tests/sources/memory/test_memory.cpp @@ -2886,13 +2886,21 @@ TEST_F(SharedAllocMultiDeviceTests, whenAllocatinSharedMemoryWithNonNullDeviceIn ze_host_mem_alloc_desc_t hostDesc = {}; void *ptr = nullptr; size_t size = 1024; + ze_result_t res = ZE_RESULT_ERROR_UNKNOWN; + ze_memory_allocation_properties_t memoryProperties = {}; + ze_device_handle_t deviceHandle; EXPECT_EQ(currSvmAllocsManager->createHostUnifiedMemoryAllocationTimes, 0u); - ze_result_t res = context->allocSharedMem(driverHandle->devices[0]->toHandle(), &deviceDesc, &hostDesc, size, 0u, &ptr); - EXPECT_EQ(res, ZE_RESULT_SUCCESS); + for (uint32_t i = 0; i < numRootDevices; i++) { + res = context->allocSharedMem(driverHandle->devices[i]->toHandle(), &deviceDesc, &hostDesc, size, 0u, &ptr); + EXPECT_EQ(res, ZE_RESULT_SUCCESS); + res = context->getMemAllocProperties(ptr, &memoryProperties, &deviceHandle); + EXPECT_EQ(ZE_RESULT_SUCCESS, res); + EXPECT_EQ(memoryProperties.type, ZE_MEMORY_TYPE_SHARED); + EXPECT_EQ(deviceHandle, driverHandle->devices[i]->toHandle()); + res = context->freeMem(ptr); + EXPECT_EQ(res, ZE_RESULT_SUCCESS); + } EXPECT_EQ(currSvmAllocsManager->createHostUnifiedMemoryAllocationTimes, 0u); - - res = context->freeMem(ptr); - EXPECT_EQ(res, ZE_RESULT_SUCCESS); } struct MemAllocMultiSubDeviceTests : public ::testing::Test {