diff --git a/level_zero/core/source/kernel/kernel_imp.cpp b/level_zero/core/source/kernel/kernel_imp.cpp index 89374c2059..c346d63cfc 100644 --- a/level_zero/core/source/kernel/kernel_imp.cpp +++ b/level_zero/core/source/kernel/kernel_imp.cpp @@ -272,11 +272,14 @@ KernelMutableState &KernelMutableState::operator=(const KernelMutableState &rhs) std::memcpy(dynamicStateHeapData.get(), rhs.dynamicStateHeapData.get(), dynamicStateHeapDataSize); } - reservePerThreadDataForWholeThreadGroup(rhs.perThreadDataSizeForWholeThreadGroup); - DEBUG_BREAK_IF(perThreadDataSizeForWholeThreadGroupAllocated < perThreadDataSizeForWholeThreadGroup); - std::memcpy(perThreadDataForWholeThreadGroup, rhs.perThreadDataForWholeThreadGroup, perThreadDataSizeForWholeThreadGroup); - const size_t tailSize = perThreadDataSizeForWholeThreadGroupAllocated - perThreadDataSizeForWholeThreadGroup; - std::memset(perThreadDataForWholeThreadGroup + perThreadDataSizeForWholeThreadGroup, 0x0, tailSize); + if (rhs.perThreadDataSizeForWholeThreadGroup) { + reservePerThreadDataForWholeThreadGroup(rhs.perThreadDataSizeForWholeThreadGroup); + DEBUG_BREAK_IF(perThreadDataSizeForWholeThreadGroupAllocated < perThreadDataSizeForWholeThreadGroup); + DEBUG_BREAK_IF(nullptr == rhs.perThreadDataForWholeThreadGroup); + std::memcpy(perThreadDataForWholeThreadGroup, rhs.perThreadDataForWholeThreadGroup, perThreadDataSizeForWholeThreadGroup); + const size_t tailSize = perThreadDataSizeForWholeThreadGroupAllocated - perThreadDataSizeForWholeThreadGroup; + std::memset(perThreadDataForWholeThreadGroup + perThreadDataSizeForWholeThreadGroup, 0x0, tailSize); + } return *this; } diff --git a/level_zero/core/test/black_box_tests/zello_graph.cpp b/level_zero/core/test/black_box_tests/zello_graph.cpp index d018d50043..2d76d93253 100644 --- a/level_zero/core/test/black_box_tests/zello_graph.cpp +++ b/level_zero/core/test/black_box_tests/zello_graph.cpp @@ -238,6 +238,168 @@ void testMultiGraph(ze_driver_handle_t driver, ze_context_handle_t &context, ze_ SUCCESS_OR_TERMINATE(zeEventPoolDestroy(eventPool)); } +void testAppendLaunchKernel(ze_driver_handle_t driver, ze_context_handle_t &context, ze_device_handle_t &device, bool &validRet) { + auto graphApi = loadGraphApi(driver); + if (false == graphApi.valid()) { + std::cerr << "Graph API not available" << std::endl; + validRet = false; + return; + } + + // Buffers + constexpr size_t allocSize = 4096; + void *srcBuffer = nullptr; + void *interimBuffer = nullptr; + void *dstBuffer = nullptr; + ze_device_mem_alloc_desc_t deviceDesc = { + .stype = ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC, + .pNext = nullptr, + .flags = 0, + .ordinal = 0, + }; + SUCCESS_OR_TERMINATE(zeMemAllocDevice(context, &deviceDesc, allocSize, allocSize, device, &srcBuffer)); + SUCCESS_OR_TERMINATE(zeMemAllocDevice(context, &deviceDesc, allocSize, allocSize, device, &interimBuffer)); + SUCCESS_OR_TERMINATE(zeMemAllocDevice(context, &deviceDesc, allocSize, allocSize, device, &dstBuffer)); + + // SpirV for a kernel + std::string buildLog; + auto moduleBinary = LevelZeroBlackBoxTests::compileToSpirV(LevelZeroBlackBoxTests::memcpyBytesTestKernelSrc, "", buildLog); + LevelZeroBlackBoxTests::printBuildLog(buildLog); + SUCCESS_OR_TERMINATE((0 == moduleBinary.size())); + + // Module + ze_module_handle_t module; + ze_module_desc_t moduleDesc = { + .stype = ZE_STRUCTURE_TYPE_MODULE_DESC, + .pNext = nullptr, + .format = ZE_MODULE_FORMAT_IL_SPIRV, + .inputSize = moduleBinary.size(), + .pInputModule = reinterpret_cast(moduleBinary.data()), + }; + SUCCESS_OR_TERMINATE(zeModuleCreate(context, device, &moduleDesc, &module, nullptr)); + + // Kernel + ze_kernel_handle_t kernel; + ze_kernel_desc_t kernelDesc = { + .stype = ZE_STRUCTURE_TYPE_KERNEL_DESC, + .pNext = nullptr, + .flags = 0, + .pKernelName = "memcpy_bytes", + }; + SUCCESS_OR_TERMINATE(zeKernelCreate(module, &kernelDesc, &kernel)); + + constexpr size_t bytesPerThread = sizeof(std::byte); + constexpr size_t numThreads = allocSize / bytesPerThread; + uint32_t groupSizeX = 32u; + uint32_t groupSizeY = 1u; + uint32_t groupSizeZ = 1u; + SUCCESS_OR_TERMINATE(zeKernelSuggestGroupSize(kernel, static_cast(numThreads), 1U, 1U, &groupSizeX, &groupSizeY, &groupSizeZ)); + SUCCESS_OR_TERMINATE_BOOL(numThreads % groupSizeX == 0); + if (LevelZeroBlackBoxTests::verbose) { + std::cout << "Group size : (" << groupSizeX << ", " << groupSizeY << ", " << groupSizeZ << ")" << std::endl; + } + SUCCESS_OR_TERMINATE(zeKernelSetGroupSize(kernel, groupSizeX, groupSizeY, groupSizeZ)); + + // Events + ze_event_pool_handle_t eventPool = nullptr; + ze_event_pool_desc_t eventPoolDesc{ + .stype = ZE_STRUCTURE_TYPE_EVENT_POOL_DESC, + .pNext = nullptr, + .flags = ZE_EVENT_POOL_FLAG_HOST_VISIBLE, + .count = 1, + }; + SUCCESS_OR_TERMINATE(zeEventPoolCreate(context, &eventPoolDesc, 1, &device, &eventPool)); + + ze_event_handle_t eventCopied = nullptr; + ze_event_desc_t eventDesc{ + .stype = ZE_STRUCTURE_TYPE_EVENT_DESC, + .pNext = nullptr, + .index = 0, + .signal = ZE_EVENT_SCOPE_FLAG_HOST, + .wait = ZE_EVENT_SCOPE_FLAG_HOST, + + }; + SUCCESS_OR_TERMINATE(zeEventCreate(eventPool, &eventDesc, &eventCopied)); + + // Create cmdList + ze_command_list_handle_t cmdList; + ze_command_queue_desc_t cmdQueueDesc{ + .stype = ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC, + .pNext = nullptr, + .ordinal = LevelZeroBlackBoxTests::getCommandQueueOrdinal(device, false), + .index = 0, + .flags = 0, + .mode = ZE_COMMAND_QUEUE_MODE_ASYNCHRONOUS, + .priority = ZE_COMMAND_QUEUE_PRIORITY_NORMAL, + }; + SUCCESS_OR_TERMINATE(zeCommandListCreateImmediate(context, device, &cmdQueueDesc, &cmdList)); + + // Start capturing commands + ze_graph_handle_t virtualGraph = nullptr; + SUCCESS_OR_TERMINATE(graphApi.graphCreate(context, &virtualGraph, nullptr)); + SUCCESS_OR_TERMINATE(graphApi.commandListBeginCaptureIntoGraph(cmdList, virtualGraph, nullptr)); + + // Encode buffers initialization + auto srcInitData = std::make_unique(allocSize); + std::memset(srcInitData.get(), 0xa, allocSize); + auto dstInitData = std::make_unique(allocSize); + std::memset(dstInitData.get(), 0x5, allocSize); + SUCCESS_OR_TERMINATE(zeCommandListAppendMemoryCopy(cmdList, srcBuffer, srcInitData.get(), allocSize, nullptr, 0, nullptr)); + SUCCESS_OR_TERMINATE(zeCommandListAppendMemoryCopy(cmdList, dstBuffer, dstInitData.get(), allocSize, nullptr, 0, nullptr)); + SUCCESS_OR_TERMINATE(zeCommandListAppendBarrier(cmdList, nullptr, 0, nullptr)); + + ze_group_count_t dispatchTraits{ + .groupCountX = static_cast(numThreads) / groupSizeX, + .groupCountY = 1u, + .groupCountZ = 1u, + }; + LevelZeroBlackBoxTests::printGroupCount(dispatchTraits); + SUCCESS_OR_TERMINATE_BOOL(dispatchTraits.groupCountX * groupSizeX == allocSize); + + // Launch first copy + SUCCESS_OR_TERMINATE(zeKernelSetArgumentValue(kernel, 0, sizeof(interimBuffer), &interimBuffer)); + SUCCESS_OR_TERMINATE(zeKernelSetArgumentValue(kernel, 1, sizeof(srcBuffer), &srcBuffer)); + SUCCESS_OR_TERMINATE(zeCommandListAppendLaunchKernel(cmdList, kernel, &dispatchTraits, eventCopied, 0, nullptr)); + + // Launch second copy + SUCCESS_OR_TERMINATE(zeKernelSetArgumentValue(kernel, 0, sizeof(dstBuffer), &dstBuffer)); + SUCCESS_OR_TERMINATE(zeKernelSetArgumentValue(kernel, 1, sizeof(interimBuffer), &interimBuffer)); + SUCCESS_OR_TERMINATE(zeCommandListAppendLaunchKernel(cmdList, kernel, &dispatchTraits, nullptr, 1, &eventCopied)); + SUCCESS_OR_TERMINATE(zeCommandListAppendBarrier(cmdList, nullptr, 0, nullptr)); + + // Encode reading data back + auto outputData = std::make_unique(allocSize); + std::memset(outputData.get(), 0x9, allocSize); + SUCCESS_OR_TERMINATE(zeCommandListAppendMemoryCopy(cmdList, outputData.get(), dstBuffer, allocSize, nullptr, 0, nullptr)); + + SUCCESS_OR_TERMINATE(graphApi.commandListEndGraphCapture(cmdList, nullptr, nullptr)); + ze_executable_graph_handle_t physicalGraph = nullptr; + SUCCESS_OR_TERMINATE(graphApi.commandListInstantiateGraph(virtualGraph, &physicalGraph, nullptr)); + + // // Dispatch and wait + SUCCESS_OR_TERMINATE(graphApi.commandListAppendGraph(cmdList, physicalGraph, nullptr, nullptr, 0, nullptr)); + SUCCESS_OR_TERMINATE(zeCommandListHostSynchronize(cmdList, -1)); + + // Validate + validRet = LevelZeroBlackBoxTests::validate(outputData.get(), srcInitData.get(), allocSize); + if (!validRet) { + std::cerr << "Data mismatches found!\n"; + std::cerr << "srcInitData == " << static_cast(srcInitData.get()) << "\n"; + std::cerr << "outputData == " << static_cast(outputData.get()) << std::endl; + } + + // Cleanup + SUCCESS_OR_TERMINATE(zeMemFree(context, dstBuffer)); + SUCCESS_OR_TERMINATE(zeMemFree(context, srcBuffer)); + + SUCCESS_OR_TERMINATE(zeCommandListDestroy(cmdList)); + SUCCESS_OR_TERMINATE(zeKernelDestroy(kernel)); + SUCCESS_OR_TERMINATE(zeModuleDestroy(module)); + + SUCCESS_OR_TERMINATE(graphApi.graphDestroy(virtualGraph)); + SUCCESS_OR_TERMINATE(graphApi.executableGraphDestroy(physicalGraph)); +} + int main(int argc, char *argv[]) { const std::string blackBoxName("Zello Graph"); LevelZeroBlackBoxTests::verbose = LevelZeroBlackBoxTests::isVerbose(argc, argv); @@ -264,6 +426,10 @@ int main(int argc, char *argv[]) { testMultiGraph(driverHandle, context, device0, outputValidationSuccessful); LevelZeroBlackBoxTests::printResult(aubMode, outputValidationSuccessful, blackBoxName, currentTest); + currentTest = "AppendLaunchKernel"; + testAppendLaunchKernel(driverHandle, context, device0, outputValidationSuccessful); + LevelZeroBlackBoxTests::printResult(aubMode, outputValidationSuccessful, blackBoxName, currentTest); + SUCCESS_OR_TERMINATE(zeContextDestroy(context)); int resultOnFailure = aubMode ? 0 : 1; diff --git a/level_zero/core/test/unit_tests/experimental/test_graph.cpp b/level_zero/core/test/unit_tests/experimental/test_graph.cpp index ba9654cb74..2b84a72978 100644 --- a/level_zero/core/test/unit_tests/experimental/test_graph.cpp +++ b/level_zero/core/test/unit_tests/experimental/test_graph.cpp @@ -468,10 +468,11 @@ TEST(GraphTestApiCapture, GivenCommandListInRecordStateThenCaptureCommandsInstea Mock otherCtx; Mock cmdlist; Mock event; + Mock kernel; ze_image_handle_t imgA = nullptr; ze_image_handle_t imgB = nullptr; zes_device_handle_t device = nullptr; - zet_kernel_handle_t kernel = nullptr; + zet_kernel_handle_t kernelHandle = &kernel; ze_external_semaphore_ext_handle_t sem = nullptr; ze_event_handle_t eventHandle = &event; ze_external_semaphore_signal_params_ext_t semSignalParams = {}; @@ -517,17 +518,17 @@ TEST(GraphTestApiCapture, GivenCommandListInRecordStateThenCaptureCommandsInstea EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeCommandListAppendWaitExternalSemaphoreExt(&cmdlist, 1, &sem, &semWaitParams, nullptr, 0, nullptr)); EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeCommandListAppendImageCopyToMemoryExt(&cmdlist, memA, imgA, &imgRegion, 16, 16, nullptr, 0, nullptr)); EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeCommandListAppendImageCopyFromMemoryExt(&cmdlist, imgA, memA, &imgRegion, 16, 16, nullptr, 0, nullptr)); + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeCommandListAppendLaunchKernel(&cmdlist, kernelHandle, &groupCount, nullptr, 0, nullptr)); // temporarily unsupported - EXPECT_EQ(ZE_RESULT_ERROR_UNSUPPORTED_FEATURE, L0::zeCommandListAppendLaunchKernel(&cmdlist, kernel, &groupCount, nullptr, 0, nullptr)); - EXPECT_EQ(ZE_RESULT_ERROR_UNSUPPORTED_FEATURE, L0::zeCommandListAppendLaunchCooperativeKernel(&cmdlist, kernel, &groupCount, nullptr, 0, nullptr)); - EXPECT_EQ(ZE_RESULT_ERROR_UNSUPPORTED_FEATURE, L0::zeCommandListAppendLaunchKernelIndirect(&cmdlist, kernel, &groupCount, nullptr, 0, nullptr)); - EXPECT_EQ(ZE_RESULT_ERROR_UNSUPPORTED_FEATURE, L0::zeCommandListAppendLaunchMultipleKernelsIndirect(&cmdlist, 1, &kernel, &kernelCount, &groupCount, nullptr, 0, nullptr)); + EXPECT_EQ(ZE_RESULT_ERROR_UNSUPPORTED_FEATURE, L0::zeCommandListAppendLaunchCooperativeKernel(&cmdlist, kernelHandle, &groupCount, nullptr, 0, nullptr)); + EXPECT_EQ(ZE_RESULT_ERROR_UNSUPPORTED_FEATURE, L0::zeCommandListAppendLaunchKernelIndirect(&cmdlist, kernelHandle, &groupCount, nullptr, 0, nullptr)); + EXPECT_EQ(ZE_RESULT_ERROR_UNSUPPORTED_FEATURE, L0::zeCommandListAppendLaunchMultipleKernelsIndirect(&cmdlist, 1, &kernelHandle, &kernelCount, &groupCount, nullptr, 0, nullptr)); ze_graph_handle_t hgraph = &graph; EXPECT_EQ(ZE_RESULT_SUCCESS, ::zeCommandListEndGraphCaptureExp(&cmdlist, &hgraph, nullptr)); - ASSERT_EQ(21U, graph.getCapturedCommands().size()); + ASSERT_EQ(22U, graph.getCapturedCommands().size()); uint32_t i = 0; EXPECT_EQ(CaptureApi::zeCommandListAppendBarrier, static_cast(graph.getCapturedCommands()[i++].index())); EXPECT_EQ(CaptureApi::zeCommandListAppendMemoryCopy, static_cast(graph.getCapturedCommands()[i++].index())); @@ -550,6 +551,7 @@ TEST(GraphTestApiCapture, GivenCommandListInRecordStateThenCaptureCommandsInstea EXPECT_EQ(CaptureApi::zeCommandListAppendWaitExternalSemaphoreExt, static_cast(graph.getCapturedCommands()[i++].index())); EXPECT_EQ(CaptureApi::zeCommandListAppendImageCopyToMemoryExt, static_cast(graph.getCapturedCommands()[i++].index())); EXPECT_EQ(CaptureApi::zeCommandListAppendImageCopyFromMemoryExt, static_cast(graph.getCapturedCommands()[i++].index())); + EXPECT_EQ(CaptureApi::zeCommandListAppendLaunchKernel, static_cast(graph.getCapturedCommands()[i++].index())); } TEST(GraphForks, GivenUnknownChildCommandlistThenJoinDoesNothing) { @@ -788,11 +790,13 @@ TEST(GraphTestInstantiation, WhenInstantiatingGraphThenBakeCommandsIntoCommandli MockGraphContextReturningSpecificCmdList otherCtx; Mock cmdlist; Mock event; + Mock kernel; ze_image_handle_t imgA = nullptr; ze_image_handle_t imgB = nullptr; zes_device_handle_t device = nullptr; ze_external_semaphore_ext_handle_t sem = nullptr; ze_event_handle_t eventHandle = &event; + zet_kernel_handle_t kernelHandle = &kernel; ze_external_semaphore_signal_params_ext_t semSignalParams = {}; ze_external_semaphore_wait_params_ext_t semWaitParams = {}; @@ -808,6 +812,7 @@ TEST(GraphTestInstantiation, WhenInstantiatingGraphThenBakeCommandsIntoCommandli imgRegion.width = 1; imgRegion.height = 1; imgRegion.depth = 1; + ze_group_count_t groupCount = {1, 1, 1}; L0::Graph srcGraph(&ctx, true); ASSERT_EQ(ZE_RESULT_SUCCESS, ::zeCommandListBeginCaptureIntoGraphExp(&cmdlist, &srcGraph, nullptr)); @@ -833,6 +838,7 @@ TEST(GraphTestInstantiation, WhenInstantiatingGraphThenBakeCommandsIntoCommandli EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeCommandListAppendWaitExternalSemaphoreExt(&cmdlist, 1, &sem, &semWaitParams, nullptr, 0, nullptr)); EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeCommandListAppendImageCopyToMemoryExt(&cmdlist, memA, imgA, &imgRegion, 16, 16, nullptr, 0, nullptr)); EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeCommandListAppendImageCopyFromMemoryExt(&cmdlist, imgA, memA, &imgRegion, 16, 16, nullptr, 0, nullptr)); + EXPECT_EQ(ZE_RESULT_SUCCESS, L0::zeCommandListAppendLaunchKernel(&cmdlist, kernelHandle, &groupCount, nullptr, 0, nullptr)); ze_graph_handle_t hgraph = &srcGraph; EXPECT_EQ(ZE_RESULT_SUCCESS, ::zeCommandListEndGraphCaptureExp(&cmdlist, &hgraph, nullptr)); @@ -862,6 +868,7 @@ TEST(GraphTestInstantiation, WhenInstantiatingGraphThenBakeCommandsIntoCommandli EXPECT_EQ(0U, graphHwCommands->appendWaitExternalSemaphoresCalled); EXPECT_EQ(0U, graphHwCommands->appendImageCopyToMemoryExtCalled); EXPECT_EQ(0U, graphHwCommands->appendImageCopyFromMemoryExtCalled); + EXPECT_EQ(0U, graphHwCommands->appendLaunchKernelCalled); execGraph.instantiateFrom(srcGraph); EXPECT_EQ(1U, graphHwCommands->appendBarrierCalled); EXPECT_EQ(1U, graphHwCommands->appendMemoryCopyCalled); @@ -884,6 +891,7 @@ TEST(GraphTestInstantiation, WhenInstantiatingGraphThenBakeCommandsIntoCommandli EXPECT_EQ(1U, graphHwCommands->appendWaitExternalSemaphoresCalled); EXPECT_EQ(1U, graphHwCommands->appendImageCopyToMemoryExtCalled); EXPECT_EQ(1U, graphHwCommands->appendImageCopyFromMemoryExtCalled); + EXPECT_EQ(1U, graphHwCommands->appendLaunchKernelCalled); } TEST(GraphExecution, GivenEmptyExecutableGraphWhenSubmittingItToCommandListThenTakeCareOnlyOfEvents) { diff --git a/level_zero/experimental/source/graph/graph.cpp b/level_zero/experimental/source/graph/graph.cpp index dee03be86b..1ee5b45838 100644 --- a/level_zero/experimental/source/graph/graph.cpp +++ b/level_zero/experimental/source/graph/graph.cpp @@ -10,6 +10,7 @@ #include "level_zero/core/source/cmdlist/cmdlist.h" #include "level_zero/core/source/context/context.h" #include "level_zero/core/source/event/event.h" +#include "level_zero/core/source/kernel/kernel_imp.h" namespace L0 { @@ -219,6 +220,26 @@ ze_result_t Closure::inst apiArgs.hSignalEvent, apiArgs.numWaitEvents, const_cast(getOptionalData(indirectArgs.waitEvents))); } +Closure::Closure(const ApiArgs &apiArgs) : apiArgs{apiArgs} { + this->apiArgs.launchKernelArgs = nullptr; + this->apiArgs.phWaitEvents = nullptr; + this->indirectArgs.launchKernelArgs = *apiArgs.launchKernelArgs; + + auto kernel = static_cast(Kernel::fromHandle(apiArgs.kernelHandle)); + this->indirectArgs.kernelState = kernel->getMutableState(); + + this->indirectArgs.waitEvents.reserve(apiArgs.numWaitEvents); + for (uint32_t i = 0; i < apiArgs.numWaitEvents; ++i) { + this->indirectArgs.waitEvents.push_back(apiArgs.phWaitEvents[i]); + } +} + +ze_result_t Closure::instantiateTo(L0::CommandList &executionTarget) const { + auto kernel = static_cast(Kernel::fromHandle(apiArgs.kernelHandle)); + kernel->getMutableState() = this->indirectArgs.kernelState; + return zeCommandListAppendLaunchKernel(&executionTarget, apiArgs.kernelHandle, &indirectArgs.launchKernelArgs, apiArgs.hSignalEvent, apiArgs.numWaitEvents, const_cast(getOptionalData(indirectArgs.waitEvents))); +} + ExecutableGraph::~ExecutableGraph() = default; L0::CommandList *ExecutableGraph::allocateAndAddCommandListSubmissionNode() { diff --git a/level_zero/experimental/source/graph/graph.h b/level_zero/experimental/source/graph/graph.h index af89866ba0..d65459774c 100644 --- a/level_zero/experimental/source/graph/graph.h +++ b/level_zero/experimental/source/graph/graph.h @@ -10,6 +10,7 @@ #include "shared/source/helpers/string.h" #include "shared/source/utilities/stackvec.h" +#include "level_zero/core/source/kernel/kernel_mutable_state.h" #include "level_zero/ze_api.h" #include @@ -697,6 +698,35 @@ struct Closure { ze_result_t instantiateTo(CommandList &executionTarget) const; }; +template <> +struct Closure { + inline static constexpr bool isSupported = true; + + struct ApiArgs { + ze_command_list_handle_t hCommandList; + ze_kernel_handle_t kernelHandle; + const ze_group_count_t *launchKernelArgs; + ze_event_handle_t hSignalEvent; + uint32_t numWaitEvents; + ze_event_handle_t *phWaitEvents; + } apiArgs; + + struct IndirectArgs { + ze_group_count_t launchKernelArgs; + KernelMutableState kernelState; + StackVec waitEvents; + } indirectArgs; + + Closure(const ApiArgs &apiArgs); + Closure(const Closure &) = delete; + Closure(Closure &&rhs) = default; + Closure &operator=(const Closure &) = delete; + Closure &operator=(Closure &&) = delete; + ~Closure() = default; + + ze_result_t instantiateTo(CommandList &executionTarget) const; +}; + using ClosureVariants = std::variant< #define RR_CAPTURED_API(X) Closure, RR_CAPTURED_APIS()