fix: improve task count handling in tbx download path

Related-To: HSD-18039789178

Signed-off-by: Bartosz Dunajski <bartosz.dunajski@intel.com>
This commit is contained in:
Bartosz Dunajski 2024-08-28 13:59:25 +00:00 committed by Compute-Runtime-Automation
parent 496012d82f
commit db611962f7
11 changed files with 124 additions and 32 deletions

View File

@ -64,6 +64,7 @@ struct EventImp : public Event {
bool handlePreQueryStatusOperationsAndCheckCompletion();
bool tbxDownload(NEO::CommandStreamReceiver &csr, bool &downloadedAllocation, bool &downloadedInOrdedAllocation);
void tbxDownload(NEO::Device &device, bool &downloadedAllocation, bool &downloadedInOrdedAllocation);
TaskCountType getTaskCount(const NEO::CommandStreamReceiver &csr) const;
ze_result_t calculateProfilingData();
ze_result_t queryStatusEventPackets();

View File

@ -250,29 +250,39 @@ ze_result_t EventImp<TagSizeT>::queryCounterBasedEventStatus() {
return ZE_RESULT_SUCCESS;
}
template <typename TagSizeT>
TaskCountType EventImp<TagSizeT>::getTaskCount(const NEO::CommandStreamReceiver &csr) const {
auto contextId = csr.getOsContext().getContextId();
TaskCountType taskCount = getPoolAllocation(this->device) ? getPoolAllocation(this->device)->getTaskCount(contextId) : 0;
if (inOrderExecInfo) {
if (inOrderExecInfo->getDeviceCounterAllocation()) {
taskCount = std::max(taskCount, inOrderExecInfo->getDeviceCounterAllocation()->getTaskCount(contextId));
} else {
DEBUG_BREAK_IF(true); // external allocation - not able to download
}
}
return taskCount;
}
template <typename TagSizeT>
void EventImp<TagSizeT>::downloadAllTbxAllocations() {
for (auto &csr : csrs) {
csr->downloadAllocations(true);
auto taskCount = getTaskCount(*csr);
if (taskCount == NEO::GraphicsAllocation::objectNotUsed) {
taskCount = csr->peekLatestFlushedTaskCount();
}
csr->downloadAllocations(true, taskCount);
}
for (auto &subDevice : this->device->getNEODevice()->getRootDevice()->getSubDevices()) {
for (auto const &engine : subDevice->getAllEngines()) {
auto osContextId = engine.commandStreamReceiver->getOsContext().getContextId();
auto taskCount = getTaskCount(*engine.commandStreamReceiver);
auto poolAllocation = getPoolAllocation(this->device);
bool isUsed = (poolAllocation && poolAllocation->isUsedByOsContext(osContextId));
if (inOrderExecInfo) {
if (inOrderExecInfo->getDeviceCounterAllocation()) {
isUsed |= inOrderExecInfo->getDeviceCounterAllocation()->isUsedByOsContext(osContextId);
} else {
DEBUG_BREAK_IF(true); // external allocation - not able to download
}
}
if (isUsed) {
engine.commandStreamReceiver->downloadAllocations(false);
if (taskCount != NEO::GraphicsAllocation::objectNotUsed) {
engine.commandStreamReceiver->downloadAllocations(false, taskCount);
}
}
}

View File

@ -4399,11 +4399,11 @@ HWTEST2_F(EventMultiTileDynamicPacketUseTest, givenEventCounterBasedUsedCreatedO
event1->hostSynchronize(1);
EXPECT_EQ(1u, ultCsr0->downloadAllocationsCalledCount);
EXPECT_EQ(2u, ultCsr0->downloadAllocationsCalledCount);
EXPECT_FALSE(ultCsr0->latestDownloadAllocationsBlocking);
EXPECT_EQ(2u, ultCsr1->downloadAllocationsCalledCount);
EXPECT_TRUE(ultCsr1->latestDownloadAllocationsBlocking);
EXPECT_EQ(3u, ultCsr1->downloadAllocationsCalledCount);
EXPECT_FALSE(ultCsr1->latestDownloadAllocationsBlocking);
event0->destroy();
event1->destroy();

View File

@ -229,7 +229,8 @@ class CommandStreamReceiver {
virtual WaitStatus waitForCompletionWithTimeout(const WaitParams &params, TaskCountType taskCountToWait);
WaitStatus baseWaitFunction(volatile TagAddressType *pollAddress, const WaitParams &params, TaskCountType taskCountToWait);
MOCKABLE_VIRTUAL bool testTaskCountReady(volatile TagAddressType *pollAddress, TaskCountType taskCountToWait);
virtual void downloadAllocations(bool blockingWait){};
void downloadAllocations(bool blockingWait) { downloadAllocations(blockingWait, this->latestFlushedTaskCount); };
virtual void downloadAllocations(bool blockingWait, TaskCountType taskCount){};
virtual void removeDownloadAllocation(GraphicsAllocation *alloc){};
void setSamplerCacheFlushRequired(SamplerCacheFlushState value) { this->samplerCacheFlushRequired = value; }

View File

@ -28,6 +28,7 @@ class CommandStreamReceiverSimulatedHw : public CommandStreamReceiverSimulatedCo
using CommandStreamReceiverSimulatedCommonHw<GfxFamily>::aubManager;
using CommandStreamReceiverSimulatedCommonHw<GfxFamily>::hardwareContextController;
using CommandStreamReceiverSimulatedCommonHw<GfxFamily>::writeMemory;
using CommandStreamReceiverSimulatedCommonHw<GfxFamily>::downloadAllocations;
public:
uint32_t getMemoryBank(GraphicsAllocation *allocation) const {

View File

@ -45,7 +45,7 @@ class TbxCommandStreamReceiverHw : public CommandStreamReceiverSimulatedHw<GfxFa
WaitStatus waitForTaskCountWithKmdNotifyFallback(TaskCountType taskCountToWait, FlushStamp flushStampToWait, bool useQuickKmdSleep, QueueThrottle throttle) override;
WaitStatus waitForCompletionWithTimeout(const WaitParams &params, TaskCountType taskCountToWait) override;
void downloadAllocations(bool blockingWait) override;
void downloadAllocations(bool blockingWait, TaskCountType taskCount) override;
void downloadAllocationTbx(GraphicsAllocation &gfxAllocation);
void removeDownloadAllocation(GraphicsAllocation *alloc) override;

View File

@ -576,18 +576,31 @@ void TbxCommandStreamReceiverHw<GfxFamily>::downloadAllocationTbx(GraphicsAlloca
}
template <typename GfxFamily>
void TbxCommandStreamReceiverHw<GfxFamily>::downloadAllocations(bool blockingWait) {
TaskCountType taskCountToWait = this->latestFlushedTaskCount;
void TbxCommandStreamReceiverHw<GfxFamily>::downloadAllocations(bool blockingWait, TaskCountType taskCount) {
volatile TagAddressType *pollAddress = this->getTagAddress();
constexpr uint64_t timeoutMs = 1000 * 2; // 2s
auto waitTaskCount = std::min(taskCount, this->latestFlushedTaskCount.load());
for (uint32_t i = 0; i < this->activePartitions; i++) {
while (*pollAddress < taskCountToWait) {
if (!blockingWait) {
return;
}
if (*pollAddress < waitTaskCount) {
this->downloadAllocation(*this->getTagAllocation());
auto startTime = std::chrono::high_resolution_clock::now();
uint64_t timeDiff = 0;
while (*pollAddress < waitTaskCount) {
if (!blockingWait) {
// Additional delay to reach PC in case of Event wait
timeDiff = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - startTime).count();
if (timeDiff > timeoutMs) {
return;
}
}
this->downloadAllocation(*this->getTagAllocation());
}
}
pollAddress = ptrOffset(pollAddress, this->immWritePostSyncWriteOffset);
}
auto lockCSR = this->obtainUniqueOwnership();
@ -598,7 +611,7 @@ void TbxCommandStreamReceiverHw<GfxFamily>::downloadAllocations(bool blockingWai
this->downloadAllocation(*graphicsAllocation);
// Used again while waiting for completion. Another download will be needed.
if (graphicsAllocation->getTaskCount(this->osContext->getContextId()) > taskCountToWait) {
if (graphicsAllocation->getTaskCount(this->osContext->getContextId()) > taskCount) {
notReadyAllocations.push_back(graphicsAllocation);
}
}

View File

@ -273,7 +273,7 @@ class UltCommandStreamReceiver : public CommandStreamReceiverHw<GfxFamily>, publ
}
void setPreemptionAllocation(GraphicsAllocation *allocation) { this->preemptionAllocation = allocation; }
void downloadAllocations(bool blockingWait) override {
void downloadAllocations(bool blockingWait, TaskCountType taskCount) override {
downloadAllocationsCalledCount++;
latestDownloadAllocationsBlocking = blockingWait;
}

View File

@ -177,7 +177,7 @@ class MockCommandStreamReceiver : public CommandStreamReceiver {
return commandStreamReceiverType;
}
void downloadAllocations(bool blockingWait) override {
void downloadAllocations(bool blockingWait, TaskCountType taskCount) override {
downloadAllocationsCalledCount++;
}

View File

@ -79,6 +79,8 @@ class MockTbxCsr : public TbxCommandStreamReceiverHw<GfxFamily> {
template <typename GfxFamily>
struct MockTbxCsrRegisterDownloadedAllocations : TbxCommandStreamReceiverHw<GfxFamily> {
using CommandStreamReceiver::downloadAllocationImpl;
using CommandStreamReceiver::downloadAllocations;
using CommandStreamReceiver::latestFlushedTaskCount;
using CommandStreamReceiver::tagsMultiAllocation;
using TbxCommandStreamReceiverHw<GfxFamily>::flushSubmissionsAndDownloadAllocations;

View File

@ -489,9 +489,14 @@ HWTEST_F(TbxCommandSteamSimpleTest, givenTbxCsrWhenUpdatingTaskCountDuringWaitTh
MockTbxCsrRegisterDownloadedAllocations<FamilyType> tbxCsr{*pDevice->executionEnvironment, pDevice->getRootDeviceIndex(), pDevice->getDeviceBitfield()};
MockOsContext osContext(0, EngineDescriptorHelper::getDefaultDescriptor(pDevice->getDeviceBitfield()));
tbxCsr.downloadAllocationImpl = nullptr;
tbxCsr.setupContext(osContext);
tbxCsr.initializeTagAllocation();
*tbxCsr.getTagAddress() = 0u;
auto tagAddress = tbxCsr.getTagAddress();
*tagAddress = 0u;
tbxCsr.latestFlushedTaskCount = 1;
MockGraphicsAllocation allocation1, allocation2, allocation3;
@ -507,7 +512,7 @@ HWTEST_F(TbxCommandSteamSimpleTest, givenTbxCsrWhenUpdatingTaskCountDuringWaitTh
EXPECT_EQ(0u, tbxCsr.obtainUniqueOwnershipCalled);
EXPECT_EQ(3u, tbxCsr.allocationsForDownload.size());
*tbxCsr.getTagAddress() = 1u;
*tagAddress = 1u;
tbxCsr.downloadAllocations(false);
EXPECT_EQ(1u, tbxCsr.obtainUniqueOwnershipCalled);
@ -515,6 +520,65 @@ HWTEST_F(TbxCommandSteamSimpleTest, givenTbxCsrWhenUpdatingTaskCountDuringWaitTh
EXPECT_NE(tbxCsr.allocationsForDownload.find(&allocation2), tbxCsr.allocationsForDownload.end());
}
HWTEST_F(TbxCommandSteamSimpleTest, givenAllocationWithBiggerTaskCountThanWaitingTaskCountThenDontRemoveFromContainer) {
MockTbxCsrRegisterDownloadedAllocations<FamilyType> tbxCsr{*pDevice->executionEnvironment, pDevice->getRootDeviceIndex(), pDevice->getDeviceBitfield()};
MockOsContext osContext(0, EngineDescriptorHelper::getDefaultDescriptor(pDevice->getDeviceBitfield()));
tbxCsr.setupContext(osContext);
tbxCsr.initializeTagAllocation();
*tbxCsr.getTagAddress() = 0u;
tbxCsr.latestFlushedTaskCount = 1;
MockGraphicsAllocation allocation1, allocation2, allocation3;
tbxCsr.allocationsForDownload = {&allocation1, &allocation2, &allocation3};
tbxCsr.makeResident(allocation1);
tbxCsr.makeResident(allocation2);
tbxCsr.makeResident(allocation3);
auto contextId = tbxCsr.getOsContext().getContextId();
allocation1.updateTaskCount(2, contextId);
allocation2.updateTaskCount(1, contextId);
allocation3.updateTaskCount(2, contextId);
*tbxCsr.getTagAddress() = 1u;
tbxCsr.downloadAllocations(false, 1);
EXPECT_EQ(1u, tbxCsr.obtainUniqueOwnershipCalled);
EXPECT_EQ(2u, tbxCsr.allocationsForDownload.size());
EXPECT_NE(tbxCsr.allocationsForDownload.find(&allocation1), tbxCsr.allocationsForDownload.end());
EXPECT_NE(tbxCsr.allocationsForDownload.find(&allocation3), tbxCsr.allocationsForDownload.end());
}
HWTEST_F(TbxCommandSteamSimpleTest, givenDifferentTaskCountThanLatestFlushedWhenDownloadingThenPickSmallest) {
MockTbxCsrRegisterDownloadedAllocations<FamilyType> tbxCsr{*pDevice->executionEnvironment, pDevice->getRootDeviceIndex(), pDevice->getDeviceBitfield()};
MockOsContext osContext(0, EngineDescriptorHelper::getDefaultDescriptor(pDevice->getDeviceBitfield()));
tbxCsr.downloadAllocationImpl = nullptr;
tbxCsr.setupContext(osContext);
tbxCsr.initializeTagAllocation();
*tbxCsr.getTagAddress() = 0;
tbxCsr.latestFlushedTaskCount = 1;
tbxCsr.downloadAllocations(false, 2);
EXPECT_EQ(0u, tbxCsr.obtainUniqueOwnershipCalled);
*tbxCsr.getTagAddress() = 1;
tbxCsr.downloadAllocations(false, 2);
EXPECT_EQ(1u, tbxCsr.obtainUniqueOwnershipCalled);
tbxCsr.latestFlushedTaskCount = 3;
tbxCsr.downloadAllocations(false, 2);
EXPECT_EQ(1u, tbxCsr.obtainUniqueOwnershipCalled);
*tbxCsr.getTagAddress() = 3;
tbxCsr.downloadAllocations(false, 2);
EXPECT_EQ(2u, tbxCsr.obtainUniqueOwnershipCalled);
}
HWTEST_F(TbxCommandSteamSimpleTest, whenTbxCommandStreamReceiverIsCreatedThenPPGTTAndGGTTCreatedHavePhysicalAddressAllocatorSet) {
MockTbxCsr<FamilyType> tbxCsr(*pDevice->executionEnvironment, pDevice->getDeviceBitfield());