diff --git a/level_zero/api/driver_experimental/public/CMakeLists.txt b/level_zero/api/driver_experimental/public/CMakeLists.txt index 2fc38da9fc..8c0e9709c5 100644 --- a/level_zero/api/driver_experimental/public/CMakeLists.txt +++ b/level_zero/api/driver_experimental/public/CMakeLists.txt @@ -20,4 +20,6 @@ target_sources(${L0_STATIC_LIB_NAME} ${CMAKE_CURRENT_SOURCE_DIR}/zex_memory.h ${CMAKE_CURRENT_SOURCE_DIR}/zex_module.cpp ${CMAKE_CURRENT_SOURCE_DIR}/zex_module.h + ${CMAKE_CURRENT_SOURCE_DIR}/zex_metric.h + ${CMAKE_CURRENT_SOURCE_DIR}/zex_metric.cpp ) diff --git a/level_zero/api/driver_experimental/public/zex_api.h b/level_zero/api/driver_experimental/public/zex_api.h index 27f7391a3e..2b2428f1b6 100644 --- a/level_zero/api/driver_experimental/public/zex_api.h +++ b/level_zero/api/driver_experimental/public/zex_api.h @@ -22,6 +22,7 @@ #include "zex_driver.h" #include "zex_event.h" #include "zex_memory.h" +#include "zex_metric.h" #include "zex_module.h" #endif // _ZEX_API_H diff --git a/level_zero/api/driver_experimental/public/zex_metric.cpp b/level_zero/api/driver_experimental/public/zex_metric.cpp new file mode 100644 index 0000000000..a3bbb608b3 --- /dev/null +++ b/level_zero/api/driver_experimental/public/zex_metric.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (C) 2024 Intel Corporation + * + * SPDX-License-Identifier: MIT + * + */ + +#include "level_zero/api/driver_experimental/public/zex_metric.h" + +#include "level_zero/core/source/device/device.h" +#include "level_zero/tools/source/metrics/metric.h" + +namespace L0 { + +ze_result_t ZE_APICALL +zexDeviceGetConcurrentMetricGroups( + zet_device_handle_t hDevice, + uint32_t metricGroupCount, + zet_metric_group_handle_t *phMetricGroups, + uint32_t *pConcurrentGroupCount, + uint32_t *pCountPerConcurrentGroup) { + + auto device = Device::fromHandle(hDevice); + return static_cast(device->getMetricDeviceContext()).getConcurrentMetricGroups(metricGroupCount, phMetricGroups, pConcurrentGroupCount, pCountPerConcurrentGroup); +} + +} // namespace L0 + +extern "C" { + +ZE_APIEXPORT ze_result_t ZE_APICALL +zexDeviceGetConcurrentMetricGroups( + zet_device_handle_t hDevice, + uint32_t metricGroupCount, + zet_metric_group_handle_t *phMetricGroups, + uint32_t *pConcurrentGroupCount, + uint32_t *pCountPerConcurrentGroup) { + return L0::zexDeviceGetConcurrentMetricGroups( + hDevice, metricGroupCount, phMetricGroups, + pConcurrentGroupCount, pCountPerConcurrentGroup); +} + +} // extern "C" \ No newline at end of file diff --git a/level_zero/api/driver_experimental/public/zex_metric.h b/level_zero/api/driver_experimental/public/zex_metric.h new file mode 100644 index 0000000000..3dd2053c3d --- /dev/null +++ b/level_zero/api/driver_experimental/public/zex_metric.h @@ -0,0 +1,40 @@ +/* + * Copyright (C) 2024 Intel Corporation + * + * SPDX-License-Identifier: MIT + * + */ + +#ifndef _ZEX_METRIC_H +#define _ZEX_METRIC_H +#if defined(__cplusplus) +#pragma once +#endif + +#include +#include + +namespace L0 { +/////////////////////////////////////////////////////////////////////////////// +/// @brief Get sets of metric groups which could be collected concurrently. +/// +/// @details +/// - Re-arrange the input metric groups to provide sets of concurrent metric groups. +/// - The application may call this function from simultaneous threads. +/// - The implementation of this function must be thread-safe. +/// +/// @returns +/// - ::ZE_RESULT_SUCCESS +/// - ::ZE_RESULT_ERROR_INVALID_ARGUMENT +/// + `pConcurrentGroupCount` is not same as was returned by L0 using zexDeviceGetConcurrentMetricGroups +ZE_APIEXPORT ze_result_t ZE_APICALL +zexDeviceGetConcurrentMetricGroups( + zet_device_handle_t hDevice, // [in] handle of the device + uint32_t metricGroupCount, // [in] metric group count + zet_metric_group_handle_t *phMetricGroups, // [in, out] metrics groups to be re-arranged to be sets of concurrent groups + uint32_t *pConcurrentGroupCount, // [out] number of concurrent groups. + uint32_t *pCountPerConcurrentGroup); // [in,out][optional][*pConcurrentGroupCount] count of metric groups per concurrent group. + +} // namespace L0 + +#endif // _ZEX_METRIC_H diff --git a/level_zero/core/source/driver/extension_function_address.cpp b/level_zero/core/source/driver/extension_function_address.cpp index f415677d34..aae0e2f1bf 100644 --- a/level_zero/core/source/driver/extension_function_address.cpp +++ b/level_zero/core/source/driver/extension_function_address.cpp @@ -40,6 +40,7 @@ void *ExtensionFunctionAddressHelper::getExtensionFunctionAddress(const std::str RETURN_FUNC_PTR_IF_EXIST(zeMemGetPitchFor2dImage); RETURN_FUNC_PTR_IF_EXIST(zeImageGetDeviceOffsetExp); + RETURN_FUNC_PTR_IF_EXIST(zexDeviceGetConcurrentMetricGroups); #undef RETURN_FUNC_PTR_IF_EXIST return ExtensionFunctionAddressHelper::getAdditionalExtensionFunctionAddress(functionName); diff --git a/level_zero/tools/source/metrics/metric.cpp b/level_zero/tools/source/metrics/metric.cpp index f9925fbbce..19632a326a 100644 --- a/level_zero/tools/source/metrics/metric.cpp +++ b/level_zero/tools/source/metrics/metric.cpp @@ -202,6 +202,88 @@ ze_result_t MetricDeviceContext::enableMetricApi() { : ZE_RESULT_SUCCESS; } +ze_result_t MetricDeviceContext::getConcurrentMetricGroups(uint32_t metricGroupCount, + zet_metric_group_handle_t *phMetricGroups, + uint32_t *pConcurrentGroupCount, uint32_t *pCountPerConcurrentGroup) { + + std::map> metricGroupsPerMetricSourceMap{}; + for (auto index = 0u; index < metricGroupCount; index++) { + auto &metricGroupSource = + static_cast(MetricGroup::fromHandle(phMetricGroups[index]))->getMetricSource(); + metricGroupsPerMetricSourceMap[&metricGroupSource].push_back(phMetricGroups[index]); + } + + // Calculate the maximum concurrent group count + uint32_t maxConcurrentGroupCount = 0; + for (auto &[source, metricGroups] : metricGroupsPerMetricSourceMap) { + uint32_t perSourceConcurrentCount = 0; + auto status = source->getConcurrentMetricGroups(metricGroups, &perSourceConcurrentCount, nullptr); + if (status != ZE_RESULT_SUCCESS) { + METRICS_LOG_ERR("Per source concurrent metric group query returned error status %d", status); + return status; + } + maxConcurrentGroupCount = std::max(maxConcurrentGroupCount, perSourceConcurrentCount); + } + + if (*pConcurrentGroupCount == 0) { + *pConcurrentGroupCount = maxConcurrentGroupCount; + return ZE_RESULT_SUCCESS; + } + + if (*pConcurrentGroupCount != maxConcurrentGroupCount) { + METRICS_LOG_ERR("Input Concurrent Group Count %d is not same as expected %d", *pConcurrentGroupCount, maxConcurrentGroupCount); + return ZE_RESULT_ERROR_INVALID_ARGUMENT; + } + + std::vector> concurrentGroups(maxConcurrentGroupCount); + for (auto &entry : metricGroupsPerMetricSourceMap) { + + auto source = entry.first; + auto &metricGroups = entry.second; + + // Using maximum possible concurrent group count + uint32_t perSourceConcurrentCount = metricGroupCount; + std::vector countPerConcurrentGroup(perSourceConcurrentCount); + auto status = source->getConcurrentMetricGroups(metricGroups, &perSourceConcurrentCount, countPerConcurrentGroup.data()); + if (status != ZE_RESULT_SUCCESS) { + METRICS_LOG_ERR("getConcurrentMetricGroups returned error status %d", status); + return status; + } + + DEBUG_BREAK_IF(static_cast(concurrentGroups.size() < perSourceConcurrentCount)); + + auto metricGroupsStartOffset = metricGroups.begin(); + [[maybe_unused]] uint32_t totalMetricGroupCount = 0; + // Copy the handles to appropriate groups + for (uint32_t groupIndex = 0; groupIndex < perSourceConcurrentCount; groupIndex++) { + totalMetricGroupCount += countPerConcurrentGroup[groupIndex]; + DEBUG_BREAK_IF(totalMetricGroupCount > static_cast(metricGroups.size())); + concurrentGroups[groupIndex].insert(concurrentGroups[groupIndex].end(), + metricGroupsStartOffset, + metricGroupsStartOffset + countPerConcurrentGroup[groupIndex]); + metricGroupsStartOffset += countPerConcurrentGroup[groupIndex]; + } + } + + // Update the concurrent Group count and count per concurrent grup + *pConcurrentGroupCount = static_cast(concurrentGroups.size()); + for (uint32_t index = 0u; index < *pConcurrentGroupCount; index++) { + pCountPerConcurrentGroup[index] = static_cast(concurrentGroups[index].size()); + } + + // Update the output metric groups + size_t availableSize = metricGroupCount; + for (auto &concurrentGroup : concurrentGroups) { + memcpy_s(phMetricGroups, availableSize, concurrentGroup.data(), concurrentGroup.size()); + availableSize -= concurrentGroup.size(); + phMetricGroups += concurrentGroup.size(); + } + + DEBUG_BREAK_IF(availableSize != 0); + + return ZE_RESULT_SUCCESS; +} + ze_result_t metricGroupGet(zet_device_handle_t hDevice, uint32_t *pCount, zet_metric_group_handle_t *phMetricGroups) { auto device = Device::fromHandle(hDevice); return device->getMetricDeviceContext().metricGroupGet(pCount, phMetricGroups); diff --git a/level_zero/tools/source/metrics/metric.h b/level_zero/tools/source/metrics/metric.h index f426d19a12..72dc85e322 100644 --- a/level_zero/tools/source/metrics/metric.h +++ b/level_zero/tools/source/metrics/metric.h @@ -64,6 +64,9 @@ class MetricSource { zet_metric_group_handle_t *pMetricGroupHandle) { return ZE_RESULT_ERROR_UNSUPPORTED_FEATURE; }; + virtual ze_result_t getConcurrentMetricGroups(std::vector &hMetricGroups, + uint32_t *pConcurrentGroupCount, + uint32_t *pCountPerConcurrentGroup) = 0; virtual ~MetricSource() = default; uint32_t getType() const { return type; @@ -113,6 +116,8 @@ class MetricDeviceContext { static std::unique_ptr create(Device &device); static ze_result_t enableMetricApi(); + ze_result_t getConcurrentMetricGroups(uint32_t metricGroupCount, zet_metric_group_handle_t *phMetricGroups, + uint32_t *pConcurrentGroupCount, uint32_t *pCountPerConcurrentGroup); bool isProgrammableMetricsEnabled = false; diff --git a/level_zero/tools/source/metrics/metric_ip_sampling_source.cpp b/level_zero/tools/source/metrics/metric_ip_sampling_source.cpp index 29357dfe70..98fbd2c117 100644 --- a/level_zero/tools/source/metrics/metric_ip_sampling_source.cpp +++ b/level_zero/tools/source/metrics/metric_ip_sampling_source.cpp @@ -189,6 +189,23 @@ ze_result_t IpSamplingMetricGroupBase::getExportData(const uint8_t *pRawData, si return ZE_RESULT_SUCCESS; } +ze_result_t IpSamplingMetricSourceImp::getConcurrentMetricGroups(std::vector &hMetricGroups, + uint32_t *pConcurrentGroupCount, + uint32_t *pCountPerConcurrentGroup) { + + if (*pConcurrentGroupCount == 0) { + *pConcurrentGroupCount = static_cast(hMetricGroups.size()); + return ZE_RESULT_SUCCESS; + } + + *pConcurrentGroupCount = std::min(*pConcurrentGroupCount, static_cast(hMetricGroups.size())); + // Each metric group is in unique container + for (uint32_t index = 0; index < *pConcurrentGroupCount; index++) { + pCountPerConcurrentGroup[index] = 1; + } + return ZE_RESULT_SUCCESS; +} + IpSamplingMetricGroupImp::IpSamplingMetricGroupImp(IpSamplingMetricSourceImp &metricSource, std::vector &metrics) : IpSamplingMetricGroupBase(metricSource) { this->metrics.reserve(metrics.size()); diff --git a/level_zero/tools/source/metrics/metric_ip_sampling_source.h b/level_zero/tools/source/metrics/metric_ip_sampling_source.h index ddf333b57b..11e6e31767 100644 --- a/level_zero/tools/source/metrics/metric_ip_sampling_source.h +++ b/level_zero/tools/source/metrics/metric_ip_sampling_source.h @@ -31,6 +31,9 @@ class IpSamplingMetricSourceImp : public MetricSource { ze_result_t metricProgrammableGet(uint32_t *pCount, zet_metric_programmable_exp_handle_t *phMetricProgrammables) override { return ZE_RESULT_ERROR_UNSUPPORTED_FEATURE; } + ze_result_t getConcurrentMetricGroups(std::vector &hMetricGroups, + uint32_t *pConcurrentGroupCount, + uint32_t *pCountPerConcurrentGroup) override; bool isMetricGroupActivated(const zet_metric_group_handle_t hMetricGroup) const; void setMetricOsInterface(std::unique_ptr &metricIPSamplingpOsInterface); static std::unique_ptr create(const MetricDeviceContext &metricDeviceContext); diff --git a/level_zero/tools/source/metrics/metric_oa_source.cpp b/level_zero/tools/source/metrics/metric_oa_source.cpp index a434a85b59..f078b305bd 100644 --- a/level_zero/tools/source/metrics/metric_oa_source.cpp +++ b/level_zero/tools/source/metrics/metric_oa_source.cpp @@ -1,5 +1,5 @@ /* - * Copyright (C) 2022-2023 Intel Corporation + * Copyright (C) 2022-2024 Intel Corporation * * SPDX-License-Identifier: MIT * @@ -183,6 +183,24 @@ ze_result_t OaMetricSourceImp::activateMetricGroupsAlreadyDeferred() { return activationTracker->activateMetricGroupsAlreadyDeferred(); } +ze_result_t OaMetricSourceImp::getConcurrentMetricGroups(std::vector &hMetricGroups, + uint32_t *pConcurrentGroupCount, + uint32_t *pCountPerConcurrentGroup) { + + if (*pConcurrentGroupCount == 0) { + *pConcurrentGroupCount = static_cast(hMetricGroups.size()); + return ZE_RESULT_SUCCESS; + } + + *pConcurrentGroupCount = std::min(*pConcurrentGroupCount, static_cast(hMetricGroups.size())); + // Each metric group is in unique container + for (uint32_t index = 0; index < *pConcurrentGroupCount; index++) { + pCountPerConcurrentGroup[index] = 1; + } + + return ZE_RESULT_SUCCESS; +} + template <> OaMetricSourceImp &MetricDeviceContext::getMetricSource() const { return static_cast(*metricSources.at(MetricSource::metricSourceTypeOa)); diff --git a/level_zero/tools/source/metrics/metric_oa_source.h b/level_zero/tools/source/metrics/metric_oa_source.h index a0f452fd02..8e491575ea 100644 --- a/level_zero/tools/source/metrics/metric_oa_source.h +++ b/level_zero/tools/source/metrics/metric_oa_source.h @@ -45,6 +45,8 @@ class OaMetricSourceImp : public MetricSource { ze_result_t activateMetricGroupsPreferDeferred(const uint32_t count, zet_metric_group_handle_t *phMetricGroups) override; ze_result_t metricProgrammableGet(uint32_t *pCount, zet_metric_programmable_exp_handle_t *phMetricProgrammables) override; + ze_result_t getConcurrentMetricGroups(std::vector &hMetricGroups, + uint32_t *pConcurrentGroupCount, uint32_t *pCountPerConcurrentGroup) override; bool isMetricGroupActivated(const zet_metric_group_handle_t hMetricGroup) const; bool isMetricGroupActivatedInHw() const; void setUseCompute(const bool useCompute); diff --git a/level_zero/tools/test/unit_tests/sources/metrics/CMakeLists.txt b/level_zero/tools/test/unit_tests/sources/metrics/CMakeLists.txt index f6c71af31e..6ceafaa40e 100644 --- a/level_zero/tools/test/unit_tests/sources/metrics/CMakeLists.txt +++ b/level_zero/tools/test/unit_tests/sources/metrics/CMakeLists.txt @@ -30,6 +30,7 @@ target_sources(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/test_metric_ip_sampling_streamer.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_metric_oa_export.cpp ${CMAKE_CURRENT_SOURCE_DIR}/${BRANCH_DIR_SUFFIX}/test_metric_programmable.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test_metric_concurrent_groups.cpp ) diff --git a/level_zero/tools/test/unit_tests/sources/metrics/mock_metric_source.h b/level_zero/tools/test/unit_tests/sources/metrics/mock_metric_source.h index aea46a2a5c..abbc1b4939 100644 --- a/level_zero/tools/test/unit_tests/sources/metrics/mock_metric_source.h +++ b/level_zero/tools/test/unit_tests/sources/metrics/mock_metric_source.h @@ -25,6 +25,11 @@ class MockMetricSource : public L0::MetricSource { ze_result_t metricProgrammableGet(uint32_t *pCount, zet_metric_programmable_exp_handle_t *phMetricProgrammables) override { return ZE_RESULT_ERROR_UNSUPPORTED_FEATURE; } + ze_result_t getConcurrentMetricGroups(std::vector &hMetricGroups, + uint32_t *pConcurrentGroupCount, + uint32_t *pCountPerConcurrentGroup) override { + return ZE_RESULT_ERROR_UNSUPPORTED_FEATURE; + } void setType(uint32_t type) { this->type = type; } diff --git a/level_zero/tools/test/unit_tests/sources/metrics/test_metric_concurrent_groups.cpp b/level_zero/tools/test/unit_tests/sources/metrics/test_metric_concurrent_groups.cpp new file mode 100644 index 0000000000..1b781679fb --- /dev/null +++ b/level_zero/tools/test/unit_tests/sources/metrics/test_metric_concurrent_groups.cpp @@ -0,0 +1,188 @@ +/* + * Copyright (C) 2020-2024 Intel Corporation + * + * SPDX-License-Identifier: MIT + * + */ + +#include "shared/test/common/test_macros/test.h" + +#include "level_zero/core/test/unit_tests/fixtures/device_fixture.h" +#include "level_zero/tools/source/metrics/metric.h" +#include "level_zero/tools/test/unit_tests/sources/metrics/mock_metric_source.h" + +#include "gtest/gtest.h" + +namespace L0 { +namespace ult { + +class ConcurrentMetricGroupFixture : public DeviceFixture, + public ::testing::Test { + public: + std::unique_ptr deviceContext = nullptr; + + protected: + void SetUp() override; + void TearDown() override; +}; + +void ConcurrentMetricGroupFixture::TearDown() { + DeviceFixture::tearDown(); + deviceContext.reset(); +} + +void ConcurrentMetricGroupFixture::SetUp() { + DeviceFixture::setUp(); + deviceContext = std::make_unique(*device); +} + +class MockMetricSourceType1 : public MockMetricSource { + ze_result_t getConcurrentMetricGroups(std::vector &hMetricGroups, + uint32_t *pConcurrentGroupCount, + uint32_t *pCountPerConcurrentGroup) override { + // No overlap possible + if (*pConcurrentGroupCount == 0) { + *pConcurrentGroupCount = static_cast(hMetricGroups.size()); + return ZE_RESULT_SUCCESS; + } + + *pConcurrentGroupCount = std::min(*pConcurrentGroupCount, static_cast(hMetricGroups.size())); + for (uint32_t index = 0; index < *pConcurrentGroupCount; index++) { + pCountPerConcurrentGroup[index] = 1; + } + return ZE_RESULT_SUCCESS; + } +}; + +class MockMetricSourceType2 : public MockMetricSource { + ze_result_t getConcurrentMetricGroups(std::vector &hMetricGroups, + uint32_t *pConcurrentGroupCount, + uint32_t *pCountPerConcurrentGroup) override { + // Requires the number of metric groups to be a multiple of 2 + if (hMetricGroups.size() & 1) { + return ZE_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + // No overlap possible + if (*pConcurrentGroupCount == 0) { + *pConcurrentGroupCount = static_cast(hMetricGroups.size()); + return ZE_RESULT_SUCCESS; + } + + *pConcurrentGroupCount = std::min(*pConcurrentGroupCount, static_cast(hMetricGroups.size() / 2)); + for (uint32_t index = 0; index < *pConcurrentGroupCount; index++) { + pCountPerConcurrentGroup[index] = 2; + } + return ZE_RESULT_SUCCESS; + } +}; + +TEST_F(ConcurrentMetricGroupFixture, WhenGetConcurrentMetricGroupsIsCalledForDifferentSourcesThenGeneratedGroupsAreCorrect) { + + MockMetricSourceType1 sourceType1; + sourceType1.setType(1); + MockMetricGroup metricGroupType1[] = {MockMetricGroup(sourceType1), MockMetricGroup(sourceType1), MockMetricGroup(sourceType1)}; + + MockMetricSourceType2 sourceType2; + sourceType2.setType(2); + MockMetricGroup metricGroupType2[] = {MockMetricGroup(sourceType2), MockMetricGroup(sourceType2), MockMetricGroup(sourceType2), MockMetricGroup(sourceType2)}; + + std::vector metricGroupHandles{}; + + metricGroupHandles.push_back(metricGroupType1[0].toHandle()); + metricGroupHandles.push_back(metricGroupType2[1].toHandle()); + metricGroupHandles.push_back(metricGroupType1[1].toHandle()); + metricGroupHandles.push_back(metricGroupType2[3].toHandle()); + metricGroupHandles.push_back(metricGroupType1[2].toHandle()); + metricGroupHandles.push_back(metricGroupType2[0].toHandle()); + metricGroupHandles.push_back(metricGroupType2[2].toHandle()); + + uint32_t concurrentGroupCount = 0; + EXPECT_EQ(ZE_RESULT_SUCCESS, deviceContext->getConcurrentMetricGroups(static_cast(metricGroupHandles.size()), metricGroupHandles.data(), &concurrentGroupCount, nullptr)); + EXPECT_EQ(concurrentGroupCount, 4u); + std::vector countPerConcurrentGroup(concurrentGroupCount); + EXPECT_EQ(ZE_RESULT_SUCCESS, deviceContext->getConcurrentMetricGroups(static_cast(metricGroupHandles.size()), metricGroupHandles.data(), &concurrentGroupCount, countPerConcurrentGroup.data())); + + EXPECT_EQ(countPerConcurrentGroup[0], 3u); + EXPECT_EQ(countPerConcurrentGroup[1], 3u); + EXPECT_EQ(countPerConcurrentGroup[2], 1u); +} + +TEST_F(ConcurrentMetricGroupFixture, WhenGetConcurrentMetricGroupsIsCalledWithIncorrectGroupCountThenFailureIsReturned) { + + MockMetricSourceType1 sourceType1; + sourceType1.setType(1); + MockMetricGroup metricGroupType1[] = {MockMetricGroup(sourceType1), MockMetricGroup(sourceType1), MockMetricGroup(sourceType1)}; + + MockMetricSourceType2 sourceType2; + sourceType2.setType(2); + MockMetricGroup metricGroupType2[] = {MockMetricGroup(sourceType2), MockMetricGroup(sourceType2), MockMetricGroup(sourceType2), MockMetricGroup(sourceType2)}; + + std::vector metricGroupHandles{}; + + metricGroupHandles.push_back(metricGroupType1[0].toHandle()); + metricGroupHandles.push_back(metricGroupType2[1].toHandle()); + metricGroupHandles.push_back(metricGroupType1[1].toHandle()); + metricGroupHandles.push_back(metricGroupType2[3].toHandle()); + metricGroupHandles.push_back(metricGroupType1[2].toHandle()); + metricGroupHandles.push_back(metricGroupType2[0].toHandle()); + metricGroupHandles.push_back(metricGroupType2[2].toHandle()); + + uint32_t concurrentGroupCount = 0; + EXPECT_EQ(ZE_RESULT_SUCCESS, deviceContext->getConcurrentMetricGroups(static_cast(metricGroupHandles.size()), metricGroupHandles.data(), &concurrentGroupCount, nullptr)); + EXPECT_EQ(concurrentGroupCount, 4u); + concurrentGroupCount -= 1; + std::vector countPerConcurrentGroup(concurrentGroupCount); + EXPECT_EQ(ZE_RESULT_ERROR_INVALID_ARGUMENT, deviceContext->getConcurrentMetricGroups(static_cast(metricGroupHandles.size()), metricGroupHandles.data(), &concurrentGroupCount, countPerConcurrentGroup.data())); +} + +class MockMetricSourceError : public MockMetricSource { + public: + static int32_t errorCallCount; + ze_result_t getConcurrentMetricGroups(std::vector &hMetricGroups, + uint32_t *pConcurrentGroupCount, + uint32_t *pCountPerConcurrentGroup) override { + errorCallCount -= 1; + errorCallCount = std::max(0, errorCallCount); + *pConcurrentGroupCount = 10; + if (errorCallCount) { + return ZE_RESULT_SUCCESS; + } + return ZE_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + ~MockMetricSourceError() override { + errorCallCount = 0; + } +}; + +int32_t MockMetricSourceError::errorCallCount = 1; + +TEST_F(ConcurrentMetricGroupFixture, WhenGetConcurrentMetricGroupsIsCalledAndSourceImplementationReturnsErrorThenErrorIsObserved) { + + MockMetricSourceError sourceTypeError; + sourceTypeError.setType(1); + MockMetricGroup metricGroupError[] = {MockMetricGroup(sourceTypeError), + MockMetricGroup(sourceTypeError), + MockMetricGroup(sourceTypeError)}; + + std::vector metricGroupHandles{}; + + metricGroupHandles.push_back(metricGroupError[0].toHandle()); + metricGroupHandles.push_back(metricGroupError[1].toHandle()); + metricGroupHandles.push_back(metricGroupError[1].toHandle()); + + uint32_t concurrentGroupCount = 0; + MockMetricSourceError::errorCallCount = 3; + EXPECT_EQ(ZE_RESULT_SUCCESS, deviceContext->getConcurrentMetricGroups(static_cast(metricGroupHandles.size()), metricGroupHandles.data(), &concurrentGroupCount, nullptr)); + concurrentGroupCount = 10; + std::vector countPerConcurrentGroup(10); + EXPECT_EQ(ZE_RESULT_ERROR_UNSUPPORTED_FEATURE, deviceContext->getConcurrentMetricGroups(static_cast(metricGroupHandles.size()), metricGroupHandles.data(), &concurrentGroupCount, countPerConcurrentGroup.data())); + + MockMetricSourceError::errorCallCount = 1; + concurrentGroupCount = 0; + EXPECT_EQ(ZE_RESULT_ERROR_UNSUPPORTED_FEATURE, deviceContext->getConcurrentMetricGroups(static_cast(metricGroupHandles.size()), metricGroupHandles.data(), &concurrentGroupCount, nullptr)); +} + +} // namespace ult +} // namespace L0 diff --git a/level_zero/tools/test/unit_tests/sources/metrics/test_metric_ip_sampling_streamer.cpp b/level_zero/tools/test/unit_tests/sources/metrics/test_metric_ip_sampling_streamer.cpp index 888160e0ab..bb8299669b 100644 --- a/level_zero/tools/test/unit_tests/sources/metrics/test_metric_ip_sampling_streamer.cpp +++ b/level_zero/tools/test/unit_tests/sources/metrics/test_metric_ip_sampling_streamer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (C) 2022-2023 Intel Corporation + * Copyright (C) 2022-2024 Intel Corporation * * SPDX-License-Identifier: MIT * @@ -509,5 +509,26 @@ TEST_F(MetricIpSamplingStreamerTest, GivenNotEnoughMemoryWhileReadingWhenReadDat EXPECT_EQ(zetMetricStreamerClose(streamerHandle), ZE_RESULT_SUCCESS); } +TEST_F(MetricIpSamplingStreamerTest, whenGetConcurrentMetricGroupsIsCalledThenCorrectConcurrentGroupsAreRetrieved) { + + EXPECT_EQ(ZE_RESULT_SUCCESS, testDevices[0]->getMetricDeviceContext().enableMetricApi()); + for (auto device : testDevices) { + + auto &metricSource = device->getMetricDeviceContext().getMetricSource(); + zet_metric_group_handle_t metricGroupHandle = MetricIpSamplingStreamerTest::getMetricGroup(device); + std::vector metricGroupList{}; + metricGroupList.push_back(metricGroupHandle); + + uint32_t concurrentGroupCount = 0; + EXPECT_EQ(ZE_RESULT_SUCCESS, metricSource.getConcurrentMetricGroups(metricGroupList, &concurrentGroupCount, nullptr)); + EXPECT_EQ(concurrentGroupCount, 1u); + + std::vector countPerConcurrentGroup(concurrentGroupCount); + concurrentGroupCount += 1; + EXPECT_EQ(ZE_RESULT_SUCCESS, metricSource.getConcurrentMetricGroups(metricGroupList, &concurrentGroupCount, countPerConcurrentGroup.data())); + EXPECT_EQ(concurrentGroupCount, 1u); + } +} + } // namespace ult } // namespace L0 diff --git a/level_zero/tools/test/unit_tests/sources/metrics/test_metric_oa_enumeration_2.cpp b/level_zero/tools/test/unit_tests/sources/metrics/test_metric_oa_enumeration_2.cpp index 782b027070..e689e68acf 100644 --- a/level_zero/tools/test/unit_tests/sources/metrics/test_metric_oa_enumeration_2.cpp +++ b/level_zero/tools/test/unit_tests/sources/metrics/test_metric_oa_enumeration_2.cpp @@ -1174,5 +1174,62 @@ TEST_F(MetricEnumerationMultiDeviceTest, givenCorrectRawDataHeaderWhenBothSubDev EXPECT_EQ(metricCounts[1], 0u); } +TEST_F(MetricEnumerationMultiDeviceTest, givenOaMetricSourceWhenGetConcurrentMetricGroupsIsCalledThenCorrectConcurrentGroupsAreRetrieved) { + auto &metricSource = devices[0]->getMetricDeviceContext().getMetricSource(); + + metricsDeviceParams.ConcurrentGroupsCount = 1; + + Mock metricsConcurrentGroup; + TConcurrentGroupParams_1_0 metricsConcurrentGroupParams = {}; + metricsConcurrentGroupParams.MetricSetsCount = 1; + metricsConcurrentGroupParams.SymbolName = "OA"; + + Mock metricsSet; + MetricsDiscovery::TMetricSetParams_1_4 metricsSetParams = {}; + metricsSetParams.ApiMask = MetricsDiscovery::API_TYPE_IOSTREAM; + metricsSetParams.RawReportSize = 256; + metricsSetParams.MetricsCount = 11; + + Mock metric; + MetricsDiscovery::TMetricParams_1_0 metricParams = {}; + + zet_metric_group_handle_t metricGroupHandle = {}; + + uint32_t returnedMetricCount = 2; + + openMetricsAdapter(); + + setupDefaultMocksForMetricDevice(metricsDevice); + + metricsDevice.getConcurrentGroupResults.push_back(&metricsConcurrentGroup); + + metricsConcurrentGroup.GetParamsResult = &metricsConcurrentGroupParams; + metricsConcurrentGroup.getMetricSetResult = &metricsSet; + + metricsSet.GetParamsResult = &metricsSetParams; + metricsSet.GetMetricResult = &metric; + metricsSet.calculateMetricsOutReportCount = &returnedMetricCount; + + metric.GetParamsResult = &metricParams; + + // Metric group handles. + uint32_t metricGroupCount = 1; + EXPECT_EQ(zetMetricGroupGet(devices[0]->toHandle(), &metricGroupCount, &metricGroupHandle), ZE_RESULT_SUCCESS); + EXPECT_EQ(metricGroupCount, 1u); + EXPECT_NE(metricGroupHandle, nullptr); + + std::vector metricGroupList{}; + metricGroupList.push_back(metricGroupHandle); + + uint32_t concurrentGroupCount = 0; + EXPECT_EQ(ZE_RESULT_SUCCESS, metricSource.getConcurrentMetricGroups(metricGroupList, &concurrentGroupCount, nullptr)); + EXPECT_EQ(concurrentGroupCount, 1u); + + std::vector countPerConcurrentGroup(concurrentGroupCount); + concurrentGroupCount += 1; + EXPECT_EQ(ZE_RESULT_SUCCESS, metricSource.getConcurrentMetricGroups(metricGroupList, &concurrentGroupCount, countPerConcurrentGroup.data())); + EXPECT_EQ(concurrentGroupCount, 1u); +} + } // namespace ult } // namespace L0