diff --git a/shared/source/device_binary_format/zebin/zeinfo.h b/shared/source/device_binary_format/zebin/zeinfo.h index 0a477fc5c7..721741f663 100644 --- a/shared/source/device_binary_format/zebin/zeinfo.h +++ b/shared/source/device_binary_format/zebin/zeinfo.h @@ -10,6 +10,7 @@ #include "shared/source/device_binary_format/yaml/yaml_parser.h" #include "shared/source/helpers/non_copyable_or_moveable.h" #include "shared/source/utilities/const_stringref.h" +#include "shared/source/utilities/mem_lifetime.h" #include #include @@ -664,43 +665,9 @@ inline constexpr OffsetT offset = -1; inline constexpr BtiValueT btiValue = -1; } // namespace Defaults -struct PayloadArgElasticPtrBaseT; -struct PayloadArgumentBaseT; struct PayloadArgumentExtT; -PayloadArgumentExtT *allocatePayloadArgumentExt(); -void freePayloadArgumentExt(PayloadArgumentExtT *); -void copyPayloadArgumentExt(PayloadArgumentExtT *&, const PayloadArgElasticPtrBaseT &); -struct PayloadArgElasticPtrBaseT { - PayloadArgElasticPtrBaseT() { - pPayArgExt = allocatePayloadArgumentExt(); - } - ~PayloadArgElasticPtrBaseT() { - freePayloadArgumentExt(pPayArgExt); - } - PayloadArgElasticPtrBaseT(const PayloadArgElasticPtrBaseT &src) { - copyPayloadArgumentExt(pPayArgExt, src); - } - PayloadArgElasticPtrBaseT &operator=(const PayloadArgElasticPtrBaseT &rhs) { - if (this != &rhs) { - copyPayloadArgumentExt(pPayArgExt, rhs); - } - return *this; - } - PayloadArgElasticPtrBaseT(PayloadArgElasticPtrBaseT &&src) noexcept : pPayArgExt(src.pPayArgExt) { - src.pPayArgExt = nullptr; - } - PayloadArgElasticPtrBaseT &operator=(PayloadArgElasticPtrBaseT &&rhs) noexcept { - if (this != &rhs) { - this->pPayArgExt = rhs.pPayArgExt; - rhs.pPayArgExt = nullptr; - } - return *this; - } - PayloadArgumentExtT *pPayArgExt = nullptr; -}; - -struct PayloadArgumentBaseT : PayloadArgElasticPtrBaseT { +struct PayloadArgumentBaseT { ArgTypeT argType = argTypeUnknown; OffsetT offset = Defaults::offset; SourceOffseT sourceOffset = Defaults::sourceOffset; @@ -717,6 +684,7 @@ struct PayloadArgumentBaseT : PayloadArgElasticPtrBaseT { bool imageTransformable = false; bool isPipe = false; bool isPtr = false; + Ext pPayArgExt; }; } // namespace PayloadArgument diff --git a/shared/source/device_binary_format/zebin/zeinfo_decoder_ext.cpp b/shared/source/device_binary_format/zebin/zeinfo_decoder_ext.cpp index bc1adcf249..57a454fabd 100644 --- a/shared/source/device_binary_format/zebin/zeinfo_decoder_ext.cpp +++ b/shared/source/device_binary_format/zebin/zeinfo_decoder_ext.cpp @@ -32,18 +32,16 @@ void freeExecEnvExt(ExecutionEnvExt *envExt) { void populateKernelExecutionEnvironmentExt(KernelDescriptor &dst, const KernelExecutionEnvBaseT &execEnv, const Types::Version &srcZeInfoVersion) { } -namespace Types::Kernel::PayloadArgument { -PayloadArgumentExtT *allocatePayloadArgumentExt() { - return nullptr; -} -void freePayloadArgumentExt(PayloadArgumentExtT *pPayArgExt) { -} -void copyPayloadArgumentExt(PayloadArgumentExtT *&pPayArgExtOut, const PayloadArgElasticPtrBaseT &src) { -} -} // namespace Types::Kernel::PayloadArgument - DecodeError populateKernelPayloadArgumentExt(NEO::KernelDescriptor &dst, const KernelPayloadArgBaseT &src, std::string &outErrReason) { return DecodeError::unhandledBinary; } } // namespace NEO::Zebin::ZeInfo + +template <> +void cloneExt(ExtUniquePtrT &dst, const NEO::Zebin::ZeInfo::Types::Kernel::PayloadArgument::PayloadArgumentExtT &src) { +} + +template <> +void destroyExt(NEO::Zebin::ZeInfo::Types::Kernel::PayloadArgument::PayloadArgumentExtT *dst) { +} diff --git a/shared/source/utilities/mem_lifetime.h b/shared/source/utilities/mem_lifetime.h index 0d4b673f2b..e0e5717837 100644 --- a/shared/source/utilities/mem_lifetime.h +++ b/shared/source/utilities/mem_lifetime.h @@ -13,24 +13,63 @@ template using ExtUniquePtrT = std::unique_ptr; template -void cloneExt(ExtUniquePtrT &dst, const T *src); +void cloneExt(ExtUniquePtrT &dst, const T &src); template void destroyExt(T *dst); +namespace Impl { + +template +struct UniquePtrWrapperOps { + auto operator*() const noexcept(noexcept(std::declval().operator*())) { + return *static_cast(this)->ptr; + } + + auto operator->() const noexcept { + return static_cast(this)->ptr.operator->(); + } + + explicit operator bool() const noexcept { + return static_cast(static_cast(this)->ptr); + } + + template + friend bool operator==(const ParentT &lhs, const Rhs &rhs) { + return lhs.ptr == rhs; + } + + template + friend bool operator==(const Lhs &lhs, const ParentT &rhs) { + return lhs == rhs.ptr; + } + + friend bool operator==(const ParentT &lhs, const ParentT &rhs) { + return lhs.ptr == rhs.ptr; + } +}; + +} // namespace Impl + template -struct Ext { +struct Ext : Impl::UniquePtrWrapperOps> { Ext(T *ptr) : ptr(ptr, destroyExt) {} Ext() = default; Ext(const Ext &rhs) { - cloneExt(this->ptr, rhs.ptr.get()); + if (rhs.ptr.get()) { + cloneExt(this->ptr, *rhs.ptr.get()); + } } Ext &operator=(const Ext &rhs) { if (this == &rhs) { return *this; } - cloneExt(this->ptr, rhs.ptr.get()); + if (rhs.ptr.get()) { + cloneExt(this->ptr, *rhs.ptr.get()); + } else { + ptr.reset(); + } return *this; } @@ -38,7 +77,7 @@ struct Ext { }; template -struct Clonable { +struct Clonable : Impl::UniquePtrWrapperOps>> { Clonable(T *ptr) : ptr(ptr) {} Clonable() = default; Clonable(const Clonable &rhs) { diff --git a/shared/test/unit_test/device_binary_format/zebin_decoder_exclusive_tests.cpp b/shared/test/unit_test/device_binary_format/zebin_decoder_exclusive_tests.cpp index 00dade2546..313bbc03d9 100644 --- a/shared/test/unit_test/device_binary_format/zebin_decoder_exclusive_tests.cpp +++ b/shared/test/unit_test/device_binary_format/zebin_decoder_exclusive_tests.cpp @@ -17,5 +17,4 @@ TEST(ExtBaseKernelDescriptorAndPayloadArgumentPointers, givenKernelDescriptorAnd EXPECT_EQ(nullptr, kd.kernelDescriptorExt); EXPECT_EQ(nullptr, arg.pPayArgExt); - EXPECT_EQ(nullptr, NEO::Zebin::ZeInfo::Types::Kernel::PayloadArgument::allocatePayloadArgumentExt()); } diff --git a/shared/test/unit_test/device_binary_format/zebin_decoder_tests.cpp b/shared/test/unit_test/device_binary_format/zebin_decoder_tests.cpp index 974e8fb83f..55075c4332 100644 --- a/shared/test/unit_test/device_binary_format/zebin_decoder_tests.cpp +++ b/shared/test/unit_test/device_binary_format/zebin_decoder_tests.cpp @@ -2394,11 +2394,7 @@ TEST(BaseKernelDescriptorAndPayloadArgumentPoinetrsExt, givenKernelDescriptorAnd NEO::KernelDescriptor kd; NEO::Zebin::ZeInfo::KernelPayloadArgBaseT arg; - if (nullptr != kd.kernelDescriptorExt) { - EXPECT_NE(nullptr, arg.pPayArgExt); - } else { - EXPECT_EQ(nullptr, arg.pPayArgExt); - } + EXPECT_EQ(nullptr, arg.pPayArgExt); } TEST(PopulateKernelSourceAttributes, GivenInvalidKernelAttributeWhenPopulatingKernelSourceAttributesThenKernelIsInvalidFlagIsSet) { diff --git a/shared/test/unit_test/utilities/mem_lifetime_tests.cpp b/shared/test/unit_test/utilities/mem_lifetime_tests.cpp index cab342cf8c..6cee1b8405 100644 --- a/shared/test/unit_test/utilities/mem_lifetime_tests.cpp +++ b/shared/test/unit_test/utilities/mem_lifetime_tests.cpp @@ -22,7 +22,7 @@ struct TestStruct { }; template <> -void cloneExt(ExtUniquePtrT &dst, const TestStruct *src) { +void cloneExt(ExtUniquePtrT &dst, const TestStruct &src) { dst.reset(new TestStruct{true}); } @@ -89,4 +89,92 @@ TEST(Clonable, GivenTypeThenProvidesClonableUniquePtr) { auto &re = e; e = re; EXPECT_EQ(nullptr, e.ptr.get()); -} \ No newline at end of file +} + +struct PtrOpsMock { + auto operator*() const noexcept { + opCalled |= deref; + return *this; + } + + auto operator->() const noexcept { + opCalled |= arrow; + return this; + } + + explicit operator bool() const noexcept { + opCalled |= boolCast; + return true; + } + + enum OpCalled : uint32_t { + deref = 1 << 0, + arrow = 1 << 1, + boolCast = 1 << 2, + compareAsLhs = 1 << 3, + compareAsRhs = 1 << 4, + compare = 1 << 5, + }; + + mutable uint32_t opCalled = 0; +}; + +static bool operator==(const PtrOpsMock &lhs, const void *) { + lhs.opCalled |= PtrOpsMock::compareAsLhs; + return true; +} + +static bool operator==(const void *, const PtrOpsMock &rhs) { + rhs.opCalled |= PtrOpsMock::compareAsRhs; + return true; +} + +static bool operator==(const PtrOpsMock &lhs, const PtrOpsMock &rhs) { + lhs.opCalled |= PtrOpsMock::compare; + rhs.opCalled |= PtrOpsMock::compare; + return true; +} + +TEST(Clonable, WhenUsingUniquePtrWrapperOpsThenForwardsToParentsPtr) { + struct Parent : Impl::UniquePtrWrapperOps { + PtrOpsMock ptr; + }; + + { + Parent parent; + [[maybe_unused]] auto res = *parent; + EXPECT_EQ(PtrOpsMock::deref, parent.ptr.opCalled); + } + + { + Parent parent; + [[maybe_unused]] auto res = parent->opCalled; + EXPECT_EQ(PtrOpsMock::arrow, parent.ptr.opCalled); + } + + { + Parent parent; + [[maybe_unused]] auto res = static_cast(parent); + EXPECT_EQ(PtrOpsMock::boolCast, parent.ptr.opCalled); + } + + { + Parent parent; + [[maybe_unused]] auto res = parent == nullptr; + EXPECT_EQ(PtrOpsMock::compareAsLhs, parent.ptr.opCalled); + } + + { + Parent parent; + [[maybe_unused]] auto res = nullptr == parent; + EXPECT_EQ(PtrOpsMock::compareAsRhs, parent.ptr.opCalled); + } + + { + Parent parentA; + Parent parentB; + [[maybe_unused]] auto res = parentA == parentB; + EXPECT_EQ(PtrOpsMock::compare, parentA.ptr.opCalled); + EXPECT_EQ(PtrOpsMock::compare, parentB.ptr.opCalled); + } +}