diff --git a/shared/source/memory_manager/unified_memory_manager.cpp b/shared/source/memory_manager/unified_memory_manager.cpp index 7443133b70..b36e55e37a 100644 --- a/shared/source/memory_manager/unified_memory_manager.cpp +++ b/shared/source/memory_manager/unified_memory_manager.cpp @@ -85,7 +85,7 @@ bool SVMAllocsManager::SvmAllocationCache::insert(size_t size, void *ptr, SvmAll } } if (isSuccess) { - allocations.emplace(std::lower_bound(allocations.begin(), allocations.end(), size), size, ptr); + allocations.emplace(std::lower_bound(allocations.begin(), allocations.end(), size), size, ptr, svmData); } if (enablePerformanceLogging) { logCacheOperation({.allocationSize = size, @@ -129,15 +129,14 @@ void *SVMAllocsManager::SvmAllocationCache::get(size_t size, const UnifiedMemory break; } void *allocationPtr = allocationIter->allocation; - SvmAllocationData *svmData = svmAllocsManager->getSVMAlloc(allocationPtr); - UNRECOVERABLE_IF(nullptr == svmData); - if (svmData->device == unifiedMemoryProperties.device && - svmData->allocationFlagsProperty.allFlags == unifiedMemoryProperties.allocationFlags.allFlags && - svmData->allocationFlagsProperty.allAllocFlags == unifiedMemoryProperties.allocationFlags.allAllocFlags && - false == isInUse(svmData)) { - if (svmData->device) { - auto lock = svmData->device->usmReuseInfo.obtainAllocationsReuseLock(); - svmData->device->usmReuseInfo.recordAllocationGetFromReuse(allocationIter->allocationSize); + DEBUG_BREAK_IF(nullptr == allocationIter->svmData); + if (allocationIter->svmData->device == unifiedMemoryProperties.device && + allocationIter->svmData->allocationFlagsProperty.allFlags == unifiedMemoryProperties.allocationFlags.allFlags && + allocationIter->svmData->allocationFlagsProperty.allAllocFlags == unifiedMemoryProperties.allocationFlags.allAllocFlags && + false == isInUse(allocationIter->svmData)) { + if (allocationIter->svmData->device) { + auto lock = allocationIter->svmData->device->usmReuseInfo.obtainAllocationsReuseLock(); + allocationIter->svmData->device->usmReuseInfo.recordAllocationGetFromReuse(allocationIter->allocationSize); } else { auto lock = memoryManager->usmReuseInfo.obtainAllocationsReuseLock(); memoryManager->usmReuseInfo.recordAllocationGetFromReuse(allocationIter->allocationSize); @@ -145,12 +144,12 @@ void *SVMAllocsManager::SvmAllocationCache::get(size_t size, const UnifiedMemory if (enablePerformanceLogging) { logCacheOperation({.allocationSize = allocationIter->allocationSize, .timePoint = std::chrono::high_resolution_clock::now(), - .allocationType = svmData->memoryType, + .allocationType = allocationIter->svmData->memoryType, .operationType = CacheOperationType::get, .isSuccess = true}); } + allocationIter->svmData->size = size; allocations.erase(allocationIter); - svmData->size = size; return allocationPtr; } } @@ -167,11 +166,10 @@ void *SVMAllocsManager::SvmAllocationCache::get(size_t size, const UnifiedMemory void SVMAllocsManager::SvmAllocationCache::trim() { std::lock_guard lock(this->mtx); for (auto &cachedAllocationInfo : this->allocations) { - SvmAllocationData *svmData = svmAllocsManager->getSVMAlloc(cachedAllocationInfo.allocation); - UNRECOVERABLE_IF(nullptr == svmData); - if (svmData->device) { - auto lock = svmData->device->usmReuseInfo.obtainAllocationsReuseLock(); - svmData->device->usmReuseInfo.recordAllocationGetFromReuse(cachedAllocationInfo.allocationSize); + DEBUG_BREAK_IF(nullptr == cachedAllocationInfo.svmData); + if (cachedAllocationInfo.svmData->device) { + auto lock = cachedAllocationInfo.svmData->device->usmReuseInfo.obtainAllocationsReuseLock(); + cachedAllocationInfo.svmData->device->usmReuseInfo.recordAllocationGetFromReuse(cachedAllocationInfo.allocationSize); } else { auto lock = memoryManager->usmReuseInfo.obtainAllocationsReuseLock(); memoryManager->usmReuseInfo.recordAllocationGetFromReuse(cachedAllocationInfo.allocationSize); @@ -179,11 +177,11 @@ void SVMAllocsManager::SvmAllocationCache::trim() { if (enablePerformanceLogging) { logCacheOperation({.allocationSize = cachedAllocationInfo.allocationSize, .timePoint = std::chrono::high_resolution_clock::now(), - .allocationType = svmData->memoryType, + .allocationType = cachedAllocationInfo.svmData->memoryType, .operationType = CacheOperationType::trim, .isSuccess = true}); } - svmAllocsManager->freeSVMAllocImpl(cachedAllocationInfo.allocation, FreePolicyType::none, svmData); + svmAllocsManager->freeSVMAllocImpl(cachedAllocationInfo.allocation, FreePolicyType::none, cachedAllocationInfo.svmData); } this->allocations.clear(); } @@ -247,12 +245,10 @@ void SVMAllocsManager::SvmAllocationCache::trimOldAllocs(std::chrono::high_resol ++allocCleanCandidate; continue; } - void *allocationPtr = allocCleanCandidate->allocation; - SvmAllocationData *svmData = svmAllocsManager->getSVMAlloc(allocationPtr); - UNRECOVERABLE_IF(nullptr == svmData); - if (svmData->device) { - auto lock = svmData->device->usmReuseInfo.obtainAllocationsReuseLock(); - svmData->device->usmReuseInfo.recordAllocationGetFromReuse(allocCleanCandidate->allocationSize); + DEBUG_BREAK_IF(nullptr == allocCleanCandidate->svmData); + if (allocCleanCandidate->svmData->device) { + auto lock = allocCleanCandidate->svmData->device->usmReuseInfo.obtainAllocationsReuseLock(); + allocCleanCandidate->svmData->device->usmReuseInfo.recordAllocationGetFromReuse(allocCleanCandidate->allocationSize); } else { auto lock = memoryManager->usmReuseInfo.obtainAllocationsReuseLock(); memoryManager->usmReuseInfo.recordAllocationGetFromReuse(allocCleanCandidate->allocationSize); @@ -260,11 +256,11 @@ void SVMAllocsManager::SvmAllocationCache::trimOldAllocs(std::chrono::high_resol if (enablePerformanceLogging) { logCacheOperation({.allocationSize = allocCleanCandidate->allocationSize, .timePoint = std::chrono::high_resolution_clock::now(), - .allocationType = svmData->memoryType, + .allocationType = allocCleanCandidate->svmData->memoryType, .operationType = CacheOperationType::trimOld, .isSuccess = true}); } - svmAllocsManager->freeSVMAllocImpl(allocCleanCandidate->allocation, FreePolicyType::defer, svmData); + svmAllocsManager->freeSVMAllocImpl(allocCleanCandidate->allocation, FreePolicyType::defer, allocCleanCandidate->svmData); if (shouldLimitReuse) { allocCleanCandidate = allocations.erase(allocCleanCandidate); } else { diff --git a/shared/source/memory_manager/unified_memory_manager.h b/shared/source/memory_manager/unified_memory_manager.h index 4de2995acd..ebbe6daf15 100644 --- a/shared/source/memory_manager/unified_memory_manager.h +++ b/shared/source/memory_manager/unified_memory_manager.h @@ -154,8 +154,9 @@ class SVMAllocsManager { struct SvmCacheAllocationInfo { size_t allocationSize; void *allocation; + SvmAllocationData *svmData; std::chrono::high_resolution_clock::time_point saveTime; - SvmCacheAllocationInfo(size_t allocationSize, void *allocation) : allocationSize(allocationSize), allocation(allocation) { + SvmCacheAllocationInfo(size_t allocationSize, void *allocation, SvmAllocationData *svmData) : allocationSize(allocationSize), allocation(allocation), svmData(svmData) { saveTime = std::chrono::high_resolution_clock::now(); } bool operator<(SvmCacheAllocationInfo const &other) const { diff --git a/shared/test/unit_test/memory_manager/unified_memory_manager_cache_tests.cpp b/shared/test/unit_test/memory_manager/unified_memory_manager_cache_tests.cpp index a6630e47f3..37873b65bb 100644 --- a/shared/test/unit_test/memory_manager/unified_memory_manager_cache_tests.cpp +++ b/shared/test/unit_test/memory_manager/unified_memory_manager_cache_tests.cpp @@ -304,6 +304,9 @@ TEST_F(SvmDeviceAllocationCacheTest, givenAllocationCacheEnabledWhenFreeingDevic for (auto i = 0u; i < svmManager->usmDeviceAllocationsCache->allocations.size(); ++i) { if (svmManager->usmDeviceAllocationsCache->allocations[i].allocation == testData.allocation) { foundInCache = true; + auto svmData = svmManager->getSVMAlloc(testData.allocation); + EXPECT_NE(nullptr, svmData); + EXPECT_EQ(svmData, svmManager->usmDeviceAllocationsCache->allocations[i].svmData); break; } } @@ -1153,6 +1156,9 @@ TEST_F(SvmHostAllocationCacheTest, givenAllocationCacheEnabledWhenFreeingHostAll for (auto i = 0u; i < svmManager->usmHostAllocationsCache->allocations.size(); ++i) { if (svmManager->usmHostAllocationsCache->allocations[i].allocation == testData.allocation) { foundInCache = true; + auto svmData = svmManager->getSVMAlloc(testData.allocation); + EXPECT_NE(nullptr, svmData); + EXPECT_EQ(svmData, svmManager->usmHostAllocationsCache->allocations[i].svmData); break; } }