fix: allow kernel access across multiple virtual regions

Related to: NEO-8350

Signed-off-by: John Falkowski <john.falkowski@intel.com>
This commit is contained in:
John Falkowski
2024-04-23 07:14:31 +00:00
committed by Compute-Runtime-Automation
parent b499973658
commit b9c1ef65dd
3 changed files with 7 additions and 25 deletions

View File

@@ -39,22 +39,11 @@ struct KernelHw : public KernelImp {
auto misalignedSize = ptrDiff(alloc->getGpuAddressToPatch(), baseAddress);
auto offset = ptrDiff(address, reinterpret_cast<void *>(baseAddress));
size_t bufferSizeForSsh = alloc->getUnderlyingBufferSize();
// If the allocation is part of a mapped virtual range, then check to see if the buffer size needs to be extended to include more physical buffers.
// If the allocation is part of a mapped virtual range, then set size to maximum to allow for access across multiple virtual ranges.
Device *device = module->getDevice();
auto allocData = device->getDriverHandle()->getSvmAllocsManager()->getSVMAlloc(reinterpret_cast<void *>(alloc->getGpuAddress()));
if (allocData && allocData->virtualReservationData) {
size_t calcBufferSizeForSsh = bufferSizeForSsh;
for (const auto &mappedAllocationData : allocData->virtualReservationData->mappedAllocations) {
// Add additional allocations buffer size to be programmed to allow full usage of the memory range if the allocation is after this starting address.
if (address != mappedAllocationData.second->ptr && mappedAllocationData.second->ptr > address) {
calcBufferSizeForSsh += mappedAllocationData.second->mappedAllocation->allocation->getUnderlyingBufferSize();
// Only allow for the surface state to be extended up to 4GB in size.
bufferSizeForSsh = std::min(calcBufferSizeForSsh, MemoryConstants::gigaByte * 4);
if (bufferSizeForSsh == MemoryConstants::gigaByte * 4) {
break;
}
}
}
bufferSizeForSsh = MemoryConstants::fullStatefulRegion;
}
auto argInfo = kernelImmData->getDescriptor().payloadMappings.explicitArgs[argIndex].as<NEO::ArgDescPointer>();
bool offsetWasPatched = NEO::patchNonPointer<uint32_t, uint32_t>(ArrayRef<uint8_t>(this->crossThreadData.get(), this->crossThreadDataSize),