From 43b81637df2a67f0636a2ed46876d4d7a649124e Mon Sep 17 00:00:00 2001 From: Joshua Santosh Ranjan Date: Thu, 1 Aug 2024 14:19:32 +0000 Subject: [PATCH] fix: correctly handle error return paths in getConcurrentMetricGroups fix the size used to copy concurrent groups Related-To: NEO-11382 Signed-off-by: Joshua Santosh Ranjan --- level_zero/tools/source/metrics/metric.cpp | 6 +++++- .../sources/metrics/test_metric_concurrent_groups.cpp | 10 ++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/level_zero/tools/source/metrics/metric.cpp b/level_zero/tools/source/metrics/metric.cpp index 19632a326a..b6eeb7b51f 100644 --- a/level_zero/tools/source/metrics/metric.cpp +++ b/level_zero/tools/source/metrics/metric.cpp @@ -220,6 +220,7 @@ ze_result_t MetricDeviceContext::getConcurrentMetricGroups(uint32_t metricGroupC auto status = source->getConcurrentMetricGroups(metricGroups, &perSourceConcurrentCount, nullptr); if (status != ZE_RESULT_SUCCESS) { METRICS_LOG_ERR("Per source concurrent metric group query returned error status %d", status); + *pConcurrentGroupCount = 0; return status; } maxConcurrentGroupCount = std::max(maxConcurrentGroupCount, perSourceConcurrentCount); @@ -232,6 +233,7 @@ ze_result_t MetricDeviceContext::getConcurrentMetricGroups(uint32_t metricGroupC if (*pConcurrentGroupCount != maxConcurrentGroupCount) { METRICS_LOG_ERR("Input Concurrent Group Count %d is not same as expected %d", *pConcurrentGroupCount, maxConcurrentGroupCount); + *pConcurrentGroupCount = 0; return ZE_RESULT_ERROR_INVALID_ARGUMENT; } @@ -247,6 +249,7 @@ ze_result_t MetricDeviceContext::getConcurrentMetricGroups(uint32_t metricGroupC auto status = source->getConcurrentMetricGroups(metricGroups, &perSourceConcurrentCount, countPerConcurrentGroup.data()); if (status != ZE_RESULT_SUCCESS) { METRICS_LOG_ERR("getConcurrentMetricGroups returned error status %d", status); + *pConcurrentGroupCount = 0; return status; } @@ -274,7 +277,8 @@ ze_result_t MetricDeviceContext::getConcurrentMetricGroups(uint32_t metricGroupC // Update the output metric groups size_t availableSize = metricGroupCount; for (auto &concurrentGroup : concurrentGroups) { - memcpy_s(phMetricGroups, availableSize, concurrentGroup.data(), concurrentGroup.size()); + memcpy_s(phMetricGroups, availableSize * sizeof(zet_metric_group_handle_t), concurrentGroup.data(), concurrentGroup.size() * sizeof(zet_metric_group_handle_t)); + availableSize -= concurrentGroup.size(); phMetricGroups += concurrentGroup.size(); } diff --git a/level_zero/tools/test/unit_tests/sources/metrics/test_metric_concurrent_groups.cpp b/level_zero/tools/test/unit_tests/sources/metrics/test_metric_concurrent_groups.cpp index 1b781679fb..a657541a32 100644 --- a/level_zero/tools/test/unit_tests/sources/metrics/test_metric_concurrent_groups.cpp +++ b/level_zero/tools/test/unit_tests/sources/metrics/test_metric_concurrent_groups.cpp @@ -97,6 +97,8 @@ TEST_F(ConcurrentMetricGroupFixture, WhenGetConcurrentMetricGroupsIsCalledForDif metricGroupHandles.push_back(metricGroupType2[0].toHandle()); metricGroupHandles.push_back(metricGroupType2[2].toHandle()); + auto backupMetricGroupHandles = metricGroupHandles; + uint32_t concurrentGroupCount = 0; EXPECT_EQ(ZE_RESULT_SUCCESS, deviceContext->getConcurrentMetricGroups(static_cast(metricGroupHandles.size()), metricGroupHandles.data(), &concurrentGroupCount, nullptr)); EXPECT_EQ(concurrentGroupCount, 4u); @@ -106,6 +108,11 @@ TEST_F(ConcurrentMetricGroupFixture, WhenGetConcurrentMetricGroupsIsCalledForDif EXPECT_EQ(countPerConcurrentGroup[0], 3u); EXPECT_EQ(countPerConcurrentGroup[1], 3u); EXPECT_EQ(countPerConcurrentGroup[2], 1u); + + // Ensure that re-arranged metric group handles were originally present in the vector + for (auto &metricGroupHandle : metricGroupHandles) { + EXPECT_TRUE(std::find(backupMetricGroupHandles.begin(), backupMetricGroupHandles.end(), metricGroupHandle) != backupMetricGroupHandles.end()); + } } TEST_F(ConcurrentMetricGroupFixture, WhenGetConcurrentMetricGroupsIsCalledWithIncorrectGroupCountThenFailureIsReturned) { @@ -134,6 +141,7 @@ TEST_F(ConcurrentMetricGroupFixture, WhenGetConcurrentMetricGroupsIsCalledWithIn concurrentGroupCount -= 1; std::vector countPerConcurrentGroup(concurrentGroupCount); EXPECT_EQ(ZE_RESULT_ERROR_INVALID_ARGUMENT, deviceContext->getConcurrentMetricGroups(static_cast(metricGroupHandles.size()), metricGroupHandles.data(), &concurrentGroupCount, countPerConcurrentGroup.data())); + EXPECT_EQ(concurrentGroupCount, 0u); } class MockMetricSourceError : public MockMetricSource { @@ -178,10 +186,12 @@ TEST_F(ConcurrentMetricGroupFixture, WhenGetConcurrentMetricGroupsIsCalledAndSou concurrentGroupCount = 10; std::vector countPerConcurrentGroup(10); EXPECT_EQ(ZE_RESULT_ERROR_UNSUPPORTED_FEATURE, deviceContext->getConcurrentMetricGroups(static_cast(metricGroupHandles.size()), metricGroupHandles.data(), &concurrentGroupCount, countPerConcurrentGroup.data())); + EXPECT_EQ(concurrentGroupCount, 0u); MockMetricSourceError::errorCallCount = 1; concurrentGroupCount = 0; EXPECT_EQ(ZE_RESULT_ERROR_UNSUPPORTED_FEATURE, deviceContext->getConcurrentMetricGroups(static_cast(metricGroupHandles.size()), metricGroupHandles.data(), &concurrentGroupCount, nullptr)); + EXPECT_EQ(concurrentGroupCount, 0u); } } // namespace ult