Files
llvm/orc-rt/unittests/SessionTest.cpp
Lang Hames bfc732efbd [orc-rt] Add ControllerAccess interface. (#169598)
ControllerAccess provides an abstract interface for bidirectional RPC
between the executor (running JIT'd code) and the controller (containing
the llvm::orc::ExecutionSession). ControllerAccess implementations are
expected to implement IPC / RPC using a concrete communication method
(shared memory, pipes, sockets, native system IPC, etc).

Calls from executor to controller are made via callController, with
"handler tags" (addresses in the executor) specifying the target handler
in the controller. A handler must be associated in the controller with
the given tag for the call to succeed. This ensures that only registered
entry points in the controller can be used, and avoids leaking
controller addresses into the executor.

Calls in both directions are to "wrapper functions" that take a buffer
of bytes as input and return a buffer of bytes as output. In the ORC
runtime these must be `orc_rt_WrapperFunction`s (see
Session::handleWrapperCall). The interpretation of the byte buffers is
up to the wrapper functions: the ORC runtime imposes no restrictions on
how the bytes are to be interpreted.

ControllerAccess objects may be detached from the Session prior to
Session shutdown, in which case no further calls may be made in either
direction, and any pending results (from calls made that haven't
returned yet) should return errors. If the ControllerAccess class is
still attached at Session shutdown time it will be detached as part of
the shutdown process. The ControllerAccess::disconnect method must
support concurrent entry on multiple threads, and all callers must block
until they can guarantee that no further calls will be received or
accepted.
2025-11-26 14:53:00 +11:00

467 lines
14 KiB
C++

//===- SessionTest.cpp ----------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Tests for orc-rt's Session.h APIs.
//
//===----------------------------------------------------------------------===//
#include "orc-rt/Session.h"
#include "orc-rt/SPSWrapperFunction.h"
#include "orc-rt/ThreadPoolTaskDispatcher.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include <chrono>
#include <deque>
#include <future>
#include <optional>
using namespace orc_rt;
using ::testing::Eq;
using ::testing::Optional;
class MockResourceManager : public ResourceManager {
public:
enum class Op { Detach, Shutdown };
static Error alwaysSucceed(Op) { return Error::success(); }
MockResourceManager(std::optional<size_t> &DetachOpIdx,
std::optional<size_t> &ShutdownOpIdx, size_t &OpIdx,
move_only_function<Error(Op)> GenResult = alwaysSucceed)
: DetachOpIdx(DetachOpIdx), ShutdownOpIdx(ShutdownOpIdx), OpIdx(OpIdx),
GenResult(std::move(GenResult)) {}
void detach(OnCompleteFn OnComplete) override {
DetachOpIdx = OpIdx++;
OnComplete(GenResult(Op::Detach));
}
void shutdown(OnCompleteFn OnComplete) override {
ShutdownOpIdx = OpIdx++;
OnComplete(GenResult(Op::Shutdown));
}
private:
std::optional<size_t> &DetachOpIdx;
std::optional<size_t> &ShutdownOpIdx;
size_t &OpIdx;
move_only_function<Error(Op)> GenResult;
};
class NoDispatcher : public TaskDispatcher {
public:
void dispatch(std::unique_ptr<Task> T) override {
assert(false && "strictly no dispatching!");
}
void shutdown() override {}
};
class EnqueueingDispatcher : public TaskDispatcher {
public:
using OnShutdownRunFn = move_only_function<void()>;
EnqueueingDispatcher(std::deque<std::unique_ptr<Task>> &Tasks,
OnShutdownRunFn OnShutdownRun = {})
: Tasks(Tasks), OnShutdownRun(std::move(OnShutdownRun)) {}
void dispatch(std::unique_ptr<Task> T) override {
Tasks.push_back(std::move(T));
}
void shutdown() override {
if (OnShutdownRun)
OnShutdownRun();
}
/// Run up to NumTasks (arbitrarily many if NumTasks == std::nullopt) tasks
/// from the front of the queue, returning the number actually run.
static size_t
runTasksFromFront(std::deque<std::unique_ptr<Task>> &Tasks,
std::optional<size_t> NumTasks = std::nullopt) {
size_t NumRun = 0;
while (!Tasks.empty() && (!NumTasks || NumRun != *NumTasks)) {
auto T = std::move(Tasks.front());
Tasks.pop_front();
T->run();
++NumRun;
}
return NumRun;
}
private:
std::deque<std::unique_ptr<Task>> &Tasks;
OnShutdownRunFn OnShutdownRun;
};
class MockControllerAccess : public Session::ControllerAccess {
public:
MockControllerAccess(Session &SS) : Session::ControllerAccess(SS), SS(SS) {}
void disconnect() override {
std::unique_lock<std::mutex> Lock(M);
Shutdown = true;
ShutdownCV.wait(Lock, [this]() { return Shutdown && Outstanding == 0; });
}
void callController(OnCallHandlerCompleteFn OnComplete, HandlerTag T,
WrapperFunctionBuffer ArgBytes) override {
// Simulate a call to the controller by dispatching a task to run the
// requested function.
size_t CId;
{
std::scoped_lock<std::mutex> Lock(M);
if (Shutdown)
return;
CId = CallId++;
Pending[CId] = std::move(OnComplete);
++Outstanding;
}
SS.dispatch(makeGenericTask([this, CId, OnComplete = std::move(OnComplete),
T, ArgBytes = std::move(ArgBytes)]() mutable {
auto Fn = reinterpret_cast<orc_rt_WrapperFunction>(T);
Fn(reinterpret_cast<orc_rt_SessionRef>(this), CId, wfReturn,
ArgBytes.release());
}));
bool Notify = false;
{
std::scoped_lock<std::mutex> Lock(M);
if (--Outstanding == 0 && Shutdown)
Notify = true;
}
if (Notify)
ShutdownCV.notify_all();
}
void sendWrapperResult(uint64_t CallId,
WrapperFunctionBuffer ResultBytes) override {
// Respond to a simulated call by the controller.
OnCallHandlerCompleteFn OnComplete;
{
std::scoped_lock<std::mutex> Lock(M);
if (Shutdown) {
assert(Pending.empty() && "Shut down but results still pending?");
return;
}
auto I = Pending.find(CallId);
assert(I != Pending.end());
OnComplete = std::move(I->second);
Pending.erase(I);
++Outstanding;
}
SS.dispatch(
makeGenericTask([OnComplete = std::move(OnComplete),
ResultBytes = std::move(ResultBytes)]() mutable {
OnComplete(std::move(ResultBytes));
}));
bool Notify = false;
{
std::scoped_lock<std::mutex> Lock(M);
if (--Outstanding == 0 && Shutdown)
Notify = true;
}
if (Notify)
ShutdownCV.notify_all();
}
void callFromController(OnCallHandlerCompleteFn OnComplete,
orc_rt_WrapperFunction Fn,
WrapperFunctionBuffer ArgBytes) {
size_t CId = 0;
bool BailOut = false;
{
std::scoped_lock<std::mutex> Lock(M);
if (!Shutdown) {
CId = CallId++;
Pending[CId] = std::move(OnComplete);
++Outstanding;
} else
BailOut = true;
}
if (BailOut)
return OnComplete(WrapperFunctionBuffer::createOutOfBandError(
"Controller disconnected"));
handleWrapperCall(CId, Fn, std::move(ArgBytes));
bool Notify = false;
{
std::scoped_lock<std::mutex> Lock(M);
if (--Outstanding == 0 && Shutdown)
Notify = true;
}
if (Notify)
ShutdownCV.notify_all();
}
/// Simulate start of outstanding operation.
void incOutstanding() {
std::scoped_lock<std::mutex> Lock(M);
++Outstanding;
}
/// Simulate end of outstanding operation.
void decOutstanding() {
bool Notify = false;
{
std::scoped_lock<std::mutex> Lock(M);
if (--Outstanding == 0 && Shutdown)
Notify = true;
}
if (Notify)
ShutdownCV.notify_all();
}
private:
static void wfReturn(orc_rt_SessionRef S, uint64_t CallId,
orc_rt_WrapperFunctionBuffer ResultBytes) {
// Abuse "session" to refer to the ControllerAccess object.
// We can just re-use sendFunctionResult for this.
reinterpret_cast<MockControllerAccess *>(S)->sendWrapperResult(CallId,
ResultBytes);
}
Session &SS;
std::mutex M;
bool Shutdown = false;
size_t Outstanding = 0;
size_t CallId = 0;
std::unordered_map<size_t, OnCallHandlerCompleteFn> Pending;
std::condition_variable ShutdownCV;
};
class CallViaMockControllerAccess {
public:
CallViaMockControllerAccess(MockControllerAccess &CA,
orc_rt_WrapperFunction Fn)
: CA(CA), Fn(Fn) {}
void operator()(Session::OnCallHandlerCompleteFn OnComplete,
WrapperFunctionBuffer ArgBytes) {
CA.callFromController(std::move(OnComplete), Fn, std::move(ArgBytes));
}
private:
MockControllerAccess &CA;
orc_rt_WrapperFunction Fn;
};
// Non-overloaded version of cantFail: allows easy construction of
// move_only_functions<void(Error)>s.
static void noErrors(Error Err) { cantFail(std::move(Err)); }
TEST(SessionTest, TrivialConstructionAndDestruction) {
Session S(std::make_unique<NoDispatcher>(), noErrors);
}
TEST(SessionTest, ReportError) {
Error E = Error::success();
cantFail(std::move(E)); // Force error into checked state.
Session S(std::make_unique<NoDispatcher>(),
[&](Error Err) { E = std::move(Err); });
S.reportError(make_error<StringError>("foo"));
if (E)
EXPECT_EQ(toString(std::move(E)), "foo");
else
ADD_FAILURE() << "Missing error value";
}
TEST(SessionTest, DispatchTask) {
int X = 0;
std::deque<std::unique_ptr<Task>> Tasks;
Session S(std::make_unique<EnqueueingDispatcher>(Tasks), noErrors);
EXPECT_EQ(Tasks.size(), 0U);
S.dispatch(makeGenericTask([&]() { ++X; }));
EXPECT_EQ(Tasks.size(), 1U);
auto T = std::move(Tasks.front());
Tasks.pop_front();
T->run();
EXPECT_EQ(X, 1);
}
TEST(SessionTest, SingleResourceManager) {
size_t OpIdx = 0;
std::optional<size_t> DetachOpIdx;
std::optional<size_t> ShutdownOpIdx;
{
Session S(std::make_unique<NoDispatcher>(), noErrors);
S.addResourceManager(std::make_unique<MockResourceManager>(
DetachOpIdx, ShutdownOpIdx, OpIdx));
}
EXPECT_EQ(OpIdx, 1U);
EXPECT_EQ(DetachOpIdx, std::nullopt);
EXPECT_THAT(ShutdownOpIdx, Optional(Eq(0)));
}
TEST(SessionTest, MultipleResourceManagers) {
size_t OpIdx = 0;
std::optional<size_t> DetachOpIdx[3];
std::optional<size_t> ShutdownOpIdx[3];
{
Session S(std::make_unique<NoDispatcher>(), noErrors);
for (size_t I = 0; I != 3; ++I)
S.addResourceManager(std::make_unique<MockResourceManager>(
DetachOpIdx[I], ShutdownOpIdx[I], OpIdx));
}
EXPECT_EQ(OpIdx, 3U);
// Expect shutdown in reverse order.
for (size_t I = 0; I != 3; ++I) {
EXPECT_EQ(DetachOpIdx[I], std::nullopt);
EXPECT_THAT(ShutdownOpIdx[I], Optional(Eq(2 - I)));
}
}
TEST(SessionTest, ExpectedShutdownSequence) {
// Check that Session shutdown results in...
// 1. ResourceManagers being shut down.
// 2. The TaskDispatcher being shut down.
// 3. A call to OnShutdownComplete.
size_t OpIdx = 0;
std::optional<size_t> DetachOpIdx;
std::optional<size_t> ShutdownOpIdx;
bool DispatcherShutDown = false;
bool SessionShutdownComplete = false;
std::deque<std::unique_ptr<Task>> Tasks;
Session S(std::make_unique<EnqueueingDispatcher>(
Tasks,
[&]() {
EXPECT_TRUE(ShutdownOpIdx);
EXPECT_EQ(*ShutdownOpIdx, 0);
EXPECT_FALSE(SessionShutdownComplete);
DispatcherShutDown = true;
}),
noErrors);
S.addResourceManager(
std::make_unique<MockResourceManager>(DetachOpIdx, ShutdownOpIdx, OpIdx));
S.shutdown([&]() {
EXPECT_TRUE(DispatcherShutDown);
SessionShutdownComplete = true;
});
S.waitForShutdown();
EXPECT_TRUE(SessionShutdownComplete);
}
TEST(ControllerAccessTest, Basics) {
// Test that we can set the ControllerAccess implementation and still shut
// down as expected.
std::deque<std::unique_ptr<Task>> Tasks;
Session S(std::make_unique<EnqueueingDispatcher>(Tasks), noErrors);
auto CA = std::make_shared<MockControllerAccess>(S);
S.setController(CA);
EnqueueingDispatcher::runTasksFromFront(Tasks);
S.waitForShutdown();
}
static void add_sps_wrapper(orc_rt_SessionRef S, uint64_t CallId,
orc_rt_WrapperFunctionReturn Return,
orc_rt_WrapperFunctionBuffer ArgBytes) {
SPSWrapperFunction<int32_t(int32_t, int32_t)>::handle(
S, CallId, Return, ArgBytes,
[](move_only_function<void(int32_t)> Return, int32_t X, int32_t Y) {
Return(X + Y);
});
}
TEST(ControllerAccessTest, ValidCallToController) {
// Simulate a call to a controller handler.
std::deque<std::unique_ptr<Task>> Tasks;
Session S(std::make_unique<EnqueueingDispatcher>(Tasks), noErrors);
auto CA = std::make_shared<MockControllerAccess>(S);
S.setController(CA);
int32_t Result = 0;
SPSWrapperFunction<int32_t(int32_t, int32_t)>::call(
CallViaSession(S, reinterpret_cast<Session::HandlerTag>(add_sps_wrapper)),
[&](Expected<int32_t> R) { Result = cantFail(std::move(R)); }, 41, 1);
EnqueueingDispatcher::runTasksFromFront(Tasks);
EXPECT_EQ(Result, 42);
S.waitForShutdown();
}
TEST(ControllerAccessTest, CallToControllerBeforeAttach) {
// Expect calls to the controller prior to attaching to fail.
std::deque<std::unique_ptr<Task>> Tasks;
Session S(std::make_unique<EnqueueingDispatcher>(Tasks), noErrors);
Error Err = Error::success();
SPSWrapperFunction<int32_t(int32_t, int32_t)>::call(
CallViaSession(S, reinterpret_cast<Session::HandlerTag>(add_sps_wrapper)),
[&](Expected<int32_t> R) {
ErrorAsOutParameter _(Err);
Err = R.takeError();
},
41, 1);
EXPECT_EQ(toString(std::move(Err)), "no controller attached");
S.waitForShutdown();
}
TEST(ControllerAccessTest, CallToControllerAfterDetach) {
// Expect calls to the controller prior to attaching to fail.
std::deque<std::unique_ptr<Task>> Tasks;
Session S(std::make_unique<EnqueueingDispatcher>(Tasks), noErrors);
auto CA = std::make_shared<MockControllerAccess>(S);
S.setController(CA);
S.detachFromController();
Error Err = Error::success();
SPSWrapperFunction<int32_t(int32_t, int32_t)>::call(
CallViaSession(S, reinterpret_cast<Session::HandlerTag>(add_sps_wrapper)),
[&](Expected<int32_t> R) {
ErrorAsOutParameter _(Err);
Err = R.takeError();
},
41, 1);
EXPECT_EQ(toString(std::move(Err)), "no controller attached");
S.waitForShutdown();
}
TEST(ControllerAccessTest, CallFromController) {
// Simulate a call from the controller.
std::deque<std::unique_ptr<Task>> Tasks;
Session S(std::make_unique<EnqueueingDispatcher>(Tasks), noErrors);
auto CA = std::make_shared<MockControllerAccess>(S);
S.setController(CA);
int32_t Result = 0;
SPSWrapperFunction<int32_t(int32_t, int32_t)>::call(
CallViaMockControllerAccess(*CA, add_sps_wrapper),
[&](Expected<int32_t> R) { Result = cantFail(std::move(R)); }, 41, 1);
EnqueueingDispatcher::runTasksFromFront(Tasks);
EXPECT_EQ(Result, 42);
S.waitForShutdown();
}