From 3297cd8342927cf5e4561e7d044c1d5e48e2c009 Mon Sep 17 00:00:00 2001 From: Artur Harasimiuk Date: Fri, 2 Mar 2018 15:19:29 +0100 Subject: [PATCH] add function to properly load dependency Change-Id: I0cec677ae19fa6525890c9b0abe0601a0c11e7df --- runtime/os_interface/windows/os_library.cpp | 19 +++- runtime/os_interface/windows/os_library.h | 6 ++ unit_tests/os_interface/CMakeLists.txt | 3 +- .../windows/os_library_win_tests.cpp | 98 +++++++++++++++++++ 4 files changed, 123 insertions(+), 3 deletions(-) create mode 100644 unit_tests/os_interface/windows/os_library_win_tests.cpp diff --git a/runtime/os_interface/windows/os_library.cpp b/runtime/os_interface/windows/os_library.cpp index b4b49cb0bf..5a523133a7 100644 --- a/runtime/os_interface/windows/os_library.cpp +++ b/runtime/os_interface/windows/os_library.cpp @@ -22,7 +22,6 @@ #include "runtime/os_interface/os_library.h" #include "os_library.h" -#include "DriverStore.h" namespace OCLRT { @@ -37,12 +36,28 @@ OsLibrary *OsLibrary::load(const std::string &name) { } namespace Windows { +decltype(&LoadLibraryExA) OsLibrary::loadLibraryExA = LoadLibraryExA; +decltype(&GetModuleFileNameA) OsLibrary::getModuleFileNameA = GetModuleFileNameA; + +HMODULE OsLibrary::loadDependency(const std::string &dependencyFileName) const { + char dllPath[MAX_PATH]; + DWORD length = getModuleFileNameA(GetModuleHandle(NULL), dllPath, MAX_PATH); + for (DWORD idx = length; idx > 0; idx--) { + if (dllPath[idx - 1] == '\\') { + dllPath[idx] = '\0'; + break; + } + } + strcat_s(dllPath, MAX_PATH, dependencyFileName.c_str()); + + return loadLibraryExA(dllPath, NULL, 0); +} OsLibrary::OsLibrary(const std::string &name) { if (name.empty()) { this->handle = GetModuleHandleA(nullptr); } else { - this->handle = LoadDependency(name.c_str()); + this->handle = loadDependency(name); if (this->handle == nullptr) { this->handle = ::LoadLibraryA(name.c_str()); } diff --git a/runtime/os_interface/windows/os_library.h b/runtime/os_interface/windows/os_library.h index 405d4a075a..55b6edfb45 100644 --- a/runtime/os_interface/windows/os_library.h +++ b/runtime/os_interface/windows/os_library.h @@ -39,6 +39,12 @@ class OsLibrary : public OCLRT::OsLibrary { bool isLoaded(); void *getProcAddress(const std::string &procName); + + protected: + HMODULE loadDependency(const std::string &dependencyFileName) const; + + static decltype(&LoadLibraryExA) loadLibraryExA; + static decltype(&GetModuleFileNameA) getModuleFileNameA; }; } } diff --git a/unit_tests/os_interface/CMakeLists.txt b/unit_tests/os_interface/CMakeLists.txt index bfa20d5b94..bab336c9da 100644 --- a/unit_tests/os_interface/CMakeLists.txt +++ b/unit_tests/os_interface/CMakeLists.txt @@ -30,9 +30,10 @@ set(IGDRCL_SRCS_tests_os_interface_windows "${CMAKE_CURRENT_SOURCE_DIR}/windows/mock_os_time_win.h" "${CMAKE_CURRENT_SOURCE_DIR}/windows/mock_wddm_memory_manager.h" "${CMAKE_CURRENT_SOURCE_DIR}/windows/options.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/windows/os_time_win_tests.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/windows/os_library_win_tests.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/windows/os_interface_tests.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/windows/os_interface_tests.h" + "${CMAKE_CURRENT_SOURCE_DIR}/windows/os_time_win_tests.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/windows/self_lib_win.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/windows/ult_dxgi_factory.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/windows/ult_dxgi_factory.h" diff --git a/unit_tests/os_interface/windows/os_library_win_tests.cpp b/unit_tests/os_interface/windows/os_library_win_tests.cpp new file mode 100644 index 0000000000..aaf32fab85 --- /dev/null +++ b/unit_tests/os_interface/windows/os_library_win_tests.cpp @@ -0,0 +1,98 @@ +/* +* Copyright (c) 2017 - 2018, Intel Corporation +* +* Permission is hereby granted, free of charge, to any person obtaining a +* copy of this software and associated documentation files (the "Software"), +* to deal in the Software without restriction, including without limitation +* the rights to use, copy, modify, merge, publish, distribute, sublicense, +* and/or sell copies of the Software, and to permit persons to whom the +* Software is furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included +* in all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR +* OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +* OTHER DEALINGS IN THE SOFTWARE. +*/ + +#include "runtime/os_interface/windows/os_library.h" +#include "unit_tests/helpers/variable_backup.h" +#include "test.h" +#include "gtest/gtest.h" + +#include + +namespace Os { +extern const char *testDllName; +} + +using namespace OCLRT; + +class OsLibraryBackup : public Windows::OsLibrary { + using Type = decltype(Windows::OsLibrary::loadLibraryExA); + using BackupType = typename VariableBackup; + + using ModuleNameType = decltype(Windows::OsLibrary::getModuleFileNameA); + using ModuleNameBackupType = typename VariableBackup; + + struct Backup { + std::unique_ptr bkp1 = nullptr; + std::unique_ptr bkp2 = nullptr; + }; + + public: + static std::unique_ptr backup(Type newValue, ModuleNameType newModuleName) { + std::unique_ptr bkp(new Backup()); + bkp->bkp1.reset(new BackupType(&OsLibrary::loadLibraryExA, newValue)); + bkp->bkp2.reset(new ModuleNameBackupType(&OsLibrary::getModuleFileNameA, newModuleName)); + return bkp; + }; +}; + +bool mockWillFail = true; +void trimFileName(char *buff, size_t length) { + for (size_t l = length; l > 0; l--) { + if (buff[l - 1] == '\\') { + buff[l] = '\0'; + break; + } + } +} + +DWORD WINAPI GetModuleFileNameAMock(HMODULE hModule, LPSTR lpFilename, DWORD nSize) { + return snprintf(lpFilename, nSize, "z:\\SomeFakeName.dll"); +} + +HMODULE WINAPI LoadLibraryExAMock(LPCSTR lpFileName, HANDLE hFile, DWORD dwFlags) { + if (mockWillFail) + return NULL; + + char fName[MAX_PATH]; + auto lenFn = strlen(lpFileName); + strcpy_s(fName, sizeof(fName), lpFileName); + trimFileName(fName, lenFn); + + EXPECT_STREQ("z:\\", fName); + + return (HMODULE)1; +} + +TEST(OSLibraryWinTest, gitOsLibraryWinWhenLoadDependencyFailsThenFallbackToNonDriverStore) { + auto bkp = OsLibraryBackup::backup(LoadLibraryExAMock, GetModuleFileNameAMock); + + std::unique_ptr library(OsLibrary::load(Os::testDllName)); + EXPECT_NE(nullptr, library); +} + +TEST(OSLibraryWinTest, gitOsLibraryWinWhenLoadDependencyThenProperPathIsConstructed) { + auto bkp = OsLibraryBackup::backup(LoadLibraryExAMock, GetModuleFileNameAMock); + VariableBackup bkpM(&mockWillFail, false); + + std::unique_ptr library(OsLibrary::load(Os::testDllName)); + EXPECT_NE(nullptr, library); +}