From 6ec5647b36263f3b720d8bb33ccd99aabf50e49e Mon Sep 17 00:00:00 2001 From: Krystian Chmielewski Date: Mon, 21 Nov 2022 14:54:41 +0000 Subject: [PATCH] fix(zebin): allow for recursive function calls Allow for loops when detecting function dependency in zebin. Signed-off-by: Krystian Chmielewski --- .../compiler_interface/external_functions.cpp | 38 ++++++---------- .../compiler_interface/external_functions.h | 12 +++--- .../external_functions_tests.cpp | 43 +++++++++++++------ 3 files changed, 48 insertions(+), 45 deletions(-) diff --git a/shared/source/compiler_interface/external_functions.cpp b/shared/source/compiler_interface/external_functions.cpp index 2703117c9d..612d24d02e 100644 --- a/shared/source/compiler_interface/external_functions.cpp +++ b/shared/source/compiler_interface/external_functions.cpp @@ -12,8 +12,8 @@ #include namespace NEO { -uint32_t resolveBarrierCount(ExternalFunctionInfosT externalFunctionInfos, KernelDependenciesT kernelDependencies, - FunctionDependenciesT funcDependencies, KernelDescriptorMapT &nameToKernelDescriptor) { +uint32_t resolveBarrierCount(const ExternalFunctionInfosT &externalFunctionInfos, const KernelDependenciesT &kernelDependencies, + const FunctionDependenciesT &funcDependencies, const KernelDescriptorMapT &nameToKernelDescriptor) { FuncNameToIdMapT funcNameToId; for (size_t i = 0U; i < externalFunctionInfos.size(); i++) { auto &extFuncInfo = externalFunctionInfos[i]; @@ -29,7 +29,7 @@ uint32_t resolveBarrierCount(ExternalFunctionInfosT externalFunctionInfos, Kerne return error; } -uint32_t getExtFuncDependencies(FuncNameToIdMapT &funcNameToId, FunctionDependenciesT funcDependencies, size_t numExternalFuncs, +uint32_t getExtFuncDependencies(const FuncNameToIdMapT &funcNameToId, const FunctionDependenciesT &funcDependencies, size_t numExternalFuncs, DependenciesT &outDependencies, CalledByT &outCalledBy) { outDependencies.resize(numExternalFuncs); outCalledBy.resize(numExternalFuncs); @@ -39,8 +39,8 @@ uint32_t getExtFuncDependencies(FuncNameToIdMapT &funcNameToId, FunctionDependen funcNameToId.count(funcDep->usedFuncName) == 0) { return ERROR_EXTERNAL_FUNCTION_INFO_MISSING; } - size_t callerId = funcNameToId[funcDep->callerFuncName]; - size_t calleeId = funcNameToId[funcDep->usedFuncName]; + size_t callerId = funcNameToId.at(funcDep->callerFuncName); + size_t calleeId = funcNameToId.at(funcDep->usedFuncName); outDependencies[callerId].push_back(calleeId); outCalledBy[calleeId].push_back(callerId); @@ -49,7 +49,7 @@ uint32_t getExtFuncDependencies(FuncNameToIdMapT &funcNameToId, FunctionDependen return RESOLVE_SUCCESS; } -uint32_t resolveExtFuncDependencies(ExternalFunctionInfosT externalFunctionInfos, FuncNameToIdMapT &funcNameToId, FunctionDependenciesT funcDependencies) { +uint32_t resolveExtFuncDependencies(const ExternalFunctionInfosT &externalFunctionInfos, const FuncNameToIdMapT &funcNameToId, const FunctionDependenciesT &funcDependencies) { DependenciesT dependencies; CalledByT calledBy; auto error = getExtFuncDependencies(funcNameToId, funcDependencies, externalFunctionInfos.size(), dependencies, calledBy); @@ -59,9 +59,6 @@ uint32_t resolveExtFuncDependencies(ExternalFunctionInfosT externalFunctionInfos DependencyResolver depResolver(dependencies); auto resolved = depResolver.resolveDependencies(); - if (depResolver.hasLoop()) { - return ERROR_LOOP_DETECTED; - } for (auto calleeId : resolved) { const auto callee = externalFunctionInfos[calleeId]; for (auto callerId : calledBy[calleeId]) { @@ -72,16 +69,15 @@ uint32_t resolveExtFuncDependencies(ExternalFunctionInfosT externalFunctionInfos return RESOLVE_SUCCESS; } -uint32_t resolveKernelDependencies(ExternalFunctionInfosT externalFunctionInfos, FuncNameToIdMapT &funcNameToId, KernelDependenciesT kernelDependencies, KernelDescriptorMapT &nameToKernelDescriptor) { - for (size_t i = 0; i < kernelDependencies.size(); i++) { - auto kernelDep = kernelDependencies[i]; +uint32_t resolveKernelDependencies(const ExternalFunctionInfosT &externalFunctionInfos, const FuncNameToIdMapT &funcNameToId, const KernelDependenciesT &kernelDependencies, const KernelDescriptorMapT &nameToKernelDescriptor) { + for (auto &kernelDep : kernelDependencies) { if (funcNameToId.count(kernelDep->usedFuncName) == 0) { return ERROR_EXTERNAL_FUNCTION_INFO_MISSING; } else if (nameToKernelDescriptor.count(kernelDep->kernelName) == 0) { return ERROR_KERNEL_DESCRIPTOR_MISSING; } - const auto functionBarrierCount = externalFunctionInfos[funcNameToId[kernelDep->usedFuncName]]->barrierCount; - auto &kernelBarrierCount = nameToKernelDescriptor[kernelDep->kernelName]->kernelAttributes.barrierCount; + const auto functionBarrierCount = externalFunctionInfos.at(funcNameToId.at(kernelDep->usedFuncName))->barrierCount; + auto &kernelBarrierCount = nameToKernelDescriptor.at(kernelDep->kernelName)->kernelAttributes.barrierCount; kernelBarrierCount = std::max(kernelBarrierCount, functionBarrierCount); } return RESOLVE_SUCCESS; @@ -89,25 +85,19 @@ uint32_t resolveKernelDependencies(ExternalFunctionInfosT externalFunctionInfos, std::vector DependencyResolver::resolveDependencies() { for (size_t i = 0; i < graph.size(); i++) { - if (std::find(seen.begin(), seen.end(), i) == seen.end() && graph[i].empty() == false) { + if (std::find(seen.begin(), seen.end(), i) == seen.end()) { resolveDependency(i, graph[i]); } } - if (loopDeteckted) { - return {}; - } return resolved; } void DependencyResolver::resolveDependency(size_t nodeId, const std::vector &edges) { seen.push_back(nodeId); for (auto &edgeId : edges) { - if (std::find(resolved.begin(), resolved.end(), edgeId) == resolved.end()) { - if (std::find(seen.begin(), seen.end(), edgeId) != seen.end()) { - loopDeteckted = true; - } else { - resolveDependency(edgeId, graph[edgeId]); - } + if (std::find(resolved.begin(), resolved.end(), edgeId) == resolved.end() && + std::find(seen.begin(), seen.end(), edgeId) == seen.end()) { + resolveDependency(edgeId, graph[edgeId]); } } resolved.push_back(nodeId); diff --git a/shared/source/compiler_interface/external_functions.h b/shared/source/compiler_interface/external_functions.h index 8f80dae3b1..87edd10734 100644 --- a/shared/source/compiler_interface/external_functions.h +++ b/shared/source/compiler_interface/external_functions.h @@ -52,24 +52,22 @@ class DependencyResolver { public: DependencyResolver(const std::vector> &graph) : graph(graph) {} std::vector resolveDependencies(); - inline bool hasLoop() { return loopDeteckted; } protected: void resolveDependency(size_t nodeId, const std::vector &edges); std::vector seen; std::vector resolved; const std::vector> &graph; - bool loopDeteckted = false; }; -uint32_t resolveBarrierCount(ExternalFunctionInfosT externalFunctionInfos, KernelDependenciesT kernelDependencies, - FunctionDependenciesT funcDependencies, KernelDescriptorMapT &nameToKernelDescriptor); +uint32_t resolveBarrierCount(const ExternalFunctionInfosT &externalFunctionInfos, const KernelDependenciesT &kernelDependencies, + const FunctionDependenciesT &funcDependencies, const KernelDescriptorMapT &nameToKernelDescriptor); -uint32_t getExtFuncDependencies(FuncNameToIdMapT &funcNameToId, FunctionDependenciesT funcDependencies, size_t numExternalFuncs, +uint32_t getExtFuncDependencies(const FuncNameToIdMapT &funcNameToId, const FunctionDependenciesT &funcDependencies, size_t numExternalFuncs, DependenciesT &outDependencies, CalledByT &outCalledBy); -uint32_t resolveExtFuncDependencies(ExternalFunctionInfosT externalFunctionInfos, FuncNameToIdMapT &funcNameToId, FunctionDependenciesT funcDependencies); +uint32_t resolveExtFuncDependencies(const ExternalFunctionInfosT &externalFunctionInfos, const FuncNameToIdMapT &funcNameToId, const FunctionDependenciesT &funcDependencies); -uint32_t resolveKernelDependencies(ExternalFunctionInfosT externalFunctionInfos, FuncNameToIdMapT &funcNameToId, KernelDependenciesT kernelDependencies, KernelDescriptorMapT &nameToKernelDescriptor); +uint32_t resolveKernelDependencies(const ExternalFunctionInfosT &externalFunctionInfos, const FuncNameToIdMapT &funcNameToId, const KernelDependenciesT &kernelDependencies, const KernelDescriptorMapT &nameToKernelDescriptor); } // namespace NEO diff --git a/shared/test/unit_test/compiler_interface/external_functions_tests.cpp b/shared/test/unit_test/compiler_interface/external_functions_tests.cpp index 224c650e20..e7a05a7609 100644 --- a/shared/test/unit_test/compiler_interface/external_functions_tests.cpp +++ b/shared/test/unit_test/compiler_interface/external_functions_tests.cpp @@ -21,7 +21,7 @@ TEST(DependencyResolverTests, GivenEmptyGraphReturnEmptyResolve) { EXPECT_TRUE(resolve.empty()); } -TEST(DependencyResolverTests, GivenGraphWithLoopReturnEmptyResolveAndSetLoopDeteckted) { +TEST(DependencyResolverTests, GivenGraphWithLoopReturnCorrectResolve) { /* 0 -> 1 ^ | @@ -29,10 +29,23 @@ TEST(DependencyResolverTests, GivenGraphWithLoopReturnEmptyResolveAndSetLoopDete 3 <- 2 */ std::vector> graph = {{1}, {2}, {3}, {0}}; + std::vector expectedResolve = {3, 2, 1, 0}; DependencyResolver resolver(graph); const auto &resolve = resolver.resolveDependencies(); - EXPECT_TRUE(resolve.empty()); - EXPECT_TRUE(resolver.hasLoop()); + EXPECT_EQ(expectedResolve, resolve); +} + +TEST(DependencyResolverTests, GivenGraphWithNodeConnectedToItselfThenReturnCorrectResolve) { + /* + 0--> + ^ | + |<-v + */ + std::vector> graph = {{0}}; + std::vector expectedResolve = {0}; + DependencyResolver resolver(graph); + const auto &resolve = resolver.resolveDependencies(); + EXPECT_EQ(expectedResolve, resolve); } TEST(DependencyResolverTests, GivenOneConnectedGraphReturnCorrectResolve) { @@ -56,7 +69,7 @@ TEST(DependencyResolverTests, GivenMultipleDisconnectedGraphsReturnCorrectResolv 4 -> 3 */ std::vector> graph = {{}, {2}, {5}, {}, {3}, {}}; - std::vector expectedResolve = {5, 2, 1, 3, 4}; + std::vector expectedResolve = {0, 5, 2, 1, 3, 4}; DependencyResolver resolver(graph); const auto &resolve = resolver.resolveDependencies(); EXPECT_EQ(expectedResolve, resolve); @@ -137,16 +150,6 @@ TEST_F(ExternalFunctionsTests, GivenMissingExtFuncInLookupMapWhenResolvingExtFun EXPECT_EQ(ERROR_EXTERNAL_FUNCTION_INFO_MISSING, error); } -TEST_F(ExternalFunctionsTests, GivenLoopWhenResolvingExtFuncDependenciesThenReturnError) { - addExternalFunction("fun0", 0); - addExternalFunction("fun1", 0); - addFuncDependency("fun0", "fun1"); - addFuncDependency("fun1", "fun0"); - set(); - auto error = resolveExtFuncDependencies(extFuncInfo, funcNameToId, functionDependencies); - EXPECT_EQ(ERROR_LOOP_DETECTED, error); -} - TEST_F(ExternalFunctionsTests, GivenMissingExtFuncInLookupMapWhenResolvingKernelDependenciesThenReturnError) { addKernel("kernel"); addKernelDependency("fun0", "kernel"); @@ -184,6 +187,18 @@ TEST_F(ExternalFunctionsTests, GivenMissingExtFuncInKernelDependenciesWhenResolv EXPECT_EQ(ERROR_EXTERNAL_FUNCTION_INFO_MISSING, error); } +TEST_F(ExternalFunctionsTests, GivenLoopWhenResolvingExtFuncDependenciesThenReturnSuccess) { + addExternalFunction("fun0", 4); + addExternalFunction("fun1", 2); + addFuncDependency("fun0", "fun1"); + addFuncDependency("fun1", "fun0"); + set(); + auto retVal = resolveExtFuncDependencies(extFuncInfo, funcNameToId, functionDependencies); + EXPECT_EQ(RESOLVE_SUCCESS, retVal); + EXPECT_EQ(4U, extFuncInfo[funcNameToId["fun0"]]->barrierCount); + EXPECT_EQ(4U, extFuncInfo[funcNameToId["fun1"]]->barrierCount); +} + TEST_F(ExternalFunctionsTests, GivenValidFunctionAndKernelDependenciesWhenResolvingBarrierCountThenSetAppropriateBarrierCountAndReturnSuccess) { addKernel("kernel"); addExternalFunction("fun0", 1U);