From b05be25349653353372bffdc0adde0dcd870b945 Mon Sep 17 00:00:00 2001 From: Jaime Arteaga Date: Sat, 10 Apr 2021 02:14:57 +0000 Subject: [PATCH] Add sub-devices to the list of devices in a context Signed-off-by: Jaime Arteaga --- .../core/source/context/context_imp.cpp | 8 ++++ level_zero/core/source/context/context_imp.h | 2 + .../core/source/driver/driver_handle_imp.cpp | 4 +- .../sources/context/test_context.cpp | 40 +++++++++++++++---- 4 files changed, 45 insertions(+), 9 deletions(-) diff --git a/level_zero/core/source/context/context_imp.cpp b/level_zero/core/source/context/context_imp.cpp index 0414debb13..adc7266de0 100644 --- a/level_zero/core/source/context/context_imp.cpp +++ b/level_zero/core/source/context/context_imp.cpp @@ -42,6 +42,14 @@ ContextImp::ContextImp(DriverHandle *driverHandle) { this->driverHandle = static_cast(driverHandle); } +void ContextImp::addDeviceAndSubDevices(Device *device) { + this->devices.insert(std::make_pair(device->toHandle(), device)); + DeviceImp *deviceImp = static_cast(device); + for (auto subDevice : deviceImp->subDevices) { + this->addDeviceAndSubDevices(subDevice); + } +} + ze_result_t ContextImp::allocHostMem(const ze_host_mem_alloc_desc_t *hostDesc, size_t size, size_t alignment, diff --git a/level_zero/core/source/context/context_imp.h b/level_zero/core/source/context/context_imp.h index 0be6e70ba7..51a7124ed0 100644 --- a/level_zero/core/source/context/context_imp.h +++ b/level_zero/core/source/context/context_imp.h @@ -109,6 +109,8 @@ struct ContextImp : Context { const ze_image_desc_t *desc, ze_image_handle_t *phImage) override; + void addDeviceAndSubDevices(Device *device); + std::map &getDevices() { return devices; } diff --git a/level_zero/core/source/driver/driver_handle_imp.cpp b/level_zero/core/source/driver/driver_handle_imp.cpp index a291c1f285..edb79622a0 100644 --- a/level_zero/core/source/driver/driver_handle_imp.cpp +++ b/level_zero/core/source/driver/driver_handle_imp.cpp @@ -45,11 +45,11 @@ ze_result_t DriverHandleImp::createContext(const ze_context_desc_t *desc, if (numDevices == 0) { for (auto device : this->devices) { - context->getDevices().insert(std::make_pair(device->toHandle(), device)); + context->addDeviceAndSubDevices(device); } } else { for (uint32_t i = 0; i < numDevices; i++) { - context->getDevices().insert(std::make_pair(phDevices[i], Device::fromHandle(phDevices[i]))); + context->addDeviceAndSubDevices(Device::fromHandle(phDevices[i])); } } diff --git a/level_zero/core/test/unit_tests/sources/context/test_context.cpp b/level_zero/core/test/unit_tests/sources/context/test_context.cpp index 66da1c00ae..4db0f57045 100644 --- a/level_zero/core/test/unit_tests/sources/context/test_context.cpp +++ b/level_zero/core/test/unit_tests/sources/context/test_context.cpp @@ -45,21 +45,47 @@ TEST_F(MultiDeviceContextTests, } TEST_F(MultiDeviceContextTests, - whenCreatingContextWithNonZeroNumDevicesThenOnlySpecifiedDeviceIsAssociatedWithTheContext) { + whenCreatingContextWithNonZeroNumDevicesThenOnlySpecifiedDeviceAndItsSubDevicesAreAssociatedWithTheContext) { ze_context_handle_t hContext; ze_context_desc_t desc; - uint32_t count = 1; - ze_device_handle_t device = driverHandle->devices[1]->toHandle(); + ze_device_handle_t device0 = driverHandle->devices[0]->toHandle(); + DeviceImp *deviceImp0 = static_cast(device0); + uint32_t subDeviceCount0 = 0; + ze_result_t res = deviceImp0->getSubDevices(&subDeviceCount0, nullptr); + EXPECT_EQ(res, ZE_RESULT_SUCCESS); + EXPECT_EQ(subDeviceCount0, numSubDevices); + std::vector subDevices0(subDeviceCount0); + res = deviceImp0->getSubDevices(&subDeviceCount0, subDevices0.data()); + EXPECT_EQ(res, ZE_RESULT_SUCCESS); - ze_result_t res = driverHandle->createContext(&desc, 1u, &device, &hContext); + ze_device_handle_t device1 = driverHandle->devices[1]->toHandle(); + DeviceImp *deviceImp1 = static_cast(device1); + uint32_t subDeviceCount1 = 0; + res = deviceImp1->getSubDevices(&subDeviceCount1, nullptr); + EXPECT_EQ(res, ZE_RESULT_SUCCESS); + EXPECT_EQ(subDeviceCount1, numSubDevices); + std::vector subDevices1(subDeviceCount1); + res = deviceImp1->getSubDevices(&subDeviceCount1, subDevices1.data()); + EXPECT_EQ(res, ZE_RESULT_SUCCESS); + + res = driverHandle->createContext(&desc, 1u, &device1, &hContext); EXPECT_EQ(ZE_RESULT_SUCCESS, res); ContextImp *contextImp = static_cast(Context::fromHandle(hContext)); - EXPECT_EQ(contextImp->getDevices().size(), count); - EXPECT_FALSE(contextImp->isDeviceDefinedForThisContext(driverHandle->devices[0])); - EXPECT_TRUE(contextImp->isDeviceDefinedForThisContext(driverHandle->devices[1])); + uint32_t expectedDeviceCountInContext = 1 + subDeviceCount1; + EXPECT_EQ(contextImp->getDevices().size(), expectedDeviceCountInContext); + + EXPECT_FALSE(contextImp->isDeviceDefinedForThisContext(L0::Device::fromHandle(device0))); + for (auto subDevice : subDevices0) { + EXPECT_FALSE(contextImp->isDeviceDefinedForThisContext(L0::Device::fromHandle(subDevice))); + } + + EXPECT_TRUE(contextImp->isDeviceDefinedForThisContext(L0::Device::fromHandle(device1))); + for (auto subDevice : subDevices1) { + EXPECT_TRUE(contextImp->isDeviceDefinedForThisContext(L0::Device::fromHandle(subDevice))); + } res = L0::Context::fromHandle(hContext)->destroy(); EXPECT_EQ(ZE_RESULT_SUCCESS, res);