diff --git a/runtime/mem_obj/mem_obj.cpp b/runtime/mem_obj/mem_obj.cpp index 49293d7f37..938a0de847 100644 --- a/runtime/mem_obj/mem_obj.cpp +++ b/runtime/mem_obj/mem_obj.cpp @@ -69,7 +69,7 @@ MemObj::~MemObj() { needWait = true; } if (needWait && graphicsAllocation->isUsed()) { - waitForCsrCompletion(); + memoryManager->waitForEnginesCompletion(*graphicsAllocation); } destroyGraphicsAllocation(graphicsAllocation, doAsyncDestrucions); graphicsAllocation = nullptr; @@ -283,11 +283,6 @@ void MemObj::releaseAllocatedMapPtr() { allocatedMapPtr = nullptr; } -void MemObj::waitForCsrCompletion() { - auto osContextId = context->getDevice(0)->getDefaultEngine().osContext->getContextId(); - memoryManager->getDefaultCommandStreamReceiver(0)->waitForCompletionWithTimeout(false, TimeoutControls::maxTimeout, graphicsAllocation->getTaskCount(osContextId)); -} - void MemObj::destroyGraphicsAllocation(GraphicsAllocation *allocation, bool asyncDestroy) { if (asyncDestroy) { memoryManager->checkGpuUsageAndDestroyGraphicsAllocations(allocation); diff --git a/runtime/mem_obj/mem_obj.h b/runtime/mem_obj/mem_obj.h index 2f923632fd..65a0d81571 100644 --- a/runtime/mem_obj/mem_obj.h +++ b/runtime/mem_obj/mem_obj.h @@ -99,7 +99,6 @@ class MemObj : public BaseObject<_cl_mem> { unsigned int acquireCount = 0; Context *getContext() const { return context; } - void waitForCsrCompletion(); void destroyGraphicsAllocation(GraphicsAllocation *allocation, bool asyncDestroy); bool checkIfMemoryTransferIsRequired(size_t offsetInMemObjest, size_t offsetInHostPtr, const void *ptr, cl_command_type cmdType); bool mappingOnCpuAllowed() const; diff --git a/runtime/memory_manager/memory_manager.cpp b/runtime/memory_manager/memory_manager.cpp index ea8cd98116..f15f9343d5 100644 --- a/runtime/memory_manager/memory_manager.cpp +++ b/runtime/memory_manager/memory_manager.cpp @@ -420,4 +420,16 @@ bool MemoryManager::copyMemoryToAllocation(GraphicsAllocation *graphicsAllocatio memcpy_s(graphicsAllocation->getUnderlyingBuffer(), graphicsAllocation->getUnderlyingBufferSize(), memoryToCopy, sizeToCopy); return true; } + +void MemoryManager::waitForEnginesCompletion(GraphicsAllocation &graphicsAllocation) { + for (auto &engine : getRegisteredEngines()) { + auto osContextId = engine.osContext->getContextId(); + auto allocationTaskCount = graphicsAllocation.getTaskCount(osContextId); + if (graphicsAllocation.isUsedByOsContext(osContextId) && + allocationTaskCount > *engine.commandStreamReceiver->getTagAddress()) { + engine.commandStreamReceiver->waitForCompletionWithTimeout(false, TimeoutControls::maxTimeout, allocationTaskCount); + } + } +} + } // namespace OCLRT diff --git a/runtime/memory_manager/memory_manager.h b/runtime/memory_manager/memory_manager.h index c29322f426..fd53675a1e 100644 --- a/runtime/memory_manager/memory_manager.h +++ b/runtime/memory_manager/memory_manager.h @@ -157,6 +157,7 @@ class MemoryManager { } void waitForDeletions(); + void waitForEnginesCompletion(GraphicsAllocation &graphicsAllocation); bool isAsyncDeleterEnabled() const; bool isLocalMemorySupported() const; diff --git a/runtime/os_interface/windows/wddm_memory_manager.cpp b/runtime/os_interface/windows/wddm_memory_manager.cpp index f7dc514959..302b2a0291 100644 --- a/runtime/os_interface/windows/wddm_memory_manager.cpp +++ b/runtime/os_interface/windows/wddm_memory_manager.cpp @@ -311,11 +311,6 @@ void WddmMemoryManager::freeGraphicsMemoryImpl(GraphicsAllocation *gfxAllocation residencyController.removeFromTrimCandidateListIfUsed(input, true); } - DEBUG_BREAK_IF(DebugManager.flags.CreateMultipleDevices.get() == 0 && - gfxAllocation->isUsed() && this->executionEnvironment.commandStreamReceivers.size() > 0 && - this->getDefaultCommandStreamReceiver(0) && this->getDefaultCommandStreamReceiver(0)->getTagAddress() && - gfxAllocation->getTaskCount(defaultEngineIndex) > *this->getDefaultCommandStreamReceiver(0)->getTagAddress()); - auto defaultGmm = gfxAllocation->getDefaultGmm(); if (defaultGmm) { if (defaultGmm->isRenderCompressed && wddm->getPageTableManager()) { diff --git a/unit_tests/mem_obj/mem_obj_destruction_tests.cpp b/unit_tests/mem_obj/mem_obj_destruction_tests.cpp index 505b662298..e7a1cf65b5 100644 --- a/unit_tests/mem_obj/mem_obj_destruction_tests.cpp +++ b/unit_tests/mem_obj/mem_obj_destruction_tests.cpp @@ -123,30 +123,49 @@ TEST_P(MemObjAsyncDestructionTest, givenMemObjWithDestructableAllocationWhenAsyn } HWTEST_P(MemObjAsyncDestructionTest, givenUsedMemObjWithAsyncDestructionsEnabledThatHasDestructorCallbacksWhenItIsDestroyedThenDestructorWaitsOnTaskCount) { - makeMemObjUsed(); - bool hasCallbacks = GetParam(); if (hasCallbacks) { memObj->setDestructorCallback(emptyDestructorCallback, nullptr); } - auto mockCsr = new ::testing::NiceMock>(*device->executionEnvironment); - device->resetCommandStreamReceiver(mockCsr); + auto mockCsr0 = new ::testing::NiceMock>(*device->executionEnvironment); + auto mockCsr1 = new ::testing::NiceMock>(*device->executionEnvironment); + device->resetCommandStreamReceiver(mockCsr0, 0); + device->resetCommandStreamReceiver(mockCsr1, 1); + *mockCsr0->getTagAddress() = 0; + *mockCsr1->getTagAddress() = 0; - bool desired = true; + auto waitForCompletionWithTimeoutMock0 = [&mockCsr0](bool enableTimeout, int64_t timeoutMs, uint32_t taskCountToWait) -> bool { + *mockCsr0->getTagAddress() = taskCountReady; + return true; + }; + auto waitForCompletionWithTimeoutMock1 = [&mockCsr1](bool enableTimeout, int64_t timeoutMs, uint32_t taskCountToWait) -> bool { + *mockCsr1->getTagAddress() = taskCountReady; + return true; + }; + auto osContextId0 = mockCsr0->getOsContext().getContextId(); + auto osContextId1 = mockCsr1->getOsContext().getContextId(); - auto waitForCompletionWithTimeoutMock = [=](bool enableTimeout, int64_t timeoutMs, uint32_t taskCountToWait) -> bool { return desired; }; - auto osContextId = mockCsr->getOsContext().getContextId(); + memObj->getGraphicsAllocation()->updateTaskCount(taskCountReady, osContextId0); + memObj->getGraphicsAllocation()->updateTaskCount(taskCountReady, osContextId1); - ON_CALL(*mockCsr, waitForCompletionWithTimeout(::testing::_, ::testing::_, ::testing::_)) - .WillByDefault(::testing::Invoke(waitForCompletionWithTimeoutMock)); + ON_CALL(*mockCsr0, waitForCompletionWithTimeout(::testing::_, ::testing::_, ::testing::_)) + .WillByDefault(::testing::Invoke(waitForCompletionWithTimeoutMock0)); + ON_CALL(*mockCsr1, waitForCompletionWithTimeout(::testing::_, ::testing::_, ::testing::_)) + .WillByDefault(::testing::Invoke(waitForCompletionWithTimeoutMock1)); if (hasCallbacks) { - EXPECT_CALL(*mockCsr, waitForCompletionWithTimeout(::testing::_, TimeoutControls::maxTimeout, allocation->getTaskCount(osContextId))) + EXPECT_CALL(*mockCsr0, waitForCompletionWithTimeout(::testing::_, TimeoutControls::maxTimeout, allocation->getTaskCount(osContextId0))) + .Times(1); + EXPECT_CALL(*mockCsr1, waitForCompletionWithTimeout(::testing::_, TimeoutControls::maxTimeout, allocation->getTaskCount(osContextId1))) .Times(1); } else { - EXPECT_CALL(*mockCsr, waitForCompletionWithTimeout(::testing::_, ::testing::_, ::testing::_)) + *mockCsr0->getTagAddress() = taskCountReady; + *mockCsr1->getTagAddress() = taskCountReady; + EXPECT_CALL(*mockCsr0, waitForCompletionWithTimeout(::testing::_, ::testing::_, ::testing::_)) + .Times(0); + EXPECT_CALL(*mockCsr1, waitForCompletionWithTimeout(::testing::_, ::testing::_, ::testing::_)) .Times(0); } delete memObj; @@ -164,6 +183,7 @@ HWTEST_P(MemObjAsyncDestructionTest, givenUsedMemObjWithAsyncDestructionsEnabled auto mockCsr = new ::testing::NiceMock>(*device->executionEnvironment); device->resetCommandStreamReceiver(mockCsr); + *mockCsr->getTagAddress() = 0; auto osContextId = mockCsr->getOsContext().getContextId(); bool desired = true; @@ -206,6 +226,7 @@ HWTEST_P(MemObjAsyncDestructionTest, givenUsedMemObjWithAsyncDestructionsEnabled makeMemObjUsed(); auto mockCsr = new ::testing::NiceMock>(*device->executionEnvironment); device->resetCommandStreamReceiver(mockCsr); + *mockCsr->getTagAddress() = 0; bool desired = true; @@ -240,6 +261,7 @@ HWTEST_P(MemObjSyncDestructionTest, givenMemObjWithDestructableAllocationWhenAsy } auto mockCsr = new ::testing::NiceMock>(*device->executionEnvironment); device->resetCommandStreamReceiver(mockCsr); + *mockCsr->getTagAddress() = 0; bool desired = true; @@ -266,6 +288,7 @@ HWTEST_P(MemObjSyncDestructionTest, givenMemObjWithDestructableAllocationWhenAsy } auto mockCsr = new ::testing::NiceMock>(*device->executionEnvironment); device->resetCommandStreamReceiver(mockCsr); + *mockCsr->getTagAddress() = 0; bool desired = true; diff --git a/unit_tests/mocks/mock_device.cpp b/unit_tests/mocks/mock_device.cpp index a4d0eebdc6..d5edbf51dd 100644 --- a/unit_tests/mocks/mock_device.cpp +++ b/unit_tests/mocks/mock_device.cpp @@ -53,14 +53,18 @@ void MockDevice::injectMemoryManager(MemoryManager *memoryManager) { } void MockDevice::resetCommandStreamReceiver(CommandStreamReceiver *newCsr) { - executionEnvironment->commandStreamReceivers[getDeviceIndex()][defaultEngineIndex].reset(newCsr); - executionEnvironment->commandStreamReceivers[getDeviceIndex()][defaultEngineIndex]->initializeTagAllocation(); - executionEnvironment->commandStreamReceivers[getDeviceIndex()][defaultEngineIndex]->setPreemptionCsrAllocation(preemptionAllocation); - this->engines[defaultEngineIndex].commandStreamReceiver = newCsr; + resetCommandStreamReceiver(newCsr, defaultEngineIndex); +} - auto osContext = this->engines[defaultEngineIndex].osContext; +void MockDevice::resetCommandStreamReceiver(CommandStreamReceiver *newCsr, uint32_t engineIndex) { + executionEnvironment->commandStreamReceivers[getDeviceIndex()][engineIndex].reset(newCsr); + executionEnvironment->commandStreamReceivers[getDeviceIndex()][engineIndex]->initializeTagAllocation(); + executionEnvironment->commandStreamReceivers[getDeviceIndex()][engineIndex]->setPreemptionCsrAllocation(preemptionAllocation); + this->engines[engineIndex].commandStreamReceiver = newCsr; + + auto osContext = this->engines[engineIndex].osContext; executionEnvironment->memoryManager->getRegisteredEngines()[osContext->getContextId()].commandStreamReceiver = newCsr; - this->engines[defaultEngineIndex].commandStreamReceiver->setupContext(*osContext); + this->engines[engineIndex].commandStreamReceiver->setupContext(*osContext); UNRECOVERABLE_IF(getDeviceIndex() != 0u); } @@ -74,4 +78,4 @@ FailDevice::FailDevice(const HardwareInfo &hwInfo, ExecutionEnvironment *executi FailDeviceAfterOne::FailDeviceAfterOne(const HardwareInfo &hwInfo, ExecutionEnvironment *executionEnvironment, uint32_t deviceIndex) : MockDevice(hwInfo, executionEnvironment, deviceIndex) { this->mockMemoryManager.reset(new FailMemoryManager(1)); -} \ No newline at end of file +} diff --git a/unit_tests/mocks/mock_device.h b/unit_tests/mocks/mock_device.h index ba9c5556a2..ff59937dd3 100644 --- a/unit_tests/mocks/mock_device.h +++ b/unit_tests/mocks/mock_device.h @@ -76,6 +76,7 @@ class MockDevice : public Device { CommandStreamReceiver &getCommandStreamReceiver() const { return *engines[defaultEngineIndex].commandStreamReceiver; } void resetCommandStreamReceiver(CommandStreamReceiver *newCsr); + void resetCommandStreamReceiver(CommandStreamReceiver *newCsr, uint32_t engineIndex); void setSourceLevelDebuggerActive(bool active) { this->deviceInfo.sourceLevelDebuggerActive = active;