fix: correct loading L0 loader functions

on Windows use getModuleHandleA with proper module name
don't load ze_loader.dll from file system

Signed-off-by: Mateusz Jablonski <mateusz.jablonski@intel.com>
This commit is contained in:
Mateusz Jablonski
2024-10-16 17:52:41 +00:00
committed by Compute-Runtime-Automation
parent deb27d0363
commit 9c7b3c5e19
37 changed files with 167 additions and 86 deletions

View File

@@ -34,14 +34,14 @@ TEST(OsLibraryTest, GivenValidNameWhenGettingFullPathAndDlinfoFailsThenPathIsEmp
}
return 0;
});
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc(Os::testDllName));
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc({Os::testDllName}));
EXPECT_NE(nullptr, library);
std::string path = library->getFullPath();
EXPECT_EQ(0u, path.size());
}
TEST(OsLibraryTest, GivenValidLibNameWhenGettingFullPathThenPathIsNotEmpty) {
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc(Os::testDllName));
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc({Os::testDllName}));
EXPECT_NE(nullptr, library);
std::string path = library->getFullPath();
EXPECT_NE(0u, path.size());
@@ -59,7 +59,8 @@ TEST(OsLibraryTest, GivenDisableDeepBindFlagWhenOpeningLibraryThenRtldDeepBindFl
VariableBackup<bool> dlOpenCalledBackup{&NEO::SysCalls::dlOpenCalled, false};
debugManager.flags.DisableDeepBind.set(1);
auto lib = std::make_unique<Linux::OsLibrary>("_abc.so", nullptr);
OsLibraryCreateProperties properties("_abc.so");
auto lib = std::make_unique<Linux::OsLibrary>(properties);
EXPECT_TRUE(NEO::SysCalls::dlOpenCalled);
EXPECT_EQ(0, NEO::SysCalls::dlOpenFlags & RTLD_DEEPBIND);
}
@@ -68,7 +69,9 @@ TEST(OsLibraryTest, GivenInvalidLibraryWhenOpeningLibraryThenDlopenErrorIsReturn
VariableBackup<bool> dlOpenCalledBackup{&NEO::SysCalls::dlOpenCalled, false};
std::string errorValue;
auto lib = std::make_unique<Linux::OsLibrary>("_abc.so", &errorValue);
OsLibraryCreateProperties properties("_abc.so");
properties.errorValue = &errorValue;
auto lib = std::make_unique<Linux::OsLibrary>(properties);
EXPECT_FALSE(errorValue.empty());
EXPECT_TRUE(NEO::SysCalls::dlOpenCalled);
}
@@ -76,12 +79,33 @@ TEST(OsLibraryTest, GivenInvalidLibraryWhenOpeningLibraryThenDlopenErrorIsReturn
TEST(OsLibraryTest, GivenLoadFlagsOverwriteWhenOpeningLibraryThenDlOpenIsCalledWithExpectedFlags) {
VariableBackup<int> dlOpenFlagsBackup{&NEO::SysCalls::dlOpenFlags, 0};
VariableBackup<bool> dlOpenCalledBackup{&NEO::SysCalls::dlOpenCalled, false};
VariableBackup<bool> dlOpenCaptureFileNameBackup{&NEO::SysCalls::captureDlOpenFilePath, true};
auto expectedFlag = RTLD_LAZY | RTLD_GLOBAL;
NEO::OsLibrary::loadFlagsOverwrite = &expectedFlag;
auto lib = std::make_unique<Linux::OsLibrary>("_abc.so", nullptr);
OsLibraryCreateProperties properties("_abc.so");
properties.customLoadFlags = &expectedFlag;
auto lib = std::make_unique<Linux::OsLibrary>(properties);
EXPECT_TRUE(NEO::SysCalls::dlOpenCalled);
EXPECT_EQ(NEO::SysCalls::dlOpenFlags, expectedFlag);
EXPECT_EQ(properties.libraryName, NEO::SysCalls::dlOpenFilePathPassed);
}
TEST(OsLibraryTest, WhenPerformSelfOpenThenIgnoreFileNameForDlOpenCall) {
VariableBackup<bool> dlOpenCalledBackup{&NEO::SysCalls::dlOpenCalled, false};
VariableBackup<bool> dlOpenCaptureFileNameBackup{&NEO::SysCalls::captureDlOpenFilePath, true};
OsLibraryCreateProperties properties("_abc.so");
properties.performSelfLoad = false;
auto lib = std::make_unique<Linux::OsLibrary>(properties);
EXPECT_TRUE(NEO::SysCalls::dlOpenCalled);
EXPECT_EQ(properties.libraryName, NEO::SysCalls::dlOpenFilePathPassed);
NEO::SysCalls::dlOpenCalled = false;
properties.performSelfLoad = true;
lib = std::make_unique<Linux::OsLibrary>(properties);
EXPECT_TRUE(NEO::SysCalls::dlOpenCalled);
EXPECT_NE(properties.libraryName, NEO::SysCalls::dlOpenFilePathPassed);
EXPECT_TRUE(NEO::SysCalls::dlOpenFilePathPassed.empty());
}
} // namespace NEO

View File

@@ -22,7 +22,7 @@ const std::string fnName = "testDynamicLibraryFunc";
using namespace NEO;
TEST(OSLibraryTest, whenLibraryNameIsEmptyThenCurrentProcesIsUsedAsLibrary) {
std::unique_ptr<OsLibrary> library{OsLibrary::loadFunc("")};
std::unique_ptr<OsLibrary> library{OsLibrary::loadFunc({""})};
EXPECT_NE(nullptr, library);
void *ptr = library->getProcAddress("selfDynamicLibraryFunc");
EXPECT_NE(nullptr, ptr);
@@ -35,32 +35,36 @@ TEST(OSLibraryTest, GivenFakeLibNameWhenLoadingLibraryThenNullIsReturned) {
TEST(OSLibraryTest, GivenFakeLibNameWhenLoadingLibraryThenNullIsReturnedAndErrorString) {
std::string errorValue;
OsLibrary *library = OsLibrary::loadAndCaptureError(fakeLibName, &errorValue);
OsLibraryCreateProperties properties(fakeLibName);
properties.errorValue = &errorValue;
OsLibrary *library = OsLibrary::loadFunc(properties);
EXPECT_FALSE(errorValue.empty());
EXPECT_EQ(nullptr, library);
}
TEST(OSLibraryTest, GivenValidLibNameWhenLoadingLibraryThenLibraryIsLoaded) {
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc(Os::testDllName));
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc({Os::testDllName}));
EXPECT_NE(nullptr, library);
}
TEST(OSLibraryTest, GivenValidLibNameWhenLoadingLibraryThenLibraryIsLoadedWithNoErrorString) {
std::string errorValue;
std::unique_ptr<OsLibrary> library(OsLibrary::loadAndCaptureError(Os::testDllName, &errorValue));
OsLibraryCreateProperties properties(Os::testDllName);
properties.errorValue = &errorValue;
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc(properties));
EXPECT_TRUE(errorValue.empty());
EXPECT_NE(nullptr, library);
}
TEST(OSLibraryTest, whenSymbolNameIsValidThenGetProcAddressReturnsNonNullPointer) {
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc(Os::testDllName));
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc({Os::testDllName}));
EXPECT_NE(nullptr, library);
void *ptr = library->getProcAddress(fnName);
EXPECT_NE(nullptr, ptr);
}
TEST(OSLibraryTest, whenSymbolNameIsInvalidThenGetProcAddressReturnsNullPointer) {
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc(Os::testDllName));
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc({Os::testDllName}));
EXPECT_NE(nullptr, library);
void *ptr = library->getProcAddress(fnName + "invalid");
EXPECT_EQ(nullptr, ptr);

View File

@@ -20,6 +20,10 @@ extern const char *testDllName;
using namespace NEO;
class OsLibraryBackup : public Windows::OsLibrary {
public:
using Windows::OsLibrary::getModuleHandleA;
using Windows::OsLibrary::loadLibraryExA;
using Type = decltype(Windows::OsLibrary::loadLibraryExA);
using BackupType = VariableBackup<Type>;
@@ -35,7 +39,6 @@ class OsLibraryBackup : public Windows::OsLibrary {
std::unique_ptr<SystemDirectoryBackupType> bkp3 = nullptr;
};
public:
static std::unique_ptr<Backup> backup(Type newValue, ModuleNameType newModuleName, SystemDirectoryType newSystemDirectoryName) {
std::unique_ptr<Backup> bkp(new Backup());
bkp->bkp1.reset(new BackupType(&OsLibrary::loadLibraryExA, newValue));
@@ -84,7 +87,7 @@ UINT WINAPI getSystemDirectoryAMock(LPSTR lpBuffer, UINT uSize) {
TEST(OSLibraryWinTest, WhenLoadDependencyFailsThenFallbackToSystem32) {
auto bkp = OsLibraryBackup::backup(loadLibraryExAMock, getModuleFileNameAMock, getSystemDirectoryAMock);
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc(Os::testDllName));
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc({Os::testDllName}));
EXPECT_NE(nullptr, library);
}
@@ -92,7 +95,7 @@ TEST(OSLibraryWinTest, WhenDependencyLoadsThenProperPathIsConstructed) {
auto bkp = OsLibraryBackup::backup(loadLibraryExAMock, getModuleFileNameAMock, getSystemDirectoryAMock);
VariableBackup<bool> bkpM(&mockWillFailInNonSystem32, false);
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc(Os::testDllName));
std::unique_ptr<OsLibrary> library(OsLibrary::loadFunc({Os::testDllName}));
EXPECT_NE(nullptr, library);
}
@@ -106,17 +109,41 @@ TEST(OSLibraryWinTest, WhenCreatingFullSystemPathThenProperPathIsConstructed) {
TEST(OSLibraryWinTest, GivenInvalidLibraryWhenOpeningLibraryThenLoadLibraryErrorIsReturned) {
std::string errorValue;
auto lib = std::make_unique<Windows::OsLibrary>("abc", &errorValue);
OsLibraryCreateProperties properties("abc");
properties.errorValue = &errorValue;
auto lib = std::make_unique<Windows::OsLibrary>(properties);
EXPECT_FALSE(errorValue.empty());
}
TEST(OSLibraryWinTest, GivenNoLastErrorOnWindowsThenErrorStringisEmpty) {
std::string errorValue;
auto lib = std::make_unique<Windows::OsLibrary>(Os::testDllName, &errorValue);
OsLibraryCreateProperties properties(Os::testDllName);
properties.errorValue = &errorValue;
auto lib = std::make_unique<Windows::OsLibrary>(properties);
EXPECT_NE(nullptr, lib);
EXPECT_TRUE(errorValue.empty());
lib->getLastErrorString(&errorValue);
EXPECT_TRUE(errorValue.empty());
lib->getLastErrorString(nullptr);
}
TEST(OSLibraryWinTest, WhenCreateOsLibraryWithSelfOpenThenDontLoadLibrary) {
OsLibraryCreateProperties properties(Os::testDllName);
VariableBackup<decltype(OsLibraryBackup::loadLibraryExA)> backupLoadLibrary(&OsLibraryBackup::loadLibraryExA, [](LPCSTR, HANDLE, DWORD) -> HMODULE {
UNRECOVERABLE_IF(true);
return nullptr;
});
VariableBackup<decltype(OsLibraryBackup::getModuleHandleA)> backupGetModuleHandle(&OsLibraryBackup::getModuleHandleA, [](LPCSTR moduleName) -> HMODULE {
return nullptr;
});
properties.performSelfLoad = true;
auto lib = std::make_unique<Windows::OsLibrary>(properties);
EXPECT_FALSE(lib->isLoaded());
OsLibraryBackup::getModuleHandleA = [](LPCSTR moduleName) -> HMODULE {
return reinterpret_cast<HMODULE>(0x1000);
};
lib = std::make_unique<Windows::OsLibrary>(properties);
EXPECT_TRUE(lib->isLoaded());
}