/* * Copyright (C) 2019-2023 Intel Corporation * * SPDX-License-Identifier: MIT * */ #pragma once #include "shared/source/utilities/stackvec.h" #include "shared/test/common/cmd_parse/hw_parse.h" #include #include namespace NEO { struct CmdValidator { CmdValidator() { } virtual ~CmdValidator() = default; virtual bool operator()(GenCmdList::iterator it, size_t numInSection, const std::string &member, std::string &outFailReason) = 0; }; template struct CmdValidatorWithStaticStorage : CmdValidator { static ChildT *get() { static ChildT val; return &val; } }; template struct GenericCmdValidator : CmdValidatorWithStaticStorage> { bool operator()(GenCmdList::iterator it, size_t numInSection, const std::string &member, std::string &outFailReason) override { auto cmd = genCmdCast(*it); UNRECOVERABLE_IF(cmd == nullptr); if (expected != (cmd->*getter)()) { outFailReason = member + " - expected: " + std::to_string(expected) + ", got: " + std::to_string((cmd->*getter)()); return false; } return true; } }; struct NamedValidator { NamedValidator(CmdValidator *validator) : NamedValidator(validator, "Unspecified") { } NamedValidator(CmdValidator *validator, const char *name) : validator(validator), name(name) { } CmdValidator *validator; const char *name; }; #define EXPECT_MEMBER(TYPE, FUNC, EXPECTED) \ NamedValidator { GenericCmdValidator, &TYPE::FUNC, EXPECTED>::get(), #FUNC } using Expects = std::vector; struct MatchCmd { MatchCmd(int amount, bool matchesAny) : amount(amount), matchesAny(matchesAny) { } MatchCmd(int amount) : MatchCmd(amount, false) { } virtual ~MatchCmd() = default; virtual bool matches(GenCmdList::iterator it) const = 0; virtual bool validates(GenCmdList::iterator it, std::string &outReason) const = 0; virtual const char *getName() const = 0; virtual void capture(GenCmdList::iterator it) = 0; int getExpectedCount() const { return amount; } bool getMatchesAny() const { return matchesAny; } protected: int amount = 0; bool matchesAny = false; }; inline constexpr int32_t anyNumber = -1; inline constexpr int32_t atLeastOne = -2; inline std::string countToString(int32_t count) { if (count == anyNumber) { return "AnyNumber"; } else if (count == atLeastOne) { return "AtLeastOne"; } else { return std::to_string(count); } } inline bool notPreciseNumber(int32_t count) { return (count == anyNumber) || (count == atLeastOne); } struct MatchAnyCmd : MatchCmd { MatchAnyCmd(int amount) : MatchCmd(amount, true) { if (amount > 0) { captured.reserve(amount); } } bool matches(GenCmdList::iterator it) const override { return true; } bool validates(GenCmdList::iterator it, std::string &outReason) const override { return true; } void capture(GenCmdList::iterator it) override { captured.push_back(*it); } const char *getName() const override { return "AnyCommand"; } protected: StackVec captured; }; template struct MatchHwCmd : MatchCmd { MatchHwCmd(int amount) : MatchCmd(amount) { if (amount > 0) { captured.reserve(amount); } } MatchHwCmd(int amount, Expects &&validators) : MatchHwCmd(amount) { this->validators.swap(validators); } bool matches(GenCmdList::iterator it) const override { return nullptr != genCmdCast(*it); } bool validates(GenCmdList::iterator it, std::string &outReason) const override { for (auto &v : validators) { if (false == (*v.validator)(it, captured.size(), v.name, outReason)) { return false; } } return true; } void capture(GenCmdList::iterator it) override { UNRECOVERABLE_IF(false == matches(it)); UNRECOVERABLE_IF(captured.size() == static_cast(amount)); captured.push_back(genCmdCast(*it)); } const char *getName() const override { CmdType cmd; cmd.init(); return HardwareParse::getCommandName(&cmd); } protected: StackVec captured; Expects validators; }; template inline bool expectCmdBuff(GenCmdList::iterator begin, GenCmdList::iterator end, std::vector &&expectedCmdBuffMatchers, std::string *outReason = nullptr) { if (expectedCmdBuffMatchers.size() == 0) { return begin == end; } bool failed = false; std::string failReason; auto it = begin; int cmdNum = 0; size_t currentMatcher = 0; int currentMatcherCount = 0; StackVec, 32> matchedCommandNames; auto matchedCommandsString = [&]() -> std::string { if (matchedCommandNames.size() == 0) { return "EMPTY"; } std::string ret = ""; for (size_t i = 0; i < matchedCommandNames.size(); ++i) { if (matchedCommandNames[i].second) { ret += std::to_string(i) + ":ANY(" + matchedCommandNames[i].first + ") "; } else { ret += std::to_string(i) + ":" + matchedCommandNames[i].first + " "; } } return ret; }; while (it != end) { if (currentMatcher < expectedCmdBuffMatchers.size()) { auto currentMatcherExpectedCount = expectedCmdBuffMatchers[currentMatcher]->getExpectedCount(); if (expectedCmdBuffMatchers[currentMatcher]->getMatchesAny() && ((currentMatcherExpectedCount == anyNumber) || ((currentMatcherExpectedCount == atLeastOne) && (currentMatcherCount > 0)))) { if (expectedCmdBuffMatchers.size() > currentMatcher + 1) { // eat as many as possible but proceed to next matcher when possible if (expectedCmdBuffMatchers[currentMatcher + 1]->matches(it)) { ++currentMatcher; currentMatcherCount = 0; } } } else if ((notPreciseNumber(expectedCmdBuffMatchers[currentMatcher]->getExpectedCount())) && (false == expectedCmdBuffMatchers[currentMatcher]->matches(it))) { // proceed to next matcher if not matched if ((expectedCmdBuffMatchers[currentMatcher]->getExpectedCount() == atLeastOne) && (currentMatcherCount < 1)) { failed = true; failReason = "Unmatched cmd#" + std::to_string(cmdNum) + ":" + HardwareParse::getCommandName(*it) + " - expected " + std::string(expectedCmdBuffMatchers[currentMatcher]->getName()) + "(" + countToString(expectedCmdBuffMatchers[currentMatcher]->getExpectedCount()) + " - " + std::to_string(currentMatcherCount) + ") after : " + matchedCommandsString(); break; } ++currentMatcher; currentMatcherCount = 0; } while ((currentMatcher < expectedCmdBuffMatchers.size()) && expectedCmdBuffMatchers[currentMatcher]->getExpectedCount() == 0) { if (expectedCmdBuffMatchers[currentMatcher]->matches(it)) { failed = true; failReason = "Unmatched cmd#" + std::to_string(cmdNum) + " - expected anything but " + std::string(expectedCmdBuffMatchers[currentMatcher]->getName()) + "(" + countToString(expectedCmdBuffMatchers[currentMatcher]->getExpectedCount()) + " - " + std::to_string(currentMatcherCount) + ") after : " + matchedCommandsString(); break; } ++currentMatcher; currentMatcherCount = 0; } } if (currentMatcher >= expectedCmdBuffMatchers.size()) { failed = true; std::string unmatchedCommands; while (it != end) { unmatchedCommands += std::to_string(cmdNum) + ":" + HardwareParse::getCommandName(*it) + " "; ++it; ++cmdNum; } failReason = "Unexpected commands at the end of the command buffer : " + unmatchedCommands + ", AFTER : " + matchedCommandsString(); break; } if (false == expectedCmdBuffMatchers[currentMatcher]->matches(it)) { failed = true; failReason = "Unmatched cmd#" + std::to_string(cmdNum) + ":" + HardwareParse::getCommandName(*it) + " - expected " + std::string(expectedCmdBuffMatchers[currentMatcher]->getName()) + "(" + countToString(expectedCmdBuffMatchers[currentMatcher]->getExpectedCount()) + " - " + std::to_string(currentMatcherCount) + ") after : " + matchedCommandsString(); break; } if (false == expectedCmdBuffMatchers[currentMatcher]->validates(it, failReason)) { failReason = "cmd#" + std::to_string(cmdNum) + " (" + HardwareParse::getCommandName(*it) + ") failed validation - reason : " + failReason + " after : " + matchedCommandsString(); failed = true; break; } matchedCommandNames.push_back(std::make_pair(HardwareParse::getCommandName(*it), expectedCmdBuffMatchers[currentMatcher]->getMatchesAny())); ++currentMatcherCount; if (currentMatcherCount == expectedCmdBuffMatchers[currentMatcher]->getExpectedCount()) { ++currentMatcher; currentMatcherCount = 0; } ++cmdNum; ++it; } if (failed == false) { while ((currentMatcher < expectedCmdBuffMatchers.size()) && ((expectedCmdBuffMatchers[currentMatcher]->getExpectedCount() == 0) || (expectedCmdBuffMatchers[currentMatcher]->getExpectedCount() == anyNumber))) { ++currentMatcher; currentMatcherCount = 0; } if (currentMatcher == expectedCmdBuffMatchers.size()) { // no more matchers } else if (currentMatcher + 1 == expectedCmdBuffMatchers.size()) { // last matcher auto currentMatcherExpectedCount = expectedCmdBuffMatchers[currentMatcher]->getExpectedCount(); if ((currentMatcherExpectedCount == atLeastOne) && (currentMatcherCount < 1)) { failReason = "Unexpected command buffer end at cmd#" + std::to_string(cmdNum) + " - expected " + expectedCmdBuffMatchers[currentMatcher]->getName() + "(" + countToString(currentMatcherExpectedCount) + " - " + std::to_string(currentMatcherCount) + ") after : " + matchedCommandsString(); failed = true; } if ((false == notPreciseNumber(currentMatcherExpectedCount)) && (currentMatcherExpectedCount != currentMatcherCount)) { failReason = "Unexpected command buffer end at cmd#" + std::to_string(cmdNum) + " - expected " + expectedCmdBuffMatchers[currentMatcher]->getName() + "(" + countToString(currentMatcherExpectedCount) + " - " + std::to_string(currentMatcherCount) + ") after : " + matchedCommandsString(); failed = true; } } else { // many matchers left std::string expectedMatchers = ""; int32_t currentMatcherExpectedCount = expectedCmdBuffMatchers[currentMatcher]->getExpectedCount(); expectedMatchers = expectedCmdBuffMatchers[currentMatcher]->getName() + std::string("(") + countToString(currentMatcherExpectedCount) + " - " + std::to_string(currentMatcherCount) + "), "; ++currentMatcher; while (currentMatcher < expectedCmdBuffMatchers.size()) { currentMatcherExpectedCount = expectedCmdBuffMatchers[currentMatcher]->getExpectedCount(); expectedMatchers += expectedCmdBuffMatchers[currentMatcher]->getName() + std::string("(") + countToString(currentMatcherExpectedCount) + " - 0), "; ++currentMatcher; } failReason = "Unexpected command buffer end at cmd#" + std::to_string(cmdNum) + " - expected " + expectedMatchers + " after : " + matchedCommandsString(); failed = true; } } else { if ((it != end) && (++it != end)) { ++cmdNum; failReason += "\n Unconsumed commands after failed one : "; while (it != end) { failReason += std::to_string(cmdNum) + ":" + HardwareParse::getCommandName(*it) + " "; ++cmdNum; ++it; } } } if (failed) { if (outReason != nullptr) { failReason += "\n Note : Input command buffer was : "; it = begin; cmdNum = 0; while (it != end) { failReason += std::to_string(cmdNum) + ":" + HardwareParse::getCommandName(*it) + " "; ++cmdNum; ++it; } *outReason = failReason; } } for (auto *matcher : expectedCmdBuffMatchers) { delete matcher; } return (failed == false); } template inline bool expectCmdBuff(NEO::LinearStream &commandStream, size_t startOffset, std::vector &&expectedCmdBuffMatchers, std::string *outReason = nullptr) { HardwareParse hwParser; hwParser.parseCommands(commandStream, startOffset); return expectCmdBuff(hwParser.cmdList.begin(), hwParser.cmdList.end(), std::move(expectedCmdBuffMatchers), outReason); } } // namespace NEO