diff --git a/core/page_fault_manager/linux/cpu_page_fault_manager_linux.cpp b/core/page_fault_manager/linux/cpu_page_fault_manager_linux.cpp index 2d27ead7c3..e831aca765 100644 --- a/core/page_fault_manager/linux/cpu_page_fault_manager_linux.cpp +++ b/core/page_fault_manager/linux/cpu_page_fault_manager_linux.cpp @@ -21,7 +21,7 @@ std::function PageFaultManager PageFaultManagerLinux::PageFaultManagerLinux() { pageFaultHandler = [&](int signal, siginfo_t *info, void *context) { if (!this->verifyPageFault(info->si_addr)) { - previousHandler.sa_sigaction(signal, info, context); + callPreviousHandler(signal, info, context); } }; @@ -33,8 +33,10 @@ PageFaultManagerLinux::PageFaultManagerLinux() { } PageFaultManagerLinux::~PageFaultManagerLinux() { - auto retVal = sigaction(SIGSEGV, &previousHandler, nullptr); - UNRECOVERABLE_IF(retVal != 0); + if (!previousHandlerRestored) { + auto retVal = sigaction(SIGSEGV, &previousHandler, nullptr); + UNRECOVERABLE_IF(retVal != 0); + } } void PageFaultManagerLinux::pageFaultHandlerWrapper(int signal, siginfo_t *info, void *context) { @@ -50,4 +52,20 @@ void PageFaultManagerLinux::protectCPUMemoryAccess(void *ptr, size_t size) { auto retVal = mprotect(ptr, size, PROT_NONE); UNRECOVERABLE_IF(retVal != 0); } + +void PageFaultManagerLinux::callPreviousHandler(int signal, siginfo_t *info, void *context) { + if (previousHandler.sa_flags & SA_SIGINFO) { + previousHandler.sa_sigaction(signal, info, context); + } else { + if (previousHandler.sa_handler == SIG_DFL) { + auto retVal = sigaction(SIGSEGV, &previousHandler, nullptr); + UNRECOVERABLE_IF(retVal != 0); + previousHandlerRestored = true; + } else if (previousHandler.sa_handler == SIG_IGN) { + return; + } else { + previousHandler.sa_handler(signal); + } + } +} } // namespace NEO diff --git a/core/page_fault_manager/linux/cpu_page_fault_manager_linux.h b/core/page_fault_manager/linux/cpu_page_fault_manager_linux.h index 99d744c971..f06276d596 100644 --- a/core/page_fault_manager/linux/cpu_page_fault_manager_linux.h +++ b/core/page_fault_manager/linux/cpu_page_fault_manager_linux.h @@ -24,6 +24,9 @@ class PageFaultManagerLinux : public PageFaultManager { void allowCPUMemoryAccess(void *ptr, size_t size) override; void protectCPUMemoryAccess(void *ptr, size_t size) override; + void callPreviousHandler(int signal, siginfo_t *info, void *context); + bool previousHandlerRestored = false; + static std::function pageFaultHandler; struct sigaction previousHandler = {}; }; diff --git a/core/unit_tests/page_fault_manager/linux/cpu_page_fault_manager_linux_tests.cpp b/core/unit_tests/page_fault_manager/linux/cpu_page_fault_manager_linux_tests.cpp index a66f6c4755..fe31751898 100644 --- a/core/unit_tests/page_fault_manager/linux/cpu_page_fault_manager_linux_tests.cpp +++ b/core/unit_tests/page_fault_manager/linux/cpu_page_fault_manager_linux_tests.cpp @@ -42,7 +42,9 @@ TEST(PageFaultManagerLinuxTest, givenProtectedMemoryWhenTryingToAccessThenPageFa class MockFailPageFaultManager : public PageFaultManagerLinux { public: + using PageFaultManagerLinux::callPreviousHandler; using PageFaultManagerLinux::PageFaultManagerLinux; + using PageFaultManagerLinux::previousHandlerRestored; bool verifyPageFault(void *ptr) override { verifyCalled = true; @@ -53,15 +55,24 @@ class MockFailPageFaultManager : public PageFaultManagerLinux { mockCalled = true; } + static void mockPageFaultSimpleHandler(int signal) { + simpleMockCalled = true; + } + ~MockFailPageFaultManager() override { mockCalled = false; + simpleMockCalled = false; } + static bool mockCalled; + static bool simpleMockCalled; bool verifyCalled = false; }; -bool MockFailPageFaultManager::mockCalled = false; -TEST(PageFaultManagerLinuxTest, givenPageFaultThatNEOShouldNotHandleThenDefaultHandlerIsCalled) { +bool MockFailPageFaultManager::mockCalled = false; +bool MockFailPageFaultManager::simpleMockCalled = false; + +TEST(PageFaultManagerLinuxTest, givenPageFaultThatNEOShouldNotHandleAndSigInfoFlagSetThenSaSigactionIsCalled) { struct sigaction previousHandler = {}; struct sigaction mockHandler = {}; mockHandler.sa_flags = SA_SIGINFO; @@ -69,12 +80,67 @@ TEST(PageFaultManagerLinuxTest, givenPageFaultThatNEOShouldNotHandleThenDefaultH auto retVal = sigaction(SIGSEGV, &mockHandler, &previousHandler); EXPECT_EQ(retVal, 0); - MockFailPageFaultManager mockPageFaultManager; + auto mockPageFaultManager = std::make_unique(); EXPECT_FALSE(MockFailPageFaultManager::mockCalled); + EXPECT_FALSE(MockFailPageFaultManager::simpleMockCalled); std::raise(SIGSEGV); - EXPECT_TRUE(mockPageFaultManager.verifyCalled); + EXPECT_TRUE(mockPageFaultManager->verifyCalled); EXPECT_TRUE(MockFailPageFaultManager::mockCalled); + EXPECT_FALSE(MockFailPageFaultManager::simpleMockCalled); + mockPageFaultManager.reset(); sigaction(SIGSEGV, &previousHandler, nullptr); } + +TEST(PageFaultManagerLinuxTest, givenPageFaultThatNEOShouldNotHandleThenSaHandlerIsCalled) { + struct sigaction previousHandler = {}; + struct sigaction mockHandler = {}; + mockHandler.sa_handler = MockFailPageFaultManager::mockPageFaultSimpleHandler; + auto retVal = sigaction(SIGSEGV, &mockHandler, &previousHandler); + EXPECT_EQ(retVal, 0); + + auto mockPageFaultManager = std::make_unique(); + EXPECT_FALSE(MockFailPageFaultManager::mockCalled); + EXPECT_FALSE(MockFailPageFaultManager::simpleMockCalled); + + std::raise(SIGSEGV); + EXPECT_TRUE(mockPageFaultManager->verifyCalled); + EXPECT_FALSE(MockFailPageFaultManager::mockCalled); + EXPECT_TRUE(MockFailPageFaultManager::simpleMockCalled); + + mockPageFaultManager.reset(); + sigaction(SIGSEGV, &previousHandler, nullptr); +} + +TEST(PageFaultManagerLinuxTest, givenDefaultSaHandlerWhenInvokeCallPreviousSaHandlerThenPreviousHandlerIsRestored) { + struct sigaction originalHandler = {}; + struct sigaction mockDefaultHandler = {}; + mockDefaultHandler.sa_handler = SIG_DFL; + auto retVal = sigaction(SIGSEGV, &mockDefaultHandler, &originalHandler); + EXPECT_EQ(retVal, 0); + + auto mockPageFaultManager = std::make_unique(); + mockPageFaultManager->callPreviousHandler(0, nullptr, nullptr); + + EXPECT_TRUE(mockPageFaultManager->previousHandlerRestored); + + mockPageFaultManager.reset(); + sigaction(SIGSEGV, &originalHandler, nullptr); +} + +TEST(PageFaultManagerLinuxTest, givenIgnoringSaHandlerWhenInvokeCallPreviousSaHandlerThenNothingHappend) { + struct sigaction originalHandler = {}; + struct sigaction mockDefaultHandler = {}; + mockDefaultHandler.sa_handler = SIG_IGN; + auto retVal = sigaction(SIGSEGV, &mockDefaultHandler, &originalHandler); + EXPECT_EQ(retVal, 0); + + auto mockPageFaultManager = std::make_unique(); + mockPageFaultManager->callPreviousHandler(0, nullptr, nullptr); + + EXPECT_FALSE(mockPageFaultManager->previousHandlerRestored); + + mockPageFaultManager.reset(); + sigaction(SIGSEGV, &originalHandler, nullptr); +}