fix: allow query kernel timestamp after launch cooperative kernel within cmdlist

Related-To: NEO-10191
Signed-off-by: Mateusz Jablonski <mateusz.jablonski@intel.com>
This commit is contained in:
Mateusz Jablonski
2024-01-30 11:45:12 +00:00
committed by Compute-Runtime-Automation
parent 25a3a63632
commit 2db441a0e0
2 changed files with 68 additions and 0 deletions

View File

@@ -2839,6 +2839,7 @@ ze_result_t CommandListCoreFamily<gfxCoreFamily>::appendQueryKernelTimestamps(
auto dstAllocationType = dstPtrAllocationStruct.alloc->getAllocationType();
CmdListKernelLaunchParams launchParams = {};
launchParams.isBuiltInKernel = true;
launchParams.isCooperative = containsCooperativeKernelsFlag;
launchParams.isDestinationAllocationInSystemMemory =
(dstAllocationType == NEO::AllocationType::bufferHostMemory) ||
(dstAllocationType == NEO::AllocationType::externalHostPtr);

View File

@@ -485,6 +485,73 @@ HWTEST2_F(CommandListAppendLaunchKernel, givenKernelUsingSyncBufferWhenAppendLau
}
}
HWTEST2_F(CommandListAppendLaunchKernel, whenAppendLaunchCooperativeKernelAndQueryKernelTimestampsToTheSameCmdlistThenFronEndStateIsNotChanged, IsAtLeastSkl) {
Mock<::L0::KernelImp> kernel;
auto pMockModule = std::unique_ptr<Module>(new Mock<Module>(device, nullptr));
kernel.module = pMockModule.get();
kernel.setGroupSize(4, 1, 1);
ze_group_count_t groupCount{8, 1, 1};
auto &kernelAttributes = kernel.immutableData.kernelDescriptor->kernelAttributes;
kernelAttributes.flags.usesSyncBuffer = true;
kernelAttributes.numGrfRequired = GrfConfig::defaultGrfNumber;
auto pCommandList = std::make_unique<WhiteBox<::L0::CommandListCoreFamily<gfxCoreFamily>>>();
auto &productHelper = device->getProductHelper();
auto &gfxCoreHelper = device->getGfxCoreHelper();
auto engineGroupType = NEO::EngineGroupType::compute;
if (productHelper.isCooperativeEngineSupported(*defaultHwInfo)) {
engineGroupType = gfxCoreHelper.getEngineGroupType(aub_stream::EngineType::ENGINE_CCS, EngineUsage::cooperative, *defaultHwInfo);
}
pCommandList->initialize(device, engineGroupType, 0u);
ze_event_pool_desc_t eventPoolDesc = {};
eventPoolDesc.flags = ZE_EVENT_POOL_FLAG_KERNEL_TIMESTAMP;
eventPoolDesc.count = 1;
ze_event_desc_t eventDesc = {};
eventDesc.index = 0;
ze_result_t returnValue = ZE_RESULT_ERROR_NOT_AVAILABLE;
auto eventPool = std::unique_ptr<L0::EventPool>(EventPool::create(driverHandle.get(), context, 0, nullptr, &eventPoolDesc, returnValue));
EXPECT_EQ(ZE_RESULT_SUCCESS, returnValue);
auto event = std::unique_ptr<L0::Event>(Event::create<typename FamilyType::TimestampPacketType>(eventPool.get(), &eventDesc, device));
returnValue = pCommandList->appendLaunchCooperativeKernel(kernel.toHandle(), groupCount, event->toHandle(), 0, nullptr, false);
EXPECT_EQ(ZE_RESULT_SUCCESS, returnValue);
void *alloc;
ze_device_mem_alloc_desc_t deviceDesc = {};
auto result = context->allocDeviceMem(device, &deviceDesc, 128, 1, &alloc);
EXPECT_EQ(result, ZE_RESULT_SUCCESS);
auto eventHandle = event->toHandle();
result = pCommandList->appendQueryKernelTimestamps(1u, &eventHandle, alloc, nullptr, nullptr, 1u, &eventHandle);
EXPECT_EQ(ZE_RESULT_SUCCESS, result);
pCommandList->close();
GenCmdList cmdList;
ASSERT_TRUE(FamilyType::Parse::parseCommandBuffer(
cmdList, ptrOffset(pCommandList->getCmdContainer().getCommandStream()->getCpuBase(), 0), pCommandList->getCmdContainer().getCommandStream()->getUsed()));
auto itor = find<typename FamilyType::DefaultWalkerType *>(cmdList.begin(), cmdList.end());
EXPECT_NE(itor, cmdList.end());
auto firstWalker = itor;
itor++;
itor = find<typename FamilyType::DefaultWalkerType *>(itor, cmdList.end());
EXPECT_NE(itor, cmdList.end());
auto secondWalker = itor;
itor = find<typename FamilyType::FrontEndStateCommand *>(firstWalker, secondWalker);
EXPECT_EQ(itor, secondWalker);
context->freeMem(alloc);
}
HWTEST2_F(CommandListAppendLaunchKernel, givenDisableOverdispatchPropertyWhenUpdateStreamPropertiesIsCalledThenRequiredStateAndFinalStateAreCorrectlySet, IsAtLeastSkl) {
Mock<::L0::KernelImp> kernel;
auto pMockModule = std::unique_ptr<Module>(new Mock<Module>(device, nullptr));