[orc-rt] Hoist DirectCaller test utility into header to enable re-use. (#162405)

The DirectCaller utility allows "direct" calls (with arguments
serialized into, and then immediately back out of a
WrapperFunctionBuffer) to wrapper functions. It was introduced for the
SPSWrapperFunction tests, but will be useful for testing WrapperFunction
interfaces for various orc-rt APIs too, so this commit hoists it
somewhere where it can be reused.
This commit is contained in:
Lang Hames
2025-10-08 11:48:27 +11:00
committed by GitHub
parent aed53d19f9
commit 70b7a3502e
2 changed files with 73 additions and 54 deletions

View File

@@ -0,0 +1,71 @@
//===- DirectCaller.h -----------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef ORC_RT_UNITTEST_DIRECTCALLER_H
#define ORC_RT_UNITTEST_DIRECTCALLER_H
#include "orc-rt/WrapperFunction.h"
#include <memory>
#include <utility>
/// Make calls and call result handlers directly on the current thread.
class DirectCaller {
private:
class DirectResultSender {
public:
virtual ~DirectResultSender() {}
virtual void send(orc_rt_SessionRef Session,
orc_rt::WrapperFunctionBuffer ResultBytes) = 0;
static void send(orc_rt_SessionRef Session, void *CallCtx,
orc_rt_WrapperFunctionBuffer ResultBytes) {
std::unique_ptr<DirectResultSender>(
reinterpret_cast<DirectResultSender *>(CallCtx))
->send(Session, ResultBytes);
}
};
template <typename ImplFn>
class DirectResultSenderImpl : public DirectResultSender {
public:
DirectResultSenderImpl(ImplFn &&Fn) : Fn(std::forward<ImplFn>(Fn)) {}
void send(orc_rt_SessionRef Session,
orc_rt::WrapperFunctionBuffer ResultBytes) override {
Fn(Session, std::move(ResultBytes));
}
private:
std::decay_t<ImplFn> Fn;
};
template <typename ImplFn>
static std::unique_ptr<DirectResultSender>
makeDirectResultSender(ImplFn &&Fn) {
return std::make_unique<DirectResultSenderImpl<ImplFn>>(
std::forward<ImplFn>(Fn));
}
public:
DirectCaller(orc_rt_SessionRef Session, orc_rt_WrapperFunction Fn)
: Session(Session), Fn(Fn) {}
template <typename HandleResultFn>
void operator()(HandleResultFn &&HandleResult,
orc_rt::WrapperFunctionBuffer ArgBytes) {
auto DR =
makeDirectResultSender(std::forward<HandleResultFn>(HandleResult));
Fn(Session, reinterpret_cast<void *>(DR.release()),
DirectResultSender::send, ArgBytes.release());
}
private:
orc_rt_SessionRef Session;
orc_rt_WrapperFunction Fn;
};
#endif // ORC_RT_UNITTEST_DIRECTCALLER_H

View File

@@ -16,64 +16,12 @@
#include "orc-rt/WrapperFunction.h"
#include "orc-rt/move_only_function.h"
#include "DirectCaller.h"
#include "gtest/gtest.h"
using namespace orc_rt;
/// Make calls and call result handlers directly on the current thread.
class DirectCaller {
private:
class DirectResultSender {
public:
virtual ~DirectResultSender() {}
virtual void send(orc_rt_SessionRef Session,
WrapperFunctionBuffer ResultBytes) = 0;
static void send(orc_rt_SessionRef Session, void *CallCtx,
orc_rt_WrapperFunctionBuffer ResultBytes) {
std::unique_ptr<DirectResultSender>(
reinterpret_cast<DirectResultSender *>(CallCtx))
->send(Session, ResultBytes);
}
};
template <typename ImplFn>
class DirectResultSenderImpl : public DirectResultSender {
public:
DirectResultSenderImpl(ImplFn &&Fn) : Fn(std::forward<ImplFn>(Fn)) {}
void send(orc_rt_SessionRef Session,
WrapperFunctionBuffer ResultBytes) override {
Fn(Session, std::move(ResultBytes));
}
private:
std::decay_t<ImplFn> Fn;
};
template <typename ImplFn>
static std::unique_ptr<DirectResultSender>
makeDirectResultSender(ImplFn &&Fn) {
return std::make_unique<DirectResultSenderImpl<ImplFn>>(
std::forward<ImplFn>(Fn));
}
public:
DirectCaller(orc_rt_SessionRef Session, orc_rt_WrapperFunction Fn)
: Session(Session), Fn(Fn) {}
template <typename HandleResultFn>
void operator()(HandleResultFn &&HandleResult,
WrapperFunctionBuffer ArgBytes) {
auto DR =
makeDirectResultSender(std::forward<HandleResultFn>(HandleResult));
Fn(Session, reinterpret_cast<void *>(DR.release()),
DirectResultSender::send, ArgBytes.release());
}
private:
orc_rt_SessionRef Session;
orc_rt_WrapperFunction Fn;
};
static void void_noop_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx,
orc_rt_WrapperFunctionReturn Return,
orc_rt_WrapperFunctionBuffer ArgBytes) {