mirror of
https://github.com/intel/compute-runtime.git
synced 2025-09-15 13:01:45 +08:00
Update GMM API related to page table manager
Resolves: NEO-3155 Change-Id: I44a544a4ecd06e5769995eb1f67948ebb10a2cb5 Signed-off-by: Mateusz Jablonski <mateusz.jablonski@intel.com>
This commit is contained in:

committed by
sys_ocldev

parent
2da3e45867
commit
3c1c4cf695
@ -8,8 +8,8 @@
|
||||
#include "runtime/gmm_helper/page_table_mngr.h"
|
||||
|
||||
namespace NEO {
|
||||
GmmPageTableMngr *GmmPageTableMngr::create(GMM_DEVICE_CALLBACKS_INT *deviceCb, unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb) {
|
||||
return new GmmPageTableMngr(deviceCb, translationTableFlags, translationTableCb);
|
||||
GmmPageTableMngr *GmmPageTableMngr::create(unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb) {
|
||||
return new GmmPageTableMngr(translationTableFlags, translationTableCb);
|
||||
}
|
||||
|
||||
} // namespace NEO
|
||||
|
@ -16,7 +16,7 @@ class GmmPageTableMngr {
|
||||
public:
|
||||
MOCKABLE_VIRTUAL ~GmmPageTableMngr();
|
||||
|
||||
static GmmPageTableMngr *create(GMM_DEVICE_CALLBACKS_INT *deviceCb, unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb);
|
||||
static GmmPageTableMngr *create(unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb);
|
||||
|
||||
MOCKABLE_VIRTUAL GMM_STATUS initContextAuxTableRegister(HANDLE initialBBHandle, GMM_ENGINE_TYPE engineType) {
|
||||
return pageTableManager->InitContextAuxTableRegister(initialBBHandle, engineType);
|
||||
@ -29,11 +29,14 @@ class GmmPageTableMngr {
|
||||
MOCKABLE_VIRTUAL GMM_STATUS updateAuxTable(const GMM_DDI_UPDATEAUXTABLE *ddiUpdateAuxTable) {
|
||||
return pageTableManager->UpdateAuxTable(ddiUpdateAuxTable);
|
||||
}
|
||||
MOCKABLE_VIRTUAL void setCsrHandle(void *csrHandle) {
|
||||
pageTableManager->GmmSetCsrHandle(csrHandle);
|
||||
}
|
||||
|
||||
protected:
|
||||
GmmPageTableMngr() = default;
|
||||
|
||||
GmmPageTableMngr(GMM_DEVICE_CALLBACKS_INT *deviceCb, unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb);
|
||||
GmmPageTableMngr(unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb);
|
||||
GMM_CLIENT_CONTEXT *clientContext = nullptr;
|
||||
GMM_PAGETABLE_MGR *pageTableManager = nullptr;
|
||||
};
|
||||
|
@ -17,9 +17,9 @@ GmmPageTableMngr::~GmmPageTableMngr() {
|
||||
}
|
||||
}
|
||||
|
||||
GmmPageTableMngr::GmmPageTableMngr(GMM_DEVICE_CALLBACKS_INT *deviceCb, unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb) {
|
||||
GmmPageTableMngr::GmmPageTableMngr(unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb) {
|
||||
clientContext = GmmHelper::getClientContext()->getHandle();
|
||||
pageTableManager = clientContext->CreatePageTblMgrObject(deviceCb, translationTableCb, translationTableFlags);
|
||||
pageTableManager = clientContext->CreatePageTblMgrObject(translationTableCb, translationTableFlags);
|
||||
}
|
||||
|
||||
} // namespace NEO
|
||||
|
@ -10,6 +10,7 @@
|
||||
|
||||
namespace NEO {
|
||||
|
||||
static long(__stdcall *notifyAubCaptureImpl)(void *csrHandle, uint64_t gfxAddress, size_t gfxSize, bool allocate) = nullptr;
|
||||
template <typename GfxFamily>
|
||||
struct DeviceCallbacks {
|
||||
static long __stdcall notifyAubCapture(void *csrHandle, uint64_t gfxAddress, size_t gfxSize, bool allocate);
|
||||
|
@ -35,6 +35,7 @@ template <typename GfxFamily>
|
||||
WddmCommandStreamReceiver<GfxFamily>::WddmCommandStreamReceiver(ExecutionEnvironment &executionEnvironment)
|
||||
: BaseClass(executionEnvironment) {
|
||||
|
||||
notifyAubCaptureImpl = DeviceCallbacks<GfxFamily>::notifyAubCapture;
|
||||
this->wddm = executionEnvironment.osInterface->get()->getWddm();
|
||||
this->osInterface = executionEnvironment.osInterface.get();
|
||||
|
||||
@ -139,34 +140,11 @@ bool WddmCommandStreamReceiver<GfxFamily>::waitForFlushStamp(FlushStamp &flushSt
|
||||
|
||||
template <typename GfxFamily>
|
||||
GmmPageTableMngr *WddmCommandStreamReceiver<GfxFamily>::createPageTableManager() {
|
||||
GMM_DEVICE_CALLBACKS_INT deviceCallbacks = {};
|
||||
GMM_TRANSLATIONTABLE_CALLBACKS ttCallbacks = {};
|
||||
auto gdi = wddm->getGdi();
|
||||
ttCallbacks.pfWriteL3Adr = TTCallbacks<GfxFamily>::writeL3Address;
|
||||
|
||||
// clang-format off
|
||||
deviceCallbacks.Adapter.KmtHandle = wddm->getAdapter();
|
||||
deviceCallbacks.hDevice.KmtHandle = wddm->getDevice();
|
||||
deviceCallbacks.hCsr = static_cast<CommandStreamReceiverHw<GfxFamily> *>(this);
|
||||
deviceCallbacks.PagingQueue = wddm->getPagingQueue();
|
||||
deviceCallbacks.PagingFence = wddm->getPagingQueueSyncObject();
|
||||
|
||||
deviceCallbacks.DevCbPtrs.KmtCbPtrs.pfnAllocate = gdi->createAllocation;
|
||||
deviceCallbacks.DevCbPtrs.KmtCbPtrs.pfnDeallocate = gdi->destroyAllocation;
|
||||
deviceCallbacks.DevCbPtrs.KmtCbPtrs.pfnMapGPUVA = gdi->mapGpuVirtualAddress;
|
||||
deviceCallbacks.DevCbPtrs.KmtCbPtrs.pfnMakeResident = gdi->makeResident;
|
||||
deviceCallbacks.DevCbPtrs.KmtCbPtrs.pfnEvict = gdi->evict;
|
||||
deviceCallbacks.DevCbPtrs.KmtCbPtrs.pfnReserveGPUVA = gdi->reserveGpuVirtualAddress;
|
||||
deviceCallbacks.DevCbPtrs.KmtCbPtrs.pfnUpdateGPUVA = gdi->updateGpuVirtualAddress;
|
||||
deviceCallbacks.DevCbPtrs.KmtCbPtrs.pfnWaitFromCpu = gdi->waitForSynchronizationObjectFromCpu;
|
||||
deviceCallbacks.DevCbPtrs.KmtCbPtrs.pfnLock = gdi->lock2;
|
||||
deviceCallbacks.DevCbPtrs.KmtCbPtrs.pfnUnLock = gdi->unlock2;
|
||||
deviceCallbacks.DevCbPtrs.KmtCbPtrs.pfnEscape = gdi->escape;
|
||||
deviceCallbacks.DevCbPtrs.KmtCbPtrs.pfnNotifyAubCapture = DeviceCallbacks<GfxFamily>::notifyAubCapture;
|
||||
|
||||
ttCallbacks.pfWriteL3Adr = TTCallbacks<GfxFamily>::writeL3Address;
|
||||
// clang-format on
|
||||
|
||||
GmmPageTableMngr *gmmPageTableMngr = GmmPageTableMngr::create(&deviceCallbacks, TT_TYPE::TRTT | TT_TYPE::AUXTT, &ttCallbacks);
|
||||
GmmPageTableMngr *gmmPageTableMngr = GmmPageTableMngr::create(TT_TYPE::TRTT | TT_TYPE::AUXTT, &ttCallbacks);
|
||||
gmmPageTableMngr->setCsrHandle(this);
|
||||
this->wddm->resetPageTableManager(gmmPageTableMngr);
|
||||
return gmmPageTableMngr;
|
||||
}
|
||||
|
@ -10,11 +10,15 @@
|
||||
namespace NEO {
|
||||
using namespace ::testing;
|
||||
|
||||
GmmPageTableMngr *GmmPageTableMngr::create(GMM_DEVICE_CALLBACKS_INT *deviceCb, unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb) {
|
||||
auto pageTableMngr = new ::testing::NiceMock<MockGmmPageTableMngr>(deviceCb, translationTableFlags, translationTableCb);
|
||||
GmmPageTableMngr *GmmPageTableMngr::create(unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb) {
|
||||
auto pageTableMngr = new ::testing::NiceMock<MockGmmPageTableMngr>(translationTableFlags, translationTableCb);
|
||||
ON_CALL(*pageTableMngr, initContextAuxTableRegister(_, _)).WillByDefault(Return(GMM_SUCCESS));
|
||||
ON_CALL(*pageTableMngr, initContextTRTableRegister(_, _)).WillByDefault(Return(GMM_SUCCESS));
|
||||
ON_CALL(*pageTableMngr, updateAuxTable(_)).WillByDefault(Return(GMM_SUCCESS));
|
||||
return pageTableMngr;
|
||||
}
|
||||
void MockGmmPageTableMngr::setCsrHandle(void *csrHandle) {
|
||||
passedCsrHandle = csrHandle;
|
||||
setCsrHanleCalled++;
|
||||
}
|
||||
} // namespace NEO
|
||||
|
@ -16,8 +16,8 @@ class MockGmmPageTableMngr : public GmmPageTableMngr {
|
||||
public:
|
||||
MockGmmPageTableMngr() = default;
|
||||
|
||||
MockGmmPageTableMngr(GMM_DEVICE_CALLBACKS_INT *deviceCb, unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb)
|
||||
: deviceCb(*deviceCb), translationTableFlags(translationTableFlags), translationTableCb(*translationTableCb){};
|
||||
MockGmmPageTableMngr(unsigned int translationTableFlags, GMM_TRANSLATIONTABLE_CALLBACKS *translationTableCb)
|
||||
: translationTableFlags(translationTableFlags), translationTableCb(*translationTableCb){};
|
||||
|
||||
MOCK_METHOD2(initContextAuxTableRegister, GMM_STATUS(HANDLE initialBBHandle, GMM_ENGINE_TYPE engineType));
|
||||
|
||||
@ -25,7 +25,11 @@ class MockGmmPageTableMngr : public GmmPageTableMngr {
|
||||
|
||||
MOCK_METHOD1(updateAuxTable, GMM_STATUS(const GMM_DDI_UPDATEAUXTABLE *ddiUpdateAuxTable));
|
||||
|
||||
GMM_DEVICE_CALLBACKS_INT deviceCb = {};
|
||||
void setCsrHandle(void *csrHandle) override;
|
||||
|
||||
uint32_t setCsrHanleCalled = 0;
|
||||
void *passedCsrHandle = nullptr;
|
||||
|
||||
GMM_TRANSLATIONTABLE_CALLBACKS translationTableCb = {};
|
||||
unsigned int translationTableFlags = 0;
|
||||
};
|
||||
|
@ -855,38 +855,14 @@ HWTEST_P(WddmCsrCompressionParameterizedTest, givenEnabledCompressionWhenInitial
|
||||
ASSERT_NE(nullptr, myMockWddm->getPageTableManager());
|
||||
|
||||
auto mockMngr = reinterpret_cast<MockGmmPageTableMngr *>(myMockWddm->getPageTableManager());
|
||||
EXPECT_EQ(1u, mockMngr->setCsrHanleCalled);
|
||||
EXPECT_EQ(&mockWddmCsr, mockMngr->passedCsrHandle);
|
||||
|
||||
GMM_DEVICE_CALLBACKS_INT expectedDeviceCb = {};
|
||||
GMM_TRANSLATIONTABLE_CALLBACKS expectedTTCallbacks = {};
|
||||
unsigned int expectedFlags = (TT_TYPE::TRTT | TT_TYPE::AUXTT);
|
||||
auto myGdi = myMockWddm->getGdi();
|
||||
// clang-format off
|
||||
expectedDeviceCb.Adapter.KmtHandle = myMockWddm->getAdapter();
|
||||
expectedDeviceCb.hDevice.KmtHandle = myMockWddm->getDevice();
|
||||
expectedDeviceCb.hCsr = &mockWddmCsr;
|
||||
expectedDeviceCb.PagingQueue = myMockWddm->getPagingQueue();
|
||||
expectedDeviceCb.PagingFence = myMockWddm->getPagingQueueSyncObject();
|
||||
|
||||
expectedDeviceCb.DevCbPtrs.KmtCbPtrs.pfnAllocate = myGdi->createAllocation;
|
||||
expectedDeviceCb.DevCbPtrs.KmtCbPtrs.pfnDeallocate = myGdi->destroyAllocation;
|
||||
expectedDeviceCb.DevCbPtrs.KmtCbPtrs.pfnMapGPUVA = myGdi->mapGpuVirtualAddress;
|
||||
expectedDeviceCb.DevCbPtrs.KmtCbPtrs.pfnMakeResident = myGdi->makeResident;
|
||||
expectedDeviceCb.DevCbPtrs.KmtCbPtrs.pfnEvict = myGdi->evict;
|
||||
expectedDeviceCb.DevCbPtrs.KmtCbPtrs.pfnReserveGPUVA = myGdi->reserveGpuVirtualAddress;
|
||||
expectedDeviceCb.DevCbPtrs.KmtCbPtrs.pfnUpdateGPUVA = myGdi->updateGpuVirtualAddress;
|
||||
expectedDeviceCb.DevCbPtrs.KmtCbPtrs.pfnWaitFromCpu = myGdi->waitForSynchronizationObjectFromCpu;
|
||||
expectedDeviceCb.DevCbPtrs.KmtCbPtrs.pfnLock = myGdi->lock2;
|
||||
expectedDeviceCb.DevCbPtrs.KmtCbPtrs.pfnUnLock = myGdi->unlock2;
|
||||
expectedDeviceCb.DevCbPtrs.KmtCbPtrs.pfnEscape = myGdi->escape;
|
||||
expectedDeviceCb.DevCbPtrs.KmtCbPtrs.pfnNotifyAubCapture = DeviceCallbacks<FamilyType>::notifyAubCapture;
|
||||
|
||||
expectedTTCallbacks.pfWriteL3Adr = TTCallbacks<FamilyType>::writeL3Address;
|
||||
// clang-format on
|
||||
|
||||
EXPECT_TRUE(memcmp(&expectedDeviceCb, &mockMngr->deviceCb, sizeof(GMM_DEVICE_CALLBACKS_INT)) == 0);
|
||||
EXPECT_TRUE(memcmp(&expectedDeviceCb.Adapter, &mockMngr->deviceCb.Adapter, sizeof(GMM_HANDLE_EXT)) == 0);
|
||||
EXPECT_TRUE(memcmp(&expectedDeviceCb.hDevice, &mockMngr->deviceCb.hDevice, sizeof(GMM_HANDLE_EXT)) == 0);
|
||||
EXPECT_TRUE(memcmp(&expectedDeviceCb.DevCbPtrs.KmtCbPtrs, &mockMngr->deviceCb.DevCbPtrs.KmtCbPtrs, sizeof(GMM_DEVICE_CB_PTRS::KmtCbPtrs)) == 0);
|
||||
EXPECT_TRUE(memcmp(&expectedTTCallbacks, &mockMngr->translationTableCb, sizeof(GMM_TRANSLATIONTABLE_CALLBACKS)) == 0);
|
||||
EXPECT_TRUE(memcmp(&expectedFlags, &mockMngr->translationTableFlags, sizeof(unsigned int)) == 0);
|
||||
}
|
||||
|
@ -36,9 +36,8 @@ void WddmMemoryManagerFixture::SetUp() {
|
||||
|
||||
wddm = static_cast<WddmMock *>(Wddm::createWddm());
|
||||
if (platformDevices[0]->capabilityTable.ftrRenderCompressedBuffers || platformDevices[0]->capabilityTable.ftrRenderCompressedImages) {
|
||||
GMM_DEVICE_CALLBACKS_INT dummyDeviceCallbacks = {};
|
||||
GMM_TRANSLATIONTABLE_CALLBACKS dummyTTCallbacks = {};
|
||||
wddm->resetPageTableManager(GmmPageTableMngr::create(&dummyDeviceCallbacks, 0, &dummyTTCallbacks));
|
||||
wddm->resetPageTableManager(GmmPageTableMngr::create(0, &dummyTTCallbacks));
|
||||
}
|
||||
EXPECT_TRUE(wddm->init(PreemptionHelper::getDefaultPreemptionMode(*platformDevices[0])));
|
||||
constexpr uint64_t heap32Base = (is32bit) ? 0x1000 : 0x800000000000;
|
||||
|
Reference in New Issue
Block a user