feature: Adding fork/join support

Related-To: NEO-15373

Signed-off-by: Chodor, Jaroslaw <jaroslaw.chodor@intel.com>
This commit is contained in:
Chodor, Jaroslaw
2025-07-07 15:12:59 +00:00
committed by Compute-Runtime-Automation
parent 32611736e5
commit 7acb9585af
7 changed files with 805 additions and 78 deletions

View File

@@ -9,10 +9,12 @@
#include "level_zero/core/source/cmdlist/cmdlist.h"
#include "level_zero/core/source/context/context.h"
#include "level_zero/core/source/event/event.h"
namespace L0 {
Graph::~Graph() {
this->unregisterSignallingEvents();
for (auto *sg : subGraphs) {
if (false == sg->wasPreallocated()) {
delete sg;
@@ -32,7 +34,58 @@ void Graph::startCapturingFrom(L0::CommandList &captureSrc, bool isSubGraph) {
}
void Graph::stopCapturing() {
this->unregisterSignallingEvents();
this->captureSrc = nullptr;
this->wasCapturingStopped = true;
}
void Graph::tryJoinOnNextCommand(L0::CommandList &childCmdList, L0::Event &joinEvent) {
auto forkInfo = this->unjoinedForks.find(&childCmdList);
if (this->unjoinedForks.end() == forkInfo) {
return;
}
ForkJoinInfo forkJoinInfo = {};
forkJoinInfo.forkCommandId = forkInfo->second.forkCommandId;
forkJoinInfo.forkEvent = forkInfo->second.forkEvent;
forkJoinInfo.joinCommandId = static_cast<CapturedCommandId>(this->commands.size());
forkJoinInfo.joinEvent = &joinEvent;
forkJoinInfo.forkDestiny = childCmdList.releaseCaptureTarget();
forkJoinInfo.forkDestiny->stopCapturing();
this->joinedForks[forkInfo->second.forkCommandId] = forkJoinInfo;
this->unjoinedForks.erase(forkInfo);
}
void Graph::forkTo(L0::CommandList &childCmdList, Graph *&child, L0::Event &forkEvent) {
UNRECOVERABLE_IF(child || childCmdList.getCaptureTarget()); // should not be capturing already
ze_context_handle_t ctx = nullptr;
childCmdList.getContextHandle(&ctx);
child = new Graph(L0::Context::fromHandle(ctx), false);
child->startCapturingFrom(childCmdList, true);
childCmdList.setCaptureTarget(child);
this->subGraphs.push_back(child);
auto forkEventInfo = this->recordedSignals.find(&forkEvent);
UNRECOVERABLE_IF(this->recordedSignals.end() == forkEventInfo);
this->unjoinedForks[&childCmdList] = ForkInfo{.forkCommandId = forkEventInfo->second,
.forkEvent = &forkEvent};
}
void Graph::registerSignallingEventFromPreviousCommand(L0::Event &ev) {
ev.setRecordedSignalFrom(this->captureSrc);
this->recordedSignals[&ev] = static_cast<CapturedCommandId>(this->commands.size() - 1);
}
void Graph::unregisterSignallingEvents() {
for (auto ev : this->recordedSignals) {
ev.first->setRecordedSignalFrom(nullptr);
}
}
template <typename ContainerT>
auto getOptionalData(ContainerT &container) {
return container.empty() ? nullptr : container.data();
}
Closure<CaptureApi::zeCommandListAppendMemoryCopy>::Closure(const ApiArgs &apiArgs)
@@ -44,7 +97,7 @@ Closure<CaptureApi::zeCommandListAppendMemoryCopy>::Closure(const ApiArgs &apiAr
}
ze_result_t Closure<CaptureApi::zeCommandListAppendMemoryCopy>::instantiateTo(L0::CommandList &executionTarget) const {
return zeCommandListAppendMemoryCopy(&executionTarget, apiArgs.dstptr, apiArgs.srcptr, apiArgs.size, apiArgs.hSignalEvent, apiArgs.numWaitEvents, apiArgs.numWaitEvents ? const_cast<ze_event_handle_t *>(indirectArgs.waitEvents.data()) : nullptr);
return zeCommandListAppendMemoryCopy(&executionTarget, apiArgs.dstptr, apiArgs.srcptr, apiArgs.size, apiArgs.hSignalEvent, apiArgs.numWaitEvents, const_cast<ze_event_handle_t *>(getOptionalData(indirectArgs.waitEvents)));
}
Closure<CaptureApi::zeCommandListAppendBarrier>::Closure(const ApiArgs &apiArgs)
@@ -56,44 +109,92 @@ Closure<CaptureApi::zeCommandListAppendBarrier>::Closure(const ApiArgs &apiArgs)
}
ze_result_t Closure<CaptureApi::zeCommandListAppendBarrier>::instantiateTo(L0::CommandList &executionTarget) const {
return zeCommandListAppendBarrier(&executionTarget, apiArgs.hSignalEvent, apiArgs.numWaitEvents, apiArgs.numWaitEvents ? const_cast<ze_event_handle_t *>(indirectArgs.waitEvents.data()) : nullptr);
return zeCommandListAppendBarrier(&executionTarget, apiArgs.hSignalEvent, apiArgs.numWaitEvents, const_cast<ze_event_handle_t *>(getOptionalData(indirectArgs.waitEvents)));
}
Closure<CaptureApi::zeCommandListAppendWaitOnEvents>::Closure(const ApiArgs &apiArgs)
: apiArgs(apiArgs) {
this->indirectArgs.waitEvents.reserve(apiArgs.numEvents);
for (uint32_t i = 0; i < apiArgs.numEvents; ++i) {
this->indirectArgs.waitEvents.push_back(apiArgs.phEvents[i]);
}
}
ze_result_t Closure<CaptureApi::zeCommandListAppendWaitOnEvents>::instantiateTo(L0::CommandList &executionTarget) const {
return zeCommandListAppendWaitOnEvents(&executionTarget, apiArgs.numEvents, const_cast<ze_event_handle_t *>(getOptionalData(indirectArgs.waitEvents)));
}
ExecutableGraph::~ExecutableGraph() = default;
void ExecutableGraph::instantiateFrom(Graph &graph) {
L0::CommandList *ExecutableGraph::allocateAndAddCommandListSubmissionNode() {
ze_command_list_handle_t newCmdListHandle = nullptr;
src->getContext()->createCommandList(src->getCaptureTargetDesc().hDevice, &src->getCaptureTargetDesc().desc, &newCmdListHandle);
L0::CommandList *newCmdList = L0::CommandList::fromHandle(newCmdListHandle);
UNRECOVERABLE_IF(nullptr == newCmdList);
this->myCommandLists.emplace_back(newCmdList);
this->submissionChain.emplace_back(newCmdList);
return newCmdList;
}
void ExecutableGraph::addSubGraphSubmissionNode(ExecutableGraph *subGraph) {
this->submissionChain.emplace_back(subGraph);
}
void ExecutableGraph::instantiateFrom(Graph &graph, const GraphInstatiateSettings &settings) {
this->src = &graph;
this->executionTarget = graph.getExecutionTarget();
std::unordered_map<Graph *, ExecutableGraph *> executableSubGraphMap;
executableSubGraphMap.reserve(graph.getSubgraphs().size());
this->subGraphs.reserve(graph.getSubgraphs().size());
for (auto &srcSubgraph : graph.getSubgraphs()) {
auto execSubGraph = std::make_unique<ExecutableGraph>();
execSubGraph->instantiateFrom(*srcSubgraph, settings);
executableSubGraphMap[srcSubgraph] = execSubGraph.get();
this->subGraphs.push_back(std::move(execSubGraph));
}
if (graph.empty() == false) {
[[maybe_unused]] ze_result_t err = ZE_RESULT_SUCCESS;
ze_command_list_handle_t cmdListHandle = nullptr;
src->getContext()->createCommandList(src->getCaptureTargetDesc().hDevice, &src->getCaptureTargetDesc().desc, &cmdListHandle);
L0::CommandList *hwCommands = L0::CommandList::fromHandle(cmdListHandle);
UNRECOVERABLE_IF(nullptr == hwCommands);
this->hwCommands.reset(hwCommands);
L0::CommandList *currCmdList = nullptr;
for (const CapturedCommand &cmd : src->getCapturedCommands()) {
const auto &allCommands = src->getCapturedCommands();
for (CapturedCommandId cmdId = 0; cmdId < static_cast<uint32_t>(allCommands.size()); ++cmdId) {
auto &cmd = src->getCapturedCommands()[cmdId];
if (nullptr == currCmdList) {
currCmdList = this->allocateAndAddCommandListSubmissionNode();
}
switch (static_cast<CaptureApi>(cmd.index())) {
default:
break;
#define RR_CAPTURED_API(X) \
case CaptureApi::X: \
std::get<static_cast<size_t>(CaptureApi::X)>(cmd).instantiateTo(*hwCommands); \
DEBUG_BREAK_IF(err != ZE_RESULT_SUCCESS); \
#define RR_CAPTURED_API(X) \
case CaptureApi::X: \
std::get<static_cast<size_t>(CaptureApi::X)>(cmd).instantiateTo(*currCmdList); \
DEBUG_BREAK_IF(err != ZE_RESULT_SUCCESS); \
break;
RR_CAPTURED_APIS()
#undef RR_CAPTURED_API
}
}
hwCommands->close();
}
this->subGraphs.reserve(graph.getSubgraphs().size());
for (auto &srcSubgraph : graph.getSubgraphs()) {
auto execSubGraph = std::make_unique<ExecutableGraph>();
execSubGraph->instantiateFrom(*srcSubgraph);
this->subGraphs.push_back(std::move(execSubGraph));
auto *forkTarget = graph.getJoinedForkTarget(cmdId);
if (nullptr != forkTarget) {
auto execSubGraph = executableSubGraphMap.find(forkTarget);
UNRECOVERABLE_IF(executableSubGraphMap.end() == execSubGraph);
if (settings.forkPolicy == GraphInstatiateSettings::ForkPolicySplitLevels) {
// interleave
currCmdList->close();
currCmdList = nullptr;
this->addSubGraphSubmissionNode(execSubGraph->second);
} else {
// submit after current
UNRECOVERABLE_IF(settings.forkPolicy != GraphInstatiateSettings::ForkPolicyMonolythicLevels)
this->addSubGraphSubmissionNode(execSubGraph->second);
}
}
}
UNRECOVERABLE_IF(nullptr == currCmdList);
currCmdList->close();
}
}
@@ -111,20 +212,73 @@ ze_result_t ExecutableGraph::execute(L0::CommandList *executionTarget, void *pNe
}
executionTarget->appendSignalEvent(hSignalEvent, false);
} else {
auto commands = this->hwCommands.get();
ze_command_list_handle_t graphCmdList = commands;
auto res = executionTarget->appendCommandLists(1, &graphCmdList, hSignalEvent, numWaitEvents, phWaitEvents);
if (ZE_RESULT_SUCCESS != res) {
return res;
L0::CommandList *const myLastCommandList = this->myCommandLists.rbegin()->get();
{
// first submission node
L0::CommandList **cmdList = std::get_if<L0::CommandList *>(&this->submissionChain[0]);
UNRECOVERABLE_IF(nullptr == cmdList);
auto currSignalEvent = (myLastCommandList == *cmdList) ? hSignalEvent : nullptr;
ze_command_list_handle_t hCmdList = *cmdList;
auto res = executionTarget->appendCommandLists(1, &hCmdList, currSignalEvent, numWaitEvents, phWaitEvents);
if (ZE_RESULT_SUCCESS != res) {
return res;
}
}
}
for (auto &subGraph : this->subGraphs) {
auto res = subGraph->execute(nullptr, pNext, nullptr, 0, nullptr);
if (ZE_RESULT_SUCCESS != res) {
return res;
for (size_t submissioNodeId = 1; submissioNodeId < this->submissionChain.size(); ++submissioNodeId) {
if (L0::CommandList **cmdList = std::get_if<L0::CommandList *>(&this->submissionChain[submissioNodeId])) {
auto currSignalEvent = (myLastCommandList == *cmdList) ? hSignalEvent : nullptr;
ze_command_list_handle_t hCmdList = *cmdList;
auto res = executionTarget->appendCommandLists(1, &hCmdList, currSignalEvent, numWaitEvents, phWaitEvents);
if (ZE_RESULT_SUCCESS != res) {
return res;
}
} else {
L0::ExecutableGraph **subGraph = std::get_if<L0::ExecutableGraph *>(&this->submissionChain[submissioNodeId]);
UNRECOVERABLE_IF(nullptr == subGraph);
auto res = (*subGraph)->execute(nullptr, pNext, nullptr, 0, nullptr);
if (ZE_RESULT_SUCCESS != res) {
return res;
}
}
}
}
return ZE_RESULT_SUCCESS;
}
void recordHandleWaitEventsFromNextCommand(L0::CommandList &srcCmdList, Graph *&captureTarget, NEO::Range<ze_event_handle_t> events) {
if (captureTarget) {
// already recording, look for joins
for (auto evh : events) {
auto *potentialJoinEvent = L0::Event::fromHandle(evh);
auto signalFromCmdList = potentialJoinEvent->getRecordedSignalFrom();
if (nullptr == signalFromCmdList) {
continue;
}
captureTarget->tryJoinOnNextCommand(*signalFromCmdList, *potentialJoinEvent);
}
} else {
// not recording yet, look for forks
for (auto evh : events) {
auto *potentialForkEvent = L0::Event::fromHandle(evh);
auto signalFromCmdList = potentialForkEvent->getRecordedSignalFrom();
if (nullptr == signalFromCmdList) {
continue;
}
signalFromCmdList->getCaptureTarget()->forkTo(srcCmdList, captureTarget, *potentialForkEvent);
}
}
}
void recordHandleSignalEventFromPreviousCommand(L0::CommandList &srcCmdList, Graph &captureTarget, ze_event_handle_t event) {
if (nullptr == event) {
return;
}
captureTarget.registerSignallingEventFromPreviousCommand(*L0::Event::fromHandle(event));
}
} // namespace L0

View File

@@ -7,11 +7,14 @@
#pragma once
#include "shared/source/utilities/range.h"
#include "shared/source/utilities/stackvec.h"
#include "level_zero/ze_api.h"
#include <atomic>
#include <memory>
#include <unordered_map>
#include <variant>
#include <vector>
@@ -26,7 +29,18 @@ struct _ze_executable_graph_handle_t {
namespace L0 {
inline std::atomic<bool> processUsesGraphs{false};
inline void enabledGraphs() {
bool graphsEnabled = false;
processUsesGraphs.compare_exchange_weak(graphsEnabled, true, std::memory_order_relaxed);
}
inline bool areGraphsEnabled() {
return processUsesGraphs.load();
}
struct Context;
struct Event;
#define RR_CAPTURED_APIS() \
RR_CAPTURED_API(zeCommandListAppendWriteGlobalTimestamp) \
@@ -70,6 +84,10 @@ struct Closure {
struct ApiArgs {
template <typename ArgsT>
ApiArgs(ArgsT...) {}
ze_event_handle_t hSignalEvent = nullptr;
uint32_t numWaitEvents = 0;
ze_event_handle_t *phWaitEvents = nullptr;
};
Closure(const ApiArgs &apiArgs) {}
@@ -80,6 +98,18 @@ struct Closure {
}
};
template <CaptureApi api, typename... TArgs>
inline NEO::Range<ze_event_handle_t> getCommandsWaitEventsList(TArgs... args) {
typename Closure<api>::ApiArgs structuredApiArgs{args...};
return NEO::Range<ze_event_handle_t>{structuredApiArgs.phWaitEvents, structuredApiArgs.numWaitEvents};
}
template <CaptureApi api, typename... TArgs>
inline ze_event_handle_t getCommandsSignalEvent(TArgs... args) {
typename Closure<api>::ApiArgs structuredApiArgs{args...};
return structuredApiArgs.hSignalEvent;
}
template <>
struct Closure<CaptureApi::zeCommandListAppendMemoryCopy> {
inline static constexpr bool isSupported = true;
@@ -123,6 +153,35 @@ struct Closure<CaptureApi::zeCommandListAppendBarrier> {
ze_result_t instantiateTo(CommandList &executionTarget) const;
};
template <>
struct Closure<CaptureApi::zeCommandListAppendWaitOnEvents> {
inline static constexpr bool isSupported = true;
struct ApiArgs {
ze_command_list_handle_t hCommandList;
uint32_t numEvents;
ze_event_handle_t *phEvents;
} apiArgs;
struct IndirectArgs {
StackVec<ze_event_handle_t, 8> waitEvents;
} indirectArgs;
Closure(const ApiArgs &apiArgs);
ze_result_t instantiateTo(CommandList &executionTarget) const;
};
template <>
inline NEO::Range<ze_event_handle_t> getCommandsWaitEventsList<CaptureApi::zeCommandListAppendWaitOnEvents>(ze_command_list_handle_t, uint32_t numEvents, ze_event_handle_t *phEvents) {
return NEO::Range<ze_event_handle_t>{phEvents, numEvents};
}
template <>
inline ze_event_handle_t getCommandsSignalEvent<CaptureApi::zeCommandListAppendWaitOnEvents>(ze_command_list_handle_t, uint32_t numEvents, ze_event_handle_t *phEvents) {
return nullptr;
}
using ClosureVariants = std::variant<
#define RR_CAPTURED_API(X) Closure<CaptureApi::X>,
RR_CAPTURED_APIS()
@@ -130,9 +189,11 @@ using ClosureVariants = std::variant<
int>;
using CapturedCommand = ClosureVariants;
using CapturedCommandId = uint32_t;
struct Graph : _ze_graph_handle_t {
Graph(L0::Context *ctx, bool preallocated) : ctx(ctx), preallocated(preallocated) {
commands.reserve(16);
enabledGraphs();
}
~Graph();
@@ -167,6 +228,14 @@ struct Graph : _ze_graph_handle_t {
return commands;
}
Graph *getJoinedForkTarget(CapturedCommandId cmdId) {
auto it = joinedForks.find(cmdId);
if (joinedForks.end() == it) {
return nullptr;
}
return it->second.forkDestiny;
}
const StackVec<Graph *, 16> &getSubgraphs() {
return subGraphs;
}
@@ -188,6 +257,14 @@ struct Graph : _ze_graph_handle_t {
return commands.empty();
}
bool validForInstantiation() const {
return closed() && unjoinedForks.empty() && std::all_of(subGraphs.begin(), subGraphs.end(), [](auto it) { return it->validForInstantiation(); });
}
bool closed() const {
return wasCapturingStopped;
}
L0::CommandList *getExecutionTarget() const {
return executionTarget;
}
@@ -196,11 +273,17 @@ struct Graph : _ze_graph_handle_t {
return (nullptr != executionTarget);
}
void addSubGraph(Graph *subGraph) {
subGraphs.push_back(subGraph);
bool hasUnjoinedForks() const {
return false == unjoinedForks.empty();
}
void tryJoinOnNextCommand(L0::CommandList &childCmdList, L0::Event &joinEvent);
void forkTo(L0::CommandList &childCmdList, Graph *&child, L0::Event &forkEvent);
void registerSignallingEventFromPreviousCommand(L0::Event &ev);
protected:
void unregisterSignallingEvents();
std::vector<CapturedCommand> commands;
StackVec<Graph *, 16> subGraphs;
@@ -210,13 +293,82 @@ struct Graph : _ze_graph_handle_t {
L0::Context *ctx = nullptr;
bool preallocated = false;
bool wasCapturingStopped = false;
std::unordered_map<L0::Event *, CapturedCommandId> recordedSignals;
struct ForkInfo {
CapturedCommandId forkCommandId = 0;
ze_event_handle_t forkEvent = nullptr;
};
std::unordered_map<L0::CommandList *, ForkInfo> unjoinedForks;
struct ForkJoinInfo {
CapturedCommandId forkCommandId = 0;
CapturedCommandId joinCommandId = 0;
ze_event_handle_t forkEvent = nullptr;
ze_event_handle_t joinEvent = nullptr;
Graph *forkDestiny = nullptr;
};
std::unordered_map<CapturedCommandId, ForkJoinInfo> joinedForks;
};
void recordHandleWaitEventsFromNextCommand(L0::CommandList &srcCmdList, Graph *&captureTarget, NEO::Range<ze_event_handle_t> events);
void recordHandleSignalEventFromPreviousCommand(L0::CommandList &srcCmdList, Graph &captureTarget, ze_event_handle_t event);
template <CaptureApi api, typename... TArgs>
ze_result_t captureCommand(L0::CommandList &srcCmdList, Graph *&captureTarget, TArgs... apiArgs) {
if (false == areGraphsEnabled()) {
return ZE_RESULT_ERROR_NOT_AVAILABLE;
}
auto eventsWaitList = getCommandsWaitEventsList<api>(apiArgs...);
if ((false == eventsWaitList.empty()) && ((nullptr == captureTarget) || (captureTarget->hasUnjoinedForks()))) { // either is not capturing and is potential fork or this can be a join operation
recordHandleWaitEventsFromNextCommand(srcCmdList, captureTarget, eventsWaitList);
}
if (nullptr == captureTarget) {
return ZE_RESULT_ERROR_NOT_AVAILABLE;
}
auto ret = captureTarget->capture<api>(apiArgs...);
if (ZE_RESULT_SUCCESS != ret) {
return ret;
}
if (getCommandsSignalEvent<api>(apiArgs...)) {
recordHandleSignalEventFromPreviousCommand(srcCmdList, *captureTarget, getCommandsSignalEvent<api>(apiArgs...));
}
return ZE_RESULT_SUCCESS;
}
struct ExecutableGraph;
using GraphSubmissionSegment = std::variant<L0::CommandList *, ExecutableGraph *>;
using GraphSubmissionChain = std::vector<GraphSubmissionSegment>;
struct GraphInstatiateSettings {
GraphInstatiateSettings() = default;
GraphInstatiateSettings(void *pNext) {
UNRECOVERABLE_IF(nullptr != pNext);
}
enum ForkPolicy {
ForkPolicyMonolythicLevels, // build and submit monolythic commandlists for each level
ForkPolicySplitLevels // split commandlists on forks and interleave submission with child graphs (prevents deadlocks when submitting to single HW queue)
};
ForkPolicy forkPolicy = ForkPolicySplitLevels;
};
struct ExecutableGraph : _ze_executable_graph_handle_t {
ExecutableGraph() {
}
void instantiateFrom(Graph &graph);
void instantiateFrom(Graph &graph, const GraphInstatiateSettings &settings);
void instantiateFrom(Graph &graph) {
instantiateFrom(graph, {});
}
~ExecutableGraph();
@@ -225,7 +377,7 @@ struct ExecutableGraph : _ze_executable_graph_handle_t {
}
bool empty() {
return nullptr == hwCommands;
return myCommandLists.empty();
}
bool isSubGraph() const {
@@ -239,11 +391,16 @@ struct ExecutableGraph : _ze_executable_graph_handle_t {
ze_result_t execute(L0::CommandList *executionTarget, void *pNext, ze_event_handle_t hSignalEvent, uint32_t numWaitEvents, ze_event_handle_t *phWaitEvents);
protected:
L0::CommandList *allocateAndAddCommandListSubmissionNode();
void addSubGraphSubmissionNode(ExecutableGraph *subGraph);
Graph *src = nullptr;
L0::CommandList *executionTarget = nullptr;
std::unique_ptr<L0::CommandList> hwCommands;
std::vector<std::unique_ptr<L0::CommandList>> myCommandLists;
StackVec<std::unique_ptr<ExecutableGraph>, 16> subGraphs;
GraphSubmissionChain submissionChain;
};
} // namespace L0