fix: correct loading L0 loader functions

- don't load ze_loader.dll from file system
- to perform self-open on Windows use getModuleHandleA with proper module name
- don't free library loaded with getModuleHandleA
- as loader may be not available during runtime teardown:
- load translate handle function during global setup
- load setDriverTeardown function during global teardown
- when loader is not available during teardown, unset translate handle
function

Related-To: GSD-10147

Signed-off-by: Mateusz Jablonski <mateusz.jablonski@intel.com>
This commit is contained in:
Mateusz Jablonski
2024-10-16 19:53:07 +00:00
committed by Compute-Runtime-Automation
parent 9d6d6e85f1
commit 4154e6666b
38 changed files with 300 additions and 123 deletions

View File

@ -19,15 +19,30 @@ decltype(&zelLoaderTranslateHandle) loaderTranslateHandleFunc = nullptr;
decltype(&zelSetDriverTeardown) setDriverTeardownFunc = nullptr;
void globalDriverSetup() {
std::unique_ptr<NEO::OsLibrary> loaderLibrary = std::unique_ptr<NEO::OsLibrary>{NEO::OsLibrary::loadFunc("")};
loaderTranslateHandleFunc = reinterpret_cast<decltype(&zelLoaderTranslateHandle)>(loaderLibrary->getProcAddress("zelLoaderTranslateHandle"));
setDriverTeardownFunc = reinterpret_cast<decltype(&zelSetDriverTeardown)>(loaderLibrary->getProcAddress("zelSetDriverTeardown"));
NEO::OsLibraryCreateProperties loaderLibraryProperties("ze_loader.dll");
loaderLibraryProperties.performSelfLoad = true;
std::unique_ptr<NEO::OsLibrary> loaderLibrary = std::unique_ptr<NEO::OsLibrary>{NEO::OsLibrary::loadFunc(loaderLibraryProperties)};
if (loaderLibrary) {
loaderTranslateHandleFunc = reinterpret_cast<decltype(&zelLoaderTranslateHandle)>(loaderLibrary->getProcAddress("zelLoaderTranslateHandle"));
}
}
void globalDriverTeardown() {
if (levelZeroDriverInitialized && setDriverTeardownFunc) {
setDriverTeardownFunc();
if (levelZeroDriverInitialized) {
NEO::OsLibraryCreateProperties loaderLibraryProperties("ze_loader.dll");
loaderLibraryProperties.performSelfLoad = true;
std::unique_ptr<NEO::OsLibrary> loaderLibrary = std::unique_ptr<NEO::OsLibrary>{NEO::OsLibrary::loadFunc(loaderLibraryProperties)};
if (loaderLibrary) {
setDriverTeardownFunc = reinterpret_cast<decltype(&zelSetDriverTeardown)>(loaderLibrary->getProcAddress("zelSetDriverTeardown"));
if (setDriverTeardownFunc) {
setDriverTeardownFunc();
}
} else {
loaderTranslateHandleFunc = nullptr;
}
}
if (globalDriver != nullptr) {
if (globalDriver->pid == NEO::SysCalls::getCurrentProcessId()) {

View File

@ -21,7 +21,7 @@
namespace L0 {
namespace ult {
TEST(GlobalTearDownTests, whenCallingGlobalDriverSetupThenLoaderFunctionsAreLoadedIfAvailable) {
TEST(GlobalTearDownTests, whenCallingGlobalDriverSetupThenLoaderFunctionForTranslateHandleIsLoadedIfAvailable) {
void *mockSetDriverTeardownPtr = reinterpret_cast<void *>(static_cast<uintptr_t>(0x1234ABC8));
void *mockLoaderTranslateHandlePtr = reinterpret_cast<void *>(static_cast<uintptr_t>(0x5678EF08));
@ -31,6 +31,12 @@ TEST(GlobalTearDownTests, whenCallingGlobalDriverSetupThenLoaderFunctionsAreLoad
VariableBackup<decltype(NEO::OsLibrary::loadFunc)> loadFuncBackup{&NEO::OsLibrary::loadFunc, MockOsLibraryCustom::load};
VariableBackup<decltype(MockOsLibrary::loadLibraryNewObject)> mockLibraryBackup{&MockOsLibrary::loadLibraryNewObject, nullptr};
MockOsLibrary::loadLibraryNewObject = nullptr;
globalDriverSetup();
EXPECT_EQ(nullptr, setDriverTeardownFunc);
EXPECT_EQ(nullptr, loaderTranslateHandleFunc);
MockOsLibrary::loadLibraryNewObject = new MockOsLibraryCustom(nullptr, true);
globalDriverSetup();
@ -42,7 +48,7 @@ TEST(GlobalTearDownTests, whenCallingGlobalDriverSetupThenLoaderFunctionsAreLoad
osLibrary->procMap["zelSetDriverTeardown"] = mockSetDriverTeardownPtr;
globalDriverSetup();
EXPECT_EQ(mockSetDriverTeardownPtr, reinterpret_cast<void *>(setDriverTeardownFunc));
EXPECT_EQ(nullptr, setDriverTeardownFunc);
EXPECT_EQ(nullptr, loaderTranslateHandleFunc);
MockOsLibrary::loadLibraryNewObject = new MockOsLibraryCustom(nullptr, true);
@ -59,10 +65,71 @@ TEST(GlobalTearDownTests, whenCallingGlobalDriverSetupThenLoaderFunctionsAreLoad
osLibrary->procMap["zelLoaderTranslateHandle"] = mockLoaderTranslateHandlePtr;
globalDriverSetup();
EXPECT_EQ(mockSetDriverTeardownPtr, reinterpret_cast<void *>(setDriverTeardownFunc));
EXPECT_EQ(nullptr, setDriverTeardownFunc);
EXPECT_EQ(mockLoaderTranslateHandlePtr, reinterpret_cast<void *>(loaderTranslateHandleFunc));
}
uint32_t loaderTearDownCalled = 0;
ze_result_t loaderTearDown() {
loaderTearDownCalled++;
return ZE_RESULT_ERROR_UNKNOWN;
};
TEST(GlobalTearDownTests, givenInitializedDriverWhenCallingGlobalDriverTeardownThenLoaderFunctionForTeardownIsLoadedAndCalledIfAvailable) {
void *mockLoaderTranslateHandlePtr = reinterpret_cast<void *>(static_cast<uintptr_t>(0x5678EF08));
VariableBackup<decltype(loaderTearDownCalled)> loaderTeardownCalledBackup{&loaderTearDownCalled, 0};
VariableBackup<decltype(levelZeroDriverInitialized)> driverInitializeBackup{&levelZeroDriverInitialized, true};
VariableBackup<decltype(setDriverTeardownFunc)> teardownFuncBackup{&setDriverTeardownFunc, nullptr};
VariableBackup<decltype(loaderTranslateHandleFunc)> translateFuncBackup{&loaderTranslateHandleFunc, nullptr};
VariableBackup<decltype(NEO::OsLibrary::loadFunc)> loadFuncBackup{&NEO::OsLibrary::loadFunc, MockOsLibraryCustom::load};
VariableBackup<decltype(MockOsLibrary::loadLibraryNewObject)> mockLibraryBackup{&MockOsLibrary::loadLibraryNewObject, nullptr};
loaderTranslateHandleFunc = reinterpret_cast<decltype(loaderTranslateHandleFunc)>(mockLoaderTranslateHandlePtr);
MockOsLibrary::loadLibraryNewObject = nullptr;
globalDriverTeardown();
EXPECT_EQ(nullptr, setDriverTeardownFunc);
EXPECT_EQ(nullptr, loaderTranslateHandleFunc);
loaderTranslateHandleFunc = reinterpret_cast<decltype(loaderTranslateHandleFunc)>(mockLoaderTranslateHandlePtr);
MockOsLibrary::loadLibraryNewObject = new MockOsLibraryCustom(nullptr, true);
globalDriverTeardown();
EXPECT_EQ(nullptr, setDriverTeardownFunc);
EXPECT_EQ(mockLoaderTranslateHandlePtr, reinterpret_cast<void *>(loaderTranslateHandleFunc));
MockOsLibrary::loadLibraryNewObject = new MockOsLibraryCustom(nullptr, true);
auto osLibrary = static_cast<MockOsLibraryCustom *>(MockOsLibrary::loadLibraryNewObject);
osLibrary->procMap["zelSetDriverTeardown"] = reinterpret_cast<void *>(&loaderTearDown);
globalDriverTeardown();
EXPECT_EQ(&loaderTearDown, reinterpret_cast<void *>(setDriverTeardownFunc));
EXPECT_EQ(mockLoaderTranslateHandlePtr, reinterpret_cast<void *>(loaderTranslateHandleFunc));
EXPECT_EQ(1u, loaderTearDownCalled);
MockOsLibrary::loadLibraryNewObject = new MockOsLibraryCustom(nullptr, true);
osLibrary = static_cast<MockOsLibraryCustom *>(MockOsLibrary::loadLibraryNewObject);
osLibrary->procMap["zelLoaderTranslateHandle"] = mockLoaderTranslateHandlePtr;
globalDriverTeardown();
EXPECT_EQ(nullptr, setDriverTeardownFunc);
EXPECT_EQ(mockLoaderTranslateHandlePtr, reinterpret_cast<void *>(loaderTranslateHandleFunc));
EXPECT_EQ(1u, loaderTearDownCalled);
loaderTranslateHandleFunc = nullptr;
MockOsLibrary::loadLibraryNewObject = new MockOsLibraryCustom(nullptr, true);
osLibrary = static_cast<MockOsLibraryCustom *>(MockOsLibrary::loadLibraryNewObject);
osLibrary->procMap["zelSetDriverTeardown"] = reinterpret_cast<void *>(&loaderTearDown);
osLibrary->procMap["zelLoaderTranslateHandle"] = mockLoaderTranslateHandlePtr;
globalDriverTeardown();
EXPECT_EQ(&loaderTearDown, reinterpret_cast<void *>(setDriverTeardownFunc));
EXPECT_EQ(nullptr, loaderTranslateHandleFunc);
}
TEST(GlobalTearDownTests, givenInitializedDriverAndNoTeardownFunctionIsAvailableWhenCallGlobalTeardownThenDontCrash) {
VariableBackup<bool> initializedBackup{&levelZeroDriverInitialized};
VariableBackup<decltype(setDriverTeardownFunc)> teardownFuncBackup{&setDriverTeardownFunc};
@ -72,22 +139,6 @@ TEST(GlobalTearDownTests, givenInitializedDriverAndNoTeardownFunctionIsAvailable
EXPECT_NO_THROW(globalDriverTeardown());
}
TEST(GlobalTearDownTests, givenInitializedDriverAndTeardownFunctionIsAvailableWhenCallGlobalTeardownThenCallTeardownFunc) {
VariableBackup<bool> initializedBackup{&levelZeroDriverInitialized};
VariableBackup<decltype(setDriverTeardownFunc)> teardownFuncBackup{&setDriverTeardownFunc};
static uint32_t teardownCalled = 0u;
levelZeroDriverInitialized = true;
setDriverTeardownFunc = []() -> ze_result_t {
EXPECT_EQ(0u, teardownCalled);
teardownCalled++;
return ZE_RESULT_SUCCESS;
};
teardownCalled = 0u;
globalDriverTeardown();
EXPECT_EQ(1u, teardownCalled);
}
TEST(GlobalTearDownTests, givenNotInitializedDriverAndTeardownFunctionIsAvailableWhenCallGlobalTeardownThenDontCallTeardownFunc) {
VariableBackup<bool> initializedBackup{&levelZeroDriverInitialized};
VariableBackup<decltype(setDriverTeardownFunc)> teardownFuncBackup{&setDriverTeardownFunc};
@ -100,22 +151,6 @@ TEST(GlobalTearDownTests, givenNotInitializedDriverAndTeardownFunctionIsAvailabl
EXPECT_NO_THROW(globalDriverTeardown());
}
TEST(GlobalTearDownTests, givenInitializedDriverAndTeardownFunctionFailsWhenCallGlobalTeardownThenDontCrash) {
VariableBackup<bool> initializedBackup{&levelZeroDriverInitialized};
VariableBackup<decltype(setDriverTeardownFunc)> teardownFuncBackup{&setDriverTeardownFunc};
static uint32_t teardownCalled = 0u;
levelZeroDriverInitialized = true;
setDriverTeardownFunc = []() -> ze_result_t {
EXPECT_EQ(0u, teardownCalled);
teardownCalled++;
return ZE_RESULT_ERROR_UNKNOWN;
};
teardownCalled = 0u;
EXPECT_NO_THROW(globalDriverTeardown());
EXPECT_EQ(1u, teardownCalled);
}
TEST(GlobalTearDownTests, givenCallToGlobalTearDownFunctionThenGlobalDriversAreNull) {
globalDriverTeardown();
EXPECT_EQ(globalDriver, nullptr);

View File

@ -6084,7 +6084,7 @@ struct RTASDeviceTest : public ::testing::Test {
std::string getFullPath() override {
return std::string();
}
static OsLibrary *load(const std::string &name) {
static OsLibrary *load(const NEO::OsLibraryCreateProperties &properties) {
if (failLibraryLoad) {
return nullptr;
}

View File

@ -221,7 +221,7 @@ struct MockRTASOsLibrary : public OsLibrary {
std::string getFullPath() override {
return std::string();
}
static OsLibrary *load(const std::string &name) {
static OsLibrary *load(const OsLibraryCreateProperties &properties) {
if (mockLoad == true) {
auto ptr = new (std::nothrow) MockRTASOsLibrary();
return ptr;
@ -264,7 +264,7 @@ TEST_F(RTASTest, GivenLibraryLoadsSymbolsAndUnderlyingFunctionsSucceedThenSucces
std::string getFullPath() override {
return std::string();
}
static OsLibrary *load(const std::string &name) {
static OsLibrary *load(const OsLibraryCreateProperties &properties) {
auto ptr = new (std::nothrow) MockSymbolsLoadedOsLibrary();
return ptr;
}

View File

@ -208,8 +208,9 @@ FirmwareUtilImp::~FirmwareUtilImp() {
FirmwareUtil *FirmwareUtil::create(uint16_t domain, uint8_t bus, uint8_t device, uint8_t function) {
FirmwareUtilImp *pFwUtilImp = new FirmwareUtilImp(domain, bus, device, function);
UNRECOVERABLE_IF(nullptr == pFwUtilImp);
NEO::OsLibrary::loadFlagsOverwrite = &FirmwareUtilImp::fwUtilLoadFlags;
pFwUtilImp->libraryHandle = NEO::OsLibrary::loadFunc(FirmwareUtilImp::fwUtilLibraryName);
NEO::OsLibraryCreateProperties properties(FirmwareUtilImp::fwUtilLibraryName);
properties.customLoadFlags = &FirmwareUtilImp::fwUtilLoadFlags;
pFwUtilImp->libraryHandle = NEO::OsLibrary::loadFunc(properties);
if (pFwUtilImp->libraryHandle == nullptr || pFwUtilImp->loadEntryPoints() == false || pFwUtilImp->fwDeviceInit() != ZE_RESULT_SUCCESS) {
if (nullptr != pFwUtilImp->libraryHandle) {
delete pFwUtilImp->libraryHandle;

View File

@ -34,8 +34,6 @@ TEST(FwUtilTest, GivenLibraryWasSetWhenCreatingFirmwareUtilInterfaceThenLibraryI
VariableBackup<bool> dlOpenCalledBackup{&dlOpenCalled, false};
VariableBackup<int> dlOpenFlagsBackup{&dlOpenFlags, 0};
auto flags = RTLD_LAZY;
NEO::OsLibrary::loadFlagsOverwrite = &flags;
L0::Sysman::FirmwareUtil *pFwUtil = L0::Sysman::FirmwareUtil::create(0, 0, 0, 0);
EXPECT_EQ(dlOpenCalled, true);
EXPECT_EQ(dlOpenFlags, RTLD_LAZY);

View File

@ -43,8 +43,6 @@ struct MockFwUtilInterface : public L0::Sysman::FirmwareUtil {
struct MockFwUtilOsLibrary : public OsLibrary {
public:
static bool mockLoad;
MockFwUtilOsLibrary(const std::string &name, std::string *errorValue) {
}
MockFwUtilOsLibrary() {}
~MockFwUtilOsLibrary() override = default;
void *getProcAddress(const std::string &procName) override {
@ -61,7 +59,7 @@ struct MockFwUtilOsLibrary : public OsLibrary {
std::string getFullPath() override {
return std::string();
}
static OsLibrary *load(const std::string &name) {
static OsLibrary *load(const OsLibraryCreateProperties &properties) {
if (mockLoad == true) {
auto ptr = new (std::nothrow) MockFwUtilOsLibrary();
return ptr;

View File

@ -114,7 +114,7 @@ ze_result_t MetricEnumeration::loadMetricsDiscovery() {
getMetricsDiscoveryFilename(libnames);
for (auto &name : libnames) {
hMetricsDiscovery.reset(NEO::OsLibrary::loadFunc(name));
hMetricsDiscovery.reset(NEO::OsLibrary::loadFunc({name}));
// Load exported functions.
if (hMetricsDiscovery) {

View File

@ -179,7 +179,7 @@ void MetricsLibrary::release() {
bool MetricsLibrary::load() {
// Load library.
handle = NEO::OsLibrary::loadFunc(getFilename());
handle = NEO::OsLibrary::loadFunc({getFilename()});
// Load exported functions.
if (handle) {

View File

@ -192,8 +192,9 @@ FirmwareUtilImp::~FirmwareUtilImp() {
FirmwareUtil *FirmwareUtil::create(uint16_t domain, uint8_t bus, uint8_t device, uint8_t function) {
FirmwareUtilImp *pFwUtilImp = new FirmwareUtilImp(domain, bus, device, function);
UNRECOVERABLE_IF(nullptr == pFwUtilImp);
NEO::OsLibrary::loadFlagsOverwrite = &FirmwareUtilImp::fwUtilLoadFlags;
pFwUtilImp->libraryHandle = NEO::OsLibrary::loadFunc(FirmwareUtilImp::fwUtilLibraryName);
NEO::OsLibraryCreateProperties properties(FirmwareUtilImp::fwUtilLibraryName);
properties.customLoadFlags = &FirmwareUtilImp::fwUtilLoadFlags;
pFwUtilImp->libraryHandle = NEO::OsLibrary::loadFunc(properties);
if (pFwUtilImp->libraryHandle == nullptr || pFwUtilImp->loadEntryPoints() == false || pFwUtilImp->fwDeviceInit() != ZE_RESULT_SUCCESS) {
if (nullptr != pFwUtilImp->libraryHandle) {
delete pFwUtilImp->libraryHandle;

View File

@ -20,7 +20,7 @@ namespace ult {
class MockOsLibrary : public NEO::OsLibrary {
public:
MockOsLibrary(const std::string &name, std::string *errorValue) {
MockOsLibrary() {
}
void *getProcAddress(const std::string &procName) override {
@ -35,8 +35,8 @@ class MockOsLibrary : public NEO::OsLibrary {
return std::string();
}
static OsLibrary *load(const std::string &name) {
auto ptr = new (std::nothrow) MockOsLibrary(name, nullptr);
static OsLibrary *load(const OsLibraryCreateProperties &properties) {
auto ptr = new (std::nothrow) MockOsLibrary();
if (ptr == nullptr) {
return nullptr;
}

View File

@ -34,8 +34,6 @@ TEST(FwUtilTest, GivenLibraryWasSetWhenCreatingFirmwareUtilInterfaceThenLibraryI
VariableBackup<bool> dlOpenCalledBackup{&dlOpenCalled, false};
VariableBackup<int> dlOpenFlagsBackup{&dlOpenFlags, 0};
auto flags = RTLD_LAZY;
NEO::OsLibrary::loadFlagsOverwrite = &flags;
L0::Sysman::FirmwareUtil *pFwUtil = L0::Sysman::FirmwareUtil::create(0, 0, 0, 0);
EXPECT_EQ(dlOpenCalled, true);
EXPECT_EQ(dlOpenFlags, RTLD_LAZY);

View File

@ -61,7 +61,7 @@ struct MockFwUtilOsLibrary : public OsLibrary {
std::string getFullPath() override {
return std::string();
}
static OsLibrary *load(const std::string &name) {
static OsLibrary *load(const OsLibraryCreateProperties &properties) {
if (mockLoad == true) {
auto ptr = new (std::nothrow) MockFwUtilOsLibrary();
return ptr;

View File

@ -95,7 +95,7 @@ void GLSharingFunctionsLinux::removeGlArbSyncEventMapping(Event &baseEvent) {
}
GLboolean GLSharingFunctionsLinux::initGLFunctions() {
std::unique_ptr<OsLibrary> dynLibrary(OsLibrary::loadFunc(""));
std::unique_ptr<OsLibrary> dynLibrary(OsLibrary::loadFunc({""}));
GlFunctionHelper glXGetProc(dynLibrary.get(), "glXGetProcAddress");
if (glXGetProc.ready()) {

View File

@ -29,7 +29,7 @@ GLSharingFunctionsWindows::~GLSharingFunctionsWindows() {
}
bool GLSharingFunctionsWindows::isGlSharingEnabled() {
static bool oglLibAvailable = std::unique_ptr<OsLibrary>(OsLibrary::loadFunc(Os::openglDllName)).get() != nullptr;
static bool oglLibAvailable = std::unique_ptr<OsLibrary>(OsLibrary::loadFunc({Os::openglDllName})).get() != nullptr;
return oglLibAvailable;
}
@ -128,7 +128,7 @@ void GLSharingFunctionsWindows::removeGlArbSyncEventMapping(Event &baseEvent) {
}
GLboolean GLSharingFunctionsWindows::initGLFunctions() {
glLibrary.reset(OsLibrary::loadFunc(Os::openglDllName));
glLibrary.reset(OsLibrary::loadFunc({Os::openglDllName}));
if (glLibrary->isLoaded()) {
GlFunctionHelper wglLibrary(glLibrary.get(), "wglGetProcAddress");

View File

@ -26,7 +26,7 @@ class GlFunctionHelperMock : public GlFunctionHelper {
};
TEST(GlFunctionHelper, whenCreateGlFunctionHelperThenSetGlFunctionPtrToLoadAnotherFunctions) {
std::unique_ptr<OsLibrary> glLibrary(OsLibrary::loadFunc("mock_opengl32.dll"));
std::unique_ptr<OsLibrary> glLibrary(OsLibrary::loadFunc({"mock_opengl32.dll"}));
EXPECT_TRUE(glLibrary->isLoaded());
GlFunctionHelperMock loader(glLibrary.get(), "mockLoader");
funcType function1 = ConvertibleProcAddr{reinterpret_cast<void *>(loader.glFunctionPtr("realFunction"))};
@ -35,7 +35,7 @@ TEST(GlFunctionHelper, whenCreateGlFunctionHelperThenSetGlFunctionPtrToLoadAnoth
}
TEST(GlFunctionHelper, givenNonExistingFunctionNameWhenCreateGlFunctionHelperThenNullptr) {
std::unique_ptr<OsLibrary> glLibrary(OsLibrary::loadFunc("mock_opengl32.dll"));
std::unique_ptr<OsLibrary> glLibrary(OsLibrary::loadFunc({"mock_opengl32.dll"}));
EXPECT_TRUE(glLibrary->isLoaded());
GlFunctionHelper loader(glLibrary.get(), "mockLoader");
funcType function = loader["nonExistingFunction"];
@ -43,7 +43,7 @@ TEST(GlFunctionHelper, givenNonExistingFunctionNameWhenCreateGlFunctionHelperThe
}
TEST(GlFunctionHelper, givenRealFunctionNameWhenCreateGlFunctionHelperThenGetPointerToAppropriateFunction) {
std::unique_ptr<OsLibrary> glLibrary(OsLibrary::loadFunc("mock_opengl32.dll"));
std::unique_ptr<OsLibrary> glLibrary(OsLibrary::loadFunc({"mock_opengl32.dll"}));
EXPECT_TRUE(glLibrary->isLoaded());
GlFunctionHelper loader(glLibrary.get(), "mockLoader");
funcType function = loader["realFunction"];

View File

@ -38,7 +38,7 @@ using setGLMockValue = void (*)(GLMockReturnedValues);
struct GlDllHelper {
public:
GlDllHelper() {
glDllLoad.reset(OsLibrary::loadFunc(Os::openglDllName));
glDllLoad.reset(OsLibrary::loadFunc({Os::openglDllName}));
if (glDllLoad) {
glSetString = (*glDllLoad)["glSetString"];
UNRECOVERABLE_IF(glSetString == nullptr);

View File

@ -42,7 +42,7 @@ struct IgaWrapper::Impl {
iga.optsContext.cb = sizeof(igaLib.optsContext);
iga.optsContext.gen = igaGen;
iga.library.reset(NEO::OsLibrary::loadFunc(Os::igaDllName));
iga.library.reset(NEO::OsLibrary::loadFunc({Os::igaDllName}));
if (iga.library == nullptr) {
return;
}

View File

@ -81,7 +81,7 @@ int OclocFclFacade::initialize(const HardwareInfo &hwInfo) {
}
std::unique_ptr<OsLibrary> OclocFclFacade::loadFclLibrary() const {
return std::unique_ptr<OsLibrary>{OsLibrary::loadFunc(Os::frontEndDllName)};
return std::unique_ptr<OsLibrary>{OsLibrary::loadFunc({Os::frontEndDllName})};
}
CIF::CreateCIFMainFunc_t OclocFclFacade::loadCreateFclMainFunction() const {

View File

@ -112,7 +112,7 @@ int OclocIgcFacade::initialize(const HardwareInfo &hwInfo) {
}
std::unique_ptr<OsLibrary> OclocIgcFacade::loadIgcLibrary() const {
return std::unique_ptr<OsLibrary>{OsLibrary::loadFunc(Os::igcDllName)};
return std::unique_ptr<OsLibrary>{OsLibrary::loadFunc({Os::igcDllName})};
}
CIF::CreateCIFMainFunc_t OclocIgcFacade::loadCreateIgcMainFunction() const {

View File

@ -88,7 +88,7 @@ void SehException::getCallStack(unsigned int code, struct _EXCEPTION_POINTERS *e
DWORD displacement = 0;
DWORD64 displacement64 = 0;
std::unique_ptr<NEO::OsLibrary> psApiLib(NEO::OsLibrary::loadFunc("psapi.dll"));
std::unique_ptr<NEO::OsLibrary> psApiLib(NEO::OsLibrary::loadFunc({"psapi.dll"}));
auto getMappedFileName = reinterpret_cast<getMappedFileNameFunction>(psApiLib->getProcAddress("GetMappedFileNameA"));
size_t callstackCounter = 0;

View File

@ -90,7 +90,9 @@ template <template <CIF::Version_t> class EntryPointT>
inline bool loadCompiler(const char *libName, std::unique_ptr<OsLibrary> &outLib,
CIF::RAII::UPtr_t<CIF::CIFMain> &outLibMain) {
std::string loadLibraryError;
auto lib = std::unique_ptr<OsLibrary>(OsLibrary::loadAndCaptureError(libName, &loadLibraryError));
OsLibraryCreateProperties libraryProperties(libName);
libraryProperties.errorValue = &loadLibraryError;
auto lib = std::unique_ptr<OsLibrary>(OsLibrary::loadFunc(libraryProperties));
if (lib == nullptr) {
NEO::printDebugString(NEO::debugManager.flags.PrintDebugMessages.get(), stderr, "Compiler Library %s could not be loaded with error: %s\n", libName, loadLibraryError.c_str());
DEBUG_BREAK_IF(true); // could not load library

View File

@ -15,8 +15,8 @@
namespace NEO {
OsLibrary *OsLibrary::loadAndCaptureError(const std::string &name, std::string *errorValue) {
auto ptr = new (std::nothrow) Linux::OsLibrary(name, errorValue);
OsLibrary *OsLibrary::load(const OsLibraryCreateProperties &properties) {
auto ptr = new (std::nothrow) Linux::OsLibrary(properties);
if (ptr == nullptr)
return nullptr;
@ -33,8 +33,8 @@ const std::string OsLibrary::createFullSystemPath(const std::string &name) {
namespace Linux {
OsLibrary::OsLibrary(const std::string &name, std::string *errorValue) {
if (name.empty()) {
OsLibrary::OsLibrary(const OsLibraryCreateProperties &properties) {
if (properties.libraryName.empty() || properties.performSelfLoad) {
this->handle = SysCalls::dlopen(0, RTLD_LAZY);
} else {
#ifdef SANITIZER_BUILD
@ -43,12 +43,11 @@ OsLibrary::OsLibrary(const std::string &name, std::string *errorValue) {
auto dlopenFlag = RTLD_LAZY | RTLD_DEEPBIND;
/* Background: https://github.com/intel/compute-runtime/issues/122 */
#endif
dlopenFlag = OsLibrary::loadFlagsOverwrite ? *OsLibrary::loadFlagsOverwrite : dlopenFlag;
OsLibrary::loadFlagsOverwrite = nullptr;
dlopenFlag = properties.customLoadFlags ? *properties.customLoadFlags : dlopenFlag;
adjustLibraryFlags(dlopenFlag);
this->handle = SysCalls::dlopen(name.c_str(), dlopenFlag);
if (!this->handle && (errorValue != nullptr)) {
errorValue->assign(dlerror());
this->handle = SysCalls::dlopen(properties.libraryName.c_str(), dlopenFlag);
if (!this->handle && (properties.errorValue != nullptr)) {
properties.errorValue->assign(dlerror());
}
}
}

View File

@ -18,7 +18,7 @@ class OsLibrary : public NEO::OsLibrary {
void *handle;
public:
OsLibrary(const std::string &name, std::string *errorValue);
OsLibrary(const OsLibraryCreateProperties &properties);
~OsLibrary() override;
bool isLoaded() override;

View File

@ -16,7 +16,7 @@ namespace NEO {
///////////////////////////////////////////////////////
MetricsLibrary::MetricsLibrary() {
api = std::make_unique<MetricsLibraryInterface>();
osLibrary.reset(OsLibrary::loadFunc(Os::metricsLibraryDllName));
osLibrary.reset(OsLibrary::loadFunc({Os::metricsLibraryDllName}));
}
//////////////////////////////////////////////////////

View File

@ -8,8 +8,6 @@
#include "shared/source/os_interface/os_library.h"
namespace NEO {
const int *OsLibrary::loadFlagsOverwrite = nullptr;
decltype(&OsLibrary::load) OsLibrary::loadFunc = OsLibrary::load;
} // namespace NEO

View File

@ -23,18 +23,25 @@ struct ConvertibleProcAddr {
void *ptr = nullptr;
};
struct OsLibraryCreateProperties {
OsLibraryCreateProperties(std::string name) {
libraryName = name;
}
std::string libraryName;
std::string *errorValue = nullptr;
bool performSelfLoad = false;
int *customLoadFlags = nullptr;
};
class OsLibrary {
protected:
OsLibrary() = default;
static OsLibrary *load(const std::string &name) { return loadAndCaptureError(name, nullptr); }
static OsLibrary *load(const OsLibraryCreateProperties &properties);
public:
virtual ~OsLibrary() = default;
static const int *loadFlagsOverwrite;
static decltype(&OsLibrary::load) loadFunc;
static OsLibrary *loadAndCaptureError(const std::string &name, std::string *errorValue);
static const std::string createFullSystemPath(const std::string &name);
ConvertibleProcAddr operator[](const std::string &name) {

View File

@ -19,7 +19,7 @@ namespace GmmInterface {
GMM_STATUS initialize(GMM_INIT_IN_ARGS *pInArgs, GMM_INIT_OUT_ARGS *pOutArgs) {
if (!gmmLib) {
gmmLib.reset(OsLibrary::loadFunc(GMM_UMD_DLL));
gmmLib.reset(OsLibrary::loadFunc({GMM_UMD_DLL}));
UNRECOVERABLE_IF(!gmmLib);
}
auto initGmmFunc = reinterpret_cast<decltype(&InitializeGmm)>(gmmLib->getProcAddress(GMM_ADAPTER_INIT_NAME));

View File

@ -9,8 +9,8 @@
namespace NEO {
OsLibrary *OsLibrary::loadAndCaptureError(const std::string &name, std::string *errorValue) {
Windows::OsLibrary *ptr = new Windows::OsLibrary(name, errorValue);
OsLibrary *OsLibrary::load(const OsLibraryCreateProperties &properties) {
Windows::OsLibrary *ptr = new Windows::OsLibrary(properties);
if (!ptr->isLoaded()) {
delete ptr;
@ -30,9 +30,11 @@ const std::string OsLibrary::createFullSystemPath(const std::string &name) {
}
namespace Windows {
decltype(&GetModuleHandleA) OsLibrary::getModuleHandleA = GetModuleHandleA;
decltype(&LoadLibraryExA) OsLibrary::loadLibraryExA = LoadLibraryExA;
decltype(&GetModuleFileNameA) OsLibrary::getModuleFileNameA = GetModuleFileNameA;
decltype(&GetSystemDirectoryA) OsLibrary::getSystemDirectoryA = GetSystemDirectoryA;
decltype(&FreeLibrary) OsLibrary::freeLibrary = FreeLibrary;
extern "C" IMAGE_DOS_HEADER __ImageBase; // NOLINT(readability-identifier-naming)
__inline HINSTANCE getModuleHINSTANCE() { return (HINSTANCE)&__ImageBase; }
@ -66,23 +68,29 @@ HMODULE OsLibrary::loadDependency(const std::string &dependencyFileName) const {
return loadLibraryExA(dllPath, NULL, 0);
}
OsLibrary::OsLibrary(const std::string &name, std::string *errorValue) {
if (name.empty()) {
this->handle = GetModuleHandleA(nullptr);
OsLibrary::OsLibrary(const OsLibraryCreateProperties &properties) {
if (properties.libraryName.empty()) {
this->handle = getModuleHandleA(nullptr);
this->selfOpen = true;
} else {
this->handle = loadDependency(name);
if (this->handle == nullptr) {
this->handle = loadLibraryExA(name.c_str(), NULL, LOAD_LIBRARY_SEARCH_SYSTEM32);
if ((this->handle == nullptr) && (errorValue != nullptr)) {
getLastErrorString(errorValue);
if (properties.performSelfLoad) {
this->handle = getModuleHandleA(properties.libraryName.c_str());
this->selfOpen = true;
} else {
this->handle = loadDependency(properties.libraryName);
if (this->handle == nullptr) {
this->handle = loadLibraryExA(properties.libraryName.c_str(), NULL, LOAD_LIBRARY_SEARCH_SYSTEM32);
}
}
if ((this->handle == nullptr) && (properties.errorValue != nullptr)) {
getLastErrorString(properties.errorValue);
}
}
}
OsLibrary::~OsLibrary() {
if ((this->handle != nullptr) && (this->handle != GetModuleHandleA(nullptr))) {
::FreeLibrary(this->handle);
if (!this->selfOpen && this->handle) {
freeLibrary(this->handle);
this->handle = nullptr;
}
}

View File

@ -20,7 +20,7 @@ class OsLibrary : public NEO::OsLibrary, NEO::NonCopyableOrMovableClass {
HMODULE handle;
public:
OsLibrary(const std::string &name, std::string *errorValue);
OsLibrary(const OsLibraryCreateProperties &properties);
~OsLibrary();
bool isLoaded();
@ -32,8 +32,11 @@ class OsLibrary : public NEO::OsLibrary, NEO::NonCopyableOrMovableClass {
protected:
HMODULE loadDependency(const std::string &dependencyFileName) const;
static decltype(&GetModuleHandleA) getModuleHandleA;
static decltype(&LoadLibraryExA) loadLibraryExA;
static decltype(&GetModuleFileNameA) getModuleFileNameA;
static decltype(&FreeLibrary) freeLibrary;
bool selfOpen = false;
};
} // namespace Windows
} // namespace NEO

View File

@ -19,7 +19,7 @@ namespace NEO {
DxCoreAdapterFactory::DxCoreAdapterFactory(AdapterFactory::CreateAdapterFactoryFcn createAdapterFactoryFcn) : createAdapterFactoryFcn(createAdapterFactoryFcn) {
if (nullptr == createAdapterFactoryFcn) {
dxCoreLibrary.reset(OsLibrary::loadFunc(Os::dxcoreDllName));
dxCoreLibrary.reset(OsLibrary::loadFunc({Os::dxcoreDllName}));
if (dxCoreLibrary && dxCoreLibrary->isLoaded()) {
auto func = dxCoreLibrary->getProcAddress(dXCoreCreateAdapterFactoryFuncName);
createAdapterFactoryFcn = reinterpret_cast<DxCoreAdapterFactory::CreateAdapterFactoryFcn>(func);

View File

@ -17,7 +17,7 @@
namespace NEO {
bool PinContext::init(const std::string &gtPinOpenFunctionName) {
auto hGtPinLibrary = std::unique_ptr<OsLibrary>(OsLibrary::loadFunc(PinContext::gtPinLibraryFilename.c_str()));
auto hGtPinLibrary = std::unique_ptr<OsLibrary>(OsLibrary::loadFunc({PinContext::gtPinLibraryFilename}));
if (hGtPinLibrary == nullptr) {
PRINT_DEBUG_STRING(NEO::debugManager.flags.PrintDebugMessages.get(), stderr, "Unable to find gtpin library %s\n", PinContext::gtPinLibraryFilename.c_str());

View File

@ -35,7 +35,10 @@ class MockOsLibrary : public NEO::OsLibrary {
static OsLibrary *loadLibraryNewObject;
static OsLibrary *load(const std::string &name) {
static OsLibrary *load(const NEO::OsLibraryCreateProperties &properties) {
if (properties.errorValue) {
return OsLibrary::load(properties);
}
OsLibrary *ptr = loadLibraryNewObject;
loadLibraryNewObject = nullptr;
return ptr;

View File

@ -107,6 +107,8 @@ int (*sysCallsGetDevicePath)(int deviceFd, char *buf, size_t &bufSize) = nullptr
off_t lseekReturn = 4096u;
std::atomic<int> lseekCalledCount(0);
long sysconfReturn = 1ull << 30;
std::string dlOpenFilePathPassed;
bool captureDlOpenFilePath = false;
int mkdir(const std::string &path) {
if (sysCallsMkdir != nullptr) {
@ -168,6 +170,13 @@ int openWithMode(const char *file, int flags, int mode) {
void *dlopen(const char *filename, int flag) {
dlOpenFlags = flag;
dlOpenCalled = true;
if (captureDlOpenFilePath) {
if (filename) {
dlOpenFilePathPassed = filename;
} else {
dlOpenFilePathPassed = {};
}
}
return ::dlopen(filename, flag);
}

View File

@ -84,6 +84,8 @@ extern uint32_t munmapFuncCalled;
extern off_t lseekReturn;
extern std::atomic<int> lseekCalledCount;
extern bool captureDlOpenFilePath;
extern std::string dlOpenFilePathPassed;
extern long sysconfReturn;
} // namespace SysCalls

View File

@ -34,14 +34,14 @@ TEST(OsLibraryTest, GivenValidNameWhenGettingFullPathAndDlinfoFailsThenPathIsEmp
}
return 0;
});
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc(Os::testDllName));
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc({Os::testDllName}));
EXPECT_NE(nullptr, library);
std::string path = library->getFullPath();
EXPECT_EQ(0u, path.size());
}
TEST(OsLibraryTest, GivenValidLibNameWhenGettingFullPathThenPathIsNotEmpty) {
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc(Os::testDllName));
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc({Os::testDllName}));
EXPECT_NE(nullptr, library);
std::string path = library->getFullPath();
EXPECT_NE(0u, path.size());
@ -59,7 +59,8 @@ TEST(OsLibraryTest, GivenDisableDeepBindFlagWhenOpeningLibraryThenRtldDeepBindFl
VariableBackup<bool> dlOpenCalledBackup{&NEO::SysCalls::dlOpenCalled, false};
debugManager.flags.DisableDeepBind.set(1);
auto lib = std::make_unique<Linux::OsLibrary>("_abc.so", nullptr);
OsLibraryCreateProperties properties("_abc.so");
auto lib = std::make_unique<Linux::OsLibrary>(properties);
EXPECT_TRUE(NEO::SysCalls::dlOpenCalled);
EXPECT_EQ(0, NEO::SysCalls::dlOpenFlags & RTLD_DEEPBIND);
}
@ -68,7 +69,9 @@ TEST(OsLibraryTest, GivenInvalidLibraryWhenOpeningLibraryThenDlopenErrorIsReturn
VariableBackup<bool> dlOpenCalledBackup{&NEO::SysCalls::dlOpenCalled, false};
std::string errorValue;
auto lib = std::make_unique<Linux::OsLibrary>("_abc.so", &errorValue);
OsLibraryCreateProperties properties("_abc.so");
properties.errorValue = &errorValue;
auto lib = std::make_unique<Linux::OsLibrary>(properties);
EXPECT_FALSE(errorValue.empty());
EXPECT_TRUE(NEO::SysCalls::dlOpenCalled);
}
@ -76,12 +79,33 @@ TEST(OsLibraryTest, GivenInvalidLibraryWhenOpeningLibraryThenDlopenErrorIsReturn
TEST(OsLibraryTest, GivenLoadFlagsOverwriteWhenOpeningLibraryThenDlOpenIsCalledWithExpectedFlags) {
VariableBackup<int> dlOpenFlagsBackup{&NEO::SysCalls::dlOpenFlags, 0};
VariableBackup<bool> dlOpenCalledBackup{&NEO::SysCalls::dlOpenCalled, false};
VariableBackup<bool> dlOpenCaptureFileNameBackup{&NEO::SysCalls::captureDlOpenFilePath, true};
auto expectedFlag = RTLD_LAZY | RTLD_GLOBAL;
NEO::OsLibrary::loadFlagsOverwrite = &expectedFlag;
auto lib = std::make_unique<Linux::OsLibrary>("_abc.so", nullptr);
OsLibraryCreateProperties properties("_abc.so");
properties.customLoadFlags = &expectedFlag;
auto lib = std::make_unique<Linux::OsLibrary>(properties);
EXPECT_TRUE(NEO::SysCalls::dlOpenCalled);
EXPECT_EQ(NEO::SysCalls::dlOpenFlags, expectedFlag);
EXPECT_EQ(properties.libraryName, NEO::SysCalls::dlOpenFilePathPassed);
}
TEST(OsLibraryTest, WhenPerformSelfOpenThenIgnoreFileNameForDlOpenCall) {
VariableBackup<bool> dlOpenCalledBackup{&NEO::SysCalls::dlOpenCalled, false};
VariableBackup<bool> dlOpenCaptureFileNameBackup{&NEO::SysCalls::captureDlOpenFilePath, true};
OsLibraryCreateProperties properties("_abc.so");
properties.performSelfLoad = false;
auto lib = std::make_unique<Linux::OsLibrary>(properties);
EXPECT_TRUE(NEO::SysCalls::dlOpenCalled);
EXPECT_EQ(properties.libraryName, NEO::SysCalls::dlOpenFilePathPassed);
NEO::SysCalls::dlOpenCalled = false;
properties.performSelfLoad = true;
lib = std::make_unique<Linux::OsLibrary>(properties);
EXPECT_TRUE(NEO::SysCalls::dlOpenCalled);
EXPECT_NE(properties.libraryName, NEO::SysCalls::dlOpenFilePathPassed);
EXPECT_TRUE(NEO::SysCalls::dlOpenFilePathPassed.empty());
}
} // namespace NEO

View File

@ -22,7 +22,7 @@ const std::string fnName = "testDynamicLibraryFunc";
using namespace NEO;
TEST(OSLibraryTest, whenLibraryNameIsEmptyThenCurrentProcesIsUsedAsLibrary) {
std::unique_ptr<OsLibrary> library{OsLibrary::loadFunc("")};
std::unique_ptr<OsLibrary> library{OsLibrary::loadFunc({""})};
EXPECT_NE(nullptr, library);
void *ptr = library->getProcAddress("selfDynamicLibraryFunc");
EXPECT_NE(nullptr, ptr);
@ -35,32 +35,36 @@ TEST(OSLibraryTest, GivenFakeLibNameWhenLoadingLibraryThenNullIsReturned) {
TEST(OSLibraryTest, GivenFakeLibNameWhenLoadingLibraryThenNullIsReturnedAndErrorString) {
std::string errorValue;
OsLibrary *library = OsLibrary::loadAndCaptureError(fakeLibName, &errorValue);
OsLibraryCreateProperties properties(fakeLibName);
properties.errorValue = &errorValue;
OsLibrary *library = OsLibrary::loadFunc(properties);
EXPECT_FALSE(errorValue.empty());
EXPECT_EQ(nullptr, library);
}
TEST(OSLibraryTest, GivenValidLibNameWhenLoadingLibraryThenLibraryIsLoaded) {
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc(Os::testDllName));
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc({Os::testDllName}));
EXPECT_NE(nullptr, library);
}
TEST(OSLibraryTest, GivenValidLibNameWhenLoadingLibraryThenLibraryIsLoadedWithNoErrorString) {
std::string errorValue;
std::unique_ptr<OsLibrary> library(OsLibrary::loadAndCaptureError(Os::testDllName, &errorValue));
OsLibraryCreateProperties properties(Os::testDllName);
properties.errorValue = &errorValue;
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc(properties));
EXPECT_TRUE(errorValue.empty());
EXPECT_NE(nullptr, library);
}
TEST(OSLibraryTest, whenSymbolNameIsValidThenGetProcAddressReturnsNonNullPointer) {
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc(Os::testDllName));
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc({Os::testDllName}));
EXPECT_NE(nullptr, library);
void *ptr = library->getProcAddress(fnName);
EXPECT_NE(nullptr, ptr);
}
TEST(OSLibraryTest, whenSymbolNameIsInvalidThenGetProcAddressReturnsNullPointer) {
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc(Os::testDllName));
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc({Os::testDllName}));
EXPECT_NE(nullptr, library);
void *ptr = library->getProcAddress(fnName + "invalid");
EXPECT_EQ(nullptr, ptr);

View File

@ -20,6 +20,11 @@ extern const char *testDllName;
using namespace NEO;
class OsLibraryBackup : public Windows::OsLibrary {
public:
using Windows::OsLibrary::freeLibrary;
using Windows::OsLibrary::getModuleHandleA;
using Windows::OsLibrary::loadLibraryExA;
using Type = decltype(Windows::OsLibrary::loadLibraryExA);
using BackupType = VariableBackup<Type>;
@ -35,7 +40,6 @@ class OsLibraryBackup : public Windows::OsLibrary {
std::unique_ptr<SystemDirectoryBackupType> bkp3 = nullptr;
};
public:
static std::unique_ptr<Backup> backup(Type newValue, ModuleNameType newModuleName, SystemDirectoryType newSystemDirectoryName) {
std::unique_ptr<Backup> bkp(new Backup());
bkp->bkp1.reset(new BackupType(&OsLibrary::loadLibraryExA, newValue));
@ -84,7 +88,7 @@ UINT WINAPI getSystemDirectoryAMock(LPSTR lpBuffer, UINT uSize) {
TEST(OSLibraryWinTest, WhenLoadDependencyFailsThenFallbackToSystem32) {
auto bkp = OsLibraryBackup::backup(loadLibraryExAMock, getModuleFileNameAMock, getSystemDirectoryAMock);
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc(Os::testDllName));
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc({Os::testDllName}));
EXPECT_NE(nullptr, library);
}
@ -92,7 +96,7 @@ TEST(OSLibraryWinTest, WhenDependencyLoadsThenProperPathIsConstructed) {
auto bkp = OsLibraryBackup::backup(loadLibraryExAMock, getModuleFileNameAMock, getSystemDirectoryAMock);
VariableBackup<bool> bkpM(&mockWillFailInNonSystem32, false);
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc(Os::testDllName));
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc({Os::testDllName}));
EXPECT_NE(nullptr, library);
}
@ -106,17 +110,85 @@ TEST(OSLibraryWinTest, WhenCreatingFullSystemPathThenProperPathIsConstructed) {
TEST(OSLibraryWinTest, GivenInvalidLibraryWhenOpeningLibraryThenLoadLibraryErrorIsReturned) {
std::string errorValue;
auto lib = std::make_unique<Windows::OsLibrary>("abc", &errorValue);
OsLibraryCreateProperties properties("abc");
properties.errorValue = &errorValue;
auto lib = std::make_unique<Windows::OsLibrary>(properties);
EXPECT_FALSE(errorValue.empty());
}
TEST(OSLibraryWinTest, GivenNoLastErrorOnWindowsThenErrorStringisEmpty) {
std::string errorValue;
auto lib = std::make_unique<Windows::OsLibrary>(Os::testDllName, &errorValue);
OsLibraryCreateProperties properties(Os::testDllName);
properties.errorValue = &errorValue;
auto lib = std::make_unique<Windows::OsLibrary>(properties);
EXPECT_NE(nullptr, lib);
EXPECT_TRUE(errorValue.empty());
lib->getLastErrorString(&errorValue);
EXPECT_TRUE(errorValue.empty());
lib->getLastErrorString(nullptr);
}
TEST(OSLibraryWinTest, WhenCreateOsLibraryWithSelfOpenThenDontLoadLibraryOrFreeLibrary) {
OsLibraryCreateProperties properties(Os::testDllName);
VariableBackup<decltype(OsLibraryBackup::loadLibraryExA)> backupLoadLibrary(&OsLibraryBackup::loadLibraryExA, [](LPCSTR, HANDLE, DWORD) -> HMODULE {
UNRECOVERABLE_IF(true);
return nullptr;
});
VariableBackup<decltype(OsLibraryBackup::getModuleHandleA)> backupGetModuleHandle(&OsLibraryBackup::getModuleHandleA, [](LPCSTR moduleName) -> HMODULE {
return nullptr;
});
VariableBackup<decltype(OsLibraryBackup::freeLibrary)> backupFreeLibrary(&OsLibraryBackup::freeLibrary, [](HMODULE) -> BOOL {
UNRECOVERABLE_IF(true);
return FALSE;
});
properties.performSelfLoad = true;
auto lib = std::make_unique<Windows::OsLibrary>(properties);
EXPECT_FALSE(lib->isLoaded());
OsLibraryBackup::getModuleHandleA = [](LPCSTR moduleName) -> HMODULE {
return reinterpret_cast<HMODULE>(0x1000);
};
lib = std::make_unique<Windows::OsLibrary>(properties);
EXPECT_TRUE(lib->isLoaded());
properties.libraryName.clear();
lib = std::make_unique<Windows::OsLibrary>(properties);
EXPECT_TRUE(lib->isLoaded());
}
TEST(OSLibraryWinTest, WhenCreateOsLibraryWithoutSelfOpenThenLoadAndFreeLibrary) {
static uint32_t loadLibraryCalled = 0;
static uint32_t freeLibraryCalled = 0;
VariableBackup<decltype(loadLibraryCalled)> backupLoadLibraryCalled(&loadLibraryCalled, 0);
VariableBackup<decltype(freeLibraryCalled)> backupFreeLibraryCalled(&freeLibraryCalled, 0);
static HMODULE hModule = reinterpret_cast<HMODULE>(0x1000);
OsLibraryCreateProperties properties(Os::testDllName);
VariableBackup<decltype(OsLibraryBackup::loadLibraryExA)> backupLoadLibrary(&OsLibraryBackup::loadLibraryExA, [](LPCSTR, HANDLE, DWORD) -> HMODULE {
loadLibraryCalled++;
return hModule;
});
VariableBackup<decltype(OsLibraryBackup::getModuleHandleA)> backupGetModuleHandle(&OsLibraryBackup::getModuleHandleA, [](LPCSTR moduleName) -> HMODULE {
UNRECOVERABLE_IF(true);
return nullptr;
});
VariableBackup<decltype(OsLibraryBackup::freeLibrary)> backupFreeLibrary(&OsLibraryBackup::freeLibrary, [](HMODULE input) -> BOOL {
EXPECT_EQ(hModule, input);
freeLibraryCalled++;
return FALSE;
});
properties.performSelfLoad = false;
auto lib = std::make_unique<Windows::OsLibrary>(properties);
EXPECT_TRUE(lib->isLoaded());
EXPECT_EQ(1u, loadLibraryCalled);
EXPECT_EQ(0u, freeLibraryCalled);
lib.reset();
EXPECT_EQ(1u, freeLibraryCalled);
}