diff --git a/shared/source/os_interface/linux/drm_neo.cpp b/shared/source/os_interface/linux/drm_neo.cpp index e1ee9faf09..4724f219fb 100644 --- a/shared/source/os_interface/linux/drm_neo.cpp +++ b/shared/source/os_interface/linux/drm_neo.cpp @@ -1399,8 +1399,19 @@ int changeBufferObjectBinding(Drm *drm, OsContext *osContext, uint32_t vmHandleI if (!drm->hasPageFaultSupport() || bo->isExplicitResidencyRequired()) { auto nextExtension = vmBind.extensions; - auto address = castToUint64(drm->getFenceAddr(vmHandleId)); - auto value = drm->getNextFenceVal(vmHandleId); + + uint64_t address = 0; + uint64_t value = 0; + + if (drm->isPerContextVMRequired()) { + auto osContextLinux = static_cast(osContext); + address = castToUint64(osContextLinux->getFenceAddr(vmHandleId)); + value = osContextLinux->getNextFenceVal(vmHandleId); + } else { + address = castToUint64(drm->getFenceAddr(vmHandleId)); + value = drm->getNextFenceVal(vmHandleId); + } + incrementFenceValue = true; ioctlHelper->fillVmBindExtUserFence(vmBindExtUserFence, address, value, nextExtension); vmBind.extensions = castToUint64(vmBindExtUserFence); @@ -1422,7 +1433,12 @@ int changeBufferObjectBinding(Drm *drm, OsContext *osContext, uint32_t vmHandleI } } if (incrementFenceValue) { - drm->incFenceVal(vmHandleId); + if (drm->isPerContextVMRequired()) { + auto osContextLinux = static_cast(osContext); + osContextLinux->incFenceVal(vmHandleId); + } else { + drm->incFenceVal(vmHandleId); + } } } diff --git a/shared/source/os_interface/linux/os_context_linux.cpp b/shared/source/os_interface/linux/os_context_linux.cpp index ce8c32abb8..961d0858ff 100644 --- a/shared/source/os_interface/linux/os_context_linux.cpp +++ b/shared/source/os_interface/linux/os_context_linux.cpp @@ -12,6 +12,7 @@ #include "shared/source/execution_environment/root_device_environment.h" #include "shared/source/helpers/engine_node_helper.h" #include "shared/source/helpers/hw_info.h" +#include "shared/source/helpers/ptr_math.h" #include "shared/source/os_interface/linux/drm_neo.h" #include "shared/source/os_interface/linux/ioctl_helper.h" #include "shared/source/os_interface/os_context.h" @@ -30,7 +31,10 @@ OsContext *OsContextLinux::create(OSInterface *osInterface, uint32_t rootDeviceI OsContextLinux::OsContextLinux(Drm &drm, uint32_t rootDeviceIndex, uint32_t contextId, const EngineDescriptor &engineDescriptor) : OsContext(rootDeviceIndex, contextId, engineDescriptor), - drm(drm) {} + drm(drm) { + pagingFence.fill(0u); + fenceVal.fill(0u); +} bool OsContextLinux::initializeContext() { auto hwInfo = drm.getRootDeviceEnvironment().getHardwareInfo(); @@ -88,11 +92,28 @@ Drm &OsContextLinux::getDrm() const { void OsContextLinux::waitForPagingFence() { for (auto drmIterator = 0u; drmIterator < this->deviceBitfield.size(); drmIterator++) { if (this->deviceBitfield.test(drmIterator)) { - drm.waitForBind(drmIterator); + this->waitForBind(drmIterator); } } } +void OsContextLinux::waitForBind(uint32_t drmIterator) { + if (drm.isPerContextVMRequired()) { + if (pagingFence[drmIterator] >= fenceVal[drmIterator]) { + return; + } + auto lock = drm.lockBindFenceMutex(); + auto fenceAddress = castToUint64(&this->pagingFence[drmIterator]); + auto fenceValue = this->fenceVal[drmIterator]; + lock.unlock(); + + drm.waitUserFence(0u, fenceAddress, fenceValue, Drm::ValueWidth::U64, -1, drm.getIoctlHelper()->getWaitUserFenceSoftFlag()); + + } else { + drm.waitForBind(drmIterator); + } +} + void OsContextLinux::reInitializeContext() {} uint64_t OsContextLinux::getOfflineDumpContextId(uint32_t deviceIndex) const { diff --git a/shared/source/os_interface/linux/os_context_linux.h b/shared/source/os_interface/linux/os_context_linux.h index fa32de0995..a0fe7f2198 100644 --- a/shared/source/os_interface/linux/os_context_linux.h +++ b/shared/source/os_interface/linux/os_context_linux.h @@ -8,8 +8,10 @@ #pragma once #include "shared/source/helpers/mt_helpers.h" +#include "shared/source/memory_manager/definitions/engine_limits.h" #include "shared/source/os_interface/os_context.h" +#include #include #include @@ -52,6 +54,11 @@ class OsContextLinux : public OsContext { uint64_t getOfflineDumpContextId(uint32_t deviceIndex) const override; + uint64_t getNextFenceVal(uint32_t deviceIndex) { return fenceVal[deviceIndex] + 1; } + void incFenceVal(uint32_t deviceIndex) { fenceVal[deviceIndex]++; } + uint64_t *getFenceAddr(uint32_t deviceIndex) { return &pagingFence[deviceIndex]; } + void waitForBind(uint32_t drmIterator); + protected: bool initializeContext() override; @@ -60,6 +67,10 @@ class OsContextLinux : public OsContext { unsigned int engineFlag = 0; std::vector drmContextIds; std::vector drmVmIds; + + std::array pagingFence; + std::array fenceVal; + Drm &drm; bool contextHangDetected = false; }; diff --git a/shared/test/common/mocks/linux/mock_os_context_linux.h b/shared/test/common/mocks/linux/mock_os_context_linux.h index a341bce8b4..11444e6fa4 100644 --- a/shared/test/common/mocks/linux/mock_os_context_linux.h +++ b/shared/test/common/mocks/linux/mock_os_context_linux.h @@ -12,6 +12,8 @@ namespace NEO { class MockOsContextLinux : public OsContextLinux { public: using OsContextLinux::drmContextIds; + using OsContextLinux::fenceVal; + using OsContextLinux::pagingFence; MockOsContextLinux(Drm &drm, uint32_t rootDeviceIndex, uint32_t contextId, const EngineDescriptor &engineDescriptor) : OsContextLinux(drm, rootDeviceIndex, contextId, engineDescriptor) {} diff --git a/shared/test/unit_test/os_interface/linux/drm_vm_bind_prelim_tests.cpp b/shared/test/unit_test/os_interface/linux/drm_vm_bind_prelim_tests.cpp index 0c71478d0b..4e44044df5 100644 --- a/shared/test/unit_test/os_interface/linux/drm_vm_bind_prelim_tests.cpp +++ b/shared/test/unit_test/os_interface/linux/drm_vm_bind_prelim_tests.cpp @@ -13,6 +13,7 @@ #include "shared/test/common/helpers/engine_descriptor_helper.h" #include "shared/test/common/libult/linux/drm_query_mock.h" #include "shared/test/common/mocks/linux/mock_drm_allocation.h" +#include "shared/test/common/mocks/linux/mock_os_context_linux.h" #include "shared/test/common/mocks/mock_execution_environment.h" #include "gtest/gtest.h" @@ -60,6 +61,35 @@ TEST(DrmVmBindTest, givenBoRequiringExplicitResidencyWhenBindingThenMakeResident } } +TEST(DrmVmBindTest, givenPerContextVmsAndBoRequiringExplicitResidencyWhenBindingThenPagingFenceFromContextIsUsed) { + auto executionEnvironment = std::make_unique(); + executionEnvironment->rootDeviceEnvironments[0]->initGmm(); + executionEnvironment->initializeMemoryManager(); + DrmQueryMock drm{*executionEnvironment->rootDeviceEnvironments[0]}; + drm.pageFaultSupported = true; + drm.requirePerContextVM = true; + + for (auto requireResidency : {false, true}) { + MockBufferObject bo(0, &drm, 3, 0, 0, 1); + bo.requireExplicitResidency(requireResidency); + + MockOsContextLinux osContext(drm, 0, 0u, EngineDescriptorHelper::getDefaultDescriptor()); + osContext.ensureContextInitialized(); + uint32_t vmHandleId = 0; + bo.bind(&osContext, vmHandleId); + + if (requireResidency) { + EXPECT_EQ(DrmPrelimHelper::getImmediateVmBindFlag() | DrmPrelimHelper::getMakeResidentVmBindFlag(), drm.context.receivedVmBind->flags); + ASSERT_TRUE(drm.context.receivedVmBindUserFence); + EXPECT_EQ(castToUint64(osContext.getFenceAddr(vmHandleId)), drm.context.receivedVmBindUserFence->addr); + EXPECT_EQ(osContext.fenceVal[vmHandleId], drm.context.receivedVmBindUserFence->val); + EXPECT_EQ(1u, osContext.fenceVal[vmHandleId]); + } else { + EXPECT_EQ(DrmPrelimHelper::getImmediateVmBindFlag(), drm.context.receivedVmBind->flags); + } + } +} + TEST(DrmVmBindTest, givenBoNotRequiringExplicitResidencyWhenCallingWaitForBindThenDontWaitOnUserFence) { auto executionEnvironment = std::make_unique(); executionEnvironment->rootDeviceEnvironments[0]->initGmm(); diff --git a/shared/test/unit_test/os_interface/linux/os_context_linux_tests.cpp b/shared/test/unit_test/os_interface/linux/os_context_linux_tests.cpp index 11e9ddad35..ba39585972 100644 --- a/shared/test/unit_test/os_interface/linux/os_context_linux_tests.cpp +++ b/shared/test/unit_test/os_interface/linux/os_context_linux_tests.cpp @@ -6,6 +6,7 @@ */ #include "shared/source/os_interface/linux/drm_memory_operations_handler.h" +#include "shared/source/os_interface/linux/ioctl_helper.h" #include "shared/source/os_interface/linux/os_context_linux.h" #include "shared/test/common/helpers/engine_descriptor_helper.h" #include "shared/test/common/libult/linux/drm_mock.h" @@ -68,3 +69,50 @@ TEST(OSContextLinux, givenOsContextLinuxWhenQueryingForOfflineDumpContextIdThenC EXPECT_EQ(0u, osContext.getOfflineDumpContextId(10)); } + +TEST(OSContextLinux, givenPerContextVmsAndBindNotCompleteWhenWaitForPagingFenceThenContextFenceIsPassedToWaitUserFenceIoctl) { + auto executionEnvironment = std::make_unique(); + DrmMock drm{*executionEnvironment->rootDeviceEnvironments[0]}; + drm.requirePerContextVM = true; + + MockOsContextLinux osContext(drm, 0, 0u, EngineDescriptorHelper::getDefaultDescriptor()); + osContext.ensureContextInitialized(); + + drm.pagingFence[0] = 26u; + drm.fenceVal[0] = 31u; + + osContext.pagingFence[0] = 46u; + osContext.fenceVal[0] = 51u; + + osContext.waitForPagingFence(); + + EXPECT_EQ(1u, drm.waitUserFenceParams.size()); + EXPECT_EQ(0u, drm.waitUserFenceParams[0].ctxId); + EXPECT_EQ(castToUint64(&osContext.pagingFence[0]), drm.waitUserFenceParams[0].address); + EXPECT_EQ(drm.ioctlHelper->getWaitUserFenceSoftFlag(), drm.waitUserFenceParams[0].flags); + EXPECT_EQ(osContext.fenceVal[0], drm.waitUserFenceParams[0].value); + EXPECT_EQ(-1, drm.waitUserFenceParams[0].timeout); + + drm.requirePerContextVM = false; + osContext.waitForPagingFence(); + + EXPECT_EQ(castToUint64(&drm.pagingFence[0]), drm.waitUserFenceParams[1].address); + EXPECT_EQ(drm.ioctlHelper->getWaitUserFenceSoftFlag(), drm.waitUserFenceParams[1].flags); + EXPECT_EQ(drm.fenceVal[0], drm.waitUserFenceParams[1].value); +} + +TEST(OSContextLinux, givenPerContextVmsAndBindCompleteWhenWaitForPagingFenceThenWaitUserFenceIoctlIsNotCalled) { + auto executionEnvironment = std::make_unique(); + DrmMock drm{*executionEnvironment->rootDeviceEnvironments[0]}; + drm.requirePerContextVM = true; + + MockOsContextLinux osContext(drm, 0, 0u, EngineDescriptorHelper::getDefaultDescriptor()); + osContext.ensureContextInitialized(); + + osContext.pagingFence[0] = 3u; + osContext.fenceVal[0] = 3u; + + osContext.waitForPagingFence(); + + EXPECT_EQ(0u, drm.waitUserFenceParams.size()); +}