diff --git a/runtime/command_queue/command_queue.cpp b/runtime/command_queue/command_queue.cpp index f351fdbe5d..7b0617643a 100644 --- a/runtime/command_queue/command_queue.cpp +++ b/runtime/command_queue/command_queue.cpp @@ -134,6 +134,9 @@ void CommandQueue::waitUntilComplete(uint32_t taskCountToWait, FlushStamp flushS DEBUG_BREAK_IF(getHwTag() < taskCountToWait); latestTaskCountWaited = taskCountToWait; + + getCommandStreamReceiver().waitForTaskCountAndCleanAllocationList(taskCountToWait, TEMPORARY_ALLOCATION); + WAIT_LEAVE() } diff --git a/runtime/command_queue/enqueue_common.h b/runtime/command_queue/enqueue_common.h index e53a4e4809..7c144a3f1e 100644 --- a/runtime/command_queue/enqueue_common.h +++ b/runtime/command_queue/enqueue_common.h @@ -381,7 +381,6 @@ void CommandQueueHw::enqueueHandler(Surface **surfacesForResidency, if (printfHandler) { printfHandler->printEnqueueOutput(); } - getCommandStreamReceiver().waitForTaskCountAndCleanAllocationList(completionStamp.taskCount, TEMPORARY_ALLOCATION); } } } diff --git a/runtime/command_queue/finish.h b/runtime/command_queue/finish.h index 7aa78f55e0..6a618500ad 100644 --- a/runtime/command_queue/finish.h +++ b/runtime/command_queue/finish.h @@ -28,8 +28,6 @@ cl_int CommandQueueHw::finish(bool dcFlush) { // Stall until HW reaches CQ taskCount waitUntilComplete(taskCountToWaitFor, flushStampToWaitFor, false); - getCommandStreamReceiver().waitForTaskCountAndCleanAllocationList(taskCountToWaitFor, TEMPORARY_ALLOCATION); - return CL_SUCCESS; } } // namespace NEO diff --git a/unit_tests/command_queue/enqueue_handler_tests.cpp b/unit_tests/command_queue/enqueue_handler_tests.cpp index c54e39f3ba..477d3c7604 100644 --- a/unit_tests/command_queue/enqueue_handler_tests.cpp +++ b/unit_tests/command_queue/enqueue_handler_tests.cpp @@ -17,6 +17,7 @@ #include "unit_tests/mocks/mock_command_queue.h" #include "unit_tests/mocks/mock_context.h" #include "unit_tests/mocks/mock_csr.h" +#include "unit_tests/mocks/mock_internal_allocation_storage.h" #include "unit_tests/mocks/mock_kernel.h" #include "unit_tests/mocks/mock_mdi.h" @@ -488,20 +489,37 @@ HWTEST_F(EnqueueHandlerTest, givenEnqueueHandlerWhenSubCaptureIsOnThenActivateSu mockCmdQ->release(); } -using EnqueueHandlerTestBasic = ::testing::Test; + +struct EnqueueHandlerTestBasic : public ::testing::Test { + template + std::unique_ptr> setupFixtureAndCreateMockCommandQueue() { + auto executionEnvironment = platformImpl->peekExecutionEnvironment(); + + device.reset(MockDevice::createWithExecutionEnvironment(nullptr, executionEnvironment, 0u)); + context = std::make_unique(device.get()); + + auto mockCmdQ = std::make_unique>(context.get(), device.get(), nullptr); + + auto &ultCsr = static_cast &>(mockCmdQ->getCommandStreamReceiver()); + ultCsr.taskCount = initialTaskCount; + + mockInternalAllocationStorage = new MockInternalAllocationStorage(ultCsr); + ultCsr.internalAllocationStorage.reset(mockInternalAllocationStorage); + + return mockCmdQ; + } + + MockInternalAllocationStorage *mockInternalAllocationStorage = nullptr; + const uint32_t initialTaskCount = 100; + std::unique_ptr device; + std::unique_ptr context; +}; + HWTEST_F(EnqueueHandlerTestBasic, givenEnqueueHandlerWhenCommandIsBlokingThenCompletionStampTaskCountIsPassedToWaitForTaskCountAndCleanAllocationListAsRequiredTaskCount) { - int32_t tag; - auto executionEnvironment = platformImpl->peekExecutionEnvironment(); - auto mockCsr = new MockCsrBase(tag, *executionEnvironment); - executionEnvironment->commandStreamReceivers.resize(1); - std::unique_ptr pDevice(MockDevice::createWithExecutionEnvironment(nullptr, executionEnvironment, 0u)); - pDevice->resetCommandStreamReceiver(mockCsr); - auto context = std::make_unique(pDevice.get()); - MockKernelWithInternals kernelInternals(*pDevice, context.get()); + auto mockCmdQ = setupFixtureAndCreateMockCommandQueue(); + MockKernelWithInternals kernelInternals(*device, context.get()); Kernel *kernel = kernelInternals.mockKernel; MockMultiDispatchInfo multiDispatchInfo(kernel); - auto mockCmdQ = new MockCommandQueueHw(context.get(), pDevice.get(), 0); - mockCmdQ->deltaTaskCount = 100; mockCmdQ->template enqueueHandler(nullptr, 0, true, @@ -509,6 +527,32 @@ HWTEST_F(EnqueueHandlerTestBasic, givenEnqueueHandlerWhenCommandIsBlokingThenCom 0, nullptr, nullptr); - EXPECT_EQ(mockCsr->waitForTaskCountRequiredTaskCount, mockCmdQ->completionStampTaskCount); - mockCmdQ->release(); + EXPECT_EQ(initialTaskCount + 1, mockInternalAllocationStorage->lastCleanAllocationsTaskCount); +} + +HWTEST_F(EnqueueHandlerTestBasic, givenBlockedEnqueueHandlerWhenCommandIsBlokingThenCompletionStampTaskCountIsPassedToWaitForTaskCountAndCleanAllocationListAsRequiredTaskCount) { + auto mockCmdQ = setupFixtureAndCreateMockCommandQueue(); + + MockKernelWithInternals kernelInternals(*device, context.get()); + Kernel *kernel = kernelInternals.mockKernel; + MockMultiDispatchInfo multiDispatchInfo(kernel); + + UserEvent userEvent; + cl_event waitlist[] = {&userEvent}; + + std::thread t0([&mockCmdQ, &userEvent]() { + while (!mockCmdQ->isQueueBlocked()) { + } + userEvent.setStatus(CL_COMPLETE); + }); + mockCmdQ->template enqueueHandler(nullptr, + 0, + true, + multiDispatchInfo, + 1, + waitlist, + nullptr); + EXPECT_EQ(initialTaskCount + 1, mockInternalAllocationStorage->lastCleanAllocationsTaskCount); + + t0.join(); } diff --git a/unit_tests/mocks/mock_command_queue.h b/unit_tests/mocks/mock_command_queue.h index 2073863ddf..1af288888e 100644 --- a/unit_tests/mocks/mock_command_queue.h +++ b/unit_tests/mocks/mock_command_queue.h @@ -146,19 +146,9 @@ class MockCommandQueueHw : public CommandQueueHw { bool notifyEnqueueReadBufferCalled = false; bool notifyEnqueueReadImageCalled = false; bool cpuDataTransferHandlerCalled = false; - uint32_t completionStampTaskCount = 0; - uint32_t deltaTaskCount = 0; LinearStream *peekCommandStream() { return this->commandStream; } - - void updateFromCompletionStamp(const CompletionStamp &completionStamp) override { - BaseClass::updateFromCompletionStamp(completionStamp); - const uint32_t &referenceToCompletionStampTaskCount = completionStamp.taskCount; - uint32_t &nonConstReferenceToCompletionStampTaskCount = const_cast(referenceToCompletionStampTaskCount); - nonConstReferenceToCompletionStampTaskCount += deltaTaskCount; - completionStampTaskCount = referenceToCompletionStampTaskCount; - } }; } // namespace NEO