mirror of
https://github.com/intel/llvm.git
synced 2026-01-13 19:08:21 +08:00
ControllerAccess provides an abstract interface for bidirectional RPC between the executor (running JIT'd code) and the controller (containing the llvm::orc::ExecutionSession). ControllerAccess implementations are expected to implement IPC / RPC using a concrete communication method (shared memory, pipes, sockets, native system IPC, etc). Calls from executor to controller are made via callController, with "handler tags" (addresses in the executor) specifying the target handler in the controller. A handler must be associated in the controller with the given tag for the call to succeed. This ensures that only registered entry points in the controller can be used, and avoids leaking controller addresses into the executor. Calls in both directions are to "wrapper functions" that take a buffer of bytes as input and return a buffer of bytes as output. In the ORC runtime these must be `orc_rt_WrapperFunction`s (see Session::handleWrapperCall). The interpretation of the byte buffers is up to the wrapper functions: the ORC runtime imposes no restrictions on how the bytes are to be interpreted. ControllerAccess objects may be detached from the Session prior to Session shutdown, in which case no further calls may be made in either direction, and any pending results (from calls made that haven't returned yet) should return errors. If the ControllerAccess class is still attached at Session shutdown time it will be detached as part of the shutdown process. The ControllerAccess::disconnect method must support concurrent entry on multiple threads, and all callers must block until they can guarantee that no further calls will be received or accepted.
467 lines
14 KiB
C++
467 lines
14 KiB
C++
//===- SessionTest.cpp ----------------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Tests for orc-rt's Session.h APIs.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "orc-rt/Session.h"
|
|
#include "orc-rt/SPSWrapperFunction.h"
|
|
#include "orc-rt/ThreadPoolTaskDispatcher.h"
|
|
|
|
#include "gmock/gmock.h"
|
|
#include "gtest/gtest.h"
|
|
|
|
#include <chrono>
|
|
#include <deque>
|
|
#include <future>
|
|
#include <optional>
|
|
|
|
using namespace orc_rt;
|
|
using ::testing::Eq;
|
|
using ::testing::Optional;
|
|
|
|
class MockResourceManager : public ResourceManager {
|
|
public:
|
|
enum class Op { Detach, Shutdown };
|
|
|
|
static Error alwaysSucceed(Op) { return Error::success(); }
|
|
|
|
MockResourceManager(std::optional<size_t> &DetachOpIdx,
|
|
std::optional<size_t> &ShutdownOpIdx, size_t &OpIdx,
|
|
move_only_function<Error(Op)> GenResult = alwaysSucceed)
|
|
: DetachOpIdx(DetachOpIdx), ShutdownOpIdx(ShutdownOpIdx), OpIdx(OpIdx),
|
|
GenResult(std::move(GenResult)) {}
|
|
|
|
void detach(OnCompleteFn OnComplete) override {
|
|
DetachOpIdx = OpIdx++;
|
|
OnComplete(GenResult(Op::Detach));
|
|
}
|
|
|
|
void shutdown(OnCompleteFn OnComplete) override {
|
|
ShutdownOpIdx = OpIdx++;
|
|
OnComplete(GenResult(Op::Shutdown));
|
|
}
|
|
|
|
private:
|
|
std::optional<size_t> &DetachOpIdx;
|
|
std::optional<size_t> &ShutdownOpIdx;
|
|
size_t &OpIdx;
|
|
move_only_function<Error(Op)> GenResult;
|
|
};
|
|
|
|
class NoDispatcher : public TaskDispatcher {
|
|
public:
|
|
void dispatch(std::unique_ptr<Task> T) override {
|
|
assert(false && "strictly no dispatching!");
|
|
}
|
|
void shutdown() override {}
|
|
};
|
|
|
|
class EnqueueingDispatcher : public TaskDispatcher {
|
|
public:
|
|
using OnShutdownRunFn = move_only_function<void()>;
|
|
EnqueueingDispatcher(std::deque<std::unique_ptr<Task>> &Tasks,
|
|
OnShutdownRunFn OnShutdownRun = {})
|
|
: Tasks(Tasks), OnShutdownRun(std::move(OnShutdownRun)) {}
|
|
void dispatch(std::unique_ptr<Task> T) override {
|
|
Tasks.push_back(std::move(T));
|
|
}
|
|
void shutdown() override {
|
|
if (OnShutdownRun)
|
|
OnShutdownRun();
|
|
}
|
|
|
|
/// Run up to NumTasks (arbitrarily many if NumTasks == std::nullopt) tasks
|
|
/// from the front of the queue, returning the number actually run.
|
|
static size_t
|
|
runTasksFromFront(std::deque<std::unique_ptr<Task>> &Tasks,
|
|
std::optional<size_t> NumTasks = std::nullopt) {
|
|
size_t NumRun = 0;
|
|
|
|
while (!Tasks.empty() && (!NumTasks || NumRun != *NumTasks)) {
|
|
auto T = std::move(Tasks.front());
|
|
Tasks.pop_front();
|
|
T->run();
|
|
++NumRun;
|
|
}
|
|
|
|
return NumRun;
|
|
}
|
|
|
|
private:
|
|
std::deque<std::unique_ptr<Task>> &Tasks;
|
|
OnShutdownRunFn OnShutdownRun;
|
|
};
|
|
|
|
class MockControllerAccess : public Session::ControllerAccess {
|
|
public:
|
|
MockControllerAccess(Session &SS) : Session::ControllerAccess(SS), SS(SS) {}
|
|
|
|
void disconnect() override {
|
|
std::unique_lock<std::mutex> Lock(M);
|
|
Shutdown = true;
|
|
ShutdownCV.wait(Lock, [this]() { return Shutdown && Outstanding == 0; });
|
|
}
|
|
|
|
void callController(OnCallHandlerCompleteFn OnComplete, HandlerTag T,
|
|
WrapperFunctionBuffer ArgBytes) override {
|
|
// Simulate a call to the controller by dispatching a task to run the
|
|
// requested function.
|
|
size_t CId;
|
|
{
|
|
std::scoped_lock<std::mutex> Lock(M);
|
|
if (Shutdown)
|
|
return;
|
|
CId = CallId++;
|
|
Pending[CId] = std::move(OnComplete);
|
|
++Outstanding;
|
|
}
|
|
|
|
SS.dispatch(makeGenericTask([this, CId, OnComplete = std::move(OnComplete),
|
|
T, ArgBytes = std::move(ArgBytes)]() mutable {
|
|
auto Fn = reinterpret_cast<orc_rt_WrapperFunction>(T);
|
|
Fn(reinterpret_cast<orc_rt_SessionRef>(this), CId, wfReturn,
|
|
ArgBytes.release());
|
|
}));
|
|
|
|
bool Notify = false;
|
|
{
|
|
std::scoped_lock<std::mutex> Lock(M);
|
|
if (--Outstanding == 0 && Shutdown)
|
|
Notify = true;
|
|
}
|
|
if (Notify)
|
|
ShutdownCV.notify_all();
|
|
}
|
|
|
|
void sendWrapperResult(uint64_t CallId,
|
|
WrapperFunctionBuffer ResultBytes) override {
|
|
// Respond to a simulated call by the controller.
|
|
OnCallHandlerCompleteFn OnComplete;
|
|
{
|
|
std::scoped_lock<std::mutex> Lock(M);
|
|
if (Shutdown) {
|
|
assert(Pending.empty() && "Shut down but results still pending?");
|
|
return;
|
|
}
|
|
auto I = Pending.find(CallId);
|
|
assert(I != Pending.end());
|
|
OnComplete = std::move(I->second);
|
|
Pending.erase(I);
|
|
++Outstanding;
|
|
}
|
|
|
|
SS.dispatch(
|
|
makeGenericTask([OnComplete = std::move(OnComplete),
|
|
ResultBytes = std::move(ResultBytes)]() mutable {
|
|
OnComplete(std::move(ResultBytes));
|
|
}));
|
|
|
|
bool Notify = false;
|
|
{
|
|
std::scoped_lock<std::mutex> Lock(M);
|
|
if (--Outstanding == 0 && Shutdown)
|
|
Notify = true;
|
|
}
|
|
if (Notify)
|
|
ShutdownCV.notify_all();
|
|
}
|
|
|
|
void callFromController(OnCallHandlerCompleteFn OnComplete,
|
|
orc_rt_WrapperFunction Fn,
|
|
WrapperFunctionBuffer ArgBytes) {
|
|
size_t CId = 0;
|
|
bool BailOut = false;
|
|
{
|
|
std::scoped_lock<std::mutex> Lock(M);
|
|
if (!Shutdown) {
|
|
CId = CallId++;
|
|
Pending[CId] = std::move(OnComplete);
|
|
++Outstanding;
|
|
} else
|
|
BailOut = true;
|
|
}
|
|
if (BailOut)
|
|
return OnComplete(WrapperFunctionBuffer::createOutOfBandError(
|
|
"Controller disconnected"));
|
|
|
|
handleWrapperCall(CId, Fn, std::move(ArgBytes));
|
|
|
|
bool Notify = false;
|
|
{
|
|
std::scoped_lock<std::mutex> Lock(M);
|
|
if (--Outstanding == 0 && Shutdown)
|
|
Notify = true;
|
|
}
|
|
|
|
if (Notify)
|
|
ShutdownCV.notify_all();
|
|
}
|
|
|
|
/// Simulate start of outstanding operation.
|
|
void incOutstanding() {
|
|
std::scoped_lock<std::mutex> Lock(M);
|
|
++Outstanding;
|
|
}
|
|
|
|
/// Simulate end of outstanding operation.
|
|
void decOutstanding() {
|
|
bool Notify = false;
|
|
{
|
|
std::scoped_lock<std::mutex> Lock(M);
|
|
if (--Outstanding == 0 && Shutdown)
|
|
Notify = true;
|
|
}
|
|
if (Notify)
|
|
ShutdownCV.notify_all();
|
|
}
|
|
|
|
private:
|
|
static void wfReturn(orc_rt_SessionRef S, uint64_t CallId,
|
|
orc_rt_WrapperFunctionBuffer ResultBytes) {
|
|
// Abuse "session" to refer to the ControllerAccess object.
|
|
// We can just re-use sendFunctionResult for this.
|
|
reinterpret_cast<MockControllerAccess *>(S)->sendWrapperResult(CallId,
|
|
ResultBytes);
|
|
}
|
|
|
|
Session &SS;
|
|
|
|
std::mutex M;
|
|
bool Shutdown = false;
|
|
size_t Outstanding = 0;
|
|
size_t CallId = 0;
|
|
std::unordered_map<size_t, OnCallHandlerCompleteFn> Pending;
|
|
std::condition_variable ShutdownCV;
|
|
};
|
|
|
|
class CallViaMockControllerAccess {
|
|
public:
|
|
CallViaMockControllerAccess(MockControllerAccess &CA,
|
|
orc_rt_WrapperFunction Fn)
|
|
: CA(CA), Fn(Fn) {}
|
|
void operator()(Session::OnCallHandlerCompleteFn OnComplete,
|
|
WrapperFunctionBuffer ArgBytes) {
|
|
CA.callFromController(std::move(OnComplete), Fn, std::move(ArgBytes));
|
|
}
|
|
|
|
private:
|
|
MockControllerAccess &CA;
|
|
orc_rt_WrapperFunction Fn;
|
|
};
|
|
|
|
// Non-overloaded version of cantFail: allows easy construction of
|
|
// move_only_functions<void(Error)>s.
|
|
static void noErrors(Error Err) { cantFail(std::move(Err)); }
|
|
|
|
TEST(SessionTest, TrivialConstructionAndDestruction) {
|
|
Session S(std::make_unique<NoDispatcher>(), noErrors);
|
|
}
|
|
|
|
TEST(SessionTest, ReportError) {
|
|
Error E = Error::success();
|
|
cantFail(std::move(E)); // Force error into checked state.
|
|
|
|
Session S(std::make_unique<NoDispatcher>(),
|
|
[&](Error Err) { E = std::move(Err); });
|
|
S.reportError(make_error<StringError>("foo"));
|
|
|
|
if (E)
|
|
EXPECT_EQ(toString(std::move(E)), "foo");
|
|
else
|
|
ADD_FAILURE() << "Missing error value";
|
|
}
|
|
|
|
TEST(SessionTest, DispatchTask) {
|
|
int X = 0;
|
|
std::deque<std::unique_ptr<Task>> Tasks;
|
|
Session S(std::make_unique<EnqueueingDispatcher>(Tasks), noErrors);
|
|
|
|
EXPECT_EQ(Tasks.size(), 0U);
|
|
S.dispatch(makeGenericTask([&]() { ++X; }));
|
|
EXPECT_EQ(Tasks.size(), 1U);
|
|
auto T = std::move(Tasks.front());
|
|
Tasks.pop_front();
|
|
T->run();
|
|
EXPECT_EQ(X, 1);
|
|
}
|
|
|
|
TEST(SessionTest, SingleResourceManager) {
|
|
size_t OpIdx = 0;
|
|
std::optional<size_t> DetachOpIdx;
|
|
std::optional<size_t> ShutdownOpIdx;
|
|
|
|
{
|
|
Session S(std::make_unique<NoDispatcher>(), noErrors);
|
|
S.addResourceManager(std::make_unique<MockResourceManager>(
|
|
DetachOpIdx, ShutdownOpIdx, OpIdx));
|
|
}
|
|
|
|
EXPECT_EQ(OpIdx, 1U);
|
|
EXPECT_EQ(DetachOpIdx, std::nullopt);
|
|
EXPECT_THAT(ShutdownOpIdx, Optional(Eq(0)));
|
|
}
|
|
|
|
TEST(SessionTest, MultipleResourceManagers) {
|
|
size_t OpIdx = 0;
|
|
std::optional<size_t> DetachOpIdx[3];
|
|
std::optional<size_t> ShutdownOpIdx[3];
|
|
|
|
{
|
|
Session S(std::make_unique<NoDispatcher>(), noErrors);
|
|
for (size_t I = 0; I != 3; ++I)
|
|
S.addResourceManager(std::make_unique<MockResourceManager>(
|
|
DetachOpIdx[I], ShutdownOpIdx[I], OpIdx));
|
|
}
|
|
|
|
EXPECT_EQ(OpIdx, 3U);
|
|
// Expect shutdown in reverse order.
|
|
for (size_t I = 0; I != 3; ++I) {
|
|
EXPECT_EQ(DetachOpIdx[I], std::nullopt);
|
|
EXPECT_THAT(ShutdownOpIdx[I], Optional(Eq(2 - I)));
|
|
}
|
|
}
|
|
|
|
TEST(SessionTest, ExpectedShutdownSequence) {
|
|
// Check that Session shutdown results in...
|
|
// 1. ResourceManagers being shut down.
|
|
// 2. The TaskDispatcher being shut down.
|
|
// 3. A call to OnShutdownComplete.
|
|
|
|
size_t OpIdx = 0;
|
|
std::optional<size_t> DetachOpIdx;
|
|
std::optional<size_t> ShutdownOpIdx;
|
|
|
|
bool DispatcherShutDown = false;
|
|
bool SessionShutdownComplete = false;
|
|
std::deque<std::unique_ptr<Task>> Tasks;
|
|
Session S(std::make_unique<EnqueueingDispatcher>(
|
|
Tasks,
|
|
[&]() {
|
|
EXPECT_TRUE(ShutdownOpIdx);
|
|
EXPECT_EQ(*ShutdownOpIdx, 0);
|
|
EXPECT_FALSE(SessionShutdownComplete);
|
|
DispatcherShutDown = true;
|
|
}),
|
|
noErrors);
|
|
S.addResourceManager(
|
|
std::make_unique<MockResourceManager>(DetachOpIdx, ShutdownOpIdx, OpIdx));
|
|
|
|
S.shutdown([&]() {
|
|
EXPECT_TRUE(DispatcherShutDown);
|
|
SessionShutdownComplete = true;
|
|
});
|
|
S.waitForShutdown();
|
|
|
|
EXPECT_TRUE(SessionShutdownComplete);
|
|
}
|
|
|
|
TEST(ControllerAccessTest, Basics) {
|
|
// Test that we can set the ControllerAccess implementation and still shut
|
|
// down as expected.
|
|
std::deque<std::unique_ptr<Task>> Tasks;
|
|
Session S(std::make_unique<EnqueueingDispatcher>(Tasks), noErrors);
|
|
auto CA = std::make_shared<MockControllerAccess>(S);
|
|
S.setController(CA);
|
|
|
|
EnqueueingDispatcher::runTasksFromFront(Tasks);
|
|
|
|
S.waitForShutdown();
|
|
}
|
|
|
|
static void add_sps_wrapper(orc_rt_SessionRef S, uint64_t CallId,
|
|
orc_rt_WrapperFunctionReturn Return,
|
|
orc_rt_WrapperFunctionBuffer ArgBytes) {
|
|
SPSWrapperFunction<int32_t(int32_t, int32_t)>::handle(
|
|
S, CallId, Return, ArgBytes,
|
|
[](move_only_function<void(int32_t)> Return, int32_t X, int32_t Y) {
|
|
Return(X + Y);
|
|
});
|
|
}
|
|
|
|
TEST(ControllerAccessTest, ValidCallToController) {
|
|
// Simulate a call to a controller handler.
|
|
std::deque<std::unique_ptr<Task>> Tasks;
|
|
Session S(std::make_unique<EnqueueingDispatcher>(Tasks), noErrors);
|
|
auto CA = std::make_shared<MockControllerAccess>(S);
|
|
S.setController(CA);
|
|
|
|
int32_t Result = 0;
|
|
SPSWrapperFunction<int32_t(int32_t, int32_t)>::call(
|
|
CallViaSession(S, reinterpret_cast<Session::HandlerTag>(add_sps_wrapper)),
|
|
[&](Expected<int32_t> R) { Result = cantFail(std::move(R)); }, 41, 1);
|
|
|
|
EnqueueingDispatcher::runTasksFromFront(Tasks);
|
|
|
|
EXPECT_EQ(Result, 42);
|
|
|
|
S.waitForShutdown();
|
|
}
|
|
|
|
TEST(ControllerAccessTest, CallToControllerBeforeAttach) {
|
|
// Expect calls to the controller prior to attaching to fail.
|
|
std::deque<std::unique_ptr<Task>> Tasks;
|
|
Session S(std::make_unique<EnqueueingDispatcher>(Tasks), noErrors);
|
|
|
|
Error Err = Error::success();
|
|
SPSWrapperFunction<int32_t(int32_t, int32_t)>::call(
|
|
CallViaSession(S, reinterpret_cast<Session::HandlerTag>(add_sps_wrapper)),
|
|
[&](Expected<int32_t> R) {
|
|
ErrorAsOutParameter _(Err);
|
|
Err = R.takeError();
|
|
},
|
|
41, 1);
|
|
|
|
EXPECT_EQ(toString(std::move(Err)), "no controller attached");
|
|
|
|
S.waitForShutdown();
|
|
}
|
|
|
|
TEST(ControllerAccessTest, CallToControllerAfterDetach) {
|
|
// Expect calls to the controller prior to attaching to fail.
|
|
std::deque<std::unique_ptr<Task>> Tasks;
|
|
Session S(std::make_unique<EnqueueingDispatcher>(Tasks), noErrors);
|
|
auto CA = std::make_shared<MockControllerAccess>(S);
|
|
S.setController(CA);
|
|
|
|
S.detachFromController();
|
|
|
|
Error Err = Error::success();
|
|
SPSWrapperFunction<int32_t(int32_t, int32_t)>::call(
|
|
CallViaSession(S, reinterpret_cast<Session::HandlerTag>(add_sps_wrapper)),
|
|
[&](Expected<int32_t> R) {
|
|
ErrorAsOutParameter _(Err);
|
|
Err = R.takeError();
|
|
},
|
|
41, 1);
|
|
|
|
EXPECT_EQ(toString(std::move(Err)), "no controller attached");
|
|
|
|
S.waitForShutdown();
|
|
}
|
|
|
|
TEST(ControllerAccessTest, CallFromController) {
|
|
// Simulate a call from the controller.
|
|
std::deque<std::unique_ptr<Task>> Tasks;
|
|
Session S(std::make_unique<EnqueueingDispatcher>(Tasks), noErrors);
|
|
auto CA = std::make_shared<MockControllerAccess>(S);
|
|
S.setController(CA);
|
|
|
|
int32_t Result = 0;
|
|
SPSWrapperFunction<int32_t(int32_t, int32_t)>::call(
|
|
CallViaMockControllerAccess(*CA, add_sps_wrapper),
|
|
[&](Expected<int32_t> R) { Result = cantFail(std::move(R)); }, 41, 1);
|
|
|
|
EnqueueingDispatcher::runTasksFromFront(Tasks);
|
|
|
|
EXPECT_EQ(Result, 42);
|
|
|
|
S.waitForShutdown();
|
|
}
|