Conditionally disable DX sharing extensions

Change-Id: Idbc253f072a9400962b7500e75ba6fd86e5e6b59
Signed-off-by: Katarzyna Cencelewska <katarzyna.cencelewska@intel.com>
This commit is contained in:
Katarzyna Cencelewska
2020-03-18 15:19:03 +01:00
committed by sys_ocldev
parent 71950fa7cc
commit 9c716a8d98
13 changed files with 133 additions and 4 deletions

View File

@@ -196,6 +196,7 @@ void ClDevice::initializeCaps() {
if (driverInfo) {
name.assign(driverInfo.get()->getDeviceName(name).c_str());
driverVersion.assign(driverInfo.get()->getVersion(driverVersion).c_str());
sharingFactory.verifyExtensionSupport(driverInfo.get());
}
auto &hwHelper = HwHelper::get(hwInfo.platform.eRenderCoreFamily);

View File

@@ -9,6 +9,8 @@
#include "opencl/source/sharings/d3d/enable_d3d.h"
#include "shared/source/os_interface/driver_info.h"
#include "opencl/source/api/api.h"
#include "opencl/source/context/context.h"
#include "opencl/source/context/context.inl"
@@ -103,7 +105,7 @@ std::unique_ptr<SharingContextBuilder> D3DSharingBuilderFactory<D3D>::createCont
};
std::string D3DSharingBuilderFactory<D3DTypesHelper::D3D9>::getExtensions() {
return "cl_intel_dx9_media_sharing cl_khr_dx9_media_sharing ";
return extensionEnabled ? "cl_intel_dx9_media_sharing cl_khr_dx9_media_sharing " : "";
}
std::string D3DSharingBuilderFactory<D3DTypesHelper::D3D10>::getExtensions() {
@@ -167,6 +169,15 @@ void *D3DSharingBuilderFactory<D3DTypesHelper::D3D11>::getExtensionFunctionAddre
}
return nullptr;
}
void D3DSharingBuilderFactory<D3DTypesHelper::D3D9>::setExtensionEnabled(DriverInfo *driverInfo) {
extensionEnabled = driverInfo->getMediaSharingSupport();
}
void D3DSharingBuilderFactory<D3DTypesHelper::D3D10>::setExtensionEnabled(DriverInfo *driverInfo) {}
void D3DSharingBuilderFactory<D3DTypesHelper::D3D11>::setExtensionEnabled(DriverInfo *driverInfo) {}
static SharingFactory::RegisterSharing<D3DSharingBuilderFactory<D3DTypesHelper::D3D9>, D3DSharingFunctions<D3DTypesHelper::D3D9>> D3D9Sharing;
static SharingFactory::RegisterSharing<D3DSharingBuilderFactory<D3DTypesHelper::D3D10>, D3DSharingFunctions<D3DTypesHelper::D3D10>> D3D10Sharing;
static SharingFactory::RegisterSharing<D3DSharingBuilderFactory<D3DTypesHelper::D3D11>, D3DSharingFunctions<D3DTypesHelper::D3D11>> D3D11Sharing;

View File

@@ -13,6 +13,7 @@
namespace NEO {
class Context;
class DriverInfo;
template <typename D3D>
struct D3DCreateContextProperties {
@@ -37,5 +38,7 @@ class D3DSharingBuilderFactory : public SharingBuilderFactory {
std::string getExtensions() override;
void fillGlobalDispatchTable() override;
void *getExtensionFunctionAddress(const std::string &functionName) override;
void setExtensionEnabled(DriverInfo *driverInfo) override;
bool extensionEnabled = true;
};
} // namespace NEO

View File

@@ -81,5 +81,15 @@ bool SharingFactory::finalizeProperties(Context &context, int32_t &errcodeRet) {
SharingBuilderFactory *SharingFactory::sharingContextBuilder[SharingType::MAX_SHARING_VALUE] = {
nullptr,
};
void SharingFactory::verifyExtensionSupport(DriverInfo *driverInfo) {
for (auto &builder : sharingContextBuilder) {
if (builder == nullptr)
continue;
builder->setExtensionEnabled(driverInfo);
}
};
void SharingBuilderFactory::setExtensionEnabled(DriverInfo *driverInfo){};
SharingFactory sharingFactory;
} // namespace NEO

View File

@@ -14,6 +14,7 @@
namespace NEO {
class Context;
class DriverInfo;
enum SharingType {
CLGL_SHARING = 0,
@@ -39,6 +40,7 @@ class SharingBuilderFactory {
virtual std::string getExtensions() = 0;
virtual void fillGlobalDispatchTable() {}
virtual void *getExtensionFunctionAddress(const std::string &functionName) = 0;
virtual void setExtensionEnabled(DriverInfo *driverInfo);
};
class SharingFactory {
@@ -59,6 +61,7 @@ class SharingFactory {
std::string getExtensions();
void fillGlobalDispatchTable();
void *getExtensionFunctionAddress(const std::string &functionName);
void verifyExtensionSupport(DriverInfo *driverInfo);
};
extern SharingFactory sharingFactory;

View File

@@ -5,6 +5,7 @@
*
*/
#include "shared/source/os_interface/driver_info.h"
#include "shared/source/utilities/arrayref.h"
#include "shared/test/unit_test/helpers/debug_manager_state_restore.h"
@@ -718,6 +719,20 @@ TEST(D3DSurfaceTest, givenD3DSurfaceWhenInvalidMemObjectIsPassedToValidateUpdate
EXPECT_EQ(CL_INVALID_MEM_OBJECT, result);
}
TEST(D3D9, givenD3D9BuilderAndExtensionEnableTrueWhenGettingExtensionsThenCorrectExtensionsListIsReturned) {
auto builderFactory = std::make_unique<D3DSharingBuilderFactory<D3DTypesHelper::D3D9>>();
builderFactory.get()->extensionEnabled = true;
EXPECT_THAT(builderFactory->getExtensions(), testing::HasSubstr(std::string("cl_intel_dx9_media_sharing")));
EXPECT_THAT(builderFactory->getExtensions(), testing::HasSubstr(std::string("cl_khr_dx9_media_sharing")));
}
TEST(D3D9, givenD3D9BuilderAndExtensionEnableFalseWhenGettingExtensionsThenDx9MediaSheringExtensionsAreNotReturned) {
auto builderFactory = std::make_unique<D3DSharingBuilderFactory<D3DTypesHelper::D3D9>>();
builderFactory.get()->extensionEnabled = false;
EXPECT_THAT(builderFactory->getExtensions(), testing::Not(testing::HasSubstr(std::string("cl_intel_dx9_media_sharing"))));
EXPECT_THAT(builderFactory->getExtensions(), testing::Not(testing::HasSubstr(std::string("cl_khr_dx9_media_sharing"))));
}
TEST(D3D10, givenD3D10BuilderWhenGettingExtensionsThenCorrectExtensionsListIsReturned) {
auto builderFactory = std::make_unique<D3DSharingBuilderFactory<D3DTypesHelper::D3D10>>();
EXPECT_THAT(builderFactory->getExtensions(), testing::HasSubstr(std::string("cl_khr_d3d10_sharing")));
@@ -743,4 +758,45 @@ TEST(D3DSharingFactory, givenEnabledFormatQueryAndFactoryWithD3DSharingsWhenGett
function = sharingFactory.getExtensionFunctionAddress("clGetSupportedD3D11TextureFormatsINTEL");
EXPECT_EQ(reinterpret_cast<void *>(clGetSupportedD3D11TextureFormatsINTEL), function);
}
TEST(D3D9SharingFactory, givenDriverInfoWhenVerifyExtensionSupportThenExtensionEnableIsSetCorrect) {
class MockDriverInfo : public DriverInfo {
public:
bool getMediaSharingSupport() override { return support; };
bool support = true;
};
class MockSharingFactory : public SharingFactory {
public:
MockSharingFactory() {
memcpy_s(savedState, sizeof(savedState), sharingContextBuilder, sizeof(sharingContextBuilder));
}
~MockSharingFactory() {
memcpy_s(sharingContextBuilder, sizeof(sharingContextBuilder), savedState, sizeof(savedState));
}
void prepare() {
for (auto &builder : sharingContextBuilder) {
builder = nullptr;
}
d3d9SharingBuilderFactory = std::make_unique<D3DSharingBuilderFactory<D3DTypesHelper::D3D9>>();
sharingContextBuilder[SharingType::D3D9_SHARING] = d3d9SharingBuilderFactory.get();
}
using SharingFactory::sharingContextBuilder;
std::unique_ptr<D3DSharingBuilderFactory<D3DTypesHelper::D3D9>> d3d9SharingBuilderFactory;
decltype(SharingFactory::sharingContextBuilder) savedState;
};
auto driverInfo = std::make_unique<MockDriverInfo>();
auto mockSharingFactory = std::make_unique<MockSharingFactory>();
mockSharingFactory->prepare();
driverInfo->support = true;
mockSharingFactory->verifyExtensionSupport(driverInfo.get());
EXPECT_TRUE(mockSharingFactory->d3d9SharingBuilderFactory->extensionEnabled);
driverInfo->support = false;
mockSharingFactory->verifyExtensionSupport(driverInfo.get());
EXPECT_FALSE(mockSharingFactory->d3d9SharingBuilderFactory->extensionEnabled);
}
} // namespace NEO

View File

@@ -33,4 +33,9 @@ TEST(DriverInfo, GivenDriverInfoWhenLinuxThenReturnDefault) {
EXPECT_STREQ(defaultVersion.c_str(), resultVersion.c_str());
}
TEST(DriverInfo, givenGetMediaSharingSupportWhenLinuxThenReturnTrue) {
std::unique_ptr<DriverInfo> driverInfo(DriverInfo::create(nullptr));
EXPECT_TRUE(driverInfo->getMediaSharingSupport());
}
} // namespace NEO

View File

@@ -113,6 +113,13 @@ class MockRegistryReader : public SettingsReader {
properNameKey = true;
} else if (key == "DriverVersion") {
properVersionKey = true;
} else if (key == "UserModeDriverName") {
properMediaSharingExtensions = true;
using64bit = true;
return returnString;
} else if (key == "UserModeDriverNameWOW") {
properMediaSharingExtensions = true;
return returnString;
}
if (key == "DriverStorePathForComputeRuntime") {
return driverStorePath;
@@ -127,6 +134,9 @@ class MockRegistryReader : public SettingsReader {
bool properNameKey = false;
bool properVersionKey = false;
std::string driverStorePath = "driverStore\\0x8086";
bool properMediaSharingExtensions = false;
bool using64bit = false;
std::string returnString = "";
};
struct DriverInfoWindowsTest : public ::testing::Test {
@@ -160,6 +170,31 @@ TEST_F(DriverInfoWindowsTest, GivenDriverInfoWhenThenReturnNonNullptr) {
EXPECT_TRUE(registryReaderMock->properVersionKey);
};
TEST(DriverInfo, givenDriverInfoWhenGetStringReturnNotMeaningEmptyStringThenEnableSharingSupport) {
MockDriverInfoWindows driverInfo("");
MockRegistryReader *registryReaderMock = new MockRegistryReader();
driverInfo.registryReader.reset(registryReaderMock);
auto enable = driverInfo.getMediaSharingSupport();
EXPECT_TRUE(enable);
EXPECT_EQ(is64bit, registryReaderMock->using64bit);
EXPECT_TRUE(registryReaderMock->properMediaSharingExtensions);
};
TEST(DriverInfo, givenDriverInfoWhenGetStringReturnMeaningEmptyStringThenDisableSharingSupport) {
MockDriverInfoWindows driverInfo("");
MockRegistryReader *registryReaderMock = new MockRegistryReader();
registryReaderMock->returnString = "<>";
driverInfo.registryReader.reset(registryReaderMock);
auto enable = driverInfo.getMediaSharingSupport();
EXPECT_FALSE(enable);
EXPECT_EQ(is64bit, registryReaderMock->using64bit);
EXPECT_TRUE(registryReaderMock->properMediaSharingExtensions);
};
TEST(DriverInfo, givenFullPathToRegistryWhenCreatingDriverInfoWindowsThenTheRegistryPathIsTrimmed) {
std::string registryPath = "Path\\In\\Registry";
std::string fullRegistryPath = "\\REGISTRY\\MACHINE\\" + registryPath;

View File

@@ -276,4 +276,4 @@ TEST(SharingFactoryTests, givenEnabledFormatQueryAndFactoryWithNoSharingsWhenAsk
auto extensionsList = sharingFactory.getExtensions();
EXPECT_THAT(extensionsList, ::testing::Not(::testing::HasSubstr(Extensions::sharingFormatQuery)));
}
}