diff --git a/level_zero/sysman/source/shared/windows/pmt/sysman_pmt.cpp b/level_zero/sysman/source/shared/windows/pmt/sysman_pmt.cpp index 17a96ec519..53caf2f649 100644 --- a/level_zero/sysman/source/shared/windows/pmt/sysman_pmt.cpp +++ b/level_zero/sysman/source/shared/windows/pmt/sysman_pmt.cpp @@ -75,13 +75,16 @@ ze_result_t PlatformMonitoringTech::readValue(const std::string &key, uint64_t & } ze_result_t PlatformMonitoringTech::getGuid() { + // arbitrary upper-bound, tune if needed + constexpr size_t sizeNeededMax = sizeof(PmtSysman::PmtTelemetryDiscovery) + (2 * PmtSysman::PmtMaxInterfaces - 1) * sizeof(PmtSysman::PmtTelemetryEntry); + ze_result_t status; unsigned long sizeNeeded; PmtSysman::PmtTelemetryDiscovery *telemetryDiscovery = nullptr; // Get Telmetry Discovery size status = ioctlReadWriteData(deviceInterface, PmtSysman::IoctlPmtGetTelemetryDiscoverySize, NULL, 0, (void *)&sizeNeeded, sizeof(sizeNeeded), NULL); - if (status != ZE_RESULT_SUCCESS || sizeNeeded == 0) { + if (status != ZE_RESULT_SUCCESS || sizeNeeded == 0 || sizeNeeded > sizeNeededMax) { NEO::printDebugString(NEO::debugManager.flags.PrintDebugMessages.get(), stderr, "Ioctl call could not return a valid value for the PMT interface telemetry size needed\n"); DEBUG_BREAK_IF(true); @@ -98,14 +101,15 @@ ze_result_t PlatformMonitoringTech::getGuid() { if (status != ZE_RESULT_SUCCESS) { NEO::printDebugString(NEO::debugManager.flags.PrintDebugMessages.get(), stderr, "Ioctl call could not return a valid value for the PMT telemetry structure which provides the guids supported.\n"); - DEBUG_BREAK_IF(true); heapFreeFunction(GetProcessHeap(), 0, telemetryDiscovery); + DEBUG_BREAK_IF(true); return ZE_RESULT_ERROR_UNKNOWN; } auto maxEntriesCount = (sizeNeeded - offsetof(PmtSysman::PmtTelemetryDiscovery, telemetry)) / sizeof(PmtSysman::PmtTelemetryEntry); if (telemetryDiscovery->count > maxEntriesCount) { NEO::printDebugString(NEO::debugManager.flags.PrintDebugMessages.get(), stderr, "Incorrect telemetry entries count.\n"); + heapFreeFunction(GetProcessHeap(), 0, telemetryDiscovery); DEBUG_BREAK_IF(true); return ZE_RESULT_ERROR_UNKNOWN; } @@ -116,8 +120,8 @@ ze_result_t PlatformMonitoringTech::getGuid() { } else { NEO::printDebugString(NEO::debugManager.flags.PrintDebugMessages.get(), stderr, "Telemetry index is out of range.\n"); - DEBUG_BREAK_IF(true); heapFreeFunction(GetProcessHeap(), 0, telemetryDiscovery); + DEBUG_BREAK_IF(true); return ZE_RESULT_ERROR_UNKNOWN; } }