diff --git a/level_zero/core/source/cmdlist/cmdlist.cpp b/level_zero/core/source/cmdlist/cmdlist.cpp index d10ff3eb2a..281a07a2ea 100644 --- a/level_zero/core/source/cmdlist/cmdlist.cpp +++ b/level_zero/core/source/cmdlist/cmdlist.cpp @@ -240,4 +240,47 @@ void CommandList::registerWalkerWithProfilingEnqueued(Event *event) { } } +ze_result_t CommandList::setKernelState(Kernel *kernel, const ze_group_size_t groupSizes, void **arguments) { + if (kernel == nullptr) { + return ZE_RESULT_ERROR_INVALID_NULL_HANDLE; + } + + auto result = kernel->setGroupSize(groupSizes.groupSizeX, groupSizes.groupSizeY, groupSizes.groupSizeZ); + + if (result != ZE_RESULT_SUCCESS) { + return result; + } + + auto &args = kernel->getImmutableData()->getDescriptor().payloadMappings.explicitArgs; + + if (args.size() > 0 && !arguments) { + return ZE_RESULT_ERROR_INVALID_NULL_POINTER; + } + for (auto i = 0u; i < args.size(); i++) { + + auto &arg = args[i]; + auto argSize = sizeof(void *); + auto argValue = arguments[i]; + + switch (arg.type) { + case NEO::ArgDescriptor::argTPointer: + if (arg.getTraits().getAddressQualifier() == NEO::KernelArgMetadata::AddrLocal) { + argSize = *reinterpret_cast(argValue); + argValue = nullptr; + } + break; + case NEO::ArgDescriptor::argTValue: + argSize = std::numeric_limits::max(); + break; + default: + break; + } + result = kernel->setArgumentValue(i, argSize, argValue); + if (result != ZE_RESULT_SUCCESS) { + return result; + } + } + return ZE_RESULT_SUCCESS; +} + } // namespace L0 diff --git a/level_zero/core/source/cmdlist/cmdlist.h b/level_zero/core/source/cmdlist/cmdlist.h index 4e2e73708b..03eaa0039b 100644 --- a/level_zero/core/source/cmdlist/cmdlist.h +++ b/level_zero/core/source/cmdlist/cmdlist.h @@ -241,6 +241,8 @@ struct CommandList : _ze_command_list_handle_t { return alloc && alloc->getAllocationType() == NEO::AllocationType::externalHostPtr; } + static ze_result_t setKernelState(Kernel *kernel, const ze_group_size_t groupSizes, void **arguments); + inline ze_command_list_handle_t toHandle() { return this; } uint32_t getCommandListPerThreadScratchSize(uint32_t slotId) const { diff --git a/level_zero/core/source/cmdlist/cmdlist_hw.inl b/level_zero/core/source/cmdlist/cmdlist_hw.inl index 2adc8916c2..e34b3e4e12 100644 --- a/level_zero/core/source/cmdlist/cmdlist_hw.inl +++ b/level_zero/core/source/cmdlist/cmdlist_hw.inl @@ -587,7 +587,6 @@ template ze_result_t CommandListCoreFamily::appendLaunchKernelWithArguments(ze_kernel_handle_t hKernel, const ze_group_count_t groupCounts, const ze_group_size_t groupSizes, - void **pArguments, const void *pNext, ze_event_handle_t hSignalEvent, @@ -595,46 +594,10 @@ ze_result_t CommandListCoreFamily::appendLaunchKernelWithArgument ze_event_handle_t *phWaitEvents) { auto kernel = L0::Kernel::fromHandle(hKernel); - - if (kernel == nullptr) { - return ZE_RESULT_ERROR_INVALID_NULL_HANDLE; - } - - auto result = kernel->setGroupSize(groupSizes.groupSizeX, groupSizes.groupSizeY, groupSizes.groupSizeZ); - + auto result = CommandList::setKernelState(kernel, groupSizes, pArguments); if (result != ZE_RESULT_SUCCESS) { return result; } - - auto &args = kernel->getImmutableData()->getDescriptor().payloadMappings.explicitArgs; - - if (args.size() > 0 && !pArguments) { - return ZE_RESULT_ERROR_INVALID_NULL_POINTER; - } - for (auto i = 0u; i < args.size(); i++) { - - auto &arg = args[i]; - auto argSize = sizeof(void *); - auto argValue = pArguments[i]; - - switch (arg.type) { - case NEO::ArgDescriptor::argTPointer: - if (arg.getTraits().getAddressQualifier() == NEO::KernelArgMetadata::AddrLocal) { - argSize = *reinterpret_cast(argValue); - argValue = nullptr; - } - break; - case NEO::ArgDescriptor::argTValue: - argSize = std::numeric_limits::max(); - - default: - break; - } - result = kernel->setArgumentValue(i, argSize, argValue); - if (result != ZE_RESULT_SUCCESS) { - return result; - } - } return this->appendLaunchKernelWithParameters(hKernel, &groupCounts, pNext, hSignalEvent, numWaitEvents, phWaitEvents); }