feature: redesign host function workers

Each host function gets its unique ID within a CSR,
uses 1 mi store to write ID - to signal that host function is ready,
and 1 mi semaphore wait will wait for the ID to be cleared,
Use 0th bit from ID as pending/completed flag,
host function ID is incremented by 2, and starts with 1.
So each ID will always have 0bit set.
This is a must have since semaphore wait can wait for 4 bytes only.

Adjust command buffer programming and patching logic to IDs.

Add hostFunction callable class - using invoke method,
which stores required information about callback.

Add host function streamer - stores all host function data
for a given CSR.
All user provided host functions are stored in unordered map,
where key is host function ID.

Add host function scheduler, and a thread pool - under debug flag
Single threaded scheduler loops over all registered host function streamers,
dispatch ready to execute host functions to thread pool.

Allow for out of order host functions execution for OOQ - under debug flag,
each host function has bool isInOrder flag which indicates if it can be
executed Out Of Order - in this mode, ID tag will be cleared immediately,
so semaphore wait will unblock before the host function execution.

Remove Host Function worker CV and atomics based implementation.

Rename classes

Related-To: NEO-14577
Signed-off-by: Kamil Kopryk <kamil.kopryk@intel.com>
This commit is contained in:
Kamil Kopryk
2025-11-24 20:35:14 +00:00
committed by Compute-Runtime-Automation
parent 4f7d2f0315
commit 56b30d1803
41 changed files with 1248 additions and 783 deletions

View File

@@ -44,6 +44,7 @@ class TagNodeBase;
struct EncodeDispatchKernelArgs;
class CommandStreamReceiver;
class GraphicsAllocation;
struct HostFunction;
} // namespace NEO
namespace L0 {
@@ -553,7 +554,7 @@ struct CommandList : _ze_command_list_handle_t {
virtual void dispatchHostFunction(void *pHostFunction,
void *pUserData) = 0;
virtual void addHostFunctionToPatchCommands(uint64_t userHostFunctionAddress, uint64_t userDataAddress) = 0;
virtual void addHostFunctionToPatchCommands(const NEO::HostFunction &hostFunction) = 0;
NEO::GraphicsAllocation *getAllocationFromHostPtrMap(const void *buffer, uint64_t bufferSize, bool copyOffload);
NEO::GraphicsAllocation *getHostPtrAlloc(const void *buffer, uint64_t bufferSize, bool hostCopyAllowed, bool copyOffload);

View File

@@ -261,7 +261,7 @@ struct CommandListCoreFamily : public CommandListImp {
protected:
void dispatchHostFunction(void *pHostFunction,
void *pUserData) override;
void addHostFunctionToPatchCommands(uint64_t userHostFunctionAddress, uint64_t userDataAddress) override;
void addHostFunctionToPatchCommands(const NEO::HostFunction &hostFunction) override;
MOCKABLE_VIRTUAL ze_result_t appendMemoryCopyKernelWithGA(uintptr_t dstPtr, NEO::GraphicsAllocation *dstPtrAlloc,
uint64_t dstOffset, uintptr_t srcPtr,

View File

@@ -8,6 +8,7 @@
#include "shared/source/command_container/cmdcontainer.h"
#include "shared/source/command_container/encode_surface_state.h"
#include "shared/source/command_stream/command_stream_receiver.h"
#include "shared/source/command_stream/host_function.h"
#include "shared/source/command_stream/linear_stream.h"
#include "shared/source/command_stream/transfer_direction.h"
#include "shared/source/device/device.h"
@@ -1884,37 +1885,45 @@ void CommandListCoreFamily<gfxCoreFamily>::dispatchHostFunction(
uint64_t userHostFunctionAddress = reinterpret_cast<uint64_t>(pHostFunction);
uint64_t userDataAddress = reinterpret_cast<uint64_t>(pUserData);
NEO::HostFunction hostFunction{
.hostFunctionAddress = userHostFunctionAddress,
.userDataAddress = userDataAddress,
.isInOrder = true};
if (NEO::debugManager.flags.AllowForOutOfOrderHostFunctionExecution.get() != 0) {
hostFunction.isInOrder = isInSynchronousMode();
}
if (isImmediateType()) {
auto csr = getCsr(false);
csr->ensureHostFunctionWorkerStarted();
csr->signalHostFunctionWorker();
NEO::HostFunctionHelper::programHostFunction<GfxFamily>(*this->commandContainer.getCommandStream(), csr->getHostFunctionData(), userHostFunctionAddress, userDataAddress);
NEO::HostFunctionHelper<GfxFamily>::programHostFunction(*this->commandContainer.getCommandStream(),
csr->getHostFunctionStreamer(),
std::move(hostFunction));
csr->signalHostFunctionWorker(1u);
} else {
addHostFunctionToPatchCommands(userHostFunctionAddress, userDataAddress);
addHostFunctionToPatchCommands(hostFunction);
}
}
template <GFXCORE_FAMILY gfxCoreFamily>
void CommandListCoreFamily<gfxCoreFamily>::addHostFunctionToPatchCommands(uint64_t userHostFunctionAddress, uint64_t userDataAddress) {
void CommandListCoreFamily<gfxCoreFamily>::addHostFunctionToPatchCommands(const NEO::HostFunction &hostFunction) {
using MI_STORE_DATA_IMM = typename GfxFamily::MI_STORE_DATA_IMM;
using MI_SEMAPHORE_WAIT = typename GfxFamily::MI_SEMAPHORE_WAIT;
commandsToPatch.reserve(commandsToPatch.size() + 4);
auto additionalSize = 2u;
commandsToPatch.reserve(commandsToPatch.size() + additionalSize);
commandsToPatch.push_back({.pCommand = commandContainer.getCommandStream()->getSpace(sizeof(MI_STORE_DATA_IMM)),
.baseAddress = userHostFunctionAddress,
.type = CommandToPatch::HostFunctionEntry});
commandsToPatch.push_back({.pCommand = commandContainer.getCommandStream()->getSpace(sizeof(MI_STORE_DATA_IMM)),
.baseAddress = userDataAddress,
.type = CommandToPatch::HostFunctionUserData});
commandsToPatch.push_back({.pCommand = commandContainer.getCommandStream()->getSpace(sizeof(MI_STORE_DATA_IMM)),
.type = CommandToPatch::HostFunctionSignalInternalTag});
.baseAddress = hostFunction.hostFunctionAddress,
.gpuAddress = hostFunction.userDataAddress,
.type = CommandToPatch::HostFunctionId,
.isInOrder = hostFunction.isInOrder});
commandsToPatch.push_back({.pCommand = commandContainer.getCommandStream()->getSpace(sizeof(MI_SEMAPHORE_WAIT)),
.type = CommandToPatch::HostFunctionWaitInternalTag});
.type = CommandToPatch::HostFunctionWait});
}
template <GFXCORE_FAMILY gfxCoreFamily>
@@ -4083,10 +4092,8 @@ void CommandListCoreFamily<gfxCoreFamily>::clearCommandsToPatch() {
case CommandToPatch::PauseOnEnqueueSemaphoreEnd:
case CommandToPatch::PauseOnEnqueuePipeControlStart:
case CommandToPatch::PauseOnEnqueuePipeControlEnd:
case CommandToPatch::HostFunctionEntry:
case CommandToPatch::HostFunctionUserData:
case CommandToPatch::HostFunctionSignalInternalTag:
case CommandToPatch::HostFunctionWaitInternalTag:
case CommandToPatch::HostFunctionId:
case CommandToPatch::HostFunctionWait:
UNRECOVERABLE_IF(commandToPatch.pCommand == nullptr);
break;
case CommandToPatch::ComputeWalkerInlineDataScratch:
@@ -4111,10 +4118,8 @@ void CommandListCoreFamily<gfxCoreFamily>::clearCommandsToPatch() {
case CommandToPatch::PauseOnEnqueueSemaphoreEnd:
case CommandToPatch::PauseOnEnqueuePipeControlStart:
case CommandToPatch::PauseOnEnqueuePipeControlEnd:
case CommandToPatch::HostFunctionEntry:
case CommandToPatch::HostFunctionUserData:
case CommandToPatch::HostFunctionSignalInternalTag:
case CommandToPatch::HostFunctionWaitInternalTag:
case CommandToPatch::HostFunctionId:
case CommandToPatch::HostFunctionWait:
UNRECOVERABLE_IF(commandToPatch.pCommand == nullptr);
break;
case CommandToPatch::ComputeWalkerInlineDataScratch:

View File

@@ -32,10 +32,8 @@ struct CommandToPatch {
ComputeWalkerImplicitArgsScratch,
NoopSpace,
PrefetchKernelMemory,
HostFunctionEntry,
HostFunctionUserData,
HostFunctionSignalInternalTag,
HostFunctionWaitInternalTag,
HostFunctionId,
HostFunctionWait,
Invalid
};
void *pDestination = nullptr;
@@ -47,6 +45,7 @@ struct CommandToPatch {
size_t inOrderPatchListIndex = 0;
size_t patchSize = 0;
CommandType type = Invalid;
bool isInOrder = false;
};
using CommandToPatchContainer = std::vector<CommandToPatch>;

View File

@@ -137,6 +137,8 @@ void CommandQueueHw<gfxCoreFamily>::patchCommands(CommandList &commandList, uint
using MI_SEMAPHORE_WAIT = typename GfxFamily::MI_SEMAPHORE_WAIT;
using COMPARE_OPERATION = typename GfxFamily::MI_SEMAPHORE_WAIT::COMPARE_OPERATION;
uint32_t hostFunctionsCounter = 0;
auto &commandsToPatch = commandList.getCommandsToPatch();
for (auto &commandToPatch : commandsToPatch) {
switch (commandToPatch.type) {
@@ -196,29 +198,40 @@ void CommandQueueHw<gfxCoreFamily>::patchCommands(CommandList &commandList, uint
}
break;
}
case CommandToPatch::HostFunctionEntry:
case CommandToPatch::HostFunctionId: {
auto callbackAddress = commandToPatch.baseAddress;
auto userDataAddress = commandToPatch.gpuAddress;
bool isInOrder = commandToPatch.isInOrder;
NEO::HostFunction hostFunction = {.hostFunctionAddress = callbackAddress,
.userDataAddress = userDataAddress,
.isInOrder = isInOrder};
csr->ensureHostFunctionWorkerStarted();
csr->signalHostFunctionWorker();
NEO::HostFunctionHelper::programHostFunctionAddress<GfxFamily>(nullptr, commandToPatch.pCommand, csr->getHostFunctionData(), commandToPatch.baseAddress);
break;
case CommandToPatch::HostFunctionUserData:
NEO::HostFunctionHelper::programHostFunctionUserData<GfxFamily>(nullptr, commandToPatch.pCommand, csr->getHostFunctionData(), commandToPatch.baseAddress);
NEO::HostFunctionHelper<GfxFamily>::programHostFunctionId(nullptr,
commandToPatch.pCommand,
csr->getHostFunctionStreamer(),
std::move(hostFunction));
hostFunctionsCounter++;
break;
}
case CommandToPatch::HostFunctionWait: {
NEO::HostFunctionHelper<GfxFamily>::programHostFunctionWaitForCompletion(nullptr,
commandToPatch.pCommand,
csr->getHostFunctionStreamer());
case CommandToPatch::HostFunctionSignalInternalTag:
NEO::HostFunctionHelper::programSignalHostFunctionStart<GfxFamily>(nullptr, commandToPatch.pCommand, csr->getHostFunctionData());
break;
case CommandToPatch::HostFunctionWaitInternalTag:
NEO::HostFunctionHelper::programWaitForHostFunctionCompletion<GfxFamily>(nullptr, commandToPatch.pCommand, csr->getHostFunctionData());
break;
}
default: {
UNRECOVERABLE_IF(true);
}
}
}
if (hostFunctionsCounter > 0) {
csr->signalHostFunctionWorker(hostFunctionsCounter);
}
}
} // namespace L0

View File

@@ -170,6 +170,8 @@ void CommandQueueHw<gfxCoreFamily>::patchCommands(CommandList &commandList, uint
using MI_SEMAPHORE_WAIT = typename GfxFamily::MI_SEMAPHORE_WAIT;
using COMPARE_OPERATION = typename GfxFamily::MI_SEMAPHORE_WAIT::COMPARE_OPERATION;
uint32_t hostFunctionsCounter = 0;
auto &commandsToPatch = commandList.getCommandsToPatch();
for (auto &commandToPatch : commandsToPatch) {
switch (commandToPatch.type) {
@@ -284,27 +286,39 @@ void CommandQueueHw<gfxCoreFamily>::patchCommands(CommandList &commandList, uint
}
break;
}
case CommandToPatch::HostFunctionEntry:
case CommandToPatch::HostFunctionId: {
auto callbackAddress = commandToPatch.baseAddress;
auto userDataAddress = commandToPatch.gpuAddress;
bool isInOrder = commandToPatch.isInOrder;
NEO::HostFunction hostFunction = {.hostFunctionAddress = callbackAddress,
.userDataAddress = userDataAddress,
.isInOrder = isInOrder};
csr->ensureHostFunctionWorkerStarted();
csr->signalHostFunctionWorker();
NEO::HostFunctionHelper::programHostFunctionAddress<GfxFamily>(nullptr, commandToPatch.pCommand, csr->getHostFunctionData(), commandToPatch.baseAddress);
break;
case CommandToPatch::HostFunctionUserData:
NEO::HostFunctionHelper::programHostFunctionUserData<GfxFamily>(nullptr, commandToPatch.pCommand, csr->getHostFunctionData(), commandToPatch.baseAddress);
NEO::HostFunctionHelper<GfxFamily>::programHostFunctionId(nullptr,
commandToPatch.pCommand,
csr->getHostFunctionStreamer(),
std::move(hostFunction));
hostFunctionsCounter++;
break;
}
case CommandToPatch::HostFunctionWait: {
NEO::HostFunctionHelper<GfxFamily>::programHostFunctionWaitForCompletion(nullptr,
commandToPatch.pCommand,
csr->getHostFunctionStreamer());
case CommandToPatch::HostFunctionSignalInternalTag:
NEO::HostFunctionHelper::programSignalHostFunctionStart<GfxFamily>(nullptr, commandToPatch.pCommand, csr->getHostFunctionData());
break;
case CommandToPatch::HostFunctionWaitInternalTag:
NEO::HostFunctionHelper::programWaitForHostFunctionCompletion<GfxFamily>(nullptr, commandToPatch.pCommand, csr->getHostFunctionData());
break;
}
default:
UNRECOVERABLE_IF(true);
}
}
if (hostFunctionsCounter > 0) {
csr->signalHostFunctionWorker(hostFunctionsCounter);
}
}
} // namespace L0

View File

@@ -656,7 +656,7 @@ struct Mock<CommandList> : public CommandList {
void *pNext,
ze_event_handle_t hSignalEvent, uint32_t numWaitEvents, ze_event_handle_t *phWaitEvents, CmdListHostFunctionParameters &parameters));
ADDMETHOD_NOBASE_VOIDRETURN(dispatchHostFunction, (void *pHostFunction, void *pUserData));
ADDMETHOD_NOBASE_VOIDRETURN(addHostFunctionToPatchCommands, (uint64_t userHostFunctionAddress, uint64_t userDataAddress));
ADDMETHOD_NOBASE_VOIDRETURN(addHostFunctionToPatchCommands, (const NEO::HostFunction &hostFunction));
uint8_t *batchBuffer = nullptr;
NEO::GraphicsAllocation *mockAllocation = nullptr;
};

View File

@@ -1462,22 +1462,12 @@ HWTEST_F(CommandListCreateTests, givenNonEmptyCommandsToPatchWhenClearCommandsTo
EXPECT_NO_THROW(pCommandList->clearCommandsToPatch());
EXPECT_TRUE(pCommandList->commandsToPatch.empty());
commandToPatch.type = CommandToPatch::HostFunctionEntry;
commandToPatch.type = CommandToPatch::HostFunctionId;
pCommandList->commandsToPatch.push_back(commandToPatch);
EXPECT_NO_THROW(pCommandList->clearCommandsToPatch());
EXPECT_TRUE(pCommandList->commandsToPatch.empty());
commandToPatch.type = CommandToPatch::HostFunctionUserData;
pCommandList->commandsToPatch.push_back(commandToPatch);
EXPECT_NO_THROW(pCommandList->clearCommandsToPatch());
EXPECT_TRUE(pCommandList->commandsToPatch.empty());
commandToPatch.type = CommandToPatch::HostFunctionSignalInternalTag;
pCommandList->commandsToPatch.push_back(commandToPatch);
EXPECT_NO_THROW(pCommandList->clearCommandsToPatch());
EXPECT_TRUE(pCommandList->commandsToPatch.empty());
commandToPatch.type = CommandToPatch::HostFunctionWaitInternalTag;
commandToPatch.type = CommandToPatch::HostFunctionWait;
pCommandList->commandsToPatch.push_back(commandToPatch);
EXPECT_NO_THROW(pCommandList->clearCommandsToPatch());
EXPECT_TRUE(pCommandList->commandsToPatch.empty());

View File

@@ -166,30 +166,35 @@ HWTEST_F(HostFunctionTests, givenRegularCmdListWhenDispatchHostFunctionIsCalledT
void *pUserData = reinterpret_cast<void *>(0xd'0000);
commandList->dispatchHostFunction(pHostFunction, pUserData);
ASSERT_EQ(4u, commandList->commandsToPatch.size());
ASSERT_EQ(2u, commandList->commandsToPatch.size());
EXPECT_EQ(CommandToPatch::HostFunctionEntry, commandList->commandsToPatch[0].type);
EXPECT_EQ(CommandToPatch::HostFunctionId, commandList->commandsToPatch[0].type);
EXPECT_EQ(reinterpret_cast<uint64_t>(pHostFunction), commandList->commandsToPatch[0].baseAddress);
EXPECT_EQ(reinterpret_cast<uint64_t>(pUserData), commandList->commandsToPatch[0].gpuAddress);
EXPECT_EQ(true, commandList->commandsToPatch[0].isInOrder);
EXPECT_NE(nullptr, commandList->commandsToPatch[0].pCommand);
EXPECT_EQ(CommandToPatch::HostFunctionUserData, commandList->commandsToPatch[1].type);
EXPECT_EQ(reinterpret_cast<uint64_t>(pUserData), commandList->commandsToPatch[1].baseAddress);
EXPECT_EQ(CommandToPatch::HostFunctionWait, commandList->commandsToPatch[1].type);
EXPECT_NE(nullptr, commandList->commandsToPatch[1].pCommand);
EXPECT_EQ(CommandToPatch::HostFunctionSignalInternalTag, commandList->commandsToPatch[2].type);
EXPECT_NE(nullptr, commandList->commandsToPatch[2].pCommand);
EXPECT_EQ(CommandToPatch::HostFunctionWaitInternalTag, commandList->commandsToPatch[3].type);
EXPECT_NE(nullptr, commandList->commandsToPatch[3].pCommand);
}
HWTEST_F(HostFunctionTests, givenImmediateCmdListWhenDispatchHostFunctionIscalledThenCorrectCommandsAreProgrammedAndHostFunctionDataWasInitializedInCsr) {
using HostFunctionTestsImmediateCmdListParams = std::tuple<bool, ze_command_queue_mode_t>;
class HostFunctionTestsImmediateCmdListTest : public HostFunctionTests,
public ::testing::WithParamInterface<HostFunctionTestsImmediateCmdListParams> {
};
HWTEST_P(HostFunctionTestsImmediateCmdListTest, givenImmediateCmdListWhenDispatchHostFunctionIscalledThenCorrectCommandsAreProgrammedAndHostFunctionWasInitializedInCsr) {
using MI_STORE_DATA_IMM = typename FamilyType::MI_STORE_DATA_IMM;
using MI_SEMAPHORE_WAIT = typename FamilyType::MI_SEMAPHORE_WAIT;
auto [allowForOutOfOrderHostFunctionExecution, queueMode] = GetParam();
DebugManagerStateRestore restorer;
NEO::debugManager.flags.AllowForOutOfOrderHostFunctionExecution.set(allowForOutOfOrderHostFunctionExecution);
ze_result_t returnValue;
ze_command_queue_desc_t queueDesc = {};
queueDesc.mode = ZE_COMMAND_QUEUE_MODE_SYNCHRONOUS;
queueDesc.mode = queueMode;
std::unique_ptr<L0::ult::CommandList> commandList(CommandList::whiteboxCast(CommandList::createImmediate(productFamily, device, &queueDesc, false, NEO::EngineGroupType::renderCompute, returnValue)));
void *pHostFunction = reinterpret_cast<void *>(0xa'0000);
@@ -199,53 +204,57 @@ HWTEST_F(HostFunctionTests, givenImmediateCmdListWhenDispatchHostFunctionIscalle
uint64_t userDataAddress = reinterpret_cast<uint64_t>(pUserData);
auto *cmdStream = commandList->commandContainer.getCommandStream();
auto offset = cmdStream->getUsed();
commandList->dispatchHostFunction(pHostFunction, pUserData);
// different csr
auto csr = commandList->getCsr(false);
auto *hostFunctionAllocation = csr->getHostFunctionDataAllocation();
auto *hostFunctionAllocation = csr->getHostFunctionStreamer().getHostFunctionIdAllocation();
ASSERT_NE(nullptr, hostFunctionAllocation);
auto &hostFunctionData = csr->getHostFunctionData();
auto hostFunctionIdAddress = csr->getHostFunctionStreamer().getHostFunctionIdGpuAddress();
HardwareParse hwParser;
hwParser.parseCommands<FamilyType>(*cmdStream, 0);
hwParser.parseCommands<FamilyType>(*cmdStream, offset);
auto miStores = findAll<MI_STORE_DATA_IMM *>(hwParser.cmdList.begin(), hwParser.cmdList.end());
EXPECT_EQ(3u, miStores.size());
EXPECT_EQ(1u, miStores.size());
auto miWait = findAll<MI_SEMAPHORE_WAIT *>(hwParser.cmdList.begin(), hwParser.cmdList.end());
EXPECT_EQ(1u, miWait.size());
// program callback address
// program callback
uint64_t expectedHostFunctionId = 1u;
auto miStoreUserHostFunction = genCmdCast<MI_STORE_DATA_IMM *>(*miStores[0]);
EXPECT_EQ(reinterpret_cast<uint64_t>(hostFunctionData.entry), miStoreUserHostFunction->getAddress());
EXPECT_EQ(getLowPart(hostFunctionAddress), miStoreUserHostFunction->getDataDword0());
EXPECT_EQ(getHighPart(hostFunctionAddress), miStoreUserHostFunction->getDataDword1());
EXPECT_EQ(hostFunctionIdAddress, miStoreUserHostFunction->getAddress());
EXPECT_EQ(getLowPart(expectedHostFunctionId), miStoreUserHostFunction->getDataDword0());
EXPECT_EQ(getHighPart(expectedHostFunctionId), miStoreUserHostFunction->getDataDword1());
EXPECT_TRUE(miStoreUserHostFunction->getStoreQword());
// program callback data
auto miStoreUserData = genCmdCast<MI_STORE_DATA_IMM *>(*miStores[1]);
EXPECT_EQ(reinterpret_cast<uint64_t>(hostFunctionData.userData), miStoreUserData->getAddress());
EXPECT_EQ(getLowPart(userDataAddress), miStoreUserData->getDataDword0());
EXPECT_EQ(getHighPart(userDataAddress), miStoreUserData->getDataDword1());
EXPECT_TRUE(miStoreUserData->getStoreQword());
// signal pending job
auto miStoreSignalTag = genCmdCast<MI_STORE_DATA_IMM *>(*miStores[2]);
EXPECT_EQ(reinterpret_cast<uint64_t>(hostFunctionData.internalTag), miStoreSignalTag->getAddress());
EXPECT_EQ(static_cast<uint32_t>(HostFunctionTagStatus::pending), miStoreSignalTag->getDataDword0());
EXPECT_FALSE(miStoreSignalTag->getStoreQword());
// wait for completion
auto miWaitTag = genCmdCast<MI_SEMAPHORE_WAIT *>(*miWait[0]);
EXPECT_EQ(reinterpret_cast<uint64_t>(hostFunctionData.internalTag), miWaitTag->getSemaphoreGraphicsAddress());
EXPECT_EQ(static_cast<uint32_t>(HostFunctionTagStatus::completed), miWaitTag->getSemaphoreDataDword());
EXPECT_EQ(hostFunctionIdAddress, miWaitTag->getSemaphoreGraphicsAddress());
EXPECT_EQ(static_cast<uint32_t>(HostFunctionStatus::completed), miWaitTag->getSemaphoreDataDword());
EXPECT_EQ(MI_SEMAPHORE_WAIT::COMPARE_OPERATION_SAD_EQUAL_SDD, miWaitTag->getCompareOperation());
EXPECT_EQ(MI_SEMAPHORE_WAIT::WAIT_MODE_POLLING_MODE, miWaitTag->getWaitMode());
*csr->getHostFunctionStreamer().getHostFunctionIdPtr() = expectedHostFunctionId;
auto hostFunction = csr->getHostFunctionStreamer().getHostFunction();
EXPECT_EQ(hostFunctionAddress, hostFunction.hostFunctionAddress);
EXPECT_EQ(userDataAddress, hostFunction.userDataAddress);
auto isInOrderExpected = (queueMode == ZE_COMMAND_QUEUE_MODE_SYNCHRONOUS) || (allowForOutOfOrderHostFunctionExecution == false);
EXPECT_EQ(isInOrderExpected, hostFunction.isInOrder);
}
INSTANTIATE_TEST_SUITE_P(HostFunctionTestsImmediateCmdListTestValues,
HostFunctionTestsImmediateCmdListTest,
::testing::Combine(::testing::Values(true, false),
::testing::Values(ZE_COMMAND_QUEUE_MODE_ASYNCHRONOUS,
ZE_COMMAND_QUEUE_MODE_SYNCHRONOUS)));
using HostFunctionsInOrderCmdListTests = InOrderCmdListFixture;
HWTEST_F(HostFunctionsInOrderCmdListTests, givenInOrderModeWhenAppendHostFunctionThenWaitAndSignalDependenciesAreProgrammed) {
@@ -298,13 +307,7 @@ HWTEST_F(HostFunctionsInOrderCmdListTests, givenInOrderModeWhenAppendHostFunctio
auto storeDataImmIt1 = find<MI_STORE_DATA_IMM *>(itor, cmdList2.end());
ASSERT_NE(cmdList2.end(), storeDataImmIt1);
auto storeDataImmIt2 = find<MI_STORE_DATA_IMM *>(storeDataImmIt1, cmdList2.end());
ASSERT_NE(cmdList2.end(), storeDataImmIt2);
auto storeDataImmIt3 = find<MI_STORE_DATA_IMM *>(storeDataImmIt2, cmdList2.end());
ASSERT_NE(cmdList2.end(), storeDataImmIt3);
auto semaphoreWait2 = find<MI_SEMAPHORE_WAIT *>(storeDataImmIt3, cmdList2.end());
auto semaphoreWait2 = find<MI_SEMAPHORE_WAIT *>(storeDataImmIt1, cmdList2.end());
ASSERT_NE(cmdList2.end(), semaphoreWait2);
// verify signal event

View File

@@ -1159,75 +1159,100 @@ HWTEST_F(HostFunctionsCmdPatchTests, givenHostFunctionPatchCommandsWhenPatchComm
auto commandQueue = std::make_unique<MockCommandQueueHw<FamilyType::gfxCoreFamily>>(device, csr, &desc);
MockCommandStreamReceiver mockCsr(*neoDevice->executionEnvironment, neoDevice->getRootDeviceIndex(), neoDevice->getDeviceBitfield());
mockCsr.initializeTagAllocation();
mockCsr.createHostFunctionStreamer();
const auto oldCsr = commandQueue->csr;
commandQueue->csr = &mockCsr;
auto commandList = std::make_unique<WhiteBox<::L0::CommandListCoreFamily<FamilyType::gfxCoreFamily>>>();
commandList->commandsToPatch.clear();
constexpr uint64_t pHostFunction = std::numeric_limits<uint64_t>::max() - 1024u;
constexpr uint64_t pUserData = std::numeric_limits<uint64_t>::max() - 4096u;
MI_STORE_DATA_IMM callbackAddressMiStore{};
MI_STORE_DATA_IMM userDataMiStore{};
MI_STORE_DATA_IMM internalTagMiStore{};
MI_SEMAPHORE_WAIT internalTagMiWait{};
constexpr uint64_t pHostFunction1 = std::numeric_limits<uint64_t>::max() - 1024u;
constexpr uint64_t pUserData1 = std::numeric_limits<uint64_t>::max() - 4096u;
MI_STORE_DATA_IMM miStore1{};
MI_SEMAPHORE_WAIT miWait1{};
{
CommandToPatch commandToPatch;
commandToPatch.type = CommandToPatch::HostFunctionEntry;
commandToPatch.baseAddress = pHostFunction;
commandToPatch.pCommand = reinterpret_cast<void *>(&callbackAddressMiStore);
commandToPatch.type = CommandToPatch::HostFunctionId;
commandToPatch.baseAddress = pHostFunction1;
commandToPatch.gpuAddress = pUserData1;
commandToPatch.isInOrder = false;
commandToPatch.pCommand = reinterpret_cast<void *>(&miStore1);
commandList->commandsToPatch.push_back(commandToPatch);
}
{
CommandToPatch commandToPatch;
commandToPatch.type = CommandToPatch::HostFunctionUserData;
commandToPatch.baseAddress = pUserData;
commandToPatch.pCommand = reinterpret_cast<void *>(&userDataMiStore);
commandToPatch.type = CommandToPatch::HostFunctionWait;
commandToPatch.pCommand = reinterpret_cast<void *>(&miWait1);
commandList->commandsToPatch.push_back(commandToPatch);
}
constexpr uint64_t pHostFunction2 = std::numeric_limits<uint64_t>::max() - 1024u - 8;
constexpr uint64_t pUserData2 = std::numeric_limits<uint64_t>::max() - 4096u - 8;
MI_STORE_DATA_IMM miStore2{};
MI_SEMAPHORE_WAIT miWait2{};
{
CommandToPatch commandToPatch;
commandToPatch.type = CommandToPatch::HostFunctionId;
commandToPatch.baseAddress = pHostFunction2;
commandToPatch.gpuAddress = pUserData2;
commandToPatch.isInOrder = true;
commandToPatch.pCommand = reinterpret_cast<void *>(&miStore2);
commandList->commandsToPatch.push_back(commandToPatch);
}
{
CommandToPatch commandToPatch;
commandToPatch.type = CommandToPatch::HostFunctionSignalInternalTag;
commandToPatch.pCommand = reinterpret_cast<void *>(&internalTagMiStore);
commandList->commandsToPatch.push_back(commandToPatch);
}
{
CommandToPatch commandToPatch;
commandToPatch.type = CommandToPatch::HostFunctionWaitInternalTag;
commandToPatch.pCommand = reinterpret_cast<void *>(&internalTagMiWait);
commandToPatch.type = CommandToPatch::HostFunctionWait;
commandToPatch.pCommand = reinterpret_cast<void *>(&miWait2);
commandList->commandsToPatch.push_back(commandToPatch);
}
commandQueue->patchCommands(*commandList, 0, false, nullptr);
EXPECT_NE(nullptr, commandQueue->csr->getHostFunctionDataAllocation());
EXPECT_EQ(1u, mockCsr.createHostFunctionWorkerCounter);
EXPECT_EQ(1u, mockCsr.signalHostFunctionWorkerCounter);
EXPECT_EQ(2u, mockCsr.signalHostFunctionWorkerCounter);
auto &hostFunctionDataFromCsr = commandQueue->csr->getHostFunctionData();
auto &hostFunctionStreamer = commandQueue->csr->getHostFunctionStreamer();
uint64_t hostFunctionIdGpuAddress = hostFunctionStreamer.getHostFunctionIdGpuAddress();
// callback address - mi store
EXPECT_EQ(getLowPart(pHostFunction), callbackAddressMiStore.getDataDword0());
EXPECT_EQ(getHighPart(pHostFunction), callbackAddressMiStore.getDataDword1());
EXPECT_TRUE(callbackAddressMiStore.getStoreQword());
EXPECT_EQ(reinterpret_cast<uint64_t>(hostFunctionDataFromCsr.entry), callbackAddressMiStore.getAddress());
{
// callback id - mi store
uint64_t expectedId = 1u;
EXPECT_EQ(getLowPart(expectedId), miStore1.getDataDword0());
EXPECT_EQ(getHighPart(expectedId), miStore1.getDataDword1());
EXPECT_TRUE(miStore1.getStoreQword());
EXPECT_EQ(hostFunctionIdGpuAddress, miStore1.getAddress());
// userData address - mi store
EXPECT_EQ(getLowPart(pUserData), userDataMiStore.getDataDword0());
EXPECT_EQ(getHighPart(pUserData), userDataMiStore.getDataDword1());
EXPECT_TRUE(userDataMiStore.getStoreQword());
EXPECT_EQ(reinterpret_cast<uint64_t>(hostFunctionDataFromCsr.userData), userDataMiStore.getAddress());
// semaphore wait
EXPECT_EQ(static_cast<uint32_t>(HostFunctionStatus::completed), miWait1.getSemaphoreDataDword());
EXPECT_EQ(hostFunctionIdGpuAddress, miWait1.getSemaphoreGraphicsAddress());
// internal tag signal - mi store
EXPECT_EQ(static_cast<uint32_t>(HostFunctionTagStatus::pending), internalTagMiStore.getDataDword0());
EXPECT_FALSE(internalTagMiStore.getStoreQword());
EXPECT_EQ(reinterpret_cast<uint64_t>(hostFunctionDataFromCsr.internalTag), internalTagMiStore.getAddress());
// host function data programmed in host function streamer
*hostFunctionStreamer.getHostFunctionIdPtr() = expectedId;
auto hostFunction = hostFunctionStreamer.getHostFunction();
EXPECT_EQ(pHostFunction1, hostFunction.hostFunctionAddress);
EXPECT_EQ(pUserData1, hostFunction.userDataAddress);
EXPECT_FALSE(hostFunction.isInOrder);
}
{
// callback id - mi store
uint64_t expectedId = 3u;
EXPECT_EQ(getLowPart(expectedId), miStore2.getDataDword0());
EXPECT_EQ(getHighPart(expectedId), miStore2.getDataDword1());
EXPECT_TRUE(miStore2.getStoreQword());
EXPECT_EQ(hostFunctionIdGpuAddress, miStore2.getAddress());
// internal tag wait - semaphore wait
EXPECT_EQ(static_cast<uint32_t>(HostFunctionTagStatus::completed), internalTagMiWait.getSemaphoreDataDword());
EXPECT_EQ(reinterpret_cast<uint64_t>(hostFunctionDataFromCsr.internalTag), internalTagMiWait.getSemaphoreGraphicsAddress());
// semaphore wait
EXPECT_EQ(static_cast<uint32_t>(HostFunctionStatus::completed), miWait2.getSemaphoreDataDword());
EXPECT_EQ(hostFunctionIdGpuAddress, miWait2.getSemaphoreGraphicsAddress());
// host function data programmed in host function streamer
*hostFunctionStreamer.getHostFunctionIdPtr() = expectedId;
auto hostFunction = hostFunctionStreamer.getHostFunction();
EXPECT_EQ(pHostFunction2, hostFunction.hostFunctionAddress);
EXPECT_EQ(pUserData2, hostFunction.userDataAddress);
EXPECT_TRUE(hostFunction.isInOrder);
}
commandQueue->csr = oldCsr;
}

View File

@@ -33,14 +33,15 @@ set(NEO_CORE_COMMAND_STREAM
${CMAKE_CURRENT_SOURCE_DIR}/host_function.cpp
${CMAKE_CURRENT_SOURCE_DIR}/host_function.inl
${CMAKE_CURRENT_SOURCE_DIR}/host_function_enablers.inl
${CMAKE_CURRENT_SOURCE_DIR}/host_function_worker_cv.h
${CMAKE_CURRENT_SOURCE_DIR}/host_function_worker_cv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/host_function_worker_interface.h
${CMAKE_CURRENT_SOURCE_DIR}/host_function_worker_interface.cpp
${CMAKE_CURRENT_SOURCE_DIR}/host_function_worker_counting_semaphore.cpp
${CMAKE_CURRENT_SOURCE_DIR}/host_function_worker_counting_semaphore.h
${CMAKE_CURRENT_SOURCE_DIR}/host_function_worker_atomic.h
${CMAKE_CURRENT_SOURCE_DIR}/host_function_worker_atomic.cpp
${CMAKE_CURRENT_SOURCE_DIR}/host_function_scheduler.cpp
${CMAKE_CURRENT_SOURCE_DIR}/host_function_scheduler.h
${CMAKE_CURRENT_SOURCE_DIR}/host_function_worker_thread_pool.cpp
${CMAKE_CURRENT_SOURCE_DIR}/host_function_worker_thread_pool.h
${CMAKE_CURRENT_SOURCE_DIR}/host_function_interface.h
${CMAKE_CURRENT_SOURCE_DIR}/linear_stream.cpp
${CMAKE_CURRENT_SOURCE_DIR}/linear_stream.h
${CMAKE_CURRENT_SOURCE_DIR}/preemption.cpp

View File

@@ -98,7 +98,7 @@ CommandStreamReceiver::CommandStreamReceiver(ExecutionEnvironment &executionEnvi
auto &compilerProductHelper = rootDeviceEnvironment.getHelper<CompilerProductHelper>();
this->heaplessModeEnabled = compilerProductHelper.isHeaplessModeEnabled(hwInfo);
this->heaplessStateInitEnabled = compilerProductHelper.isHeaplessStateInitEnabled(heaplessModeEnabled);
this->hostFunctionWorkerMode = debugManager.flags.HostFunctionWorkMode.get();
this->hostFunctionWorkerMode = static_cast<HostFunctionWorkerMode>(debugManager.flags.HostFunctionWorkMode.get());
}
CommandStreamReceiver::~CommandStreamReceiver() {
@@ -243,16 +243,21 @@ void CommandStreamReceiver::createHostFunctionWorker() {
return;
}
this->hostFunctionWorker = HostFunctionFactory::createHostFunctionWorker(this->hostFunctionWorkerMode,
this->isAubMode(),
this->downloadAllocationImpl,
this->getHostFunctionDataAllocation(),
&this->getHostFunctionData());
bool skipHostFunctionExecution = getType() == NEO::CommandStreamReceiverType::aub ||
getType() == NEO::CommandStreamReceiverType::nullAub;
this->hostFunctionWorker->start();
auto *rootDeviceEnvironment = this->executionEnvironment.rootDeviceEnvironments[rootDeviceIndex].get();
HostFunctionFactory::createAndSetHostFunctionWorker(this->hostFunctionWorkerMode,
skipHostFunctionExecution,
this,
rootDeviceEnvironment);
auto *streamer = &this->getHostFunctionStreamer();
this->hostFunctionWorker->start(streamer);
}
IHostFunctionWorker *CommandStreamReceiver::getHostFunctionWorker() {
HostFunctionWorker *CommandStreamReceiver::getHostFunctionWorker() {
return this->hostFunctionWorker;
}
@@ -437,8 +442,6 @@ void CommandStreamReceiver::cleanupResources() {
DEBUG_BREAK_IF(tagAllocation != nullptr);
DEBUG_BREAK_IF(tagAddress != nullptr);
hostFunctionDataAllocation = nullptr;
for (auto graphicsAllocation : tagsMultiAllocation->getGraphicsAllocations()) {
getMemoryManager()->freeGraphicsMemory(graphicsAllocation);
}
@@ -482,9 +485,16 @@ void CommandStreamReceiver::cleanupResources() {
}
void CommandStreamReceiver ::cleanupHostFunctionWorker() {
hostFunctionWorker->finish();
delete hostFunctionWorker;
hostFunctionWorker = nullptr;
if (hostFunctionWorker) {
hostFunctionWorker->finish();
if (hostFunctionWorkerMode != HostFunctionWorkerMode::schedulerWithThreadPool) {
delete hostFunctionWorker;
}
hostFunctionWorker = nullptr;
}
}
WaitStatus CommandStreamReceiver::waitForCompletionWithTimeout(const WaitParams &params, TaskCountType taskCountToWait) {
@@ -720,8 +730,8 @@ void *CommandStreamReceiver::getIndirectHeapCurrentPtr(IndirectHeapType heapType
return nullptr;
}
void CommandStreamReceiver::signalHostFunctionWorker() {
hostFunctionWorker->submit();
void CommandStreamReceiver::signalHostFunctionWorker(uint32_t nHostFunctions) {
hostFunctionWorker->submit(nHostFunctions);
}
void CommandStreamReceiver::ensureHostFunctionWorkerStarted() {
@@ -736,33 +746,26 @@ void CommandStreamReceiver::startHostFunctionWorker() {
return;
}
createHostFunctionStreamer();
createHostFunctionWorker();
this->hostFunctionWorkerStarted.store(true, std::memory_order_release);
}
void CommandStreamReceiver::initializeHostFunctionData() {
void CommandStreamReceiver::createHostFunctionStreamer() {
auto tagAddress = this->tagAllocation->getUnderlyingBuffer();
auto offset = TagAllocationLayout::hostFunctionDataOffset + this->immWritePostSyncWriteOffset;
auto hostFunctionIdAddress = ptrOffset(tagAddress, static_cast<size_t>(offset));
auto entryAddress = ptrOffset(tagAddress, HostFunctionHelper::entryOffset + TagAllocationLayout::hostFunctionDataOffset);
auto userDataAddress = ptrOffset(tagAddress, HostFunctionHelper::userDataOffset + TagAllocationLayout::hostFunctionDataOffset);
auto internalTagAddress = ptrOffset(tagAddress, HostFunctionHelper::internalTagOffset + TagAllocationLayout::hostFunctionDataOffset);
this->hostFunctionData.entry = reinterpret_cast<uint64_t *>(entryAddress);
this->hostFunctionData.userData = reinterpret_cast<uint64_t *>(userDataAddress);
this->hostFunctionData.internalTag = reinterpret_cast<uint32_t *>(internalTagAddress);
*this->hostFunctionData.entry = 0;
*this->hostFunctionData.userData = 0;
*this->hostFunctionData.internalTag = 0;
this->hostFunctionStreamer = std::make_unique<HostFunctionStreamer>(this->tagAllocation,
hostFunctionIdAddress,
this->downloadAllocationImpl,
isTbxMode());
}
HostFunctionData &CommandStreamReceiver::getHostFunctionData() {
return hostFunctionData;
}
GraphicsAllocation *CommandStreamReceiver::getHostFunctionDataAllocation() {
return tagAllocation;
HostFunctionStreamer &CommandStreamReceiver::getHostFunctionStreamer() {
return *hostFunctionStreamer.get();
}
IndirectHeap &CommandStreamReceiver::getIndirectHeap(IndirectHeap::Type heapType,
@@ -910,6 +913,7 @@ bool CommandStreamReceiver::initializeTagAllocation() {
auto tagAddress = this->tagAddress;
auto ucTagAddress = this->ucTagAddress;
auto completionFence = reinterpret_cast<TaskCountType *>(getCompletionAddress());
auto hostFunctionDataAddress = reinterpret_cast<uint64_t *>(ptrOffset(this->tagAllocation->getUnderlyingBuffer(), TagAllocationLayout::hostFunctionDataOffset));
UNRECOVERABLE_IF(!completionFence);
uint32_t subDevices = static_cast<uint32_t>(this->deviceBitfield.count());
for (uint32_t i = 0; i < subDevices; i++) {
@@ -919,6 +923,8 @@ bool CommandStreamReceiver::initializeTagAllocation() {
ucTagAddress = ptrOffset(ucTagAddress, this->immWritePostSyncWriteOffset);
*completionFence = 0;
completionFence = ptrOffset(completionFence, this->immWritePostSyncWriteOffset);
*hostFunctionDataAddress = 0u;
hostFunctionDataAddress = ptrOffset(hostFunctionDataAddress, this->immWritePostSyncWriteOffset);
}
*this->debugPauseStateAddress = debugManager.flags.EnableNullHardware.get() ? DebugPauseState::disabled : DebugPauseState::waitingForFirstSemaphore;
@@ -932,8 +938,6 @@ bool CommandStreamReceiver::initializeTagAllocation() {
this->barrierCountTagAddress = ptrOffset(this->tagAddress, TagAllocationLayout::barrierCountOffset);
initializeHostFunctionData();
return true;
}

View File

@@ -67,7 +67,7 @@ class KmdNotifyHelper;
class GfxCoreHelper;
class ProductHelper;
class ReleaseHelper;
class IHostFunctionWorker;
class HostFunctionWorker;
enum class WaitStatus;
struct AubSubCaptureStatus;
class SharedPoolAllocation;
@@ -150,7 +150,7 @@ class CommandStreamReceiver : NEO::NonCopyableAndNonMovableClass {
WaitStatus waitForTaskCountAndCleanAllocationList(TaskCountType requiredTaskCount, uint32_t allocationUsage);
MOCKABLE_VIRTUAL WaitStatus waitForTaskCountAndCleanTemporaryAllocationList(TaskCountType requiredTaskCount);
MOCKABLE_VIRTUAL void createHostFunctionWorker();
IHostFunctionWorker *getHostFunctionWorker();
HostFunctionWorker *getHostFunctionWorker();
LinearStream &getCS(size_t minRequiredSize = 1024u);
OSInterface *getOSInterface() const;
@@ -570,15 +570,19 @@ class CommandStreamReceiver : NEO::NonCopyableAndNonMovableClass {
bool isLatestFlushIsTaskCountUpdateOnly() const { return latestFlushIsTaskCountUpdateOnly; }
MOCKABLE_VIRTUAL uint32_t getContextGroupId() const;
MOCKABLE_VIRTUAL void signalHostFunctionWorker();
MOCKABLE_VIRTUAL void signalHostFunctionWorker(uint32_t nHostFunctions);
void ensureHostFunctionWorkerStarted();
HostFunctionData &getHostFunctionData();
GraphicsAllocation *getHostFunctionDataAllocation();
void createHostFunctionStreamer();
HostFunctionStreamer &getHostFunctionStreamer();
[[nodiscard]] std::unique_lock<MutexType> obtainHostFunctionWorkerStartLock();
void setHostFunctionWorker(HostFunctionWorker *hostFunctionWorker) {
this->hostFunctionWorker = hostFunctionWorker;
}
protected:
void initializeHostFunctionData();
MOCKABLE_VIRTUAL void startHostFunctionWorker();
virtual CompletionStamp flushTaskHeapless(LinearStream &commandStreamTask, size_t commandStreamTaskStart,
@@ -617,7 +621,8 @@ class CommandStreamReceiver : NEO::NonCopyableAndNonMovableClass {
std::unique_ptr<TagAllocatorBase> timestampPacketAllocator;
std::unique_ptr<Thread> userPauseConfirmation;
std::unique_ptr<IndirectHeap> globalStatelessHeap;
IHostFunctionWorker *hostFunctionWorker = nullptr;
std::unique_ptr<HostFunctionStreamer> hostFunctionStreamer;
HostFunctionWorker *hostFunctionWorker = nullptr;
ResidencyContainer residencyAllocations;
PrivateAllocsToReuseContainer ownedPrivateAllocations;
@@ -654,14 +659,12 @@ class CommandStreamReceiver : NEO::NonCopyableAndNonMovableClass {
GraphicsAllocation *clearColorAllocation = nullptr;
GraphicsAllocation *workPartitionAllocation = nullptr;
GraphicsAllocation *globalStatelessHeapAllocation = nullptr;
MultiGraphicsAllocation *tagsMultiAllocation = nullptr;
GraphicsAllocation *hostFunctionDataAllocation = nullptr;
IndirectHeap *indirectHeap[IndirectHeapType::numTypes];
OsContext *osContext = nullptr;
CommandStreamReceiver *primaryCsr = nullptr;
TaskCountType *completionFenceValuePointer = nullptr;
HostFunctionData hostFunctionData;
std::atomic<TaskCountType> barrierCount{0};
// current taskLevel. Used for determining if a PIPE_CONTROL is needed.
@@ -680,7 +683,7 @@ class CommandStreamReceiver : NEO::NonCopyableAndNonMovableClass {
uint32_t lastSentL3Config = 0;
uint32_t latestSentStatelessMocsConfig;
uint64_t lastSentSliceCount;
int32_t hostFunctionWorkerMode = -1;
HostFunctionWorkerMode hostFunctionWorkerMode = HostFunctionWorkerMode::countingSemaphore;
uint32_t requiredScratchSlot0Size = 0;
uint32_t requiredScratchSlot1Size = 0;
uint32_t lastAdditionalKernelExecInfo;

View File

@@ -8,30 +8,151 @@
#include "shared/source/command_stream/host_function.h"
#include "shared/source/command_stream/command_stream_receiver.h"
#include "shared/source/command_stream/host_function_worker_atomic.h"
#include "shared/source/command_stream/host_function_interface.h"
#include "shared/source/command_stream/host_function_scheduler.h"
#include "shared/source/command_stream/host_function_worker_counting_semaphore.h"
#include "shared/source/command_stream/host_function_worker_cv.h"
#include "shared/source/command_stream/host_function_worker_interface.h"
#include "shared/source/debug_settings/debug_settings_manager.h"
#include "shared/source/execution_environment/root_device_environment.h"
#include "shared/source/memory_manager/graphics_allocation.h"
namespace NEO::HostFunctionFactory {
namespace NEO {
HostFunctionStreamer::HostFunctionStreamer(GraphicsAllocation *allocation,
void *hostFunctionIdAddress,
const std::function<void(GraphicsAllocation &)> &downloadAllocationImpl,
bool isTbx)
: hostFunctionIdAddress(reinterpret_cast<volatile uint64_t *>(hostFunctionIdAddress)),
allocation(allocation),
downloadAllocationImpl(downloadAllocationImpl),
nextHostFunctionId(1), // start from 1 to keep 0 bit for pending/completed status
isTbx(isTbx) {
}
IHostFunctionWorker *createHostFunctionWorker(int32_t hostFunctionWorkerMode,
bool isAubMode,
const std::function<void(GraphicsAllocation &)> &downloadAllocationImpl,
GraphicsAllocation *allocation,
HostFunctionData *data) {
uint64_t HostFunctionStreamer::getHostFunctionIdGpuAddress() const {
return reinterpret_cast<uint64_t>(hostFunctionIdAddress);
}
bool skipHostFunctionExecution = isAubMode;
volatile uint64_t *HostFunctionStreamer::getHostFunctionIdPtr() const {
return hostFunctionIdAddress;
}
switch (hostFunctionWorkerMode) {
default:
case 0:
return new HostFunctionWorkerCountingSemaphore(skipHostFunctionExecution, downloadAllocationImpl, allocation, data);
case 1:
return new HostFunctionWorkerCV(skipHostFunctionExecution, downloadAllocationImpl, allocation, data);
case 2:
return new HostFunctionWorkerAtomic(skipHostFunctionExecution, downloadAllocationImpl, allocation, data);
uint64_t HostFunctionStreamer::getNextHostFunctionIdAndIncrement() {
// increment by 2 to keep 0 bit for pending/completed status
return nextHostFunctionId.fetch_add(2, std::memory_order_acq_rel);
}
uint64_t HostFunctionStreamer::getHostFunctionId() const {
return *hostFunctionIdAddress;
}
void HostFunctionStreamer::signalHostFunctionCompletion(const HostFunction &hostFunction) {
if (hostFunction.isInOrder) {
*hostFunctionIdAddress = HostFunctionStatus::completed;
isBusy.store(false, std::memory_order_release);
}
}
} // namespace NEO::HostFunctionFactory
void HostFunctionStreamer::prepareForExecution(const HostFunction &hostFunction) {
if (hostFunction.isInOrder) {
isBusy.store(true, std::memory_order_release);
} else {
*hostFunctionIdAddress = HostFunctionStatus::completed;
}
pendingHostFunctions.fetch_sub(1, std::memory_order_acq_rel);
}
HostFunction HostFunctionStreamer::getHostFunction() {
std::unique_lock lock(hostFunctionsMutex);
auto hostFunctionId = getHostFunctionId();
auto node = hostFunctions.extract(hostFunctionId);
if (!node) {
UNRECOVERABLE_IF(true);
return HostFunction{};
}
return std::move(node.mapped());
}
HostFunction HostFunctionStreamer::getHostFunction(uint64_t hostFunctionId) {
std::unique_lock lock(hostFunctionsMutex);
auto node = hostFunctions.extract(hostFunctionId);
if (!node) {
UNRECOVERABLE_IF(true);
return HostFunction{};
}
return std::move(node.mapped());
}
void HostFunctionStreamer::addHostFunction(uint64_t hostFunctionId, HostFunction &&hostFunction) {
{
std::unique_lock lock(hostFunctionsMutex);
hostFunctions.emplace(hostFunctionId, std::move(hostFunction));
}
pendingHostFunctions.fetch_add(1, std::memory_order_acq_rel);
}
GraphicsAllocation *HostFunctionStreamer::getHostFunctionIdAllocation() const {
return allocation;
}
void HostFunctionStreamer::downloadHostFunctionAllocation() const {
if (isTbx) {
downloadAllocationImpl(*allocation);
}
}
uint64_t HostFunctionStreamer::isHostFunctionReadyToExecute() const {
if (pendingHostFunctions.load(std::memory_order_acquire) == 0) {
return false;
}
if (isBusy.load(std::memory_order_acquire)) {
return false;
}
downloadHostFunctionAllocation();
auto hostFunctionId = getHostFunctionId();
return hostFunctionId;
}
namespace HostFunctionFactory {
void createAndSetHostFunctionWorker(HostFunctionWorkerMode hostFunctionWorkerMode,
bool skipHostFunctionExecution,
CommandStreamReceiver *csr,
RootDeviceEnvironment *rootDeviceEnvironment) {
if (csr->getHostFunctionWorker() != nullptr) {
return;
}
switch (hostFunctionWorkerMode) {
default:
case HostFunctionWorkerMode::defaultMode:
case HostFunctionWorkerMode::countingSemaphore:
csr->setHostFunctionWorker(new HostFunctionWorkerCountingSemaphore(skipHostFunctionExecution));
break;
case HostFunctionWorkerMode::schedulerWithThreadPool: {
auto scheduler = rootDeviceEnvironment->getHostFunctionScheduler();
if (scheduler == nullptr) {
int32_t nWorkers = (debugManager.flags.HostFunctionThreadPoolSize.get() > 0)
? debugManager.flags.HostFunctionThreadPoolSize.get()
: HostFunctionThreadPoolHelper::unlimitedThreads;
auto createdScheduler = std::make_unique<HostFunctionScheduler>(skipHostFunctionExecution,
nWorkers);
rootDeviceEnvironment->setHostFunctionScheduler(std::move(createdScheduler));
}
scheduler = rootDeviceEnvironment->getHostFunctionScheduler();
csr->setHostFunctionWorker(scheduler);
break;
}
}
}
} // namespace HostFunctionFactory
} // namespace NEO

View File

@@ -7,57 +7,96 @@
#pragma once
#include <atomic>
#include <cstddef>
#include <cstdint>
#include <deque>
#include <functional>
#include <mutex>
#include <unordered_map>
namespace NEO {
class LinearStream;
class CommandStreamReceiver;
class IHostFunctionWorker;
class IHostFunction;
class GraphicsAllocation;
struct RootDeviceEnvironment;
struct HostFunctionData {
volatile uint64_t *entry = nullptr;
volatile uint64_t *userData = nullptr;
volatile uint32_t *internalTag = nullptr;
struct HostFunction {
uint64_t hostFunctionAddress = 0;
uint64_t userDataAddress = 0;
bool isInOrder = true;
void invoke() const {
using CallbackT = void (*)(void *);
CallbackT callback = reinterpret_cast<CallbackT>(hostFunctionAddress);
void *callbackData = reinterpret_cast<void *>(userDataAddress);
callback(callbackData);
}
};
enum class HostFunctionTagStatus : uint32_t {
completed = 0,
pending = 1
namespace HostFunctionStatus {
inline constexpr uint64_t completed = 0;
} // namespace HostFunctionStatus
namespace HostFunctionThreadPoolHelper {
inline constexpr int32_t unlimitedThreads = -1; // each CSR that uses host function creates worker thread in thread pool
}
class HostFunctionStreamer {
public:
HostFunctionStreamer(GraphicsAllocation *allocation, void *hostFunctionIdAddress, const std::function<void(GraphicsAllocation &)> &downloadAllocationImpl, bool isTbx);
~HostFunctionStreamer() = default;
uint64_t isHostFunctionReadyToExecute() const;
GraphicsAllocation *getHostFunctionIdAllocation() const;
HostFunction getHostFunction();
HostFunction getHostFunction(uint64_t hostFunctionId);
uint64_t getHostFunctionId() const;
uint64_t getHostFunctionIdGpuAddress() const;
volatile uint64_t *getHostFunctionIdPtr() const;
uint64_t getNextHostFunctionIdAndIncrement();
void addHostFunction(uint64_t hostFunctionId, HostFunction &&hostFunction);
void downloadHostFunctionAllocation() const;
void signalHostFunctionCompletion(const HostFunction &hostFunction);
void prepareForExecution(const HostFunction &hostFunction);
private:
std::mutex hostFunctionsMutex;
std::unordered_map<uint64_t, HostFunction> hostFunctions;
volatile uint64_t *hostFunctionIdAddress = nullptr; // 0 bit - used to signal that host function is pending or completed
GraphicsAllocation *allocation = nullptr;
std::function<void(GraphicsAllocation &)> downloadAllocationImpl;
std::atomic<uint64_t> nextHostFunctionId{1};
std::atomic<uint32_t> pendingHostFunctions{0};
std::atomic<bool> isBusy{false};
const bool isTbx = false;
};
enum class HostFunctionWorkerMode : int32_t {
defaultMode = -1,
countingSemaphore = 0,
schedulerWithThreadPool = 1,
};
template <typename GfxFamily>
struct HostFunctionHelper {
constexpr static size_t entryOffset = offsetof(HostFunctionData, entry);
constexpr static size_t userDataOffset = offsetof(HostFunctionData, userData);
constexpr static size_t internalTagOffset = offsetof(HostFunctionData, internalTag);
template <typename GfxFamily>
static void programHostFunction(LinearStream &commandStream, const HostFunctionData &hostFunctionData, uint64_t userHostFunctionAddress, uint64_t userDataAddress);
template <typename GfxFamily>
static void programHostFunctionAddress(LinearStream *commandStream, void *cmdBuffer, const HostFunctionData &hostFunctionData, uint64_t userHostFunctionAddress);
template <typename GfxFamily>
static void programHostFunctionUserData(LinearStream *commandStream, void *cmdBuffer, const HostFunctionData &hostFunctionData, uint64_t userDataAddress);
template <typename GfxFamily>
static void programSignalHostFunctionStart(LinearStream *commandStream, void *cmdBuffer, const HostFunctionData &hostFunctionData);
template <typename GfxFamily>
static void programWaitForHostFunctionCompletion(LinearStream *commandStream, void *cmdBuffer, const HostFunctionData &hostFunctionData);
static void programHostFunction(LinearStream &commandStream, HostFunctionStreamer &streamer, HostFunction &&hostFunction);
static void programHostFunctionId(LinearStream *commandStream, void *cmdBuffer, HostFunctionStreamer &streamer, HostFunction &&hostFunction);
static void programHostFunctionWaitForCompletion(LinearStream *commandStream, void *cmdBuffer, const HostFunctionStreamer &streamer);
};
namespace HostFunctionFactory {
IHostFunctionWorker *createHostFunctionWorker(int32_t hostFunctionWorkerMode,
bool isAubMode,
const std::function<void(GraphicsAllocation &)> &downloadAllocationImpl,
GraphicsAllocation *allocation,
HostFunctionData *data);
}
void createAndSetHostFunctionWorker(HostFunctionWorkerMode hostFunctionWorkerMode,
bool skipHostFunctionExecution,
CommandStreamReceiver *csr,
RootDeviceEnvironment *rootDeviceEnvironment);
} // namespace HostFunctionFactory
} // namespace NEO

View File

@@ -12,67 +12,43 @@
namespace NEO {
template <typename GfxFamily>
void HostFunctionHelper::programHostFunction(LinearStream &commandStream, const HostFunctionData &hostFunctionData, uint64_t userHostFunctionAddress, uint64_t userDataAddress) {
HostFunctionHelper::programHostFunctionAddress<GfxFamily>(&commandStream, nullptr, hostFunctionData, userHostFunctionAddress);
HostFunctionHelper::programHostFunctionUserData<GfxFamily>(&commandStream, nullptr, hostFunctionData, userDataAddress);
HostFunctionHelper::programSignalHostFunctionStart<GfxFamily>(&commandStream, nullptr, hostFunctionData);
HostFunctionHelper::programWaitForHostFunctionCompletion<GfxFamily>(&commandStream, nullptr, hostFunctionData);
void HostFunctionHelper<GfxFamily>::programHostFunction(LinearStream &commandStream, HostFunctionStreamer &streamer, HostFunction &&hostFunction) {
HostFunctionHelper<GfxFamily>::programHostFunctionId(&commandStream, nullptr, streamer, std::move(hostFunction));
HostFunctionHelper<GfxFamily>::programHostFunctionWaitForCompletion(&commandStream, nullptr, streamer);
}
template <typename GfxFamily>
void HostFunctionHelper::programHostFunctionAddress(LinearStream *commandStream, void *cmdBuffer, const HostFunctionData &hostFunctionData, uint64_t userHostFunctionAddress) {
void HostFunctionHelper<GfxFamily>::programHostFunctionId(LinearStream *commandStream, void *cmdBuffer, HostFunctionStreamer &streamer, HostFunction &&hostFunction) {
using MI_STORE_DATA_IMM = typename GfxFamily::MI_STORE_DATA_IMM;
auto hostFunctionAddressDst = reinterpret_cast<uint64_t>(hostFunctionData.entry);
auto idGpuAddress = streamer.getHostFunctionIdGpuAddress();
auto hostFunctionId = streamer.getNextHostFunctionIdAndIncrement();
streamer.addHostFunction(hostFunctionId, std::move(hostFunction));
auto lowPart = getLowPart(hostFunctionId);
auto highPart = getHighPart(hostFunctionId);
bool storeQword = true;
EncodeStoreMemory<GfxFamily>::programStoreDataImmCommand(commandStream,
static_cast<MI_STORE_DATA_IMM *>(cmdBuffer),
hostFunctionAddressDst,
getLowPart(userHostFunctionAddress),
getHighPart(userHostFunctionAddress),
true,
idGpuAddress,
lowPart,
highPart,
storeQword,
false);
}
template <typename GfxFamily>
void HostFunctionHelper::programHostFunctionUserData(LinearStream *commandStream, void *cmdBuffer, const HostFunctionData &hostFunctionData, uint64_t userDataAddress) {
using MI_STORE_DATA_IMM = typename GfxFamily::MI_STORE_DATA_IMM;
auto userDataAddressDst = reinterpret_cast<uint64_t>(hostFunctionData.userData);
EncodeStoreMemory<GfxFamily>::programStoreDataImmCommand(commandStream,
static_cast<MI_STORE_DATA_IMM *>(cmdBuffer),
userDataAddressDst,
getLowPart(userDataAddress),
getHighPart(userDataAddress),
true,
false);
}
template <typename GfxFamily>
void HostFunctionHelper::programSignalHostFunctionStart(LinearStream *commandStream, void *cmdBuffer, const HostFunctionData &hostFunctionData) {
using MI_STORE_DATA_IMM = typename GfxFamily::MI_STORE_DATA_IMM;
auto internalTagAddress = reinterpret_cast<uint64_t>(hostFunctionData.internalTag);
EncodeStoreMemory<GfxFamily>::programStoreDataImmCommand(commandStream,
static_cast<MI_STORE_DATA_IMM *>(cmdBuffer),
internalTagAddress,
static_cast<uint32_t>(HostFunctionTagStatus::pending),
0u,
false,
false);
}
template <typename GfxFamily>
void HostFunctionHelper::programWaitForHostFunctionCompletion(LinearStream *commandStream, void *cmdBuffer, const HostFunctionData &hostFunctionData) {
void HostFunctionHelper<GfxFamily>::programHostFunctionWaitForCompletion(LinearStream *commandStream, void *cmdBuffer, const HostFunctionStreamer &streamer) {
using MI_SEMAPHORE_WAIT = typename GfxFamily::MI_SEMAPHORE_WAIT;
auto internalTagAddress = reinterpret_cast<uint64_t>(hostFunctionData.internalTag);
auto idGpuAddress = streamer.getHostFunctionIdGpuAddress();
auto waitValue = HostFunctionStatus::completed;
EncodeSemaphore<GfxFamily>::programMiSemaphoreWaitCommand(commandStream,
static_cast<MI_SEMAPHORE_WAIT *>(cmdBuffer),
internalTagAddress,
static_cast<uint32_t>(HostFunctionTagStatus::completed),
idGpuAddress,
waitValue,
GfxFamily::MI_SEMAPHORE_WAIT::COMPARE_OPERATION::COMPARE_OPERATION_SAD_EQUAL_SDD,
false,
true,

View File

@@ -8,9 +8,6 @@
#include "shared/source/command_stream/host_function.h"
namespace NEO {
template void HostFunctionHelper::programHostFunction<Family>(LinearStream &commandStream, const HostFunctionData &hostFunctionData, uint64_t userHostFunctionAddress, uint64_t userDataAddress);
template void HostFunctionHelper::programHostFunctionAddress<Family>(LinearStream *commandStream, void *cmdBuffer, const HostFunctionData &hostFunctionData, uint64_t userHostFunctionAddress);
template void HostFunctionHelper::programHostFunctionUserData<Family>(LinearStream *commandStream, void *cmdBuffer, const HostFunctionData &hostFunctionData, uint64_t userDataAddress);
template void HostFunctionHelper::programSignalHostFunctionStart<Family>(LinearStream *commandStream, void *cmdBuffer, const HostFunctionData &hostFunctionData);
template void HostFunctionHelper::programWaitForHostFunctionCompletion<Family>(LinearStream *commandStream, void *cmdBuffer, const HostFunctionData &hostFunctionData);
template struct HostFunctionHelper<Family>;
} // namespace NEO

View File

@@ -0,0 +1,43 @@
/*
* Copyright (C) 2025 Intel Corporation
*
* SPDX-License-Identifier: MIT
*
*/
#pragma once
#include "shared/source/helpers/non_copyable_or_moveable.h"
#include <functional>
#include <memory>
#include <mutex>
#include <stop_token>
#include <thread>
namespace NEO {
class GraphicsAllocation;
class HostFunctionStreamer;
struct HostFunction;
class HostFunctionWorker : public NonCopyableAndNonMovableClass {
public:
explicit HostFunctionWorker(bool skipHostFunctionExecution)
: skipHostFunctionExecution(skipHostFunctionExecution) {
}
virtual ~HostFunctionWorker() = default;
virtual void start(HostFunctionStreamer *streamer) = 0;
virtual void finish() = 0;
virtual void submit(uint32_t nHostFunctions) noexcept = 0;
protected:
std::unique_ptr<std::jthread> worker;
std::mutex workerMutex;
bool skipHostFunctionExecution = false;
};
static_assert(NonCopyableAndNonMovable<HostFunctionWorker>);
} // namespace NEO

View File

@@ -0,0 +1,122 @@
/*
* Copyright (C) 2025 Intel Corporation
*
* SPDX-License-Identifier: MIT
*
*/
#include "shared/source/command_stream/host_function_scheduler.h"
#include "shared/source/command_stream/host_function.h"
#include "shared/source/memory_manager/graphics_allocation.h"
#include "shared/source/utilities/wait_util.h"
#include <chrono>
#include <iostream>
#include <type_traits>
namespace NEO {
HostFunctionScheduler::HostFunctionScheduler(bool skipHostFunctionExecution,
int32_t threadsInThreadPoolLimit)
: HostFunctionWorker(skipHostFunctionExecution),
threadPool(threadsInThreadPoolLimit) {
}
HostFunctionScheduler::~HostFunctionScheduler() = default;
void HostFunctionScheduler::start(HostFunctionStreamer *streamer) {
this->registerHostFunctionStreamer(streamer);
this->threadPool.registerThread();
if (worker == nullptr) {
std::unique_lock<std::mutex> lock(workerMutex);
if (worker == nullptr) {
worker = std::make_unique<std::jthread>([this](std::stop_token st) {
this->schedulerLoop(st);
});
}
}
}
void HostFunctionScheduler::finish() {
std::call_once(shutdownOnceFlag, [&]() {
threadPool.shutdown();
{
std::unique_lock<std::mutex> lock(workerMutex);
if (worker) {
worker->request_stop();
semaphore.release();
worker->join();
worker.reset(nullptr);
}
}
{
std::unique_lock<std::mutex> lock(registeredStreamersMutex);
registeredStreamers.clear();
}
});
}
void HostFunctionScheduler::submit(uint32_t nHostFunctions) noexcept {
semaphore.release(static_cast<ptrdiff_t>(nHostFunctions));
}
void HostFunctionScheduler::scheduleHostFunctionToThreadPool(HostFunctionStreamer *streamer, uint64_t id) noexcept {
auto hostFunction = streamer->getHostFunction(id);
streamer->prepareForExecution(hostFunction);
threadPool.registerHostFunctionToExecute(streamer, std::move(hostFunction));
}
void HostFunctionScheduler::schedulerLoop(std::stop_token st) noexcept {
std::unique_lock<std::mutex> registeredStreamersLock(registeredStreamersMutex, std::defer_lock);
auto waitStart = std::chrono::steady_clock::now();
while (st.stop_requested() == false) {
semaphore.acquire(); // wait until there is at least one pending host function
semaphore.release(); // leave count unchanged intentionally
if (st.stop_requested()) {
return;
}
registeredStreamersLock.lock();
for (auto streamer : registeredStreamers) {
if (auto id = isHostFunctionReadyToExecute(streamer); id != HostFunctionStatus::completed) {
// std::cout << "id : " << id << std::endl;
scheduleHostFunctionToThreadPool(streamer, id);
waitStart = std::chrono::steady_clock::now();
}
}
registeredStreamersLock.unlock();
if (st.stop_requested()) {
return;
}
auto waitTime = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::steady_clock::now() - waitStart);
WaitUtils::waitFunctionWithoutPredicate(waitTime.count());
}
}
void HostFunctionScheduler::registerHostFunctionStreamer(HostFunctionStreamer *streamer) {
std::lock_guard<std::mutex> lock(registeredStreamersMutex);
registeredStreamers.push_back(streamer);
}
uint64_t HostFunctionScheduler::isHostFunctionReadyToExecute(HostFunctionStreamer *streamer) {
auto id = streamer->isHostFunctionReadyToExecute();
if (id != HostFunctionStatus::completed && semaphore.try_acquire()) {
return id;
}
return HostFunctionStatus::completed;
}
static_assert(NonCopyableAndNonMovable<HostFunctionScheduler>);
} // namespace NEO

View File

@@ -0,0 +1,52 @@
/*
* Copyright (C) 2025 Intel Corporation
*
* SPDX-License-Identifier: MIT
*
*/
#pragma once
#include "shared/source/command_stream/host_function_interface.h"
#include "shared/source/command_stream/host_function_worker_thread_pool.h"
#include "shared/source/utilities/stackvec.h"
#include <functional>
#include <memory>
#include <mutex>
#include <stop_token>
#include <thread>
#include <vector>
namespace NEO {
class GraphicsAllocation;
struct HostFunction;
class HostFunctionStreamer;
class HostFunctionScheduler final : public HostFunctionWorker {
public:
HostFunctionScheduler(bool skipHostFunctionExecution,
int32_t threadsInThreadPoolLimit);
~HostFunctionScheduler() override;
void start(HostFunctionStreamer *streamer) override;
void finish() override;
void submit(uint32_t nHostFunctions) noexcept override;
private:
void scheduleHostFunctionToThreadPool(HostFunctionStreamer *streamer, uint64_t hostFunctionId) noexcept;
void schedulerLoop(std::stop_token st) noexcept;
void registerHostFunctionStreamer(HostFunctionStreamer *streamer);
uint64_t isHostFunctionReadyToExecute(HostFunctionStreamer *streamer);
std::mutex registeredStreamersMutex;
std::once_flag shutdownOnceFlag;
std::counting_semaphore<> semaphore{0};
HostFunctionThreadPool threadPool;
std::vector<HostFunctionStreamer *> registeredStreamers;
};
} // namespace NEO

View File

@@ -1,68 +0,0 @@
/*
* Copyright (C) 2025 Intel Corporation
*
* SPDX-License-Identifier: MIT
*
*/
#include "shared/source/command_stream/host_function_worker_atomic.h"
#include "shared/source/command_stream/host_function.h"
namespace NEO {
HostFunctionWorkerAtomic::HostFunctionWorkerAtomic(bool skipHostFunctionExecution,
const std::function<void(GraphicsAllocation &)> &downloadAllocationImpl,
GraphicsAllocation *allocation,
HostFunctionData *data)
: IHostFunctionWorker(skipHostFunctionExecution, downloadAllocationImpl, allocation, data) {
}
HostFunctionWorkerAtomic::~HostFunctionWorkerAtomic() = default;
void HostFunctionWorkerAtomic::start() {
std::lock_guard<std::mutex> lg{workerMutex};
if (!worker) {
worker = std::make_unique<std::jthread>([this](std::stop_token st) {
this->workerLoop(std::move(st));
});
}
}
void HostFunctionWorkerAtomic::finish() {
std::lock_guard<std::mutex> lg{workerMutex};
if (worker) {
worker->request_stop();
pending.fetch_add(1u);
pending.notify_one();
worker.reset(nullptr);
}
}
void HostFunctionWorkerAtomic::submit() noexcept {
pending.fetch_add(1, std::memory_order_release);
pending.notify_one();
}
void HostFunctionWorkerAtomic::workerLoop(std::stop_token st) noexcept {
while (true) {
while (pending.load(std::memory_order_acquire) == 0) {
pending.wait(0, std::memory_order_acquire);
}
if (st.stop_requested()) {
return;
}
pending.fetch_sub(1, std::memory_order_acq_rel);
bool sucess = this->runHostFunction(st);
if (!sucess) [[unlikely]] {
return;
}
}
}
} // namespace NEO

View File

@@ -1,38 +0,0 @@
/*
* Copyright (C) 2025 Intel Corporation
*
* SPDX-License-Identifier: MIT
*
*/
#pragma once
#include "shared/source/command_stream/host_function_worker_interface.h"
#include <atomic>
#include <memory>
#include <thread>
namespace NEO {
class HostFunctionWorkerAtomic final : public IHostFunctionWorker {
public:
HostFunctionWorkerAtomic(bool skipHostFunctionExecution,
const std::function<void(GraphicsAllocation &)> &downloadAllocationImpl,
GraphicsAllocation *allocation,
HostFunctionData *data);
~HostFunctionWorkerAtomic() override;
void start() override;
void finish() override;
void submit() noexcept override;
private:
void workerLoop(std::stop_token st) noexcept;
std::atomic<uint32_t> pending{0};
};
static_assert(NonCopyableAndNonMovable<HostFunctionWorkerAtomic>);
} // namespace NEO

View File

@@ -7,20 +7,24 @@
#include "shared/source/command_stream/host_function_worker_counting_semaphore.h"
#include "shared/source/command_stream/host_function.h"
namespace NEO {
HostFunctionWorkerCountingSemaphore::HostFunctionWorkerCountingSemaphore(bool skipHostFunctionExecution, const std::function<void(GraphicsAllocation &)> &downloadAllocationImpl, GraphicsAllocation *allocation, HostFunctionData *data)
: IHostFunctionWorker(skipHostFunctionExecution, downloadAllocationImpl, allocation, data) {
HostFunctionWorkerCountingSemaphore::HostFunctionWorkerCountingSemaphore(bool skipHostFunctionExecution)
: HostFunctionSingleWorker(skipHostFunctionExecution) {
}
HostFunctionWorkerCountingSemaphore::~HostFunctionWorkerCountingSemaphore() = default;
void HostFunctionWorkerCountingSemaphore::start() {
void HostFunctionWorkerCountingSemaphore::start(HostFunctionStreamer *streamer) {
std::lock_guard<std::mutex> lg{workerMutex};
this->streamer = streamer;
if (!worker) {
worker = std::make_unique<std::jthread>([this](std::stop_token st) {
this->workerLoop(std::move(st));
this->workerLoop(st);
});
}
}
@@ -35,24 +39,16 @@ void HostFunctionWorkerCountingSemaphore::finish() {
}
}
void HostFunctionWorkerCountingSemaphore::submit() noexcept {
semaphore.release();
void HostFunctionWorkerCountingSemaphore::submit(uint32_t nHostFunctions) noexcept {
semaphore.release(static_cast<ptrdiff_t>(nHostFunctions));
}
void HostFunctionWorkerCountingSemaphore::workerLoop(std::stop_token st) noexcept {
while (true) {
while (st.stop_requested() == false) {
semaphore.acquire();
if (st.stop_requested()) [[unlikely]] {
return;
}
bool success = runHostFunction(st);
if (!success) [[unlikely]] {
return;
}
processNextHostFunction(st);
}
}

View File

@@ -15,17 +15,17 @@
namespace NEO {
class HostFunctionWorkerCountingSemaphore final : public IHostFunctionWorker {
class HostFunctionStreamer;
struct HostFunction;
class HostFunctionWorkerCountingSemaphore final : public HostFunctionSingleWorker {
public:
HostFunctionWorkerCountingSemaphore(bool skipHostFunctionExecution,
const std::function<void(GraphicsAllocation &)> &downloadAllocationImpl,
GraphicsAllocation *allocation,
HostFunctionData *data);
HostFunctionWorkerCountingSemaphore(bool skipHostFunctionExecution);
~HostFunctionWorkerCountingSemaphore() override;
void start() override;
void start(HostFunctionStreamer *streamer) override;
void finish() override;
void submit() noexcept override;
void submit(uint32_t nHostFunctions) noexcept override;
private:
void workerLoop(std::stop_token st) noexcept;

View File

@@ -1,75 +0,0 @@
/*
* Copyright (C) 2025 Intel Corporation
*
* SPDX-License-Identifier: MIT
*
*/
#include "shared/source/command_stream/host_function_worker_cv.h"
#include "shared/source/command_stream/command_stream_receiver.h"
#include "shared/source/command_stream/host_function.h"
#include "shared/source/utilities/wait_util.h"
namespace NEO {
HostFunctionWorkerCV::HostFunctionWorkerCV(bool skipHostFunctionExecution,
const std::function<void(GraphicsAllocation &)> &downloadAllocationImpl,
GraphicsAllocation *allocation,
HostFunctionData *data)
: IHostFunctionWorker(skipHostFunctionExecution, downloadAllocationImpl, allocation, data) {
}
HostFunctionWorkerCV::~HostFunctionWorkerCV() = default;
void HostFunctionWorkerCV::start() {
std::lock_guard<std::mutex> lg{workerMutex};
if (!worker) {
worker = std::make_unique<std::jthread>([this](std::stop_token st) {
this->workerLoop(std::move(st));
});
}
}
void HostFunctionWorkerCV::finish() {
std::lock_guard<std::mutex> lg{workerMutex};
if (worker) {
worker->request_stop();
cv.notify_one();
worker.reset(nullptr);
}
}
void HostFunctionWorkerCV::submit() noexcept {
{
std::lock_guard<std::mutex> lock{pendingAccessMutex};
++pending;
}
cv.notify_one();
}
void HostFunctionWorkerCV::workerLoop(std::stop_token st) noexcept {
std::unique_lock<std::mutex> lock{pendingAccessMutex, std::defer_lock};
while (true) {
lock.lock();
cv.wait(lock, [&]() {
return pending > 0 || st.stop_requested();
});
if (st.stop_requested()) [[unlikely]] {
return;
}
--pending;
lock.unlock();
bool sucess = this->runHostFunction(st);
if (!sucess) [[unlikely]] {
return;
}
}
}
} // namespace NEO

View File

@@ -1,39 +0,0 @@
/*
* Copyright (C) 2025 Intel Corporation
*
* SPDX-License-Identifier: MIT
*
*/
#pragma once
#include "shared/source/command_stream/host_function_worker_interface.h"
#include <condition_variable>
#include <mutex>
#include <thread>
namespace NEO {
class HostFunctionWorkerCV final : public IHostFunctionWorker {
public:
HostFunctionWorkerCV(bool skipHostFunctionExecution,
const std::function<void(GraphicsAllocation &)> &downloadAllocationImpl,
GraphicsAllocation *allocation,
HostFunctionData *data);
~HostFunctionWorkerCV() override;
void start() override;
void finish() override;
void submit() noexcept override;
private:
void workerLoop(std::stop_token st) noexcept;
std::mutex pendingAccessMutex;
std::condition_variable cv;
uint32_t pending{0};
};
static_assert(NonCopyableAndNonMovable<HostFunctionWorkerCV>);
} // namespace NEO

View File

@@ -14,55 +14,47 @@
#include <type_traits>
namespace NEO {
IHostFunctionWorker::IHostFunctionWorker(bool skipHostFunctionExecution,
const std::function<void(GraphicsAllocation &)> &downloadAllocationImpl,
GraphicsAllocation *allocation,
HostFunctionData *data)
: downloadAllocationImpl(downloadAllocationImpl),
allocation(allocation),
data(data),
skipHostFunctionExecution(skipHostFunctionExecution) {
HostFunctionSingleWorker::HostFunctionSingleWorker(bool skipHostFunctionExecution)
: HostFunctionWorker(skipHostFunctionExecution) {
}
IHostFunctionWorker::~IHostFunctionWorker() = default;
HostFunctionSingleWorker::~HostFunctionSingleWorker() = default;
bool IHostFunctionWorker::runHostFunction(std::stop_token st) noexcept {
void HostFunctionSingleWorker::processNextHostFunction(std::stop_token st) noexcept {
if (skipHostFunctionExecution == false) {
auto hostFunctionReady = waitUntilHostFunctionIsReady(st);
if (hostFunctionReady) {
auto hostFunction = streamer->getHostFunction();
streamer->prepareForExecution(hostFunction);
hostFunction.invoke();
streamer->signalHostFunctionCompletion(hostFunction);
}
}
}
bool HostFunctionSingleWorker::waitUntilHostFunctionIsReady(std::stop_token st) noexcept {
using tagStatusT = std::underlying_type_t<HostFunctionTagStatus>;
const auto start = std::chrono::steady_clock::now();
std::chrono::microseconds waitTime{0};
if (!this->skipHostFunctionExecution) {
while (true) {
while (true) {
if (this->downloadAllocationImpl) [[unlikely]] {
this->downloadAllocationImpl(*this->allocation);
}
const volatile uint32_t *hostFuntionTagAddress = this->data->internalTag;
waitTime = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::steady_clock::now() - start);
bool pendingJobFound = WaitUtils::waitFunctionWithPredicate<const tagStatusT>(hostFuntionTagAddress,
static_cast<tagStatusT>(HostFunctionTagStatus::pending),
std::equal_to<tagStatusT>(),
waitTime.count());
if (pendingJobFound) {
break;
}
if (st.stop_requested()) {
return false;
}
if (st.stop_requested()) {
return false;
}
using CallbackT = void (*)(void *);
CallbackT callback = reinterpret_cast<CallbackT>(*this->data->entry);
void *callbackData = reinterpret_cast<void *>(*this->data->userData);
streamer->downloadHostFunctionAllocation();
callback(callbackData);
auto waitTime = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::steady_clock::now() - start);
auto hostFunctionReady = WaitUtils::waitFunctionWithPredicate<uint64_t>(streamer->getHostFunctionIdPtr(),
HostFunctionStatus::completed,
std::greater<uint64_t>(),
waitTime.count());
if (hostFunctionReady) {
return true;
}
}
*this->data->internalTag = static_cast<tagStatusT>(HostFunctionTagStatus::completed);
return true;
}
} // namespace NEO

View File

@@ -7,6 +7,7 @@
#pragma once
#include "shared/source/command_stream/host_function_interface.h"
#include "shared/source/helpers/non_copyable_or_moveable.h"
#include <functional>
@@ -18,32 +19,24 @@
namespace NEO {
class GraphicsAllocation;
struct HostFunctionData;
class HostFunctionStreamer;
struct HostFunction;
class IHostFunctionWorker : public NonCopyableAndNonMovableClass {
class HostFunctionSingleWorker : public HostFunctionWorker {
public:
IHostFunctionWorker(bool skipHostFunctionExecution,
const std::function<void(GraphicsAllocation &)> &downloadAllocationImpl,
GraphicsAllocation *allocation,
HostFunctionData *data);
virtual ~IHostFunctionWorker() = 0;
explicit HostFunctionSingleWorker(bool skipHostFunctionExecution);
~HostFunctionSingleWorker() override = 0;
virtual void start() = 0;
virtual void finish() = 0;
virtual void submit() noexcept = 0;
void start(HostFunctionStreamer *streamer) override = 0;
void finish() override = 0;
void submit(uint32_t nHostFunctions) noexcept override = 0;
protected:
MOCKABLE_VIRTUAL bool runHostFunction(std::stop_token st) noexcept;
std::unique_ptr<std::jthread> worker;
std::mutex workerMutex;
private:
std::function<void(GraphicsAllocation &)> downloadAllocationImpl;
GraphicsAllocation *allocation = nullptr;
HostFunctionData *data = nullptr;
bool skipHostFunctionExecution = false;
MOCKABLE_VIRTUAL void processNextHostFunction(std::stop_token st) noexcept;
bool waitUntilHostFunctionIsReady(std::stop_token st) noexcept;
HostFunctionStreamer *streamer = nullptr;
};
static_assert(NonCopyableAndNonMovable<IHostFunctionWorker>);
static_assert(NonCopyableAndNonMovable<HostFunctionSingleWorker>);
} // namespace NEO

View File

@@ -0,0 +1,89 @@
/*
* Copyright (C) 2025 Intel Corporation
*
* SPDX-License-Identifier: MIT
*
*/
#include "shared/source/command_stream/host_function_worker_thread_pool.h"
#include "shared/source/command_stream/host_function.h"
namespace NEO {
HostFunctionThreadPool::HostFunctionThreadPool(int32_t threadsInThreadPoolLimit) {
if (threadsInThreadPoolLimit == HostFunctionThreadPoolHelper::unlimitedThreads) {
unlimitedThreads = true;
} else {
threadsLimit = static_cast<uint32_t>(threadsInThreadPoolLimit);
}
}
HostFunctionThreadPool::~HostFunctionThreadPool() = default;
void HostFunctionThreadPool::registerThread() noexcept {
if ((threads.size() < threadsLimit) || unlimitedThreads) {
threads.emplace_back(([this](std::stop_token st) {
this->workerLoop(st);
}));
}
}
void HostFunctionThreadPool::shutdown() noexcept {
for (auto &thread : threads) {
thread.request_stop();
}
semaphore.release(static_cast<ptrdiff_t>(threads.size()));
for (auto &thread : threads) {
thread.join();
}
{
std::lock_guard lock{this->hostFunctionsMutex};
hostFunctions.clear();
}
threads.clear();
}
void HostFunctionThreadPool::registerHostFunctionToExecute(HostFunctionStreamer *streamer, HostFunction &&hostFunction) {
{
std::unique_lock lock{this->hostFunctionsMutex};
hostFunctions.emplace_back(streamer, std::move(hostFunction));
}
semaphore.release();
}
void NEO::HostFunctionThreadPool::workerLoop(std::stop_token st) noexcept {
while (st.stop_requested() == false) {
semaphore.acquire();
if (st.stop_requested()) {
return;
}
executeHostFunction();
}
}
void HostFunctionThreadPool::executeHostFunction() noexcept {
std::unique_lock lock{this->hostFunctionsMutex};
auto [streamer, hostFunction] = std::move(hostFunctions.front());
hostFunctions.pop_front();
lock.unlock();
hostFunction.invoke();
streamer->signalHostFunctionCompletion(hostFunction);
}
} // namespace NEO

View File

@@ -0,0 +1,50 @@
/*
* Copyright (C) 2025 Intel Corporation
*
* SPDX-License-Identifier: MIT
*
*/
#pragma once
#include "shared/source/command_stream/host_function.h"
#include "shared/source/helpers/non_copyable_or_moveable.h"
#include <deque>
#include <functional>
#include <memory>
#include <mutex>
#include <semaphore>
#include <stop_token>
#include <thread>
namespace NEO {
class GraphicsAllocation;
struct HostFunction;
class HostFunctionStreamer;
class HostFunctionThreadPool : public NonCopyableAndNonMovableClass {
public:
explicit HostFunctionThreadPool(int32_t threadsInThreadPoolLimit);
~HostFunctionThreadPool();
void registerHostFunctionToExecute(HostFunctionStreamer *streamer, HostFunction &&hostFunction);
void registerThread() noexcept;
void shutdown() noexcept;
private:
void executeHostFunction() noexcept;
void workerLoop(std::stop_token st) noexcept;
std::mutex hostFunctionsMutex;
std::deque<std::jthread> threads;
std::deque<std::pair<HostFunctionStreamer *, HostFunction>> hostFunctions;
std::counting_semaphore<> semaphore{0};
uint32_t threadsLimit = 0;
bool unlimitedThreads = false;
};
static_assert(NonCopyableAndNonMovable<HostFunctionThreadPool>);
} // namespace NEO

View File

@@ -16,7 +16,7 @@ inline constexpr uint64_t debugPauseStateAddressOffset = MemoryConstants::kiloBy
inline constexpr uint64_t ucTagAddressOffset = MemoryConstants::kiloByte + MemoryConstants::cacheLineSize;
inline constexpr uint64_t completionFenceOffset = 2 * MemoryConstants::kiloByte;
inline constexpr uint64_t barrierCountOffset = 3 * MemoryConstants::kiloByte;
inline constexpr uint64_t hostFunctionDataOffset = barrierCountOffset + (8 * MemoryConstants::cacheLineSize);
inline constexpr uint64_t hostFunctionDataOffset = barrierCountOffset + MemoryConstants::cacheLineSize;
} // namespace TagAllocationLayout
} // namespace NEO

View File

@@ -319,7 +319,9 @@ DECLARE_DEBUG_VARIABLE(int32_t, OverrideCopyOffloadMode, -1, "-1: default, 0: di
DECLARE_DEBUG_VARIABLE(int32_t, UseSingleListForTemporaryAllocations, -1, "-1: default, 0: disabled, 1: enabled. If enabled, use single list, instead of per CSR for tracking temporary allocations")
DECLARE_DEBUG_VARIABLE(int32_t, OverrideMaxMemAllocSizeMb, -1, "-1: default, >=0 override reported max mem alloc size in MB")
DECLARE_DEBUG_VARIABLE(int32_t, DetectIncorrectPointersOnSetArgCalls, -1, "-1: default do not detect, 0: do not detect, 1: detect incorrect pointers and return error")
DECLARE_DEBUG_VARIABLE(int32_t, HostFunctionWorkMode, -1, "-1: default - counting semaphore based, 0: counting semaphore based, 1: condition variable base, 2: atomics based")
DECLARE_DEBUG_VARIABLE(int32_t, HostFunctionWorkMode, -1, "-1: default - counting semaphore based, 0: counting semaphore based, 1: scheduler with thread pool")
DECLARE_DEBUG_VARIABLE(int32_t, HostFunctionThreadPoolSize, -1, "-1: default - one thread per CSR that uses host functions, >0: number of threads per host function worker thread pool. Usable only if HostFunctionWorkMode=1 is set ")
DECLARE_DEBUG_VARIABLE(bool, AllowForOutOfOrderHostFunctionExecution, 0, "0: default disabled, 1: enable out-of-order host function execution")
DECLARE_DEBUG_VARIABLE(int32_t, ForceDisableGraphPatchPreamble, -1, "-1: default, 0: enable patch preamble, 1: disable graph patch preamble. If disabled, do not patch preamble graph internal command lists")
DECLARE_DEBUG_VARIABLE(int32_t, EnableStateCacheInvalidationWa, -1, "-1: default, 0: disabled, 1: enabled. When enabled, insert a PIPE_CONTROL with state cache invalidation on the CCS after the walker for kernels that contain stateful access")

View File

@@ -12,6 +12,7 @@
#include "shared/source/aub/aub_center.h"
#include "shared/source/built_ins/built_ins.h"
#include "shared/source/built_ins/sip.h"
#include "shared/source/command_stream/host_function_scheduler.h"
#include "shared/source/compiler_interface/compiler_interface.h"
#include "shared/source/compiler_interface/default_cache_config.h"
#include "shared/source/debugger/debugger.h"
@@ -297,6 +298,14 @@ void RootDeviceEnvironment::releaseDummyAllocation() {
dummyAllocation.reset();
}
void RootDeviceEnvironment::setHostFunctionScheduler(std::unique_ptr<HostFunctionWorker> &&scheduler) {
hostFunctionScheduler = std::move(scheduler);
}
HostFunctionWorker *RootDeviceEnvironment::getHostFunctionScheduler() const {
return hostFunctionScheduler.get();
}
AssertHandler *RootDeviceEnvironment::getAssertHandler(Device *neoDevice) {
if (this->assertHandler.get() == nullptr) {
std::lock_guard<std::mutex> autolock(this->mtx);

View File

@@ -42,6 +42,7 @@ class CompilerProductHelper;
class GraphicsAllocation;
class ReleaseHelper;
class AILConfiguration;
class HostFunctionWorker;
struct AllocationProperties;
struct HardwareInfo;
@@ -109,6 +110,9 @@ struct RootDeviceEnvironment : NonCopyableClass {
return exposeSingleDevice;
}
void setHostFunctionScheduler(std::unique_ptr<HostFunctionWorker> &&scheduler);
HostFunctionWorker *getHostFunctionScheduler() const;
std::unique_ptr<SipKernel> sipKernels[static_cast<uint32_t>(SipKernelType::count)];
std::unique_ptr<GmmHelper> gmmHelper;
std::unique_ptr<OSInterface> osInterface;
@@ -128,7 +132,7 @@ struct RootDeviceEnvironment : NonCopyableClass {
std::unique_ptr<ReleaseHelper> releaseHelper;
std::unique_ptr<AILConfiguration> ailConfiguration;
std::unique_ptr<BindlessHeapsHelper> bindlessHeapsHelper;
std::unique_ptr<HostFunctionWorker> hostFunctionScheduler;
std::unique_ptr<AssertHandler> assertHandler;
ExecutionEnvironment &executionEnvironment;

View File

@@ -79,6 +79,16 @@ inline bool waitFunctionWithPredicate(volatile T const *pollAddress, T expectedV
return false;
}
inline void waitFunctionWithoutPredicate(int64_t timeElapsedSinceWaitStarted) {
if (waitpkgUse == WaitpkgUse::tpause && timeElapsedSinceWaitStarted > waitPkgThresholdInMicroSeconds) {
tpause();
} else {
for (uint32_t i = 0; i < waitCount; i++) {
CpuIntrinsics::pause();
}
}
}
inline bool waitFunction(volatile TagAddressType *pollAddress, TaskCountType expectedValue, int64_t timeElapsedSinceWaitStarted) {
return waitFunctionWithPredicate<TaskCountType>(pollAddress, expectedValue, std::greater_equal<TaskCountType>(), timeElapsedSinceWaitStarted);
}

View File

@@ -67,6 +67,7 @@ class UltCommandStreamReceiver : public CommandStreamReceiverHw<GfxFamily> {
using BaseClass::getCmdSizeForExceptions;
using BaseClass::getCmdSizeForHeaplessPrologue;
using BaseClass::getCmdSizeForPrologue;
using BaseClass::getHostFunctionStreamer;
using BaseClass::getScratchPatchAddress;
using BaseClass::getScratchSpaceController;
using BaseClass::handleAllocationsResidencyForHeaplessProlog;
@@ -564,8 +565,8 @@ class UltCommandStreamReceiver : public CommandStreamReceiverHw<GfxFamily> {
BaseClass::setupContext(osContext);
}
void signalHostFunctionWorker() override {
signalHostFunctionWorkerCounter++;
void signalHostFunctionWorker(uint32_t nHostFunctions) override {
signalHostFunctionWorkerCounter += nHostFunctions;
}
void createHostFunctionWorker() override {

View File

@@ -43,7 +43,6 @@ class MockCommandStreamReceiver : public CommandStreamReceiver {
using CommandStreamReceiver::gpuHangCheckPeriod;
using CommandStreamReceiver::heaplessStateInitEnabled;
using CommandStreamReceiver::heaplessStateInitialized;
using CommandStreamReceiver::hostFunctionDataAllocation;
using CommandStreamReceiver::immWritePostSyncWriteOffset;
using CommandStreamReceiver::internalAllocationStorage;
using CommandStreamReceiver::latestFlushedTaskCount;
@@ -284,8 +283,8 @@ class MockCommandStreamReceiver : public CommandStreamReceiver {
BaseClass::startHostFunctionWorker();
}
void signalHostFunctionWorker() override {
signalHostFunctionWorkerCounter++;
void signalHostFunctionWorker(uint32_t nHostFunction) override {
signalHostFunctionWorkerCounter += nHostFunction;
}
void createHostFunctionWorker() override {

View File

@@ -678,6 +678,8 @@ CopyLockedMemoryBeforeWrite = 0
SplitBcsPerEngineMaxSize = -1
PrintSecondaryContextEngineInfo = 0
HostFunctionWorkMode = -1
HostFunctionThreadPoolSize = -1
AllowForOutOfOrderHostFunctionExecution = 0
Enable512NumGrfs = 1
EnableUsmPoolResidencyTracking = -1
EnableUsmPoolLazyInit = -1

View File

@@ -6448,90 +6448,36 @@ HWTEST_F(CommandStreamReceiverHwTest, givenVariousCsrModeWhenGettingHardwareMode
EXPECT_FALSE(ultCsr.isHardwareMode());
}
TEST(CommandStreamReceiverHostFunctionsTest, givenCommandStreamReceiverWhenEnsureHostFunctionDataInitializationCalledThenHostFunctionAllocationIsBeingAllocatedOnlyOnce) {
MockExecutionEnvironment executionEnvironment(defaultHwInfo.get());
DeviceBitfield devices(0b11);
auto csr = std::make_unique<MockCommandStreamReceiver>(executionEnvironment, 0, devices);
executionEnvironment.memoryManager.reset(new OsAgnosticMemoryManager(executionEnvironment));
using CommandStreamReceiverHostFunctionHwTest = Test<CommandStreamReceiverFixture>;
EXPECT_EQ(nullptr, csr->getHostFunctionDataAllocation());
csr->initializeTagAllocation();
csr->ensureHostFunctionWorkerStarted();
auto *hostDataAllocation = csr->getHostFunctionDataAllocation();
EXPECT_NE(nullptr, hostDataAllocation);
EXPECT_EQ(1u, csr->startHostFunctionWorkerCalledTimes);
csr->ensureHostFunctionWorkerStarted();
EXPECT_EQ(hostDataAllocation, csr->getHostFunctionDataAllocation());
EXPECT_EQ(1u, csr->startHostFunctionWorkerCalledTimes);
csr->startHostFunctionWorker();
EXPECT_EQ(2u, csr->startHostFunctionWorkerCalledTimes); // direct call -> the counter updated but due to an early return allocation didn't change
EXPECT_EQ(hostDataAllocation, csr->getHostFunctionDataAllocation());
EXPECT_EQ(AllocationType::tagBuffer, hostDataAllocation->getAllocationType());
auto expectedHostFunctionAddress = reinterpret_cast<uint64_t>(ptrOffset(hostDataAllocation->getUnderlyingBuffer(),
HostFunctionHelper::entryOffset + TagAllocationLayout::hostFunctionDataOffset));
EXPECT_EQ(expectedHostFunctionAddress, reinterpret_cast<uint64_t>(csr->getHostFunctionData().entry));
auto expectedUserDataAddress = reinterpret_cast<uint64_t>(ptrOffset(hostDataAllocation->getUnderlyingBuffer(),
HostFunctionHelper::userDataOffset + TagAllocationLayout::hostFunctionDataOffset));
EXPECT_EQ(expectedUserDataAddress, reinterpret_cast<uint64_t>(csr->getHostFunctionData().userData));
auto expectedInternalTagAddress = reinterpret_cast<uint64_t>(ptrOffset(hostDataAllocation->getUnderlyingBuffer(),
HostFunctionHelper::internalTagOffset + TagAllocationLayout::hostFunctionDataOffset));
EXPECT_EQ(expectedInternalTagAddress, reinterpret_cast<uint64_t>(csr->getHostFunctionData().internalTag));
}
TEST(CommandStreamReceiverHostFunctionsTest, givenDestructedCommandStreamReceiverWhenEnsureHostFunctionDataInitializationCalledThenHostFunctionAllocationsDeallocated) {
MockExecutionEnvironment executionEnvironment(defaultHwInfo.get());
DeviceBitfield devices(0b11);
auto csr = std::make_unique<MockCommandStreamReceiver>(executionEnvironment, 0, devices);
executionEnvironment.memoryManager.reset(new OsAgnosticMemoryManager(executionEnvironment));
csr->initializeTagAllocation();
EXPECT_NE(nullptr, csr->getHostFunctionDataAllocation());
csr->ensureHostFunctionWorkerStarted();
EXPECT_EQ(1u, csr->createHostFunctionWorkerCounter);
}
HWTEST_F(CommandStreamReceiverHwTest, givenHostFunctionDataWhenMakeResidentHostFunctionAllocationIsCalledThenHostAllocationIsResident) {
HWTEST_F(CommandStreamReceiverHostFunctionHwTest, givenHostFunctionWhenMakeResidentHostFunctionAllocationIsCalledThenHostAllocationIsResident) {
auto &csr = pDevice->getUltCommandStreamReceiver<FamilyType>();
auto *hostDataAllocation = csr.getHostFunctionDataAllocation();
ASSERT_NE(nullptr, hostDataAllocation);
csr.ensureHostFunctionWorkerStarted();
EXPECT_EQ(1u, csr.createHostFunctionWorkerCounter);
auto *hostFunctionIdAllocation = csr.getHostFunctionStreamer().getHostFunctionIdAllocation();
ASSERT_NE(nullptr, hostFunctionIdAllocation);
auto csrContextId = csr.getOsContext().getContextId();
EXPECT_FALSE(hostDataAllocation->isResident(csrContextId));
EXPECT_FALSE(hostFunctionIdAllocation->isResident(csrContextId));
csr.makeResident(*csr.tagAllocation);
EXPECT_TRUE(hostDataAllocation->isResident(csrContextId));
EXPECT_TRUE(hostFunctionIdAllocation->isResident(csrContextId));
csr.makeNonResident(*hostDataAllocation);
EXPECT_FALSE(hostDataAllocation->isResident(csrContextId));
csr.makeNonResident(*hostFunctionIdAllocation);
EXPECT_FALSE(hostFunctionIdAllocation->isResident(csrContextId));
}
HWTEST_F(CommandStreamReceiverHwTest, givenHostFunctionDataWhenSignalHostFunctionWorkerIsCalledThenCounterIsUpdated) {
HWTEST_F(CommandStreamReceiverHostFunctionHwTest, givenHostFunctionWhenSignalHostFunctionWorkerIsCalledThenCounterIsUpdated) {
auto &csr = pDevice->getUltCommandStreamReceiver<FamilyType>();
auto *hostDataAllocation = csr.getHostFunctionDataAllocation();
ASSERT_NE(nullptr, hostDataAllocation);
ASSERT_EQ(0u, csr.createHostFunctionWorkerCounter);
ASSERT_EQ(0u, csr.createHostFunctionWorkerCounter);
csr.ensureHostFunctionWorkerStarted();
csr.signalHostFunctionWorker();
csr.signalHostFunctionWorker(10u);
ASSERT_EQ(1u, csr.createHostFunctionWorkerCounter);
EXPECT_EQ(1u, csr.signalHostFunctionWorkerCounter);
EXPECT_EQ(10u, csr.signalHostFunctionWorkerCounter);
}

View File

@@ -6,8 +6,13 @@
*/
#include "shared/source/command_stream/host_function.h"
#include "shared/source/command_stream/tag_allocation_layout.h"
#include "shared/source/memory_manager/os_agnostic_memory_manager.h"
#include "shared/test/common/cmd_parse/hw_parse.h"
#include "shared/test/common/fixtures/device_fixture.h"
#include "shared/test/common/helpers/default_hw_info.h"
#include "shared/test/common/mocks/mock_command_stream_receiver.h"
#include "shared/test/common/mocks/mock_graphics_allocation.h"
#include "shared/test/common/test_macros/hw_test.h"
#include <cstddef>
@@ -19,125 +24,267 @@ using HostFunctionTests = Test<DeviceFixture>;
HWTEST_F(HostFunctionTests, givenHostFunctionDataStoredWhenProgramHostFunctionIsCalledThenMiStoresAndSemaphoreWaitAreProgrammedCorrectlyInCorrectOrder) {
using MI_STORE_DATA_IMM = typename FamilyType::MI_STORE_DATA_IMM;
using MI_SEMAPHORE_WAIT = typename FamilyType::MI_SEMAPHORE_WAIT;
constexpr auto size = 1024u;
std::byte buff[size] = {};
LinearStream stream(buff, size);
uint64_t userHostFunctionStored = 10u;
uint64_t userDataStored = 20u;
uint32_t tagStored = 0;
uint64_t callbackAddress = 1024;
uint64_t userDataAddress = 2048;
bool isInOrder = true;
HostFunctionData hostFunctionData{
.entry = &userHostFunctionStored,
.userData = &userDataStored,
.internalTag = &tagStored};
HostFunction hostFunction{
.hostFunctionAddress = callbackAddress,
.userDataAddress = userDataAddress,
.isInOrder = true};
uint64_t userCallback = 0xAAAA'0000ull;
uint64_t userCallbackData = 0xBBBB'000ull;
MockGraphicsAllocation allocation;
HostFunctionHelper::programHostFunction<FamilyType>(stream, hostFunctionData, userCallback, userCallbackData);
uint64_t hostFunctionId = 1;
std::function<void(GraphicsAllocation &)> downloadAllocationImpl = [](GraphicsAllocation &) {};
bool isTbx = false;
auto hostFunctionStreamer = std::make_unique<HostFunctionStreamer>(&allocation,
&hostFunctionId,
downloadAllocationImpl,
isTbx);
HostFunctionHelper<FamilyType>::programHostFunction(stream, *hostFunctionStreamer.get(), std::move(hostFunction));
HardwareParse hwParser;
hwParser.parseCommands<FamilyType>(stream, 0);
auto miStores = findAll<MI_STORE_DATA_IMM *>(hwParser.cmdList.begin(), hwParser.cmdList.end());
EXPECT_EQ(3u, miStores.size());
EXPECT_EQ(1u, miStores.size());
auto miWait = findAll<MI_SEMAPHORE_WAIT *>(hwParser.cmdList.begin(), hwParser.cmdList.end());
EXPECT_EQ(1u, miWait.size());
// program callback address
// program host function id
auto expectedHostFunctionId = 1u;
auto miStoreUserHostFunction = genCmdCast<MI_STORE_DATA_IMM *>(*miStores[0]);
EXPECT_EQ(reinterpret_cast<uint64_t>(&userHostFunctionStored), miStoreUserHostFunction->getAddress());
EXPECT_EQ(getLowPart(userCallback), miStoreUserHostFunction->getDataDword0());
EXPECT_EQ(getHighPart(userCallback), miStoreUserHostFunction->getDataDword1());
EXPECT_EQ(reinterpret_cast<uint64_t>(&hostFunctionId), miStoreUserHostFunction->getAddress());
EXPECT_EQ(getLowPart(expectedHostFunctionId), miStoreUserHostFunction->getDataDword0());
EXPECT_EQ(getHighPart(expectedHostFunctionId), miStoreUserHostFunction->getDataDword1());
EXPECT_TRUE(miStoreUserHostFunction->getStoreQword());
// program callback data
auto miStoreUserData = genCmdCast<MI_STORE_DATA_IMM *>(*miStores[1]);
EXPECT_EQ(reinterpret_cast<uint64_t>(&userDataStored), miStoreUserData->getAddress());
EXPECT_EQ(getLowPart(userCallbackData), miStoreUserData->getDataDword0());
EXPECT_EQ(getHighPart(userCallbackData), miStoreUserData->getDataDword1());
EXPECT_TRUE(miStoreUserData->getStoreQword());
// signal pending job
auto miStoreSignalTag = genCmdCast<MI_STORE_DATA_IMM *>(*miStores[2]);
EXPECT_EQ(reinterpret_cast<uint64_t>(&tagStored), miStoreSignalTag->getAddress());
EXPECT_EQ(static_cast<uint32_t>(HostFunctionTagStatus::pending), miStoreSignalTag->getDataDword0());
EXPECT_FALSE(miStoreSignalTag->getStoreQword());
// wait for completion
// program wait for host function completion
auto miWaitTag = genCmdCast<MI_SEMAPHORE_WAIT *>(*miWait[0]);
EXPECT_EQ(reinterpret_cast<uint64_t>(&tagStored), miWaitTag->getSemaphoreGraphicsAddress());
EXPECT_EQ(static_cast<uint32_t>(HostFunctionTagStatus::completed), miWaitTag->getSemaphoreDataDword());
EXPECT_EQ(reinterpret_cast<uint64_t>(&hostFunctionId), miWaitTag->getSemaphoreGraphicsAddress());
EXPECT_EQ(static_cast<uint32_t>(HostFunctionStatus::completed), miWaitTag->getSemaphoreDataDword());
EXPECT_EQ(MI_SEMAPHORE_WAIT::COMPARE_OPERATION_SAD_EQUAL_SDD, miWaitTag->getCompareOperation());
EXPECT_EQ(MI_SEMAPHORE_WAIT::WAIT_MODE_POLLING_MODE, miWaitTag->getWaitMode());
// host function from host function streamer
auto programmedHostFunction = hostFunctionStreamer->getHostFunction();
EXPECT_EQ(callbackAddress, programmedHostFunction.hostFunctionAddress);
EXPECT_EQ(userDataAddress, programmedHostFunction.userDataAddress);
EXPECT_EQ(isInOrder, programmedHostFunction.isInOrder);
}
HWTEST_F(HostFunctionTests, givenCommandBufferPassedWhenProgramHostFunctionsAreCalledThenMiStoresAndSemaphoreWaitAreProgrammedCorrectlyInCorrectOrder) {
using MI_STORE_DATA_IMM = typename FamilyType::MI_STORE_DATA_IMM;
using MI_SEMAPHORE_WAIT = typename FamilyType::MI_SEMAPHORE_WAIT;
MockGraphicsAllocation allocation;
uint64_t hostFunctionId = 1;
std::function<void(GraphicsAllocation &)> downloadAllocationImpl = [](GraphicsAllocation &) {};
bool isTbx = false;
auto hostFunctionStreamer = std::make_unique<HostFunctionStreamer>(&allocation,
&hostFunctionId,
downloadAllocationImpl,
isTbx);
constexpr auto size = 1024u;
std::byte buff[size] = {};
uint64_t userHostFunctionStored = 10u;
uint64_t userDataStored = 20u;
uint32_t tagStored = 0;
uint64_t callbackAddress = 1024;
uint64_t userDataAddress = 2048;
bool isInOrder = true;
HostFunctionData hostFunctionData{
.entry = &userHostFunctionStored,
.userData = &userDataStored,
.internalTag = &tagStored};
uint64_t userCallback = 0xAAAA'0000ull;
uint64_t userCallbackData = 0xBBBB'000ull;
HostFunction hostFunction{
.hostFunctionAddress = callbackAddress,
.userDataAddress = userDataAddress,
.isInOrder = true};
LinearStream commandStream(buff, size);
auto miStoreDataImmBuffer1 = commandStream.getSpaceForCmd<MI_STORE_DATA_IMM>();
HostFunctionHelper::programHostFunctionAddress<FamilyType>(nullptr, miStoreDataImmBuffer1, hostFunctionData, userCallback);
auto miStoreDataImmBuffer2 = commandStream.getSpaceForCmd<MI_STORE_DATA_IMM>();
HostFunctionHelper::programHostFunctionUserData<FamilyType>(nullptr, miStoreDataImmBuffer2, hostFunctionData, userCallbackData);
auto miStoreDataImmBuffer3 = commandStream.getSpaceForCmd<MI_STORE_DATA_IMM>();
HostFunctionHelper::programSignalHostFunctionStart<FamilyType>(nullptr, miStoreDataImmBuffer3, hostFunctionData);
HostFunctionHelper<FamilyType>::programHostFunctionId(nullptr, miStoreDataImmBuffer1, *hostFunctionStreamer.get(), std::move(hostFunction));
auto semaphoreCommand = commandStream.getSpaceForCmd<MI_SEMAPHORE_WAIT>();
HostFunctionHelper::programWaitForHostFunctionCompletion<FamilyType>(nullptr, semaphoreCommand, hostFunctionData);
HostFunctionHelper<FamilyType>::programHostFunctionWaitForCompletion(nullptr, semaphoreCommand, *hostFunctionStreamer.get());
HardwareParse hwParser;
hwParser.parseCommands<FamilyType>(commandStream, 0);
auto miStores = findAll<MI_STORE_DATA_IMM *>(hwParser.cmdList.begin(), hwParser.cmdList.end());
EXPECT_EQ(3u, miStores.size());
EXPECT_EQ(1u, miStores.size());
auto miWait = findAll<MI_SEMAPHORE_WAIT *>(hwParser.cmdList.begin(), hwParser.cmdList.end());
EXPECT_EQ(1u, miWait.size());
// program callback address
// program host function id
auto expectedHostFunctionId = 1u;
auto miStoreUserHostFunction = genCmdCast<MI_STORE_DATA_IMM *>(*miStores[0]);
EXPECT_EQ(reinterpret_cast<uint64_t>(&userHostFunctionStored), miStoreUserHostFunction->getAddress());
EXPECT_EQ(getLowPart(userCallback), miStoreUserHostFunction->getDataDword0());
EXPECT_EQ(getHighPart(userCallback), miStoreUserHostFunction->getDataDword1());
EXPECT_EQ(reinterpret_cast<uint64_t>(&hostFunctionId), miStoreUserHostFunction->getAddress());
EXPECT_EQ(getLowPart(expectedHostFunctionId), miStoreUserHostFunction->getDataDword0());
EXPECT_EQ(getHighPart(expectedHostFunctionId), miStoreUserHostFunction->getDataDword1());
EXPECT_TRUE(miStoreUserHostFunction->getStoreQword());
// program callback data
auto miStoreUserData = genCmdCast<MI_STORE_DATA_IMM *>(*miStores[1]);
EXPECT_EQ(reinterpret_cast<uint64_t>(&userDataStored), miStoreUserData->getAddress());
EXPECT_EQ(getLowPart(userCallbackData), miStoreUserData->getDataDword0());
EXPECT_EQ(getHighPart(userCallbackData), miStoreUserData->getDataDword1());
EXPECT_TRUE(miStoreUserData->getStoreQword());
// signal pending job
auto miStoreSignalTag = genCmdCast<MI_STORE_DATA_IMM *>(*miStores[2]);
EXPECT_EQ(reinterpret_cast<uint64_t>(&tagStored), miStoreSignalTag->getAddress());
EXPECT_EQ(static_cast<uint32_t>(HostFunctionTagStatus::pending), miStoreSignalTag->getDataDword0());
EXPECT_FALSE(miStoreSignalTag->getStoreQword());
// wait for completion
// program wait for host function completion
auto miWaitTag = genCmdCast<MI_SEMAPHORE_WAIT *>(*miWait[0]);
EXPECT_EQ(reinterpret_cast<uint64_t>(&tagStored), miWaitTag->getSemaphoreGraphicsAddress());
EXPECT_EQ(static_cast<uint32_t>(HostFunctionTagStatus::completed), miWaitTag->getSemaphoreDataDword());
EXPECT_EQ(reinterpret_cast<uint64_t>(&hostFunctionId), miWaitTag->getSemaphoreGraphicsAddress());
EXPECT_EQ(static_cast<uint32_t>(HostFunctionStatus::completed), miWaitTag->getSemaphoreDataDword());
EXPECT_EQ(MI_SEMAPHORE_WAIT::COMPARE_OPERATION_SAD_EQUAL_SDD, miWaitTag->getCompareOperation());
EXPECT_EQ(MI_SEMAPHORE_WAIT::WAIT_MODE_POLLING_MODE, miWaitTag->getWaitMode());
// host function from host function streamer
auto programmedHostFunction = hostFunctionStreamer->getHostFunction();
EXPECT_EQ(callbackAddress, programmedHostFunction.hostFunctionAddress);
EXPECT_EQ(userDataAddress, programmedHostFunction.userDataAddress);
EXPECT_EQ(isInOrder, programmedHostFunction.isInOrder);
}
HWTEST_F(HostFunctionTests, givenHostFunctionStreamerWhenProgramHostFunctionIsCalledThenHostFunctionStreamerWasUpdatedWithHostFunction) {
uint64_t callbackAddress1 = 1024;
uint64_t userDataAddress1 = 2048;
uint64_t callbackAddress2 = 4096;
uint64_t userDataAddress2 = 8192;
constexpr auto size = 4096u;
std::byte buff[size] = {};
LinearStream stream(buff, size);
for (bool isTbx : ::testing::Bool()) {
HostFunction hostFunction1{
.hostFunctionAddress = callbackAddress1,
.userDataAddress = userDataAddress1,
.isInOrder = true};
HostFunction hostFunction2{
.hostFunctionAddress = callbackAddress2,
.userDataAddress = userDataAddress2,
.isInOrder = false};
uint64_t hostFunctionId = HostFunctionStatus::completed;
uint64_t hostFunctionIdAddress = reinterpret_cast<uint64_t>(&hostFunctionId);
MockGraphicsAllocation mockAllocation;
bool downloadAllocationCalled = false;
std::function<void(GraphicsAllocation &)> downloadAllocationImpl = [&](GraphicsAllocation &) { downloadAllocationCalled = true; };
auto hostFunctionStreamer = std::make_unique<HostFunctionStreamer>(&mockAllocation,
&hostFunctionId,
downloadAllocationImpl,
isTbx);
EXPECT_FALSE(hostFunctionStreamer->isHostFunctionReadyToExecute());
{
// 1st host function in order
HostFunctionHelper<FamilyType>::programHostFunction(stream, *hostFunctionStreamer.get(), std::move(hostFunction1));
hostFunctionId = 1u; // simulate function being processed
auto programmedHostFunction1 = hostFunctionStreamer->getHostFunction();
EXPECT_EQ(&mockAllocation, hostFunctionStreamer->getHostFunctionIdAllocation());
EXPECT_EQ(hostFunctionIdAddress, hostFunctionStreamer->getHostFunctionIdGpuAddress());
hostFunctionId = HostFunctionStatus::completed;
EXPECT_FALSE(hostFunctionStreamer->isHostFunctionReadyToExecute());
hostFunctionId = 1u;
EXPECT_TRUE(hostFunctionStreamer->isHostFunctionReadyToExecute());
EXPECT_EQ(isTbx, downloadAllocationCalled);
hostFunctionStreamer->prepareForExecution(programmedHostFunction1);
// next host function must wait, streamer busy until host function is completed
EXPECT_FALSE(hostFunctionStreamer->isHostFunctionReadyToExecute());
hostFunctionStreamer->signalHostFunctionCompletion(programmedHostFunction1);
EXPECT_EQ(HostFunctionStatus::completed, hostFunctionId); // host function ID should be marked as completed
EXPECT_EQ(callbackAddress1, programmedHostFunction1.hostFunctionAddress);
EXPECT_EQ(userDataAddress1, programmedHostFunction1.userDataAddress);
EXPECT_TRUE(programmedHostFunction1.isInOrder);
}
{
hostFunctionId = HostFunctionStatus::completed;
// 2nd host function out of order
HostFunctionHelper<FamilyType>::programHostFunction(stream, *hostFunctionStreamer.get(), std::move(hostFunction2));
hostFunctionId = 3u; // simulate function being processed
auto programmedHostFunction2 = hostFunctionStreamer->getHostFunction();
EXPECT_EQ(&mockAllocation, hostFunctionStreamer->getHostFunctionIdAllocation());
EXPECT_EQ(hostFunctionIdAddress, hostFunctionStreamer->getHostFunctionIdGpuAddress());
hostFunctionId = HostFunctionStatus::completed;
EXPECT_FALSE(hostFunctionStreamer->isHostFunctionReadyToExecute());
hostFunctionId = hostFunctionStreamer->getNextHostFunctionIdAndIncrement();
EXPECT_TRUE(hostFunctionStreamer->isHostFunctionReadyToExecute());
EXPECT_EQ(isTbx, downloadAllocationCalled);
hostFunctionStreamer->prepareForExecution(programmedHostFunction2);
hostFunctionStreamer->signalHostFunctionCompletion(programmedHostFunction2);
EXPECT_EQ(HostFunctionStatus::completed, hostFunctionId); // host function ID should be marked as completed
EXPECT_EQ(callbackAddress2, programmedHostFunction2.hostFunctionAddress);
EXPECT_EQ(userDataAddress2, programmedHostFunction2.userDataAddress);
EXPECT_FALSE(programmedHostFunction2.isInOrder);
}
{
// no more programmed Host Functions
EXPECT_FALSE(hostFunctionStreamer->isHostFunctionReadyToExecute());
}
}
}
TEST(CommandStreamReceiverHostFunctionsTest, givenCommandStreamReceiverWhenEnsureHostFunctionDataInitializationCalledThenHostFunctionAllocationIsBeingAllocatedOnlyOnce) {
MockExecutionEnvironment executionEnvironment(defaultHwInfo.get());
DeviceBitfield devices(0b11);
auto csr = std::make_unique<MockCommandStreamReceiver>(executionEnvironment, 0, devices);
executionEnvironment.memoryManager.reset(new OsAgnosticMemoryManager(executionEnvironment));
csr->initializeTagAllocation();
csr->ensureHostFunctionWorkerStarted();
auto *streamer = &csr->getHostFunctionStreamer();
EXPECT_NE(nullptr, streamer);
EXPECT_EQ(1u, csr->startHostFunctionWorkerCalledTimes);
csr->ensureHostFunctionWorkerStarted();
EXPECT_EQ(streamer, &csr->getHostFunctionStreamer());
EXPECT_EQ(1u, csr->startHostFunctionWorkerCalledTimes);
csr->startHostFunctionWorker();
EXPECT_EQ(2u, csr->startHostFunctionWorkerCalledTimes); // direct call -> the counter updated but due to an early return allocation didn't change
EXPECT_EQ(streamer, &csr->getHostFunctionStreamer());
EXPECT_EQ(AllocationType::tagBuffer, streamer->getHostFunctionIdAllocation()->getAllocationType());
auto expectedHostFunctionIdAddress = reinterpret_cast<uint64_t>(ptrOffset(streamer->getHostFunctionIdAllocation()->getUnderlyingBuffer(),
TagAllocationLayout::hostFunctionDataOffset));
EXPECT_EQ(expectedHostFunctionIdAddress, streamer->getHostFunctionIdGpuAddress());
}
TEST(CommandStreamReceiverHostFunctionsTest, givenDestructedCommandStreamReceiverWhenEnsureHostFunctionDataInitializationCalledThenHostFunctionAllocationsDeallocated) {
MockExecutionEnvironment executionEnvironment(defaultHwInfo.get());
DeviceBitfield devices(0b11);
auto csr = std::make_unique<MockCommandStreamReceiver>(executionEnvironment, 0, devices);
executionEnvironment.memoryManager.reset(new OsAgnosticMemoryManager(executionEnvironment));
csr->initializeTagAllocation();
csr->ensureHostFunctionWorkerStarted();
EXPECT_NE(nullptr, csr->getHostFunctionStreamer().getHostFunctionIdAllocation());
EXPECT_EQ(1u, csr->createHostFunctionWorkerCounter);
}

View File

@@ -6,12 +6,14 @@
*/
#include "shared/source/command_stream/host_function_worker_counting_semaphore.h"
#include "shared/source/command_stream/host_function_worker_cv.h"
#include "shared/test/common/helpers/debug_manager_state_restore.h"
#include "shared/test/common/mocks/mock_command_stream_receiver.h"
#include "shared/test/common/mocks/mock_execution_environment.h"
#include "shared/test/common/test_macros/test.h"
#include <atomic>
#include <vector>
#if defined(__clang__)
#if defined(__has_feature)
#if __has_feature(thread_sanitizer)
@@ -37,7 +39,6 @@ extern "C" void __tsan_ignore_thread_end();
namespace {
class MockCommandStreamReceiverHostFunction : public MockCommandStreamReceiver {
public:
using MockCommandStreamReceiver::hostFunctionData;
using MockCommandStreamReceiver::hostFunctionWorker;
using MockCommandStreamReceiver::MockCommandStreamReceiver;
@@ -45,28 +46,28 @@ class MockCommandStreamReceiverHostFunction : public MockCommandStreamReceiver {
CommandStreamReceiver::createHostFunctionWorker();
}
void signalHostFunctionWorker() override {
CommandStreamReceiver::signalHostFunctionWorker();
void signalHostFunctionWorker(uint32_t nHostFunctions) override {
CommandStreamReceiver::signalHostFunctionWorker(nHostFunctions);
}
};
struct Arg {
uint32_t expected = 0;
uint32_t result = 0;
uint32_t counter = 0;
std::atomic<uint32_t> counter{0};
};
extern "C" void hostFunctionExample(void *data) {
Arg *arg = static_cast<Arg *>(data);
arg->result = arg->expected;
++arg->counter;
arg->counter.fetch_add(1, std::memory_order_acq_rel);
}
void createArgs(std::vector<Arg> &hostFunctionArgs, uint32_t n) {
void createArgs(std::vector<std::unique_ptr<Arg>> &hostFunctionArgs, uint32_t n) {
hostFunctionArgs.reserve(n);
for (auto i = 0u; i < n; i++) {
hostFunctionArgs.push_back(Arg{.expected = i + 1, .result = 0, .counter = 0});
hostFunctionArgs.emplace_back(std::make_unique<Arg>(i + 1, 0, 0));
}
}
@@ -84,18 +85,44 @@ class HostFunctionMtFixture {
csrs.push_back(std::make_unique<MockCommandStreamReceiverHostFunction>(executionEnvironment, 0, deviceBitfield));
}
if (testingMode == 1) {
// csrs[0] is primary for all other csrs
for (auto i = 1u; i < numberOfCSRs; i++) {
csrs[i]->primaryCsr = csrs[0].get();
}
} else if (testingMode == 2) {
// csrs[0] and csrs[1] are primaries for other csrs
// secondary split between two primaries
for (auto i = 2u; i < numberOfCSRs; i++) {
uint32_t primaryIdx = (i % 2 == 0) ? 0 : 1;
csrs[i]->primaryCsr = csrs[primaryIdx].get();
}
}
for (auto &csr : csrs) {
csr->initializeTagAllocation();
}
for (auto i = 0u; i < csrs.size(); i++) {
*csrs[i]->hostFunctionData.entry = reinterpret_cast<uint64_t>(hostFunctionExample);
*csrs[i]->hostFunctionData.userData = reinterpret_cast<uint64_t>(&hostFunctionArgs[i]);
*csrs[i]->hostFunctionData.internalTag = static_cast<uint32_t>(HostFunctionTagStatus::completed);
for (auto &csr : csrs) {
csr->ensureHostFunctionWorkerStarted();
}
for (auto &csr : csrs) {
csr->startHostFunctionWorker();
for (auto i = 0u; i < csrs.size(); i++) {
auto &streamer = csrs[i]->getHostFunctionStreamer();
for (auto k = 0u; k < callbacksPerCsr; k++) {
bool isOutOfOrder = k < 3; // first 3, 8th and 9th are out of order, rest is in order
isOutOfOrder |= (k == 7) || (k == 8);
HostFunction hostFunction = {
.hostFunctionAddress = reinterpret_cast<uint64_t>(hostFunctionExample),
.userDataAddress = reinterpret_cast<uint64_t>(this->hostFunctionArgs[i].get()),
.isInOrder = !isOutOfOrder};
auto hostFunctionId = streamer.getNextHostFunctionIdAndIncrement();
streamer.addHostFunction(hostFunctionId, std::move(hostFunction));
}
}
}
@@ -108,10 +135,12 @@ class HostFunctionMtFixture {
while (true) {
for (auto i = 0u; i < csrs.size(); i++) {
if (*csrs[i]->hostFunctionData.internalTag == static_cast<uint32_t>(HostFunctionTagStatus::completed)) {
if (csrs[i]->getHostFunctionStreamer().getHostFunctionId() == HostFunctionStatus::completed) {
if (callbacksPerCsrCounter[i] < callbacksPerCsr) {
*csrs[i]->hostFunctionData.internalTag = static_cast<uint32_t>(HostFunctionTagStatus::pending);
auto hostFunctionId = (callbacksPerCsrCounter[i] * 2) + 1;
*csrs[i]->getHostFunctionStreamer().getHostFunctionIdPtr() = hostFunctionId;
++callbacksPerCsrCounter[i];
++callbacksCounter;
}
@@ -130,10 +159,11 @@ class HostFunctionMtFixture {
void waitForCallbacksCompletion() {
TSAN_ANNOTATE_IGNORE_BEGIN();
while (true) {
uint32_t csrsCompleted = 0u;
for (auto i = 0u; i < csrs.size(); i++) {
if (*csrs[i]->hostFunctionData.internalTag == static_cast<uint32_t>(HostFunctionTagStatus::completed)) {
if (csrs[i]->getHostFunctionStreamer().getHostFunctionId() == HostFunctionStatus::completed) {
++csrsCompleted;
}
}
@@ -157,11 +187,13 @@ class HostFunctionMtFixture {
TSAN_ANNOTATE_IGNORE_BEGIN();
for (auto i = 0u; i < csrs.size(); i++) {
Arg *arg = reinterpret_cast<Arg *>(*csrs[i]->hostFunctionData.userData);
EXPECT_EQ(arg->expected, arg->result);
EXPECT_EQ(uint32_t{i + 1u}, arg->result);
EXPECT_EQ(expectedCounter, arg->counter);
EXPECT_EQ(static_cast<uint32_t>(HostFunctionTagStatus::completed), *csrs[i]->hostFunctionData.internalTag);
Arg &arg = *(this->hostFunctionArgs[i].get());
EXPECT_EQ(arg.expected, arg.result);
EXPECT_EQ(uint32_t{i + 1u}, arg.result);
EXPECT_EQ(expectedCounter, arg.counter.load());
auto &streamer = csrs[i]->getHostFunctionStreamer();
EXPECT_EQ(HostFunctionStatus::completed, streamer.getHostFunctionId());
}
TSAN_ANNOTATE_IGNORE_END();
}
@@ -171,7 +203,7 @@ class HostFunctionMtFixture {
hostFunctionArgs.clear();
}
std::vector<Arg> hostFunctionArgs;
std::vector<std::unique_ptr<Arg>> hostFunctionArgs;
std::vector<std::unique_ptr<MockCommandStreamReceiverHostFunction>> csrs;
DebugManagerStateRestore restorer{};
uint32_t callbacksPerCsr = 0;
@@ -189,6 +221,11 @@ class HostFunctionMtTestP : public ::testing::TestWithParam<int>, public HostFun
auto param = GetParam();
this->testingMode = static_cast<int>(param);
debugManager.flags.HostFunctionWorkMode.set(this->testingMode);
if (testingMode == 1 || testingMode == 2) {
debugManager.flags.HostFunctionThreadPoolSize.set(2);
debugManager.flags.HostFunctionWorkMode.set(static_cast<int32_t>(HostFunctionWorkerMode::schedulerWithThreadPool));
}
}
void TearDown() override {
@@ -201,8 +238,8 @@ class HostFunctionMtTestP : public ::testing::TestWithParam<int>, public HostFun
TEST_P(HostFunctionMtTestP, givenHostFunctionWorkersWhenSequentialCsrJobIsSubmittedThenHostFunctionsWorkIsDoneCorrectly) {
uint32_t numberOfCSRs = 4;
uint32_t callbacksPerCsr = 6;
uint32_t numberOfCSRs = 6;
uint32_t callbacksPerCsr = 12;
configureCSRs(numberOfCSRs, callbacksPerCsr, testingMode, primaryCSRs);
@@ -210,7 +247,7 @@ TEST_P(HostFunctionMtTestP, givenHostFunctionWorkersWhenSequentialCsrJobIsSubmit
for (auto iCallback = 0u; iCallback < callbacksPerCsr; iCallback++) {
for (auto &csr : csrs) {
csr->signalHostFunctionWorker();
csr->signalHostFunctionWorker(1u);
}
}
@@ -221,8 +258,8 @@ TEST_P(HostFunctionMtTestP, givenHostFunctionWorkersWhenSequentialCsrJobIsSubmit
}
TEST_P(HostFunctionMtTestP, givenHostFunctionWorkersWhenEachCsrSubmitAllCalbacksPerThreadThenHostFunctionsWorkIsDoneCorrectly) {
uint32_t numberOfCSRs = 4;
uint32_t callbacksPerCsr = 6;
uint32_t numberOfCSRs = 6;
uint32_t callbacksPerCsr = 12;
configureCSRs(numberOfCSRs, callbacksPerCsr, testingMode, primaryCSRs);
@@ -233,9 +270,7 @@ TEST_P(HostFunctionMtTestP, givenHostFunctionWorkersWhenEachCsrSubmitAllCalbacks
auto submitAllCallbacksPerCsr = [&](uint32_t idxCsr) {
auto csr = csrs[idxCsr].get();
for (auto callbackIdx = 0u; callbackIdx < callbacksPerCsr; callbackIdx++) {
csr->signalHostFunctionWorker();
}
csr->signalHostFunctionWorker(callbacksPerCsr);
};
for (auto i = 0u; i < nSubmitters; i++) {
@@ -256,8 +291,8 @@ TEST_P(HostFunctionMtTestP, givenHostFunctionWorkersWhenEachCsrSubmitAllCalbacks
TEST_P(HostFunctionMtTestP, givenHostFunctionWorkersWhenCsrJobsAreSubmittedConcurrentlyThenHostFunctionsWorkIsDoneCorrectly) {
uint32_t numberOfCSRs = 4;
uint32_t callbacksPerCsr = 6;
uint32_t numberOfCSRs = 6;
uint32_t callbacksPerCsr = 12;
configureCSRs(numberOfCSRs, callbacksPerCsr, testingMode, primaryCSRs);
@@ -265,10 +300,10 @@ TEST_P(HostFunctionMtTestP, givenHostFunctionWorkersWhenCsrJobsAreSubmittedConcu
std::vector<std::jthread> submitters;
submitters.reserve(nSubmitters);
// multiple threads can submit host function in parrarel using the same csr
// multiple threads can submit host function in parallel using the same csr
auto submitOnceCallbackForAllCSRs = [&]() {
for (auto &csr : csrs) {
csr->signalHostFunctionWorker();
csr->signalHostFunctionWorker(1u);
}
};
@@ -291,7 +326,9 @@ TEST_P(HostFunctionMtTestP, givenHostFunctionWorkersWhenCsrJobsAreSubmittedConcu
INSTANTIATE_TEST_SUITE_P(AllModes,
HostFunctionMtTestP,
::testing::Values(
0 // Counting Semaphore implementation
0, // Counting Semaphore implementation
1, // Thread Pool implementation, one primary csr
2 // Thread Pool implementation, two primary csrs
));
} // namespace