mirror of
https://github.com/intel/compute-runtime.git
synced 2025-09-10 12:53:42 +08:00
Make sure that local workgroup size is properly passed for IOH estimation.
Change-Id: I0ad5da4fffd1575f64d44803ce8eb4a6a0ab1532
This commit is contained in:

committed by
sys_ocldev

parent
8a5b0ee518
commit
2d0af9d4a4
@ -455,6 +455,14 @@ void dispatchWalker(
|
||||
OCLRT::IndirectHeap *dsh = nullptr, *ish = nullptr, *ioh = nullptr, *ssh = nullptr;
|
||||
bool executionModelKernel = multiDispatchInfo.begin()->getKernel()->isParentKernel;
|
||||
|
||||
for (auto &dispatchInfo : multiDispatchInfo) {
|
||||
// Compute local workgroup sizes
|
||||
if (dispatchInfo.getLocalWorkgroupSize().x == 0) {
|
||||
const auto lws = generateWorkgroupSize(dispatchInfo);
|
||||
const_cast<DispatchInfo &>(dispatchInfo).setLWS(lws);
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate command stream and indirect heaps
|
||||
size_t cmdQInstructionHeapReservedBlockSize = 0;
|
||||
if (blockQueue) {
|
||||
@ -541,7 +549,7 @@ void dispatchWalker(
|
||||
Vec3<size_t> swgs = dispatchInfo.getStartOfWorkgroups();
|
||||
|
||||
// Compute local workgroup sizes
|
||||
Vec3<size_t> lws = (dispatchInfo.getLocalWorkgroupSize().x > 0) ? dispatchInfo.getLocalWorkgroupSize() : generateWorkgroupSize(dispatchInfo);
|
||||
Vec3<size_t> lws = dispatchInfo.getLocalWorkgroupSize();
|
||||
Vec3<size_t> elws = (dispatchInfo.getEnqueuedWorkgroupSize().x > 0) ? dispatchInfo.getEnqueuedWorkgroupSize() : lws;
|
||||
|
||||
// Compute number of work groups
|
||||
|
@ -245,7 +245,7 @@ void Device::initializeCaps() {
|
||||
//default value if systemInfo not provided
|
||||
deviceInfo.maxWorkGroupSize = 128;
|
||||
}
|
||||
DEBUG_BREAK_IF(deviceInfo.maxWorkGroupSize > 256);
|
||||
DEBUG_BREAK_IF(!DebugManager.flags.UseMaxSimdSizeToDeduceMaxWorkgroupSize.get() && deviceInfo.maxWorkGroupSize > 256);
|
||||
|
||||
// calculate a maximum number of subgroups in a workgroup (for the required SIMD size)
|
||||
deviceInfo.maxNumOfSubGroups = static_cast<uint32_t>(deviceInfo.maxWorkGroupSize / simdSizeUsed);
|
||||
|
@ -22,8 +22,10 @@
|
||||
|
||||
#pragma once
|
||||
#include "runtime/helpers/debug_helpers.h"
|
||||
#include "runtime/utilities/vec.h"
|
||||
#include <cstdint>
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include <stdio.h>
|
||||
|
||||
#define KB 1024uLL
|
||||
@ -153,5 +155,14 @@ inline bool isDivisableByPowerOfTwoDivisor(uint32_t number, uint32_t divisor) {
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
inline size_t computeTotalElementsCount(const Vec3<size_t> &inputVector) {
|
||||
size_t minElementCount = 1;
|
||||
auto xDim = std::max(minElementCount, inputVector.x);
|
||||
auto yDim = std::max(minElementCount, inputVector.y);
|
||||
auto zDim = std::max(minElementCount, inputVector.z);
|
||||
return xDim * yDim * zDim;
|
||||
}
|
||||
|
||||
} // namespace Math
|
||||
} // namespace OCLRT
|
||||
|
@ -118,8 +118,7 @@ struct KernelCommandsHelper : public PerThreadDataHelper {
|
||||
static size_t getTotalSizeRequiredIH(
|
||||
const MultiDispatchInfo &multiDispatchInfo);
|
||||
static size_t getTotalSizeRequiredIOH(
|
||||
const MultiDispatchInfo &multiDispatchInfo,
|
||||
size_t localWorkSize = 256);
|
||||
const MultiDispatchInfo &multiDispatchInfo);
|
||||
static size_t getTotalSizeRequiredSSH(
|
||||
const MultiDispatchInfo &multiDispatchInfo);
|
||||
|
||||
|
@ -112,7 +112,7 @@ size_t getSizeRequired(const MultiDispatchInfo &multiDispatchInfo, SizeGetterT &
|
||||
auto it = multiDispatchInfo.begin();
|
||||
for (auto e = multiDispatchInfo.end(); it != e; ++it) {
|
||||
totalSize = alignUp(totalSize, MemoryConstants::pageSize);
|
||||
totalSize += getSize(it->getKernel(), std::forward<ArgsT>(args)...);
|
||||
totalSize += getSize(*it, std::forward<ArgsT>(args)...);
|
||||
}
|
||||
return totalSize;
|
||||
}
|
||||
@ -120,25 +120,25 @@ size_t getSizeRequired(const MultiDispatchInfo &multiDispatchInfo, SizeGetterT &
|
||||
template <typename GfxFamily>
|
||||
size_t KernelCommandsHelper<GfxFamily>::getTotalSizeRequiredDSH(
|
||||
const MultiDispatchInfo &multiDispatchInfo) {
|
||||
return getSizeRequired(multiDispatchInfo, [](const Kernel *k) { return getSizeRequiredDSH(*k); });
|
||||
return getSizeRequired(multiDispatchInfo, [](const DispatchInfo &dispatchInfo) { return getSizeRequiredDSH(*dispatchInfo.getKernel()); });
|
||||
}
|
||||
|
||||
template <typename GfxFamily>
|
||||
size_t KernelCommandsHelper<GfxFamily>::getTotalSizeRequiredIH(
|
||||
const MultiDispatchInfo &multiDispatchInfo) {
|
||||
return getSizeRequired(multiDispatchInfo, [](const Kernel *k) { return getSizeRequiredIH(*k); });
|
||||
return getSizeRequired(multiDispatchInfo, [](const DispatchInfo &dispatchInfo) { return getSizeRequiredIH(*dispatchInfo.getKernel()); });
|
||||
}
|
||||
|
||||
template <typename GfxFamily>
|
||||
size_t KernelCommandsHelper<GfxFamily>::getTotalSizeRequiredIOH(
|
||||
const MultiDispatchInfo &multiDispatchInfo, size_t localWorkSize) {
|
||||
return getSizeRequired(multiDispatchInfo, [localWorkSize](Kernel *k) { return getSizeRequiredIOH(*k, localWorkSize); });
|
||||
const MultiDispatchInfo &multiDispatchInfo) {
|
||||
return getSizeRequired(multiDispatchInfo, [](const DispatchInfo &dispatchInfo) { return getSizeRequiredIOH(*dispatchInfo.getKernel(), Math::computeTotalElementsCount(dispatchInfo.getLocalWorkgroupSize())); });
|
||||
}
|
||||
|
||||
template <typename GfxFamily>
|
||||
size_t KernelCommandsHelper<GfxFamily>::getTotalSizeRequiredSSH(
|
||||
const MultiDispatchInfo &multiDispatchInfo) {
|
||||
return getSizeRequired(multiDispatchInfo, [](const Kernel *k) { return getSizeRequiredSSH(*k); });
|
||||
return getSizeRequired(multiDispatchInfo, [](const DispatchInfo &dispatchInfo) { return getSizeRequiredSSH(*dispatchInfo.getKernel()); });
|
||||
}
|
||||
|
||||
template <typename GfxFamily>
|
||||
@ -185,13 +185,13 @@ size_t KernelCommandsHelper<GfxFamily>::sendInterfaceDescriptorData(
|
||||
// # of threads in thread group should be based on LWS.
|
||||
pInterfaceDescriptor->setNumberOfThreadsInGpgpuThreadGroup(threadsPerThreadGroup);
|
||||
|
||||
DEBUG_BREAK_IF((sizeCrossThreadData % sizeof(GRF)) != 0);
|
||||
DEBUG_BREAK_IF((sizeCrossThreadData % sizeof(GRF)) != 0);
|
||||
auto numGrfCrossThreadData = static_cast<uint32_t>(sizeCrossThreadData / sizeof(GRF));
|
||||
DEBUG_BREAK_IF(numGrfCrossThreadData == 0);
|
||||
pInterfaceDescriptor->setCrossThreadConstantDataReadLength(numGrfCrossThreadData);
|
||||
pInterfaceDescriptor->setDenormMode(INTERFACE_DESCRIPTOR_DATA::DENORM_MODE_SETBYKERNEL);
|
||||
|
||||
DEBUG_BREAK_IF((sizePerThreadData % sizeof(GRF)) != 0);
|
||||
DEBUG_BREAK_IF((sizePerThreadData % sizeof(GRF)) != 0);
|
||||
auto numGrfPerThreadData = static_cast<uint32_t>(sizePerThreadData / sizeof(GRF));
|
||||
|
||||
// at least 1 GRF of perThreadData for each thread in a thread group when sizeCrossThreadData != 0
|
||||
@ -299,7 +299,7 @@ size_t KernelCommandsHelper<GfxFamily>::pushBindingTableAndSurfaceStates(Indirec
|
||||
|
||||
// march over BTIs and offset the pointers based on surface state base address
|
||||
auto *dstBtiTableBase = reinterpret_cast<BINDING_TABLE_STATE *>(ptrOffset(dstSurfaceState, localBtiOffset));
|
||||
DEBUG_BREAK_IF(reinterpret_cast<uintptr_t>(dstBtiTableBase) % INTERFACE_DESCRIPTOR_DATA::BINDINGTABLEPOINTER_ALIGN_SIZE != 0);
|
||||
DEBUG_BREAK_IF(reinterpret_cast<uintptr_t>(dstBtiTableBase) % INTERFACE_DESCRIPTOR_DATA::BINDINGTABLEPOINTER_ALIGN_SIZE != 0);
|
||||
auto *srcBtiTableBase = reinterpret_cast<const BINDING_TABLE_STATE *>(ptrOffset(srcSurfaceState, localBtiOffset));
|
||||
BINDING_TABLE_STATE bti;
|
||||
bti.init(); // init whole DWORD - i.e. not just the SurfaceStatePointer bits
|
||||
@ -308,7 +308,7 @@ size_t KernelCommandsHelper<GfxFamily>::pushBindingTableAndSurfaceStates(Indirec
|
||||
uint32_t offsetedSurfaceStateOffset = localSurfaceStateOffset + surfaceStatesOffset;
|
||||
bti.setSurfaceStatePointer(offsetedSurfaceStateOffset); // patch just the SurfaceStatePointer bits
|
||||
dstBtiTableBase[i] = bti;
|
||||
DEBUG_BREAK_IF(bti.getRawData(0) % sizeof(BINDING_TABLE_STATE::SURFACESTATEPOINTER_ALIGN_SIZE) != 0);
|
||||
DEBUG_BREAK_IF(bti.getRawData(0) % sizeof(BINDING_TABLE_STATE::SURFACESTATEPOINTER_ALIGN_SIZE) != 0);
|
||||
}
|
||||
|
||||
return ptrDiff(dstBtiTableBase, dstHeap.getBase());
|
||||
|
@ -671,10 +671,12 @@ HWTEST_F(DispatchWalkerTest, dispatchWalkerShouldGetRequiredHeapSizesFromKernelW
|
||||
nullptr,
|
||||
blockQueue);
|
||||
|
||||
Vec3<size_t> localWorkgroupSize(workGroupSize);
|
||||
|
||||
auto expectedSizeCS = MemoryConstants::pageSize; //can get estimated more precisely
|
||||
auto expectedSizeDSH = KernelCommandsHelper<FamilyType>::getSizeRequiredDSH(kernel);
|
||||
auto expectedSizeISH = KernelCommandsHelper<FamilyType>::getSizeRequiredIH(kernel);
|
||||
auto expectedSizeIOH = KernelCommandsHelper<FamilyType>::getSizeRequiredIOH(kernel);
|
||||
auto expectedSizeIOH = KernelCommandsHelper<FamilyType>::getSizeRequiredIOH(kernel, Math::computeTotalElementsCount(localWorkgroupSize));
|
||||
auto expectedSizeSSH = KernelCommandsHelper<FamilyType>::getSizeRequiredSSH(kernel);
|
||||
|
||||
EXPECT_EQ(expectedSizeCS, blockedCommandsData->commandStream->getMaxAvailableSpace());
|
||||
|
@ -24,6 +24,7 @@
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using namespace OCLRT::Math;
|
||||
using namespace OCLRT;
|
||||
|
||||
TEST(NextPowerOfTwo, aFewCases) {
|
||||
EXPECT_EQ(1u, nextPowerOfTwo(1));
|
||||
@ -166,3 +167,24 @@ TEST(l3configsGenerator, givenInputValuesWhenPassedToL3ConfigThenRawValueIsProdu
|
||||
EXPECT_EQ(0x40u, config2.bits.AllL3WayAssignement);
|
||||
EXPECT_EQ(0x20u, config2.bits.UrbAllocation);
|
||||
}
|
||||
|
||||
struct ElementCountsTestData {
|
||||
size_t x, y, z;
|
||||
size_t result;
|
||||
};
|
||||
|
||||
ElementCountsTestData elementCountInputData[] = {
|
||||
{1, 2, 3, 6},
|
||||
{0, 0, 1, 1},
|
||||
{0, 1, 0, 1},
|
||||
{1, 0, 0, 1},
|
||||
{5, 0, 10, 50},
|
||||
{0, 0, 30, 30},
|
||||
};
|
||||
|
||||
typedef ::testing::TestWithParam<ElementCountsTestData> ComputeTotalElementsCount;
|
||||
|
||||
TEST_P(ComputeTotalElementsCount, givenVariousInputVectorsWhenComputeTotalElementsCountIsUsedThenProperProductIsComputed) {
|
||||
Vec3<size_t> inputData(GetParam().x, GetParam().y, GetParam().z);
|
||||
EXPECT_EQ(GetParam().result, computeTotalElementsCount(inputData));
|
||||
}
|
||||
|
Reference in New Issue
Block a user