fix(zebin): allow for recursive function calls

Allow for loops when detecting function dependency in zebin.

Signed-off-by: Krystian Chmielewski <krystian.chmielewski@intel.com>
This commit is contained in:
Krystian Chmielewski 2022-11-21 14:54:41 +00:00 committed by Compute-Runtime-Automation
parent 67bfebb25e
commit 6ec5647b36
3 changed files with 48 additions and 45 deletions

View File

@ -12,8 +12,8 @@
#include <algorithm>
namespace NEO {
uint32_t resolveBarrierCount(ExternalFunctionInfosT externalFunctionInfos, KernelDependenciesT kernelDependencies,
FunctionDependenciesT funcDependencies, KernelDescriptorMapT &nameToKernelDescriptor) {
uint32_t resolveBarrierCount(const ExternalFunctionInfosT &externalFunctionInfos, const KernelDependenciesT &kernelDependencies,
const FunctionDependenciesT &funcDependencies, const KernelDescriptorMapT &nameToKernelDescriptor) {
FuncNameToIdMapT funcNameToId;
for (size_t i = 0U; i < externalFunctionInfos.size(); i++) {
auto &extFuncInfo = externalFunctionInfos[i];
@ -29,7 +29,7 @@ uint32_t resolveBarrierCount(ExternalFunctionInfosT externalFunctionInfos, Kerne
return error;
}
uint32_t getExtFuncDependencies(FuncNameToIdMapT &funcNameToId, FunctionDependenciesT funcDependencies, size_t numExternalFuncs,
uint32_t getExtFuncDependencies(const FuncNameToIdMapT &funcNameToId, const FunctionDependenciesT &funcDependencies, size_t numExternalFuncs,
DependenciesT &outDependencies, CalledByT &outCalledBy) {
outDependencies.resize(numExternalFuncs);
outCalledBy.resize(numExternalFuncs);
@ -39,8 +39,8 @@ uint32_t getExtFuncDependencies(FuncNameToIdMapT &funcNameToId, FunctionDependen
funcNameToId.count(funcDep->usedFuncName) == 0) {
return ERROR_EXTERNAL_FUNCTION_INFO_MISSING;
}
size_t callerId = funcNameToId[funcDep->callerFuncName];
size_t calleeId = funcNameToId[funcDep->usedFuncName];
size_t callerId = funcNameToId.at(funcDep->callerFuncName);
size_t calleeId = funcNameToId.at(funcDep->usedFuncName);
outDependencies[callerId].push_back(calleeId);
outCalledBy[calleeId].push_back(callerId);
@ -49,7 +49,7 @@ uint32_t getExtFuncDependencies(FuncNameToIdMapT &funcNameToId, FunctionDependen
return RESOLVE_SUCCESS;
}
uint32_t resolveExtFuncDependencies(ExternalFunctionInfosT externalFunctionInfos, FuncNameToIdMapT &funcNameToId, FunctionDependenciesT funcDependencies) {
uint32_t resolveExtFuncDependencies(const ExternalFunctionInfosT &externalFunctionInfos, const FuncNameToIdMapT &funcNameToId, const FunctionDependenciesT &funcDependencies) {
DependenciesT dependencies;
CalledByT calledBy;
auto error = getExtFuncDependencies(funcNameToId, funcDependencies, externalFunctionInfos.size(), dependencies, calledBy);
@ -59,9 +59,6 @@ uint32_t resolveExtFuncDependencies(ExternalFunctionInfosT externalFunctionInfos
DependencyResolver depResolver(dependencies);
auto resolved = depResolver.resolveDependencies();
if (depResolver.hasLoop()) {
return ERROR_LOOP_DETECTED;
}
for (auto calleeId : resolved) {
const auto callee = externalFunctionInfos[calleeId];
for (auto callerId : calledBy[calleeId]) {
@ -72,16 +69,15 @@ uint32_t resolveExtFuncDependencies(ExternalFunctionInfosT externalFunctionInfos
return RESOLVE_SUCCESS;
}
uint32_t resolveKernelDependencies(ExternalFunctionInfosT externalFunctionInfos, FuncNameToIdMapT &funcNameToId, KernelDependenciesT kernelDependencies, KernelDescriptorMapT &nameToKernelDescriptor) {
for (size_t i = 0; i < kernelDependencies.size(); i++) {
auto kernelDep = kernelDependencies[i];
uint32_t resolveKernelDependencies(const ExternalFunctionInfosT &externalFunctionInfos, const FuncNameToIdMapT &funcNameToId, const KernelDependenciesT &kernelDependencies, const KernelDescriptorMapT &nameToKernelDescriptor) {
for (auto &kernelDep : kernelDependencies) {
if (funcNameToId.count(kernelDep->usedFuncName) == 0) {
return ERROR_EXTERNAL_FUNCTION_INFO_MISSING;
} else if (nameToKernelDescriptor.count(kernelDep->kernelName) == 0) {
return ERROR_KERNEL_DESCRIPTOR_MISSING;
}
const auto functionBarrierCount = externalFunctionInfos[funcNameToId[kernelDep->usedFuncName]]->barrierCount;
auto &kernelBarrierCount = nameToKernelDescriptor[kernelDep->kernelName]->kernelAttributes.barrierCount;
const auto functionBarrierCount = externalFunctionInfos.at(funcNameToId.at(kernelDep->usedFuncName))->barrierCount;
auto &kernelBarrierCount = nameToKernelDescriptor.at(kernelDep->kernelName)->kernelAttributes.barrierCount;
kernelBarrierCount = std::max(kernelBarrierCount, functionBarrierCount);
}
return RESOLVE_SUCCESS;
@ -89,25 +85,19 @@ uint32_t resolveKernelDependencies(ExternalFunctionInfosT externalFunctionInfos,
std::vector<size_t> DependencyResolver::resolveDependencies() {
for (size_t i = 0; i < graph.size(); i++) {
if (std::find(seen.begin(), seen.end(), i) == seen.end() && graph[i].empty() == false) {
if (std::find(seen.begin(), seen.end(), i) == seen.end()) {
resolveDependency(i, graph[i]);
}
}
if (loopDeteckted) {
return {};
}
return resolved;
}
void DependencyResolver::resolveDependency(size_t nodeId, const std::vector<size_t> &edges) {
seen.push_back(nodeId);
for (auto &edgeId : edges) {
if (std::find(resolved.begin(), resolved.end(), edgeId) == resolved.end()) {
if (std::find(seen.begin(), seen.end(), edgeId) != seen.end()) {
loopDeteckted = true;
} else {
resolveDependency(edgeId, graph[edgeId]);
}
if (std::find(resolved.begin(), resolved.end(), edgeId) == resolved.end() &&
std::find(seen.begin(), seen.end(), edgeId) == seen.end()) {
resolveDependency(edgeId, graph[edgeId]);
}
}
resolved.push_back(nodeId);

View File

@ -52,24 +52,22 @@ class DependencyResolver {
public:
DependencyResolver(const std::vector<std::vector<size_t>> &graph) : graph(graph) {}
std::vector<size_t> resolveDependencies();
inline bool hasLoop() { return loopDeteckted; }
protected:
void resolveDependency(size_t nodeId, const std::vector<size_t> &edges);
std::vector<size_t> seen;
std::vector<size_t> resolved;
const std::vector<std::vector<size_t>> &graph;
bool loopDeteckted = false;
};
uint32_t resolveBarrierCount(ExternalFunctionInfosT externalFunctionInfos, KernelDependenciesT kernelDependencies,
FunctionDependenciesT funcDependencies, KernelDescriptorMapT &nameToKernelDescriptor);
uint32_t resolveBarrierCount(const ExternalFunctionInfosT &externalFunctionInfos, const KernelDependenciesT &kernelDependencies,
const FunctionDependenciesT &funcDependencies, const KernelDescriptorMapT &nameToKernelDescriptor);
uint32_t getExtFuncDependencies(FuncNameToIdMapT &funcNameToId, FunctionDependenciesT funcDependencies, size_t numExternalFuncs,
uint32_t getExtFuncDependencies(const FuncNameToIdMapT &funcNameToId, const FunctionDependenciesT &funcDependencies, size_t numExternalFuncs,
DependenciesT &outDependencies, CalledByT &outCalledBy);
uint32_t resolveExtFuncDependencies(ExternalFunctionInfosT externalFunctionInfos, FuncNameToIdMapT &funcNameToId, FunctionDependenciesT funcDependencies);
uint32_t resolveExtFuncDependencies(const ExternalFunctionInfosT &externalFunctionInfos, const FuncNameToIdMapT &funcNameToId, const FunctionDependenciesT &funcDependencies);
uint32_t resolveKernelDependencies(ExternalFunctionInfosT externalFunctionInfos, FuncNameToIdMapT &funcNameToId, KernelDependenciesT kernelDependencies, KernelDescriptorMapT &nameToKernelDescriptor);
uint32_t resolveKernelDependencies(const ExternalFunctionInfosT &externalFunctionInfos, const FuncNameToIdMapT &funcNameToId, const KernelDependenciesT &kernelDependencies, const KernelDescriptorMapT &nameToKernelDescriptor);
} // namespace NEO

View File

@ -21,7 +21,7 @@ TEST(DependencyResolverTests, GivenEmptyGraphReturnEmptyResolve) {
EXPECT_TRUE(resolve.empty());
}
TEST(DependencyResolverTests, GivenGraphWithLoopReturnEmptyResolveAndSetLoopDeteckted) {
TEST(DependencyResolverTests, GivenGraphWithLoopReturnCorrectResolve) {
/*
0 -> 1
^ |
@ -29,10 +29,23 @@ TEST(DependencyResolverTests, GivenGraphWithLoopReturnEmptyResolveAndSetLoopDete
3 <- 2
*/
std::vector<std::vector<size_t>> graph = {{1}, {2}, {3}, {0}};
std::vector<size_t> expectedResolve = {3, 2, 1, 0};
DependencyResolver resolver(graph);
const auto &resolve = resolver.resolveDependencies();
EXPECT_TRUE(resolve.empty());
EXPECT_TRUE(resolver.hasLoop());
EXPECT_EQ(expectedResolve, resolve);
}
TEST(DependencyResolverTests, GivenGraphWithNodeConnectedToItselfThenReturnCorrectResolve) {
/*
0-->
^ |
|<-v
*/
std::vector<std::vector<size_t>> graph = {{0}};
std::vector<size_t> expectedResolve = {0};
DependencyResolver resolver(graph);
const auto &resolve = resolver.resolveDependencies();
EXPECT_EQ(expectedResolve, resolve);
}
TEST(DependencyResolverTests, GivenOneConnectedGraphReturnCorrectResolve) {
@ -56,7 +69,7 @@ TEST(DependencyResolverTests, GivenMultipleDisconnectedGraphsReturnCorrectResolv
4 -> 3
*/
std::vector<std::vector<size_t>> graph = {{}, {2}, {5}, {}, {3}, {}};
std::vector<size_t> expectedResolve = {5, 2, 1, 3, 4};
std::vector<size_t> expectedResolve = {0, 5, 2, 1, 3, 4};
DependencyResolver resolver(graph);
const auto &resolve = resolver.resolveDependencies();
EXPECT_EQ(expectedResolve, resolve);
@ -137,16 +150,6 @@ TEST_F(ExternalFunctionsTests, GivenMissingExtFuncInLookupMapWhenResolvingExtFun
EXPECT_EQ(ERROR_EXTERNAL_FUNCTION_INFO_MISSING, error);
}
TEST_F(ExternalFunctionsTests, GivenLoopWhenResolvingExtFuncDependenciesThenReturnError) {
addExternalFunction("fun0", 0);
addExternalFunction("fun1", 0);
addFuncDependency("fun0", "fun1");
addFuncDependency("fun1", "fun0");
set();
auto error = resolveExtFuncDependencies(extFuncInfo, funcNameToId, functionDependencies);
EXPECT_EQ(ERROR_LOOP_DETECTED, error);
}
TEST_F(ExternalFunctionsTests, GivenMissingExtFuncInLookupMapWhenResolvingKernelDependenciesThenReturnError) {
addKernel("kernel");
addKernelDependency("fun0", "kernel");
@ -184,6 +187,18 @@ TEST_F(ExternalFunctionsTests, GivenMissingExtFuncInKernelDependenciesWhenResolv
EXPECT_EQ(ERROR_EXTERNAL_FUNCTION_INFO_MISSING, error);
}
TEST_F(ExternalFunctionsTests, GivenLoopWhenResolvingExtFuncDependenciesThenReturnSuccess) {
addExternalFunction("fun0", 4);
addExternalFunction("fun1", 2);
addFuncDependency("fun0", "fun1");
addFuncDependency("fun1", "fun0");
set();
auto retVal = resolveExtFuncDependencies(extFuncInfo, funcNameToId, functionDependencies);
EXPECT_EQ(RESOLVE_SUCCESS, retVal);
EXPECT_EQ(4U, extFuncInfo[funcNameToId["fun0"]]->barrierCount);
EXPECT_EQ(4U, extFuncInfo[funcNameToId["fun1"]]->barrierCount);
}
TEST_F(ExternalFunctionsTests, GivenValidFunctionAndKernelDependenciesWhenResolvingBarrierCountThenSetAppropriateBarrierCountAndReturnSuccess) {
addKernel("kernel");
addExternalFunction("fun0", 1U);