feature(zebin): add support for spill/private size in execution env

add fallback to previous logic based on zeinfo version

Related-To: NEO-9944
Signed-off-by: Mateusz Jablonski <mateusz.jablonski@intel.com>
This commit is contained in:
Mateusz Jablonski
2024-01-22 15:17:44 +00:00
committed by Compute-Runtime-Automation
parent c0686da2d6
commit dd7083d710
4 changed files with 114 additions and 37 deletions

View File

@@ -420,7 +420,8 @@ DecodeError decodeZeInfo(ProgramInfo &dst, ConstStringRef zeInfo, std::string &o
return DecodeError::invalidBinary;
}
auto zeInfoDecodeError = decodeZeInfoVersion(yamlParser, zeInfoSections, outErrReason, outWarning);
Types::Version zeInfoVersion{};
auto zeInfoDecodeError = decodeZeInfoVersion(yamlParser, zeInfoSections, outErrReason, outWarning, zeInfoVersion);
if (DecodeError::success != zeInfoDecodeError) {
return zeInfoDecodeError;
}
@@ -435,7 +436,7 @@ DecodeError decodeZeInfo(ProgramInfo &dst, ConstStringRef zeInfo, std::string &o
return zeInfoDecodeError;
}
zeInfoDecodeError = decodeZeInfoKernels(dst, yamlParser, zeInfoSections, outErrReason, outWarning);
zeInfoDecodeError = decodeZeInfoKernels(dst, yamlParser, zeInfoSections, outErrReason, outWarning, zeInfoVersion);
if (DecodeError::success != zeInfoDecodeError) {
return zeInfoDecodeError;
}
@@ -443,18 +444,18 @@ DecodeError decodeZeInfo(ProgramInfo &dst, ConstStringRef zeInfo, std::string &o
return DecodeError::success;
}
DecodeError decodeZeInfoVersion(Yaml::YamlParser &parser, const ZeInfoSections &zeInfoSections, std::string &outErrReason, std::string &outWarning) {
DecodeError decodeZeInfoVersion(Yaml::YamlParser &parser, const ZeInfoSections &zeInfoSections, std::string &outErrReason, std::string &outWarning, Types::Version &srcZeInfoVersion) {
if (false == zeInfoSections.version.empty()) {
Types::Version zeInfoVersion;
auto err = readZeInfoVersionFromZeInfo(zeInfoVersion, parser, *zeInfoSections.version[0], outErrReason, outWarning);
auto err = readZeInfoVersionFromZeInfo(srcZeInfoVersion, parser, *zeInfoSections.version[0], outErrReason, outWarning);
if (DecodeError::success != err) {
return err;
}
err = validateZeInfoVersion(zeInfoVersion, outErrReason, outWarning);
err = validateZeInfoVersion(srcZeInfoVersion, outErrReason, outWarning);
if (DecodeError::success != err) {
return err;
}
} else {
srcZeInfoVersion = zeInfoDecoderVersion;
outWarning.append("DeviceBinaryFormat::zebin::.ze_info : No version info provided (i.e. no " + Tags::version.str() + " entry in global scope of DeviceBinaryFormat::zebin::.ze_info) - will use decoder's default : \'" + std::to_string(zeInfoDecoderVersion.major) + "." + std::to_string(zeInfoDecoderVersion.minor) + "\'\n");
}
return DecodeError::success;
@@ -487,11 +488,11 @@ DecodeError decodeZeInfoFunctions(ProgramInfo &dst, Yaml::YamlParser &parser, co
return DecodeError::success;
}
DecodeError decodeZeInfoKernels(ProgramInfo &dst, Yaml::YamlParser &parser, const ZeInfoSections &zeInfoSections, std::string &outErrReason, std::string &outWarning) {
DecodeError decodeZeInfoKernels(ProgramInfo &dst, Yaml::YamlParser &parser, const ZeInfoSections &zeInfoSections, std::string &outErrReason, std::string &outWarning, const Types::Version &srcZeInfoVersion) {
UNRECOVERABLE_IF(zeInfoSections.kernels.size() != 1U);
for (const auto &kernelNd : parser.createChildrenRange(*zeInfoSections.kernels[0])) {
auto kernelInfo = std::make_unique<KernelInfo>();
auto zeInfoErr = decodeZeInfoKernelEntry(kernelInfo->kernelDescriptor, parser, kernelNd, dst.grfSize, dst.minScratchSpaceSize, outErrReason, outWarning);
auto zeInfoErr = decodeZeInfoKernelEntry(kernelInfo->kernelDescriptor, parser, kernelNd, dst.grfSize, dst.minScratchSpaceSize, outErrReason, outWarning, srcZeInfoVersion);
if (DecodeError::success != zeInfoErr) {
return zeInfoErr;
}
@@ -504,7 +505,7 @@ DecodeError decodeZeInfoKernels(ProgramInfo &dst, Yaml::YamlParser &parser, cons
return DecodeError::success;
}
DecodeError decodeZeInfoKernelEntry(NEO::KernelDescriptor &dst, NEO::Yaml::YamlParser &yamlParser, const NEO::Yaml::Node &kernelNd, uint32_t grfSize, uint32_t minScratchSpaceSize, std::string &outErrReason, std::string &outWarning) {
DecodeError decodeZeInfoKernelEntry(NEO::KernelDescriptor &dst, NEO::Yaml::YamlParser &yamlParser, const NEO::Yaml::Node &kernelNd, uint32_t grfSize, uint32_t minScratchSpaceSize, std::string &outErrReason, std::string &outWarning, const Types::Version &srcZeInfoVersion) {
ZeInfoKernelSections zeInfokernelSections;
extractZeInfoKernelSections(yamlParser, kernelNd, zeInfokernelSections, ".ze_info", outWarning);
auto extractError = validateZeInfoKernelSectionsCount(zeInfokernelSections, outErrReason, outWarning);
@@ -515,7 +516,7 @@ DecodeError decodeZeInfoKernelEntry(NEO::KernelDescriptor &dst, NEO::Yaml::YamlP
dst.kernelAttributes.binaryFormat = DeviceBinaryFormat::zebin;
dst.kernelMetadata.kernelName = yamlParser.readValueNoQuotes(*zeInfokernelSections.nameNd[0]).str();
auto decodeError = decodeZeInfoKernelExecutionEnvironment(dst, yamlParser, zeInfokernelSections, outErrReason, outWarning);
auto decodeError = decodeZeInfoKernelExecutionEnvironment(dst, yamlParser, zeInfokernelSections, outErrReason, outWarning, srcZeInfoVersion);
if (DecodeError::success != decodeError) {
return decodeError;
}
@@ -545,7 +546,7 @@ DecodeError decodeZeInfoKernelEntry(NEO::KernelDescriptor &dst, NEO::Yaml::YamlP
return decodeError;
}
decodeError = decodeZeInfoKernelPerThreadMemoryBuffers(dst, yamlParser, zeInfokernelSections, minScratchSpaceSize, outErrReason, outWarning);
decodeError = decodeZeInfoKernelPerThreadMemoryBuffers(dst, yamlParser, zeInfokernelSections, minScratchSpaceSize, outErrReason, outWarning, srcZeInfoVersion);
if (DecodeError::success != decodeError) {
return decodeError;
}
@@ -580,13 +581,13 @@ DecodeError decodeZeInfoKernelEntry(NEO::KernelDescriptor &dst, NEO::Yaml::YamlP
return DecodeError::success;
}
DecodeError decodeZeInfoKernelExecutionEnvironment(KernelDescriptor &dst, Yaml::YamlParser &parser, const ZeInfoKernelSections &kernelSections, std::string &outErrReason, std::string &outWarning) {
DecodeError decodeZeInfoKernelExecutionEnvironment(KernelDescriptor &dst, Yaml::YamlParser &parser, const ZeInfoKernelSections &kernelSections, std::string &outErrReason, std::string &outWarning, const Types::Version &srcZeInfoVersion) {
KernelExecutionEnvBaseT execEnv;
auto execEnvErr = readZeInfoExecutionEnvironment(parser, *kernelSections.executionEnvNd[0], execEnv, dst.kernelMetadata.kernelName, outErrReason, outWarning);
if (DecodeError::success != execEnvErr) {
return execEnvErr;
}
populateKernelExecutionEnvironment(dst, execEnv);
populateKernelExecutionEnvironment(dst, execEnv, srcZeInfoVersion);
return DecodeError::success;
}
@@ -647,6 +648,10 @@ DecodeError readZeInfoExecutionEnvironment(const Yaml::YamlParser &parser, const
validExecEnv &= readZeInfoValueChecked(parser, execEnvMetadataNd, outExecEnv.indirectStatelessCount, context, outErrReason);
} else if (Tags::Kernel::ExecutionEnv::hasSample == key) {
validExecEnv &= readZeInfoValueChecked(parser, execEnvMetadataNd, outExecEnv.hasSample, context, outErrReason);
} else if (Tags::Kernel::ExecutionEnv::privateSize == key) {
validExecEnv &= readZeInfoValueChecked(parser, execEnvMetadataNd, outExecEnv.privateSize, context, outErrReason);
} else if (Tags::Kernel::ExecutionEnv::spillSize == key) {
validExecEnv &= readZeInfoValueChecked(parser, execEnvMetadataNd, outExecEnv.spillSize, context, outErrReason);
} else {
outWarning.append("DeviceBinaryFormat::zebin::.ze_info : Unknown entry \"" + key.str() + "\" in context of " + context.str() + "\n");
}
@@ -664,7 +669,7 @@ DecodeError readZeInfoExecutionEnvironment(const Yaml::YamlParser &parser, const
return DecodeError::success;
}
void populateKernelExecutionEnvironment(KernelDescriptor &dst, const KernelExecutionEnvBaseT &execEnv) {
void populateKernelExecutionEnvironment(KernelDescriptor &dst, const KernelExecutionEnvBaseT &execEnv, const Types::Version &srcZeInfoVersion) {
dst.entryPoints.skipPerThreadDataLoad = execEnv.offsetToSkipPerThreadDataLoad;
dst.entryPoints.skipSetFFIDGP = execEnv.offsetToSkipSetFfidGp;
dst.kernelAttributes.flags.passInlineData = (execEnv.inlineDataPayloadSize != 0);
@@ -692,6 +697,10 @@ void populateKernelExecutionEnvironment(KernelDescriptor &dst, const KernelExecu
dst.kernelAttributes.workgroupWalkOrder[2] = static_cast<uint8_t>(execEnv.workgroupWalkOrderDimensions[2]);
dst.kernelAttributes.hasIndirectStatelessAccess = (execEnv.indirectStatelessCount > 0);
dst.kernelAttributes.numThreadsRequired = static_cast<uint32_t>(execEnv.euThreadCount);
if (isScratchMemoryUsageDefinedInExecutionEnvironment(srcZeInfoVersion)) {
dst.kernelAttributes.privateScratchMemorySize = static_cast<uint32_t>(execEnv.privateSize);
dst.kernelAttributes.spillFillScratchMemorySize = static_cast<uint32_t>(execEnv.spillSize);
}
using ThreadSchedulingMode = Types::Kernel::ExecutionEnv::ThreadSchedulingMode;
switch (execEnv.threadSchedulingMode) {
@@ -1471,7 +1480,7 @@ DecodeError populateKernelInlineSampler(KernelDescriptor &dst, const KernelInlin
return DecodeError::success;
}
DecodeError decodeZeInfoKernelPerThreadMemoryBuffers(KernelDescriptor &dst, Yaml::YamlParser &parser, const ZeInfoKernelSections &kernelSections, const uint32_t minScratchSpaceSize, std::string &outErrReason, std::string &outWarning) {
DecodeError decodeZeInfoKernelPerThreadMemoryBuffers(KernelDescriptor &dst, Yaml::YamlParser &parser, const ZeInfoKernelSections &kernelSections, const uint32_t minScratchSpaceSize, std::string &outErrReason, std::string &outWarning, const Types::Version &srcZeInfoVersion) {
if (false == kernelSections.perThreadMemoryBuffersNd.empty()) {
KernelPerThreadMemoryBuffers perThreadMemoryBuffers{};
auto perThreadMemoryBuffersErr = readZeInfoPerThreadMemoryBuffers(parser, *kernelSections.perThreadMemoryBuffersNd[0], perThreadMemoryBuffers,
@@ -1480,7 +1489,7 @@ DecodeError decodeZeInfoKernelPerThreadMemoryBuffers(KernelDescriptor &dst, Yaml
return perThreadMemoryBuffersErr;
}
for (const auto &memBuff : perThreadMemoryBuffers) {
auto decodeErr = populateKernelPerThreadMemoryBuffer(dst, memBuff, minScratchSpaceSize, outErrReason, outWarning);
auto decodeErr = populateKernelPerThreadMemoryBuffer(dst, memBuff, minScratchSpaceSize, outErrReason, outWarning, srcZeInfoVersion);
if (DecodeError::success != decodeErr) {
return decodeErr;
}
@@ -1514,7 +1523,7 @@ DecodeError readZeInfoPerThreadMemoryBuffers(const Yaml::YamlParser &parser, con
return validBuffer ? DecodeError::success : DecodeError::invalidBinary;
}
DecodeError populateKernelPerThreadMemoryBuffer(KernelDescriptor &dst, const KernelPerThreadMemoryBufferBaseT &src, const uint32_t minScratchSpaceSize, std::string &outErrReason, std::string &outWarning) {
DecodeError populateKernelPerThreadMemoryBuffer(KernelDescriptor &dst, const KernelPerThreadMemoryBufferBaseT &src, const uint32_t minScratchSpaceSize, std::string &outErrReason, std::string &outWarning, const Types::Version &srcZeInfoVersion) {
using namespace Types::Kernel::PerThreadMemoryBuffer;
using namespace Tags::Kernel::PerThreadMemoryBuffer::AllocationType;
using namespace Tags::Kernel::PerThreadMemoryBuffer::MemoryUsage;
@@ -1540,15 +1549,18 @@ DecodeError populateKernelPerThreadMemoryBuffer(KernelDescriptor &dst, const Ker
dst.kernelAttributes.perHwThreadPrivateMemorySize = size;
break;
case AllocationTypeScratch:
if (src.slot == 0) {
dst.kernelAttributes.spillFillScratchMemorySize = src.size;
} else if (src.slot == 1) {
dst.kernelAttributes.privateScratchMemorySize = src.size;
} else {
if (src.slot > 1) {
outErrReason.append("DeviceBinaryFormat::zebin : Invalid scratch buffer slot " + std::to_string(src.slot) + " in context of : " + dst.kernelMetadata.kernelName + ". Expected 0 or 1.\n");
return DecodeError::invalidBinary;
}
if (!isScratchMemoryUsageDefinedInExecutionEnvironment(srcZeInfoVersion)) {
if (src.slot == 0) {
dst.kernelAttributes.spillFillScratchMemorySize = src.size;
} else { // slot 1
dst.kernelAttributes.privateScratchMemorySize = src.size;
}
}
if (0 != dst.kernelAttributes.perThreadScratchSize[src.slot]) {
outErrReason.append("DeviceBinaryFormat::zebin : Invalid duplicated scratch buffer entry " + std::to_string(src.slot) + " in context of : " + dst.kernelMetadata.kernelName + ".\n");
return DecodeError::invalidBinary;