mirror of
https://github.com/intel/compute-runtime.git
synced 2025-12-23 11:03:02 +08:00
fix(zebin): allow for recursive function calls
Allow for loops when detecting function dependency in zebin. Signed-off-by: Krystian Chmielewski <krystian.chmielewski@intel.com>
This commit is contained in:
committed by
Compute-Runtime-Automation
parent
67bfebb25e
commit
6ec5647b36
@@ -12,8 +12,8 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
namespace NEO {
|
namespace NEO {
|
||||||
|
|
||||||
uint32_t resolveBarrierCount(ExternalFunctionInfosT externalFunctionInfos, KernelDependenciesT kernelDependencies,
|
uint32_t resolveBarrierCount(const ExternalFunctionInfosT &externalFunctionInfos, const KernelDependenciesT &kernelDependencies,
|
||||||
FunctionDependenciesT funcDependencies, KernelDescriptorMapT &nameToKernelDescriptor) {
|
const FunctionDependenciesT &funcDependencies, const KernelDescriptorMapT &nameToKernelDescriptor) {
|
||||||
FuncNameToIdMapT funcNameToId;
|
FuncNameToIdMapT funcNameToId;
|
||||||
for (size_t i = 0U; i < externalFunctionInfos.size(); i++) {
|
for (size_t i = 0U; i < externalFunctionInfos.size(); i++) {
|
||||||
auto &extFuncInfo = externalFunctionInfos[i];
|
auto &extFuncInfo = externalFunctionInfos[i];
|
||||||
@@ -29,7 +29,7 @@ uint32_t resolveBarrierCount(ExternalFunctionInfosT externalFunctionInfos, Kerne
|
|||||||
return error;
|
return error;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t getExtFuncDependencies(FuncNameToIdMapT &funcNameToId, FunctionDependenciesT funcDependencies, size_t numExternalFuncs,
|
uint32_t getExtFuncDependencies(const FuncNameToIdMapT &funcNameToId, const FunctionDependenciesT &funcDependencies, size_t numExternalFuncs,
|
||||||
DependenciesT &outDependencies, CalledByT &outCalledBy) {
|
DependenciesT &outDependencies, CalledByT &outCalledBy) {
|
||||||
outDependencies.resize(numExternalFuncs);
|
outDependencies.resize(numExternalFuncs);
|
||||||
outCalledBy.resize(numExternalFuncs);
|
outCalledBy.resize(numExternalFuncs);
|
||||||
@@ -39,8 +39,8 @@ uint32_t getExtFuncDependencies(FuncNameToIdMapT &funcNameToId, FunctionDependen
|
|||||||
funcNameToId.count(funcDep->usedFuncName) == 0) {
|
funcNameToId.count(funcDep->usedFuncName) == 0) {
|
||||||
return ERROR_EXTERNAL_FUNCTION_INFO_MISSING;
|
return ERROR_EXTERNAL_FUNCTION_INFO_MISSING;
|
||||||
}
|
}
|
||||||
size_t callerId = funcNameToId[funcDep->callerFuncName];
|
size_t callerId = funcNameToId.at(funcDep->callerFuncName);
|
||||||
size_t calleeId = funcNameToId[funcDep->usedFuncName];
|
size_t calleeId = funcNameToId.at(funcDep->usedFuncName);
|
||||||
|
|
||||||
outDependencies[callerId].push_back(calleeId);
|
outDependencies[callerId].push_back(calleeId);
|
||||||
outCalledBy[calleeId].push_back(callerId);
|
outCalledBy[calleeId].push_back(callerId);
|
||||||
@@ -49,7 +49,7 @@ uint32_t getExtFuncDependencies(FuncNameToIdMapT &funcNameToId, FunctionDependen
|
|||||||
return RESOLVE_SUCCESS;
|
return RESOLVE_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t resolveExtFuncDependencies(ExternalFunctionInfosT externalFunctionInfos, FuncNameToIdMapT &funcNameToId, FunctionDependenciesT funcDependencies) {
|
uint32_t resolveExtFuncDependencies(const ExternalFunctionInfosT &externalFunctionInfos, const FuncNameToIdMapT &funcNameToId, const FunctionDependenciesT &funcDependencies) {
|
||||||
DependenciesT dependencies;
|
DependenciesT dependencies;
|
||||||
CalledByT calledBy;
|
CalledByT calledBy;
|
||||||
auto error = getExtFuncDependencies(funcNameToId, funcDependencies, externalFunctionInfos.size(), dependencies, calledBy);
|
auto error = getExtFuncDependencies(funcNameToId, funcDependencies, externalFunctionInfos.size(), dependencies, calledBy);
|
||||||
@@ -59,9 +59,6 @@ uint32_t resolveExtFuncDependencies(ExternalFunctionInfosT externalFunctionInfos
|
|||||||
|
|
||||||
DependencyResolver depResolver(dependencies);
|
DependencyResolver depResolver(dependencies);
|
||||||
auto resolved = depResolver.resolveDependencies();
|
auto resolved = depResolver.resolveDependencies();
|
||||||
if (depResolver.hasLoop()) {
|
|
||||||
return ERROR_LOOP_DETECTED;
|
|
||||||
}
|
|
||||||
for (auto calleeId : resolved) {
|
for (auto calleeId : resolved) {
|
||||||
const auto callee = externalFunctionInfos[calleeId];
|
const auto callee = externalFunctionInfos[calleeId];
|
||||||
for (auto callerId : calledBy[calleeId]) {
|
for (auto callerId : calledBy[calleeId]) {
|
||||||
@@ -72,16 +69,15 @@ uint32_t resolveExtFuncDependencies(ExternalFunctionInfosT externalFunctionInfos
|
|||||||
return RESOLVE_SUCCESS;
|
return RESOLVE_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t resolveKernelDependencies(ExternalFunctionInfosT externalFunctionInfos, FuncNameToIdMapT &funcNameToId, KernelDependenciesT kernelDependencies, KernelDescriptorMapT &nameToKernelDescriptor) {
|
uint32_t resolveKernelDependencies(const ExternalFunctionInfosT &externalFunctionInfos, const FuncNameToIdMapT &funcNameToId, const KernelDependenciesT &kernelDependencies, const KernelDescriptorMapT &nameToKernelDescriptor) {
|
||||||
for (size_t i = 0; i < kernelDependencies.size(); i++) {
|
for (auto &kernelDep : kernelDependencies) {
|
||||||
auto kernelDep = kernelDependencies[i];
|
|
||||||
if (funcNameToId.count(kernelDep->usedFuncName) == 0) {
|
if (funcNameToId.count(kernelDep->usedFuncName) == 0) {
|
||||||
return ERROR_EXTERNAL_FUNCTION_INFO_MISSING;
|
return ERROR_EXTERNAL_FUNCTION_INFO_MISSING;
|
||||||
} else if (nameToKernelDescriptor.count(kernelDep->kernelName) == 0) {
|
} else if (nameToKernelDescriptor.count(kernelDep->kernelName) == 0) {
|
||||||
return ERROR_KERNEL_DESCRIPTOR_MISSING;
|
return ERROR_KERNEL_DESCRIPTOR_MISSING;
|
||||||
}
|
}
|
||||||
const auto functionBarrierCount = externalFunctionInfos[funcNameToId[kernelDep->usedFuncName]]->barrierCount;
|
const auto functionBarrierCount = externalFunctionInfos.at(funcNameToId.at(kernelDep->usedFuncName))->barrierCount;
|
||||||
auto &kernelBarrierCount = nameToKernelDescriptor[kernelDep->kernelName]->kernelAttributes.barrierCount;
|
auto &kernelBarrierCount = nameToKernelDescriptor.at(kernelDep->kernelName)->kernelAttributes.barrierCount;
|
||||||
kernelBarrierCount = std::max(kernelBarrierCount, functionBarrierCount);
|
kernelBarrierCount = std::max(kernelBarrierCount, functionBarrierCount);
|
||||||
}
|
}
|
||||||
return RESOLVE_SUCCESS;
|
return RESOLVE_SUCCESS;
|
||||||
@@ -89,27 +85,21 @@ uint32_t resolveKernelDependencies(ExternalFunctionInfosT externalFunctionInfos,
|
|||||||
|
|
||||||
std::vector<size_t> DependencyResolver::resolveDependencies() {
|
std::vector<size_t> DependencyResolver::resolveDependencies() {
|
||||||
for (size_t i = 0; i < graph.size(); i++) {
|
for (size_t i = 0; i < graph.size(); i++) {
|
||||||
if (std::find(seen.begin(), seen.end(), i) == seen.end() && graph[i].empty() == false) {
|
if (std::find(seen.begin(), seen.end(), i) == seen.end()) {
|
||||||
resolveDependency(i, graph[i]);
|
resolveDependency(i, graph[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (loopDeteckted) {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
return resolved;
|
return resolved;
|
||||||
}
|
}
|
||||||
|
|
||||||
void DependencyResolver::resolveDependency(size_t nodeId, const std::vector<size_t> &edges) {
|
void DependencyResolver::resolveDependency(size_t nodeId, const std::vector<size_t> &edges) {
|
||||||
seen.push_back(nodeId);
|
seen.push_back(nodeId);
|
||||||
for (auto &edgeId : edges) {
|
for (auto &edgeId : edges) {
|
||||||
if (std::find(resolved.begin(), resolved.end(), edgeId) == resolved.end()) {
|
if (std::find(resolved.begin(), resolved.end(), edgeId) == resolved.end() &&
|
||||||
if (std::find(seen.begin(), seen.end(), edgeId) != seen.end()) {
|
std::find(seen.begin(), seen.end(), edgeId) == seen.end()) {
|
||||||
loopDeteckted = true;
|
|
||||||
} else {
|
|
||||||
resolveDependency(edgeId, graph[edgeId]);
|
resolveDependency(edgeId, graph[edgeId]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
resolved.push_back(nodeId);
|
resolved.push_back(nodeId);
|
||||||
}
|
}
|
||||||
} // namespace NEO
|
} // namespace NEO
|
||||||
@@ -52,24 +52,22 @@ class DependencyResolver {
|
|||||||
public:
|
public:
|
||||||
DependencyResolver(const std::vector<std::vector<size_t>> &graph) : graph(graph) {}
|
DependencyResolver(const std::vector<std::vector<size_t>> &graph) : graph(graph) {}
|
||||||
std::vector<size_t> resolveDependencies();
|
std::vector<size_t> resolveDependencies();
|
||||||
inline bool hasLoop() { return loopDeteckted; }
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void resolveDependency(size_t nodeId, const std::vector<size_t> &edges);
|
void resolveDependency(size_t nodeId, const std::vector<size_t> &edges);
|
||||||
std::vector<size_t> seen;
|
std::vector<size_t> seen;
|
||||||
std::vector<size_t> resolved;
|
std::vector<size_t> resolved;
|
||||||
const std::vector<std::vector<size_t>> &graph;
|
const std::vector<std::vector<size_t>> &graph;
|
||||||
bool loopDeteckted = false;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
uint32_t resolveBarrierCount(ExternalFunctionInfosT externalFunctionInfos, KernelDependenciesT kernelDependencies,
|
uint32_t resolveBarrierCount(const ExternalFunctionInfosT &externalFunctionInfos, const KernelDependenciesT &kernelDependencies,
|
||||||
FunctionDependenciesT funcDependencies, KernelDescriptorMapT &nameToKernelDescriptor);
|
const FunctionDependenciesT &funcDependencies, const KernelDescriptorMapT &nameToKernelDescriptor);
|
||||||
|
|
||||||
uint32_t getExtFuncDependencies(FuncNameToIdMapT &funcNameToId, FunctionDependenciesT funcDependencies, size_t numExternalFuncs,
|
uint32_t getExtFuncDependencies(const FuncNameToIdMapT &funcNameToId, const FunctionDependenciesT &funcDependencies, size_t numExternalFuncs,
|
||||||
DependenciesT &outDependencies, CalledByT &outCalledBy);
|
DependenciesT &outDependencies, CalledByT &outCalledBy);
|
||||||
|
|
||||||
uint32_t resolveExtFuncDependencies(ExternalFunctionInfosT externalFunctionInfos, FuncNameToIdMapT &funcNameToId, FunctionDependenciesT funcDependencies);
|
uint32_t resolveExtFuncDependencies(const ExternalFunctionInfosT &externalFunctionInfos, const FuncNameToIdMapT &funcNameToId, const FunctionDependenciesT &funcDependencies);
|
||||||
|
|
||||||
uint32_t resolveKernelDependencies(ExternalFunctionInfosT externalFunctionInfos, FuncNameToIdMapT &funcNameToId, KernelDependenciesT kernelDependencies, KernelDescriptorMapT &nameToKernelDescriptor);
|
uint32_t resolveKernelDependencies(const ExternalFunctionInfosT &externalFunctionInfos, const FuncNameToIdMapT &funcNameToId, const KernelDependenciesT &kernelDependencies, const KernelDescriptorMapT &nameToKernelDescriptor);
|
||||||
|
|
||||||
} // namespace NEO
|
} // namespace NEO
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ TEST(DependencyResolverTests, GivenEmptyGraphReturnEmptyResolve) {
|
|||||||
EXPECT_TRUE(resolve.empty());
|
EXPECT_TRUE(resolve.empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(DependencyResolverTests, GivenGraphWithLoopReturnEmptyResolveAndSetLoopDeteckted) {
|
TEST(DependencyResolverTests, GivenGraphWithLoopReturnCorrectResolve) {
|
||||||
/*
|
/*
|
||||||
0 -> 1
|
0 -> 1
|
||||||
^ |
|
^ |
|
||||||
@@ -29,10 +29,23 @@ TEST(DependencyResolverTests, GivenGraphWithLoopReturnEmptyResolveAndSetLoopDete
|
|||||||
3 <- 2
|
3 <- 2
|
||||||
*/
|
*/
|
||||||
std::vector<std::vector<size_t>> graph = {{1}, {2}, {3}, {0}};
|
std::vector<std::vector<size_t>> graph = {{1}, {2}, {3}, {0}};
|
||||||
|
std::vector<size_t> expectedResolve = {3, 2, 1, 0};
|
||||||
DependencyResolver resolver(graph);
|
DependencyResolver resolver(graph);
|
||||||
const auto &resolve = resolver.resolveDependencies();
|
const auto &resolve = resolver.resolveDependencies();
|
||||||
EXPECT_TRUE(resolve.empty());
|
EXPECT_EQ(expectedResolve, resolve);
|
||||||
EXPECT_TRUE(resolver.hasLoop());
|
}
|
||||||
|
|
||||||
|
TEST(DependencyResolverTests, GivenGraphWithNodeConnectedToItselfThenReturnCorrectResolve) {
|
||||||
|
/*
|
||||||
|
0-->
|
||||||
|
^ |
|
||||||
|
|<-v
|
||||||
|
*/
|
||||||
|
std::vector<std::vector<size_t>> graph = {{0}};
|
||||||
|
std::vector<size_t> expectedResolve = {0};
|
||||||
|
DependencyResolver resolver(graph);
|
||||||
|
const auto &resolve = resolver.resolveDependencies();
|
||||||
|
EXPECT_EQ(expectedResolve, resolve);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(DependencyResolverTests, GivenOneConnectedGraphReturnCorrectResolve) {
|
TEST(DependencyResolverTests, GivenOneConnectedGraphReturnCorrectResolve) {
|
||||||
@@ -56,7 +69,7 @@ TEST(DependencyResolverTests, GivenMultipleDisconnectedGraphsReturnCorrectResolv
|
|||||||
4 -> 3
|
4 -> 3
|
||||||
*/
|
*/
|
||||||
std::vector<std::vector<size_t>> graph = {{}, {2}, {5}, {}, {3}, {}};
|
std::vector<std::vector<size_t>> graph = {{}, {2}, {5}, {}, {3}, {}};
|
||||||
std::vector<size_t> expectedResolve = {5, 2, 1, 3, 4};
|
std::vector<size_t> expectedResolve = {0, 5, 2, 1, 3, 4};
|
||||||
DependencyResolver resolver(graph);
|
DependencyResolver resolver(graph);
|
||||||
const auto &resolve = resolver.resolveDependencies();
|
const auto &resolve = resolver.resolveDependencies();
|
||||||
EXPECT_EQ(expectedResolve, resolve);
|
EXPECT_EQ(expectedResolve, resolve);
|
||||||
@@ -137,16 +150,6 @@ TEST_F(ExternalFunctionsTests, GivenMissingExtFuncInLookupMapWhenResolvingExtFun
|
|||||||
EXPECT_EQ(ERROR_EXTERNAL_FUNCTION_INFO_MISSING, error);
|
EXPECT_EQ(ERROR_EXTERNAL_FUNCTION_INFO_MISSING, error);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ExternalFunctionsTests, GivenLoopWhenResolvingExtFuncDependenciesThenReturnError) {
|
|
||||||
addExternalFunction("fun0", 0);
|
|
||||||
addExternalFunction("fun1", 0);
|
|
||||||
addFuncDependency("fun0", "fun1");
|
|
||||||
addFuncDependency("fun1", "fun0");
|
|
||||||
set();
|
|
||||||
auto error = resolveExtFuncDependencies(extFuncInfo, funcNameToId, functionDependencies);
|
|
||||||
EXPECT_EQ(ERROR_LOOP_DETECTED, error);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ExternalFunctionsTests, GivenMissingExtFuncInLookupMapWhenResolvingKernelDependenciesThenReturnError) {
|
TEST_F(ExternalFunctionsTests, GivenMissingExtFuncInLookupMapWhenResolvingKernelDependenciesThenReturnError) {
|
||||||
addKernel("kernel");
|
addKernel("kernel");
|
||||||
addKernelDependency("fun0", "kernel");
|
addKernelDependency("fun0", "kernel");
|
||||||
@@ -184,6 +187,18 @@ TEST_F(ExternalFunctionsTests, GivenMissingExtFuncInKernelDependenciesWhenResolv
|
|||||||
EXPECT_EQ(ERROR_EXTERNAL_FUNCTION_INFO_MISSING, error);
|
EXPECT_EQ(ERROR_EXTERNAL_FUNCTION_INFO_MISSING, error);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ExternalFunctionsTests, GivenLoopWhenResolvingExtFuncDependenciesThenReturnSuccess) {
|
||||||
|
addExternalFunction("fun0", 4);
|
||||||
|
addExternalFunction("fun1", 2);
|
||||||
|
addFuncDependency("fun0", "fun1");
|
||||||
|
addFuncDependency("fun1", "fun0");
|
||||||
|
set();
|
||||||
|
auto retVal = resolveExtFuncDependencies(extFuncInfo, funcNameToId, functionDependencies);
|
||||||
|
EXPECT_EQ(RESOLVE_SUCCESS, retVal);
|
||||||
|
EXPECT_EQ(4U, extFuncInfo[funcNameToId["fun0"]]->barrierCount);
|
||||||
|
EXPECT_EQ(4U, extFuncInfo[funcNameToId["fun1"]]->barrierCount);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ExternalFunctionsTests, GivenValidFunctionAndKernelDependenciesWhenResolvingBarrierCountThenSetAppropriateBarrierCountAndReturnSuccess) {
|
TEST_F(ExternalFunctionsTests, GivenValidFunctionAndKernelDependenciesWhenResolvingBarrierCountThenSetAppropriateBarrierCountAndReturnSuccess) {
|
||||||
addKernel("kernel");
|
addKernel("kernel");
|
||||||
addExternalFunction("fun0", 1U);
|
addExternalFunction("fun0", 1U);
|
||||||
|
|||||||
Reference in New Issue
Block a user