Add sub-devices to the list of devices in a context

Signed-off-by: Jaime Arteaga <jaime.a.arteaga.molina@intel.com>
This commit is contained in:
Jaime Arteaga
2021-04-10 02:14:57 +00:00
committed by Compute-Runtime-Automation
parent e35ffb0601
commit b05be25349
4 changed files with 45 additions and 9 deletions

View File

@@ -42,6 +42,14 @@ ContextImp::ContextImp(DriverHandle *driverHandle) {
this->driverHandle = static_cast<DriverHandleImp *>(driverHandle);
}
void ContextImp::addDeviceAndSubDevices(Device *device) {
this->devices.insert(std::make_pair(device->toHandle(), device));
DeviceImp *deviceImp = static_cast<DeviceImp *>(device);
for (auto subDevice : deviceImp->subDevices) {
this->addDeviceAndSubDevices(subDevice);
}
}
ze_result_t ContextImp::allocHostMem(const ze_host_mem_alloc_desc_t *hostDesc,
size_t size,
size_t alignment,

View File

@@ -109,6 +109,8 @@ struct ContextImp : Context {
const ze_image_desc_t *desc,
ze_image_handle_t *phImage) override;
void addDeviceAndSubDevices(Device *device);
std::map<ze_device_handle_t, Device *> &getDevices() {
return devices;
}

View File

@@ -45,11 +45,11 @@ ze_result_t DriverHandleImp::createContext(const ze_context_desc_t *desc,
if (numDevices == 0) {
for (auto device : this->devices) {
context->getDevices().insert(std::make_pair(device->toHandle(), device));
context->addDeviceAndSubDevices(device);
}
} else {
for (uint32_t i = 0; i < numDevices; i++) {
context->getDevices().insert(std::make_pair(phDevices[i], Device::fromHandle(phDevices[i])));
context->addDeviceAndSubDevices(Device::fromHandle(phDevices[i]));
}
}

View File

@@ -45,21 +45,47 @@ TEST_F(MultiDeviceContextTests,
}
TEST_F(MultiDeviceContextTests,
whenCreatingContextWithNonZeroNumDevicesThenOnlySpecifiedDeviceIsAssociatedWithTheContext) {
whenCreatingContextWithNonZeroNumDevicesThenOnlySpecifiedDeviceAndItsSubDevicesAreAssociatedWithTheContext) {
ze_context_handle_t hContext;
ze_context_desc_t desc;
uint32_t count = 1;
ze_device_handle_t device = driverHandle->devices[1]->toHandle();
ze_device_handle_t device0 = driverHandle->devices[0]->toHandle();
DeviceImp *deviceImp0 = static_cast<DeviceImp *>(device0);
uint32_t subDeviceCount0 = 0;
ze_result_t res = deviceImp0->getSubDevices(&subDeviceCount0, nullptr);
EXPECT_EQ(res, ZE_RESULT_SUCCESS);
EXPECT_EQ(subDeviceCount0, numSubDevices);
std::vector<ze_device_handle_t> subDevices0(subDeviceCount0);
res = deviceImp0->getSubDevices(&subDeviceCount0, subDevices0.data());
EXPECT_EQ(res, ZE_RESULT_SUCCESS);
ze_result_t res = driverHandle->createContext(&desc, 1u, &device, &hContext);
ze_device_handle_t device1 = driverHandle->devices[1]->toHandle();
DeviceImp *deviceImp1 = static_cast<DeviceImp *>(device1);
uint32_t subDeviceCount1 = 0;
res = deviceImp1->getSubDevices(&subDeviceCount1, nullptr);
EXPECT_EQ(res, ZE_RESULT_SUCCESS);
EXPECT_EQ(subDeviceCount1, numSubDevices);
std::vector<ze_device_handle_t> subDevices1(subDeviceCount1);
res = deviceImp1->getSubDevices(&subDeviceCount1, subDevices1.data());
EXPECT_EQ(res, ZE_RESULT_SUCCESS);
res = driverHandle->createContext(&desc, 1u, &device1, &hContext);
EXPECT_EQ(ZE_RESULT_SUCCESS, res);
ContextImp *contextImp = static_cast<ContextImp *>(Context::fromHandle(hContext));
EXPECT_EQ(contextImp->getDevices().size(), count);
EXPECT_FALSE(contextImp->isDeviceDefinedForThisContext(driverHandle->devices[0]));
EXPECT_TRUE(contextImp->isDeviceDefinedForThisContext(driverHandle->devices[1]));
uint32_t expectedDeviceCountInContext = 1 + subDeviceCount1;
EXPECT_EQ(contextImp->getDevices().size(), expectedDeviceCountInContext);
EXPECT_FALSE(contextImp->isDeviceDefinedForThisContext(L0::Device::fromHandle(device0)));
for (auto subDevice : subDevices0) {
EXPECT_FALSE(contextImp->isDeviceDefinedForThisContext(L0::Device::fromHandle(subDevice)));
}
EXPECT_TRUE(contextImp->isDeviceDefinedForThisContext(L0::Device::fromHandle(device1)));
for (auto subDevice : subDevices1) {
EXPECT_TRUE(contextImp->isDeviceDefinedForThisContext(L0::Device::fromHandle(subDevice)));
}
res = L0::Context::fromHandle(hContext)->destroy();
EXPECT_EQ(ZE_RESULT_SUCCESS, res);