diff --git a/opencl/source/os_interface/linux/drm_command_stream.inl b/opencl/source/os_interface/linux/drm_command_stream.inl index ab9bdc71e6..d16cbb8bb1 100644 --- a/opencl/source/os_interface/linux/drm_command_stream.inl +++ b/opencl/source/os_interface/linux/drm_command_stream.inl @@ -8,6 +8,7 @@ #include "shared/source/command_stream/linear_stream.h" #include "shared/source/direct_submission/linux/drm_direct_submission.h" #include "shared/source/execution_environment/execution_environment.h" +#include "shared/source/gmm_helper/client_context/gmm_client_context.h" #include "shared/source/gmm_helper/gmm_helper.h" #include "shared/source/gmm_helper/page_table_mngr.h" #include "shared/source/helpers/aligned_memory.h" @@ -211,9 +212,19 @@ DrmMemoryManager *DrmCommandStreamReceiver::getMemoryManager() const template GmmPageTableMngr *DrmCommandStreamReceiver::createPageTableManager() { - GmmPageTableMngr *gmmPageTableMngr = GmmPageTableMngr::create(this->executionEnvironment.rootDeviceEnvironments[this->rootDeviceIndex]->getGmmClientContext(), TT_TYPE::AUXTT, nullptr); + auto rootDeviceEnvironment = this->executionEnvironment.rootDeviceEnvironments[this->rootDeviceIndex].get(); + auto gmmClientContext = rootDeviceEnvironment->getGmmClientContext(); + + GMM_DEVICE_INFO deviceInfo{}; + GMM_DEVICE_CALLBACKS_INT deviceCallbacks{}; + deviceInfo.pDeviceCb = &deviceCallbacks; + gmmClientContext->setGmmDeviceInfo(&deviceInfo); + + auto gmmPageTableMngr = GmmPageTableMngr::create(gmmClientContext, TT_TYPE::AUXTT, nullptr); gmmPageTableMngr->setCsrHandle(this); - this->executionEnvironment.rootDeviceEnvironments[this->rootDeviceIndex]->pageTableManager.reset(gmmPageTableMngr); + + rootDeviceEnvironment->pageTableManager.reset(gmmPageTableMngr); + return gmmPageTableMngr; } diff --git a/shared/source/gmm_helper/client_context/gmm_client_context.cpp b/shared/source/gmm_helper/client_context/gmm_client_context.cpp index 3dda085a8f..9557ad5a3d 100644 --- a/shared/source/gmm_helper/client_context/gmm_client_context.cpp +++ b/shared/source/gmm_helper/client_context/gmm_client_context.cpp @@ -75,4 +75,8 @@ uint8_t GmmClientContext::getMediaSurfaceStateCompressionFormat(GMM_RESOURCE_FOR return clientContext->GetMediaSurfaceStateCompressionFormat(format); } +void GmmClientContext::setGmmDeviceInfo(GMM_DEVICE_INFO *deviceInfo) { + clientContext->GmmSetDeviceInfo(deviceInfo); +} + } // namespace NEO diff --git a/shared/source/gmm_helper/client_context/gmm_client_context.h b/shared/source/gmm_helper/client_context/gmm_client_context.h index b8d23526a0..b13ad008da 100644 --- a/shared/source/gmm_helper/client_context/gmm_client_context.h +++ b/shared/source/gmm_helper/client_context/gmm_client_context.h @@ -42,6 +42,8 @@ class GmmClientContext { this->handleAllocator = std::move(allocator); } + MOCKABLE_VIRTUAL void setGmmDeviceInfo(GMM_DEVICE_INFO *deviceInfo); + GmmHandleAllocator *getHandleAllocator() { return handleAllocator.get(); } diff --git a/shared/source/os_interface/linux/page_table_manager_functions.cpp b/shared/source/os_interface/linux/page_table_manager_functions.cpp index 94de688470..d2583349ec 100644 --- a/shared/source/os_interface/linux/page_table_manager_functions.cpp +++ b/shared/source/os_interface/linux/page_table_manager_functions.cpp @@ -14,6 +14,7 @@ namespace NEO { GmmPageTableMngr::GmmPageTableMngr(GmmClientContext *gmmClientContext, unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb) : clientContext(gmmClientContext->getHandle()) { pageTableManager = clientContext->CreatePageTblMgrObject(translationTableFlags); + DEBUG_BREAK_IF(pageTableManager == nullptr); } void GmmPageTableMngr::setCsrHandle(void *csrHandle) {} diff --git a/shared/test/common/mocks/mock_gmm_client_context_base.cpp b/shared/test/common/mocks/mock_gmm_client_context_base.cpp index 6f94b58f04..c5d01c2495 100644 --- a/shared/test/common/mocks/mock_gmm_client_context_base.cpp +++ b/shared/test/common/mocks/mock_gmm_client_context_base.cpp @@ -7,6 +7,8 @@ #include "shared/test/common/mocks/mock_gmm_client_context_base.h" +#include "gtest/gtest.h" + namespace NEO { GMM_RESOURCE_INFO *MockGmmClientContextBase::createResInfoObject(GMM_RESCREATE_PARAMS *pCreateParams) { @@ -33,4 +35,10 @@ uint8_t MockGmmClientContextBase::getMediaSurfaceStateCompressionFormat(GMM_RESO return compressionFormatToReturn; } +void MockGmmClientContextBase::setGmmDeviceInfo(GMM_DEVICE_INFO *deviceInfo) { + EXPECT_NE(deviceInfo, nullptr); + + GMM_DEVICE_CALLBACKS_INT emptyStruct{}; + EXPECT_EQ(0, memcmp(deviceInfo->pDeviceCb, &emptyStruct, sizeof(GMM_DEVICE_CALLBACKS_INT))); +} } // namespace NEO diff --git a/shared/test/common/mocks/mock_gmm_client_context_base.h b/shared/test/common/mocks/mock_gmm_client_context_base.h index 6e02c930d5..83004ef3f6 100644 --- a/shared/test/common/mocks/mock_gmm_client_context_base.h +++ b/shared/test/common/mocks/mock_gmm_client_context_base.h @@ -17,6 +17,7 @@ class MockGmmClientContextBase : public GmmClientContext { void destroyResInfoObject(GMM_RESOURCE_INFO *pResInfo) override; uint8_t getSurfaceStateCompressionFormat(GMM_RESOURCE_FORMAT format) override; uint8_t getMediaSurfaceStateCompressionFormat(GMM_RESOURCE_FORMAT format) override; + void setGmmDeviceInfo(GMM_DEVICE_INFO *deviceInfo) override; GMM_RESOURCE_FORMAT capturedFormat = GMM_FORMAT_INVALID; uint8_t compressionFormatToReturn = 1;