diff --git a/opencl/test/unit_test/device/device_tests.cpp b/opencl/test/unit_test/device/device_tests.cpp index cb89807ab2..12743f6485 100644 --- a/opencl/test/unit_test/device/device_tests.cpp +++ b/opencl/test/unit_test/device/device_tests.cpp @@ -12,6 +12,7 @@ #include "shared/test/unit_test/helpers/debug_manager_state_restore.h" #include "shared/test/unit_test/helpers/ult_hw_config.h" #include "shared/test/unit_test/helpers/variable_backup.h" +#include "shared/test/unit_test/mocks/ult_device_factory.h" #include "opencl/source/command_stream/tbx_command_stream_receiver.h" #include "opencl/source/platform/platform.h" @@ -329,6 +330,18 @@ TEST(DeviceCreation, givenDeviceWhenCheckingEnginesCountThenNumberGreaterThanZer EXPECT_GT(HwHelper::getEnginesCount(device->getHardwareInfo()), 0u); } +TEST(DeviceCreation, givenDeviceWhenCheckingParentDeviceThenCorrectValueIsReturned) { + UltDeviceFactory deviceFactory{2, 2}; + + EXPECT_EQ(nullptr, deviceFactory.rootDevices[0]->getParentDevice()); + EXPECT_EQ(deviceFactory.rootDevices[0], deviceFactory.subDevices[0]->getParentDevice()); + EXPECT_EQ(deviceFactory.rootDevices[0], deviceFactory.subDevices[1]->getParentDevice()); + + EXPECT_EQ(nullptr, deviceFactory.rootDevices[1]->getParentDevice()); + EXPECT_EQ(deviceFactory.rootDevices[1], deviceFactory.subDevices[2]->getParentDevice()); + EXPECT_EQ(deviceFactory.rootDevices[1], deviceFactory.subDevices[3]->getParentDevice()); +} + using DeviceHwTest = ::testing::Test; HWTEST_F(DeviceHwTest, givenHwHelperInputWhenInitializingCsrThenCreatePageTableManagerIfNeeded) { diff --git a/shared/source/device/device.h b/shared/source/device/device.h index 167ab0fb40..47aaf6a98b 100644 --- a/shared/source/device/device.h +++ b/shared/source/device/device.h @@ -86,6 +86,7 @@ class Device : public ReferenceTrackedObject { virtual uint32_t getRootDeviceIndex() const = 0; virtual uint32_t getNumAvailableDevices() const = 0; virtual Device *getDeviceById(uint32_t deviceId) const = 0; + virtual Device *getParentDevice() const = 0; virtual DeviceBitfield getDeviceBitfield() const = 0; static decltype(&PerformanceCounters::create) createPerformanceCountersFunc; diff --git a/shared/source/device/root_device.cpp b/shared/source/device/root_device.cpp index 9cfd52ecd7..29250bb895 100644 --- a/shared/source/device/root_device.cpp +++ b/shared/source/device/root_device.cpp @@ -48,7 +48,11 @@ Device *RootDevice::getDeviceById(uint32_t deviceId) const { return const_cast(this); } return subdevices[deviceId]; -}; +} + +Device *RootDevice::getParentDevice() const { + return nullptr; +} SubDevice *RootDevice::createSubDevice(uint32_t subDeviceIndex) { return Device::create(executionEnvironment, subDeviceIndex, *this); diff --git a/shared/source/device/root_device.h b/shared/source/device/root_device.h index b8b62e78f8..6cd2ce815c 100644 --- a/shared/source/device/root_device.h +++ b/shared/source/device/root_device.h @@ -20,6 +20,7 @@ class RootDevice : public Device { uint32_t getNumAvailableDevices() const override; uint32_t getRootDeviceIndex() const override; Device *getDeviceById(uint32_t deviceId) const override; + Device *getParentDevice() const override; uint32_t getNumSubDevices() const; diff --git a/shared/source/device/sub_device.cpp b/shared/source/device/sub_device.cpp index e65ca27a04..afaa7c08c0 100644 --- a/shared/source/device/sub_device.cpp +++ b/shared/source/device/sub_device.cpp @@ -43,6 +43,10 @@ Device *SubDevice::getDeviceById(uint32_t deviceId) const { return const_cast(this); } +Device *SubDevice::getParentDevice() const { + return &rootDevice; +} + uint64_t SubDevice::getGlobalMemorySize() const { auto globalMemorySize = Device::getGlobalMemorySize(); return globalMemorySize / rootDevice.getNumAvailableDevices(); diff --git a/shared/source/device/sub_device.h b/shared/source/device/sub_device.h index 024a4f3b63..517f05188b 100644 --- a/shared/source/device/sub_device.h +++ b/shared/source/device/sub_device.h @@ -19,6 +19,7 @@ class SubDevice : public Device { uint32_t getNumAvailableDevices() const override; uint32_t getRootDeviceIndex() const override; Device *getDeviceById(uint32_t deviceId) const override; + Device *getParentDevice() const override; uint32_t getSubDeviceIndex() const;