[orc-rt] WrapperFunction::handle: add by-ref args, minimize temporaries. (#161999)

This adds support for WrapperFunction::handle handlers that take their
arguments by reference, rather than by value.

This commit also reduces the number of temporary objects created to
support SPS-transparent conversion in SPSWrapperFunction.
This commit is contained in:
Lang Hames
2025-10-05 18:21:00 +11:00
committed by GitHub
parent 5284c83a8f
commit e8489c162b
3 changed files with 107 additions and 8 deletions

View File

@@ -57,8 +57,8 @@ private:
template <typename... Ts>
using DeserializableTuple_t = typename DeserializableTuple<Ts...>::type;
template <typename T> static T fromSerializable(T &&Arg) noexcept {
return Arg;
template <typename T> static T &&fromSerializable(T &&Arg) noexcept {
return std::forward<T>(Arg);
}
static Error fromSerializable(SPSSerializableError Err) noexcept {
@@ -86,7 +86,10 @@ public:
decltype(Args)>::deserialize(IB, Args))
return std::nullopt;
return std::apply(
[](auto &&...A) { return ArgTuple(fromSerializable(A)...); },
[](auto &&...A) {
return std::optional<ArgTuple>(std::in_place,
std::move(fromSerializable(A))...);
},
std::move(Args));
}
};

View File

@@ -111,7 +111,23 @@ struct WFHandlerTraitsImpl {
static_assert(std::is_void_v<RetT>,
"Async wrapper function handler must return void");
typedef ReturnT YieldType;
typedef std::tuple<ArgTs...> ArgTupleType;
typedef std::tuple<std::decay_t<ArgTs>...> ArgTupleType;
// Forwards arguments based on the parameter types of the handler.
template <typename FnT> class ForwardArgsAsRequested {
public:
ForwardArgsAsRequested(FnT &&Fn) : Fn(std::move(Fn)) {}
void operator()(ArgTs &...Args) { Fn(std::forward<ArgTs>(Args)...); }
private:
FnT Fn;
};
template <typename FnT>
static ForwardArgsAsRequested<std::decay_t<FnT>>
forwardArgsAsRequested(FnT &&Fn) {
return ForwardArgsAsRequested<std::decay_t<FnT>>(std::forward<FnT>(Fn));
}
};
template <typename C>
@@ -244,10 +260,11 @@ struct WrapperFunction {
if (auto Args =
S.arguments().template deserialize<ArgTuple>(std::move(ArgBytes)))
std::apply(bind_front(std::forward<Handler>(H),
detail::StructuredYield<RetTupleType, Serializer>(
Session, CallCtx, Return, std::move(S))),
std::move(*Args));
std::apply(HandlerTraits::forwardArgsAsRequested(bind_front(
std::forward<Handler>(H),
detail::StructuredYield<RetTupleType, Serializer>(
Session, CallCtx, Return, std::move(S)))),
*Args);
else
Return(Session, CallCtx,
WrapperFunctionBuffer::createOutOfBandError(

View File

@@ -10,6 +10,8 @@
//
//===----------------------------------------------------------------------===//
#include "CommonTestUtils.h"
#include "orc-rt/SPSWrapperFunction.h"
#include "orc-rt/WrapperFunction.h"
#include "orc-rt/move_only_function.h"
@@ -218,3 +220,80 @@ TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningExpectedFailureCase) {
EXPECT_EQ(ErrMsg, "N is not a multiple of 2");
}
template <size_t N> struct SPSOpCounter {};
namespace orc_rt {
template <size_t N>
class SPSSerializationTraits<SPSOpCounter<N>, OpCounter<N>> {
public:
static size_t size(const OpCounter<N> &O) { return 0; }
static bool serialize(SPSOutputBuffer &OB, const OpCounter<N> &O) {
return true;
}
static bool deserialize(SPSInputBuffer &OB, OpCounter<N> &O) { return true; }
};
} // namespace orc_rt
static void
handle_with_reference_types_sps_wrapper(orc_rt_SessionRef Session,
void *CallCtx,
orc_rt_WrapperFunctionReturn Return,
orc_rt_WrapperFunctionBuffer ArgBytes) {
SPSWrapperFunction<void(
SPSOpCounter<0>, SPSOpCounter<1>, SPSOpCounter<2>,
SPSOpCounter<3>)>::handle(Session, CallCtx, Return, ArgBytes,
[](move_only_function<void()> Return,
OpCounter<0>, OpCounter<1> &,
const OpCounter<2> &,
OpCounter<3> &&) { Return(); });
}
TEST(SPSWrapperFunctionUtilsTest, TestHandlerWithReferences) {
// Test that we can handle by-value, by-ref, by-const-ref, and by-rvalue-ref
// arguments, and that we generate the expected number of moves.
OpCounter<0>::reset();
OpCounter<1>::reset();
OpCounter<2>::reset();
OpCounter<3>::reset();
bool DidRun = false;
SPSWrapperFunction<void(SPSOpCounter<0>, SPSOpCounter<1>, SPSOpCounter<2>,
SPSOpCounter<3>)>::
call(
DirectCaller(nullptr, handle_with_reference_types_sps_wrapper),
[&](Error R) {
cantFail(std::move(R));
DidRun = true;
},
OpCounter<0>(), OpCounter<1>(), OpCounter<2>(), OpCounter<3>());
EXPECT_TRUE(DidRun);
// We expect two default constructions for each parameter: one for the
// argument to call, and one for the object to deserialize into.
EXPECT_EQ(OpCounter<0>::defaultConstructions(), 2U);
EXPECT_EQ(OpCounter<1>::defaultConstructions(), 2U);
EXPECT_EQ(OpCounter<2>::defaultConstructions(), 2U);
EXPECT_EQ(OpCounter<3>::defaultConstructions(), 2U);
// Pass-by-value: we expect two moves (one for SPS transparent conversion,
// one to copy the value to the parameter), and no copies.
EXPECT_EQ(OpCounter<0>::moves(), 2U);
EXPECT_EQ(OpCounter<0>::copies(), 0U);
// Pass-by-lvalue-reference: we expect one move (for SPS transparent
// conversion), no copies.
EXPECT_EQ(OpCounter<1>::moves(), 1U);
EXPECT_EQ(OpCounter<1>::copies(), 0U);
// Pass-by-const-lvalue-reference: we expect one move (for SPS transparent
// conversion), no copies.
EXPECT_EQ(OpCounter<2>::moves(), 1U);
EXPECT_EQ(OpCounter<2>::copies(), 0U);
// Pass-by-rvalue-reference: we expect one move (for SPS transparent
// conversion), no copies.
EXPECT_EQ(OpCounter<3>::moves(), 1U);
EXPECT_EQ(OpCounter<3>::copies(), 0U);
}