From 6ba76147d0b3ccf086cca758ad0e6c2c6b84dfba Mon Sep 17 00:00:00 2001 From: Mateusz Jablonski Date: Mon, 23 Dec 2019 12:09:16 +0100 Subject: [PATCH] Pass proper handle to initContextAuxTableRegister function Resolves: NEO-4082, NEO-4080, NEO-4079 Change-Id: If8d0b69126d6442e8a9a102cd21f78944f8551e9 Signed-off-by: Mateusz Jablonski --- runtime/gmm_helper/page_table_mngr.h | 4 ++++ runtime/gmm_helper/page_table_mngr_impl.cpp | 7 ++++++- .../os_interface/linux/page_table_manager_functions.cpp | 6 ++++-- .../windows/page_table_manager_functions.cpp | 9 +++++---- unit_tests/mocks/mock_gmm_page_table_mngr.cpp | 4 +++- unit_tests/mocks/mock_gmm_page_table_mngr.h | 7 +++++-- .../os_interface/linux/drm_command_stream_tests.cpp | 4 +++- .../os_interface/windows/device_command_stream_tests.cpp | 2 +- 8 files changed, 31 insertions(+), 12 deletions(-) diff --git a/runtime/gmm_helper/page_table_mngr.h b/runtime/gmm_helper/page_table_mngr.h index 35e9075233..21e373f8f7 100644 --- a/runtime/gmm_helper/page_table_mngr.h +++ b/runtime/gmm_helper/page_table_mngr.h @@ -16,6 +16,8 @@ namespace NEO { class Gmm; class LinearStream; + +void gmmSetCsrHandle(GMM_PAGETABLE_MGR *pageTableManager, HANDLE csrHandle); class GmmPageTableMngr { public: MOCKABLE_VIRTUAL ~GmmPageTableMngr(); @@ -43,5 +45,7 @@ class GmmPageTableMngr { GmmPageTableMngr(unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb); GMM_CLIENT_CONTEXT *clientContext = nullptr; GMM_PAGETABLE_MGR *pageTableManager = nullptr; + decltype(&gmmSetCsrHandle) gmmSetCsrHandleFunc = gmmSetCsrHandle; + void *csrHandle = nullptr; }; } // namespace NEO diff --git a/runtime/gmm_helper/page_table_mngr_impl.cpp b/runtime/gmm_helper/page_table_mngr_impl.cpp index 197308d322..46dd33de22 100644 --- a/runtime/gmm_helper/page_table_mngr_impl.cpp +++ b/runtime/gmm_helper/page_table_mngr_impl.cpp @@ -31,10 +31,15 @@ bool GmmPageTableMngr::updateAuxTable(uint64_t gpuVa, Gmm *gmm, bool map) { void GmmPageTableMngr::initPageTableManagerRegisters() { if (!initialized) { - initContextAuxTableRegister(this, GMM_ENGINE_TYPE::ENGINE_TYPE_RCS); + initContextAuxTableRegister(csrHandle, GMM_ENGINE_TYPE::ENGINE_TYPE_RCS); initialized = true; } } +void GmmPageTableMngr::setCsrHandle(void *csrHandleIn) { + csrHandle = csrHandleIn; + gmmSetCsrHandleFunc(pageTableManager, csrHandle); +} + } // namespace NEO diff --git a/runtime/os_interface/linux/page_table_manager_functions.cpp b/runtime/os_interface/linux/page_table_manager_functions.cpp index 446b34828b..97512e2ecd 100644 --- a/runtime/os_interface/linux/page_table_manager_functions.cpp +++ b/runtime/os_interface/linux/page_table_manager_functions.cpp @@ -12,10 +12,12 @@ #include "gmm_client_context.h" namespace NEO { + +void gmmSetCsrHandle(GMM_PAGETABLE_MGR *pageTableManager, HANDLE csrHandle) { +} + GmmPageTableMngr::GmmPageTableMngr(unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb) { clientContext = platform()->peekGmmClientContext()->getHandle(); pageTableManager = clientContext->CreatePageTblMgrObject(translationTableFlags); } - -void GmmPageTableMngr::setCsrHandle(void *csrHandle) {} } // namespace NEO diff --git a/runtime/os_interface/windows/page_table_manager_functions.cpp b/runtime/os_interface/windows/page_table_manager_functions.cpp index 5ecde957c6..85ab8a6081 100644 --- a/runtime/os_interface/windows/page_table_manager_functions.cpp +++ b/runtime/os_interface/windows/page_table_manager_functions.cpp @@ -12,12 +12,13 @@ #include "gmm_client_context.h" namespace NEO { + +void gmmSetCsrHandle(GMM_PAGETABLE_MGR *pageTableManager, HANDLE csrHandle) { + pageTableManager->GmmSetCsrHandle(csrHandle); +} + GmmPageTableMngr::GmmPageTableMngr(unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb) { clientContext = platform()->peekGmmClientContext()->getHandle(); pageTableManager = clientContext->CreatePageTblMgrObject(translationTableCb, translationTableFlags); } - -void GmmPageTableMngr::setCsrHandle(void *csrHandle) { - pageTableManager->GmmSetCsrHandle(csrHandle); -} } // namespace NEO diff --git a/unit_tests/mocks/mock_gmm_page_table_mngr.cpp b/unit_tests/mocks/mock_gmm_page_table_mngr.cpp index 136a7d37b0..96a1ddbbcc 100644 --- a/unit_tests/mocks/mock_gmm_page_table_mngr.cpp +++ b/unit_tests/mocks/mock_gmm_page_table_mngr.cpp @@ -7,6 +7,8 @@ #include "unit_tests/mocks/mock_gmm_page_table_mngr.h" +void dummySetCsrHandle(GMM_PAGETABLE_MGR *, HANDLE){}; + namespace NEO { using namespace ::testing; @@ -17,7 +19,7 @@ GmmPageTableMngr *GmmPageTableMngr::create(unsigned int translationTableFlags, G return pageTableMngr; } void MockGmmPageTableMngr::setCsrHandle(void *csrHandle) { - passedCsrHandle = csrHandle; + GmmPageTableMngr::setCsrHandle(csrHandle); setCsrHanleCalled++; } } // namespace NEO diff --git a/unit_tests/mocks/mock_gmm_page_table_mngr.h b/unit_tests/mocks/mock_gmm_page_table_mngr.h index 45d44658de..fc2def6a67 100644 --- a/unit_tests/mocks/mock_gmm_page_table_mngr.h +++ b/unit_tests/mocks/mock_gmm_page_table_mngr.h @@ -16,16 +16,20 @@ #pragma clang diagnostic ignored "-Winconsistent-missing-override" #endif +void dummySetCsrHandle(GMM_PAGETABLE_MGR *, HANDLE); + namespace NEO { class MockGmmPageTableMngr : public GmmPageTableMngr { public: - MockGmmPageTableMngr() = default; + using GmmPageTableMngr::csrHandle; + MockGmmPageTableMngr() : MockGmmPageTableMngr(0u, nullptr){}; MockGmmPageTableMngr(unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb) : translationTableFlags(translationTableFlags) { if (translationTableCb) { this->translationTableCb = *translationTableCb; } + gmmSetCsrHandleFunc = dummySetCsrHandle; }; MOCK_METHOD2(initContextAuxTableRegister, GMM_STATUS(HANDLE initialBBHandle, GMM_ENGINE_TYPE engineType)); @@ -35,7 +39,6 @@ class MockGmmPageTableMngr : public GmmPageTableMngr { void setCsrHandle(void *csrHandle) override; uint32_t setCsrHanleCalled = 0; - void *passedCsrHandle = nullptr; unsigned int translationTableFlags = 0; GMM_TRANSLATIONTABLE_CALLBACKS translationTableCb = {}; diff --git a/unit_tests/os_interface/linux/drm_command_stream_tests.cpp b/unit_tests/os_interface/linux/drm_command_stream_tests.cpp index b16cd6f389..a0f3e32b00 100644 --- a/unit_tests/os_interface/linux/drm_command_stream_tests.cpp +++ b/unit_tests/os_interface/linux/drm_command_stream_tests.cpp @@ -1397,7 +1397,9 @@ HWTEST_TEMPLATED_F(DrmCommandStreamTest, givenDrmCommandStreamReceiverWhenInitia auto &rootDeviceEnvironment = executionEnvironment.rootDeviceEnvironments[1]; MockGmmPageTableMngr *mockMngr = static_cast(rootDeviceEnvironment->pageTableManager.get()); - EXPECT_CALL(*mockMngr, initContextAuxTableRegister(::testing::_, ::testing::_)).Times(1); + auto csrHandle = reinterpret_cast(0x1234); + mockMngr->setCsrHandle(csrHandle); + EXPECT_CALL(*mockMngr, initContextAuxTableRegister(csrHandle, ::testing::_)).Times(1); EXPECT_FALSE(rootDeviceEnvironment->pageTableManager->initialized); LinearStream linearStream = {}; diff --git a/unit_tests/os_interface/windows/device_command_stream_tests.cpp b/unit_tests/os_interface/windows/device_command_stream_tests.cpp index e9be4eb863..a88c4bbd02 100644 --- a/unit_tests/os_interface/windows/device_command_stream_tests.cpp +++ b/unit_tests/os_interface/windows/device_command_stream_tests.cpp @@ -872,7 +872,7 @@ HWTEST_P(WddmCsrCompressionParameterizedTest, givenEnabledCompressionWhenInitial auto mockMngr = reinterpret_cast(executionEnvironment->rootDeviceEnvironments[index]->pageTableManager.get()); EXPECT_EQ(1u, mockMngr->setCsrHanleCalled); - EXPECT_EQ(&mockWddmCsr, mockMngr->passedCsrHandle); + EXPECT_EQ(&mockWddmCsr, mockMngr->csrHandle); GMM_TRANSLATIONTABLE_CALLBACKS expectedTTCallbacks = {}; unsigned int expectedFlags = TT_TYPE::AUXTT;