From 3a51656af5ae38814626ae3e997d29b90d64a0a0 Mon Sep 17 00:00:00 2001 From: Joshua Santosh Ranjan Date: Fri, 18 Nov 2022 14:21:34 +0000 Subject: [PATCH] Support ReturnSubDevicesAsApiDevices for zeFabricVertexGetExp Related-To: LOCI-3635 Signed-off-by: Joshua Santosh Ranjan --- .../core/source/driver/driver_handle_imp.cpp | 33 +++++++++--- .../unit_tests/sources/fabric/test_fabric.cpp | 52 +++++++++++++++++++ 2 files changed, 79 insertions(+), 6 deletions(-) diff --git a/level_zero/core/source/driver/driver_handle_imp.cpp b/level_zero/core/source/driver/driver_handle_imp.cpp index 99daa7df70..7208ebe52a 100644 --- a/level_zero/core/source/driver/driver_handle_imp.cpp +++ b/level_zero/core/source/driver/driver_handle_imp.cpp @@ -650,16 +650,37 @@ ze_result_t DriverHandleImp::fabricVertexGetExp(uint32_t *pCount, ze_fabric_vert this->initializeVertexes(); } - uint32_t fabricVertexCount = static_cast(this->fabricVertices.size()); + bool exposeSubDevices = false; + if (NEO::DebugManager.flags.ReturnSubDevicesAsApiDevices.get() != -1) { + exposeSubDevices = NEO::DebugManager.flags.ReturnSubDevicesAsApiDevices.get(); + } + if (*pCount == 0) { - *pCount = fabricVertexCount; + if (exposeSubDevices) { + for (auto &vertex : this->fabricVertices) { + *pCount += std::max(static_cast(vertex->subVertices.size()), 1u); + } + } else { + *pCount = static_cast(this->fabricVertices.size()); + } return ZE_RESULT_SUCCESS; } - *pCount = std::min(fabricVertexCount, *pCount); - - for (uint32_t index = 0; index < *pCount; index++) { - phVertices[index] = this->fabricVertices[index]->toHandle(); + uint32_t i = 0; + for (auto vertex : this->fabricVertices) { + if (vertex->subVertices.size() > 0 && exposeSubDevices) { + for (auto subVertex : vertex->subVertices) { + phVertices[i++] = subVertex->toHandle(); + if (i == *pCount) { + return ZE_RESULT_SUCCESS; + } + } + } else { + phVertices[i++] = vertex->toHandle(); + if (i == *pCount) { + return ZE_RESULT_SUCCESS; + } + } } return ZE_RESULT_SUCCESS; diff --git a/level_zero/core/test/unit_tests/sources/fabric/test_fabric.cpp b/level_zero/core/test/unit_tests/sources/fabric/test_fabric.cpp index dd2cb975d6..0d7fe2a53f 100644 --- a/level_zero/core/test/unit_tests/sources/fabric/test_fabric.cpp +++ b/level_zero/core/test/unit_tests/sources/fabric/test_fabric.cpp @@ -215,6 +215,58 @@ TEST_F(FabricVertexFixture, GivenDevicesAreCreatedWhenFabricVertexIsNotSetToDevi EXPECT_EQ(hVertex, nullptr); } +class FabricVertexSubdeviceAsDeviceTestFixture : public MultiDeviceFixture, + public ::testing::Test { + void SetUp() override { + NEO::DebugManager.flags.ZE_AFFINITY_MASK.set("0,1.1,2"); + NEO::DebugManager.flags.ReturnSubDevicesAsApiDevices.set(1); + MultiDeviceFixture::setUp(); + } + + void TearDown() override { + MultiDeviceFixture::tearDown(); + } + DebugManagerStateRestore restorer; +}; + +TEST_F(FabricVertexSubdeviceAsDeviceTestFixture, GivenReturnSubDevicesAsApiDevicesIsSetWhenFabricVerticesGetExpIsCalledCorrectVerticesAreReturned) { + + uint32_t count = 0; + std::vector phVertices; + EXPECT_EQ(driverHandle->fabricVertexGetExp(&count, nullptr), ZE_RESULT_SUCCESS); + EXPECT_EQ(count, 5u); + + // Requesting for a reduced count + count -= 1; + phVertices.resize(count); + EXPECT_EQ(driverHandle->fabricVertexGetExp(&count, phVertices.data()), ZE_RESULT_SUCCESS); + + ze_device_handle_t hDevice{}; + // 0.0 + EXPECT_EQ(L0::zeFabricVertexGetDeviceExp(phVertices[0], &hDevice), ZE_RESULT_SUCCESS); + DeviceImp *deviceImp = static_cast(hDevice); + EXPECT_TRUE(deviceImp->isSubdevice); + EXPECT_EQ(deviceImp->getPhysicalSubDeviceId(), 0u); + + // 0.1 + EXPECT_EQ(L0::zeFabricVertexGetDeviceExp(phVertices[1], &hDevice), ZE_RESULT_SUCCESS); + deviceImp = static_cast(hDevice); + EXPECT_TRUE(deviceImp->isSubdevice); + EXPECT_EQ(deviceImp->getPhysicalSubDeviceId(), 1u); + + // 1.1 + EXPECT_EQ(L0::zeFabricVertexGetDeviceExp(phVertices[2], &hDevice), ZE_RESULT_SUCCESS); + deviceImp = static_cast(hDevice); + EXPECT_FALSE(deviceImp->isSubdevice); + EXPECT_EQ(deviceImp->getPhysicalSubDeviceId(), 1u); + + // 2.0 + EXPECT_EQ(L0::zeFabricVertexGetDeviceExp(phVertices[3], &hDevice), ZE_RESULT_SUCCESS); + deviceImp = static_cast(hDevice); + EXPECT_TRUE(deviceImp->isSubdevice); + EXPECT_EQ(deviceImp->getPhysicalSubDeviceId(), 0u); +} + using FabricEdgeFixture = Test; TEST_F(FabricEdgeFixture, GivenFabricVerticesAreCreatedWhenZeFabricEdgeGetExpIsCalledThenReturnSuccess) {