diff --git a/opencl/source/command_queue/command_queue.cpp b/opencl/source/command_queue/command_queue.cpp index c12db990f6..e7c1f7081a 100644 --- a/opencl/source/command_queue/command_queue.cpp +++ b/opencl/source/command_queue/command_queue.cpp @@ -576,6 +576,46 @@ bool CommandQueue::validateCapabilityForOperation(cl_command_queue_capabilities_ return operationValid && waitListValid && outEventValid; } +void CommandQueue::waitForEventsFromDifferentRootDeviceIndex(cl_uint numEventsInWaitList, const cl_event *eventWaitList, + StackVec &waitListCurrentRootDeviceIndex, bool &isEventWaitListFromPreviousRootDevice) { + isEventWaitListFromPreviousRootDevice = false; + + for (auto &rootDeviceIndex : context->getRootDeviceIndices()) { + CommandQueue *commandQueuePreviousRootDevice = nullptr; + auto maxTaskCountPreviousRootDevice = 0u; + + if (this->getDevice().getRootDeviceIndex() != rootDeviceIndex) { + for (auto eventId = 0u; eventId < numEventsInWaitList; eventId++) { + auto event = castToObject(eventWaitList[eventId]); + + if (event->getCommandQueue() && event->getCommandQueue()->getDevice().getRootDeviceIndex() == rootDeviceIndex) { + maxTaskCountPreviousRootDevice = std::max(maxTaskCountPreviousRootDevice, event->peekTaskCount()); + commandQueuePreviousRootDevice = event->getCommandQueue(); + isEventWaitListFromPreviousRootDevice = true; + } + } + + if (maxTaskCountPreviousRootDevice) { + commandQueuePreviousRootDevice->getCommandStreamReceiver(false).waitForCompletionWithTimeout(false, 0, maxTaskCountPreviousRootDevice); + } + } + } + + if (isEventWaitListFromPreviousRootDevice) { + for (auto eventId = 0u; eventId < numEventsInWaitList; eventId++) { + auto event = castToObject(eventWaitList[eventId]); + + if (event->getCommandQueue()) { + if (event->getCommandQueue()->getDevice().getRootDeviceIndex() == this->getDevice().getRootDeviceIndex()) { + waitListCurrentRootDeviceIndex.push_back(static_cast(eventWaitList[eventId])); + } + } else { + waitListCurrentRootDeviceIndex.push_back(static_cast(eventWaitList[eventId])); + } + } + } +} + cl_uint CommandQueue::getQueueFamilyIndex() const { if (isQueueFamilySelected()) { return queueFamilyIndex; diff --git a/opencl/source/command_queue/command_queue.h b/opencl/source/command_queue/command_queue.h index 1b744bfa78..019721a08b 100644 --- a/opencl/source/command_queue/command_queue.h +++ b/opencl/source/command_queue/command_queue.h @@ -303,6 +303,8 @@ class CommandQueue : public BaseObject<_cl_command_queue> { bool validateCapability(cl_command_queue_capabilities_intel capability) const; bool validateCapabilitiesForEventWaitList(cl_uint numEventsInWaitList, const cl_event *waitList) const; bool validateCapabilityForOperation(cl_command_queue_capabilities_intel capability, cl_uint numEventsInWaitList, const cl_event *waitList, const cl_event *outEvent) const; + void waitForEventsFromDifferentRootDeviceIndex(cl_uint numEventsInWaitList, const cl_event *eventWaitList, + StackVec &waitListCurrentRootDeviceIndex, bool &isEventWaitListFromPreviousRootDevice); cl_uint getQueueFamilyIndex() const; cl_uint getQueueIndexWithinFamily() const { return queueIndexWithinFamily; } bool isQueueFamilySelected() const { return queueFamilySelected; } diff --git a/opencl/source/command_queue/enqueue_common.h b/opencl/source/command_queue/enqueue_common.h index 160d63bd16..101c342c74 100644 --- a/opencl/source/command_queue/enqueue_common.h +++ b/opencl/source/command_queue/enqueue_common.h @@ -147,6 +147,16 @@ void CommandQueueHw::enqueueHandler(Surface **surfacesForResidency, return; } + StackVec waitListCurrentRootDeviceIndex; + bool isEventWaitListFromPreviousRootDevice = false; + + if (context->getRootDeviceIndices().size() > 1u) { + waitForEventsFromDifferentRootDeviceIndex(numEventsInWaitList, eventWaitList, waitListCurrentRootDeviceIndex, isEventWaitListFromPreviousRootDevice); + } + + const cl_event *eventWaitListCurrentRootDevice = isEventWaitListFromPreviousRootDevice ? waitListCurrentRootDeviceIndex.data() : eventWaitList; + cl_uint numEventsInWaitListCurrentRootDevice = isEventWaitListFromPreviousRootDevice ? static_cast(waitListCurrentRootDeviceIndex.size()) : numEventsInWaitList; + Kernel *parentKernel = multiDispatchInfo.peekParentKernel(); auto devQueue = this->getContext().getDefaultDeviceQueue(); DeviceQueueHw *devQueueHw = castToObject>(devQueue); @@ -165,7 +175,7 @@ void CommandQueueHw::enqueueHandler(Surface **surfacesForResidency, auto blockQueue = false; auto taskLevel = 0u; - obtainTaskLevelAndBlockedStatus(taskLevel, numEventsInWaitList, eventWaitList, blockQueue, commandType); + obtainTaskLevelAndBlockedStatus(taskLevel, numEventsInWaitListCurrentRootDevice, eventWaitListCurrentRootDevice, blockQueue, commandType); if (parentKernel && !blockQueue) { while (!devQueueHw->isEMCriticalSectionFree()) @@ -181,7 +191,7 @@ void CommandQueueHw::enqueueHandler(Surface **surfacesForResidency, } TimestampPacketDependencies timestampPacketDependencies; - EventsRequest eventsRequest(numEventsInWaitList, eventWaitList, event); + EventsRequest eventsRequest(numEventsInWaitListCurrentRootDevice, eventWaitListCurrentRootDevice, event); CsrDependencies csrDeps; BlitPropertiesContainer blitPropertiesContainer; @@ -300,17 +310,20 @@ void CommandQueueHw::enqueueHandler(Surface **surfacesForResidency, taskLevel); } else { UNRECOVERABLE_IF(enqueueProperties.operation != EnqueueProperties::Operation::EnqueueWithoutSubmission); - auto maxTaskCount = this->taskCount; - for (auto eventId = 0u; eventId < numEventsInWaitList; eventId++) { - auto event = castToObject(eventWaitList[eventId]); + + auto maxTaskCountCurrentRootDevice = this->taskCount; + + for (auto eventId = 0u; eventId < numEventsInWaitListCurrentRootDevice; eventId++) { + auto event = castToObject(eventWaitListCurrentRootDevice[eventId]); + if (!event->isUserEvent() && !event->isExternallySynchronized()) { - maxTaskCount = std::max(maxTaskCount, event->peekTaskCount()); + maxTaskCountCurrentRootDevice = std::max(maxTaskCountCurrentRootDevice, event->peekTaskCount()); } } //inherit data from event_wait_list and previous packets completionStamp.flushStamp = this->flushStamp->peekStamp(); - completionStamp.taskCount = maxTaskCount; + completionStamp.taskCount = maxTaskCountCurrentRootDevice; completionStamp.taskLevel = taskLevel; if (eventBuilder.getEvent() && isProfilingEnabled()) { diff --git a/opencl/test/unit_test/command_stream/command_stream_receiver_flush_task_3_tests.cpp b/opencl/test/unit_test/command_stream/command_stream_receiver_flush_task_3_tests.cpp index 1222b18de0..c0dca161c1 100644 --- a/opencl/test/unit_test/command_stream/command_stream_receiver_flush_task_3_tests.cpp +++ b/opencl/test/unit_test/command_stream/command_stream_receiver_flush_task_3_tests.cpp @@ -14,6 +14,7 @@ #include "opencl/source/helpers/hardware_commands_helper.h" #include "opencl/source/mem_obj/buffer.h" #include "opencl/source/platform/platform.h" +#include "opencl/test/unit_test/fixtures/multi_root_device_fixture.h" #include "opencl/test/unit_test/fixtures/ult_command_stream_receiver_fixture.h" #include "opencl/test/unit_test/mocks/mock_allocation_properties.h" #include "opencl/test/unit_test/mocks/mock_command_queue.h" @@ -1856,3 +1857,87 @@ HWTEST_F(CommandStreamReceiverFlushTaskTests, GivenGpuIsIdleWhenCsrIsEnabledToFl *commandStreamReceiver.getTagAddress() = 2u; } + +TEST(MultiRootDeviceCommandStreamReceiverTests, givenMultipleEventInMultiRootDeviceEnvironmentWhenTheyArePassedToMarkerThenCsrsAreWaitingForEventsFromPreviousDevices) { + auto deviceFactory = std::make_unique(4, 0); + auto device1 = deviceFactory->rootDevices[1]; + auto device2 = deviceFactory->rootDevices[2]; + auto device3 = deviceFactory->rootDevices[3]; + + auto mockCsr1 = new MockCommandStreamReceiver(*device1->executionEnvironment, device1->getRootDeviceIndex(), device1->getDeviceBitfield()); + auto mockCsr2 = new MockCommandStreamReceiver(*device2->executionEnvironment, device2->getRootDeviceIndex(), device2->getDeviceBitfield()); + auto mockCsr3 = new MockCommandStreamReceiver(*device3->executionEnvironment, device3->getRootDeviceIndex(), device3->getDeviceBitfield()); + + device1->resetCommandStreamReceiver(mockCsr1); + device2->resetCommandStreamReceiver(mockCsr2); + device3->resetCommandStreamReceiver(mockCsr3); + + cl_device_id devices[] = {device1, device2, device3}; + + auto context = std::make_unique(ClDeviceVector(devices, 3), false); + + auto pCmdQ1 = context.get()->getSpecialQueue(1u); + auto pCmdQ2 = context.get()->getSpecialQueue(2u); + auto pCmdQ3 = context.get()->getSpecialQueue(3u); + + Event event1(pCmdQ1, CL_COMMAND_NDRANGE_KERNEL, 5, 15); + Event event2(nullptr, CL_COMMAND_NDRANGE_KERNEL, 6, 16); + Event event3(pCmdQ1, CL_COMMAND_NDRANGE_KERNEL, 1, 6); + Event event4(pCmdQ1, CL_COMMAND_NDRANGE_KERNEL, 4, 20); + Event event5(pCmdQ2, CL_COMMAND_NDRANGE_KERNEL, 3, 4); + Event event6(pCmdQ3, CL_COMMAND_NDRANGE_KERNEL, 7, 21); + Event event7(pCmdQ2, CL_COMMAND_NDRANGE_KERNEL, 2, 7); + UserEvent userEvent1(&pCmdQ1->getContext()); + UserEvent userEvent2(&pCmdQ2->getContext()); + + userEvent1.setStatus(CL_COMPLETE); + userEvent2.setStatus(CL_COMPLETE); + + cl_event eventWaitList[] = + { + &event1, + &event2, + &event3, + &event4, + &event5, + &event6, + &event7, + &userEvent1, + &userEvent2, + }; + + cl_uint numEventsInWaitList = sizeof(eventWaitList) / sizeof(eventWaitList[0]); + + { + pCmdQ1->enqueueMarkerWithWaitList( + numEventsInWaitList, + eventWaitList, + nullptr); + + EXPECT_EQ(0u, mockCsr1->waitForCompletionWithTimeoutCalled); + EXPECT_EQ(1u, mockCsr2->waitForCompletionWithTimeoutCalled); + EXPECT_EQ(1u, mockCsr3->waitForCompletionWithTimeoutCalled); + } + + { + pCmdQ2->enqueueMarkerWithWaitList( + numEventsInWaitList, + eventWaitList, + nullptr); + + EXPECT_EQ(1u, mockCsr1->waitForCompletionWithTimeoutCalled); + EXPECT_EQ(1u, mockCsr2->waitForCompletionWithTimeoutCalled); + EXPECT_EQ(2u, mockCsr3->waitForCompletionWithTimeoutCalled); + } + + { + pCmdQ3->enqueueMarkerWithWaitList( + numEventsInWaitList, + eventWaitList, + nullptr); + + EXPECT_EQ(2u, mockCsr1->waitForCompletionWithTimeoutCalled); + EXPECT_EQ(2u, mockCsr2->waitForCompletionWithTimeoutCalled); + EXPECT_EQ(2u, mockCsr3->waitForCompletionWithTimeoutCalled); + } +}