Fix windows wrappers

Signed-off-by: Bartosz Dunajski <bartosz.dunajski@intel.com>
This commit is contained in:
Bartosz Dunajski
2021-10-05 11:21:46 +00:00
committed by Compute-Runtime-Automation
parent cfad41f28a
commit cd702af3a1
8 changed files with 165 additions and 182 deletions

View File

@@ -37,7 +37,6 @@ set(IGDRCL_SRCS_tests_os_interface_windows
${CMAKE_CURRENT_SOURCE_DIR}/wddm_memory_manager_allocate_in_device_pool_tests.inl
${CMAKE_CURRENT_SOURCE_DIR}/wddm_residency_controller_tests.cpp
${CMAKE_CURRENT_SOURCE_DIR}/wddm_residency_handler_tests.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mock_registry_reader.cpp
${CMAKE_CURRENT_SOURCE_DIR}/mock_sys_calls.h
)

View File

@@ -1,87 +0,0 @@
/*
* Copyright (C) 2018-2021 Intel Corporation
*
* SPDX-License-Identifier: MIT
*
*/
#include "shared/source/os_interface/windows/windows_wrapper.h"
uint32_t regOpenKeySuccessCount = 0u;
uint32_t regQueryValueSuccessCount = 0u;
uint64_t regQueryValueExpectedData = 0ull;
const HKEY validHkey = reinterpret_cast<HKEY>(0);
LSTATUS APIENTRY RegOpenKeyExA(
HKEY hKey,
LPCSTR lpSubKey,
DWORD ulOptions,
REGSAM samDesired,
PHKEY phkResult) {
if (regOpenKeySuccessCount > 0) {
regOpenKeySuccessCount--;
if (phkResult) {
*phkResult = validHkey;
}
return ERROR_SUCCESS;
}
return ERROR_FILE_NOT_FOUND;
};
LSTATUS APIENTRY RegQueryValueExA(
HKEY hKey,
LPCSTR lpValueName,
LPDWORD lpReserved,
LPDWORD lpType,
LPBYTE lpData,
LPDWORD lpcbData) {
if (hKey == validHkey && regQueryValueSuccessCount > 0) {
regQueryValueSuccessCount--;
if (lpcbData) {
if (strcmp(lpValueName, "settingSourceString") == 0) {
const auto settingSource = "registry";
if (lpData) {
strcpy(reinterpret_cast<char *>(lpData), settingSource);
} else {
*lpcbData = static_cast<DWORD>(strlen(settingSource) + 1u);
if (lpType) {
*lpType = REG_SZ;
}
}
} else if (strcmp(lpValueName, "settingSourceInt") == 0) {
if (lpData) {
*reinterpret_cast<DWORD *>(lpData) = 1;
} else {
*lpcbData = sizeof(DWORD);
}
} else if (strcmp(lpValueName, "settingSourceInt64") == 0) {
if (lpData) {
*reinterpret_cast<INT64 *>(lpData) = 0xffffffffeeeeeeee;
} else {
*lpcbData = sizeof(INT64);
}
} else if (strcmp(lpValueName, "settingSourceBinary") == 0) {
const auto settingSource = L"registry";
auto size = wcslen(settingSource) * sizeof(wchar_t);
if (lpData) {
memcpy(reinterpret_cast<wchar_t *>(lpData), settingSource, size);
} else {
*lpcbData = static_cast<DWORD>(size);
if (lpType) {
*lpType = REG_BINARY;
}
}
} else if (strcmp(lpValueName, "boolRegistryKey") == 0) {
if (*lpcbData == sizeof(int64_t)) {
if (lpData) {
*reinterpret_cast<uint64_t *>(lpData) = regQueryValueExpectedData;
}
}
}
}
return ERROR_SUCCESS;
}
return ERROR_FILE_NOT_FOUND;
};

View File

@@ -7,17 +7,20 @@
#include "opencl/test/unit_test/os_interface/windows/registry_reader_tests.h"
#include "shared/source/os_interface/windows/sys_calls.h"
#include "shared/test/common/helpers/variable_backup.h"
#include "test.h"
using namespace NEO;
namespace NEO {
using RegistryReaderTest = ::testing::Test;
namespace SysCalls {
extern uint32_t regOpenKeySuccessCount;
extern uint32_t regQueryValueSuccessCount;
extern uint64_t regQueryValueExpectedData;
} // namespace SysCalls
TEST_F(RegistryReaderTest, givenRegistryReaderWhenItIsCreatedWithUserScopeSetToFalseThenItsHkeyTypeIsInitializedToHkeyLocalMachine) {
bool userScope = false;
@@ -69,89 +72,89 @@ TEST_F(RegistryReaderTest, givenRegistryReaderWhenEnvironmentIntVariableExistsTh
}
struct DebugReaderWithRegistryAndEnvTest : ::testing::Test {
VariableBackup<uint32_t> openRegCountBackup{&regOpenKeySuccessCount};
VariableBackup<uint32_t> queryRegCountBackup{&regQueryValueSuccessCount};
VariableBackup<uint32_t> openRegCountBackup{&SysCalls::regOpenKeySuccessCount};
VariableBackup<uint32_t> queryRegCountBackup{&SysCalls::regQueryValueSuccessCount};
TestedRegistryReader registryReader{""};
};
TEST_F(DebugReaderWithRegistryAndEnvTest, givenIntDebugKeyWhenReadFromRegistrySucceedsThenReturnObtainedValue) {
regOpenKeySuccessCount = 1u;
regQueryValueSuccessCount = 1u;
SysCalls::regOpenKeySuccessCount = 1u;
SysCalls::regQueryValueSuccessCount = 1u;
EXPECT_EQ(1, registryReader.getSetting("settingSourceInt", 0));
}
TEST_F(DebugReaderWithRegistryAndEnvTest, givenInt64DebugKeyWhenReadFromRegistrySucceedsThenReturnObtainedValue) {
regOpenKeySuccessCount = 1u;
regQueryValueSuccessCount = 1u;
SysCalls::regOpenKeySuccessCount = 1u;
SysCalls::regQueryValueSuccessCount = 1u;
EXPECT_EQ(0xffffffffeeeeeeee, registryReader.getSetting("settingSourceInt64", 0));
}
TEST_F(DebugReaderWithRegistryAndEnvTest, givenIntDebugKeyWhenQueryValueFailsThenObtainValueFromEnv) {
regOpenKeySuccessCount = 1u;
regQueryValueSuccessCount = 0u;
SysCalls::regOpenKeySuccessCount = 1u;
SysCalls::regQueryValueSuccessCount = 0u;
EXPECT_EQ(2, registryReader.getSetting("settingSourceInt", 0));
}
TEST_F(DebugReaderWithRegistryAndEnvTest, givenIntDebugKeyWhenOpenKeyFailsThenObtainValueFromEnv) {
regOpenKeySuccessCount = 0u;
regQueryValueSuccessCount = 0u;
SysCalls::regOpenKeySuccessCount = 0u;
SysCalls::regQueryValueSuccessCount = 0u;
EXPECT_EQ(2, registryReader.getSetting("settingSourceInt", 0));
}
TEST_F(DebugReaderWithRegistryAndEnvTest, givenStringDebugKeyWhenReadFromRegistrySucceedsThenReturnObtainedValue) {
std::string defaultValue("default");
regOpenKeySuccessCount = 1u;
regQueryValueSuccessCount = 2u;
SysCalls::regOpenKeySuccessCount = 1u;
SysCalls::regQueryValueSuccessCount = 2u;
EXPECT_STREQ("registry", registryReader.getSetting("settingSourceString", defaultValue).c_str());
}
TEST_F(DebugReaderWithRegistryAndEnvTest, givenStringDebugKeyWhenQueryValueFailsThenObtainValueFromEnv) {
std::string defaultValue("default");
regOpenKeySuccessCount = 1u;
regQueryValueSuccessCount = 0u;
SysCalls::regOpenKeySuccessCount = 1u;
SysCalls::regQueryValueSuccessCount = 0u;
EXPECT_STREQ("environment", registryReader.getSetting("settingSourceString", defaultValue).c_str());
regOpenKeySuccessCount = 1u;
regQueryValueSuccessCount = 1u;
SysCalls::regOpenKeySuccessCount = 1u;
SysCalls::regQueryValueSuccessCount = 1u;
EXPECT_STREQ("environment", registryReader.getSetting("settingSourceString", defaultValue).c_str());
}
TEST_F(DebugReaderWithRegistryAndEnvTest, givenStringDebugKeyWhenOpenKeyFailsThenObtainValueFromEnv) {
std::string defaultValue("default");
regOpenKeySuccessCount = 0u;
regQueryValueSuccessCount = 0u;
SysCalls::regOpenKeySuccessCount = 0u;
SysCalls::regQueryValueSuccessCount = 0u;
EXPECT_STREQ("environment", registryReader.getSetting("settingSourceString", defaultValue).c_str());
}
TEST_F(DebugReaderWithRegistryAndEnvTest, givenBinaryDebugKeyWhenReadFromRegistrySucceedsThenReturnObtainedValue) {
std::string defaultValue("default");
regOpenKeySuccessCount = 1u;
regQueryValueSuccessCount = 2u;
SysCalls::regOpenKeySuccessCount = 1u;
SysCalls::regQueryValueSuccessCount = 2u;
EXPECT_STREQ("registry", registryReader.getSetting("settingSourceBinary", defaultValue).c_str());
}
TEST_F(DebugReaderWithRegistryAndEnvTest, givenBinaryDebugKeyOnlyInRegistryWhenReadFromRegistryFailsThenReturnDefaultValue) {
std::string defaultValue("default");
regOpenKeySuccessCount = 1u;
regQueryValueSuccessCount = 1u;
SysCalls::regOpenKeySuccessCount = 1u;
SysCalls::regQueryValueSuccessCount = 1u;
EXPECT_STREQ("default", registryReader.getSetting("settingSourceBinary", defaultValue).c_str());
regOpenKeySuccessCount = 1u;
regQueryValueSuccessCount = 0u;
SysCalls::regOpenKeySuccessCount = 1u;
SysCalls::regQueryValueSuccessCount = 0u;
EXPECT_STREQ("default", registryReader.getSetting("settingSourceBinary", defaultValue).c_str());
regOpenKeySuccessCount = 0u;
regQueryValueSuccessCount = 0u;
SysCalls::regOpenKeySuccessCount = 0u;
SysCalls::regQueryValueSuccessCount = 0u;
EXPECT_STREQ("default", registryReader.getSetting("settingSourceBinary", defaultValue).c_str());
}
@@ -161,9 +164,9 @@ TEST_F(RegistryReaderTest, givenRegistryKeyPresentWhenValueIsZeroThenExpectBoole
std::string keyName = "boolRegistryKey";
bool defaultValue = false;
regOpenKeySuccessCount = 1;
regQueryValueSuccessCount = 1;
regQueryValueExpectedData = 0ull;
SysCalls::regOpenKeySuccessCount = 1;
SysCalls::regQueryValueSuccessCount = 1;
SysCalls::regQueryValueExpectedData = 0ull;
TestedRegistryReader registryReader(regKey);
bool value = registryReader.getSetting(keyName.c_str(), defaultValue);
@@ -175,18 +178,18 @@ TEST_F(RegistryReaderTest, givenRegistryKeyNotPresentWhenDefaulValueIsFalseOrTru
std::string keyName = "boolRegistryKey";
bool defaultValue = false;
regOpenKeySuccessCount = 1;
regQueryValueSuccessCount = 0;
regQueryValueExpectedData = 1ull;
SysCalls::regOpenKeySuccessCount = 1;
SysCalls::regQueryValueSuccessCount = 0;
SysCalls::regQueryValueExpectedData = 1ull;
TestedRegistryReader registryReader(regKey);
bool value = registryReader.getSetting(keyName.c_str(), defaultValue);
EXPECT_FALSE(value);
defaultValue = true;
regOpenKeySuccessCount = 1;
regQueryValueSuccessCount = 0;
regQueryValueExpectedData = 0ull;
SysCalls::regOpenKeySuccessCount = 1;
SysCalls::regQueryValueSuccessCount = 0;
SysCalls::regQueryValueExpectedData = 0ull;
value = registryReader.getSetting(keyName.c_str(), defaultValue);
EXPECT_TRUE(value);
@@ -197,9 +200,9 @@ TEST_F(RegistryReaderTest, givenRegistryKeyPresentWhenValueIsNonZeroInHigherDwor
std::string keyName = "boolRegistryKey";
bool defaultValue = true;
regOpenKeySuccessCount = 1;
regQueryValueSuccessCount = 1;
regQueryValueExpectedData = 1ull << 32;
SysCalls::regOpenKeySuccessCount = 1;
SysCalls::regQueryValueSuccessCount = 1;
SysCalls::regQueryValueExpectedData = 1ull << 32;
TestedRegistryReader registryReader(regKey);
bool value = registryReader.getSetting(keyName.c_str(), defaultValue);
@@ -211,9 +214,9 @@ TEST_F(RegistryReaderTest, givenRegistryKeyPresentWhenValueIsNonZeroInLowerDword
std::string keyName = "boolRegistryKey";
bool defaultValue = false;
regOpenKeySuccessCount = 1;
regQueryValueSuccessCount = 1;
regQueryValueExpectedData = 1ull;
SysCalls::regOpenKeySuccessCount = 1;
SysCalls::regQueryValueSuccessCount = 1;
SysCalls::regQueryValueExpectedData = 1ull;
TestedRegistryReader registryReader(regKey);
bool value = registryReader.getSetting(keyName.c_str(), defaultValue);
@@ -225,9 +228,9 @@ TEST_F(RegistryReaderTest, givenRegistryKeyPresentWhenValueIsNonZeroInBothDwords
std::string keyName = "boolRegistryKey";
bool defaultValue = false;
regOpenKeySuccessCount = 1;
regQueryValueSuccessCount = 1;
regQueryValueExpectedData = 1ull | (1ull << 32);
SysCalls::regOpenKeySuccessCount = 1;
SysCalls::regQueryValueSuccessCount = 1;
SysCalls::regQueryValueExpectedData = 1ull | (1ull << 32);
TestedRegistryReader registryReader(regKey);
bool value = registryReader.getSetting(keyName.c_str(), defaultValue);
@@ -235,10 +238,11 @@ TEST_F(RegistryReaderTest, givenRegistryKeyPresentWhenValueIsNonZeroInBothDwords
}
TEST_F(DebugReaderWithRegistryAndEnvTest, givenSetProcessNameWhenReadFromEnvironmentVariableThenReturnClCacheDir) {
regOpenKeySuccessCount = 0u;
regQueryValueSuccessCount = 0u;
SysCalls::regOpenKeySuccessCount = 0u;
SysCalls::regQueryValueSuccessCount = 0u;
registryReader.processName = "processName";
std::string defaultCacheDir = "";
std::string cacheDir = registryReader.getSetting("processName", defaultCacheDir);
EXPECT_STREQ("./tested_cl_cache_dir", cacheDir.c_str());
}
} // namespace NEO

View File

@@ -9,6 +9,8 @@
#include "opencl/test/unit_test/os_interface/windows/mock_sys_calls.h"
#include <cstdint>
namespace NEO {
namespace SysCalls {
@@ -20,6 +22,10 @@ unsigned int getProcessId() {
BOOL systemPowerStatusRetVal = 1;
BYTE systemPowerStatusACLineStatusOverride = 1;
const wchar_t *currentLibraryPath = L"";
uint32_t regOpenKeySuccessCount = 0u;
uint32_t regQueryValueSuccessCount = 0u;
uint64_t regQueryValueExpectedData = 0ull;
const HKEY validHkey = reinterpret_cast<HKEY>(0);
HANDLE createEvent(LPSECURITY_ATTRIBUTES lpEventAttributes, BOOL bManualReset, BOOL bInitialState, LPCSTR lpName) {
if (mockCreateEventClb) {
@@ -60,6 +66,69 @@ char *getenv(const char *variableName) {
}
return ::getenv(variableName);
}
LSTATUS regOpenKeyExA(HKEY hKey, LPCSTR lpSubKey, DWORD ulOptions, REGSAM samDesired, PHKEY phkResult) {
if (regOpenKeySuccessCount > 0) {
regOpenKeySuccessCount--;
if (phkResult) {
*phkResult = validHkey;
}
return ERROR_SUCCESS;
}
return ERROR_FILE_NOT_FOUND;
};
LSTATUS regQueryValueExA(HKEY hKey, LPCSTR lpValueName, LPDWORD lpReserved, LPDWORD lpType, LPBYTE lpData, LPDWORD lpcbData) {
if (hKey == validHkey && regQueryValueSuccessCount > 0) {
regQueryValueSuccessCount--;
if (lpcbData) {
if (strcmp(lpValueName, "settingSourceString") == 0) {
const auto settingSource = "registry";
if (lpData) {
strcpy(reinterpret_cast<char *>(lpData), settingSource);
} else {
*lpcbData = static_cast<DWORD>(strlen(settingSource) + 1u);
if (lpType) {
*lpType = REG_SZ;
}
}
} else if (strcmp(lpValueName, "settingSourceInt") == 0) {
if (lpData) {
*reinterpret_cast<DWORD *>(lpData) = 1;
} else {
*lpcbData = sizeof(DWORD);
}
} else if (strcmp(lpValueName, "settingSourceInt64") == 0) {
if (lpData) {
*reinterpret_cast<INT64 *>(lpData) = 0xffffffffeeeeeeee;
} else {
*lpcbData = sizeof(INT64);
}
} else if (strcmp(lpValueName, "settingSourceBinary") == 0) {
const auto settingSource = L"registry";
auto size = wcslen(settingSource) * sizeof(wchar_t);
if (lpData) {
memcpy(reinterpret_cast<wchar_t *>(lpData), settingSource, size);
} else {
*lpcbData = static_cast<DWORD>(size);
if (lpType) {
*lpType = REG_BINARY;
}
}
} else if (strcmp(lpValueName, "boolRegistryKey") == 0) {
if (*lpcbData == sizeof(int64_t)) {
if (lpData) {
*reinterpret_cast<uint64_t *>(lpData) = regQueryValueExpectedData;
}
}
}
}
return ERROR_SUCCESS;
}
return ERROR_FILE_NOT_FOUND;
};
} // namespace SysCalls
bool isShutdownInProgress() {

View File

@@ -57,22 +57,22 @@ int64_t RegistryReader::getSetting(const char *settingName, int64_t defaultValue
DWORD success = ERROR_SUCCESS;
bool readSettingFromEnv = true;
success = RegOpenKeyExA(hkeyType,
registryReadRootKey.c_str(),
0,
KEY_READ,
&Key);
success = SysCalls::regOpenKeyExA(hkeyType,
registryReadRootKey.c_str(),
0,
KEY_READ,
&Key);
if (ERROR_SUCCESS == success) {
DWORD size = sizeof(int64_t);
int64_t regData;
success = RegQueryValueExA(Key,
settingName,
NULL,
NULL,
reinterpret_cast<LPBYTE>(&regData),
&size);
success = SysCalls::regQueryValueExA(Key,
settingName,
NULL,
NULL,
reinterpret_cast<LPBYTE>(&regData),
&size);
if (ERROR_SUCCESS == success) {
value = regData;
readSettingFromEnv = false;
@@ -95,30 +95,30 @@ std::string RegistryReader::getSetting(const char *settingName, const std::strin
std::string keyValue = value;
bool readSettingFromEnv = true;
success = RegOpenKeyExA(hkeyType,
registryReadRootKey.c_str(),
0,
KEY_READ,
&Key);
success = SysCalls::regOpenKeyExA(hkeyType,
registryReadRootKey.c_str(),
0,
KEY_READ,
&Key);
if (ERROR_SUCCESS == success) {
DWORD regType = REG_NONE;
DWORD regSize = 0;
success = RegQueryValueExA(Key,
settingName,
NULL,
&regType,
NULL,
&regSize);
success = SysCalls::regQueryValueExA(Key,
settingName,
NULL,
&regType,
NULL,
&regSize);
if (ERROR_SUCCESS == success) {
if (regType == REG_SZ || regType == REG_MULTI_SZ) {
auto regData = std::make_unique<char[]>(regSize);
success = RegQueryValueExA(Key,
settingName,
NULL,
&regType,
reinterpret_cast<LPBYTE>(regData.get()),
&regSize);
success = SysCalls::regQueryValueExA(Key,
settingName,
NULL,
&regType,
reinterpret_cast<LPBYTE>(regData.get()),
&regSize);
if (success == ERROR_SUCCESS) {
keyValue.assign(regData.get());
readSettingFromEnv = false;
@@ -126,12 +126,12 @@ std::string RegistryReader::getSetting(const char *settingName, const std::strin
} else if (regType == REG_BINARY) {
size_t charCount = regSize / sizeof(wchar_t);
auto regData = std::make_unique<wchar_t[]>(charCount);
success = RegQueryValueExA(Key,
settingName,
NULL,
&regType,
reinterpret_cast<LPBYTE>(regData.get()),
&regSize);
success = SysCalls::regQueryValueExA(Key,
settingName,
NULL,
&regType,
reinterpret_cast<LPBYTE>(regData.get()),
&regSize);
if (ERROR_SUCCESS == success) {

View File

@@ -50,6 +50,14 @@ DWORD getModuleFileName(HMODULE hModule, LPWSTR lpFilename, DWORD nSize) {
char *getenv(const char *variableName) {
return ::getenv(variableName);
}
LSTATUS regOpenKeyExA(HKEY hKey, LPCSTR lpSubKey, DWORD ulOptions, REGSAM samDesired, PHKEY phkResult) {
return RegOpenKeyExA(hKey, lpSubKey, ulOptions, samDesired, phkResult);
}
LSTATUS regQueryValueExA(HKEY hKey, LPCSTR lpValueName, LPDWORD lpReserved, LPDWORD lpType, LPBYTE lpData, LPDWORD lpcbData) {
return RegQueryValueExA(hKey, lpValueName, lpReserved, lpType, lpData, lpcbData);
}
} // namespace SysCalls
} // namespace NEO

View File

@@ -21,6 +21,9 @@ BOOL getModuleHandle(DWORD dwFlags, LPCWSTR lpModuleName, HMODULE *phModule);
DWORD getModuleFileName(HMODULE hModule, LPWSTR lpFilename, DWORD nSize);
char *getenv(const char *variableName);
LSTATUS regOpenKeyExA(HKEY hKey, LPCSTR lpSubKey, DWORD ulOptions, REGSAM samDesired, PHKEY phkResult);
LSTATUS regQueryValueExA(HKEY hKey, LPCSTR lpValueName, LPDWORD lpReserved, LPDWORD lpType, LPBYTE lpData, LPDWORD lpcbData);
} // namespace SysCalls
} // namespace NEO

View File

@@ -18,19 +18,6 @@
#undef RegOpenKeyExA
#undef RegQueryValueExA
#pragma warning(disable : 4273)
LSTATUS APIENTRY RegOpenKeyExA(
HKEY hKey,
LPCSTR lpSubKey,
DWORD ulOptions,
REGSAM samDesired,
PHKEY phkResult);
LSTATUS APIENTRY RegQueryValueExA(
HKEY hKey,
LPCSTR lpValueName,
LPDWORD lpReserved,
LPDWORD lpType,
LPBYTE lpData,
LPDWORD lpcbData);
#else
#include <cstdint>
#if __clang__