fix: add function to calculate number of threads per tg

Signed-off-by: Cencelewska, Katarzyna <katarzyna.cencelewska@intel.com>
This commit is contained in:
Cencelewska, Katarzyna
2023-06-12 11:41:13 +00:00
committed by Compute-Runtime-Automation
parent 987394b27c
commit 7cb3278eb3
16 changed files with 71 additions and 21 deletions

View File

@@ -317,10 +317,9 @@ ze_result_t KernelImp::setGroupSize(uint32_t groupSizeX, uint32_t groupSizeY,
}
}
auto simdSize = kernelDescriptor.kernelAttributes.simdSize;
this->numThreadsPerThreadGroup = static_cast<uint32_t>((itemsInGroup + simdSize - 1u) / simdSize);
patchWorkgroupSizeInCrossThreadData(groupSizeX, groupSizeY, groupSizeZ);
auto simdSize = kernelDescriptor.kernelAttributes.simdSize;
auto remainderSimdLanes = itemsInGroup & (simdSize - 1u);
threadExecutionMask = static_cast<uint32_t>(maxNBitValue(remainderSimdLanes));
if (!threadExecutionMask) {
@@ -328,6 +327,12 @@ ze_result_t KernelImp::setGroupSize(uint32_t groupSizeX, uint32_t groupSizeY,
}
evaluateIfRequiresGenerationOfLocalIdsByRuntime(kernelDescriptor);
auto grfSize = this->module->getDevice()->getHwInfo().capabilityTable.grfSize;
auto &rootDeviceEnvironment = module->getDevice()->getNEODevice()->getRootDeviceEnvironment();
auto &gfxCoreHelper = rootDeviceEnvironment.getHelper<NEO::GfxCoreHelper>();
this->numThreadsPerThreadGroup = gfxCoreHelper.calculateNumThreadsPerThreadGroup(
simdSize, static_cast<uint32_t>(itemsInGroup), grfSize, kernelRequiresGenerationOfLocalIdsByRuntime);
if (kernelRequiresGenerationOfLocalIdsByRuntime) {
auto grfSize = this->module->getDevice()->getHwInfo().capabilityTable.grfSize;
uint32_t perThreadDataSizeForWholeThreadGroupNeeded =