mirror of
https://github.com/intel/llvm.git
synced 2026-01-13 11:02:04 +08:00
[lldb] Adding A new Binding helper for JSONTransport. (#159160)
This adds a new Binding helper class to allow mapping of incoming and outgoing requests / events to specific handlers. This should make it easier to create new protocol implementations and allow us to create a relay in the lldb-mcp binary.
This commit is contained in:
@@ -18,6 +18,7 @@
|
||||
#include "lldb/Utility/IOObject.h"
|
||||
#include "lldb/Utility/Status.h"
|
||||
#include "lldb/lldb-forward.h"
|
||||
#include "llvm/ADT/FunctionExtras.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Error.h"
|
||||
@@ -25,13 +26,23 @@
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/JSON.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <functional>
|
||||
#include <mutex>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <system_error>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
#if __cplusplus >= 202002L
|
||||
#include <concepts>
|
||||
#endif
|
||||
|
||||
namespace lldb_private {
|
||||
namespace lldb_private::transport {
|
||||
|
||||
/// An error to indicate that the transport reached EOF but there were still
|
||||
/// unhandled contents in the read buffer.
|
||||
class TransportUnhandledContentsError
|
||||
: public llvm::ErrorInfo<TransportUnhandledContentsError> {
|
||||
public:
|
||||
@@ -50,17 +61,75 @@ private:
|
||||
std::string m_unhandled_contents;
|
||||
};
|
||||
|
||||
/// An error to indicate that the parameters of a Req, Resp or Evt could not be
|
||||
/// deserialized.
|
||||
class InvalidParams : public llvm::ErrorInfo<InvalidParams> {
|
||||
public:
|
||||
static char ID;
|
||||
|
||||
explicit InvalidParams(std::string method, std::string context)
|
||||
: m_method(std::move(method)), m_context(std::move(context)) {}
|
||||
|
||||
void log(llvm::raw_ostream &OS) const override;
|
||||
std::error_code convertToErrorCode() const override;
|
||||
|
||||
private:
|
||||
/// The JSONRPC remote method call.
|
||||
std::string m_method;
|
||||
|
||||
/// Additional context from the parsing failure, e.g. "missing value at
|
||||
/// (root)[1].str".
|
||||
std::string m_context;
|
||||
};
|
||||
|
||||
/// An error to indicate that no handler was registered for a given method.
|
||||
class MethodNotFound : public llvm::ErrorInfo<MethodNotFound> {
|
||||
public:
|
||||
static char ID;
|
||||
|
||||
static constexpr int kErrorCode = -32601;
|
||||
|
||||
explicit MethodNotFound(std::string method) : m_method(std::move(method)) {}
|
||||
|
||||
void log(llvm::raw_ostream &OS) const override;
|
||||
std::error_code convertToErrorCode() const override;
|
||||
|
||||
private:
|
||||
std::string m_method;
|
||||
};
|
||||
|
||||
#if __cplusplus >= 202002L
|
||||
/// A ProtocolDescriptor details the types used in a JSONTransport for handling
|
||||
/// transport communication.
|
||||
template <typename T>
|
||||
concept ProtocolDescriptor = requires {
|
||||
typename T::Id;
|
||||
typename T::Req;
|
||||
typename T::Resp;
|
||||
typename T::Evt;
|
||||
};
|
||||
#endif
|
||||
|
||||
/// A transport is responsible for maintaining the connection to a client
|
||||
/// application, and reading/writing structured messages to it.
|
||||
///
|
||||
/// Transports have limited thread safety requirements:
|
||||
/// JSONTransport have limited thread safety requirements:
|
||||
/// - Messages will not be sent concurrently.
|
||||
/// - Messages MAY be sent while Run() is reading, or its callback is active.
|
||||
template <typename Req, typename Resp, typename Evt> class Transport {
|
||||
///
|
||||
#if __cplusplus >= 202002L
|
||||
template <ProtocolDescriptor Proto>
|
||||
#else
|
||||
template <typename Proto>
|
||||
#endif
|
||||
class JSONTransport {
|
||||
public:
|
||||
using Req = typename Proto::Req;
|
||||
using Resp = typename Proto::Resp;
|
||||
using Evt = typename Proto::Evt;
|
||||
using Message = std::variant<Req, Resp, Evt>;
|
||||
|
||||
virtual ~Transport() = default;
|
||||
virtual ~JSONTransport() = default;
|
||||
|
||||
/// Sends an event, a message that does not require a response.
|
||||
virtual llvm::Error Send(const Evt &) = 0;
|
||||
@@ -69,7 +138,8 @@ public:
|
||||
/// Sends a response to a specific request.
|
||||
virtual llvm::Error Send(const Resp &) = 0;
|
||||
|
||||
/// Implemented to handle incoming messages. (See Run() below).
|
||||
/// Implemented to handle incoming messages. (See `RegisterMessageHandler()`
|
||||
/// below).
|
||||
class MessageHandler {
|
||||
public:
|
||||
virtual ~MessageHandler() = default;
|
||||
@@ -90,8 +160,6 @@ public:
|
||||
virtual void OnClosed() = 0;
|
||||
};
|
||||
|
||||
using MessageHandlerSP = std::shared_ptr<MessageHandler>;
|
||||
|
||||
/// RegisterMessageHandler registers the Transport with the given MainLoop and
|
||||
/// handles any incoming messages using the given MessageHandler.
|
||||
///
|
||||
@@ -108,18 +176,23 @@ protected:
|
||||
};
|
||||
|
||||
/// An IOTransport sends and receives messages using an IOObject.
|
||||
template <typename Req, typename Resp, typename Evt>
|
||||
class IOTransport : public Transport<Req, Resp, Evt> {
|
||||
template <typename Proto> class IOTransport : public JSONTransport<Proto> {
|
||||
public:
|
||||
using Transport<Req, Resp, Evt>::Transport;
|
||||
using MessageHandler = typename Transport<Req, Resp, Evt>::MessageHandler;
|
||||
using Message = typename JSONTransport<Proto>::Message;
|
||||
using MessageHandler = typename JSONTransport<Proto>::MessageHandler;
|
||||
|
||||
IOTransport(lldb::IOObjectSP in, lldb::IOObjectSP out)
|
||||
: m_in(in), m_out(out) {}
|
||||
|
||||
llvm::Error Send(const Evt &evt) override { return Write(evt); }
|
||||
llvm::Error Send(const Req &req) override { return Write(req); }
|
||||
llvm::Error Send(const Resp &resp) override { return Write(resp); }
|
||||
llvm::Error Send(const typename Proto::Evt &evt) override {
|
||||
return Write(evt);
|
||||
}
|
||||
llvm::Error Send(const typename Proto::Req &req) override {
|
||||
return Write(req);
|
||||
}
|
||||
llvm::Error Send(const typename Proto::Resp &resp) override {
|
||||
return Write(resp);
|
||||
}
|
||||
|
||||
llvm::Expected<MainLoop::ReadHandleUP>
|
||||
RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) override {
|
||||
@@ -139,7 +212,7 @@ public:
|
||||
/// detail.
|
||||
static constexpr size_t kReadBufferSize = 1024;
|
||||
|
||||
// FIXME: Write should be protected.
|
||||
protected:
|
||||
llvm::Error Write(const llvm::json::Value &message) {
|
||||
this->Logv("<-- {0}", message);
|
||||
std::string output = Encode(message);
|
||||
@@ -147,7 +220,6 @@ public:
|
||||
return m_out->Write(output.data(), bytes_written).takeError();
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual llvm::Expected<std::vector<std::string>> Parse() = 0;
|
||||
virtual std::string Encode(const llvm::json::Value &message) = 0;
|
||||
|
||||
@@ -174,9 +246,8 @@ private:
|
||||
}
|
||||
|
||||
for (const std::string &raw_message : *raw_messages) {
|
||||
llvm::Expected<typename Transport<Req, Resp, Evt>::Message> message =
|
||||
llvm::json::parse<typename Transport<Req, Resp, Evt>::Message>(
|
||||
raw_message);
|
||||
llvm::Expected<Message> message =
|
||||
llvm::json::parse<Message>(raw_message);
|
||||
if (!message) {
|
||||
handler.OnError(message.takeError());
|
||||
return;
|
||||
@@ -201,10 +272,14 @@ private:
|
||||
};
|
||||
|
||||
/// A transport class for JSON with a HTTP header.
|
||||
template <typename Req, typename Resp, typename Evt>
|
||||
class HTTPDelimitedJSONTransport : public IOTransport<Req, Resp, Evt> {
|
||||
#if __cplusplus >= 202002L
|
||||
template <ProtocolDescriptor Proto>
|
||||
#else
|
||||
template <typename Proto>
|
||||
#endif
|
||||
class HTTPDelimitedJSONTransport : public IOTransport<Proto> {
|
||||
public:
|
||||
using IOTransport<Req, Resp, Evt>::IOTransport;
|
||||
using IOTransport<Proto>::IOTransport;
|
||||
|
||||
protected:
|
||||
/// Encodes messages based on
|
||||
@@ -230,8 +305,8 @@ protected:
|
||||
for (const llvm::StringRef &header :
|
||||
llvm::split(headers, kHeaderSeparator)) {
|
||||
auto [key, value] = header.split(kHeaderFieldSeparator);
|
||||
// 'Content-Length' is the only meaningful key at the moment. Others are
|
||||
// ignored.
|
||||
// 'Content-Length' is the only meaningful key at the moment. Others
|
||||
// are ignored.
|
||||
if (!key.equals_insensitive(kHeaderContentLength))
|
||||
continue;
|
||||
|
||||
@@ -268,10 +343,14 @@ protected:
|
||||
};
|
||||
|
||||
/// A transport class for JSON RPC.
|
||||
template <typename Req, typename Resp, typename Evt>
|
||||
class JSONRPCTransport : public IOTransport<Req, Resp, Evt> {
|
||||
#if __cplusplus >= 202002L
|
||||
template <ProtocolDescriptor Proto>
|
||||
#else
|
||||
template <typename Proto>
|
||||
#endif
|
||||
class JSONRPCTransport : public IOTransport<Proto> {
|
||||
public:
|
||||
using IOTransport<Req, Resp, Evt>::IOTransport;
|
||||
using IOTransport<Proto>::IOTransport;
|
||||
|
||||
protected:
|
||||
std::string Encode(const llvm::json::Value &message) override {
|
||||
@@ -297,6 +376,497 @@ protected:
|
||||
static constexpr llvm::StringLiteral kMessageSeparator = "\n";
|
||||
};
|
||||
|
||||
} // namespace lldb_private
|
||||
/// A handler for the response to an outgoing request.
|
||||
template <typename T>
|
||||
using Reply =
|
||||
std::conditional_t<std::is_void_v<T>,
|
||||
llvm::unique_function<void(llvm::Error)>,
|
||||
llvm::unique_function<void(llvm::Expected<T>)>>;
|
||||
|
||||
namespace detail {
|
||||
template <typename R, typename P> struct request_t final {
|
||||
using type = llvm::unique_function<void(const P &, Reply<R>)>;
|
||||
};
|
||||
template <typename R> struct request_t<R, void> final {
|
||||
using type = llvm::unique_function<void(Reply<R>)>;
|
||||
};
|
||||
template <typename P> struct event_t final {
|
||||
using type = llvm::unique_function<void(const P &)>;
|
||||
};
|
||||
template <> struct event_t<void> final {
|
||||
using type = llvm::unique_function<void()>;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <typename R, typename P>
|
||||
using OutgoingRequest = typename detail::request_t<R, P>::type;
|
||||
|
||||
/// A function to send an outgoing event.
|
||||
template <typename P> using OutgoingEvent = typename detail::event_t<P>::type;
|
||||
|
||||
#if __cplusplus >= 202002L
|
||||
/// This represents a protocol description that includes additional helpers
|
||||
/// for constructing requests, responses and events to work with `Binder`.
|
||||
template <typename T>
|
||||
concept BindingBuilder =
|
||||
ProtocolDescriptor<T> &&
|
||||
requires(T::Id id, T::Req req, T::Resp resp, T::Evt evt,
|
||||
llvm::StringRef method, std::optional<llvm::json::Value> params,
|
||||
std::optional<llvm::json::Value> result, llvm::Error err) {
|
||||
/// For initializing the unique sequence identifier;
|
||||
{ T::InitialId() } -> std::same_as<typename T::Id>;
|
||||
/// Incrementing the sequence identifier.
|
||||
{ id++ } -> std::same_as<typename T::Id>;
|
||||
|
||||
/// Constructing protocol types
|
||||
/// @{
|
||||
/// Construct a new request.
|
||||
{ T::Make(id, method, params) } -> std::same_as<typename T::Req>;
|
||||
/// Construct a new error response.
|
||||
{ T::Make(req, std::move(err)) } -> std::same_as<typename T::Resp>;
|
||||
/// Construct a new success response.
|
||||
{ T::Make(req, result) } -> std::same_as<typename T::Resp>;
|
||||
/// Construct a new event.
|
||||
{ T::Make(method, params) } -> std::same_as<typename T::Evt>;
|
||||
/// @}
|
||||
|
||||
/// Keys for associated types.
|
||||
/// @{
|
||||
/// Looking up in flight responses.
|
||||
{ T::KeyFor(resp) } -> std::same_as<typename T::Id>;
|
||||
/// Extract method from request.
|
||||
{ T::KeyFor(req) } -> std::same_as<std::string>;
|
||||
/// Extract method from event.
|
||||
{ T::KeyFor(evt) } -> std::same_as<std::string>;
|
||||
/// @}
|
||||
|
||||
/// Extracting information from associated types.
|
||||
/// @{
|
||||
/// Extract parameters from a request.
|
||||
{ T::Extract(req) } -> std::same_as<std::optional<llvm::json::Value>>;
|
||||
/// Extract result from a response.
|
||||
{ T::Extract(resp) } -> std::same_as<llvm::Expected<llvm::json::Value>>;
|
||||
/// Extract parameters from an event.
|
||||
{ T::Extract(evt) } -> std::same_as<std::optional<llvm::json::Value>>;
|
||||
/// @}
|
||||
};
|
||||
#endif
|
||||
|
||||
/// Binder collects a table of functions that handle calls.
|
||||
///
|
||||
/// The wrapper takes care of parsing/serializing responses.
|
||||
///
|
||||
/// This allows a JSONTransport to handle incoming and outgoing requests and
|
||||
/// events.
|
||||
///
|
||||
/// A bind of an incoming request to a lambda.
|
||||
/// \code{cpp}
|
||||
/// Binder binder{transport};
|
||||
/// binder.bind<int, vector<int>>("adder", [](const vector<int> ¶ms) {
|
||||
/// int sum = 0;
|
||||
/// for (int v : params)
|
||||
/// sum += v;
|
||||
/// return sum;
|
||||
/// });
|
||||
/// \endcode
|
||||
///
|
||||
/// A bind of an outgoing request.
|
||||
/// \code{cpp}
|
||||
/// OutgoingRequest<int, vector<int>> call_add =
|
||||
/// binder.bind<int, vector<int>>("add");
|
||||
/// call_add({1,2,3}, [](Expected<int> result) {
|
||||
/// cout << *result << "\n";
|
||||
/// });
|
||||
/// \endcode
|
||||
#if __cplusplus >= 202002L
|
||||
template <BindingBuilder Proto>
|
||||
#else
|
||||
template <typename Proto>
|
||||
#endif
|
||||
class Binder : public JSONTransport<Proto>::MessageHandler {
|
||||
using Req = typename Proto::Req;
|
||||
using Resp = typename Proto::Resp;
|
||||
using Evt = typename Proto::Evt;
|
||||
using Id = typename Proto::Id;
|
||||
using Transport = JSONTransport<Proto>;
|
||||
using MessageHandler = typename Transport::MessageHandler;
|
||||
|
||||
public:
|
||||
explicit Binder(Transport &transport) : m_transport(transport), m_seq(0) {}
|
||||
|
||||
Binder(const Binder &) = delete;
|
||||
Binder &operator=(const Binder &) = delete;
|
||||
|
||||
/// Bind a handler on transport disconnect.
|
||||
template <typename Fn, typename... Args>
|
||||
void OnDisconnect(Fn &&fn, Args &&...args);
|
||||
|
||||
/// Bind a handler on error when communicating with the transport.
|
||||
template <typename Fn, typename... Args>
|
||||
void OnError(Fn &&fn, Args &&...args);
|
||||
|
||||
/// Bind a handler for an incoming request.
|
||||
/// e.g. `bind("peek", &ThisModule::peek, this);`.
|
||||
/// Handler should be e.g. `Expected<PeekResult> peek(const PeekParams&);`
|
||||
/// PeekParams must be JSON parsable and PeekResult must be serializable.
|
||||
template <typename Result, typename Params, typename Fn, typename... Args>
|
||||
void Bind(llvm::StringLiteral method, Fn &&fn, Args &&...args);
|
||||
|
||||
/// Bind a handler for an incoming event.
|
||||
/// e.g. `bind("peek", &ThisModule::peek, this);`
|
||||
/// Handler should be e.g. `void peek(const PeekParams&);`
|
||||
/// PeekParams must be JSON parsable.
|
||||
template <typename Params, typename Fn, typename... Args>
|
||||
void Bind(llvm::StringLiteral method, Fn &&fn, Args &&...args);
|
||||
|
||||
/// Bind a function object to be used for outgoing requests.
|
||||
/// e.g. `OutgoingRequest<Params, Result> Edit = bind("edit");`
|
||||
/// Params must be JSON-serializable, Result must be parsable.
|
||||
template <typename Result, typename Params>
|
||||
OutgoingRequest<Result, Params> Bind(llvm::StringLiteral method);
|
||||
|
||||
/// Bind a function object to be used for outgoing events.
|
||||
/// e.g. `OutgoingEvent<LogParams> Log = bind("log");`
|
||||
/// LogParams must be JSON-serializable.
|
||||
template <typename Params>
|
||||
OutgoingEvent<Params> Bind(llvm::StringLiteral method);
|
||||
|
||||
void Received(const Evt &evt) override {
|
||||
std::scoped_lock<std::recursive_mutex> guard(m_mutex);
|
||||
auto it = m_event_handlers.find(Proto::KeyFor(evt));
|
||||
if (it == m_event_handlers.end()) {
|
||||
OnError(llvm::createStringError(
|
||||
llvm::formatv("no handled for event {0}", toJSON(evt))));
|
||||
return;
|
||||
}
|
||||
it->second(evt);
|
||||
}
|
||||
|
||||
void Received(const Req &req) override {
|
||||
ReplyOnce reply(req, &m_transport, this);
|
||||
|
||||
std::scoped_lock<std::recursive_mutex> guard(m_mutex);
|
||||
auto it = m_request_handlers.find(Proto::KeyFor(req));
|
||||
if (it == m_request_handlers.end()) {
|
||||
reply(Proto::Make(req, llvm::createStringError("method not found")));
|
||||
return;
|
||||
}
|
||||
|
||||
it->second(req, std::move(reply));
|
||||
}
|
||||
|
||||
void Received(const Resp &resp) override {
|
||||
std::scoped_lock<std::recursive_mutex> guard(m_mutex);
|
||||
|
||||
Id id = Proto::KeyFor(resp);
|
||||
auto it = m_pending_responses.find(id);
|
||||
if (it == m_pending_responses.end()) {
|
||||
OnError(llvm::createStringError(
|
||||
llvm::formatv("no pending request for {0}", toJSON(resp))));
|
||||
return;
|
||||
}
|
||||
|
||||
it->second(resp);
|
||||
m_pending_responses.erase(it);
|
||||
}
|
||||
|
||||
void OnError(llvm::Error err) override {
|
||||
std::scoped_lock<std::recursive_mutex> guard(m_mutex);
|
||||
if (m_error_handler)
|
||||
m_error_handler(std::move(err));
|
||||
}
|
||||
|
||||
void OnClosed() override {
|
||||
std::scoped_lock<std::recursive_mutex> guard(m_mutex);
|
||||
if (m_disconnect_handler)
|
||||
m_disconnect_handler();
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
llvm::Expected<T> static Parse(const llvm::json::Value &raw,
|
||||
llvm::StringRef method);
|
||||
|
||||
template <typename T> using Callback = llvm::unique_function<T>;
|
||||
|
||||
std::recursive_mutex m_mutex;
|
||||
Transport &m_transport;
|
||||
Id m_seq;
|
||||
std::map<Id, Callback<void(const Resp &)>> m_pending_responses;
|
||||
llvm::StringMap<Callback<void(const Req &, Callback<void(const Resp &)>)>>
|
||||
m_request_handlers;
|
||||
llvm::StringMap<Callback<void(const Evt &)>> m_event_handlers;
|
||||
Callback<void()> m_disconnect_handler;
|
||||
Callback<void(llvm::Error)> m_error_handler;
|
||||
|
||||
/// Function object to reply to a call.
|
||||
/// Each instance must be called exactly once, otherwise:
|
||||
/// - the bug is logged, and (in debug mode) an assert will fire
|
||||
/// - if there was no reply, an error reply is sent
|
||||
/// - if there were multiple replies, only the first is sent
|
||||
class ReplyOnce {
|
||||
std::atomic<bool> replied = {false};
|
||||
const Req req;
|
||||
Transport *transport; // Null when moved-from.
|
||||
MessageHandler *handler; // Null when moved-from.
|
||||
|
||||
public:
|
||||
ReplyOnce(const Req req, Transport *transport, MessageHandler *handler)
|
||||
: req(req), transport(transport), handler(handler) {
|
||||
assert(handler);
|
||||
}
|
||||
ReplyOnce(ReplyOnce &&other)
|
||||
: replied(other.replied.load()), req(other.req),
|
||||
transport(other.transport), handler(other.handler) {
|
||||
other.transport = nullptr;
|
||||
other.handler = nullptr;
|
||||
}
|
||||
ReplyOnce &operator=(ReplyOnce &&) = delete;
|
||||
ReplyOnce(const ReplyOnce &) = delete;
|
||||
ReplyOnce &operator=(const ReplyOnce &) = delete;
|
||||
|
||||
~ReplyOnce() {
|
||||
if (transport && handler && !replied) {
|
||||
assert(false && "must reply to all calls!");
|
||||
(*this)(Proto::Make(req, llvm::createStringError("failed to reply")));
|
||||
}
|
||||
}
|
||||
|
||||
void operator()(const Resp &resp) {
|
||||
assert(transport && handler && "moved-from!");
|
||||
if (replied.exchange(true)) {
|
||||
assert(false && "must reply to each call only once!");
|
||||
return;
|
||||
}
|
||||
|
||||
if (llvm::Error error = transport->Send(resp))
|
||||
handler->OnError(std::move(error));
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
#if __cplusplus >= 202002L
|
||||
template <BindingBuilder Proto>
|
||||
#else
|
||||
template <typename Proto>
|
||||
#endif
|
||||
template <typename Fn, typename... Args>
|
||||
void Binder<Proto>::OnDisconnect(Fn &&fn, Args &&...args) {
|
||||
m_disconnect_handler = [fn, args...]() mutable {
|
||||
std::invoke(std::forward<Fn>(fn), std::forward<Args>(args)...);
|
||||
};
|
||||
}
|
||||
|
||||
#if __cplusplus >= 202002L
|
||||
template <BindingBuilder Proto>
|
||||
#else
|
||||
template <typename Proto>
|
||||
#endif
|
||||
template <typename Fn, typename... Args>
|
||||
void Binder<Proto>::OnError(Fn &&fn, Args &&...args) {
|
||||
m_error_handler = [fn, args...](llvm::Error error) mutable {
|
||||
std::invoke(std::forward<Fn>(fn), std::forward<Args>(args)...,
|
||||
std::move(error));
|
||||
};
|
||||
}
|
||||
|
||||
#if __cplusplus >= 202002L
|
||||
template <BindingBuilder Proto>
|
||||
#else
|
||||
template <typename Proto>
|
||||
#endif
|
||||
template <typename Result, typename Params, typename Fn, typename... Args>
|
||||
void Binder<Proto>::Bind(llvm::StringLiteral method, Fn &&fn, Args &&...args) {
|
||||
assert(m_request_handlers.find(method) == m_request_handlers.end() &&
|
||||
"request already bound");
|
||||
if constexpr (std::is_void_v<Result> && std::is_void_v<Params>) {
|
||||
m_request_handlers[method] =
|
||||
[fn, args...](const Req &req,
|
||||
llvm::unique_function<void(const Resp &)> reply) mutable {
|
||||
llvm::Error result =
|
||||
std::invoke(std::forward<Fn>(fn), std::forward<Args>(args)...);
|
||||
reply(Proto::Make(req, std::move(result)));
|
||||
};
|
||||
} else if constexpr (std::is_void_v<Params>) {
|
||||
m_request_handlers[method] =
|
||||
[fn, args...](const Req &req,
|
||||
llvm::unique_function<void(const Resp &)> reply) mutable {
|
||||
llvm::Expected<Result> result =
|
||||
std::invoke(std::forward<Fn>(fn), std::forward<Args>(args)...);
|
||||
if (!result)
|
||||
return reply(Proto::Make(req, result.takeError()));
|
||||
reply(Proto::Make(req, toJSON(*result)));
|
||||
};
|
||||
} else if constexpr (std::is_void_v<Result>) {
|
||||
m_request_handlers[method] =
|
||||
[method, fn,
|
||||
args...](const Req &req,
|
||||
llvm::unique_function<void(const Resp &)> reply) mutable {
|
||||
llvm::Expected<Params> params =
|
||||
Parse<Params>(Proto::Extract(req), method);
|
||||
if (!params)
|
||||
return reply(Proto::Make(req, params.takeError()));
|
||||
|
||||
llvm::Error result = std::invoke(
|
||||
std::forward<Fn>(fn), std::forward<Args>(args)..., *params);
|
||||
reply(Proto::Make(req, std::move(result)));
|
||||
};
|
||||
} else {
|
||||
m_request_handlers[method] =
|
||||
[method, fn,
|
||||
args...](const Req &req,
|
||||
llvm::unique_function<void(const Resp &)> reply) mutable {
|
||||
llvm::Expected<Params> params =
|
||||
Parse<Params>(Proto::Extract(req), method);
|
||||
if (!params)
|
||||
return reply(Proto::Make(req, params.takeError()));
|
||||
|
||||
llvm::Expected<Result> result = std::invoke(
|
||||
std::forward<Fn>(fn), std::forward<Args>(args)..., *params);
|
||||
if (!result)
|
||||
return reply(Proto::Make(req, result.takeError()));
|
||||
|
||||
reply(Proto::Make(req, toJSON(*result)));
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#if __cplusplus >= 202002L
|
||||
template <BindingBuilder Proto>
|
||||
#else
|
||||
template <typename Proto>
|
||||
#endif
|
||||
template <typename Params, typename Fn, typename... Args>
|
||||
void Binder<Proto>::Bind(llvm::StringLiteral method, Fn &&fn, Args &&...args) {
|
||||
assert(m_event_handlers.find(method) == m_event_handlers.end() &&
|
||||
"event already bound");
|
||||
if constexpr (std::is_void_v<Params>) {
|
||||
m_event_handlers[method] = [fn, args...](const Evt &) mutable {
|
||||
std::invoke(std::forward<Fn>(fn), std::forward<Args>(args)...);
|
||||
};
|
||||
} else {
|
||||
m_event_handlers[method] = [this, method, fn,
|
||||
args...](const Evt &evt) mutable {
|
||||
llvm::Expected<Params> params =
|
||||
Parse<Params>(Proto::Extract(evt), method);
|
||||
if (!params)
|
||||
return OnError(params.takeError());
|
||||
std::invoke(std::forward<Fn>(fn), std::forward<Args>(args)..., *params);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#if __cplusplus >= 202002L
|
||||
template <BindingBuilder Proto>
|
||||
#else
|
||||
template <typename Proto>
|
||||
#endif
|
||||
template <typename Result, typename Params>
|
||||
OutgoingRequest<Result, Params>
|
||||
Binder<Proto>::Bind(llvm::StringLiteral method) {
|
||||
if constexpr (std::is_void_v<Result> && std::is_void_v<Params>) {
|
||||
return [this, method](Reply<Result> fn) {
|
||||
std::scoped_lock<std::recursive_mutex> guard(m_mutex);
|
||||
Id id = ++m_seq;
|
||||
Req req = Proto::Make(id, method, std::nullopt);
|
||||
m_pending_responses[id] = [fn = std::move(fn)](const Resp &resp) mutable {
|
||||
llvm::Expected<llvm::json::Value> result = Proto::Extract(resp);
|
||||
if (!result)
|
||||
return fn(result.takeError());
|
||||
fn(llvm::Error::success());
|
||||
};
|
||||
if (llvm::Error error = m_transport.Send(req))
|
||||
OnError(std::move(error));
|
||||
};
|
||||
} else if constexpr (std::is_void_v<Params>) {
|
||||
return [this, method](Reply<Result> fn) {
|
||||
std::scoped_lock<std::recursive_mutex> guard(m_mutex);
|
||||
Id id = ++m_seq;
|
||||
Req req = Proto::Make(id, method, std::nullopt);
|
||||
m_pending_responses[id] = [fn = std::move(fn),
|
||||
method](const Resp &resp) mutable {
|
||||
llvm::Expected<llvm::json::Value> result = Proto::Extract(resp);
|
||||
if (!result)
|
||||
return fn(result.takeError());
|
||||
fn(Parse<Result>(*result, method));
|
||||
};
|
||||
if (llvm::Error error = m_transport.Send(req))
|
||||
OnError(std::move(error));
|
||||
};
|
||||
} else if constexpr (std::is_void_v<Result>) {
|
||||
return [this, method](const Params ¶ms, Reply<Result> fn) {
|
||||
std::scoped_lock<std::recursive_mutex> guard(m_mutex);
|
||||
Id id = ++m_seq;
|
||||
Req req = Proto::Make(id, method, llvm::json::Value(params));
|
||||
m_pending_responses[id] = [fn = std::move(fn)](const Resp &resp) mutable {
|
||||
llvm::Expected<llvm::json::Value> result = Proto::Extract(resp);
|
||||
if (!result)
|
||||
return fn(result.takeError());
|
||||
fn(llvm::Error::success());
|
||||
};
|
||||
if (llvm::Error error = m_transport.Send(req))
|
||||
OnError(std::move(error));
|
||||
};
|
||||
} else {
|
||||
return [this, method](const Params ¶ms, Reply<Result> fn) {
|
||||
std::scoped_lock<std::recursive_mutex> guard(m_mutex);
|
||||
Id id = ++m_seq;
|
||||
Req req = Proto::Make(id, method, llvm::json::Value(params));
|
||||
m_pending_responses[id] = [fn = std::move(fn),
|
||||
method](const Resp &resp) mutable {
|
||||
llvm::Expected<llvm::json::Value> result = Proto::Extract(resp);
|
||||
if (llvm::Error err = result.takeError())
|
||||
return fn(std::move(err));
|
||||
fn(Parse<Result>(*result, method));
|
||||
};
|
||||
if (llvm::Error error = m_transport.Send(req))
|
||||
OnError(std::move(error));
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#if __cplusplus >= 202002L
|
||||
template <BindingBuilder Proto>
|
||||
#else
|
||||
template <typename Proto>
|
||||
#endif
|
||||
template <typename Params>
|
||||
OutgoingEvent<Params> Binder<Proto>::Bind(llvm::StringLiteral method) {
|
||||
if constexpr (std::is_void_v<Params>) {
|
||||
return [this, method]() {
|
||||
if (llvm::Error error =
|
||||
m_transport.Send(Proto::Make(method, std::nullopt)))
|
||||
OnError(std::move(error));
|
||||
};
|
||||
} else {
|
||||
return [this, method](const Params ¶ms) {
|
||||
if (llvm::Error error =
|
||||
m_transport.Send(Proto::Make(method, toJSON(params))))
|
||||
OnError(std::move(error));
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#if __cplusplus >= 202002L
|
||||
template <BindingBuilder Proto>
|
||||
#else
|
||||
template <typename Proto>
|
||||
#endif
|
||||
template <typename T>
|
||||
llvm::Expected<T> Binder<Proto>::Parse(const llvm::json::Value &raw,
|
||||
llvm::StringRef method) {
|
||||
T result;
|
||||
llvm::json::Path::Root root;
|
||||
if (!fromJSON(raw, result, root)) {
|
||||
// Dump the relevant parts of the broken message.
|
||||
std::string context;
|
||||
llvm::raw_string_ostream OS(context);
|
||||
root.printErrorContext(raw, OS);
|
||||
return llvm::make_error<InvalidParams>(method.str(), context);
|
||||
}
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
} // namespace lldb_private::transport
|
||||
|
||||
#endif
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
#ifndef LLDB_PROTOCOL_MCP_MCPERROR_H
|
||||
#define LLDB_PROTOCOL_MCP_MCPERROR_H
|
||||
|
||||
#include "lldb/Protocol/MCP/Protocol.h"
|
||||
#include "llvm/Support/Error.h"
|
||||
#include <string>
|
||||
|
||||
@@ -26,14 +25,12 @@ public:
|
||||
|
||||
const std::string &getMessage() const { return m_message; }
|
||||
|
||||
lldb_protocol::mcp::Error toProtocolError() const;
|
||||
|
||||
static constexpr int64_t kResourceNotFound = -32002;
|
||||
static constexpr int64_t kInternalError = -32603;
|
||||
|
||||
private:
|
||||
std::string m_message;
|
||||
int64_t m_error_code;
|
||||
int m_error_code;
|
||||
};
|
||||
|
||||
class UnsupportedURI : public llvm::ErrorInfo<UnsupportedURI> {
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#ifndef LLDB_PROTOCOL_MCP_PROTOCOL_H
|
||||
#define LLDB_PROTOCOL_MCP_PROTOCOL_H
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/JSON.h"
|
||||
#include <optional>
|
||||
#include <string>
|
||||
@@ -322,6 +323,10 @@ struct CallToolResult {
|
||||
llvm::json::Value toJSON(const CallToolResult &);
|
||||
bool fromJSON(const llvm::json::Value &, CallToolResult &, llvm::json::Path);
|
||||
|
||||
lldb_protocol::mcp::Request
|
||||
MakeRequest(int64_t id, llvm::StringRef method,
|
||||
std::optional<llvm::json::Value> params);
|
||||
|
||||
} // namespace lldb_protocol::mcp
|
||||
|
||||
#endif
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
#ifndef LLDB_PROTOCOL_MCP_SERVER_H
|
||||
#define LLDB_PROTOCOL_MCP_SERVER_H
|
||||
|
||||
#include "lldb/Host/JSONTransport.h"
|
||||
#include "lldb/Host/MainLoop.h"
|
||||
#include "lldb/Protocol/MCP/Protocol.h"
|
||||
#include "lldb/Protocol/MCP/Resource.h"
|
||||
@@ -19,75 +18,66 @@
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Error.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/JSON.h"
|
||||
#include "llvm/Support/Signals.h"
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace lldb_protocol::mcp {
|
||||
|
||||
class Server : public MCPTransport::MessageHandler {
|
||||
using ClosedCallback = llvm::unique_function<void()>;
|
||||
class Server {
|
||||
|
||||
using MCPTransportUP = std::unique_ptr<lldb_protocol::mcp::MCPTransport>;
|
||||
|
||||
using ReadHandleUP = lldb_private::MainLoop::ReadHandleUP;
|
||||
|
||||
public:
|
||||
Server(std::string name, std::string version, MCPTransport &client,
|
||||
LogCallback log_callback = {}, ClosedCallback closed_callback = {});
|
||||
Server(std::string name, std::string version, LogCallback log_callback = {});
|
||||
~Server() = default;
|
||||
|
||||
using NotificationHandler = std::function<void(const Notification &)>;
|
||||
|
||||
void AddTool(std::unique_ptr<Tool> tool);
|
||||
void AddResourceProvider(std::unique_ptr<ResourceProvider> resource_provider);
|
||||
void AddNotificationHandler(llvm::StringRef method,
|
||||
NotificationHandler handler);
|
||||
|
||||
llvm::Error Accept(lldb_private::MainLoop &, MCPTransportUP);
|
||||
|
||||
protected:
|
||||
MCPBinderUP Bind(MCPTransport &);
|
||||
|
||||
ServerCapabilities GetCapabilities();
|
||||
|
||||
using RequestHandler =
|
||||
std::function<llvm::Expected<Response>(const Request &)>;
|
||||
llvm::Expected<InitializeResult> InitializeHandler(const InitializeParams &);
|
||||
|
||||
void AddRequestHandlers();
|
||||
llvm::Expected<ListToolsResult> ToolsListHandler();
|
||||
llvm::Expected<CallToolResult> ToolsCallHandler(const CallToolParams &);
|
||||
|
||||
void AddRequestHandler(llvm::StringRef method, RequestHandler handler);
|
||||
llvm::Expected<ListResourcesResult> ResourcesListHandler();
|
||||
llvm::Expected<ReadResourceResult>
|
||||
ResourcesReadHandler(const ReadResourceParams &);
|
||||
|
||||
llvm::Expected<std::optional<Message>> HandleData(llvm::StringRef data);
|
||||
|
||||
llvm::Expected<Response> Handle(const Request &request);
|
||||
void Handle(const Notification ¬ification);
|
||||
|
||||
llvm::Expected<Response> InitializeHandler(const Request &);
|
||||
|
||||
llvm::Expected<Response> ToolsListHandler(const Request &);
|
||||
llvm::Expected<Response> ToolsCallHandler(const Request &);
|
||||
|
||||
llvm::Expected<Response> ResourcesListHandler(const Request &);
|
||||
llvm::Expected<Response> ResourcesReadHandler(const Request &);
|
||||
|
||||
void Received(const Request &) override;
|
||||
void Received(const Response &) override;
|
||||
void Received(const Notification &) override;
|
||||
void OnError(llvm::Error) override;
|
||||
void OnClosed() override;
|
||||
|
||||
protected:
|
||||
void Log(llvm::StringRef);
|
||||
template <typename... Ts> inline auto Logv(const char *Fmt, Ts &&...Vals) {
|
||||
Log(llvm::formatv(Fmt, std::forward<Ts>(Vals)...).str());
|
||||
}
|
||||
void Log(llvm::StringRef message) {
|
||||
if (m_log_callback)
|
||||
m_log_callback(message);
|
||||
}
|
||||
|
||||
private:
|
||||
const std::string m_name;
|
||||
const std::string m_version;
|
||||
|
||||
MCPTransport &m_client;
|
||||
LogCallback m_log_callback;
|
||||
ClosedCallback m_closed_callback;
|
||||
struct Client {
|
||||
ReadHandleUP handle;
|
||||
MCPTransportUP transport;
|
||||
MCPBinderUP binder;
|
||||
};
|
||||
std::map<MCPTransport *, Client> m_instances;
|
||||
|
||||
llvm::StringMap<std::unique_ptr<Tool>> m_tools;
|
||||
std::vector<std::unique_ptr<ResourceProvider>> m_resource_providers;
|
||||
|
||||
llvm::StringMap<RequestHandler> m_request_handlers;
|
||||
llvm::StringMap<NotificationHandler> m_notification_handlers;
|
||||
};
|
||||
|
||||
class ServerInfoHandle;
|
||||
@@ -121,7 +111,7 @@ public:
|
||||
ServerInfoHandle &operator=(const ServerInfoHandle &) = delete;
|
||||
/// @}
|
||||
|
||||
/// Remove the file.
|
||||
/// Remove the file on disk, if one is tracked.
|
||||
void Remove();
|
||||
|
||||
private:
|
||||
|
||||
@@ -10,22 +10,78 @@
|
||||
#define LLDB_PROTOCOL_MCP_TRANSPORT_H
|
||||
|
||||
#include "lldb/Host/JSONTransport.h"
|
||||
#include "lldb/Protocol/MCP/MCPError.h"
|
||||
#include "lldb/Protocol/MCP/Protocol.h"
|
||||
#include "lldb/lldb-forward.h"
|
||||
#include "llvm/ADT/FunctionExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Error.h"
|
||||
#include <sys/types.h>
|
||||
|
||||
namespace lldb_protocol::mcp {
|
||||
|
||||
struct ProtocolDescriptor {
|
||||
using Id = int64_t;
|
||||
using Req = Request;
|
||||
using Resp = Response;
|
||||
using Evt = Notification;
|
||||
|
||||
static inline Id InitialId() { return 0; }
|
||||
static inline Request Make(Id id, llvm::StringRef method,
|
||||
std::optional<llvm::json::Value> params) {
|
||||
return Request{id, method.str(), params};
|
||||
}
|
||||
static inline Notification Make(llvm::StringRef method,
|
||||
std::optional<llvm::json::Value> params) {
|
||||
return Notification{method.str(), params};
|
||||
}
|
||||
static inline Response Make(Req req, llvm::Error error) {
|
||||
lldb_protocol::mcp::Error protocol_error;
|
||||
llvm::handleAllErrors(
|
||||
std::move(error), [&](const llvm::ErrorInfoBase &err) {
|
||||
std::error_code cerr = err.convertToErrorCode();
|
||||
protocol_error.code =
|
||||
cerr == llvm::inconvertibleErrorCode()
|
||||
? lldb_protocol::mcp::eErrorCodeInternalError
|
||||
: cerr.value();
|
||||
protocol_error.message = err.message();
|
||||
});
|
||||
|
||||
return Response{req.id, std::move(protocol_error)};
|
||||
}
|
||||
static inline Response Make(Req req,
|
||||
std::optional<llvm::json::Value> result) {
|
||||
return Response{req.id, std::move(result)};
|
||||
}
|
||||
static inline Id KeyFor(Response r) { return std::get<Id>(r.id); }
|
||||
static inline std::string KeyFor(Request r) { return r.method; }
|
||||
static inline std::string KeyFor(Notification n) { return n.method; }
|
||||
static inline std::optional<llvm::json::Value> Extract(Request r) {
|
||||
return r.params;
|
||||
}
|
||||
static inline llvm::Expected<llvm::json::Value> Extract(Response r) {
|
||||
if (const lldb_protocol::mcp::Error *error =
|
||||
std::get_if<lldb_protocol::mcp::Error>(&r.result))
|
||||
return llvm::make_error<lldb_protocol::mcp::MCPError>(error->message,
|
||||
error->code);
|
||||
return std::get<llvm::json::Value>(r.result);
|
||||
}
|
||||
static inline std::optional<llvm::json::Value> Extract(Notification n) {
|
||||
return n.params;
|
||||
}
|
||||
};
|
||||
|
||||
/// Generic transport that uses the MCP protocol.
|
||||
using MCPTransport = lldb_private::Transport<Request, Response, Notification>;
|
||||
using MCPTransport = lldb_private::transport::JSONTransport<ProtocolDescriptor>;
|
||||
using MCPBinder = lldb_private::transport::Binder<ProtocolDescriptor>;
|
||||
using MCPBinderUP = std::unique_ptr<MCPBinder>;
|
||||
|
||||
/// Generic logging callback, to allow the MCP server / client / transport layer
|
||||
/// to be independent of the lldb log implementation.
|
||||
using LogCallback = llvm::unique_function<void(llvm::StringRef message)>;
|
||||
|
||||
class Transport final
|
||||
: public lldb_private::JSONRPCTransport<Request, Response, Notification> {
|
||||
: public lldb_private::transport::JSONRPCTransport<ProtocolDescriptor> {
|
||||
public:
|
||||
Transport(lldb::IOObjectSP in, lldb::IOObjectSP out,
|
||||
LogCallback log_callback = {});
|
||||
|
||||
@@ -14,8 +14,7 @@
|
||||
#include <string>
|
||||
|
||||
using namespace llvm;
|
||||
using namespace lldb;
|
||||
using namespace lldb_private;
|
||||
using namespace lldb_private::transport;
|
||||
|
||||
char TransportUnhandledContentsError::ID;
|
||||
|
||||
@@ -23,10 +22,31 @@ TransportUnhandledContentsError::TransportUnhandledContentsError(
|
||||
std::string unhandled_contents)
|
||||
: m_unhandled_contents(unhandled_contents) {}
|
||||
|
||||
void TransportUnhandledContentsError::log(llvm::raw_ostream &OS) const {
|
||||
void TransportUnhandledContentsError::log(raw_ostream &OS) const {
|
||||
OS << "transport EOF with unhandled contents: '" << m_unhandled_contents
|
||||
<< "'";
|
||||
}
|
||||
std::error_code TransportUnhandledContentsError::convertToErrorCode() const {
|
||||
return std::make_error_code(std::errc::bad_message);
|
||||
}
|
||||
|
||||
char InvalidParams::ID;
|
||||
|
||||
void InvalidParams::log(raw_ostream &OS) const {
|
||||
OS << "invalid parameters for method '" << m_method << "': '" << m_context
|
||||
<< "'";
|
||||
}
|
||||
std::error_code InvalidParams::convertToErrorCode() const {
|
||||
return std::make_error_code(std::errc::invalid_argument);
|
||||
}
|
||||
|
||||
char MethodNotFound::ID;
|
||||
|
||||
void MethodNotFound::log(raw_ostream &OS) const {
|
||||
OS << "method not found: '" << m_method << "'";
|
||||
}
|
||||
|
||||
std::error_code MethodNotFound::convertToErrorCode() const {
|
||||
// JSON-RPC Method not found
|
||||
return std::error_code(MethodNotFound::kErrorCode, std::generic_category());
|
||||
}
|
||||
|
||||
@@ -52,11 +52,6 @@ llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() {
|
||||
}
|
||||
|
||||
void ProtocolServerMCP::Extend(lldb_protocol::mcp::Server &server) const {
|
||||
server.AddNotificationHandler("notifications/initialized",
|
||||
[](const lldb_protocol::mcp::Notification &) {
|
||||
LLDB_LOG(GetLog(LLDBLog::Host),
|
||||
"MCP initialization complete");
|
||||
});
|
||||
server.AddTool(
|
||||
std::make_unique<CommandTool>("command", "Run an lldb command."));
|
||||
server.AddTool(std::make_unique<DebuggerListTool>(
|
||||
@@ -74,26 +69,9 @@ void ProtocolServerMCP::AcceptCallback(std::unique_ptr<Socket> socket) {
|
||||
io_sp, io_sp, [client_name](llvm::StringRef message) {
|
||||
LLDB_LOG(GetLog(LLDBLog::Host), "{0}: {1}", client_name, message);
|
||||
});
|
||||
MCPTransport *transport_ptr = transport_up.get();
|
||||
auto instance_up = std::make_unique<lldb_protocol::mcp::Server>(
|
||||
std::string(kName), std::string(kVersion), *transport_up,
|
||||
/*log_callback=*/
|
||||
[client_name](llvm::StringRef message) {
|
||||
LLDB_LOG(GetLog(LLDBLog::Host), "{0} Server: {1}", client_name,
|
||||
message);
|
||||
},
|
||||
/*closed_callback=*/
|
||||
[this, transport_ptr]() { m_instances.erase(transport_ptr); });
|
||||
Extend(*instance_up);
|
||||
llvm::Expected<MainLoop::ReadHandleUP> handle =
|
||||
transport_up->RegisterMessageHandler(m_loop, *instance_up);
|
||||
if (!handle) {
|
||||
LLDB_LOG_ERROR(log, handle.takeError(), "Failed to run MCP server: {0}");
|
||||
return;
|
||||
}
|
||||
m_instances[transport_ptr] =
|
||||
std::make_tuple<ServerUP, ReadHandleUP, TransportUP>(
|
||||
std::move(instance_up), std::move(*handle), std::move(transport_up));
|
||||
|
||||
if (auto error = m_server->Accept(m_loop, std::move(transport_up)))
|
||||
LLDB_LOG_ERROR(log, std::move(error), "{0}:");
|
||||
}
|
||||
|
||||
llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
|
||||
@@ -124,14 +102,21 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
|
||||
llvm::join(m_listener->GetListeningConnectionURI(), ", ");
|
||||
|
||||
ServerInfo info{listening_uris[0]};
|
||||
llvm::Expected<ServerInfoHandle> handle = ServerInfo::Write(info);
|
||||
if (!handle)
|
||||
return handle.takeError();
|
||||
llvm::Expected<ServerInfoHandle> server_info_handle = ServerInfo::Write(info);
|
||||
if (!server_info_handle)
|
||||
return server_info_handle.takeError();
|
||||
|
||||
m_client_count = 0;
|
||||
m_server = std::make_unique<lldb_protocol::mcp::Server>(
|
||||
std::string(kName), std::string(kVersion), [](StringRef message) {
|
||||
LLDB_LOG(GetLog(LLDBLog::Host), "MCP Server: {0}", message);
|
||||
});
|
||||
Extend(*m_server);
|
||||
|
||||
m_running = true;
|
||||
m_server_info_handle = std::move(*handle);
|
||||
m_listen_handlers = std::move(*handles);
|
||||
m_loop_thread = std::thread([=] {
|
||||
m_server_info_handle = std::move(*server_info_handle);
|
||||
m_accept_handles = std::move(*handles);
|
||||
m_loop_thread = std::thread([this] {
|
||||
llvm::set_thread_name("protocol-server.mcp");
|
||||
m_loop.Run();
|
||||
});
|
||||
@@ -155,9 +140,10 @@ llvm::Error ProtocolServerMCP::Stop() {
|
||||
if (m_loop_thread.joinable())
|
||||
m_loop_thread.join();
|
||||
|
||||
m_accept_handles.clear();
|
||||
|
||||
m_server.reset(nullptr);
|
||||
m_server_info_handle.Remove();
|
||||
m_listen_handlers.clear();
|
||||
m_instances.clear();
|
||||
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
@@ -23,16 +23,17 @@
|
||||
namespace lldb_private::mcp {
|
||||
|
||||
class ProtocolServerMCP : public ProtocolServer {
|
||||
using ReadHandleUP = MainLoopBase::ReadHandleUP;
|
||||
using TransportUP = std::unique_ptr<lldb_protocol::mcp::MCPTransport>;
|
||||
|
||||
using ServerUP = std::unique_ptr<lldb_protocol::mcp::Server>;
|
||||
|
||||
using ReadHandleUP = MainLoop::ReadHandleUP;
|
||||
|
||||
public:
|
||||
ProtocolServerMCP();
|
||||
virtual ~ProtocolServerMCP() override;
|
||||
~ProtocolServerMCP() override;
|
||||
|
||||
virtual llvm::Error Start(ProtocolServer::Connection connection) override;
|
||||
virtual llvm::Error Stop() override;
|
||||
llvm::Error Start(ProtocolServer::Connection connection) override;
|
||||
llvm::Error Stop() override;
|
||||
|
||||
static void Initialize();
|
||||
static void Terminate();
|
||||
@@ -56,19 +57,18 @@ private:
|
||||
|
||||
bool m_running = false;
|
||||
|
||||
lldb_protocol::mcp::ServerInfoHandle m_server_info_handle;
|
||||
lldb_private::MainLoop m_loop;
|
||||
std::thread m_loop_thread;
|
||||
std::mutex m_mutex;
|
||||
size_t m_client_count = 0;
|
||||
|
||||
std::unique_ptr<Socket> m_listener;
|
||||
std::vector<ReadHandleUP> m_accept_handles;
|
||||
|
||||
std::vector<ReadHandleUP> m_listen_handlers;
|
||||
std::map<lldb_protocol::mcp::MCPTransport *,
|
||||
std::tuple<ServerUP, ReadHandleUP, TransportUP>>
|
||||
m_instances;
|
||||
ServerUP m_server;
|
||||
lldb_protocol::mcp::ServerInfoHandle m_server_info_handle;
|
||||
};
|
||||
|
||||
} // namespace lldb_private::mcp
|
||||
|
||||
#endif
|
||||
|
||||
@@ -22,14 +22,7 @@ MCPError::MCPError(std::string message, int64_t error_code)
|
||||
void MCPError::log(llvm::raw_ostream &OS) const { OS << m_message; }
|
||||
|
||||
std::error_code MCPError::convertToErrorCode() const {
|
||||
return llvm::inconvertibleErrorCode();
|
||||
}
|
||||
|
||||
lldb_protocol::mcp::Error MCPError::toProtocolError() const {
|
||||
lldb_protocol::mcp::Error error;
|
||||
error.code = m_error_code;
|
||||
error.message = m_message;
|
||||
return error;
|
||||
return std::error_code(m_error_code, std::generic_category());
|
||||
}
|
||||
|
||||
UnsupportedURI::UnsupportedURI(std::string uri) : m_uri(uri) {}
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "lldb/Host/HostInfo.h"
|
||||
#include "lldb/Protocol/MCP/MCPError.h"
|
||||
#include "lldb/Protocol/MCP/Protocol.h"
|
||||
#include "lldb/Protocol/MCP/Transport.h"
|
||||
#include "llvm/ADT/SmallString.h"
|
||||
#include "llvm/Support/FileSystem.h"
|
||||
#include "llvm/Support/JSON.h"
|
||||
@@ -108,48 +109,9 @@ Expected<std::vector<ServerInfo>> ServerInfo::Load() {
|
||||
return infos;
|
||||
}
|
||||
|
||||
Server::Server(std::string name, std::string version, MCPTransport &client,
|
||||
LogCallback log_callback, ClosedCallback closed_callback)
|
||||
: m_name(std::move(name)), m_version(std::move(version)), m_client(client),
|
||||
m_log_callback(std::move(log_callback)),
|
||||
m_closed_callback(std::move(closed_callback)) {
|
||||
AddRequestHandlers();
|
||||
}
|
||||
|
||||
void Server::AddRequestHandlers() {
|
||||
AddRequestHandler("initialize", std::bind(&Server::InitializeHandler, this,
|
||||
std::placeholders::_1));
|
||||
AddRequestHandler("tools/list", std::bind(&Server::ToolsListHandler, this,
|
||||
std::placeholders::_1));
|
||||
AddRequestHandler("tools/call", std::bind(&Server::ToolsCallHandler, this,
|
||||
std::placeholders::_1));
|
||||
AddRequestHandler("resources/list", std::bind(&Server::ResourcesListHandler,
|
||||
this, std::placeholders::_1));
|
||||
AddRequestHandler("resources/read", std::bind(&Server::ResourcesReadHandler,
|
||||
this, std::placeholders::_1));
|
||||
}
|
||||
|
||||
llvm::Expected<Response> Server::Handle(const Request &request) {
|
||||
auto it = m_request_handlers.find(request.method);
|
||||
if (it != m_request_handlers.end()) {
|
||||
llvm::Expected<Response> response = it->second(request);
|
||||
if (!response)
|
||||
return response;
|
||||
response->id = request.id;
|
||||
return *response;
|
||||
}
|
||||
|
||||
return llvm::make_error<MCPError>(
|
||||
llvm::formatv("no handler for request: {0}", request.method).str());
|
||||
}
|
||||
|
||||
void Server::Handle(const Notification ¬ification) {
|
||||
auto it = m_notification_handlers.find(notification.method);
|
||||
if (it != m_notification_handlers.end()) {
|
||||
it->second(notification);
|
||||
return;
|
||||
}
|
||||
}
|
||||
Server::Server(std::string name, std::string version, LogCallback log_callback)
|
||||
: m_name(std::move(name)), m_version(std::move(version)),
|
||||
m_log_callback(std::move(log_callback)) {}
|
||||
|
||||
void Server::AddTool(std::unique_ptr<Tool> tool) {
|
||||
if (!tool)
|
||||
@@ -164,48 +126,64 @@ void Server::AddResourceProvider(
|
||||
m_resource_providers.push_back(std::move(resource_provider));
|
||||
}
|
||||
|
||||
void Server::AddRequestHandler(llvm::StringRef method, RequestHandler handler) {
|
||||
m_request_handlers[method] = std::move(handler);
|
||||
MCPBinderUP Server::Bind(MCPTransport &transport) {
|
||||
MCPBinderUP binder_up = std::make_unique<MCPBinder>(transport);
|
||||
binder_up->Bind<InitializeResult, InitializeParams>(
|
||||
"initialize", &Server::InitializeHandler, this);
|
||||
binder_up->Bind<ListToolsResult, void>("tools/list",
|
||||
&Server::ToolsListHandler, this);
|
||||
binder_up->Bind<CallToolResult, CallToolParams>(
|
||||
"tools/call", &Server::ToolsCallHandler, this);
|
||||
binder_up->Bind<ListResourcesResult, void>(
|
||||
"resources/list", &Server::ResourcesListHandler, this);
|
||||
binder_up->Bind<ReadResourceResult, ReadResourceParams>(
|
||||
"resources/read", &Server::ResourcesReadHandler, this);
|
||||
binder_up->Bind<void>("notifications/initialized",
|
||||
[this]() { Log("MCP initialization complete"); });
|
||||
return binder_up;
|
||||
}
|
||||
|
||||
void Server::AddNotificationHandler(llvm::StringRef method,
|
||||
NotificationHandler handler) {
|
||||
m_notification_handlers[method] = std::move(handler);
|
||||
llvm::Error Server::Accept(MainLoop &loop, MCPTransportUP transport) {
|
||||
MCPBinderUP binder = Bind(*transport);
|
||||
MCPTransport *transport_ptr = transport.get();
|
||||
binder->OnDisconnect([this, transport_ptr]() {
|
||||
assert(m_instances.find(transport_ptr) != m_instances.end() &&
|
||||
"Client not found in m_instances");
|
||||
m_instances.erase(transport_ptr);
|
||||
});
|
||||
binder->OnError([this](llvm::Error err) {
|
||||
Logv("Transport error: {0}", llvm::toString(std::move(err)));
|
||||
});
|
||||
|
||||
auto handle = transport->RegisterMessageHandler(loop, *binder);
|
||||
if (!handle)
|
||||
return handle.takeError();
|
||||
|
||||
m_instances[transport_ptr] =
|
||||
Client{std::move(*handle), std::move(transport), std::move(binder)};
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Expected<Response> Server::InitializeHandler(const Request &request) {
|
||||
Response response;
|
||||
Expected<InitializeResult>
|
||||
Server::InitializeHandler(const InitializeParams &request) {
|
||||
InitializeResult result;
|
||||
result.protocolVersion = mcp::kProtocolVersion;
|
||||
result.capabilities = GetCapabilities();
|
||||
result.serverInfo.name = m_name;
|
||||
result.serverInfo.version = m_version;
|
||||
response.result = std::move(result);
|
||||
return response;
|
||||
return result;
|
||||
}
|
||||
|
||||
llvm::Expected<Response> Server::ToolsListHandler(const Request &request) {
|
||||
Response response;
|
||||
|
||||
llvm::Expected<ListToolsResult> Server::ToolsListHandler() {
|
||||
ListToolsResult result;
|
||||
for (const auto &tool : m_tools)
|
||||
result.tools.emplace_back(tool.second->GetDefinition());
|
||||
|
||||
response.result = std::move(result);
|
||||
|
||||
return response;
|
||||
return result;
|
||||
}
|
||||
|
||||
llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) {
|
||||
Response response;
|
||||
|
||||
if (!request.params)
|
||||
return llvm::createStringError("no tool parameters");
|
||||
CallToolParams params;
|
||||
json::Path::Root root("params");
|
||||
if (!fromJSON(request.params, params, root))
|
||||
return root.getError();
|
||||
|
||||
llvm::Expected<CallToolResult>
|
||||
Server::ToolsCallHandler(const CallToolParams ¶ms) {
|
||||
llvm::StringRef tool_name = params.name;
|
||||
if (tool_name.empty())
|
||||
return llvm::createStringError("no tool name");
|
||||
@@ -222,113 +200,50 @@ llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) {
|
||||
if (!text_result)
|
||||
return text_result.takeError();
|
||||
|
||||
response.result = toJSON(*text_result);
|
||||
|
||||
return response;
|
||||
return text_result;
|
||||
}
|
||||
|
||||
llvm::Expected<Response> Server::ResourcesListHandler(const Request &request) {
|
||||
Response response;
|
||||
|
||||
llvm::Expected<ListResourcesResult> Server::ResourcesListHandler() {
|
||||
ListResourcesResult result;
|
||||
for (std::unique_ptr<ResourceProvider> &resource_provider_up :
|
||||
m_resource_providers)
|
||||
for (const Resource &resource : resource_provider_up->GetResources())
|
||||
result.resources.push_back(resource);
|
||||
|
||||
response.result = std::move(result);
|
||||
|
||||
return response;
|
||||
return result;
|
||||
}
|
||||
|
||||
llvm::Expected<Response> Server::ResourcesReadHandler(const Request &request) {
|
||||
Response response;
|
||||
|
||||
if (!request.params)
|
||||
return llvm::createStringError("no resource parameters");
|
||||
|
||||
ReadResourceParams params;
|
||||
json::Path::Root root("params");
|
||||
if (!fromJSON(request.params, params, root))
|
||||
return root.getError();
|
||||
|
||||
llvm::StringRef uri_str = params.uri;
|
||||
Expected<ReadResourceResult>
|
||||
Server::ResourcesReadHandler(const ReadResourceParams ¶ms) {
|
||||
StringRef uri_str = params.uri;
|
||||
if (uri_str.empty())
|
||||
return llvm::createStringError("no resource uri");
|
||||
return createStringError("no resource uri");
|
||||
|
||||
for (std::unique_ptr<ResourceProvider> &resource_provider_up :
|
||||
m_resource_providers) {
|
||||
llvm::Expected<ReadResourceResult> result =
|
||||
Expected<ReadResourceResult> result =
|
||||
resource_provider_up->ReadResource(uri_str);
|
||||
if (result.errorIsA<UnsupportedURI>()) {
|
||||
llvm::consumeError(result.takeError());
|
||||
consumeError(result.takeError());
|
||||
continue;
|
||||
}
|
||||
if (!result)
|
||||
return result.takeError();
|
||||
|
||||
Response response;
|
||||
response.result = std::move(*result);
|
||||
return response;
|
||||
return *result;
|
||||
}
|
||||
|
||||
return make_error<MCPError>(
|
||||
llvm::formatv("no resource handler for uri: {0}", uri_str).str(),
|
||||
formatv("no resource handler for uri: {0}", uri_str).str(),
|
||||
MCPError::kResourceNotFound);
|
||||
}
|
||||
|
||||
ServerCapabilities Server::GetCapabilities() {
|
||||
lldb_protocol::mcp::ServerCapabilities capabilities;
|
||||
capabilities.supportsToolsList = true;
|
||||
capabilities.supportsResourcesList = true;
|
||||
// FIXME: Support sending notifications when a debugger/target are
|
||||
// added/removed.
|
||||
capabilities.supportsResourcesList = false;
|
||||
capabilities.supportsResourcesSubscribe = false;
|
||||
return capabilities;
|
||||
}
|
||||
|
||||
void Server::Log(llvm::StringRef message) {
|
||||
if (m_log_callback)
|
||||
m_log_callback(message);
|
||||
}
|
||||
|
||||
void Server::Received(const Request &request) {
|
||||
auto SendResponse = [this](const Response &response) {
|
||||
if (llvm::Error error = m_client.Send(response))
|
||||
Log(llvm::toString(std::move(error)));
|
||||
};
|
||||
|
||||
llvm::Expected<Response> response = Handle(request);
|
||||
if (response)
|
||||
return SendResponse(*response);
|
||||
|
||||
lldb_protocol::mcp::Error protocol_error;
|
||||
llvm::handleAllErrors(
|
||||
response.takeError(),
|
||||
[&](const MCPError &err) { protocol_error = err.toProtocolError(); },
|
||||
[&](const llvm::ErrorInfoBase &err) {
|
||||
protocol_error.code = MCPError::kInternalError;
|
||||
protocol_error.message = err.message();
|
||||
});
|
||||
Response error_response;
|
||||
error_response.id = request.id;
|
||||
error_response.result = std::move(protocol_error);
|
||||
SendResponse(error_response);
|
||||
}
|
||||
|
||||
void Server::Received(const Response &response) {
|
||||
Log("unexpected MCP message: response");
|
||||
}
|
||||
|
||||
void Server::Received(const Notification ¬ification) {
|
||||
Handle(notification);
|
||||
}
|
||||
|
||||
void Server::OnError(llvm::Error error) {
|
||||
Log(llvm::toString(std::move(error)));
|
||||
}
|
||||
|
||||
void Server::OnClosed() {
|
||||
Log("EOF");
|
||||
if (m_closed_callback)
|
||||
m_closed_callback();
|
||||
}
|
||||
|
||||
@@ -78,11 +78,9 @@ enum DAPBroadcasterBits {
|
||||
|
||||
enum class ReplMode { Variable = 0, Command, Auto };
|
||||
|
||||
using DAPTransport =
|
||||
lldb_private::Transport<protocol::Request, protocol::Response,
|
||||
protocol::Event>;
|
||||
using DAPTransport = lldb_private::transport::JSONTransport<ProtocolDescriptor>;
|
||||
|
||||
struct DAP final : private DAPTransport::MessageHandler {
|
||||
struct DAP final : public DAPTransport::MessageHandler {
|
||||
/// Path to the lldb-dap binary itself.
|
||||
static llvm::StringRef debug_adapter_path;
|
||||
|
||||
|
||||
@@ -30,6 +30,8 @@ namespace lldb_dap::protocol {
|
||||
|
||||
// MARK: Base Protocol
|
||||
|
||||
using Id = int64_t;
|
||||
|
||||
/// A client or debug adapter initiated request.
|
||||
struct Request {
|
||||
/// Sequence number of the message (also known as message ID). The `seq` for
|
||||
@@ -39,7 +41,7 @@ struct Request {
|
||||
/// associate requests with their corresponding responses. For protocol
|
||||
/// messages of type `request` the sequence number can be used to cancel the
|
||||
/// request.
|
||||
int64_t seq;
|
||||
Id seq;
|
||||
|
||||
/// The command to execute.
|
||||
std::string command;
|
||||
@@ -76,7 +78,7 @@ enum ResponseMessage : unsigned {
|
||||
/// Response for a request.
|
||||
struct Response {
|
||||
/// Sequence number of the corresponding request.
|
||||
int64_t request_seq;
|
||||
Id request_seq;
|
||||
|
||||
/// The command requested.
|
||||
std::string command;
|
||||
|
||||
@@ -22,11 +22,18 @@
|
||||
|
||||
namespace lldb_dap {
|
||||
|
||||
struct ProtocolDescriptor {
|
||||
using Id = protocol::Id;
|
||||
using Req = protocol::Request;
|
||||
using Resp = protocol::Response;
|
||||
using Evt = protocol::Event;
|
||||
};
|
||||
|
||||
/// A transport class that performs the Debug Adapter Protocol communication
|
||||
/// with the client.
|
||||
class Transport final
|
||||
: public lldb_private::HTTPDelimitedJSONTransport<
|
||||
protocol::Request, protocol::Response, protocol::Event> {
|
||||
: public lldb_private::transport::HTTPDelimitedJSONTransport<
|
||||
ProtocolDescriptor> {
|
||||
public:
|
||||
Transport(llvm::StringRef client_name, lldb_dap::Log *log,
|
||||
lldb::IOObjectSP input, lldb::IOObjectSP output);
|
||||
|
||||
@@ -9,13 +9,10 @@
|
||||
#include "DAP.h"
|
||||
#include "Protocol/ProtocolBase.h"
|
||||
#include "TestBase.h"
|
||||
#include "llvm/Testing/Support/Error.h"
|
||||
#include "gmock/gmock.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include <optional>
|
||||
|
||||
using namespace llvm;
|
||||
using namespace lldb;
|
||||
using namespace lldb_dap;
|
||||
using namespace lldb_dap_tests;
|
||||
using namespace lldb_dap::protocol;
|
||||
@@ -24,18 +21,7 @@ using namespace testing;
|
||||
class DAPTest : public TransportBase {};
|
||||
|
||||
TEST_F(DAPTest, SendProtocolMessages) {
|
||||
DAP dap{
|
||||
/*log=*/nullptr,
|
||||
/*default_repl_mode=*/ReplMode::Auto,
|
||||
/*pre_init_commands=*/{},
|
||||
/*no_lldbinit=*/false,
|
||||
/*client_name=*/"test_client",
|
||||
/*transport=*/*transport,
|
||||
/*loop=*/loop,
|
||||
};
|
||||
dap.Send(Event{/*event=*/"my-event", /*body=*/std::nullopt});
|
||||
loop.AddPendingCallback(
|
||||
[](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); });
|
||||
EXPECT_CALL(client, Received(IsEvent("my-event", std::nullopt)));
|
||||
ASSERT_THAT_ERROR(dap.Loop(), llvm::Succeeded());
|
||||
dap->Send(Event{/*event=*/"my-event", /*body=*/std::nullopt});
|
||||
EXPECT_CALL(client, Received(IsEvent("my-event")));
|
||||
Run();
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminated) {
|
||||
DisconnectRequestHandler handler(*dap);
|
||||
ASSERT_THAT_ERROR(handler.Run(std::nullopt), Succeeded());
|
||||
EXPECT_CALL(client, Received(IsEvent("terminated", _)));
|
||||
RunOnce();
|
||||
Run();
|
||||
}
|
||||
|
||||
TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminateCommands) {
|
||||
@@ -53,5 +53,5 @@ TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminateCommands) {
|
||||
EXPECT_CALL(client, Received(Output("(lldb) script print(2)\n")));
|
||||
EXPECT_CALL(client, Received(Output("Running terminateCommands:\n")));
|
||||
EXPECT_CALL(client, Received(IsEvent("terminated", _)));
|
||||
RunOnce();
|
||||
Run();
|
||||
}
|
||||
|
||||
@@ -32,23 +32,9 @@ using lldb_private::FileSystem;
|
||||
using lldb_private::MainLoop;
|
||||
using lldb_private::Pipe;
|
||||
|
||||
Expected<MainLoop::ReadHandleUP>
|
||||
TestTransport::RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) {
|
||||
Expected<lldb::FileUP> dummy_file = FileSystem::Instance().Open(
|
||||
FileSpec(FileSystem::DEV_NULL), File::eOpenOptionReadWrite);
|
||||
if (!dummy_file)
|
||||
return dummy_file.takeError();
|
||||
m_dummy_file = std::move(*dummy_file);
|
||||
lldb_private::Status status;
|
||||
auto handle = loop.RegisterReadObject(
|
||||
m_dummy_file, [](lldb_private::MainLoopBase &) {}, status);
|
||||
if (status.Fail())
|
||||
return status.takeError();
|
||||
return handle;
|
||||
}
|
||||
void TransportBase::SetUp() {
|
||||
std::tie(to_client, to_server) = TestDAPTransport::createPair();
|
||||
|
||||
void DAPTestBase::SetUp() {
|
||||
TransportBase::SetUp();
|
||||
std::error_code EC;
|
||||
log = std::make_unique<Log>("-", EC);
|
||||
dap = std::make_unique<DAP>(
|
||||
@@ -57,16 +43,30 @@ void DAPTestBase::SetUp() {
|
||||
/*pre_init_commands=*/std::vector<std::string>(),
|
||||
/*no_lldbinit=*/false,
|
||||
/*client_name=*/"test_client",
|
||||
/*transport=*/*transport, /*loop=*/loop);
|
||||
/*transport=*/*to_client, /*loop=*/loop);
|
||||
|
||||
auto server_handle = to_server->RegisterMessageHandler(loop, *dap.get());
|
||||
EXPECT_THAT_EXPECTED(server_handle, Succeeded());
|
||||
handles[0] = std::move(*server_handle);
|
||||
|
||||
auto client_handle = to_client->RegisterMessageHandler(loop, client);
|
||||
EXPECT_THAT_EXPECTED(client_handle, Succeeded());
|
||||
handles[1] = std::move(*client_handle);
|
||||
}
|
||||
|
||||
void TransportBase::Run() {
|
||||
loop.AddPendingCallback(
|
||||
[](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); });
|
||||
EXPECT_THAT_ERROR(loop.Run().takeError(), llvm::Succeeded());
|
||||
}
|
||||
|
||||
void DAPTestBase::SetUp() { TransportBase::SetUp(); }
|
||||
|
||||
void DAPTestBase::TearDown() {
|
||||
if (core) {
|
||||
if (core)
|
||||
ASSERT_THAT_ERROR(core->discard(), Succeeded());
|
||||
}
|
||||
if (binary) {
|
||||
if (binary)
|
||||
ASSERT_THAT_ERROR(binary->discard(), Succeeded());
|
||||
}
|
||||
}
|
||||
|
||||
void DAPTestBase::SetUpTestSuite() {
|
||||
|
||||
@@ -7,73 +7,48 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "DAP.h"
|
||||
#include "DAPLog.h"
|
||||
#include "Protocol/ProtocolBase.h"
|
||||
#include "TestingSupport/Host/JSONTransportTestUtilities.h"
|
||||
#include "TestingSupport/SubsystemRAII.h"
|
||||
#include "Transport.h"
|
||||
#include "lldb/Host/FileSystem.h"
|
||||
#include "lldb/Host/HostInfo.h"
|
||||
#include "lldb/Host/MainLoop.h"
|
||||
#include "lldb/Host/MainLoopBase.h"
|
||||
#include "lldb/lldb-forward.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Error.h"
|
||||
#include "llvm/Support/FileSystem.h"
|
||||
#include "llvm/Support/JSON.h"
|
||||
#include "llvm/Testing/Support/Error.h"
|
||||
#include "gmock/gmock.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
||||
/// Helpers for gtest printing.
|
||||
namespace lldb_dap::protocol {
|
||||
|
||||
inline void PrintTo(const Request &req, std::ostream *os) {
|
||||
*os << llvm::formatv("{0}", toJSON(req)).str();
|
||||
}
|
||||
|
||||
inline void PrintTo(const Response &resp, std::ostream *os) {
|
||||
*os << llvm::formatv("{0}", toJSON(resp)).str();
|
||||
}
|
||||
|
||||
inline void PrintTo(const Event &evt, std::ostream *os) {
|
||||
*os << llvm::formatv("{0}", toJSON(evt)).str();
|
||||
}
|
||||
|
||||
inline void PrintTo(const Message &message, std::ostream *os) {
|
||||
return std::visit([os](auto &&message) { return PrintTo(message, os); },
|
||||
message);
|
||||
}
|
||||
|
||||
} // namespace lldb_dap::protocol
|
||||
|
||||
namespace lldb_dap_tests {
|
||||
|
||||
class TestTransport final
|
||||
: public lldb_private::Transport<lldb_dap::protocol::Request,
|
||||
lldb_dap::protocol::Response,
|
||||
lldb_dap::protocol::Event> {
|
||||
public:
|
||||
using Message = lldb_private::Transport<lldb_dap::protocol::Request,
|
||||
lldb_dap::protocol::Response,
|
||||
lldb_dap::protocol::Event>::Message;
|
||||
|
||||
TestTransport(lldb_private::MainLoop &loop, MessageHandler &handler)
|
||||
: m_loop(loop), m_handler(handler) {}
|
||||
|
||||
llvm::Error Send(const lldb_dap::protocol::Event &e) override {
|
||||
m_loop.AddPendingCallback([this, e](lldb_private::MainLoopBase &) {
|
||||
this->m_handler.Received(e);
|
||||
});
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error Send(const lldb_dap::protocol::Request &r) override {
|
||||
m_loop.AddPendingCallback([this, r](lldb_private::MainLoopBase &) {
|
||||
this->m_handler.Received(r);
|
||||
});
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error Send(const lldb_dap::protocol::Response &r) override {
|
||||
m_loop.AddPendingCallback([this, r](lldb_private::MainLoopBase &) {
|
||||
this->m_handler.Received(r);
|
||||
});
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Expected<lldb_private::MainLoop::ReadHandleUP>
|
||||
RegisterMessageHandler(lldb_private::MainLoop &loop,
|
||||
MessageHandler &handler) override;
|
||||
|
||||
void Log(llvm::StringRef message) override {
|
||||
log_messages.emplace_back(message);
|
||||
}
|
||||
|
||||
std::vector<std::string> log_messages;
|
||||
|
||||
private:
|
||||
lldb_private::MainLoop &m_loop;
|
||||
MessageHandler &m_handler;
|
||||
lldb::FileSP m_dummy_file;
|
||||
};
|
||||
using TestDAPTransport = TestTransport<lldb_dap::ProtocolDescriptor>;
|
||||
|
||||
/// A base class for tests that need transport configured for communicating DAP
|
||||
/// messages.
|
||||
@@ -82,22 +57,36 @@ protected:
|
||||
lldb_private::SubsystemRAII<lldb_private::FileSystem, lldb_private::HostInfo>
|
||||
subsystems;
|
||||
lldb_private::MainLoop loop;
|
||||
std::unique_ptr<TestTransport> transport;
|
||||
MockMessageHandler<lldb_dap::protocol::Request, lldb_dap::protocol::Response,
|
||||
lldb_dap::protocol::Event>
|
||||
client;
|
||||
lldb_private::MainLoop::ReadHandleUP handles[2];
|
||||
|
||||
void SetUp() override {
|
||||
transport = std::make_unique<TestTransport>(loop, client);
|
||||
}
|
||||
std::unique_ptr<lldb_dap::Log> log;
|
||||
|
||||
std::unique_ptr<TestDAPTransport> to_client;
|
||||
MockMessageHandler<lldb_dap::ProtocolDescriptor> client;
|
||||
|
||||
std::unique_ptr<TestDAPTransport> to_server;
|
||||
std::unique_ptr<lldb_dap::DAP> dap;
|
||||
|
||||
void SetUp() override;
|
||||
|
||||
void Run();
|
||||
};
|
||||
|
||||
/// A matcher for a DAP event.
|
||||
template <typename M1, typename M2>
|
||||
template <typename EventMatcher, typename BodyMatcher>
|
||||
inline testing::Matcher<const lldb_dap::protocol::Event &>
|
||||
IsEvent(const M1 &m1, const M2 &m2) {
|
||||
return testing::AllOf(testing::Field(&lldb_dap::protocol::Event::event, m1),
|
||||
testing::Field(&lldb_dap::protocol::Event::body, m2));
|
||||
IsEvent(const EventMatcher &event_matcher, const BodyMatcher &body_matcher) {
|
||||
return testing::AllOf(
|
||||
testing::Field(&lldb_dap::protocol::Event::event, event_matcher),
|
||||
testing::Field(&lldb_dap::protocol::Event::body, body_matcher));
|
||||
}
|
||||
|
||||
template <typename EventMatcher>
|
||||
inline testing::Matcher<const lldb_dap::protocol::Event &>
|
||||
IsEvent(const EventMatcher &event_matcher) {
|
||||
return testing::AllOf(
|
||||
testing::Field(&lldb_dap::protocol::Event::event, event_matcher),
|
||||
testing::Field(&lldb_dap::protocol::Event::body, std::nullopt));
|
||||
}
|
||||
|
||||
/// Matches an "output" event.
|
||||
@@ -110,8 +99,6 @@ inline auto Output(llvm::StringRef o, llvm::StringRef cat = "console") {
|
||||
/// A base class for tests that interact with a `lldb_dap::DAP` instance.
|
||||
class DAPTestBase : public TransportBase {
|
||||
protected:
|
||||
std::unique_ptr<lldb_dap::Log> log;
|
||||
std::unique_ptr<lldb_dap::DAP> dap;
|
||||
std::optional<llvm::sys::fs::TempFile> core;
|
||||
std::optional<llvm::sys::fs::TempFile> binary;
|
||||
|
||||
@@ -126,12 +113,6 @@ protected:
|
||||
bool GetDebuggerSupportsTarget(llvm::StringRef platform);
|
||||
void CreateDebugger();
|
||||
void LoadCore();
|
||||
|
||||
void RunOnce() {
|
||||
loop.AddPendingCallback(
|
||||
[](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); });
|
||||
ASSERT_THAT_ERROR(dap->Loop(), llvm::Succeeded());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace lldb_dap_tests
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "lldb/Host/JSONTransport.h"
|
||||
#include "TestingSupport/Host/JSONTransportTestUtilities.h"
|
||||
#include "TestingSupport/Host/PipeTestUtilities.h"
|
||||
#include "TestingSupport/SubsystemRAII.h"
|
||||
#include "lldb/Host/File.h"
|
||||
#include "lldb/Host/MainLoop.h"
|
||||
#include "lldb/Host/MainLoopBase.h"
|
||||
@@ -25,27 +26,45 @@
|
||||
#include <chrono>
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <system_error>
|
||||
|
||||
using namespace llvm;
|
||||
using namespace lldb_private;
|
||||
using namespace lldb_private::transport;
|
||||
using testing::_;
|
||||
using testing::HasSubstr;
|
||||
using testing::InSequence;
|
||||
using testing::Ref;
|
||||
|
||||
namespace llvm::json {
|
||||
static bool fromJSON(const Value &V, Value &T, Path P) {
|
||||
T = V;
|
||||
return true;
|
||||
}
|
||||
} // namespace llvm::json
|
||||
|
||||
namespace {
|
||||
|
||||
namespace test_protocol {
|
||||
|
||||
struct Req {
|
||||
int id = 0;
|
||||
std::string name;
|
||||
std::optional<json::Value> params;
|
||||
};
|
||||
json::Value toJSON(const Req &T) { return json::Object{{"req", T.name}}; }
|
||||
json::Value toJSON(const Req &T) {
|
||||
return json::Object{{"name", T.name}, {"id", T.id}, {"params", T.params}};
|
||||
}
|
||||
bool fromJSON(const json::Value &V, Req &T, json::Path P) {
|
||||
json::ObjectMapper O(V, P);
|
||||
return O && O.map("req", T.name);
|
||||
return O && O.map("name", T.name) && O.map("id", T.id) &&
|
||||
O.map("params", T.params);
|
||||
}
|
||||
bool operator==(const Req &a, const Req &b) {
|
||||
return a.name == b.name && a.id == b.id && a.params == b.params;
|
||||
}
|
||||
bool operator==(const Req &a, const Req &b) { return a.name == b.name; }
|
||||
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Req &V) {
|
||||
OS << toJSON(V);
|
||||
return OS;
|
||||
@@ -58,14 +77,22 @@ void PrintTo(const Req &message, std::ostream *os) {
|
||||
}
|
||||
|
||||
struct Resp {
|
||||
std::string name;
|
||||
int id = 0;
|
||||
int errorCode = 0;
|
||||
std::optional<json::Value> result;
|
||||
};
|
||||
json::Value toJSON(const Resp &T) { return json::Object{{"resp", T.name}}; }
|
||||
json::Value toJSON(const Resp &T) {
|
||||
return json::Object{
|
||||
{"id", T.id}, {"errorCode", T.errorCode}, {"result", T.result}};
|
||||
}
|
||||
bool fromJSON(const json::Value &V, Resp &T, json::Path P) {
|
||||
json::ObjectMapper O(V, P);
|
||||
return O && O.map("resp", T.name);
|
||||
return O && O.map("id", T.id) && O.mapOptional("errorCode", T.errorCode) &&
|
||||
O.map("result", T.result);
|
||||
}
|
||||
bool operator==(const Resp &a, const Resp &b) {
|
||||
return a.id == b.id && a.errorCode == b.errorCode && a.result == b.result;
|
||||
}
|
||||
bool operator==(const Resp &a, const Resp &b) { return a.name == b.name; }
|
||||
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Resp &V) {
|
||||
OS << toJSON(V);
|
||||
return OS;
|
||||
@@ -79,11 +106,14 @@ void PrintTo(const Resp &message, std::ostream *os) {
|
||||
|
||||
struct Evt {
|
||||
std::string name;
|
||||
std::optional<json::Value> params;
|
||||
};
|
||||
json::Value toJSON(const Evt &T) { return json::Object{{"evt", T.name}}; }
|
||||
json::Value toJSON(const Evt &T) {
|
||||
return json::Object{{"name", T.name}, {"params", T.params}};
|
||||
}
|
||||
bool fromJSON(const json::Value &V, Evt &T, json::Path P) {
|
||||
json::ObjectMapper O(V, P);
|
||||
return O && O.map("evt", T.name);
|
||||
return O && O.map("name", T.name) && O.map("params", T.params);
|
||||
}
|
||||
bool operator==(const Evt &a, const Evt &b) { return a.name == b.name; }
|
||||
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Evt &V) {
|
||||
@@ -107,23 +137,8 @@ bool fromJSON(const json::Value &V, Message &msg, json::Path P) {
|
||||
P.report("expected object");
|
||||
return false;
|
||||
}
|
||||
if (O->get("req")) {
|
||||
Req R;
|
||||
if (!fromJSON(V, R, P))
|
||||
return false;
|
||||
|
||||
msg = std::move(R);
|
||||
return true;
|
||||
}
|
||||
if (O->get("resp")) {
|
||||
Resp R;
|
||||
if (!fromJSON(V, R, P))
|
||||
return false;
|
||||
|
||||
msg = std::move(R);
|
||||
return true;
|
||||
}
|
||||
if (O->get("evt")) {
|
||||
if (O->find("id") == O->end()) {
|
||||
Evt E;
|
||||
if (!fromJSON(V, E, P))
|
||||
return false;
|
||||
@@ -131,17 +146,105 @@ bool fromJSON(const json::Value &V, Message &msg, json::Path P) {
|
||||
msg = std::move(E);
|
||||
return true;
|
||||
}
|
||||
P.report("unknown message type");
|
||||
return false;
|
||||
|
||||
if (O->get("name")) {
|
||||
Req R;
|
||||
if (!fromJSON(V, R, P))
|
||||
return false;
|
||||
|
||||
msg = std::move(R);
|
||||
return true;
|
||||
}
|
||||
|
||||
Resp R;
|
||||
if (!fromJSON(V, R, P))
|
||||
return false;
|
||||
|
||||
msg = std::move(R);
|
||||
return true;
|
||||
}
|
||||
|
||||
struct MyFnParams {
|
||||
int a = 0;
|
||||
int b = 0;
|
||||
};
|
||||
json::Value toJSON(const MyFnParams &T) {
|
||||
return json::Object{{"a", T.a}, {"b", T.b}};
|
||||
}
|
||||
bool fromJSON(const json::Value &V, MyFnParams &T, json::Path P) {
|
||||
json::ObjectMapper O(V, P);
|
||||
return O && O.map("a", T.a) && O.map("b", T.b);
|
||||
}
|
||||
|
||||
struct MyFnResult {
|
||||
int c = 0;
|
||||
};
|
||||
json::Value toJSON(const MyFnResult &T) { return json::Object{{"c", T.c}}; }
|
||||
bool fromJSON(const json::Value &V, MyFnResult &T, json::Path P) {
|
||||
json::ObjectMapper O(V, P);
|
||||
return O && O.map("c", T.c);
|
||||
}
|
||||
|
||||
struct ProtoDesc {
|
||||
using Id = int;
|
||||
using Req = Req;
|
||||
using Resp = Resp;
|
||||
using Evt = Evt;
|
||||
|
||||
static inline Id InitialId() { return 0; }
|
||||
static inline Req Make(Id id, llvm::StringRef method,
|
||||
std::optional<llvm::json::Value> params) {
|
||||
return Req{id, method.str(), params};
|
||||
}
|
||||
static inline Evt Make(llvm::StringRef method,
|
||||
std::optional<llvm::json::Value> params) {
|
||||
return Evt{method.str(), params};
|
||||
}
|
||||
static inline Resp Make(Req req, llvm::Error error) {
|
||||
Resp resp;
|
||||
resp.id = req.id;
|
||||
llvm::handleAllErrors(
|
||||
std::move(error), [&](const llvm::ErrorInfoBase &err) {
|
||||
std::error_code cerr = err.convertToErrorCode();
|
||||
resp.errorCode =
|
||||
cerr == llvm::inconvertibleErrorCode() ? 1 : cerr.value();
|
||||
resp.result = err.message();
|
||||
});
|
||||
return resp;
|
||||
}
|
||||
static inline Resp Make(Req req, std::optional<llvm::json::Value> result) {
|
||||
return Resp{req.id, 0, std::move(result)};
|
||||
}
|
||||
static inline Id KeyFor(Resp r) { return r.id; }
|
||||
static inline std::string KeyFor(Req r) { return r.name; }
|
||||
static inline std::string KeyFor(Evt e) { return e.name; }
|
||||
static inline std::optional<llvm::json::Value> Extract(Req r) {
|
||||
return r.params;
|
||||
}
|
||||
static inline llvm::Expected<llvm::json::Value> Extract(Resp r) {
|
||||
if (r.errorCode != 0)
|
||||
return llvm::createStringError(
|
||||
std::error_code(r.errorCode, std::generic_category()),
|
||||
r.result && r.result->getAsString() ? *r.result->getAsString()
|
||||
: "no-message");
|
||||
return r.result;
|
||||
}
|
||||
static inline std::optional<llvm::json::Value> Extract(Evt e) {
|
||||
return e.params;
|
||||
}
|
||||
};
|
||||
|
||||
using Transport = TestTransport<ProtoDesc>;
|
||||
using Binder = lldb_private::transport::Binder<ProtoDesc>;
|
||||
using MessageHandler = MockMessageHandler<ProtoDesc>;
|
||||
|
||||
} // namespace test_protocol
|
||||
|
||||
template <typename T, typename Req, typename Resp, typename Evt>
|
||||
class JSONTransportTest : public PipePairTest {
|
||||
|
||||
template <typename T> class JSONTransportTest : public PipePairTest {
|
||||
protected:
|
||||
MockMessageHandler<Req, Resp, Evt> message_handler;
|
||||
SubsystemRAII<FileSystem> subsystems;
|
||||
|
||||
test_protocol::MessageHandler message_handler;
|
||||
std::unique_ptr<T> transport;
|
||||
MainLoop loop;
|
||||
|
||||
@@ -191,8 +294,7 @@ protected:
|
||||
};
|
||||
|
||||
class TestHTTPDelimitedJSONTransport final
|
||||
: public HTTPDelimitedJSONTransport<test_protocol::Req, test_protocol::Resp,
|
||||
test_protocol::Evt> {
|
||||
: public HTTPDelimitedJSONTransport<test_protocol::ProtoDesc> {
|
||||
public:
|
||||
using HTTPDelimitedJSONTransport::HTTPDelimitedJSONTransport;
|
||||
|
||||
@@ -204,9 +306,7 @@ public:
|
||||
};
|
||||
|
||||
class HTTPDelimitedJSONTransportTest
|
||||
: public JSONTransportTest<TestHTTPDelimitedJSONTransport,
|
||||
test_protocol::Req, test_protocol::Resp,
|
||||
test_protocol::Evt> {
|
||||
: public JSONTransportTest<TestHTTPDelimitedJSONTransport> {
|
||||
public:
|
||||
using JSONTransportTest::JSONTransportTest;
|
||||
|
||||
@@ -222,8 +322,7 @@ public:
|
||||
};
|
||||
|
||||
class TestJSONRPCTransport final
|
||||
: public JSONRPCTransport<test_protocol::Req, test_protocol::Resp,
|
||||
test_protocol::Evt> {
|
||||
: public JSONRPCTransport<test_protocol::ProtoDesc> {
|
||||
public:
|
||||
using JSONRPCTransport::JSONRPCTransport;
|
||||
|
||||
@@ -234,9 +333,7 @@ public:
|
||||
std::vector<std::string> log_messages;
|
||||
};
|
||||
|
||||
class JSONRPCTransportTest
|
||||
: public JSONTransportTest<TestJSONRPCTransport, test_protocol::Req,
|
||||
test_protocol::Resp, test_protocol::Evt> {
|
||||
class JSONRPCTransportTest : public JSONTransportTest<TestJSONRPCTransport> {
|
||||
public:
|
||||
using JSONTransportTest::JSONTransportTest;
|
||||
|
||||
@@ -248,6 +345,33 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class TransportBinderTest : public testing::Test {
|
||||
protected:
|
||||
SubsystemRAII<FileSystem> subsystems;
|
||||
|
||||
std::unique_ptr<test_protocol::Transport> to_remote;
|
||||
std::unique_ptr<test_protocol::Transport> from_remote;
|
||||
std::unique_ptr<test_protocol::Binder> binder;
|
||||
test_protocol::MessageHandler remote;
|
||||
MainLoop loop;
|
||||
|
||||
void SetUp() override {
|
||||
std::tie(to_remote, from_remote) = test_protocol::Transport::createPair();
|
||||
binder = std::make_unique<test_protocol::Binder>(*to_remote);
|
||||
|
||||
auto binder_handle = to_remote->RegisterMessageHandler(loop, remote);
|
||||
EXPECT_THAT_EXPECTED(binder_handle, Succeeded());
|
||||
|
||||
auto remote_handle = from_remote->RegisterMessageHandler(loop, *binder);
|
||||
EXPECT_THAT_EXPECTED(remote_handle, Succeeded());
|
||||
}
|
||||
|
||||
void Run() {
|
||||
loop.AddPendingCallback([](auto &loop) { loop.RequestTermination(); });
|
||||
EXPECT_THAT_ERROR(loop.Run().takeError(), Succeeded());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// Failing on Windows, see https://github.com/llvm/llvm-project/issues/153446.
|
||||
@@ -269,35 +393,45 @@ TEST_F(HTTPDelimitedJSONTransportTest, MalformedRequests) {
|
||||
}
|
||||
|
||||
TEST_F(HTTPDelimitedJSONTransportTest, Read) {
|
||||
Write(Req{"foo"});
|
||||
EXPECT_CALL(message_handler, Received(Req{"foo"}));
|
||||
Write(Req{6, "foo", std::nullopt});
|
||||
EXPECT_CALL(message_handler, Received(Req{6, "foo", std::nullopt}));
|
||||
ASSERT_THAT_ERROR(Run(), Succeeded());
|
||||
}
|
||||
|
||||
TEST_F(HTTPDelimitedJSONTransportTest, ReadMultipleMessagesInSingleWrite) {
|
||||
InSequence seq;
|
||||
Write(Message{Req{"one"}}, Message{Evt{"two"}}, Message{Resp{"three"}});
|
||||
EXPECT_CALL(message_handler, Received(Req{"one"}));
|
||||
EXPECT_CALL(message_handler, Received(Evt{"two"}));
|
||||
EXPECT_CALL(message_handler, Received(Resp{"three"}));
|
||||
Write(
|
||||
Message{
|
||||
Req{6, "one", std::nullopt},
|
||||
},
|
||||
Message{
|
||||
Evt{"two", std::nullopt},
|
||||
},
|
||||
Message{
|
||||
Resp{2, 0, std::nullopt},
|
||||
});
|
||||
EXPECT_CALL(message_handler, Received(Req{6, "one", std::nullopt}));
|
||||
EXPECT_CALL(message_handler, Received(Evt{"two", std::nullopt}));
|
||||
EXPECT_CALL(message_handler, Received(Resp{2, 0, std::nullopt}));
|
||||
ASSERT_THAT_ERROR(Run(), Succeeded());
|
||||
}
|
||||
|
||||
TEST_F(HTTPDelimitedJSONTransportTest, ReadAcrossMultipleChunks) {
|
||||
std::string long_str = std::string(
|
||||
HTTPDelimitedJSONTransport<Req, Resp, Evt>::kReadBufferSize * 2, 'x');
|
||||
Write(Req{long_str});
|
||||
EXPECT_CALL(message_handler, Received(Req{long_str}));
|
||||
HTTPDelimitedJSONTransport<test_protocol::ProtoDesc>::kReadBufferSize * 2,
|
||||
'x');
|
||||
Write(Req{5, long_str, std::nullopt});
|
||||
EXPECT_CALL(message_handler, Received(Req{5, long_str, std::nullopt}));
|
||||
ASSERT_THAT_ERROR(Run(), Succeeded());
|
||||
}
|
||||
|
||||
TEST_F(HTTPDelimitedJSONTransportTest, ReadPartialMessage) {
|
||||
std::string message = Encode(Req{"foo"});
|
||||
std::string message = Encode(Req{5, "foo", std::nullopt});
|
||||
auto split_at = message.size() / 2;
|
||||
std::string part1 = message.substr(0, split_at);
|
||||
std::string part2 = message.substr(split_at);
|
||||
|
||||
EXPECT_CALL(message_handler, Received(Req{"foo"}));
|
||||
EXPECT_CALL(message_handler, Received(Req{5, "foo", std::nullopt}));
|
||||
|
||||
ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded());
|
||||
loop.AddPendingCallback(
|
||||
@@ -309,12 +443,12 @@ TEST_F(HTTPDelimitedJSONTransportTest, ReadPartialMessage) {
|
||||
}
|
||||
|
||||
TEST_F(HTTPDelimitedJSONTransportTest, ReadWithZeroByteWrites) {
|
||||
std::string message = Encode(Req{"foo"});
|
||||
std::string message = Encode(Req{6, "foo", std::nullopt});
|
||||
auto split_at = message.size() / 2;
|
||||
std::string part1 = message.substr(0, split_at);
|
||||
std::string part2 = message.substr(split_at);
|
||||
|
||||
EXPECT_CALL(message_handler, Received(Req{"foo"}));
|
||||
EXPECT_CALL(message_handler, Received(Req{6, "foo", std::nullopt}));
|
||||
|
||||
ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded());
|
||||
|
||||
@@ -366,20 +500,21 @@ TEST_F(HTTPDelimitedJSONTransportTest, InvalidTransport) {
|
||||
}
|
||||
|
||||
TEST_F(HTTPDelimitedJSONTransportTest, Write) {
|
||||
ASSERT_THAT_ERROR(transport->Send(Req{"foo"}), Succeeded());
|
||||
ASSERT_THAT_ERROR(transport->Send(Resp{"bar"}), Succeeded());
|
||||
ASSERT_THAT_ERROR(transport->Send(Evt{"baz"}), Succeeded());
|
||||
ASSERT_THAT_ERROR(transport->Send(Req{7, "foo", std::nullopt}), Succeeded());
|
||||
ASSERT_THAT_ERROR(transport->Send(Resp{5, 0, "bar"}), Succeeded());
|
||||
ASSERT_THAT_ERROR(transport->Send(Evt{"baz", std::nullopt}), Succeeded());
|
||||
output.CloseWriteFileDescriptor();
|
||||
char buf[1024];
|
||||
Expected<size_t> bytes_read =
|
||||
output.Read(buf, sizeof(buf), std::chrono::milliseconds(1));
|
||||
ASSERT_THAT_EXPECTED(bytes_read, Succeeded());
|
||||
ASSERT_EQ(StringRef(buf, *bytes_read), StringRef("Content-Length: 13\r\n\r\n"
|
||||
R"({"req":"foo"})"
|
||||
"Content-Length: 14\r\n\r\n"
|
||||
R"({"resp":"bar"})"
|
||||
"Content-Length: 13\r\n\r\n"
|
||||
R"({"evt":"baz"})"));
|
||||
ASSERT_EQ(StringRef(buf, *bytes_read),
|
||||
StringRef("Content-Length: 35\r\n\r\n"
|
||||
R"({"id":7,"name":"foo","params":null})"
|
||||
"Content-Length: 37\r\n\r\n"
|
||||
R"({"errorCode":0,"id":5,"result":"bar"})"
|
||||
"Content-Length: 28\r\n\r\n"
|
||||
R"({"name":"baz","params":null})"));
|
||||
}
|
||||
|
||||
TEST_F(JSONRPCTransportTest, MalformedRequests) {
|
||||
@@ -395,37 +530,38 @@ TEST_F(JSONRPCTransportTest, MalformedRequests) {
|
||||
}
|
||||
|
||||
TEST_F(JSONRPCTransportTest, Read) {
|
||||
Write(Message{Req{"foo"}});
|
||||
EXPECT_CALL(message_handler, Received(Req{"foo"}));
|
||||
Write(Message{Req{1, "foo", std::nullopt}});
|
||||
EXPECT_CALL(message_handler, Received(Req{1, "foo", std::nullopt}));
|
||||
ASSERT_THAT_ERROR(Run(), Succeeded());
|
||||
}
|
||||
|
||||
TEST_F(JSONRPCTransportTest, ReadMultipleMessagesInSingleWrite) {
|
||||
InSequence seq;
|
||||
Write(Message{Req{"one"}}, Message{Evt{"two"}}, Message{Resp{"three"}});
|
||||
EXPECT_CALL(message_handler, Received(Req{"one"}));
|
||||
EXPECT_CALL(message_handler, Received(Evt{"two"}));
|
||||
EXPECT_CALL(message_handler, Received(Resp{"three"}));
|
||||
Write(Message{Req{1, "one", std::nullopt}}, Message{Evt{"two", std::nullopt}},
|
||||
Message{Resp{3, 0, "three"}});
|
||||
EXPECT_CALL(message_handler, Received(Req{1, "one", std::nullopt}));
|
||||
EXPECT_CALL(message_handler, Received(Evt{"two", std::nullopt}));
|
||||
EXPECT_CALL(message_handler, Received(Resp{3, 0, "three"}));
|
||||
ASSERT_THAT_ERROR(Run(), Succeeded());
|
||||
}
|
||||
|
||||
TEST_F(JSONRPCTransportTest, ReadAcrossMultipleChunks) {
|
||||
// Use a string longer than the chunk size to ensure we split the message
|
||||
// across the chunk boundary.
|
||||
std::string long_str =
|
||||
std::string(IOTransport<Req, Resp, Evt>::kReadBufferSize * 2, 'x');
|
||||
Write(Req{long_str});
|
||||
EXPECT_CALL(message_handler, Received(Req{long_str}));
|
||||
std::string long_str = std::string(
|
||||
IOTransport<test_protocol::ProtoDesc>::kReadBufferSize * 2, 'x');
|
||||
Write(Req{42, long_str, std::nullopt});
|
||||
EXPECT_CALL(message_handler, Received(Req{42, long_str, std::nullopt}));
|
||||
ASSERT_THAT_ERROR(Run(), Succeeded());
|
||||
}
|
||||
|
||||
TEST_F(JSONRPCTransportTest, ReadPartialMessage) {
|
||||
std::string message = R"({"req": "foo"})"
|
||||
std::string message = R"({"id":42,"name":"foo","params":null})"
|
||||
"\n";
|
||||
std::string part1 = message.substr(0, 7);
|
||||
std::string part2 = message.substr(7);
|
||||
|
||||
EXPECT_CALL(message_handler, Received(Req{"foo"}));
|
||||
EXPECT_CALL(message_handler, Received(Req{42, "foo", std::nullopt}));
|
||||
|
||||
ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded());
|
||||
loop.AddPendingCallback(
|
||||
@@ -455,20 +591,21 @@ TEST_F(JSONRPCTransportTest, ReaderWithUnhandledData) {
|
||||
}
|
||||
|
||||
TEST_F(JSONRPCTransportTest, Write) {
|
||||
ASSERT_THAT_ERROR(transport->Send(Req{"foo"}), Succeeded());
|
||||
ASSERT_THAT_ERROR(transport->Send(Resp{"bar"}), Succeeded());
|
||||
ASSERT_THAT_ERROR(transport->Send(Evt{"baz"}), Succeeded());
|
||||
ASSERT_THAT_ERROR(transport->Send(Req{11, "foo", std::nullopt}), Succeeded());
|
||||
ASSERT_THAT_ERROR(transport->Send(Resp{14, 0, "bar"}), Succeeded());
|
||||
ASSERT_THAT_ERROR(transport->Send(Evt{"baz", std::nullopt}), Succeeded());
|
||||
output.CloseWriteFileDescriptor();
|
||||
char buf[1024];
|
||||
Expected<size_t> bytes_read =
|
||||
output.Read(buf, sizeof(buf), std::chrono::milliseconds(1));
|
||||
ASSERT_THAT_EXPECTED(bytes_read, Succeeded());
|
||||
ASSERT_EQ(StringRef(buf, *bytes_read), StringRef(R"({"req":"foo"})"
|
||||
"\n"
|
||||
R"({"resp":"bar"})"
|
||||
"\n"
|
||||
R"({"evt":"baz"})"
|
||||
"\n"));
|
||||
ASSERT_EQ(StringRef(buf, *bytes_read),
|
||||
StringRef(R"({"id":11,"name":"foo","params":null})"
|
||||
"\n"
|
||||
R"({"errorCode":0,"id":14,"result":"bar"})"
|
||||
"\n"
|
||||
R"({"name":"baz","params":null})"
|
||||
"\n"));
|
||||
}
|
||||
|
||||
TEST_F(JSONRPCTransportTest, InvalidTransport) {
|
||||
@@ -477,4 +614,183 @@ TEST_F(JSONRPCTransportTest, InvalidTransport) {
|
||||
FailedWithMessage("IO object is not valid."));
|
||||
}
|
||||
|
||||
// Out-bound binding request handler.
|
||||
TEST_F(TransportBinderTest, OutBoundRequests) {
|
||||
OutgoingRequest<MyFnResult, MyFnParams> addFn =
|
||||
binder->Bind<MyFnResult, MyFnParams>("add");
|
||||
bool replied = false;
|
||||
addFn(MyFnParams{1, 2}, [&](Expected<MyFnResult> result) {
|
||||
EXPECT_THAT_EXPECTED(result, Succeeded());
|
||||
EXPECT_EQ(result->c, 3);
|
||||
replied = true;
|
||||
});
|
||||
EXPECT_CALL(remote, Received(Req{1, "add", MyFnParams{1, 2}}));
|
||||
EXPECT_THAT_ERROR(from_remote->Send(Resp{1, 0, toJSON(MyFnResult{3})}),
|
||||
Succeeded());
|
||||
Run();
|
||||
EXPECT_TRUE(replied);
|
||||
}
|
||||
|
||||
TEST_F(TransportBinderTest, OutBoundRequestsVoidParams) {
|
||||
OutgoingRequest<MyFnResult, void> voidParamFn =
|
||||
binder->Bind<MyFnResult, void>("voidParam");
|
||||
bool replied = false;
|
||||
voidParamFn([&](Expected<MyFnResult> result) {
|
||||
EXPECT_THAT_EXPECTED(result, Succeeded());
|
||||
EXPECT_EQ(result->c, 3);
|
||||
replied = true;
|
||||
});
|
||||
EXPECT_CALL(remote, Received(Req{1, "voidParam", std::nullopt}));
|
||||
EXPECT_THAT_ERROR(from_remote->Send(Resp{1, 0, toJSON(MyFnResult{3})}),
|
||||
Succeeded());
|
||||
Run();
|
||||
EXPECT_TRUE(replied);
|
||||
}
|
||||
|
||||
TEST_F(TransportBinderTest, OutBoundRequestsVoidResult) {
|
||||
OutgoingRequest<void, MyFnParams> voidResultFn =
|
||||
binder->Bind<void, MyFnParams>("voidResult");
|
||||
bool replied = false;
|
||||
voidResultFn(MyFnParams{4, 5}, [&](llvm::Error error) {
|
||||
EXPECT_THAT_ERROR(std::move(error), Succeeded());
|
||||
replied = true;
|
||||
});
|
||||
EXPECT_CALL(remote, Received(Req{1, "voidResult", MyFnParams{4, 5}}));
|
||||
EXPECT_THAT_ERROR(from_remote->Send(Resp{1, 0, std::nullopt}), Succeeded());
|
||||
Run();
|
||||
EXPECT_TRUE(replied);
|
||||
}
|
||||
|
||||
TEST_F(TransportBinderTest, OutBoundRequestsVoidParamsAndVoidResult) {
|
||||
OutgoingRequest<void, void> voidParamAndResultFn =
|
||||
binder->Bind<void, void>("voidParamAndResult");
|
||||
bool replied = false;
|
||||
voidParamAndResultFn([&](llvm::Error error) {
|
||||
EXPECT_THAT_ERROR(std::move(error), Succeeded());
|
||||
replied = true;
|
||||
});
|
||||
EXPECT_CALL(remote, Received(Req{1, "voidParamAndResult", std::nullopt}));
|
||||
EXPECT_THAT_ERROR(from_remote->Send(Resp{1, 0, std::nullopt}), Succeeded());
|
||||
Run();
|
||||
EXPECT_TRUE(replied);
|
||||
}
|
||||
|
||||
// In-bound binding request handler.
|
||||
TEST_F(TransportBinderTest, InBoundRequests) {
|
||||
bool called = false;
|
||||
binder->Bind<MyFnResult, MyFnParams>(
|
||||
"add",
|
||||
[&](const int captured_param,
|
||||
const MyFnParams ¶ms) -> Expected<MyFnResult> {
|
||||
called = true;
|
||||
return MyFnResult{params.a + params.b + captured_param};
|
||||
},
|
||||
2);
|
||||
EXPECT_THAT_ERROR(from_remote->Send(Req{1, "add", MyFnParams{3, 4}}),
|
||||
Succeeded());
|
||||
|
||||
EXPECT_CALL(remote, Received(Resp{1, 0, MyFnResult{9}}));
|
||||
Run();
|
||||
EXPECT_TRUE(called);
|
||||
}
|
||||
|
||||
TEST_F(TransportBinderTest, InBoundRequestsVoidParams) {
|
||||
bool called = false;
|
||||
binder->Bind<MyFnResult, void>(
|
||||
"voidParam",
|
||||
[&](const int captured_param) -> Expected<MyFnResult> {
|
||||
called = true;
|
||||
return MyFnResult{captured_param};
|
||||
},
|
||||
2);
|
||||
EXPECT_THAT_ERROR(from_remote->Send(Req{2, "voidParam", std::nullopt}),
|
||||
Succeeded());
|
||||
EXPECT_CALL(remote, Received(Resp{2, 0, MyFnResult{2}}));
|
||||
Run();
|
||||
EXPECT_TRUE(called);
|
||||
}
|
||||
|
||||
TEST_F(TransportBinderTest, InBoundRequestsVoidResult) {
|
||||
bool called = false;
|
||||
binder->Bind<void, MyFnParams>(
|
||||
"voidResult",
|
||||
[&](const int captured_param, const MyFnParams ¶ms) -> llvm::Error {
|
||||
called = true;
|
||||
EXPECT_EQ(captured_param, 2);
|
||||
EXPECT_EQ(params.a, 3);
|
||||
EXPECT_EQ(params.b, 4);
|
||||
return llvm::Error::success();
|
||||
},
|
||||
2);
|
||||
EXPECT_THAT_ERROR(from_remote->Send(Req{3, "voidResult", MyFnParams{3, 4}}),
|
||||
Succeeded());
|
||||
EXPECT_CALL(remote, Received(Resp{3, 0, std::nullopt}));
|
||||
Run();
|
||||
EXPECT_TRUE(called);
|
||||
}
|
||||
TEST_F(TransportBinderTest, InBoundRequestsVoidParamsAndResult) {
|
||||
bool called = false;
|
||||
binder->Bind<void, void>(
|
||||
"voidParamAndResult",
|
||||
[&](const int captured_param) -> llvm::Error {
|
||||
called = true;
|
||||
EXPECT_EQ(captured_param, 2);
|
||||
return llvm::Error::success();
|
||||
},
|
||||
2);
|
||||
EXPECT_THAT_ERROR(
|
||||
from_remote->Send(Req{4, "voidParamAndResult", std::nullopt}),
|
||||
Succeeded());
|
||||
EXPECT_CALL(remote, Received(Resp{4, 0, std::nullopt}));
|
||||
Run();
|
||||
EXPECT_TRUE(called);
|
||||
}
|
||||
|
||||
// Out-bound binding event handler.
|
||||
TEST_F(TransportBinderTest, OutBoundEvents) {
|
||||
OutgoingEvent<MyFnParams> emitEvent = binder->Bind<MyFnParams>("evt");
|
||||
emitEvent(MyFnParams{1, 2});
|
||||
EXPECT_CALL(remote, Received(Evt{"evt", MyFnParams{1, 2}}));
|
||||
Run();
|
||||
}
|
||||
|
||||
TEST_F(TransportBinderTest, OutBoundEventsVoidParams) {
|
||||
OutgoingEvent<void> emitEvent = binder->Bind<void>("evt");
|
||||
emitEvent();
|
||||
EXPECT_CALL(remote, Received(Evt{"evt", std::nullopt}));
|
||||
Run();
|
||||
}
|
||||
|
||||
// In-bound binding event handler.
|
||||
TEST_F(TransportBinderTest, InBoundEvents) {
|
||||
bool called = false;
|
||||
binder->Bind<MyFnParams>(
|
||||
"evt",
|
||||
[&](const int captured_arg, const MyFnParams ¶ms) {
|
||||
EXPECT_EQ(captured_arg, 42);
|
||||
EXPECT_EQ(params.a, 3);
|
||||
EXPECT_EQ(params.b, 4);
|
||||
called = true;
|
||||
},
|
||||
42);
|
||||
EXPECT_THAT_ERROR(from_remote->Send(Evt{"evt", MyFnParams{3, 4}}),
|
||||
Succeeded());
|
||||
Run();
|
||||
EXPECT_TRUE(called);
|
||||
}
|
||||
|
||||
TEST_F(TransportBinderTest, InBoundEventsVoidParams) {
|
||||
bool called = false;
|
||||
binder->Bind<void>(
|
||||
"evt",
|
||||
[&](const int captured_arg) {
|
||||
EXPECT_EQ(captured_arg, 42);
|
||||
called = true;
|
||||
},
|
||||
42);
|
||||
EXPECT_THAT_ERROR(from_remote->Send(Evt{"evt", std::nullopt}), Succeeded());
|
||||
Run();
|
||||
EXPECT_TRUE(called);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -6,9 +6,8 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "ProtocolMCPTestUtilities.h"
|
||||
#include "ProtocolMCPTestUtilities.h" // IWYU pragma: keep
|
||||
#include "TestingSupport/Host/JSONTransportTestUtilities.h"
|
||||
#include "TestingSupport/Host/PipeTestUtilities.h"
|
||||
#include "TestingSupport/SubsystemRAII.h"
|
||||
#include "lldb/Host/FileSystem.h"
|
||||
#include "lldb/Host/HostInfo.h"
|
||||
@@ -28,20 +27,22 @@
|
||||
#include "llvm/Testing/Support/Error.h"
|
||||
#include "gmock/gmock.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include <chrono>
|
||||
#include <condition_variable>
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <system_error>
|
||||
|
||||
using namespace llvm;
|
||||
using namespace lldb;
|
||||
using namespace lldb_private;
|
||||
using namespace lldb_private::transport;
|
||||
using namespace lldb_protocol::mcp;
|
||||
|
||||
namespace {
|
||||
|
||||
class TestServer : public Server {
|
||||
public:
|
||||
using Server::Server;
|
||||
};
|
||||
template <typename T> Response make_response(T &&result, Id id = 1) {
|
||||
return Response{id, std::forward<T>(result)};
|
||||
}
|
||||
|
||||
/// Test tool that returns it argument as text.
|
||||
class TestTool : public Tool {
|
||||
@@ -101,7 +102,9 @@ public:
|
||||
using Tool::Tool;
|
||||
|
||||
llvm::Expected<CallToolResult> Call(const ToolArguments &args) override {
|
||||
return llvm::createStringError("error");
|
||||
return llvm::createStringError(
|
||||
std::error_code(eErrorCodeInternalError, std::generic_category()),
|
||||
"error");
|
||||
}
|
||||
};
|
||||
|
||||
@@ -118,195 +121,207 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class ProtocolServerMCPTest : public PipePairTest {
|
||||
class TestServer : public Server {
|
||||
public:
|
||||
using Server::Bind;
|
||||
using Server::Server;
|
||||
};
|
||||
|
||||
using Transport = TestTransport<lldb_protocol::mcp::ProtocolDescriptor>;
|
||||
|
||||
class ProtocolServerMCPTest : public testing::Test {
|
||||
public:
|
||||
SubsystemRAII<FileSystem, HostInfo, Socket> subsystems;
|
||||
|
||||
MainLoop loop;
|
||||
lldb_private::MainLoop::ReadHandleUP handles[2];
|
||||
|
||||
std::unique_ptr<lldb_protocol::mcp::Transport> from_client;
|
||||
std::unique_ptr<lldb_protocol::mcp::Transport> to_client;
|
||||
MainLoopBase::ReadHandleUP handles[2];
|
||||
|
||||
std::unique_ptr<Transport> to_server;
|
||||
MCPBinderUP binder;
|
||||
std::unique_ptr<TestServer> server_up;
|
||||
MockMessageHandler<Request, Response, Notification> message_handler;
|
||||
|
||||
llvm::Error Write(llvm::StringRef message) {
|
||||
llvm::Expected<json::Value> value = json::parse(message);
|
||||
if (!value)
|
||||
return value.takeError();
|
||||
return from_client->Write(*value);
|
||||
}
|
||||
std::unique_ptr<Transport> to_client;
|
||||
MockMessageHandler<lldb_protocol::mcp::ProtocolDescriptor> client;
|
||||
|
||||
llvm::Error Write(json::Value value) { return from_client->Write(value); }
|
||||
std::vector<std::string> logged_messages;
|
||||
|
||||
/// Run the transport MainLoop and return any messages received.
|
||||
llvm::Error Run() {
|
||||
loop.AddCallback([](MainLoopBase &loop) { loop.RequestTermination(); },
|
||||
std::chrono::milliseconds(10));
|
||||
return loop.Run().takeError();
|
||||
/// Runs the MainLoop a single time, executing any pending callbacks.
|
||||
void Run() {
|
||||
loop.AddPendingCallback(
|
||||
[](MainLoopBase &loop) { loop.RequestTermination(); });
|
||||
EXPECT_THAT_ERROR(loop.Run().takeError(), Succeeded());
|
||||
}
|
||||
|
||||
void SetUp() override {
|
||||
PipePairTest::SetUp();
|
||||
std::tie(to_client, to_server) = Transport::createPair();
|
||||
|
||||
from_client = std::make_unique<lldb_protocol::mcp::Transport>(
|
||||
std::make_shared<NativeFile>(input.GetReadFileDescriptor(),
|
||||
File::eOpenOptionReadOnly,
|
||||
NativeFile::Unowned),
|
||||
std::make_shared<NativeFile>(output.GetWriteFileDescriptor(),
|
||||
File::eOpenOptionWriteOnly,
|
||||
NativeFile::Unowned),
|
||||
[](StringRef message) {
|
||||
// Uncomment for debugging
|
||||
// llvm::errs() << "from_client: " << message << '\n';
|
||||
});
|
||||
to_client = std::make_unique<lldb_protocol::mcp::Transport>(
|
||||
std::make_shared<NativeFile>(output.GetReadFileDescriptor(),
|
||||
File::eOpenOptionReadOnly,
|
||||
NativeFile::Unowned),
|
||||
std::make_shared<NativeFile>(input.GetWriteFileDescriptor(),
|
||||
File::eOpenOptionWriteOnly,
|
||||
NativeFile::Unowned),
|
||||
[](StringRef message) {
|
||||
// Uncomment for debugging
|
||||
// llvm::errs() << "to_client: " << message << '\n';
|
||||
});
|
||||
server_up = std::make_unique<TestServer>(
|
||||
"lldb-mcp", "0.1.0",
|
||||
[this](StringRef msg) { logged_messages.push_back(msg.str()); });
|
||||
binder = server_up->Bind(*to_client);
|
||||
auto server_handle = to_server->RegisterMessageHandler(loop, *binder);
|
||||
EXPECT_THAT_EXPECTED(server_handle, Succeeded());
|
||||
binder->OnError([](llvm::Error error) {
|
||||
llvm::errs() << formatv("Server transport error: {0}", error);
|
||||
});
|
||||
handles[0] = std::move(*server_handle);
|
||||
|
||||
server_up = std::make_unique<TestServer>("lldb-mcp", "0.1.0", *to_client,
|
||||
[](StringRef message) {
|
||||
// Uncomment for debugging
|
||||
// llvm::errs() << "server: " <<
|
||||
// message << '\n';
|
||||
});
|
||||
auto client_handle = to_client->RegisterMessageHandler(loop, client);
|
||||
EXPECT_THAT_EXPECTED(client_handle, Succeeded());
|
||||
handles[1] = std::move(*client_handle);
|
||||
}
|
||||
|
||||
auto maybe_from_client_handle =
|
||||
from_client->RegisterMessageHandler(loop, message_handler);
|
||||
EXPECT_THAT_EXPECTED(maybe_from_client_handle, Succeeded());
|
||||
handles[0] = std::move(*maybe_from_client_handle);
|
||||
template <typename Result, typename Params>
|
||||
Expected<json::Value> Call(StringRef method, const Params ¶ms) {
|
||||
std::promise<Response> promised_result;
|
||||
Request req =
|
||||
lldb_protocol::mcp::Request{/*id=*/1, method.str(), toJSON(params)};
|
||||
EXPECT_THAT_ERROR(to_server->Send(req), Succeeded());
|
||||
EXPECT_CALL(client, Received(testing::An<const Response &>()))
|
||||
.WillOnce(
|
||||
[&](const Response &resp) { promised_result.set_value(resp); });
|
||||
Run();
|
||||
Response resp = promised_result.get_future().get();
|
||||
return toJSON(resp);
|
||||
}
|
||||
|
||||
auto maybe_to_client_handle =
|
||||
to_client->RegisterMessageHandler(loop, *server_up);
|
||||
EXPECT_THAT_EXPECTED(maybe_to_client_handle, Succeeded());
|
||||
handles[1] = std::move(*maybe_to_client_handle);
|
||||
template <typename Result>
|
||||
Expected<json::Value>
|
||||
Capture(llvm::unique_function<void(Reply<Result>)> &fn) {
|
||||
std::promise<llvm::Expected<Result>> promised_result;
|
||||
fn([&promised_result](llvm::Expected<Result> result) {
|
||||
promised_result.set_value(std::move(result));
|
||||
});
|
||||
Run();
|
||||
llvm::Expected<Result> result = promised_result.get_future().get();
|
||||
if (!result)
|
||||
return result.takeError();
|
||||
return toJSON(*result);
|
||||
}
|
||||
|
||||
template <typename Result, typename Params>
|
||||
Expected<json::Value>
|
||||
Capture(llvm::unique_function<void(const Params &, Reply<Result>)> &fn,
|
||||
const Params ¶ms) {
|
||||
std::promise<llvm::Expected<Result>> promised_result;
|
||||
fn(params, [&promised_result](llvm::Expected<Result> result) {
|
||||
promised_result.set_value(std::move(result));
|
||||
});
|
||||
Run();
|
||||
llvm::Expected<Result> result = promised_result.get_future().get();
|
||||
if (!result)
|
||||
return result.takeError();
|
||||
return toJSON(*result);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
Request make_request(StringLiteral method, T &¶ms, Id id = 1) {
|
||||
return Request{id, method.str(), toJSON(std::forward<T>(params))};
|
||||
}
|
||||
|
||||
template <typename T> Response make_response(T &&result, Id id = 1) {
|
||||
return Response{id, std::forward<T>(result)};
|
||||
inline testing::internal::EqMatcher<llvm::json::Value> HasJSON(T x) {
|
||||
return testing::internal::EqMatcher<llvm::json::Value>(toJSON(x));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST_F(ProtocolServerMCPTest, Initialization) {
|
||||
Request request = make_request(
|
||||
"initialize", InitializeParams{/*protocolVersion=*/"2024-11-05",
|
||||
/*capabilities=*/{},
|
||||
/*clientInfo=*/{"lldb-unit", "0.1.0"}});
|
||||
Response response = make_response(
|
||||
InitializeResult{/*protocolVersion=*/"2024-11-05",
|
||||
/*capabilities=*/{/*supportsToolsList=*/true},
|
||||
/*serverInfo=*/{"lldb-mcp", "0.1.0"}});
|
||||
|
||||
ASSERT_THAT_ERROR(Write(request), Succeeded());
|
||||
EXPECT_CALL(message_handler, Received(response));
|
||||
EXPECT_THAT_ERROR(Run(), Succeeded());
|
||||
EXPECT_THAT_EXPECTED(
|
||||
(Call<InitializeResult, InitializeParams>(
|
||||
"initialize",
|
||||
InitializeParams{/*protocolVersion=*/"2024-11-05",
|
||||
/*capabilities=*/{},
|
||||
/*clientInfo=*/{"lldb-unit", "0.1.0"}})),
|
||||
HasValue(make_response(
|
||||
InitializeResult{/*protocolVersion=*/"2024-11-05",
|
||||
/*capabilities=*/
|
||||
{
|
||||
/*supportsToolsList=*/true,
|
||||
/*supportsResourcesList=*/true,
|
||||
},
|
||||
/*serverInfo=*/{"lldb-mcp", "0.1.0"}})));
|
||||
}
|
||||
|
||||
TEST_F(ProtocolServerMCPTest, ToolsList) {
|
||||
server_up->AddTool(std::make_unique<TestTool>("test", "test tool"));
|
||||
|
||||
Request request = make_request("tools/list", Void{}, /*id=*/"one");
|
||||
|
||||
ToolDefinition test_tool;
|
||||
test_tool.name = "test";
|
||||
test_tool.description = "test tool";
|
||||
test_tool.inputSchema = json::Object{{"type", "object"}};
|
||||
|
||||
Response response = make_response(ListToolsResult{{test_tool}}, /*id=*/"one");
|
||||
|
||||
ASSERT_THAT_ERROR(Write(request), llvm::Succeeded());
|
||||
EXPECT_CALL(message_handler, Received(response));
|
||||
EXPECT_THAT_ERROR(Run(), Succeeded());
|
||||
EXPECT_THAT_EXPECTED(Call<ListToolsResult>("tools/list", Void{}),
|
||||
HasValue(make_response(ListToolsResult{{test_tool}})));
|
||||
}
|
||||
|
||||
TEST_F(ProtocolServerMCPTest, ResourcesList) {
|
||||
server_up->AddResourceProvider(std::make_unique<TestResourceProvider>());
|
||||
|
||||
Request request = make_request("resources/list", Void{});
|
||||
Response response = make_response(ListResourcesResult{
|
||||
{{/*uri=*/"lldb://foo/bar", /*name=*/"name",
|
||||
/*description=*/"description", /*mimeType=*/"application/json"}}});
|
||||
|
||||
ASSERT_THAT_ERROR(Write(request), llvm::Succeeded());
|
||||
EXPECT_CALL(message_handler, Received(response));
|
||||
EXPECT_THAT_ERROR(Run(), Succeeded());
|
||||
EXPECT_THAT_EXPECTED(Call<ListResourcesResult>("resources/list", Void{}),
|
||||
HasValue(make_response(ListResourcesResult{{
|
||||
{
|
||||
/*uri=*/"lldb://foo/bar",
|
||||
/*name=*/"name",
|
||||
/*description=*/"description",
|
||||
/*mimeType=*/"application/json",
|
||||
},
|
||||
}})));
|
||||
}
|
||||
|
||||
TEST_F(ProtocolServerMCPTest, ToolsCall) {
|
||||
server_up->AddTool(std::make_unique<TestTool>("test", "test tool"));
|
||||
|
||||
Request request = make_request(
|
||||
"tools/call", CallToolParams{/*name=*/"test", /*arguments=*/json::Object{
|
||||
{"arguments", "foo"},
|
||||
{"debugger_id", 0},
|
||||
}});
|
||||
Response response = make_response(CallToolResult{{{/*text=*/"foo"}}});
|
||||
|
||||
ASSERT_THAT_ERROR(Write(request), llvm::Succeeded());
|
||||
EXPECT_CALL(message_handler, Received(response));
|
||||
EXPECT_THAT_ERROR(Run(), Succeeded());
|
||||
EXPECT_THAT_EXPECTED(
|
||||
(Call<CallToolResult, CallToolParams>("tools/call",
|
||||
CallToolParams{
|
||||
/*name=*/"test",
|
||||
/*arguments=*/
|
||||
json::Object{
|
||||
{"arguments", "foo"},
|
||||
{"debugger_id", 0},
|
||||
},
|
||||
})),
|
||||
HasValue(make_response(CallToolResult{{{/*text=*/"foo"}}})));
|
||||
}
|
||||
|
||||
TEST_F(ProtocolServerMCPTest, ToolsCallError) {
|
||||
server_up->AddTool(std::make_unique<ErrorTool>("error", "error tool"));
|
||||
|
||||
Request request = make_request(
|
||||
"tools/call", CallToolParams{/*name=*/"error", /*arguments=*/json::Object{
|
||||
{"arguments", "foo"},
|
||||
{"debugger_id", 0},
|
||||
}});
|
||||
Response response =
|
||||
make_response(lldb_protocol::mcp::Error{eErrorCodeInternalError,
|
||||
/*message=*/"error"});
|
||||
|
||||
ASSERT_THAT_ERROR(Write(request), llvm::Succeeded());
|
||||
EXPECT_CALL(message_handler, Received(response));
|
||||
EXPECT_THAT_ERROR(Run(), Succeeded());
|
||||
EXPECT_THAT_EXPECTED((Call<CallToolResult, CallToolParams>(
|
||||
"tools/call", CallToolParams{
|
||||
/*name=*/"error",
|
||||
/*arguments=*/
|
||||
json::Object{
|
||||
{"arguments", "foo"},
|
||||
{"debugger_id", 0},
|
||||
},
|
||||
})),
|
||||
HasValue(make_response(lldb_protocol::mcp::Error{
|
||||
eErrorCodeInternalError, "error"})));
|
||||
}
|
||||
|
||||
TEST_F(ProtocolServerMCPTest, ToolsCallFail) {
|
||||
server_up->AddTool(std::make_unique<FailTool>("fail", "fail tool"));
|
||||
|
||||
Request request = make_request(
|
||||
"tools/call", CallToolParams{/*name=*/"fail", /*arguments=*/json::Object{
|
||||
{"arguments", "foo"},
|
||||
{"debugger_id", 0},
|
||||
}});
|
||||
Response response =
|
||||
make_response(CallToolResult{{{/*text=*/"failed"}}, /*isError=*/true});
|
||||
|
||||
ASSERT_THAT_ERROR(Write(request), llvm::Succeeded());
|
||||
EXPECT_CALL(message_handler, Received(response));
|
||||
EXPECT_THAT_ERROR(Run(), Succeeded());
|
||||
EXPECT_THAT_EXPECTED((Call<CallToolResult, CallToolParams>(
|
||||
"tools/call", CallToolParams{
|
||||
/*name=*/"fail",
|
||||
/*arguments=*/
|
||||
json::Object{
|
||||
{"arguments", "foo"},
|
||||
{"debugger_id", 0},
|
||||
},
|
||||
})),
|
||||
HasValue(make_response(CallToolResult{
|
||||
{{/*text=*/"failed"}},
|
||||
/*isError=*/true,
|
||||
})));
|
||||
}
|
||||
|
||||
TEST_F(ProtocolServerMCPTest, NotificationInitialized) {
|
||||
bool handler_called = false;
|
||||
std::condition_variable cv;
|
||||
|
||||
server_up->AddNotificationHandler(
|
||||
"notifications/initialized",
|
||||
[&](const Notification ¬ification) { handler_called = true; });
|
||||
llvm::StringLiteral request =
|
||||
R"json({"method":"notifications/initialized","jsonrpc":"2.0"})json";
|
||||
|
||||
ASSERT_THAT_ERROR(Write(request), llvm::Succeeded());
|
||||
EXPECT_THAT_ERROR(Run(), Succeeded());
|
||||
EXPECT_TRUE(handler_called);
|
||||
EXPECT_THAT_ERROR(to_server->Send(lldb_protocol::mcp::Notification{
|
||||
"notifications/initialized",
|
||||
std::nullopt,
|
||||
}),
|
||||
Succeeded());
|
||||
Run();
|
||||
EXPECT_THAT(logged_messages,
|
||||
testing::Contains("MCP initialization complete"));
|
||||
}
|
||||
|
||||
@@ -6,19 +6,105 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef LLDB_UNITTESTS_TESTINGSUPPORT_HOST_NATIVEPROCESSTESTUTILS_H
|
||||
#define LLDB_UNITTESTS_TESTINGSUPPORT_HOST_NATIVEPROCESSTESTUTILS_H
|
||||
#ifndef LLDB_UNITTESTS_TESTINGSUPPORT_HOST_JSONTRANSPORTTESTUTILITIES_H
|
||||
#define LLDB_UNITTESTS_TESTINGSUPPORT_HOST_JSONTRANSPORTTESTUTILITIES_H
|
||||
|
||||
#include "lldb/Host/FileSystem.h"
|
||||
#include "lldb/Host/JSONTransport.h"
|
||||
#include "lldb/Host/MainLoop.h"
|
||||
#include "lldb/Utility/FileSpec.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "llvm/Testing/Support/Error.h"
|
||||
#include "gmock/gmock.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
template <typename Req, typename Resp, typename Evt>
|
||||
class MockMessageHandler final
|
||||
: public lldb_private::Transport<Req, Resp, Evt>::MessageHandler {
|
||||
template <typename Proto>
|
||||
class TestTransport final
|
||||
: public lldb_private::transport::JSONTransport<Proto> {
|
||||
public:
|
||||
MOCK_METHOD(void, Received, (const Evt &), (override));
|
||||
MOCK_METHOD(void, Received, (const Req &), (override));
|
||||
MOCK_METHOD(void, Received, (const Resp &), (override));
|
||||
using MessageHandler =
|
||||
typename lldb_private::transport::JSONTransport<Proto>::MessageHandler;
|
||||
|
||||
static std::pair<std::unique_ptr<TestTransport<Proto>>,
|
||||
std::unique_ptr<TestTransport<Proto>>>
|
||||
createPair() {
|
||||
std::unique_ptr<TestTransport<Proto>> transports[2] = {
|
||||
std::make_unique<TestTransport<Proto>>(),
|
||||
std::make_unique<TestTransport<Proto>>()};
|
||||
return std::make_pair(std::move(transports[0]), std::move(transports[1]));
|
||||
}
|
||||
|
||||
explicit TestTransport() {
|
||||
llvm::Expected<lldb::FileUP> dummy_file =
|
||||
lldb_private::FileSystem::Instance().Open(
|
||||
lldb_private::FileSpec(lldb_private::FileSystem::DEV_NULL),
|
||||
lldb_private::File::eOpenOptionReadWrite);
|
||||
EXPECT_THAT_EXPECTED(dummy_file, llvm::Succeeded());
|
||||
m_dummy_file = std::move(*dummy_file);
|
||||
}
|
||||
|
||||
llvm::Error Send(const typename Proto::Evt &evt) override {
|
||||
EXPECT_TRUE(m_loop && m_handler)
|
||||
<< "Send called before RegisterMessageHandler";
|
||||
m_loop->AddPendingCallback([this, evt](lldb_private::MainLoopBase &) {
|
||||
m_handler->Received(evt);
|
||||
});
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error Send(const typename Proto::Req &req) override {
|
||||
EXPECT_TRUE(m_loop && m_handler)
|
||||
<< "Send called before RegisterMessageHandler";
|
||||
m_loop->AddPendingCallback([this, req](lldb_private::MainLoopBase &) {
|
||||
m_handler->Received(req);
|
||||
});
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error Send(const typename Proto::Resp &resp) override {
|
||||
EXPECT_TRUE(m_loop && m_handler)
|
||||
<< "Send called before RegisterMessageHandler";
|
||||
m_loop->AddPendingCallback([this, resp](lldb_private::MainLoopBase &) {
|
||||
m_handler->Received(resp);
|
||||
});
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Expected<lldb_private::MainLoop::ReadHandleUP>
|
||||
RegisterMessageHandler(lldb_private::MainLoop &loop,
|
||||
MessageHandler &handler) override {
|
||||
if (!m_loop)
|
||||
m_loop = &loop;
|
||||
if (!m_handler)
|
||||
m_handler = &handler;
|
||||
lldb_private::Status status;
|
||||
auto handle = loop.RegisterReadObject(
|
||||
m_dummy_file, [](lldb_private::MainLoopBase &) {}, status);
|
||||
if (status.Fail())
|
||||
return status.takeError();
|
||||
return handle;
|
||||
}
|
||||
|
||||
protected:
|
||||
void Log(llvm::StringRef message) override {};
|
||||
|
||||
private:
|
||||
lldb_private::MainLoop *m_loop = nullptr;
|
||||
MessageHandler *m_handler = nullptr;
|
||||
// Dummy file for registering with the MainLoop.
|
||||
lldb::FileSP m_dummy_file = nullptr;
|
||||
};
|
||||
|
||||
template <typename Proto>
|
||||
class MockMessageHandler final
|
||||
: public lldb_private::transport::JSONTransport<Proto>::MessageHandler {
|
||||
public:
|
||||
MOCK_METHOD(void, Received, (const typename Proto::Req &), (override));
|
||||
MOCK_METHOD(void, Received, (const typename Proto::Resp &), (override));
|
||||
MOCK_METHOD(void, Received, (const typename Proto::Evt &), (override));
|
||||
MOCK_METHOD(void, OnError, (llvm::Error), (override));
|
||||
MOCK_METHOD(void, OnClosed, (), (override));
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user