From c612a86d281288fd4cc9f5452c256afcd29b2779 Mon Sep 17 00:00:00 2001 From: "Dunajski, Bartosz" Date: Mon, 18 Dec 2023 18:25:00 +0000 Subject: [PATCH] feature: initial support for new zeinfo args Related-To: NEO-8070 Signed-off-by: Dunajski, Bartosz --- .../device_binary_format/zebin/zeinfo.h | 6 ++ .../zebin/zeinfo_decoder.cpp | 9 ++ .../zebin/zeinfo_enum_lookup.h | 87 ++++++++------- shared/source/kernel/kernel_descriptor.h | 4 + .../zebin_decoder_tests.cpp | 101 ++++++++++++++++++ 5 files changed, 166 insertions(+), 41 deletions(-) diff --git a/shared/source/device_binary_format/zebin/zeinfo.h b/shared/source/device_binary_format/zebin/zeinfo.h index 9c0e218be3..2fa5c810fb 100644 --- a/shared/source/device_binary_format/zebin/zeinfo.h +++ b/shared/source/device_binary_format/zebin/zeinfo.h @@ -123,6 +123,9 @@ inline constexpr ConstStringRef dataGlobalBuffer("global_base"); inline constexpr ConstStringRef assertBuffer("assert_buffer"); inline constexpr ConstStringRef indirectDataPointer("indirect_data_pointer"); inline constexpr ConstStringRef scratchPointer("scratch_pointer"); +inline constexpr ConstStringRef regionGroupSize("region_group_size"); +inline constexpr ConstStringRef regionGroupDimension("region_group_dimension"); +inline constexpr ConstStringRef regionGroupWgCount("region_group_wg_count"); namespace Image { inline constexpr ConstStringRef width("image_width"); @@ -492,6 +495,9 @@ enum ArgType : uint8_t { argTypeAssertBuffer, argTypeIndirectDataPointer, argTypeScratchPointer, + argTypeRegionGroupSize, + argTypeRegionGroupDimension, + argTypeRegionGroupWgCount, argTypeMax }; diff --git a/shared/source/device_binary_format/zebin/zeinfo_decoder.cpp b/shared/source/device_binary_format/zebin/zeinfo_decoder.cpp index b937a52113..a71b86d3be 100644 --- a/shared/source/device_binary_format/zebin/zeinfo_decoder.cpp +++ b/shared/source/device_binary_format/zebin/zeinfo_decoder.cpp @@ -1343,6 +1343,15 @@ DecodeError populateKernelPayloadArgument(NEO::KernelDescriptor &dst, const Kern case Types::Kernel::argTypeVmeSearchPathType: return populateWithOffset(getVmeDescriptor()->searchPathType); + + case Types::Kernel::argTypeRegionGroupSize: + return populateArgVec(dst.payloadMappings.dispatchTraits.regionGroupSize, Tags::Kernel::PayloadArgument::ArgType::regionGroupSize); + + case Types::Kernel::argTypeRegionGroupDimension: + return populateWithOffsetChecked(dst.payloadMappings.dispatchTraits.regionGroupDimension, sizeof(int32_t), Tags::Kernel::PayloadArgument::ArgType::regionGroupDimension); + + case Types::Kernel::argTypeRegionGroupWgCount: + return populateWithOffsetChecked(dst.payloadMappings.dispatchTraits.regionGroupWgCount, sizeof(int32_t), Tags::Kernel::PayloadArgument::ArgType::regionGroupWgCount); } UNREACHABLE(); diff --git a/shared/source/device_binary_format/zebin/zeinfo_enum_lookup.h b/shared/source/device_binary_format/zebin/zeinfo_enum_lookup.h index 3067264be2..f6393126e4 100644 --- a/shared/source/device_binary_format/zebin/zeinfo_enum_lookup.h +++ b/shared/source/device_binary_format/zebin/zeinfo_enum_lookup.h @@ -22,47 +22,52 @@ using namespace Tags::Kernel::PayloadArgument::ArgType::Sampler::Vme; using ArgType = Types::Kernel::ArgType; inline constexpr ConstStringRef name = "argument type"; -inline constexpr LookupArray lookup({{{packedLocalIds, ArgType::argTypePackedLocalIds}, - {localId, ArgType::argTypeLocalId}, - {localSize, ArgType::argTypeLocalSize}, - {groupCount, ArgType::argTypeGroupCount}, - {globalSize, ArgType::argTypeGlobalSize}, - {enqueuedLocalSize, ArgType::argTypeEnqueuedLocalSize}, - {globalIdOffset, ArgType::argTypeGlobalIdOffset}, - {privateBaseStateless, ArgType::argTypePrivateBaseStateless}, - {argByvalue, ArgType::argTypeArgByvalue}, - {argBypointer, ArgType::argTypeArgBypointer}, - {bufferAddress, ArgType::argTypeBufferAddress}, - {bufferOffset, ArgType::argTypeBufferOffset}, - {printfBuffer, ArgType::argTypePrintfBuffer}, - {workDimensions, ArgType::argTypeWorkDimensions}, - {implicitArgBuffer, ArgType::argTypeImplicitArgBuffer}, - {width, ArgType::argTypeImageWidth}, - {height, ArgType::argTypeImageHeight}, - {depth, ArgType::argTypeImageDepth}, - {channelDataType, ArgType::argTypeImageChannelDataType}, - {channelOrder, ArgType::argTypeImageChannelOrder}, - {arraySize, ArgType::argTypeImageArraySize}, - {numSamples, ArgType::argTypeImageNumSamples}, - {numMipLevels, ArgType::argTypeImageMipLevels}, - {flatBaseOffset, ArgType::argTypeImageFlatBaseOffset}, - {flatWidth, ArgType::argTypeImageFlatWidth}, - {flatHeight, ArgType::argTypeImageFlatHeight}, - {flatPitch, ArgType::argTypeImageFlatPitch}, - {snapWa, ArgType::argTypeSamplerSnapWa}, - {normCoords, ArgType::argTypeSamplerNormCoords}, - {addrMode, ArgType::argTypeSamplerAddrMode}, - {blockType, ArgType::argTypeVmeMbBlockType}, - {subpixelMode, ArgType::argTypeVmeSubpixelMode}, - {sadAdjustMode, ArgType::argTypeVmeSadAdjustMode}, - {searchPathType, ArgType::argTypeVmeSearchPathType}, - {syncBuffer, ArgType::argTypeSyncBuffer}, - {rtGlobalBuffer, ArgType::argTypeRtGlobalBuffer}, - {dataConstBuffer, ArgType::argTypeDataConstBuffer}, - {dataGlobalBuffer, ArgType::argTypeDataGlobalBuffer}, - {assertBuffer, ArgType::argTypeAssertBuffer}, - {indirectDataPointer, ArgType::argTypeIndirectDataPointer}, - {scratchPointer, ArgType::argTypeScratchPointer}}}); +inline constexpr LookupArray lookup({{ + {packedLocalIds, ArgType::argTypePackedLocalIds}, + {localId, ArgType::argTypeLocalId}, + {localSize, ArgType::argTypeLocalSize}, + {groupCount, ArgType::argTypeGroupCount}, + {globalSize, ArgType::argTypeGlobalSize}, + {enqueuedLocalSize, ArgType::argTypeEnqueuedLocalSize}, + {globalIdOffset, ArgType::argTypeGlobalIdOffset}, + {privateBaseStateless, ArgType::argTypePrivateBaseStateless}, + {argByvalue, ArgType::argTypeArgByvalue}, + {argBypointer, ArgType::argTypeArgBypointer}, + {bufferAddress, ArgType::argTypeBufferAddress}, + {bufferOffset, ArgType::argTypeBufferOffset}, + {printfBuffer, ArgType::argTypePrintfBuffer}, + {workDimensions, ArgType::argTypeWorkDimensions}, + {implicitArgBuffer, ArgType::argTypeImplicitArgBuffer}, + {width, ArgType::argTypeImageWidth}, + {height, ArgType::argTypeImageHeight}, + {depth, ArgType::argTypeImageDepth}, + {channelDataType, ArgType::argTypeImageChannelDataType}, + {channelOrder, ArgType::argTypeImageChannelOrder}, + {arraySize, ArgType::argTypeImageArraySize}, + {numSamples, ArgType::argTypeImageNumSamples}, + {numMipLevels, ArgType::argTypeImageMipLevels}, + {flatBaseOffset, ArgType::argTypeImageFlatBaseOffset}, + {flatWidth, ArgType::argTypeImageFlatWidth}, + {flatHeight, ArgType::argTypeImageFlatHeight}, + {flatPitch, ArgType::argTypeImageFlatPitch}, + {snapWa, ArgType::argTypeSamplerSnapWa}, + {normCoords, ArgType::argTypeSamplerNormCoords}, + {addrMode, ArgType::argTypeSamplerAddrMode}, + {blockType, ArgType::argTypeVmeMbBlockType}, + {subpixelMode, ArgType::argTypeVmeSubpixelMode}, + {sadAdjustMode, ArgType::argTypeVmeSadAdjustMode}, + {searchPathType, ArgType::argTypeVmeSearchPathType}, + {syncBuffer, ArgType::argTypeSyncBuffer}, + {rtGlobalBuffer, ArgType::argTypeRtGlobalBuffer}, + {dataConstBuffer, ArgType::argTypeDataConstBuffer}, + {dataGlobalBuffer, ArgType::argTypeDataGlobalBuffer}, + {assertBuffer, ArgType::argTypeAssertBuffer}, + {indirectDataPointer, ArgType::argTypeIndirectDataPointer}, + {scratchPointer, ArgType::argTypeScratchPointer}, + {regionGroupSize, ArgType::argTypeRegionGroupSize}, + {regionGroupDimension, ArgType::argTypeRegionGroupDimension}, + {regionGroupWgCount, ArgType::argTypeRegionGroupWgCount}, +}}); static_assert(lookup.size() == ArgType::argTypeMax - 1, "Every enum field must be present"); } // namespace ArgType diff --git a/shared/source/kernel/kernel_descriptor.h b/shared/source/kernel/kernel_descriptor.h index b9f03fb476..752e7b1d68 100644 --- a/shared/source/kernel/kernel_descriptor.h +++ b/shared/source/kernel/kernel_descriptor.h @@ -151,6 +151,10 @@ struct KernelDescriptor { CrossThreadDataOffset enqueuedLocalWorkSize[3] = {undefined, undefined, undefined}; CrossThreadDataOffset numWorkGroups[3] = {undefined, undefined, undefined}; CrossThreadDataOffset workDim = undefined; + + CrossThreadDataOffset regionGroupSize[3] = {undefined, undefined, undefined}; + CrossThreadDataOffset regionGroupDimension = undefined; + CrossThreadDataOffset regionGroupWgCount = undefined; } dispatchTraits; struct { diff --git a/shared/test/unit_test/device_binary_format/zebin_decoder_tests.cpp b/shared/test/unit_test/device_binary_format/zebin_decoder_tests.cpp index 17b55bc8f5..ae04e60d55 100644 --- a/shared/test/unit_test/device_binary_format/zebin_decoder_tests.cpp +++ b/shared/test/unit_test/device_binary_format/zebin_decoder_tests.cpp @@ -4815,6 +4815,107 @@ TEST_F(decodeZeInfoKernelEntryTest, GivenArgTypeGlobalSizeWhenArgSizeIsInvalidTh EXPECT_TRUE(warnings.empty()) << warnings; } +TEST_F(decodeZeInfoKernelEntryTest, givenRegionArgTypesWhenArgSizeIsInvalidThenFails) { + ConstStringRef zeInfoRegionGroupSize = R"===( + kernels: + - name : some_kernel + execution_env: + simd_size: 32 + payload_arguments: + - arg_type : region_group_size + offset : 16 + size : 7 +)==="; + auto err = decodeZeInfoKernelEntry(zeInfoRegionGroupSize); + EXPECT_EQ(NEO::DecodeError::invalidBinary, err); + EXPECT_STREQ("DeviceBinaryFormat::zebin : Invalid size for argument of type region_group_size in context of : some_kernel. Expected 4 or 8 or 12. Got : 7\n", errors.c_str()); + EXPECT_TRUE(warnings.empty()) << warnings; + + ConstStringRef zeInfoRegionGroupDim = R"===( + kernels: + - name : some_kernel + execution_env: + simd_size: 32 + payload_arguments: + - arg_type : region_group_dimension + offset : 16 + size : 7 +)==="; + err = decodeZeInfoKernelEntry(zeInfoRegionGroupDim); + EXPECT_EQ(NEO::DecodeError::invalidBinary, err); + EXPECT_STREQ("DeviceBinaryFormat::zebin : Invalid size for argument of type region_group_dimension in context of : some_kernel. Expected 4. Got : 7\n", errors.c_str()); + EXPECT_TRUE(warnings.empty()) << warnings; + + ConstStringRef zeInfoRegionGroupCount = R"===( + kernels: + - name : some_kernel + execution_env: + simd_size: 32 + payload_arguments: + - arg_type : region_group_wg_count + offset : 16 + size : 7 +)==="; + err = decodeZeInfoKernelEntry(zeInfoRegionGroupCount); + EXPECT_EQ(NEO::DecodeError::invalidBinary, err); + EXPECT_STREQ("DeviceBinaryFormat::zebin : Invalid size for argument of type region_group_wg_count in context of : some_kernel. Expected 4. Got : 7\n", errors.c_str()); + EXPECT_TRUE(warnings.empty()) << warnings; +} + +TEST_F(decodeZeInfoKernelEntryTest, givenRegionArgTypesWhenArgSizeIsCorrectThenReturnSuccess) { + ConstStringRef zeInfoRegionGroupSize = R"===( + kernels: + - name : some_kernel + execution_env: + simd_size: 32 + payload_arguments: + - arg_type : region_group_size + offset : 16 + size : 12 +)==="; + auto err = decodeZeInfoKernelEntry(zeInfoRegionGroupSize); + EXPECT_EQ(NEO::DecodeError::success, err); + EXPECT_TRUE(errors.empty()) << errors; + EXPECT_TRUE(warnings.empty()) << warnings; + for (uint32_t i = 0; i < 3; ++i) { + EXPECT_EQ(16 + sizeof(uint32_t) * i, kernelDescriptor->payloadMappings.dispatchTraits.regionGroupSize[i]); + } + + ConstStringRef zeInfoRegionGroupDim = R"===( + kernels: + - name : some_kernel + execution_env: + simd_size: 32 + payload_arguments: + - arg_type : region_group_dimension + offset : 16 + size : 4 +)==="; + err = decodeZeInfoKernelEntry(zeInfoRegionGroupDim); + EXPECT_EQ(NEO::DecodeError::success, err); + EXPECT_TRUE(errors.empty()) << errors; + EXPECT_TRUE(warnings.empty()) << warnings; + + EXPECT_EQ(16, kernelDescriptor->payloadMappings.dispatchTraits.regionGroupDimension); + + ConstStringRef zeInfoRegionGroupCount = R"===( + kernels: + - name : some_kernel + execution_env: + simd_size: 32 + payload_arguments: + - arg_type : region_group_wg_count + offset : 16 + size : 4 +)==="; + err = decodeZeInfoKernelEntry(zeInfoRegionGroupCount); + EXPECT_EQ(NEO::DecodeError::success, err); + EXPECT_TRUE(errors.empty()) << errors; + EXPECT_TRUE(warnings.empty()) << warnings; + + EXPECT_EQ(16, kernelDescriptor->payloadMappings.dispatchTraits.regionGroupWgCount); +} + TEST_F(decodeZeInfoKernelEntryTest, GivenArgTypeGlobalSizeWhenArgSizeValidThenPopulatesKernelDescriptor) { uint32_t vectorSizes[] = {4, 8, 12};