diff --git a/runtime/device/device.h b/runtime/device/device.h index 39871e2913..870da62390 100644 --- a/runtime/device/device.h +++ b/runtime/device/device.h @@ -91,6 +91,7 @@ class Device : public BaseObject<_cl_device_id> { ExecutionEnvironment *getExecutionEnvironment() const { return executionEnvironment; } const HardwareCapabilities &getHardwareCapabilities() const { return hardwareCapabilities; } uint32_t getDeviceIndex() const { return deviceIndex; } + virtual uint32_t getRootDeviceIndex() const = 0; bool isFullRangeSvm() const { return executionEnvironment->isFullRangeSvm(); } diff --git a/runtime/device/root_device.cpp b/runtime/device/root_device.cpp index e694f13fe7..bbd59b7c2c 100644 --- a/runtime/device/root_device.cpp +++ b/runtime/device/root_device.cpp @@ -18,6 +18,10 @@ uint32_t RootDevice::getNumSubDevices() const { return static_cast(subdevices.size()); } +uint32_t RootDevice::getRootDeviceIndex() const { + return this->deviceIndex; +} + uint32_t RootDevice::getNumAvailableDevices() const { if (subdevices.empty()) { return 1u; diff --git a/runtime/device/root_device.h b/runtime/device/root_device.h index 0a47bb0916..b83226cfda 100644 --- a/runtime/device/root_device.h +++ b/runtime/device/root_device.h @@ -19,6 +19,7 @@ class RootDevice : public Device { bool createDeviceImpl() override; uint32_t getNumAvailableDevices() const override; uint32_t getNumSubDevices() const; + uint32_t getRootDeviceIndex() const override; Device *getDeviceById(uint32_t deviceId) const override; /* We hide the retain and release function of BaseObject. */ diff --git a/runtime/device/sub_device.cpp b/runtime/device/sub_device.cpp index 887a1bedf8..5cdb266554 100644 --- a/runtime/device/sub_device.cpp +++ b/runtime/device/sub_device.cpp @@ -35,6 +35,9 @@ DeviceBitfield SubDevice::getDeviceBitfieldForOsContext() const { uint32_t SubDevice::getNumAvailableDevices() const { return 1u; } +uint32_t SubDevice::getRootDeviceIndex() const { + return this->rootDevice.getRootDeviceIndex(); +} Device *SubDevice::getDeviceById(uint32_t deviceId) const { UNRECOVERABLE_IF(deviceId >= getNumAvailableDevices()); return const_cast(this); diff --git a/runtime/device/sub_device.h b/runtime/device/sub_device.h index cf17b0b2b2..759b8571c0 100644 --- a/runtime/device/sub_device.h +++ b/runtime/device/sub_device.h @@ -18,6 +18,7 @@ class SubDevice : public Device { void retainInternal(); void releaseInternal(); uint32_t getNumAvailableDevices() const override; + uint32_t getRootDeviceIndex() const override; Device *getDeviceById(uint32_t deviceId) const override; protected: diff --git a/unit_tests/device/sub_device_tests.cpp b/unit_tests/device/sub_device_tests.cpp index 4a655c3874..6b02c02d24 100644 --- a/unit_tests/device/sub_device_tests.cpp +++ b/unit_tests/device/sub_device_tests.cpp @@ -22,6 +22,21 @@ TEST(SubDevicesTest, givenDefaultConfigWhenCreateRootDeviceThenItDoesntContainSu EXPECT_EQ(1u, device->getNumAvailableDevices()); } +TEST(SubDevicesTest, givenCreateMultipleSubDevicesFlagSetWhenCreateRootDeviceThenItsSubdevicesHaveProperRootIdSet) { + DebugManagerStateRestore restorer; + DebugManager.flags.CreateMultipleSubDevices.set(2); + auto device = std::unique_ptr(MockDevice::createWithNewExecutionEnvironment(*platformDevices)); + + EXPECT_EQ(2u, device->getNumSubDevices()); + EXPECT_EQ(0u, device->getDeviceIndex()); + EXPECT_EQ(1u, device->subdevices.at(0)->getDeviceIndex()); + EXPECT_EQ(2u, device->subdevices.at(1)->getDeviceIndex()); + + EXPECT_EQ(0u, device->getRootDeviceIndex()); + EXPECT_EQ(0u, device->subdevices.at(0)->getRootDeviceIndex()); + EXPECT_EQ(0u, device->subdevices.at(1)->getRootDeviceIndex()); +} + TEST(SubDevicesTest, givenCreateMultipleSubDevicesFlagSetWhenCreateRootDeviceThenItContainsSubDevices) { DebugManagerStateRestore restorer; DebugManager.flags.CreateMultipleSubDevices.set(2);