feature: Support for concurrent groups

Related-To: NEO-10306

Signed-off-by: Joshua Santosh Ranjan <joshua.santosh.ranjan@intel.com>
This commit is contained in:
Joshua Santosh Ranjan
2024-04-12 10:23:36 +00:00
committed by Compute-Runtime-Automation
parent 15a049cf9c
commit e0a580fce7
16 changed files with 488 additions and 2 deletions

View File

@@ -202,6 +202,88 @@ ze_result_t MetricDeviceContext::enableMetricApi() {
: ZE_RESULT_SUCCESS;
}
ze_result_t MetricDeviceContext::getConcurrentMetricGroups(uint32_t metricGroupCount,
zet_metric_group_handle_t *phMetricGroups,
uint32_t *pConcurrentGroupCount, uint32_t *pCountPerConcurrentGroup) {
std::map<MetricSource *, std::vector<zet_metric_group_handle_t>> metricGroupsPerMetricSourceMap{};
for (auto index = 0u; index < metricGroupCount; index++) {
auto &metricGroupSource =
static_cast<MetricGroupImp *>(MetricGroup::fromHandle(phMetricGroups[index]))->getMetricSource();
metricGroupsPerMetricSourceMap[&metricGroupSource].push_back(phMetricGroups[index]);
}
// Calculate the maximum concurrent group count
uint32_t maxConcurrentGroupCount = 0;
for (auto &[source, metricGroups] : metricGroupsPerMetricSourceMap) {
uint32_t perSourceConcurrentCount = 0;
auto status = source->getConcurrentMetricGroups(metricGroups, &perSourceConcurrentCount, nullptr);
if (status != ZE_RESULT_SUCCESS) {
METRICS_LOG_ERR("Per source concurrent metric group query returned error status %d", status);
return status;
}
maxConcurrentGroupCount = std::max(maxConcurrentGroupCount, perSourceConcurrentCount);
}
if (*pConcurrentGroupCount == 0) {
*pConcurrentGroupCount = maxConcurrentGroupCount;
return ZE_RESULT_SUCCESS;
}
if (*pConcurrentGroupCount != maxConcurrentGroupCount) {
METRICS_LOG_ERR("Input Concurrent Group Count %d is not same as expected %d", *pConcurrentGroupCount, maxConcurrentGroupCount);
return ZE_RESULT_ERROR_INVALID_ARGUMENT;
}
std::vector<std::vector<zet_metric_group_handle_t>> concurrentGroups(maxConcurrentGroupCount);
for (auto &entry : metricGroupsPerMetricSourceMap) {
auto source = entry.first;
auto &metricGroups = entry.second;
// Using maximum possible concurrent group count
uint32_t perSourceConcurrentCount = metricGroupCount;
std::vector<uint32_t> countPerConcurrentGroup(perSourceConcurrentCount);
auto status = source->getConcurrentMetricGroups(metricGroups, &perSourceConcurrentCount, countPerConcurrentGroup.data());
if (status != ZE_RESULT_SUCCESS) {
METRICS_LOG_ERR("getConcurrentMetricGroups returned error status %d", status);
return status;
}
DEBUG_BREAK_IF(static_cast<uint32_t>(concurrentGroups.size() < perSourceConcurrentCount));
auto metricGroupsStartOffset = metricGroups.begin();
[[maybe_unused]] uint32_t totalMetricGroupCount = 0;
// Copy the handles to appropriate groups
for (uint32_t groupIndex = 0; groupIndex < perSourceConcurrentCount; groupIndex++) {
totalMetricGroupCount += countPerConcurrentGroup[groupIndex];
DEBUG_BREAK_IF(totalMetricGroupCount > static_cast<uint32_t>(metricGroups.size()));
concurrentGroups[groupIndex].insert(concurrentGroups[groupIndex].end(),
metricGroupsStartOffset,
metricGroupsStartOffset + countPerConcurrentGroup[groupIndex]);
metricGroupsStartOffset += countPerConcurrentGroup[groupIndex];
}
}
// Update the concurrent Group count and count per concurrent grup
*pConcurrentGroupCount = static_cast<uint32_t>(concurrentGroups.size());
for (uint32_t index = 0u; index < *pConcurrentGroupCount; index++) {
pCountPerConcurrentGroup[index] = static_cast<uint32_t>(concurrentGroups[index].size());
}
// Update the output metric groups
size_t availableSize = metricGroupCount;
for (auto &concurrentGroup : concurrentGroups) {
memcpy_s(phMetricGroups, availableSize, concurrentGroup.data(), concurrentGroup.size());
availableSize -= concurrentGroup.size();
phMetricGroups += concurrentGroup.size();
}
DEBUG_BREAK_IF(availableSize != 0);
return ZE_RESULT_SUCCESS;
}
ze_result_t metricGroupGet(zet_device_handle_t hDevice, uint32_t *pCount, zet_metric_group_handle_t *phMetricGroups) {
auto device = Device::fromHandle(hDevice);
return device->getMetricDeviceContext().metricGroupGet(pCount, phMetricGroups);