mirror of
https://github.com/intel/llvm.git
synced 2026-01-13 11:02:04 +08:00
[orc-rt] Introduce Task and TaskDispatcher APIs and implementations. (#168514)
Introduces the Task and TaskDispatcher interfaces (TaskDispatcher.h), ThreadPoolTaskDispatcher implementation (ThreadPoolTaskDispatch.h), and updates Session to include a TaskDispatcher instance that can be used to run tasks. TaskDispatcher's introduction is motivated by the need to handle calls to JIT'd code initiated from the controller process: Incoming calls will be wrapped in Tasks and dispatched. Session shutdown will wait on TaskDispatcher shutdown, ensuring that all Tasks are run or destroyed prior to the Session being destroyed.
This commit is contained in:
@@ -22,6 +22,8 @@ set(ORC_RT_HEADERS
|
||||
orc-rt/SPSMemoryFlags.h
|
||||
orc-rt/SPSWrapperFunction.h
|
||||
orc-rt/SPSWrapperFunctionBuffer.h
|
||||
orc-rt/TaskDispatcher.h
|
||||
orc-rt/ThreadPoolTaskDispatcher.h
|
||||
orc-rt/WrapperFunction.h
|
||||
orc-rt/bind.h
|
||||
orc-rt/bit.h
|
||||
|
||||
@@ -15,10 +15,12 @@
|
||||
|
||||
#include "orc-rt/Error.h"
|
||||
#include "orc-rt/ResourceManager.h"
|
||||
#include "orc-rt/TaskDispatcher.h"
|
||||
#include "orc-rt/move_only_function.h"
|
||||
|
||||
#include "orc-rt-c/CoreTypes.h"
|
||||
|
||||
#include <condition_variable>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
@@ -39,7 +41,10 @@ public:
|
||||
///
|
||||
/// Note that entry into the reporter is not synchronized: it may be
|
||||
/// called from multiple threads concurrently.
|
||||
Session(ErrorReporterFn ReportError) : ReportError(std::move(ReportError)) {}
|
||||
Session(std::unique_ptr<TaskDispatcher> Dispatcher,
|
||||
ErrorReporterFn ReportError)
|
||||
: Dispatcher(std::move(Dispatcher)), ReportError(std::move(ReportError)) {
|
||||
}
|
||||
|
||||
// Sessions are not copyable or moveable.
|
||||
Session(const Session &) = delete;
|
||||
@@ -49,6 +54,9 @@ public:
|
||||
|
||||
~Session();
|
||||
|
||||
/// Dispatch a task using the Session's TaskDispatcher.
|
||||
void dispatch(std::unique_ptr<Task> T) { Dispatcher->dispatch(std::move(T)); }
|
||||
|
||||
/// Report an error via the ErrorReporter function.
|
||||
void reportError(Error Err) { ReportError(std::move(Err)); }
|
||||
|
||||
@@ -67,12 +75,21 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
void shutdownNext(OnShutdownCompleteFn OnShutdownComplete, Error Err,
|
||||
void shutdownNext(Error Err,
|
||||
std::vector<std::unique_ptr<ResourceManager>> RemainingRMs);
|
||||
|
||||
std::mutex M;
|
||||
void shutdownComplete();
|
||||
|
||||
std::unique_ptr<TaskDispatcher> Dispatcher;
|
||||
ErrorReporterFn ReportError;
|
||||
|
||||
enum class SessionState { Running, ShuttingDown, Shutdown };
|
||||
|
||||
std::mutex M;
|
||||
SessionState State = SessionState::Running;
|
||||
std::condition_variable StateCV;
|
||||
std::vector<std::unique_ptr<ResourceManager>> ResourceMgrs;
|
||||
std::vector<OnShutdownCompleteFn> ShutdownCallbacks;
|
||||
};
|
||||
|
||||
inline orc_rt_SessionRef wrap(Session *S) noexcept {
|
||||
|
||||
64
orc-rt/include/orc-rt/TaskDispatcher.h
Normal file
64
orc-rt/include/orc-rt/TaskDispatcher.h
Normal file
@@ -0,0 +1,64 @@
|
||||
//===----------- TaskDispatcher.h - Task dispatch utils ---------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Task and TaskDispatcher classes.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef ORC_RT_TASKDISPATCHER_H
|
||||
#define ORC_RT_TASKDISPATCHER_H
|
||||
|
||||
#include "orc-rt/RTTI.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
namespace orc_rt {
|
||||
|
||||
/// Represents an abstract task to be run.
|
||||
class Task : public RTTIExtends<Task, RTTIRoot> {
|
||||
public:
|
||||
virtual ~Task();
|
||||
virtual void run() = 0;
|
||||
};
|
||||
|
||||
/// Base class for generic tasks.
|
||||
class GenericTask : public RTTIExtends<GenericTask, Task> {};
|
||||
|
||||
/// Generic task implementation.
|
||||
template <typename FnT> class GenericTaskImpl : public GenericTask {
|
||||
public:
|
||||
GenericTaskImpl(FnT &&Fn) : Fn(std::forward<FnT>(Fn)) {}
|
||||
void run() override { Fn(); }
|
||||
|
||||
private:
|
||||
FnT Fn;
|
||||
};
|
||||
|
||||
/// Create a generic task from a function object.
|
||||
template <typename FnT> std::unique_ptr<GenericTask> makeGenericTask(FnT &&Fn) {
|
||||
return std::make_unique<GenericTaskImpl<std::decay_t<FnT>>>(
|
||||
std::forward<FnT>(Fn));
|
||||
}
|
||||
|
||||
/// Abstract base for classes that dispatch Tasks.
|
||||
class TaskDispatcher {
|
||||
public:
|
||||
virtual ~TaskDispatcher();
|
||||
|
||||
/// Run the given task.
|
||||
virtual void dispatch(std::unique_ptr<Task> T) = 0;
|
||||
|
||||
/// Called by Session. Should cause further dispatches to be rejected, and
|
||||
/// wait until all previously dispatched tasks have completed.
|
||||
virtual void shutdown() = 0;
|
||||
};
|
||||
|
||||
} // End namespace orc_rt
|
||||
|
||||
#endif // ORC_RT_TASKDISPATCHER_H
|
||||
48
orc-rt/include/orc-rt/ThreadPoolTaskDispatcher.h
Normal file
48
orc-rt/include/orc-rt/ThreadPoolTaskDispatcher.h
Normal file
@@ -0,0 +1,48 @@
|
||||
//===--- ThreadPoolTaskDispatcher.h - Run tasks in thread pool --*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// ThreadPoolTaskDispatcher implementation.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef ORC_RT_THREADPOOLTASKDISPATCHER_H
|
||||
#define ORC_RT_THREADPOOLTASKDISPATCHER_H
|
||||
|
||||
#include "orc-rt/TaskDispatcher.h"
|
||||
|
||||
#include <condition_variable>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
namespace orc_rt {
|
||||
|
||||
/// Thread-pool based TaskDispatcher.
|
||||
///
|
||||
/// Will spawn NumThreads threads to run dispatched Tasks.
|
||||
class ThreadPoolTaskDispatcher : public TaskDispatcher {
|
||||
public:
|
||||
ThreadPoolTaskDispatcher(size_t NumThreads);
|
||||
~ThreadPoolTaskDispatcher() override;
|
||||
void dispatch(std::unique_ptr<Task> T) override;
|
||||
void shutdown() override;
|
||||
|
||||
private:
|
||||
void taskLoop();
|
||||
|
||||
std::vector<std::thread> Threads;
|
||||
|
||||
std::mutex M;
|
||||
bool AcceptingTasks = true;
|
||||
std::condition_variable CV;
|
||||
std::vector<std::unique_ptr<Task>> PendingTasks;
|
||||
};
|
||||
|
||||
} // End namespace orc_rt
|
||||
|
||||
#endif // ORC_RT_THREADPOOLTASKDISPATCHER_H
|
||||
@@ -4,6 +4,8 @@ set(files
|
||||
RTTI.cpp
|
||||
Session.cpp
|
||||
SimpleNativeMemoryMap.cpp
|
||||
TaskDispatcher.cpp
|
||||
ThreadPoolTaskDispatcher.cpp
|
||||
)
|
||||
|
||||
add_library(orc-rt-executor STATIC ${files})
|
||||
|
||||
@@ -12,8 +12,6 @@
|
||||
|
||||
#include "orc-rt/Session.h"
|
||||
|
||||
#include <future>
|
||||
|
||||
namespace orc_rt {
|
||||
|
||||
Session::~Session() { waitForShutdown(); }
|
||||
@@ -23,38 +21,62 @@ void Session::shutdown(OnShutdownCompleteFn OnShutdownComplete) {
|
||||
|
||||
{
|
||||
std::scoped_lock<std::mutex> Lock(M);
|
||||
ShutdownCallbacks.push_back(std::move(OnShutdownComplete));
|
||||
|
||||
// If somebody else has already called shutdown then there's nothing further
|
||||
// for us to do here.
|
||||
if (State >= SessionState::ShuttingDown)
|
||||
return;
|
||||
|
||||
State = SessionState::ShuttingDown;
|
||||
std::swap(ResourceMgrs, ToShutdown);
|
||||
}
|
||||
|
||||
shutdownNext(std::move(OnShutdownComplete), Error::success(),
|
||||
std::move(ToShutdown));
|
||||
shutdownNext(Error::success(), std::move(ToShutdown));
|
||||
}
|
||||
|
||||
void Session::waitForShutdown() {
|
||||
std::promise<void> P;
|
||||
auto F = P.get_future();
|
||||
|
||||
shutdown([P = std::move(P)]() mutable { P.set_value(); });
|
||||
|
||||
F.wait();
|
||||
shutdown([]() {});
|
||||
std::unique_lock<std::mutex> Lock(M);
|
||||
StateCV.wait(Lock, [&]() { return State == SessionState::Shutdown; });
|
||||
}
|
||||
|
||||
void Session::shutdownNext(
|
||||
OnShutdownCompleteFn OnComplete, Error Err,
|
||||
std::vector<std::unique_ptr<ResourceManager>> RemainingRMs) {
|
||||
Error Err, std::vector<std::unique_ptr<ResourceManager>> RemainingRMs) {
|
||||
if (Err)
|
||||
reportError(std::move(Err));
|
||||
|
||||
if (RemainingRMs.empty())
|
||||
return OnComplete();
|
||||
return shutdownComplete();
|
||||
|
||||
auto NextRM = std::move(RemainingRMs.back());
|
||||
RemainingRMs.pop_back();
|
||||
NextRM->shutdown([this, RemainingRMs = std::move(RemainingRMs),
|
||||
OnComplete = std::move(OnComplete)](Error Err) mutable {
|
||||
shutdownNext(std::move(OnComplete), std::move(Err),
|
||||
std::move(RemainingRMs));
|
||||
});
|
||||
NextRM->shutdown(
|
||||
[this, RemainingRMs = std::move(RemainingRMs)](Error Err) mutable {
|
||||
shutdownNext(std::move(Err), std::move(RemainingRMs));
|
||||
});
|
||||
}
|
||||
|
||||
void Session::shutdownComplete() {
|
||||
|
||||
std::unique_ptr<TaskDispatcher> TmpDispatcher;
|
||||
std::vector<OnShutdownCompleteFn> TmpShutdownCallbacks;
|
||||
{
|
||||
std::lock_guard<std::mutex> Lock(M);
|
||||
TmpDispatcher = std::move(Dispatcher);
|
||||
TmpShutdownCallbacks = std::move(ShutdownCallbacks);
|
||||
}
|
||||
|
||||
TmpDispatcher->shutdown();
|
||||
|
||||
for (auto &OnShutdownComplete : TmpShutdownCallbacks)
|
||||
OnShutdownComplete();
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> Lock(M);
|
||||
State = SessionState::Shutdown;
|
||||
}
|
||||
StateCV.notify_all();
|
||||
}
|
||||
|
||||
} // namespace orc_rt
|
||||
|
||||
20
orc-rt/lib/executor/TaskDispatcher.cpp
Normal file
20
orc-rt/lib/executor/TaskDispatcher.cpp
Normal file
@@ -0,0 +1,20 @@
|
||||
//===- TaskDispatch.cpp ---------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Contains the implementation of APIs in the orc-rt/TaskDispatch.h header.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "orc-rt/TaskDispatcher.h"
|
||||
|
||||
namespace orc_rt {
|
||||
|
||||
Task::~Task() = default;
|
||||
TaskDispatcher::~TaskDispatcher() = default;
|
||||
|
||||
} // namespace orc_rt
|
||||
70
orc-rt/lib/executor/ThreadPoolTaskDispatcher.cpp
Normal file
70
orc-rt/lib/executor/ThreadPoolTaskDispatcher.cpp
Normal file
@@ -0,0 +1,70 @@
|
||||
//===- ThreadPoolTaskDispatch.cpp -----------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Contains the implementation of APIs in the orc-rt/ThreadPoolTaskDispatch.h
|
||||
// header.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "orc-rt/ThreadPoolTaskDispatcher.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
namespace orc_rt {
|
||||
|
||||
ThreadPoolTaskDispatcher::~ThreadPoolTaskDispatcher() {
|
||||
assert(!AcceptingTasks && "shutdown was not run");
|
||||
}
|
||||
|
||||
ThreadPoolTaskDispatcher::ThreadPoolTaskDispatcher(size_t NumThreads) {
|
||||
Threads.reserve(NumThreads);
|
||||
for (size_t I = 0; I < NumThreads; ++I)
|
||||
Threads.emplace_back([this]() { taskLoop(); });
|
||||
}
|
||||
|
||||
void ThreadPoolTaskDispatcher::dispatch(std::unique_ptr<Task> T) {
|
||||
{
|
||||
std::scoped_lock<std::mutex> Lock(M);
|
||||
if (!AcceptingTasks)
|
||||
return;
|
||||
PendingTasks.push_back(std::move(T));
|
||||
}
|
||||
CV.notify_one();
|
||||
}
|
||||
|
||||
void ThreadPoolTaskDispatcher::shutdown() {
|
||||
{
|
||||
std::scoped_lock<std::mutex> Lock(M);
|
||||
assert(AcceptingTasks && "ThreadPoolTaskDispatcher already shut down?");
|
||||
AcceptingTasks = false;
|
||||
}
|
||||
CV.notify_all();
|
||||
for (auto &Thread : Threads)
|
||||
Thread.join();
|
||||
}
|
||||
|
||||
void ThreadPoolTaskDispatcher::taskLoop() {
|
||||
while (true) {
|
||||
std::unique_ptr<Task> T;
|
||||
{
|
||||
std::unique_lock<std::mutex> Lock(M);
|
||||
CV.wait(Lock,
|
||||
[this]() { return !PendingTasks.empty() || !AcceptingTasks; });
|
||||
|
||||
if (!AcceptingTasks && PendingTasks.empty())
|
||||
return;
|
||||
|
||||
T = std::move(PendingTasks.back());
|
||||
PendingTasks.pop_back();
|
||||
}
|
||||
|
||||
T->run();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace orc_rt
|
||||
@@ -31,6 +31,7 @@ add_orc_rt_unittest(CoreTests
|
||||
SPSMemoryFlagsTest.cpp
|
||||
SPSWrapperFunctionTest.cpp
|
||||
SPSWrapperFunctionBufferTest.cpp
|
||||
ThreadPoolTaskDispatcherTest.cpp
|
||||
WrapperFunctionBufferTest.cpp
|
||||
bind-test.cpp
|
||||
bit-test.cpp
|
||||
|
||||
@@ -11,11 +11,17 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "orc-rt/Session.h"
|
||||
#include "orc-rt/ThreadPoolTaskDispatcher.h"
|
||||
|
||||
#include "gmock/gmock.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include <deque>
|
||||
#include <future>
|
||||
#include <optional>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
using namespace orc_rt;
|
||||
using ::testing::Eq;
|
||||
using ::testing::Optional;
|
||||
@@ -49,17 +55,47 @@ private:
|
||||
move_only_function<Error(Op)> GenResult;
|
||||
};
|
||||
|
||||
class NoDispatcher : public TaskDispatcher {
|
||||
public:
|
||||
void dispatch(std::unique_ptr<Task> T) override {
|
||||
assert(false && "strictly no dispatching!");
|
||||
}
|
||||
void shutdown() override {}
|
||||
};
|
||||
|
||||
class EnqueueingDispatcher : public TaskDispatcher {
|
||||
public:
|
||||
using OnShutdownRunFn = move_only_function<void()>;
|
||||
EnqueueingDispatcher(std::deque<std::unique_ptr<Task>> &Tasks,
|
||||
OnShutdownRunFn OnShutdownRun = {})
|
||||
: Tasks(Tasks), OnShutdownRun(std::move(OnShutdownRun)) {}
|
||||
void dispatch(std::unique_ptr<Task> T) override {
|
||||
Tasks.push_back(std::move(T));
|
||||
}
|
||||
void shutdown() override {
|
||||
if (OnShutdownRun)
|
||||
OnShutdownRun();
|
||||
}
|
||||
|
||||
private:
|
||||
std::deque<std::unique_ptr<Task>> &Tasks;
|
||||
OnShutdownRunFn OnShutdownRun;
|
||||
};
|
||||
|
||||
// Non-overloaded version of cantFail: allows easy construction of
|
||||
// move_only_functions<void(Error)>s.
|
||||
static void noErrors(Error Err) { cantFail(std::move(Err)); }
|
||||
|
||||
TEST(SessionTest, TrivialConstructionAndDestruction) { Session S(noErrors); }
|
||||
TEST(SessionTest, TrivialConstructionAndDestruction) {
|
||||
Session S(std::make_unique<NoDispatcher>(), noErrors);
|
||||
}
|
||||
|
||||
TEST(SessionTest, ReportError) {
|
||||
Error E = Error::success();
|
||||
cantFail(std::move(E)); // Force error into checked state.
|
||||
|
||||
Session S([&](Error Err) { E = std::move(Err); });
|
||||
Session S(std::make_unique<NoDispatcher>(),
|
||||
[&](Error Err) { E = std::move(Err); });
|
||||
S.reportError(make_error<StringError>("foo"));
|
||||
|
||||
if (E)
|
||||
@@ -68,13 +104,27 @@ TEST(SessionTest, ReportError) {
|
||||
ADD_FAILURE() << "Missing error value";
|
||||
}
|
||||
|
||||
TEST(SessionTest, DispatchTask) {
|
||||
int X = 0;
|
||||
std::deque<std::unique_ptr<Task>> Tasks;
|
||||
Session S(std::make_unique<EnqueueingDispatcher>(Tasks), noErrors);
|
||||
|
||||
EXPECT_EQ(Tasks.size(), 0U);
|
||||
S.dispatch(makeGenericTask([&]() { ++X; }));
|
||||
EXPECT_EQ(Tasks.size(), 1U);
|
||||
auto T = std::move(Tasks.front());
|
||||
Tasks.pop_front();
|
||||
T->run();
|
||||
EXPECT_EQ(X, 1);
|
||||
}
|
||||
|
||||
TEST(SessionTest, SingleResourceManager) {
|
||||
size_t OpIdx = 0;
|
||||
std::optional<size_t> DetachOpIdx;
|
||||
std::optional<size_t> ShutdownOpIdx;
|
||||
|
||||
{
|
||||
Session S(noErrors);
|
||||
Session S(std::make_unique<NoDispatcher>(), noErrors);
|
||||
S.addResourceManager(std::make_unique<MockResourceManager>(
|
||||
DetachOpIdx, ShutdownOpIdx, OpIdx));
|
||||
}
|
||||
@@ -90,7 +140,7 @@ TEST(SessionTest, MultipleResourceManagers) {
|
||||
std::optional<size_t> ShutdownOpIdx[3];
|
||||
|
||||
{
|
||||
Session S(noErrors);
|
||||
Session S(std::make_unique<NoDispatcher>(), noErrors);
|
||||
for (size_t I = 0; I != 3; ++I)
|
||||
S.addResourceManager(std::make_unique<MockResourceManager>(
|
||||
DetachOpIdx[I], ShutdownOpIdx[I], OpIdx));
|
||||
@@ -103,3 +153,39 @@ TEST(SessionTest, MultipleResourceManagers) {
|
||||
EXPECT_THAT(ShutdownOpIdx[I], Optional(Eq(2 - I)));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SessionTest, ExpectedShutdownSequence) {
|
||||
// Check that Session shutdown results in...
|
||||
// 1. ResourceManagers being shut down.
|
||||
// 2. The TaskDispatcher being shut down.
|
||||
// 3. A call to OnShutdownComplete.
|
||||
|
||||
size_t OpIdx = 0;
|
||||
std::optional<size_t> DetachOpIdx;
|
||||
std::optional<size_t> ShutdownOpIdx;
|
||||
|
||||
bool DispatcherShutDown = false;
|
||||
bool SessionShutdownComplete = false;
|
||||
std::deque<std::unique_ptr<Task>> Tasks;
|
||||
Session S(std::make_unique<EnqueueingDispatcher>(
|
||||
Tasks,
|
||||
[&]() {
|
||||
std::cerr << "Running dispatcher shutdown.\n";
|
||||
EXPECT_TRUE(ShutdownOpIdx);
|
||||
EXPECT_EQ(*ShutdownOpIdx, 0);
|
||||
EXPECT_FALSE(SessionShutdownComplete);
|
||||
DispatcherShutDown = true;
|
||||
}),
|
||||
noErrors);
|
||||
S.addResourceManager(
|
||||
std::make_unique<MockResourceManager>(DetachOpIdx, ShutdownOpIdx, OpIdx));
|
||||
|
||||
S.shutdown([&]() {
|
||||
EXPECT_TRUE(DispatcherShutDown);
|
||||
std::cerr << "Running shutdown callback.\n";
|
||||
SessionShutdownComplete = true;
|
||||
});
|
||||
S.waitForShutdown();
|
||||
|
||||
EXPECT_TRUE(SessionShutdownComplete);
|
||||
}
|
||||
|
||||
110
orc-rt/unittests/ThreadPoolTaskDispatcherTest.cpp
Normal file
110
orc-rt/unittests/ThreadPoolTaskDispatcherTest.cpp
Normal file
@@ -0,0 +1,110 @@
|
||||
//===-- ThreadPoolTaskDispatcherTest.cpp ----------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "orc-rt/ThreadPoolTaskDispatcher.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <future>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
using namespace orc_rt;
|
||||
|
||||
namespace {
|
||||
|
||||
TEST(ThreadPoolTaskDispatcherTest, NoTasks) {
|
||||
// Check that immediate shutdown works as expected.
|
||||
ThreadPoolTaskDispatcher Dispatcher(1);
|
||||
Dispatcher.shutdown();
|
||||
}
|
||||
|
||||
TEST(ThreadPoolTaskDispatcherTest, BasicTaskExecution) {
|
||||
// Smoke test: Check that we can run a single task on a single-threaded pool.
|
||||
ThreadPoolTaskDispatcher Dispatcher(1);
|
||||
std::atomic<bool> TaskRan = false;
|
||||
|
||||
Dispatcher.dispatch(makeGenericTask([&]() { TaskRan = true; }));
|
||||
|
||||
Dispatcher.shutdown();
|
||||
|
||||
EXPECT_TRUE(TaskRan);
|
||||
}
|
||||
|
||||
TEST(ThreadPoolTaskDispatcherTest, SingleThreadMultipleTasks) {
|
||||
// Check that multiple tasks in a single threaded pool run as expected.
|
||||
ThreadPoolTaskDispatcher Dispatcher(1);
|
||||
size_t NumTasksToRun = 10;
|
||||
std::atomic<size_t> TasksRun = 0;
|
||||
|
||||
for (size_t I = 0; I != NumTasksToRun; ++I)
|
||||
Dispatcher.dispatch(makeGenericTask([&]() { ++TasksRun; }));
|
||||
|
||||
Dispatcher.shutdown();
|
||||
|
||||
EXPECT_EQ(TasksRun, NumTasksToRun);
|
||||
}
|
||||
|
||||
TEST(ThreadPoolTaskDispatcherTest, ConcurrentTasks) {
|
||||
// Check that tasks are run concurrently when multiple workers are available.
|
||||
// Adds two tasks that communicate a value back and forth using futures.
|
||||
// Neither task should be able to complete without the other having started.
|
||||
ThreadPoolTaskDispatcher Dispatcher(2);
|
||||
|
||||
std::promise<int> PInit;
|
||||
std::future<int> FInit = PInit.get_future();
|
||||
std::promise<int> P1;
|
||||
std::future<int> F1 = P1.get_future();
|
||||
std::promise<int> P2;
|
||||
std::future<int> F2 = P2.get_future();
|
||||
std::promise<int> PResult;
|
||||
std::future<int> FResult = PResult.get_future();
|
||||
|
||||
// Task A gets the initial value, sends it via P1, waits for response on F2.
|
||||
Dispatcher.dispatch(makeGenericTask([&]() {
|
||||
P1.set_value(FInit.get());
|
||||
PResult.set_value(F2.get());
|
||||
}));
|
||||
|
||||
// Task B gets value from F1, sends it back on P2.
|
||||
Dispatcher.dispatch(makeGenericTask([&]() { P2.set_value(F1.get()); }));
|
||||
|
||||
int ExpectedValue = 42;
|
||||
PInit.set_value(ExpectedValue);
|
||||
|
||||
Dispatcher.shutdown();
|
||||
|
||||
EXPECT_EQ(FResult.get(), ExpectedValue);
|
||||
}
|
||||
|
||||
TEST(ThreadPoolTaskDispatcherTest, TasksRejectedAfterShutdown) {
|
||||
class TaskToReject : public Task {
|
||||
public:
|
||||
TaskToReject(bool &BodyRun, bool &DestructorRun)
|
||||
: BodyRun(BodyRun), DestructorRun(DestructorRun) {}
|
||||
~TaskToReject() { DestructorRun = true; }
|
||||
void run() override { BodyRun = true; }
|
||||
|
||||
private:
|
||||
bool &BodyRun;
|
||||
bool &DestructorRun;
|
||||
};
|
||||
|
||||
ThreadPoolTaskDispatcher Dispatcher(1);
|
||||
Dispatcher.shutdown();
|
||||
|
||||
bool BodyRun = false;
|
||||
bool DestructorRun = false;
|
||||
|
||||
Dispatcher.dispatch(std::make_unique<TaskToReject>(BodyRun, DestructorRun));
|
||||
|
||||
EXPECT_FALSE(BodyRun);
|
||||
EXPECT_TRUE(DestructorRun);
|
||||
}
|
||||
|
||||
} // end anonymous namespace
|
||||
Reference in New Issue
Block a user