diff --git a/opencl/source/context/context.cpp b/opencl/source/context/context.cpp index 4ce20365b5..a33545be20 100644 --- a/opencl/source/context/context.cpp +++ b/opencl/source/context/context.cpp @@ -84,6 +84,14 @@ cl_int Context::setDestructorCallback(void(CL_CALLBACK *funcNotify)(cl_context, return CL_SUCCESS; } +const std::set &Context::getRootDeviceIndices() const { + return rootDeviceIndices; +} + +uint32_t Context::getMaxRootDeviceIndex() const { + return maxRootDeviceIndex; +} + DeviceQueue *Context::getDefaultDeviceQueue() { return defaultDeviceQueue; } @@ -178,22 +186,21 @@ bool Context::createImpl(const cl_context_properties *properties, return false; } - this->driverDiagnostics = driverDiagnostics.release(); - if (inputDevices.size() > 1) { - if (!DebugManager.flags.EnableMultiRootDeviceContexts.get()) { - auto rootDeviceIndex = inputDevices[0]->getRootDeviceIndex(); - for (const auto &device : inputDevices) { - if (device->getRootDeviceIndex() != rootDeviceIndex) { - DEBUG_BREAK_IF("No support for context with multiple root devices"); - errcodeRet = CL_OUT_OF_HOST_MEMORY; - return false; - } - } - } + for (const auto &device : inputDevices) { + rootDeviceIndices.insert(device->getRootDeviceIndex()); } + + this->driverDiagnostics = driverDiagnostics.release(); + if (rootDeviceIndices.size() > 1 && !DebugManager.flags.EnableMultiRootDeviceContexts.get()) { + DEBUG_BREAK_IF("No support for context with multiple root devices"); + errcodeRet = CL_OUT_OF_HOST_MEMORY; + return false; + } + this->devices = inputDevices; if (devices.size() > 0) { + maxRootDeviceIndex = *std::max_element(rootDeviceIndices.begin(), rootDeviceIndices.end(), std::less()); auto device = this->getDevice(0); this->memoryManager = device->getMemoryManager(); if (memoryManager->isAsyncDeleterEnabled()) { diff --git a/opencl/source/context/context.h b/opencl/source/context/context.h index 2e83c8a57d..db62d8a1a9 100644 --- a/opencl/source/context/context.h +++ b/opencl/source/context/context.h @@ -17,6 +17,7 @@ #include "opencl/source/helpers/destructor_callback.h" #include +#include namespace NEO { @@ -88,6 +89,10 @@ class Context : public BaseObject<_cl_context> { return svmAllocsManager; } + const std::set &getRootDeviceIndices() const; + + uint32_t getMaxRootDeviceIndex() const; + DeviceQueue *getDefaultDeviceQueue(); void setDefaultDeviceQueue(DeviceQueue *queue); @@ -155,6 +160,9 @@ class Context : public BaseObject<_cl_context> { cl_int processExtraProperties(cl_context_properties propertyType, cl_context_properties propertyValue); void setupContextType(); + std::set rootDeviceIndices = {}; + uint32_t maxRootDeviceIndex = std::numeric_limits::max(); + const cl_context_properties *properties = nullptr; size_t numProperties = 0u; void(CL_CALLBACK *contextCallback)(const char *, const void *, size_t, void *) = nullptr; diff --git a/opencl/test/unit_test/api/cl_create_context_tests.inl b/opencl/test/unit_test/api/cl_create_context_tests.inl index 2016b723a7..3fd9b3d4ab 100644 --- a/opencl/test/unit_test/api/cl_create_context_tests.inl +++ b/opencl/test/unit_test/api/cl_create_context_tests.inl @@ -99,6 +99,39 @@ TEST_F(clCreateContextTests, givenEnabledMultipleRootDeviceSupportWhenCreateCont clReleaseContext(context); } +TEST_F(clCreateContextTests, givenMultipleRootDevicesWhenCreateContextThenRootDeviceIndicesSetIsFilled) { + UltClDeviceFactory deviceFactory{3, 2}; + DebugManager.flags.EnableMultiRootDeviceContexts.set(true); + cl_device_id devices[] = {deviceFactory.rootDevices[0], deviceFactory.rootDevices[1], deviceFactory.rootDevices[2]}; + auto context = clCreateContext(nullptr, 3u, devices, eventCallBack, nullptr, &retVal); + EXPECT_NE(nullptr, context); + EXPECT_EQ(CL_SUCCESS, retVal); + + auto pContext = castToObject(context); + auto rootDeviceIndices = pContext->getRootDeviceIndices(); + + for (auto numDevice = 0u; numDevice < pContext->getNumDevices(); numDevice++) { + auto rootDeviceIndex = rootDeviceIndices.find(pContext->getDevice(numDevice)->getRootDeviceIndex()); + EXPECT_EQ(*rootDeviceIndex, pContext->getDevice(numDevice)->getRootDeviceIndex()); + } + + clReleaseContext(context); +} + +TEST_F(clCreateContextTests, givenMultipleRootDevicesWhenCreateContextThenMaxRootDeviceIndexIsProperlyFilled) { + UltClDeviceFactory deviceFactory{3, 0}; + DebugManager.flags.EnableMultiRootDeviceContexts.set(true); + cl_device_id devices[] = {deviceFactory.rootDevices[0], deviceFactory.rootDevices[2]}; + auto context = clCreateContext(nullptr, 2u, devices, eventCallBack, nullptr, &retVal); + EXPECT_NE(nullptr, context); + EXPECT_EQ(CL_SUCCESS, retVal); + + auto pContext = castToObject(context); + EXPECT_EQ(2u, pContext->getMaxRootDeviceIndex()); + + clReleaseContext(context); +} + TEST_F(clCreateContextTests, givenInvalidContextCreationPropertiesThenContextCreationFails) { cl_context_properties invalidProperties[3] = {CL_CONTEXT_PLATFORM, (cl_context_properties) nullptr, 0}; auto context = clCreateContext(invalidProperties, 1u, &testedClDevice, nullptr, nullptr, &retVal);