diff --git a/level_zero/api/core/ze_context.cpp b/level_zero/api/core/ze_context.cpp index 5c76d1b90e..c7648fdeb4 100644 --- a/level_zero/api/core/ze_context.cpp +++ b/level_zero/api/core/ze_context.cpp @@ -1,5 +1,5 @@ /* - * Copyright (C) 2019-2020 Intel Corporation + * Copyright (C) 2019-2021 Intel Corporation * * SPDX-License-Identifier: MIT * @@ -14,7 +14,17 @@ zeContextCreate( ze_driver_handle_t hDriver, const ze_context_desc_t *desc, ze_context_handle_t *phContext) { - return L0::DriverHandle::fromHandle(hDriver)->createContext(desc, phContext); + return L0::DriverHandle::fromHandle(hDriver)->createContext(desc, 0u, nullptr, phContext); +} + +ZE_APIEXPORT ze_result_t ZE_APICALL +zeContextCreateEx( + ze_driver_handle_t hDriver, + const ze_context_desc_t *desc, + uint32_t numDevices, + ze_device_handle_t *phDevices, + ze_context_handle_t *phContext) { + return L0::DriverHandle::fromHandle(hDriver)->createContext(desc, numDevices, phDevices, phContext); } ZE_APIEXPORT ze_result_t ZE_APICALL diff --git a/level_zero/api/core/ze_core_loader.cpp b/level_zero/api/core/ze_core_loader.cpp index 45f3df4433..2a92c2bbbb 100644 --- a/level_zero/api/core/ze_core_loader.cpp +++ b/level_zero/api/core/ze_core_loader.cpp @@ -94,6 +94,7 @@ zeGetContextProcAddrTable( ze_result_t result = ZE_RESULT_SUCCESS; pDdiTable->pfnCreate = zeContextCreate; + pDdiTable->pfnCreateEx = zeContextCreateEx; pDdiTable->pfnDestroy = zeContextDestroy; pDdiTable->pfnGetStatus = zeContextGetStatus; pDdiTable->pfnSystemBarrier = zeContextSystemBarrier; diff --git a/level_zero/core/source/context/context_imp.h b/level_zero/core/source/context/context_imp.h index 4a93353c83..551343f886 100644 --- a/level_zero/core/source/context/context_imp.h +++ b/level_zero/core/source/context/context_imp.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2020 Intel Corporation + * Copyright (C) 2020-2021 Intel Corporation * * SPDX-License-Identifier: MIT * @@ -109,7 +109,12 @@ struct ContextImp : Context { const ze_image_desc_t *desc, ze_image_handle_t *phImage) override; + std::vector &getDevices() { + return devices; + } + protected: + std::vector devices; DriverHandle *driverHandle = nullptr; }; diff --git a/level_zero/core/source/driver/driver_handle.h b/level_zero/core/source/driver/driver_handle.h index 208bb083e0..ca4af663fe 100644 --- a/level_zero/core/source/driver/driver_handle.h +++ b/level_zero/core/source/driver/driver_handle.h @@ -24,6 +24,8 @@ struct L0EnvVariables; struct DriverHandle : _ze_driver_handle_t { virtual ze_result_t createContext(const ze_context_desc_t *desc, + uint32_t numDevices, + ze_device_handle_t *phDevices, ze_context_handle_t *phContext) = 0; virtual ze_result_t getDevice(uint32_t *pCount, ze_device_handle_t *phDevices) = 0; virtual ze_result_t getProperties(ze_driver_properties_t *properties) = 0; diff --git a/level_zero/core/source/driver/driver_handle_imp.cpp b/level_zero/core/source/driver/driver_handle_imp.cpp index 734a5c1964..220d5265ab 100644 --- a/level_zero/core/source/driver/driver_handle_imp.cpp +++ b/level_zero/core/source/driver/driver_handle_imp.cpp @@ -33,6 +33,8 @@ struct DriverHandleImp *GlobalDriver; DriverHandleImp::DriverHandleImp() = default; ze_result_t DriverHandleImp::createContext(const ze_context_desc_t *desc, + uint32_t numDevices, + ze_device_handle_t *phDevices, ze_context_handle_t *phContext) { ContextImp *context = new ContextImp(this); if (nullptr == context) { @@ -41,6 +43,15 @@ ze_result_t DriverHandleImp::createContext(const ze_context_desc_t *desc, *phContext = context->toHandle(); + if (numDevices == 0) { + context->getDevices().resize(numDevices); + context->getDevices() = this->devices; + } else { + for (uint32_t i = 0; i < numDevices; i++) { + context->getDevices().push_back(Device::fromHandle(phDevices[i])); + } + } + return ZE_RESULT_SUCCESS; } diff --git a/level_zero/core/source/driver/driver_handle_imp.h b/level_zero/core/source/driver/driver_handle_imp.h index 815806cdae..2bd7a076d6 100644 --- a/level_zero/core/source/driver/driver_handle_imp.h +++ b/level_zero/core/source/driver/driver_handle_imp.h @@ -20,7 +20,10 @@ struct DriverHandleImp : public DriverHandle { ~DriverHandleImp() override; DriverHandleImp(); - ze_result_t createContext(const ze_context_desc_t *desc, ze_context_handle_t *phContext) override; + ze_result_t createContext(const ze_context_desc_t *desc, + uint32_t numDevices, + ze_device_handle_t *phDevices, + ze_context_handle_t *phContext) override; ze_result_t getDevice(uint32_t *pCount, ze_device_handle_t *phDevices) override; ze_result_t getProperties(ze_driver_properties_t *properties) override; ze_result_t getApiVersion(ze_api_version_t *version) override; diff --git a/level_zero/core/test/unit_tests/fixtures/device_fixture.cpp b/level_zero/core/test/unit_tests/fixtures/device_fixture.cpp index 5e2d7594ba..e3be0d62c2 100644 --- a/level_zero/core/test/unit_tests/fixtures/device_fixture.cpp +++ b/level_zero/core/test/unit_tests/fixtures/device_fixture.cpp @@ -46,7 +46,7 @@ void ContextFixture::SetUp() { ze_context_handle_t hContext = {}; ze_context_desc_t desc; - ze_result_t res = driverHandle->createContext(&desc, &hContext); + ze_result_t res = driverHandle->createContext(&desc, 0u, nullptr, &hContext); EXPECT_EQ(ZE_RESULT_SUCCESS, res); EXPECT_NE(nullptr, hContext); context = L0::Context::fromHandle(hContext); diff --git a/level_zero/core/test/unit_tests/fixtures/host_pointer_manager_fixture.h b/level_zero/core/test/unit_tests/fixtures/host_pointer_manager_fixture.h index 745317b200..29d19dbe29 100644 --- a/level_zero/core/test/unit_tests/fixtures/host_pointer_manager_fixture.h +++ b/level_zero/core/test/unit_tests/fixtures/host_pointer_manager_fixture.h @@ -46,7 +46,7 @@ struct HostPointerManagerFixure { ASSERT_NE(nullptr, heapPointer); ze_context_desc_t desc; - ze_result_t ret = hostDriverHandle->createContext(&desc, &hContext); + ze_result_t ret = hostDriverHandle->createContext(&desc, 0u, nullptr, &hContext); EXPECT_EQ(ZE_RESULT_SUCCESS, ret); context = L0::Context::fromHandle(hContext); } diff --git a/level_zero/core/test/unit_tests/gen12lp/test_device_gen12lp.cpp b/level_zero/core/test/unit_tests/gen12lp/test_device_gen12lp.cpp index d9ddf52f4a..89e39aa0bd 100644 --- a/level_zero/core/test/unit_tests/gen12lp/test_device_gen12lp.cpp +++ b/level_zero/core/test/unit_tests/gen12lp/test_device_gen12lp.cpp @@ -255,7 +255,7 @@ HWTEST2_F(DeviceQueueGroupTest, ze_context_handle_t hContext; ze_context_desc_t desc; - res = driverHandle->createContext(&desc, &hContext); + res = driverHandle->createContext(&desc, 0u, nullptr, &hContext); EXPECT_EQ(ZE_RESULT_SUCCESS, res); L0::Context *context = Context::fromHandle(hContext); diff --git a/level_zero/core/test/unit_tests/gen9/test_device_gen9.cpp b/level_zero/core/test/unit_tests/gen9/test_device_gen9.cpp index 5b436a2926..8ed55d700b 100644 --- a/level_zero/core/test/unit_tests/gen9/test_device_gen9.cpp +++ b/level_zero/core/test/unit_tests/gen9/test_device_gen9.cpp @@ -1,5 +1,5 @@ /* - * Copyright (C) 2020 Intel Corporation + * Copyright (C) 2020-2021 Intel Corporation * * SPDX-License-Identifier: MIT * @@ -76,7 +76,7 @@ HWTEST2_F(DeviceQueueGroupTest, givenQueueGroupsReturnedThenCommandListIsCreated ze_context_handle_t hContext; ze_context_desc_t contextDesc; - res = driverHandle->createContext(&contextDesc, &hContext); + res = driverHandle->createContext(&contextDesc, 0u, nullptr, &hContext); EXPECT_EQ(ZE_RESULT_SUCCESS, res); L0::Context *context = Context::fromHandle(hContext); diff --git a/level_zero/core/test/unit_tests/sources/context/test_context.cpp b/level_zero/core/test/unit_tests/sources/context/test_context.cpp index fc8828ada6..a18cf11de2 100644 --- a/level_zero/core/test/unit_tests/sources/context/test_context.cpp +++ b/level_zero/core/test/unit_tests/sources/context/test_context.cpp @@ -12,6 +12,7 @@ #include "test.h" #include "level_zero/core/source/context/context_imp.h" +#include "level_zero/core/source/driver/driver_handle_imp.h" #include "level_zero/core/source/image/image.h" #include "level_zero/core/test/unit_tests/fixtures/device_fixture.h" #include "level_zero/core/test/unit_tests/fixtures/host_pointer_manager_fixture.h" @@ -22,11 +23,51 @@ namespace L0 { namespace ult { +using MultiDeviceContextTests = Test; + +TEST_F(MultiDeviceContextTests, + whenCreatingContextWithZeroNumDevicesThenAllDevicesAreAssociatedWithTheContext) { + ze_context_handle_t hContext; + ze_context_desc_t desc; + + ze_result_t res = driverHandle->createContext(&desc, 0u, nullptr, &hContext); + EXPECT_EQ(ZE_RESULT_SUCCESS, res); + + ContextImp *contextImp = static_cast(Context::fromHandle(hContext)); + + for (size_t i = 0; i < driverHandle->devices.size(); i++) { + EXPECT_EQ(driverHandle->devices[i], contextImp->getDevices()[i]); + } + + res = L0::Context::fromHandle(hContext)->destroy(); + EXPECT_EQ(ZE_RESULT_SUCCESS, res); +} + +TEST_F(MultiDeviceContextTests, + whenCreatingContextWithNonZeroNumDevicesThenOnlySpecifiedDeviceIsAssociatedWithTheContext) { + ze_context_handle_t hContext; + ze_context_desc_t desc; + + uint32_t count = 1; + ze_device_handle_t device = driverHandle->devices[1]->toHandle(); + + ze_result_t res = driverHandle->createContext(&desc, 1u, &device, &hContext); + EXPECT_EQ(ZE_RESULT_SUCCESS, res); + + ContextImp *contextImp = static_cast(Context::fromHandle(hContext)); + + EXPECT_EQ(contextImp->getDevices().size(), count); + EXPECT_EQ(contextImp->getDevices()[0], driverHandle->devices[1]); + + res = L0::Context::fromHandle(hContext)->destroy(); + EXPECT_EQ(ZE_RESULT_SUCCESS, res); +} + using ContextGetStatusTest = Test; TEST_F(ContextGetStatusTest, givenCallToContextGetStatusThenCorrectErrorCodeIsReturnedWhenResourcesHaveBeenReleased) { ze_context_handle_t hContext; ze_context_desc_t desc; - ze_result_t res = driverHandle->createContext(&desc, &hContext); + ze_result_t res = driverHandle->createContext(&desc, 0u, nullptr, &hContext); EXPECT_EQ(ZE_RESULT_SUCCESS, res); L0::Context *context = L0::Context::fromHandle(hContext); @@ -51,7 +92,7 @@ TEST_F(ContextTest, whenCreatingAndDestroyingContextThenSuccessIsReturned) { ze_context_handle_t hContext; ze_context_desc_t desc; - ze_result_t res = driverHandle->createContext(&desc, &hContext); + ze_result_t res = driverHandle->createContext(&desc, 0u, nullptr, &hContext); EXPECT_EQ(ZE_RESULT_SUCCESS, res); res = L0::Context::fromHandle(hContext)->destroy(); @@ -338,7 +379,7 @@ TEST_F(ContextTest, whenGettingDriverThenDriverIsRetrievedSuccessfully) { ze_context_handle_t hContext; ze_context_desc_t desc; - ze_result_t res = driverHandle->createContext(&desc, &hContext); + ze_result_t res = driverHandle->createContext(&desc, 0u, nullptr, &hContext); EXPECT_EQ(ZE_RESULT_SUCCESS, res); ContextImp *contextImp = static_cast(L0::Context::fromHandle(hContext)); @@ -353,7 +394,7 @@ TEST_F(ContextTest, whenCallingVirtualMemInterfacesThenUnsupportedIsReturned) { ze_context_handle_t hContext; ze_context_desc_t desc; - ze_result_t res = driverHandle->createContext(&desc, &hContext); + ze_result_t res = driverHandle->createContext(&desc, 0u, nullptr, &hContext); EXPECT_EQ(ZE_RESULT_SUCCESS, res); ContextImp *contextImp = static_cast(L0::Context::fromHandle(hContext)); @@ -379,7 +420,7 @@ TEST_F(ContextTest, whenCallingPhysicalMemInterfacesThenUnsupportedIsReturned) { ze_context_handle_t hContext; ze_context_desc_t desc; - ze_result_t res = driverHandle->createContext(&desc, &hContext); + ze_result_t res = driverHandle->createContext(&desc, 0u, nullptr, &hContext); EXPECT_EQ(ZE_RESULT_SUCCESS, res); ContextImp *contextImp = static_cast(L0::Context::fromHandle(hContext)); @@ -400,7 +441,7 @@ TEST_F(ContextTest, whenCallingMappingVirtualInterfacesThenUnsupportedIsReturned ze_context_handle_t hContext; ze_context_desc_t desc; - ze_result_t res = driverHandle->createContext(&desc, &hContext); + ze_result_t res = driverHandle->createContext(&desc, 0u, nullptr, &hContext); EXPECT_EQ(ZE_RESULT_SUCCESS, res); ContextImp *contextImp = static_cast(L0::Context::fromHandle(hContext)); @@ -440,7 +481,7 @@ using IsAtMostProductDG1 = IsAtMostProduct; HWTEST2_F(ContextTest, WhenCreatingImageThenSuccessIsReturned, IsAtMostProductDG1) { ze_context_handle_t hContext; ze_context_desc_t desc; - ze_result_t res = driverHandle->createContext(&desc, &hContext); + ze_result_t res = driverHandle->createContext(&desc, 0u, nullptr, &hContext); EXPECT_EQ(ZE_RESULT_SUCCESS, res); ContextImp *contextImp = static_cast(L0::Context::fromHandle(hContext)); @@ -459,4 +500,4 @@ HWTEST2_F(ContextTest, WhenCreatingImageThenSuccessIsReturned, IsAtMostProductDG } } // namespace ult -} // namespace L0 \ No newline at end of file +} // namespace L0 diff --git a/level_zero/core/test/unit_tests/sources/memory/test_memory.cpp b/level_zero/core/test/unit_tests/sources/memory/test_memory.cpp index fd75062b2a..c9a997c484 100644 --- a/level_zero/core/test/unit_tests/sources/memory/test_memory.cpp +++ b/level_zero/core/test/unit_tests/sources/memory/test_memory.cpp @@ -641,7 +641,7 @@ TEST_F(MemoryIPCTests, ze_context_handle_t hContext; ze_context_desc_t desc; - ze_result_t result = driverHandle->createContext(&desc, &hContext); + ze_result_t result = driverHandle->createContext(&desc, 0u, nullptr, &hContext); EXPECT_EQ(ZE_RESULT_SUCCESS, result); ContextImp *contextImp = static_cast(L0::Context::fromHandle(hContext)); @@ -682,7 +682,7 @@ TEST_F(MemoryIPCTests, ze_context_handle_t hContext; ze_context_desc_t desc; - ze_result_t result = driverHandle->createContext(&desc, &hContext); + ze_result_t result = driverHandle->createContext(&desc, 0u, nullptr, &hContext); EXPECT_EQ(ZE_RESULT_SUCCESS, result); ContextImp *contextImp = static_cast(L0::Context::fromHandle(hContext));