feature: check graph capture restrictions

Related-To: NEO-15377
Signed-off-by: Naklicki, Mateusz <mateusz.naklicki@intel.com>
This commit is contained in:
Naklicki, Mateusz
2025-10-17 14:55:53 +00:00
committed by Compute-Runtime-Automation
parent a503776008
commit 0913ef4e7a
3 changed files with 150 additions and 13 deletions

View File

@@ -43,7 +43,7 @@ ze_result_t ZE_APICALL zeCommandListBeginGraphCaptureExp(ze_command_list_handle_
return ZE_RESULT_ERROR_INVALID_ARGUMENT; return ZE_RESULT_ERROR_INVALID_ARGUMENT;
} }
if (cmdList->getCaptureTarget() != nullptr) { if (cmdList->isCapturing() || !cmdList->isImmediateType() || cmdList->isInSynchronousMode()) {
return ZE_RESULT_ERROR_INVALID_ARGUMENT; return ZE_RESULT_ERROR_INVALID_ARGUMENT;
} }
@@ -68,12 +68,12 @@ ze_result_t ZE_APICALL zeCommandListBeginCaptureIntoGraphExp(ze_command_list_han
return ZE_RESULT_ERROR_INVALID_ARGUMENT; return ZE_RESULT_ERROR_INVALID_ARGUMENT;
} }
if (cmdList->getCaptureTarget() != nullptr) { if (cmdList->isCapturing()) {
return ZE_RESULT_ERROR_INVALID_ARGUMENT; return ZE_RESULT_ERROR_INVALID_ARGUMENT;
} }
auto graph = L0::Graph::fromHandle(hGraph); auto graph = L0::Graph::fromHandle(hGraph);
if (nullptr == graph) { if (nullptr == graph || !graph->empty()) {
return ZE_RESULT_ERROR_INVALID_ARGUMENT; return ZE_RESULT_ERROR_INVALID_ARGUMENT;
} }
@@ -93,21 +93,22 @@ ze_result_t ZE_APICALL zeCommandListEndGraphCaptureExp(ze_command_list_handle_t
return ZE_RESULT_ERROR_INVALID_ARGUMENT; return ZE_RESULT_ERROR_INVALID_ARGUMENT;
} }
if (nullptr == cmdList->getCaptureTarget()) { auto *graph = cmdList->getCaptureTarget();
if (nullptr == graph || graph->hasUnjoinedForks()) {
return ZE_RESULT_ERROR_INVALID_ARGUMENT; return ZE_RESULT_ERROR_INVALID_ARGUMENT;
} }
cmdList->getCaptureTarget()->stopCapturing(); cmdList->getCaptureTarget()->stopCapturing();
if (nullptr == phGraph) { if (nullptr == phGraph) {
if (cmdList->getCaptureTarget()->wasPreallocated()) { if (graph->wasPreallocated()) {
cmdList->setCaptureTarget(nullptr); cmdList->setCaptureTarget(nullptr);
return ZE_RESULT_SUCCESS; return ZE_RESULT_SUCCESS;
} else { } else {
return ZE_RESULT_ERROR_INVALID_ARGUMENT; return ZE_RESULT_ERROR_INVALID_ARGUMENT;
} }
} else { } else {
*phGraph = cmdList->getCaptureTarget(); *phGraph = graph->toHandle();
cmdList->setCaptureTarget(nullptr); cmdList->setCaptureTarget(nullptr);
} }

View File

@@ -487,6 +487,10 @@ struct CommandList : _ze_command_list_handle_t {
return this->captureTarget; return this->captureTarget;
} }
bool isCapturing() const {
return this->captureTarget != nullptr;
}
Graph *releaseCaptureTarget() { Graph *releaseCaptureTarget() {
return std::exchange(this->captureTarget, nullptr); return std::exchange(this->captureTarget, nullptr);
} }
@@ -539,6 +543,10 @@ struct CommandList : _ze_command_list_handle_t {
return this->latestTaskCount; return this->latestTaskCount;
} }
bool isInSynchronousMode() const {
return this->isSyncModeQueue;
}
protected: protected:
virtual void dispatchHostFunction(void *pHostFunction, virtual void dispatchHostFunction(void *pHostFunction,
void *pUserData) = 0; void *pUserData) = 0;

View File

@@ -43,12 +43,17 @@ struct GraphFixture : public DeviceFixture {
L0::CommandList *immCmdList = nullptr; L0::CommandList *immCmdList = nullptr;
}; };
struct MockEventWithRecordedSignal : public Mock<Event> {
ADDMETHOD_CONST_NOBASE(getRecordedSignalFrom, CommandList *, nullptr, ());
};
using GraphTestApiSubmit = Test<GraphFixture>; using GraphTestApiSubmit = Test<GraphFixture>;
using GraphTestApiInstantiate = Test<GraphFixture>; using GraphTestApiInstantiate = Test<GraphFixture>;
using GraphTestApiCaptureWithDevice = Test<GraphFixture>; using GraphTestApiCaptureWithDevice = Test<GraphFixture>;
using GraphTestInstantiationTest = Test<GraphFixture>; using GraphTestInstantiationTest = Test<GraphFixture>;
using GraphInstantiation = Test<GraphFixture>; using GraphInstantiation = Test<GraphFixture>;
using GraphExecution = Test<GraphFixture>; using GraphExecution = Test<GraphFixture>;
using GraphTestCaptureRestrictions = Test<GraphFixture>;
TEST(GraphTestApiCreate, GivenNonNullPNextThenGraphCreateReturnsError) { TEST(GraphTestApiCreate, GivenNonNullPNextThenGraphCreateReturnsError) {
GraphsCleanupGuard graphCleanup; GraphsCleanupGuard graphCleanup;
@@ -93,13 +98,7 @@ TEST(GraphTestApiCreate, GivenInvalidGraphThenGraphDestroyReturnsError) {
TEST(GraphTestApiCaptureBeginEnd, GivenGraphsEnabledWhenCapturingCmdlistThenItWorksForImmediateAndReturnsEarlyForRegular) { TEST(GraphTestApiCaptureBeginEnd, GivenGraphsEnabledWhenCapturingCmdlistThenItWorksForImmediateAndReturnsEarlyForRegular) {
GraphsCleanupGuard graphCleanup; GraphsCleanupGuard graphCleanup;
struct MyMockEvent : public Mock<Event> { MockEventWithRecordedSignal event;
CommandList *getRecordedSignalFrom() const override {
getRecordedSignalFromCalled = true;
return nullptr;
}
mutable bool getRecordedSignalFromCalled = false;
} event;
auto hEvent = event.toHandle(); auto hEvent = event.toHandle();
L0::Graph *pGraph = nullptr; L0::Graph *pGraph = nullptr;
@@ -171,6 +170,7 @@ TEST(GraphTestApiCaptureBeginEnd, GivenNonNullPNextThenGraphEndCaptureReturnsErr
GraphsCleanupGuard graphCleanup; GraphsCleanupGuard graphCleanup;
Mock<Context> ctx; Mock<Context> ctx;
Mock<CommandList> cmdlist; Mock<CommandList> cmdlist;
cmdlist.cmdListType = L0::CommandList::CommandListType::typeImmediate;
auto cmdListHandle = cmdlist.toHandle(); auto cmdListHandle = cmdlist.toHandle();
ze_base_desc_t ext = {}; ze_base_desc_t ext = {};
ext.stype = ZE_STRUCTURE_TYPE_MUTABLE_GRAPH_ARGUMENT_EXP_DESC; ext.stype = ZE_STRUCTURE_TYPE_MUTABLE_GRAPH_ARGUMENT_EXP_DESC;
@@ -189,6 +189,7 @@ TEST(GraphTestApiCaptureBeginEnd, WhenNoDestinyGraphProvidedThenEndCaptureReturn
GraphsCleanupGuard graphCleanup; GraphsCleanupGuard graphCleanup;
Mock<Context> ctx; Mock<Context> ctx;
Mock<CommandList> cmdlist; Mock<CommandList> cmdlist;
cmdlist.cmdListType = L0::CommandList::CommandListType::typeImmediate;
auto cmdListHandle = cmdlist.toHandle(); auto cmdListHandle = cmdlist.toHandle();
auto err = zeCommandListBeginGraphCaptureExp(cmdListHandle, nullptr); auto err = zeCommandListBeginGraphCaptureExp(cmdListHandle, nullptr);
EXPECT_EQ(ZE_RESULT_SUCCESS, err); EXPECT_EQ(ZE_RESULT_SUCCESS, err);
@@ -214,6 +215,7 @@ TEST(GraphTestApiCaptureBeginEnd, WhenCommandListIsNotRecordingThenEndCaptureRet
TEST(GraphTestApiCaptureBeginEnd, WhenNoDestinyGraphProvidedThenEndCaptureRequiresOutputGraphPlaceholder) { TEST(GraphTestApiCaptureBeginEnd, WhenNoDestinyGraphProvidedThenEndCaptureRequiresOutputGraphPlaceholder) {
GraphsCleanupGuard graphCleanup; GraphsCleanupGuard graphCleanup;
Mock<CommandList> cmdlist; Mock<CommandList> cmdlist;
cmdlist.cmdListType = L0::CommandList::CommandListType::typeImmediate;
auto cmdListHandle = cmdlist.toHandle(); auto cmdListHandle = cmdlist.toHandle();
auto err = zeCommandListBeginGraphCaptureExp(cmdListHandle, nullptr); auto err = zeCommandListBeginGraphCaptureExp(cmdListHandle, nullptr);
EXPECT_EQ(ZE_RESULT_SUCCESS, err); EXPECT_EQ(ZE_RESULT_SUCCESS, err);
@@ -1609,5 +1611,131 @@ TEST(ClosureExternalStorage, GivenCopyRegionThenRecordsItProperly) {
EXPECT_EQ(7U, storage.getCopyRegion(copyRegion2Id)->width); EXPECT_EQ(7U, storage.getCopyRegion(copyRegion2Id)->width);
} }
TEST_F(GraphTestCaptureRestrictions, GivenNonImmediateCommandListWhenBeginGraphCaptureCalledThenErrorIsReturned) {
GraphsCleanupGuard graphCleanup;
Mock<Context> ctx;
Mock<CommandList> regularCmdList;
regularCmdList.cmdListType = L0::CommandList::CommandListType::typeRegular;
auto cmdListHandle = regularCmdList.toHandle();
auto err = zeCommandListBeginGraphCaptureExp(cmdListHandle, nullptr);
EXPECT_EQ(ZE_RESULT_ERROR_INVALID_ARGUMENT, err);
}
TEST_F(GraphTestCaptureRestrictions, GivenSynchronousImmediateCommandListWhenBeginGraphCaptureCalledThenErrorIsReturned) {
GraphsCleanupGuard graphCleanup;
ze_command_list_handle_t syncCmdListHandle = nullptr;
ze_command_queue_desc_t cmdQueueDesc = {.stype = ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC, .mode = ZE_COMMAND_QUEUE_MODE_SYNCHRONOUS};
EXPECT_EQ(ZE_RESULT_SUCCESS, zeCommandListCreateImmediate(context->toHandle(), device->toHandle(), &cmdQueueDesc, &syncCmdListHandle));
auto err = zeCommandListBeginGraphCaptureExp(syncCmdListHandle, nullptr);
EXPECT_EQ(ZE_RESULT_ERROR_INVALID_ARGUMENT, err);
zeCommandListDestroy(syncCmdListHandle);
}
TEST_F(GraphTestCaptureRestrictions, GivenGraphWithUnjoinedForksWhenEndGraphCaptureCalledThenErrorIsReturned) {
GraphsCleanupGuard graphCleanup;
Mock<Context> ctx;
MockGraphCmdListWithContext mainCmdlist{&ctx};
MockGraphCmdListWithContext subCmdlist{&ctx};
Mock<Event> forkEvent;
auto mainCmdlistHandle = mainCmdlist.toHandle();
auto forkEventHandle = forkEvent.toHandle();
Graph srcGraph(&ctx, true);
mainCmdlist.setCaptureTarget(&srcGraph);
srcGraph.startCapturingFrom(mainCmdlist, false);
mainCmdlist.capture<CaptureApi::zeCommandListAppendBarrier>(mainCmdlistHandle, forkEventHandle, 0U, nullptr);
Graph *srcSubGraph = nullptr;
srcGraph.forkTo(subCmdlist, srcSubGraph, forkEvent);
// try to end capture with unjoined fork
ze_graph_handle_t retGraph = nullptr;
auto err = zeCommandListEndGraphCaptureExp(mainCmdlistHandle, &retGraph, nullptr);
EXPECT_EQ(ZE_RESULT_ERROR_INVALID_ARGUMENT, err);
EXPECT_EQ(nullptr, retGraph);
srcSubGraph->stopCapturing();
subCmdlist.setCaptureTarget(nullptr);
srcGraph.stopCapturing();
mainCmdlist.setCaptureTarget(nullptr);
}
TEST_F(GraphTestCaptureRestrictions, GivenCommandListAlreadyCapturingWhenBeginGraphCaptureCalledThenErrorIsReturned) {
GraphsCleanupGuard graphCleanup;
Mock<Context> ctx;
Mock<CommandList> cmdList;
cmdList.cmdListType = L0::CommandList::CommandListType::typeImmediate;
auto cmdListHandle = cmdList.toHandle();
auto err = zeCommandListBeginGraphCaptureExp(cmdListHandle, nullptr);
EXPECT_EQ(ZE_RESULT_SUCCESS, err);
err = zeCommandListBeginGraphCaptureExp(cmdListHandle, nullptr);
EXPECT_EQ(ZE_RESULT_ERROR_INVALID_ARGUMENT, err);
// check also zeCommandListBeginCaptureIntoGraphExp
Graph graph(&ctx, true);
auto graphHandle = graph.toHandle();
err = zeCommandListBeginCaptureIntoGraphExp(cmdListHandle, graphHandle, nullptr);
EXPECT_EQ(ZE_RESULT_ERROR_INVALID_ARGUMENT, err);
ze_graph_handle_t retGraph = nullptr;
err = zeCommandListEndGraphCaptureExp(cmdListHandle, &retGraph, nullptr);
EXPECT_EQ(ZE_RESULT_SUCCESS, err);
zeGraphDestroyExp(retGraph);
}
TEST_F(GraphTestCaptureRestrictions, GivenCommandListNotCapturingWhenEndGraphCaptureCalledThenErrorIsReturned) {
GraphsCleanupGuard graphCleanup;
ze_graph_handle_t retGraph = nullptr;
auto err = zeCommandListEndGraphCaptureExp(immCmdListHandle, &retGraph, nullptr);
EXPECT_EQ(ZE_RESULT_ERROR_INVALID_ARGUMENT, err);
EXPECT_EQ(nullptr, retGraph);
}
TEST_F(GraphTestCaptureRestrictions, GivenNonEmptyGraphWhenBeginCaptureIntoGraphCalledThenErrorIsReturned) {
GraphsCleanupGuard graphCleanup;
Mock<Context> ctx;
MockGraphCmdListWithContext cmdList{&ctx};
Mock<Event> event;
auto cmdListHandle = cmdList.toHandle();
auto eventHandle = event.toHandle();
auto ctxHandle = ctx.toHandle();
ze_graph_handle_t graphHandle = nullptr;
auto err = zeGraphCreateExp(ctxHandle, &graphHandle, nullptr);
EXPECT_EQ(ZE_RESULT_SUCCESS, err);
ASSERT_NE(nullptr, graphHandle);
auto *graph = L0::Graph::fromHandle(graphHandle);
cmdList.setCaptureTarget(graph);
graph->startCapturingFrom(cmdList, false);
cmdList.capture<CaptureApi::zeCommandListAppendBarrier>(cmdListHandle, eventHandle, 0U, nullptr);
graph->stopCapturing();
cmdList.setCaptureTarget(nullptr);
MockGraphCmdListWithContext cmdList2{&ctx};
auto cmdList2Handle = cmdList2.toHandle();
err = zeCommandListBeginCaptureIntoGraphExp(cmdList2Handle, graphHandle, nullptr);
EXPECT_EQ(ZE_RESULT_ERROR_INVALID_ARGUMENT, err);
EXPECT_EQ(nullptr, cmdList2.getCaptureTarget());
zeGraphDestroyExp(graphHandle);
}
} // namespace ult } // namespace ult
} // namespace L0 } // namespace L0