fix: correctly handle error return paths in getConcurrentMetricGroups

fix the size used to copy concurrent groups

Related-To: NEO-11382


Signed-off-by: Joshua Santosh Ranjan <joshua.santosh.ranjan@intel.com>
This commit is contained in:
Joshua Santosh Ranjan
2024-08-01 14:19:32 +00:00
committed by Compute-Runtime-Automation
parent a2fdc3d3f5
commit 43b81637df
2 changed files with 15 additions and 1 deletions

View File

@@ -220,6 +220,7 @@ ze_result_t MetricDeviceContext::getConcurrentMetricGroups(uint32_t metricGroupC
auto status = source->getConcurrentMetricGroups(metricGroups, &perSourceConcurrentCount, nullptr);
if (status != ZE_RESULT_SUCCESS) {
METRICS_LOG_ERR("Per source concurrent metric group query returned error status %d", status);
*pConcurrentGroupCount = 0;
return status;
}
maxConcurrentGroupCount = std::max(maxConcurrentGroupCount, perSourceConcurrentCount);
@@ -232,6 +233,7 @@ ze_result_t MetricDeviceContext::getConcurrentMetricGroups(uint32_t metricGroupC
if (*pConcurrentGroupCount != maxConcurrentGroupCount) {
METRICS_LOG_ERR("Input Concurrent Group Count %d is not same as expected %d", *pConcurrentGroupCount, maxConcurrentGroupCount);
*pConcurrentGroupCount = 0;
return ZE_RESULT_ERROR_INVALID_ARGUMENT;
}
@@ -247,6 +249,7 @@ ze_result_t MetricDeviceContext::getConcurrentMetricGroups(uint32_t metricGroupC
auto status = source->getConcurrentMetricGroups(metricGroups, &perSourceConcurrentCount, countPerConcurrentGroup.data());
if (status != ZE_RESULT_SUCCESS) {
METRICS_LOG_ERR("getConcurrentMetricGroups returned error status %d", status);
*pConcurrentGroupCount = 0;
return status;
}
@@ -274,7 +277,8 @@ ze_result_t MetricDeviceContext::getConcurrentMetricGroups(uint32_t metricGroupC
// Update the output metric groups
size_t availableSize = metricGroupCount;
for (auto &concurrentGroup : concurrentGroups) {
memcpy_s(phMetricGroups, availableSize, concurrentGroup.data(), concurrentGroup.size());
memcpy_s(phMetricGroups, availableSize * sizeof(zet_metric_group_handle_t), concurrentGroup.data(), concurrentGroup.size() * sizeof(zet_metric_group_handle_t));
availableSize -= concurrentGroup.size();
phMetricGroups += concurrentGroup.size();
}

View File

@@ -97,6 +97,8 @@ TEST_F(ConcurrentMetricGroupFixture, WhenGetConcurrentMetricGroupsIsCalledForDif
metricGroupHandles.push_back(metricGroupType2[0].toHandle());
metricGroupHandles.push_back(metricGroupType2[2].toHandle());
auto backupMetricGroupHandles = metricGroupHandles;
uint32_t concurrentGroupCount = 0;
EXPECT_EQ(ZE_RESULT_SUCCESS, deviceContext->getConcurrentMetricGroups(static_cast<uint32_t>(metricGroupHandles.size()), metricGroupHandles.data(), &concurrentGroupCount, nullptr));
EXPECT_EQ(concurrentGroupCount, 4u);
@@ -106,6 +108,11 @@ TEST_F(ConcurrentMetricGroupFixture, WhenGetConcurrentMetricGroupsIsCalledForDif
EXPECT_EQ(countPerConcurrentGroup[0], 3u);
EXPECT_EQ(countPerConcurrentGroup[1], 3u);
EXPECT_EQ(countPerConcurrentGroup[2], 1u);
// Ensure that re-arranged metric group handles were originally present in the vector
for (auto &metricGroupHandle : metricGroupHandles) {
EXPECT_TRUE(std::find(backupMetricGroupHandles.begin(), backupMetricGroupHandles.end(), metricGroupHandle) != backupMetricGroupHandles.end());
}
}
TEST_F(ConcurrentMetricGroupFixture, WhenGetConcurrentMetricGroupsIsCalledWithIncorrectGroupCountThenFailureIsReturned) {
@@ -134,6 +141,7 @@ TEST_F(ConcurrentMetricGroupFixture, WhenGetConcurrentMetricGroupsIsCalledWithIn
concurrentGroupCount -= 1;
std::vector<uint32_t> countPerConcurrentGroup(concurrentGroupCount);
EXPECT_EQ(ZE_RESULT_ERROR_INVALID_ARGUMENT, deviceContext->getConcurrentMetricGroups(static_cast<uint32_t>(metricGroupHandles.size()), metricGroupHandles.data(), &concurrentGroupCount, countPerConcurrentGroup.data()));
EXPECT_EQ(concurrentGroupCount, 0u);
}
class MockMetricSourceError : public MockMetricSource {
@@ -178,10 +186,12 @@ TEST_F(ConcurrentMetricGroupFixture, WhenGetConcurrentMetricGroupsIsCalledAndSou
concurrentGroupCount = 10;
std::vector<uint32_t> countPerConcurrentGroup(10);
EXPECT_EQ(ZE_RESULT_ERROR_UNSUPPORTED_FEATURE, deviceContext->getConcurrentMetricGroups(static_cast<uint32_t>(metricGroupHandles.size()), metricGroupHandles.data(), &concurrentGroupCount, countPerConcurrentGroup.data()));
EXPECT_EQ(concurrentGroupCount, 0u);
MockMetricSourceError::errorCallCount = 1;
concurrentGroupCount = 0;
EXPECT_EQ(ZE_RESULT_ERROR_UNSUPPORTED_FEATURE, deviceContext->getConcurrentMetricGroups(static_cast<uint32_t>(metricGroupHandles.size()), metricGroupHandles.data(), &concurrentGroupCount, nullptr));
EXPECT_EQ(concurrentGroupCount, 0u);
}
} // namespace ult