diff --git a/core/gmm_helper/page_table_mngr.cpp b/core/gmm_helper/page_table_mngr.cpp index d284337079..ca72ba9ae3 100644 --- a/core/gmm_helper/page_table_mngr.cpp +++ b/core/gmm_helper/page_table_mngr.cpp @@ -8,8 +8,8 @@ #include "core/gmm_helper/page_table_mngr.h" namespace NEO { -GmmPageTableMngr *GmmPageTableMngr::create(unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb) { - return new GmmPageTableMngr(translationTableFlags, translationTableCb); +GmmPageTableMngr *GmmPageTableMngr::create(GmmClientContext *clientContext, unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb) { + return new GmmPageTableMngr(clientContext, translationTableFlags, translationTableCb); } } // namespace NEO diff --git a/core/gmm_helper/page_table_mngr.h b/core/gmm_helper/page_table_mngr.h index 0fea50171e..2b50f24ff0 100644 --- a/core/gmm_helper/page_table_mngr.h +++ b/core/gmm_helper/page_table_mngr.h @@ -15,12 +15,13 @@ namespace NEO { class Gmm; +class GmmClientContext; class LinearStream; class GmmPageTableMngr { public: MOCKABLE_VIRTUAL ~GmmPageTableMngr(); - static GmmPageTableMngr *create(unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb); + static GmmPageTableMngr *create(GmmClientContext *clientContext, unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb); MOCKABLE_VIRTUAL void setCsrHandle(void *csrHandle); @@ -38,7 +39,7 @@ class GmmPageTableMngr { return pageTableManager->InitContextAuxTableRegister(initialBBHandle, engineType); } - GmmPageTableMngr(unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb); + GmmPageTableMngr(GmmClientContext *clientContext, unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb); GMM_CLIENT_CONTEXT *clientContext = nullptr; GMM_PAGETABLE_MGR *pageTableManager = nullptr; }; diff --git a/core/os_interface/linux/page_table_manager_functions.cpp b/core/os_interface/linux/page_table_manager_functions.cpp index 37bf83c250..28024beb83 100644 --- a/core/os_interface/linux/page_table_manager_functions.cpp +++ b/core/os_interface/linux/page_table_manager_functions.cpp @@ -12,8 +12,7 @@ #include "gmm_client_context.h" namespace NEO { -GmmPageTableMngr::GmmPageTableMngr(unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb) { - clientContext = platform()->peekGmmClientContext()->getHandle(); +GmmPageTableMngr::GmmPageTableMngr(GmmClientContext *gmmClientContext, unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb) : clientContext(gmmClientContext->getHandle()) { pageTableManager = clientContext->CreatePageTblMgrObject(translationTableFlags); } diff --git a/core/os_interface/windows/page_table_manager_functions.cpp b/core/os_interface/windows/page_table_manager_functions.cpp index e6703cd2fb..83048ae33b 100644 --- a/core/os_interface/windows/page_table_manager_functions.cpp +++ b/core/os_interface/windows/page_table_manager_functions.cpp @@ -12,8 +12,7 @@ #include "gmm_client_context.h" namespace NEO { -GmmPageTableMngr::GmmPageTableMngr(unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb) { - clientContext = platform()->peekGmmClientContext()->getHandle(); +GmmPageTableMngr::GmmPageTableMngr(GmmClientContext *gmmClientContext, unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb) : clientContext(gmmClientContext->getHandle()) { pageTableManager = clientContext->CreatePageTblMgrObject(translationTableCb, translationTableFlags); } diff --git a/runtime/os_interface/linux/drm_command_stream.inl b/runtime/os_interface/linux/drm_command_stream.inl index 4669fb1537..a4e29da66c 100644 --- a/runtime/os_interface/linux/drm_command_stream.inl +++ b/runtime/os_interface/linux/drm_command_stream.inl @@ -151,7 +151,7 @@ DrmMemoryManager *DrmCommandStreamReceiver::getMemoryManager() const template GmmPageTableMngr *DrmCommandStreamReceiver::createPageTableManager() { - GmmPageTableMngr *gmmPageTableMngr = GmmPageTableMngr::create(TT_TYPE::AUXTT, nullptr); + GmmPageTableMngr *gmmPageTableMngr = GmmPageTableMngr::create(this->executionEnvironment.getGmmClientContext(), TT_TYPE::AUXTT, nullptr); gmmPageTableMngr->setCsrHandle(this); this->executionEnvironment.rootDeviceEnvironments[this->rootDeviceIndex]->pageTableManager.reset(gmmPageTableMngr); return gmmPageTableMngr; diff --git a/runtime/os_interface/windows/wddm_device_command_stream.inl b/runtime/os_interface/windows/wddm_device_command_stream.inl index 78344b59b8..1a7e1cbe14 100644 --- a/runtime/os_interface/windows/wddm_device_command_stream.inl +++ b/runtime/os_interface/windows/wddm_device_command_stream.inl @@ -134,9 +134,11 @@ GmmPageTableMngr *WddmCommandStreamReceiver::createPageTableManager() GMM_TRANSLATIONTABLE_CALLBACKS ttCallbacks = {}; ttCallbacks.pfWriteL3Adr = TTCallbacks::writeL3Address; - GmmPageTableMngr *gmmPageTableMngr = GmmPageTableMngr::create(TT_TYPE::AUXTT, &ttCallbacks); + auto rootDeviceEnvironment = executionEnvironment.rootDeviceEnvironments[this->rootDeviceIndex].get(); + + GmmPageTableMngr *gmmPageTableMngr = GmmPageTableMngr::create(executionEnvironment.getGmmClientContext(), TT_TYPE::AUXTT, &ttCallbacks); gmmPageTableMngr->setCsrHandle(this); - this->executionEnvironment.rootDeviceEnvironments[this->rootDeviceIndex]->pageTableManager.reset(gmmPageTableMngr); + rootDeviceEnvironment->pageTableManager.reset(gmmPageTableMngr); return gmmPageTableMngr; } diff --git a/unit_tests/mocks/mock_gmm_page_table_mngr.cpp b/unit_tests/mocks/mock_gmm_page_table_mngr.cpp index 6a2c270b7c..3ec8d545a3 100644 --- a/unit_tests/mocks/mock_gmm_page_table_mngr.cpp +++ b/unit_tests/mocks/mock_gmm_page_table_mngr.cpp @@ -10,7 +10,7 @@ namespace NEO { using namespace ::testing; -GmmPageTableMngr *GmmPageTableMngr::create(unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb) { +GmmPageTableMngr *GmmPageTableMngr::create(GmmClientContext *clientContext, unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb) { auto pageTableMngr = new ::testing::NiceMock(translationTableFlags, translationTableCb); ON_CALL(*pageTableMngr, initContextAuxTableRegister(_, _)).WillByDefault(Return(GMM_SUCCESS)); ON_CALL(*pageTableMngr, updateAuxTable(_)).WillByDefault(Return(GMM_SUCCESS)); diff --git a/unit_tests/os_interface/linux/drm_command_stream_tests.cpp b/unit_tests/os_interface/linux/drm_command_stream_tests.cpp index 8987a16d77..fce22e20f1 100644 --- a/unit_tests/os_interface/linux/drm_command_stream_tests.cpp +++ b/unit_tests/os_interface/linux/drm_command_stream_tests.cpp @@ -1385,6 +1385,7 @@ struct MockDrmCsr : public DrmCommandStreamReceiver { HWTEST_TEMPLATED_F(DrmCommandStreamTest, givenDrmCommandStreamReceiverWhenCreatePageTableMngrIsCalledThenCreatePageTableManager) { executionEnvironment.prepareRootDeviceEnvironments(2); + executionEnvironment.initGmm(); executionEnvironment.rootDeviceEnvironments[1]->osInterface = std::make_unique(); executionEnvironment.rootDeviceEnvironments[1]->osInterface->get()->setDrm(mock.get()); auto csr = std::make_unique>(executionEnvironment, 1, gemCloseWorkerMode::gemCloseWorkerActive); diff --git a/unit_tests/os_interface/windows/wddm_memory_manager_tests.cpp b/unit_tests/os_interface/windows/wddm_memory_manager_tests.cpp index a2a132d182..9c6e3cc9f7 100644 --- a/unit_tests/os_interface/windows/wddm_memory_manager_tests.cpp +++ b/unit_tests/os_interface/windows/wddm_memory_manager_tests.cpp @@ -44,7 +44,7 @@ void WddmMemoryManagerFixture::SetUp() { wddm = static_cast(Wddm::createWddm(*executionEnvironment->rootDeviceEnvironments[0].get())); if (platformDevices[0]->capabilityTable.ftrRenderCompressedBuffers || platformDevices[0]->capabilityTable.ftrRenderCompressedImages) { GMM_TRANSLATIONTABLE_CALLBACKS dummyTTCallbacks = {}; - executionEnvironment->rootDeviceEnvironments[0]->pageTableManager.reset(GmmPageTableMngr::create(0, &dummyTTCallbacks)); + executionEnvironment->rootDeviceEnvironments[0]->pageTableManager.reset(GmmPageTableMngr::create(nullptr, 0, &dummyTTCallbacks)); } auto hwInfo = *platformDevices[0]; wddm->init(hwInfo);