diff --git a/opencl/source/api/api.cpp b/opencl/source/api/api.cpp index 9e56617521..b324e3a3f6 100644 --- a/opencl/source/api/api.cpp +++ b/opencl/source/api/api.cpp @@ -381,12 +381,25 @@ cl_context CL_API_CALL clCreateContext(const cl_context_properties *properties, retVal = CL_INVALID_VALUE; break; } + auto pPlatform = Context::getPlatformFromProperties(properties, retVal); + if (CL_SUCCESS != retVal) { + break; + } ClDeviceVector allDevs(devices, numDevices); - context = Context::create(properties, allDevs, funcNotify, userData, retVal); - if (context != nullptr) { - gtpinNotifyContextCreate(context); + if (!pPlatform) { + pPlatform = allDevs[0]->getPlatform(); } + for (auto &pClDevice : allDevs) { + if (pClDevice->getPlatform() != pPlatform) { + retVal = CL_INVALID_DEVICE; + break; + } + } + if (CL_SUCCESS != retVal) { + break; + } + context = Context::create(properties, allDevs, funcNotify, userData, retVal); } while (false); if (errcodeRet) { @@ -413,10 +426,13 @@ cl_context CL_API_CALL clCreateContextFromType(const cl_context_properties *prop retVal = CL_INVALID_VALUE; break; } - + auto pPlatform = Context::getPlatformFromProperties(properties, retVal); + if (CL_SUCCESS != retVal) { + break; + } cl_uint numDevices = 0; /* Query the number of device first. */ - retVal = clGetDeviceIDs(nullptr, deviceType, 0, nullptr, &numDevices); + retVal = clGetDeviceIDs(pPlatform, deviceType, 0, nullptr, &numDevices); if (retVal != CL_SUCCESS) { break; } @@ -425,7 +441,7 @@ cl_context CL_API_CALL clCreateContextFromType(const cl_context_properties *prop StackVec supportedDevs; supportedDevs.resize(numDevices); - retVal = clGetDeviceIDs(nullptr, deviceType, numDevices, supportedDevs.begin(), nullptr); + retVal = clGetDeviceIDs(pPlatform, deviceType, numDevices, supportedDevs.begin(), nullptr); DEBUG_BREAK_IF(retVal != CL_SUCCESS); if (!DebugManager.flags.EnableMultiRootDeviceContexts.get()) { @@ -434,9 +450,6 @@ cl_context CL_API_CALL clCreateContextFromType(const cl_context_properties *prop ClDeviceVector deviceVector(supportedDevs.begin(), numDevices); pContext = Context::create(properties, deviceVector, funcNotify, userData, retVal); - if (pContext != nullptr) { - gtpinNotifyContextCreate(pContext); - } } while (false); if (errcodeRet) { diff --git a/opencl/source/cl_device/cl_device.cpp b/opencl/source/cl_device/cl_device.cpp index 76b356d126..d571590aad 100644 --- a/opencl/source/cl_device/cl_device.cpp +++ b/opencl/source/cl_device/cl_device.cpp @@ -260,5 +260,7 @@ void ClDevice::getQueueFamilyName(char *outputName, size_t maxOutputNameLength, UNRECOVERABLE_IF(name.length() > maxOutputNameLength + 1); strncpy_s(outputName, maxOutputNameLength, name.c_str(), name.size()); } - +Platform *ClDevice::getPlatform() const { + return castToObject(platformId); +} } // namespace NEO diff --git a/opencl/source/cl_device/cl_device.h b/opencl/source/cl_device/cl_device.h index 946a9743ea..5e65f9cfa1 100644 --- a/opencl/source/cl_device/cl_device.h +++ b/opencl/source/cl_device/cl_device.h @@ -125,6 +125,7 @@ class ClDevice : public BaseObject<_cl_device_id> { static cl_command_queue_capabilities_intel getQueueFamilyCapabilitiesAll(); MOCKABLE_VIRTUAL cl_command_queue_capabilities_intel getQueueFamilyCapabilities(EngineGroupType type); void getQueueFamilyName(char *outputName, size_t maxOutputNameLength, EngineGroupType type); + Platform *getPlatform() const; protected: void initializeCaps(); diff --git a/opencl/source/context/context.cpp b/opencl/source/context/context.cpp index 79f5b293ae..8e35e6e28a 100644 --- a/opencl/source/context/context.cpp +++ b/opencl/source/context/context.cpp @@ -135,12 +135,8 @@ bool Context::createImpl(const cl_context_properties *properties, propertiesCurrent += 2; switch (propertyType) { - case CL_CONTEXT_PLATFORM: { - if (castToObject(reinterpret_cast(propertyValue)) == nullptr) { - errcodeRet = CL_INVALID_PLATFORM; - return false; - } - } break; + case CL_CONTEXT_PLATFORM: + break; case CL_CONTEXT_SHOW_DIAGNOSTICS_INTEL: driverDiagnosticsUsed = static_cast(propertyValue); break; @@ -457,4 +453,19 @@ void Context::setupContextType() { } } +Platform *Context::getPlatformFromProperties(const cl_context_properties *properties, cl_int &errcode) { + errcode = CL_SUCCESS; + auto propertiesCurrent = properties; + while (propertiesCurrent && *propertiesCurrent) { + auto propertyType = propertiesCurrent[0]; + auto propertyValue = propertiesCurrent[1]; + propertiesCurrent += 2; + if (CL_CONTEXT_PLATFORM == propertyType) { + Platform *pPlatform = nullptr; + errcode = validateObject(WithCastToInternal(reinterpret_cast(propertyValue), &pPlatform)); + return pPlatform; + } + } + return nullptr; +} } // namespace NEO diff --git a/opencl/source/context/context.h b/opencl/source/context/context.h index 8127650980..606c10412c 100644 --- a/opencl/source/context/context.h +++ b/opencl/source/context/context.h @@ -13,6 +13,7 @@ #include "opencl/source/cl_device/cl_device_vector.h" #include "opencl/source/context/context_type.h" #include "opencl/source/context/driver_diagnostics.h" +#include "opencl/source/gtpin/gtpin_notify.h" #include "opencl/source/helpers/base_object.h" #include "opencl/source/helpers/destructor_callbacks.h" @@ -33,6 +34,7 @@ class SharingFunctions; class SVMAllocsManager; class SchedulerKernel; class Program; +class Platform; template <> struct OpenCLObjectMapper<_cl_context> { @@ -60,7 +62,7 @@ class Context : public BaseObject<_cl_context> { delete pContext; pContext = nullptr; } - + gtpinNotifyContextCreate(pContext); return pContext; } @@ -162,6 +164,8 @@ class Context : public BaseObject<_cl_context> { } const std::map &getDeviceBitfields() const { return deviceBitfields; }; + static Platform *getPlatformFromProperties(const cl_context_properties *properties, cl_int &errcode); + protected: struct BuiltInKernel { const char *pSource = nullptr; diff --git a/opencl/test/unit_test/api/cl_create_context_from_type_tests.inl b/opencl/test/unit_test/api/cl_create_context_from_type_tests.inl index b6ab8600fb..5101ddb224 100644 --- a/opencl/test/unit_test/api/cl_create_context_from_type_tests.inl +++ b/opencl/test/unit_test/api/cl_create_context_from_type_tests.inl @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017-2020 Intel Corporation + * Copyright (C) 2017-2021 Intel Corporation * * SPDX-License-Identifier: MIT * @@ -86,6 +86,10 @@ TEST_F(clCreateContextFromTypeTests, GivenNonDefaultPlatformInContextCreationPro auto clContext = clCreateContextFromType(properties, CL_DEVICE_TYPE_GPU, nullptr, nullptr, &retVal); EXPECT_EQ(CL_SUCCESS, retVal); EXPECT_NE(nullptr, clContext); + auto pContext = castToObject(clContext); + for (auto i = 0u; i < nonDefaultPlatform->getNumDevices(); i++) { + EXPECT_EQ(nonDefaultPlatform->getClDevice(i), pContext->getDevice(i)); + } clReleaseContext(clContext); } diff --git a/opencl/test/unit_test/api/cl_create_context_tests.inl b/opencl/test/unit_test/api/cl_create_context_tests.inl index 7014bbd97a..e0e1a11de5 100644 --- a/opencl/test/unit_test/api/cl_create_context_tests.inl +++ b/opencl/test/unit_test/api/cl_create_context_tests.inl @@ -1,5 +1,5 @@ /* - * Copyright (C) 2017-2020 Intel Corporation + * Copyright (C) 2017-2021 Intel Corporation * * SPDX-License-Identifier: MIT * @@ -180,7 +180,7 @@ TEST_F(clCreateContextTests, GivenNonDefaultPlatformInContextCreationPropertiesW clReleaseContext(clContext); } -TEST_F(clCreateContextFromTypeTests, GivenNonDefaultPlatformWithInvalidIcdDispatchInContextCreationPropertiesWhenCreatingContextThenInvalidPlatformErrorIsReturned) { +TEST_F(clCreateContextTests, GivenNonDefaultPlatformWithInvalidIcdDispatchInContextCreationPropertiesWhenCreatingContextThenInvalidPlatformErrorIsReturned) { auto nonDefaultPlatform = std::make_unique(); nonDefaultPlatform->initializeWithNewDevices(); cl_platform_id nonDefaultPlatformCl = nonDefaultPlatform.get(); @@ -192,4 +192,42 @@ TEST_F(clCreateContextFromTypeTests, GivenNonDefaultPlatformWithInvalidIcdDispat EXPECT_EQ(nullptr, clContext); } +TEST_F(clCreateContextTests, GivenDeviceNotAssociatedToPlatformInPropertiesWhenCreatingContextThenInvalidDeviceErrorIsReturned) { + auto nonDefaultPlatform = std::make_unique(); + nonDefaultPlatform->initializeWithNewDevices(); + cl_device_id clDevice = platform()->getClDevice(0); + cl_platform_id nonDefaultPlatformCl = nonDefaultPlatform.get(); + cl_context_properties properties[3] = {CL_CONTEXT_PLATFORM, reinterpret_cast(nonDefaultPlatformCl), 0}; + + auto clContext = clCreateContext(properties, 1, &clDevice, nullptr, nullptr, &retVal); + EXPECT_EQ(CL_INVALID_DEVICE, retVal); + EXPECT_EQ(nullptr, clContext); +} + +TEST_F(clCreateContextTests, GivenDevicesFromDifferentPlatformsWhenCreatingContextWithoutSpecifiedPlatformThenInvalidDeviceErrorIsReturned) { + auto platform1 = std::make_unique(); + auto platform2 = std::make_unique(); + platform1->initializeWithNewDevices(); + platform2->initializeWithNewDevices(); + cl_device_id clDevices[] = {platform1->getClDevice(0), platform2->getClDevice(0)}; + + auto clContext = clCreateContext(nullptr, 2, clDevices, nullptr, nullptr, &retVal); + EXPECT_EQ(CL_INVALID_DEVICE, retVal); + EXPECT_EQ(nullptr, clContext); +} + +TEST_F(clCreateContextTests, GivenDevicesFromDifferentPlatformsWhenCreatingContextWithSpecifiedPlatformThenInvalidDeviceErrorIsReturned) { + auto platform1 = std::make_unique(); + auto platform2 = std::make_unique(); + platform1->initializeWithNewDevices(); + platform2->initializeWithNewDevices(); + cl_device_id clDevices[] = {platform1->getClDevice(0), platform2->getClDevice(0)}; + + cl_platform_id clPlatform = platform1.get(); + cl_context_properties properties[3] = {CL_CONTEXT_PLATFORM, reinterpret_cast(clPlatform), 0}; + + auto clContext = clCreateContext(properties, 2, clDevices, nullptr, nullptr, &retVal); + EXPECT_EQ(CL_INVALID_DEVICE, retVal); + EXPECT_EQ(nullptr, clContext); +} } // namespace ClCreateContextTests