diff --git a/opencl/source/device/device_caps.cpp b/opencl/source/device/device_caps.cpp index bd410d20d2..823c9b0f45 100644 --- a/opencl/source/device/device_caps.cpp +++ b/opencl/source/device/device_caps.cpp @@ -196,6 +196,7 @@ void ClDevice::initializeCaps() { if (driverInfo) { name.assign(driverInfo.get()->getDeviceName(name).c_str()); driverVersion.assign(driverInfo.get()->getVersion(driverVersion).c_str()); + sharingFactory.verifyExtensionSupport(driverInfo.get()); } auto &hwHelper = HwHelper::get(hwInfo.platform.eRenderCoreFamily); diff --git a/opencl/source/sharings/d3d/enable_d3d.cpp b/opencl/source/sharings/d3d/enable_d3d.cpp index 72a53f54f2..cf6fc87773 100644 --- a/opencl/source/sharings/d3d/enable_d3d.cpp +++ b/opencl/source/sharings/d3d/enable_d3d.cpp @@ -9,6 +9,8 @@ #include "opencl/source/sharings/d3d/enable_d3d.h" +#include "shared/source/os_interface/driver_info.h" + #include "opencl/source/api/api.h" #include "opencl/source/context/context.h" #include "opencl/source/context/context.inl" @@ -103,7 +105,7 @@ std::unique_ptr D3DSharingBuilderFactory::createCont }; std::string D3DSharingBuilderFactory::getExtensions() { - return "cl_intel_dx9_media_sharing cl_khr_dx9_media_sharing "; + return extensionEnabled ? "cl_intel_dx9_media_sharing cl_khr_dx9_media_sharing " : ""; } std::string D3DSharingBuilderFactory::getExtensions() { @@ -167,6 +169,15 @@ void *D3DSharingBuilderFactory::getExtensionFunctionAddre } return nullptr; } + +void D3DSharingBuilderFactory::setExtensionEnabled(DriverInfo *driverInfo) { + extensionEnabled = driverInfo->getMediaSharingSupport(); +} + +void D3DSharingBuilderFactory::setExtensionEnabled(DriverInfo *driverInfo) {} + +void D3DSharingBuilderFactory::setExtensionEnabled(DriverInfo *driverInfo) {} + static SharingFactory::RegisterSharing, D3DSharingFunctions> D3D9Sharing; static SharingFactory::RegisterSharing, D3DSharingFunctions> D3D10Sharing; static SharingFactory::RegisterSharing, D3DSharingFunctions> D3D11Sharing; diff --git a/opencl/source/sharings/d3d/enable_d3d.h b/opencl/source/sharings/d3d/enable_d3d.h index b6d237b45a..31e021088e 100644 --- a/opencl/source/sharings/d3d/enable_d3d.h +++ b/opencl/source/sharings/d3d/enable_d3d.h @@ -13,6 +13,7 @@ namespace NEO { class Context; +class DriverInfo; template struct D3DCreateContextProperties { @@ -37,5 +38,7 @@ class D3DSharingBuilderFactory : public SharingBuilderFactory { std::string getExtensions() override; void fillGlobalDispatchTable() override; void *getExtensionFunctionAddress(const std::string &functionName) override; + void setExtensionEnabled(DriverInfo *driverInfo) override; + bool extensionEnabled = true; }; } // namespace NEO \ No newline at end of file diff --git a/opencl/source/sharings/sharing_factory.cpp b/opencl/source/sharings/sharing_factory.cpp index 9bc9b5a37f..0759a852ff 100644 --- a/opencl/source/sharings/sharing_factory.cpp +++ b/opencl/source/sharings/sharing_factory.cpp @@ -81,5 +81,15 @@ bool SharingFactory::finalizeProperties(Context &context, int32_t &errcodeRet) { SharingBuilderFactory *SharingFactory::sharingContextBuilder[SharingType::MAX_SHARING_VALUE] = { nullptr, }; + +void SharingFactory::verifyExtensionSupport(DriverInfo *driverInfo) { + for (auto &builder : sharingContextBuilder) { + if (builder == nullptr) + continue; + builder->setExtensionEnabled(driverInfo); + } +}; +void SharingBuilderFactory::setExtensionEnabled(DriverInfo *driverInfo){}; + SharingFactory sharingFactory; } // namespace NEO diff --git a/opencl/source/sharings/sharing_factory.h b/opencl/source/sharings/sharing_factory.h index e62c4b84b5..a91254beb1 100644 --- a/opencl/source/sharings/sharing_factory.h +++ b/opencl/source/sharings/sharing_factory.h @@ -14,6 +14,7 @@ namespace NEO { class Context; +class DriverInfo; enum SharingType { CLGL_SHARING = 0, @@ -39,6 +40,7 @@ class SharingBuilderFactory { virtual std::string getExtensions() = 0; virtual void fillGlobalDispatchTable() {} virtual void *getExtensionFunctionAddress(const std::string &functionName) = 0; + virtual void setExtensionEnabled(DriverInfo *driverInfo); }; class SharingFactory { @@ -59,6 +61,7 @@ class SharingFactory { std::string getExtensions(); void fillGlobalDispatchTable(); void *getExtensionFunctionAddress(const std::string &functionName); + void verifyExtensionSupport(DriverInfo *driverInfo); }; extern SharingFactory sharingFactory; diff --git a/opencl/test/unit_test/d3d_sharing/d3d_tests_part1.cpp b/opencl/test/unit_test/d3d_sharing/d3d_tests_part1.cpp index 2e6a9f20f9..f2bce2d424 100644 --- a/opencl/test/unit_test/d3d_sharing/d3d_tests_part1.cpp +++ b/opencl/test/unit_test/d3d_sharing/d3d_tests_part1.cpp @@ -5,6 +5,7 @@ * */ +#include "shared/source/os_interface/driver_info.h" #include "shared/source/utilities/arrayref.h" #include "shared/test/unit_test/helpers/debug_manager_state_restore.h" @@ -718,6 +719,20 @@ TEST(D3DSurfaceTest, givenD3DSurfaceWhenInvalidMemObjectIsPassedToValidateUpdate EXPECT_EQ(CL_INVALID_MEM_OBJECT, result); } +TEST(D3D9, givenD3D9BuilderAndExtensionEnableTrueWhenGettingExtensionsThenCorrectExtensionsListIsReturned) { + auto builderFactory = std::make_unique>(); + builderFactory.get()->extensionEnabled = true; + EXPECT_THAT(builderFactory->getExtensions(), testing::HasSubstr(std::string("cl_intel_dx9_media_sharing"))); + EXPECT_THAT(builderFactory->getExtensions(), testing::HasSubstr(std::string("cl_khr_dx9_media_sharing"))); +} + +TEST(D3D9, givenD3D9BuilderAndExtensionEnableFalseWhenGettingExtensionsThenDx9MediaSheringExtensionsAreNotReturned) { + auto builderFactory = std::make_unique>(); + builderFactory.get()->extensionEnabled = false; + EXPECT_THAT(builderFactory->getExtensions(), testing::Not(testing::HasSubstr(std::string("cl_intel_dx9_media_sharing")))); + EXPECT_THAT(builderFactory->getExtensions(), testing::Not(testing::HasSubstr(std::string("cl_khr_dx9_media_sharing")))); +} + TEST(D3D10, givenD3D10BuilderWhenGettingExtensionsThenCorrectExtensionsListIsReturned) { auto builderFactory = std::make_unique>(); EXPECT_THAT(builderFactory->getExtensions(), testing::HasSubstr(std::string("cl_khr_d3d10_sharing"))); @@ -743,4 +758,45 @@ TEST(D3DSharingFactory, givenEnabledFormatQueryAndFactoryWithD3DSharingsWhenGett function = sharingFactory.getExtensionFunctionAddress("clGetSupportedD3D11TextureFormatsINTEL"); EXPECT_EQ(reinterpret_cast(clGetSupportedD3D11TextureFormatsINTEL), function); } + +TEST(D3D9SharingFactory, givenDriverInfoWhenVerifyExtensionSupportThenExtensionEnableIsSetCorrect) { + class MockDriverInfo : public DriverInfo { + public: + bool getMediaSharingSupport() override { return support; }; + bool support = true; + }; + class MockSharingFactory : public SharingFactory { + public: + MockSharingFactory() { + memcpy_s(savedState, sizeof(savedState), sharingContextBuilder, sizeof(sharingContextBuilder)); + } + ~MockSharingFactory() { + memcpy_s(sharingContextBuilder, sizeof(sharingContextBuilder), savedState, sizeof(savedState)); + } + + void prepare() { + for (auto &builder : sharingContextBuilder) { + builder = nullptr; + } + d3d9SharingBuilderFactory = std::make_unique>(); + sharingContextBuilder[SharingType::D3D9_SHARING] = d3d9SharingBuilderFactory.get(); + } + + using SharingFactory::sharingContextBuilder; + std::unique_ptr> d3d9SharingBuilderFactory; + decltype(SharingFactory::sharingContextBuilder) savedState; + }; + + auto driverInfo = std::make_unique(); + auto mockSharingFactory = std::make_unique(); + mockSharingFactory->prepare(); + + driverInfo->support = true; + mockSharingFactory->verifyExtensionSupport(driverInfo.get()); + EXPECT_TRUE(mockSharingFactory->d3d9SharingBuilderFactory->extensionEnabled); + + driverInfo->support = false; + mockSharingFactory->verifyExtensionSupport(driverInfo.get()); + EXPECT_FALSE(mockSharingFactory->d3d9SharingBuilderFactory->extensionEnabled); +} } // namespace NEO diff --git a/opencl/test/unit_test/os_interface/linux/driver_info_tests.cpp b/opencl/test/unit_test/os_interface/linux/driver_info_tests.cpp index 9ee0abad29..451fd2d737 100644 --- a/opencl/test/unit_test/os_interface/linux/driver_info_tests.cpp +++ b/opencl/test/unit_test/os_interface/linux/driver_info_tests.cpp @@ -33,4 +33,9 @@ TEST(DriverInfo, GivenDriverInfoWhenLinuxThenReturnDefault) { EXPECT_STREQ(defaultVersion.c_str(), resultVersion.c_str()); } +TEST(DriverInfo, givenGetMediaSharingSupportWhenLinuxThenReturnTrue) { + std::unique_ptr driverInfo(DriverInfo::create(nullptr)); + + EXPECT_TRUE(driverInfo->getMediaSharingSupport()); +} } // namespace NEO diff --git a/opencl/test/unit_test/os_interface/windows/driver_info_tests.cpp b/opencl/test/unit_test/os_interface/windows/driver_info_tests.cpp index f20750b3c0..6412e38cbc 100644 --- a/opencl/test/unit_test/os_interface/windows/driver_info_tests.cpp +++ b/opencl/test/unit_test/os_interface/windows/driver_info_tests.cpp @@ -113,6 +113,13 @@ class MockRegistryReader : public SettingsReader { properNameKey = true; } else if (key == "DriverVersion") { properVersionKey = true; + } else if (key == "UserModeDriverName") { + properMediaSharingExtensions = true; + using64bit = true; + return returnString; + } else if (key == "UserModeDriverNameWOW") { + properMediaSharingExtensions = true; + return returnString; } if (key == "DriverStorePathForComputeRuntime") { return driverStorePath; @@ -127,6 +134,9 @@ class MockRegistryReader : public SettingsReader { bool properNameKey = false; bool properVersionKey = false; std::string driverStorePath = "driverStore\\0x8086"; + bool properMediaSharingExtensions = false; + bool using64bit = false; + std::string returnString = ""; }; struct DriverInfoWindowsTest : public ::testing::Test { @@ -160,6 +170,31 @@ TEST_F(DriverInfoWindowsTest, GivenDriverInfoWhenThenReturnNonNullptr) { EXPECT_TRUE(registryReaderMock->properVersionKey); }; +TEST(DriverInfo, givenDriverInfoWhenGetStringReturnNotMeaningEmptyStringThenEnableSharingSupport) { + MockDriverInfoWindows driverInfo(""); + MockRegistryReader *registryReaderMock = new MockRegistryReader(); + + driverInfo.registryReader.reset(registryReaderMock); + auto enable = driverInfo.getMediaSharingSupport(); + + EXPECT_TRUE(enable); + EXPECT_EQ(is64bit, registryReaderMock->using64bit); + EXPECT_TRUE(registryReaderMock->properMediaSharingExtensions); +}; + +TEST(DriverInfo, givenDriverInfoWhenGetStringReturnMeaningEmptyStringThenDisableSharingSupport) { + MockDriverInfoWindows driverInfo(""); + MockRegistryReader *registryReaderMock = new MockRegistryReader(); + registryReaderMock->returnString = "<>"; + driverInfo.registryReader.reset(registryReaderMock); + + auto enable = driverInfo.getMediaSharingSupport(); + + EXPECT_FALSE(enable); + EXPECT_EQ(is64bit, registryReaderMock->using64bit); + EXPECT_TRUE(registryReaderMock->properMediaSharingExtensions); +}; + TEST(DriverInfo, givenFullPathToRegistryWhenCreatingDriverInfoWindowsThenTheRegistryPathIsTrimmed) { std::string registryPath = "Path\\In\\Registry"; std::string fullRegistryPath = "\\REGISTRY\\MACHINE\\" + registryPath; diff --git a/opencl/test/unit_test/sharings/sharing_factory_tests.cpp b/opencl/test/unit_test/sharings/sharing_factory_tests.cpp index 8a50eea363..1f168e9787 100644 --- a/opencl/test/unit_test/sharings/sharing_factory_tests.cpp +++ b/opencl/test/unit_test/sharings/sharing_factory_tests.cpp @@ -276,4 +276,4 @@ TEST(SharingFactoryTests, givenEnabledFormatQueryAndFactoryWithNoSharingsWhenAsk auto extensionsList = sharingFactory.getExtensions(); EXPECT_THAT(extensionsList, ::testing::Not(::testing::HasSubstr(Extensions::sharingFormatQuery))); -} +} \ No newline at end of file diff --git a/shared/source/os_interface/driver_info.h b/shared/source/os_interface/driver_info.h index e91b99dc0c..187fa3c3fc 100644 --- a/shared/source/os_interface/driver_info.h +++ b/shared/source/os_interface/driver_info.h @@ -21,6 +21,7 @@ class DriverInfo { virtual std::string getDeviceName(std::string defaultName) { return defaultName; }; virtual std::string getVersion(std::string defaultVersion) { return defaultVersion; }; + virtual bool getMediaSharingSupport() { return true; }; }; } // namespace NEO diff --git a/shared/source/os_interface/windows/debug_registry_reader.cpp b/shared/source/os_interface/windows/debug_registry_reader.cpp index 26561deadd..1cb27a4cda 100644 --- a/shared/source/os_interface/windows/debug_registry_reader.cpp +++ b/shared/source/os_interface/windows/debug_registry_reader.cpp @@ -103,7 +103,7 @@ std::string RegistryReader::getSetting(const char *settingName, const std::strin NULL, ®Size); if (ERROR_SUCCESS == success) { - if (regType == REG_SZ) { + if (regType == REG_SZ || regType == REG_MULTI_SZ) { auto regData = std::make_unique(regSize); success = RegQueryValueExA(Key, settingName, diff --git a/shared/source/os_interface/windows/driver_info_windows.cpp b/shared/source/os_interface/windows/driver_info_windows.cpp index 37e64726ad..fa86f9113d 100644 --- a/shared/source/os_interface/windows/driver_info_windows.cpp +++ b/shared/source/os_interface/windows/driver_info_windows.cpp @@ -75,4 +75,8 @@ bool DriverInfoWindows::isCompatibleDriverStore() const { decltype(DriverInfoWindows::createRegistryReaderFunc) DriverInfoWindows::createRegistryReaderFunc = [](const std::string ®istryPath) -> std::unique_ptr { return std::make_unique(false, registryPath); }; + +bool DriverInfoWindows::getMediaSharingSupport() { + return registryReader.get()->getSetting(is64bit ? "UserModeDriverName" : "UserModeDriverNameWOW", std::string("")) != "<>"; +} } // namespace NEO diff --git a/shared/source/os_interface/windows/driver_info_windows.h b/shared/source/os_interface/windows/driver_info_windows.h index f33895517f..ca7b232df7 100644 --- a/shared/source/os_interface/windows/driver_info_windows.h +++ b/shared/source/os_interface/windows/driver_info_windows.h @@ -23,7 +23,7 @@ class DriverInfoWindows : public DriverInfo { std::string getDeviceName(std::string defaultName) override; std::string getVersion(std::string defaultVersion) override; bool isCompatibleDriverStore() const; - + bool getMediaSharingSupport() override; static std::function(const std::string ®istryPath)> createRegistryReaderFunc; protected: