fix(zebin): allow for recursive function calls

Allow for loops when detecting function dependency in zebin.

Signed-off-by: Krystian Chmielewski <krystian.chmielewski@intel.com>
This commit is contained in:
Krystian Chmielewski
2022-11-21 14:54:41 +00:00
committed by Compute-Runtime-Automation
parent 67bfebb25e
commit 6ec5647b36
3 changed files with 48 additions and 45 deletions

View File

@@ -12,8 +12,8 @@
#include <algorithm> #include <algorithm>
namespace NEO { namespace NEO {
uint32_t resolveBarrierCount(ExternalFunctionInfosT externalFunctionInfos, KernelDependenciesT kernelDependencies, uint32_t resolveBarrierCount(const ExternalFunctionInfosT &externalFunctionInfos, const KernelDependenciesT &kernelDependencies,
FunctionDependenciesT funcDependencies, KernelDescriptorMapT &nameToKernelDescriptor) { const FunctionDependenciesT &funcDependencies, const KernelDescriptorMapT &nameToKernelDescriptor) {
FuncNameToIdMapT funcNameToId; FuncNameToIdMapT funcNameToId;
for (size_t i = 0U; i < externalFunctionInfos.size(); i++) { for (size_t i = 0U; i < externalFunctionInfos.size(); i++) {
auto &extFuncInfo = externalFunctionInfos[i]; auto &extFuncInfo = externalFunctionInfos[i];
@@ -29,7 +29,7 @@ uint32_t resolveBarrierCount(ExternalFunctionInfosT externalFunctionInfos, Kerne
return error; return error;
} }
uint32_t getExtFuncDependencies(FuncNameToIdMapT &funcNameToId, FunctionDependenciesT funcDependencies, size_t numExternalFuncs, uint32_t getExtFuncDependencies(const FuncNameToIdMapT &funcNameToId, const FunctionDependenciesT &funcDependencies, size_t numExternalFuncs,
DependenciesT &outDependencies, CalledByT &outCalledBy) { DependenciesT &outDependencies, CalledByT &outCalledBy) {
outDependencies.resize(numExternalFuncs); outDependencies.resize(numExternalFuncs);
outCalledBy.resize(numExternalFuncs); outCalledBy.resize(numExternalFuncs);
@@ -39,8 +39,8 @@ uint32_t getExtFuncDependencies(FuncNameToIdMapT &funcNameToId, FunctionDependen
funcNameToId.count(funcDep->usedFuncName) == 0) { funcNameToId.count(funcDep->usedFuncName) == 0) {
return ERROR_EXTERNAL_FUNCTION_INFO_MISSING; return ERROR_EXTERNAL_FUNCTION_INFO_MISSING;
} }
size_t callerId = funcNameToId[funcDep->callerFuncName]; size_t callerId = funcNameToId.at(funcDep->callerFuncName);
size_t calleeId = funcNameToId[funcDep->usedFuncName]; size_t calleeId = funcNameToId.at(funcDep->usedFuncName);
outDependencies[callerId].push_back(calleeId); outDependencies[callerId].push_back(calleeId);
outCalledBy[calleeId].push_back(callerId); outCalledBy[calleeId].push_back(callerId);
@@ -49,7 +49,7 @@ uint32_t getExtFuncDependencies(FuncNameToIdMapT &funcNameToId, FunctionDependen
return RESOLVE_SUCCESS; return RESOLVE_SUCCESS;
} }
uint32_t resolveExtFuncDependencies(ExternalFunctionInfosT externalFunctionInfos, FuncNameToIdMapT &funcNameToId, FunctionDependenciesT funcDependencies) { uint32_t resolveExtFuncDependencies(const ExternalFunctionInfosT &externalFunctionInfos, const FuncNameToIdMapT &funcNameToId, const FunctionDependenciesT &funcDependencies) {
DependenciesT dependencies; DependenciesT dependencies;
CalledByT calledBy; CalledByT calledBy;
auto error = getExtFuncDependencies(funcNameToId, funcDependencies, externalFunctionInfos.size(), dependencies, calledBy); auto error = getExtFuncDependencies(funcNameToId, funcDependencies, externalFunctionInfos.size(), dependencies, calledBy);
@@ -59,9 +59,6 @@ uint32_t resolveExtFuncDependencies(ExternalFunctionInfosT externalFunctionInfos
DependencyResolver depResolver(dependencies); DependencyResolver depResolver(dependencies);
auto resolved = depResolver.resolveDependencies(); auto resolved = depResolver.resolveDependencies();
if (depResolver.hasLoop()) {
return ERROR_LOOP_DETECTED;
}
for (auto calleeId : resolved) { for (auto calleeId : resolved) {
const auto callee = externalFunctionInfos[calleeId]; const auto callee = externalFunctionInfos[calleeId];
for (auto callerId : calledBy[calleeId]) { for (auto callerId : calledBy[calleeId]) {
@@ -72,16 +69,15 @@ uint32_t resolveExtFuncDependencies(ExternalFunctionInfosT externalFunctionInfos
return RESOLVE_SUCCESS; return RESOLVE_SUCCESS;
} }
uint32_t resolveKernelDependencies(ExternalFunctionInfosT externalFunctionInfos, FuncNameToIdMapT &funcNameToId, KernelDependenciesT kernelDependencies, KernelDescriptorMapT &nameToKernelDescriptor) { uint32_t resolveKernelDependencies(const ExternalFunctionInfosT &externalFunctionInfos, const FuncNameToIdMapT &funcNameToId, const KernelDependenciesT &kernelDependencies, const KernelDescriptorMapT &nameToKernelDescriptor) {
for (size_t i = 0; i < kernelDependencies.size(); i++) { for (auto &kernelDep : kernelDependencies) {
auto kernelDep = kernelDependencies[i];
if (funcNameToId.count(kernelDep->usedFuncName) == 0) { if (funcNameToId.count(kernelDep->usedFuncName) == 0) {
return ERROR_EXTERNAL_FUNCTION_INFO_MISSING; return ERROR_EXTERNAL_FUNCTION_INFO_MISSING;
} else if (nameToKernelDescriptor.count(kernelDep->kernelName) == 0) { } else if (nameToKernelDescriptor.count(kernelDep->kernelName) == 0) {
return ERROR_KERNEL_DESCRIPTOR_MISSING; return ERROR_KERNEL_DESCRIPTOR_MISSING;
} }
const auto functionBarrierCount = externalFunctionInfos[funcNameToId[kernelDep->usedFuncName]]->barrierCount; const auto functionBarrierCount = externalFunctionInfos.at(funcNameToId.at(kernelDep->usedFuncName))->barrierCount;
auto &kernelBarrierCount = nameToKernelDescriptor[kernelDep->kernelName]->kernelAttributes.barrierCount; auto &kernelBarrierCount = nameToKernelDescriptor.at(kernelDep->kernelName)->kernelAttributes.barrierCount;
kernelBarrierCount = std::max(kernelBarrierCount, functionBarrierCount); kernelBarrierCount = std::max(kernelBarrierCount, functionBarrierCount);
} }
return RESOLVE_SUCCESS; return RESOLVE_SUCCESS;
@@ -89,27 +85,21 @@ uint32_t resolveKernelDependencies(ExternalFunctionInfosT externalFunctionInfos,
std::vector<size_t> DependencyResolver::resolveDependencies() { std::vector<size_t> DependencyResolver::resolveDependencies() {
for (size_t i = 0; i < graph.size(); i++) { for (size_t i = 0; i < graph.size(); i++) {
if (std::find(seen.begin(), seen.end(), i) == seen.end() && graph[i].empty() == false) { if (std::find(seen.begin(), seen.end(), i) == seen.end()) {
resolveDependency(i, graph[i]); resolveDependency(i, graph[i]);
} }
} }
if (loopDeteckted) {
return {};
}
return resolved; return resolved;
} }
void DependencyResolver::resolveDependency(size_t nodeId, const std::vector<size_t> &edges) { void DependencyResolver::resolveDependency(size_t nodeId, const std::vector<size_t> &edges) {
seen.push_back(nodeId); seen.push_back(nodeId);
for (auto &edgeId : edges) { for (auto &edgeId : edges) {
if (std::find(resolved.begin(), resolved.end(), edgeId) == resolved.end()) { if (std::find(resolved.begin(), resolved.end(), edgeId) == resolved.end() &&
if (std::find(seen.begin(), seen.end(), edgeId) != seen.end()) { std::find(seen.begin(), seen.end(), edgeId) == seen.end()) {
loopDeteckted = true;
} else {
resolveDependency(edgeId, graph[edgeId]); resolveDependency(edgeId, graph[edgeId]);
} }
} }
}
resolved.push_back(nodeId); resolved.push_back(nodeId);
} }
} // namespace NEO } // namespace NEO

View File

@@ -52,24 +52,22 @@ class DependencyResolver {
public: public:
DependencyResolver(const std::vector<std::vector<size_t>> &graph) : graph(graph) {} DependencyResolver(const std::vector<std::vector<size_t>> &graph) : graph(graph) {}
std::vector<size_t> resolveDependencies(); std::vector<size_t> resolveDependencies();
inline bool hasLoop() { return loopDeteckted; }
protected: protected:
void resolveDependency(size_t nodeId, const std::vector<size_t> &edges); void resolveDependency(size_t nodeId, const std::vector<size_t> &edges);
std::vector<size_t> seen; std::vector<size_t> seen;
std::vector<size_t> resolved; std::vector<size_t> resolved;
const std::vector<std::vector<size_t>> &graph; const std::vector<std::vector<size_t>> &graph;
bool loopDeteckted = false;
}; };
uint32_t resolveBarrierCount(ExternalFunctionInfosT externalFunctionInfos, KernelDependenciesT kernelDependencies, uint32_t resolveBarrierCount(const ExternalFunctionInfosT &externalFunctionInfos, const KernelDependenciesT &kernelDependencies,
FunctionDependenciesT funcDependencies, KernelDescriptorMapT &nameToKernelDescriptor); const FunctionDependenciesT &funcDependencies, const KernelDescriptorMapT &nameToKernelDescriptor);
uint32_t getExtFuncDependencies(FuncNameToIdMapT &funcNameToId, FunctionDependenciesT funcDependencies, size_t numExternalFuncs, uint32_t getExtFuncDependencies(const FuncNameToIdMapT &funcNameToId, const FunctionDependenciesT &funcDependencies, size_t numExternalFuncs,
DependenciesT &outDependencies, CalledByT &outCalledBy); DependenciesT &outDependencies, CalledByT &outCalledBy);
uint32_t resolveExtFuncDependencies(ExternalFunctionInfosT externalFunctionInfos, FuncNameToIdMapT &funcNameToId, FunctionDependenciesT funcDependencies); uint32_t resolveExtFuncDependencies(const ExternalFunctionInfosT &externalFunctionInfos, const FuncNameToIdMapT &funcNameToId, const FunctionDependenciesT &funcDependencies);
uint32_t resolveKernelDependencies(ExternalFunctionInfosT externalFunctionInfos, FuncNameToIdMapT &funcNameToId, KernelDependenciesT kernelDependencies, KernelDescriptorMapT &nameToKernelDescriptor); uint32_t resolveKernelDependencies(const ExternalFunctionInfosT &externalFunctionInfos, const FuncNameToIdMapT &funcNameToId, const KernelDependenciesT &kernelDependencies, const KernelDescriptorMapT &nameToKernelDescriptor);
} // namespace NEO } // namespace NEO

View File

@@ -21,7 +21,7 @@ TEST(DependencyResolverTests, GivenEmptyGraphReturnEmptyResolve) {
EXPECT_TRUE(resolve.empty()); EXPECT_TRUE(resolve.empty());
} }
TEST(DependencyResolverTests, GivenGraphWithLoopReturnEmptyResolveAndSetLoopDeteckted) { TEST(DependencyResolverTests, GivenGraphWithLoopReturnCorrectResolve) {
/* /*
0 -> 1 0 -> 1
^ | ^ |
@@ -29,10 +29,23 @@ TEST(DependencyResolverTests, GivenGraphWithLoopReturnEmptyResolveAndSetLoopDete
3 <- 2 3 <- 2
*/ */
std::vector<std::vector<size_t>> graph = {{1}, {2}, {3}, {0}}; std::vector<std::vector<size_t>> graph = {{1}, {2}, {3}, {0}};
std::vector<size_t> expectedResolve = {3, 2, 1, 0};
DependencyResolver resolver(graph); DependencyResolver resolver(graph);
const auto &resolve = resolver.resolveDependencies(); const auto &resolve = resolver.resolveDependencies();
EXPECT_TRUE(resolve.empty()); EXPECT_EQ(expectedResolve, resolve);
EXPECT_TRUE(resolver.hasLoop()); }
TEST(DependencyResolverTests, GivenGraphWithNodeConnectedToItselfThenReturnCorrectResolve) {
/*
0-->
^ |
|<-v
*/
std::vector<std::vector<size_t>> graph = {{0}};
std::vector<size_t> expectedResolve = {0};
DependencyResolver resolver(graph);
const auto &resolve = resolver.resolveDependencies();
EXPECT_EQ(expectedResolve, resolve);
} }
TEST(DependencyResolverTests, GivenOneConnectedGraphReturnCorrectResolve) { TEST(DependencyResolverTests, GivenOneConnectedGraphReturnCorrectResolve) {
@@ -56,7 +69,7 @@ TEST(DependencyResolverTests, GivenMultipleDisconnectedGraphsReturnCorrectResolv
4 -> 3 4 -> 3
*/ */
std::vector<std::vector<size_t>> graph = {{}, {2}, {5}, {}, {3}, {}}; std::vector<std::vector<size_t>> graph = {{}, {2}, {5}, {}, {3}, {}};
std::vector<size_t> expectedResolve = {5, 2, 1, 3, 4}; std::vector<size_t> expectedResolve = {0, 5, 2, 1, 3, 4};
DependencyResolver resolver(graph); DependencyResolver resolver(graph);
const auto &resolve = resolver.resolveDependencies(); const auto &resolve = resolver.resolveDependencies();
EXPECT_EQ(expectedResolve, resolve); EXPECT_EQ(expectedResolve, resolve);
@@ -137,16 +150,6 @@ TEST_F(ExternalFunctionsTests, GivenMissingExtFuncInLookupMapWhenResolvingExtFun
EXPECT_EQ(ERROR_EXTERNAL_FUNCTION_INFO_MISSING, error); EXPECT_EQ(ERROR_EXTERNAL_FUNCTION_INFO_MISSING, error);
} }
TEST_F(ExternalFunctionsTests, GivenLoopWhenResolvingExtFuncDependenciesThenReturnError) {
addExternalFunction("fun0", 0);
addExternalFunction("fun1", 0);
addFuncDependency("fun0", "fun1");
addFuncDependency("fun1", "fun0");
set();
auto error = resolveExtFuncDependencies(extFuncInfo, funcNameToId, functionDependencies);
EXPECT_EQ(ERROR_LOOP_DETECTED, error);
}
TEST_F(ExternalFunctionsTests, GivenMissingExtFuncInLookupMapWhenResolvingKernelDependenciesThenReturnError) { TEST_F(ExternalFunctionsTests, GivenMissingExtFuncInLookupMapWhenResolvingKernelDependenciesThenReturnError) {
addKernel("kernel"); addKernel("kernel");
addKernelDependency("fun0", "kernel"); addKernelDependency("fun0", "kernel");
@@ -184,6 +187,18 @@ TEST_F(ExternalFunctionsTests, GivenMissingExtFuncInKernelDependenciesWhenResolv
EXPECT_EQ(ERROR_EXTERNAL_FUNCTION_INFO_MISSING, error); EXPECT_EQ(ERROR_EXTERNAL_FUNCTION_INFO_MISSING, error);
} }
TEST_F(ExternalFunctionsTests, GivenLoopWhenResolvingExtFuncDependenciesThenReturnSuccess) {
addExternalFunction("fun0", 4);
addExternalFunction("fun1", 2);
addFuncDependency("fun0", "fun1");
addFuncDependency("fun1", "fun0");
set();
auto retVal = resolveExtFuncDependencies(extFuncInfo, funcNameToId, functionDependencies);
EXPECT_EQ(RESOLVE_SUCCESS, retVal);
EXPECT_EQ(4U, extFuncInfo[funcNameToId["fun0"]]->barrierCount);
EXPECT_EQ(4U, extFuncInfo[funcNameToId["fun1"]]->barrierCount);
}
TEST_F(ExternalFunctionsTests, GivenValidFunctionAndKernelDependenciesWhenResolvingBarrierCountThenSetAppropriateBarrierCountAndReturnSuccess) { TEST_F(ExternalFunctionsTests, GivenValidFunctionAndKernelDependenciesWhenResolvingBarrierCountThenSetAppropriateBarrierCountAndReturnSuccess) {
addKernel("kernel"); addKernel("kernel");
addExternalFunction("fun0", 1U); addExternalFunction("fun0", 1U);