diff --git a/level_zero/core/source/kernel/kernel_imp.cpp b/level_zero/core/source/kernel/kernel_imp.cpp index cda485bac8..fa18f462c1 100644 --- a/level_zero/core/source/kernel/kernel_imp.cpp +++ b/level_zero/core/source/kernel/kernel_imp.cpp @@ -375,13 +375,6 @@ ze_result_t KernelImp::setGroupSize(uint32_t groupSizeX, uint32_t groupSizeY, patchWorkgroupSizeInCrossThreadData(groupSizeX, groupSizeY, groupSizeZ); auto simdSize = kernelDescriptor.kernelAttributes.simdSize; - auto remainderSimdLanes = itemsInGroup & (simdSize - 1u); - threadExecutionMask = static_cast(maxNBitValue(remainderSimdLanes)); - if (!threadExecutionMask) { - threadExecutionMask = static_cast(maxNBitValue((isSimd1(simdSize)) ? 32 : simdSize)); - } - evaluateIfRequiresGenerationOfLocalIdsByRuntime(kernelDescriptor); - auto grfCount = kernelDescriptor.kernelAttributes.numGrfRequired; auto neoDevice = module->getDevice()->getNEODevice(); auto &rootDeviceEnvironment = neoDevice->getRootDeviceEnvironment(); @@ -389,6 +382,17 @@ ze_result_t KernelImp::setGroupSize(uint32_t groupSizeX, uint32_t groupSizeY, this->numThreadsPerThreadGroup = gfxCoreHelper.calculateNumThreadsPerThreadGroup( simdSize, static_cast(itemsInGroup), grfCount, rootDeviceEnvironment); + if (auto wgSizeRet = validateWorkgroupSize(); wgSizeRet != ZE_RESULT_SUCCESS) { + return wgSizeRet; + } + + auto remainderSimdLanes = itemsInGroup & (simdSize - 1u); + threadExecutionMask = static_cast(maxNBitValue(remainderSimdLanes)); + if (!threadExecutionMask) { + threadExecutionMask = static_cast(maxNBitValue((isSimd1(simdSize)) ? 32 : simdSize)); + } + evaluateIfRequiresGenerationOfLocalIdsByRuntime(kernelDescriptor); + if (kernelRequiresGenerationOfLocalIdsByRuntime) { auto grfSize = this->module->getDevice()->getHwInfo().capabilityTable.grfSize; uint32_t perThreadDataSizeForWholeThreadGroupNeeded = diff --git a/level_zero/core/source/kernel/kernel_imp.h b/level_zero/core/source/kernel/kernel_imp.h index 10321a6658..fcaa78a9f7 100644 --- a/level_zero/core/source/kernel/kernel_imp.h +++ b/level_zero/core/source/kernel/kernel_imp.h @@ -251,6 +251,7 @@ struct KernelImp : Kernel { virtual void evaluateIfRequiresGenerationOfLocalIdsByRuntime(const NEO::KernelDescriptor &kernelDescriptor) = 0; void *patchBindlessSurfaceState(NEO::GraphicsAllocation *alloc, uint32_t bindless); uint32_t getSurfaceStateIndexForBindlessOffset(NEO::CrossThreadDataOffset bindlessOffset) const; + ze_result_t validateWorkgroupSize() const; const KernelImmutableData *kernelImmData = nullptr; Module *module = nullptr; diff --git a/level_zero/core/source/kernel/kernel_imp_helper.cpp b/level_zero/core/source/kernel/kernel_imp_helper.cpp index e317cf27e9..d12512dc09 100644 --- a/level_zero/core/source/kernel/kernel_imp_helper.cpp +++ b/level_zero/core/source/kernel/kernel_imp_helper.cpp @@ -14,4 +14,7 @@ KernelExt *KernelImp::getExtension(uint32_t extensionType) { return nullptr; } void KernelImp::patchRegionParams(const CmdListKernelLaunchParams &launchParams, const ze_group_count_t &threadGroupDimensions) {} +ze_result_t KernelImp::validateWorkgroupSize() const { + return ZE_RESULT_SUCCESS; +} } // namespace L0