mirror of
https://github.com/intel/llvm.git
synced 2026-01-13 11:02:04 +08:00
[lldb] Add Socket::CreatePair (#145015)
It creates a pair of connected sockets using the simplest mechanism for the given platform (TCP on windows, socketpair(2) elsewhere). Main motivation is to remove the ugly platform-specific code in ProcessGDBRemote::LaunchAndConnectToDebugserver, but it can also be used in other places where we need to create a pair of connected sockets.
This commit is contained in:
@@ -106,6 +106,10 @@ public:
|
||||
static std::unique_ptr<Socket> Create(const SocketProtocol protocol,
|
||||
Status &error);
|
||||
|
||||
using Pair = std::pair<std::unique_ptr<Socket>, std::unique_ptr<Socket>>;
|
||||
static llvm::Expected<Pair>
|
||||
CreatePair(std::optional<SocketProtocol> protocol = std::nullopt);
|
||||
|
||||
virtual Status Connect(llvm::StringRef name) = 0;
|
||||
virtual Status Listen(llvm::StringRef name, int backlog) = 0;
|
||||
|
||||
|
||||
@@ -23,6 +23,10 @@ public:
|
||||
TCPSocket(NativeSocket socket, bool should_close);
|
||||
~TCPSocket() override;
|
||||
|
||||
using Pair =
|
||||
std::pair<std::unique_ptr<TCPSocket>, std::unique_ptr<TCPSocket>>;
|
||||
static llvm::Expected<Pair> CreatePair();
|
||||
|
||||
// returns port number or 0 if error
|
||||
uint16_t GetLocalPortNumber() const;
|
||||
|
||||
|
||||
@@ -19,6 +19,10 @@ public:
|
||||
DomainSocket(NativeSocket socket, bool should_close);
|
||||
explicit DomainSocket(bool should_close);
|
||||
|
||||
using Pair =
|
||||
std::pair<std::unique_ptr<DomainSocket>, std::unique_ptr<DomainSocket>>;
|
||||
static llvm::Expected<Pair> CreatePair();
|
||||
|
||||
Status Connect(llvm::StringRef name) override;
|
||||
Status Listen(llvm::StringRef name, int backlog) override;
|
||||
|
||||
|
||||
@@ -234,6 +234,23 @@ std::unique_ptr<Socket> Socket::Create(const SocketProtocol protocol,
|
||||
return socket_up;
|
||||
}
|
||||
|
||||
llvm::Expected<Socket::Pair>
|
||||
Socket::CreatePair(std::optional<SocketProtocol> protocol) {
|
||||
constexpr SocketProtocol kBestProtocol =
|
||||
LLDB_ENABLE_POSIX ? ProtocolUnixDomain : ProtocolTcp;
|
||||
switch (protocol.value_or(kBestProtocol)) {
|
||||
case ProtocolTcp:
|
||||
return TCPSocket::CreatePair();
|
||||
#if LLDB_ENABLE_POSIX
|
||||
case ProtocolUnixDomain:
|
||||
case ProtocolUnixAbstract:
|
||||
return DomainSocket::CreatePair();
|
||||
#endif
|
||||
default:
|
||||
return llvm::createStringError("Unsupported protocol");
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<Socket>>
|
||||
Socket::TcpConnect(llvm::StringRef host_and_port) {
|
||||
Log *log = GetLog(LLDBLog::Connection);
|
||||
|
||||
@@ -52,6 +52,32 @@ TCPSocket::TCPSocket(NativeSocket socket, bool should_close)
|
||||
|
||||
TCPSocket::~TCPSocket() { CloseListenSockets(); }
|
||||
|
||||
llvm::Expected<TCPSocket::Pair> TCPSocket::CreatePair() {
|
||||
auto listen_socket_up = std::make_unique<TCPSocket>(true);
|
||||
if (Status error = listen_socket_up->Listen("localhost:0", 5); error.Fail())
|
||||
return error.takeError();
|
||||
|
||||
std::string connect_address =
|
||||
llvm::StringRef(listen_socket_up->GetListeningConnectionURI()[0])
|
||||
.split("://")
|
||||
.second.str();
|
||||
|
||||
auto connect_socket_up = std::make_unique<TCPSocket>(true);
|
||||
if (Status error = connect_socket_up->Connect(connect_address); error.Fail())
|
||||
return error.takeError();
|
||||
|
||||
// Connection has already been made above, so a short timeout is sufficient.
|
||||
Socket *accept_socket;
|
||||
if (Status error =
|
||||
listen_socket_up->Accept(std::chrono::seconds(1), accept_socket);
|
||||
error.Fail())
|
||||
return error.takeError();
|
||||
|
||||
return Pair(
|
||||
std::move(connect_socket_up),
|
||||
std::unique_ptr<TCPSocket>(static_cast<TCPSocket *>(accept_socket)));
|
||||
}
|
||||
|
||||
bool TCPSocket::IsValid() const {
|
||||
return m_socket != kInvalidSocketValue || m_listen_sockets.size() != 0;
|
||||
}
|
||||
|
||||
@@ -13,9 +13,11 @@
|
||||
#endif
|
||||
|
||||
#include "llvm/Support/Errno.h"
|
||||
#include "llvm/Support/Error.h"
|
||||
#include "llvm/Support/FileSystem.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <fcntl.h>
|
||||
#include <memory>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/un.h>
|
||||
@@ -76,6 +78,31 @@ DomainSocket::DomainSocket(SocketProtocol protocol, NativeSocket socket,
|
||||
m_socket = socket;
|
||||
}
|
||||
|
||||
llvm::Expected<DomainSocket::Pair> DomainSocket::CreatePair() {
|
||||
int sockets[2];
|
||||
int type = SOCK_STREAM;
|
||||
#ifdef SOCK_CLOEXEC
|
||||
type |= SOCK_CLOEXEC;
|
||||
#endif
|
||||
if (socketpair(AF_UNIX, type, 0, sockets) == -1)
|
||||
return llvm::errorCodeToError(llvm::errnoAsErrorCode());
|
||||
|
||||
#ifndef SOCK_CLOEXEC
|
||||
for (int s : sockets) {
|
||||
int r = fcntl(s, F_SETFD, FD_CLOEXEC | fcntl(s, F_GETFD));
|
||||
assert(r == 0);
|
||||
(void)r;
|
||||
}
|
||||
#endif
|
||||
|
||||
return Pair(std::unique_ptr<DomainSocket>(
|
||||
new DomainSocket(ProtocolUnixDomain, sockets[0],
|
||||
/*should_close=*/true)),
|
||||
std::unique_ptr<DomainSocket>(
|
||||
new DomainSocket(ProtocolUnixDomain, sockets[1],
|
||||
/*should_close=*/true)));
|
||||
}
|
||||
|
||||
Status DomainSocket::Connect(llvm::StringRef name) {
|
||||
sockaddr_un saddr_un;
|
||||
socklen_t saddr_un_len;
|
||||
|
||||
@@ -1141,34 +1141,14 @@ void GDBRemoteCommunication::DumpHistory(Stream &strm) { m_history.Dump(strm); }
|
||||
llvm::Error
|
||||
GDBRemoteCommunication::ConnectLocally(GDBRemoteCommunication &client,
|
||||
GDBRemoteCommunication &server) {
|
||||
const int backlog = 5;
|
||||
TCPSocket listen_socket(true);
|
||||
if (llvm::Error error =
|
||||
listen_socket.Listen("localhost:0", backlog).ToError())
|
||||
return error;
|
||||
llvm::Expected<Socket::Pair> pair = Socket::CreatePair();
|
||||
if (!pair)
|
||||
return pair.takeError();
|
||||
|
||||
llvm::SmallString<32> remote_addr;
|
||||
llvm::raw_svector_ostream(remote_addr)
|
||||
<< "connect://localhost:" << listen_socket.GetLocalPortNumber();
|
||||
|
||||
std::unique_ptr<ConnectionFileDescriptor> conn_up(
|
||||
new ConnectionFileDescriptor());
|
||||
Status status;
|
||||
if (conn_up->Connect(remote_addr, &status) != lldb::eConnectionStatusSuccess)
|
||||
return llvm::createStringError(llvm::inconvertibleErrorCode(),
|
||||
"Unable to connect: %s", status.AsCString());
|
||||
|
||||
// The connection was already established above, so a short timeout is
|
||||
// sufficient.
|
||||
Socket *accept_socket = nullptr;
|
||||
if (Status accept_status =
|
||||
listen_socket.Accept(std::chrono::seconds(1), accept_socket);
|
||||
accept_status.Fail())
|
||||
return accept_status.takeError();
|
||||
|
||||
client.SetConnection(std::move(conn_up));
|
||||
client.SetConnection(
|
||||
std::make_unique<ConnectionFileDescriptor>(pair->first.release()));
|
||||
server.SetConnection(
|
||||
std::make_unique<ConnectionFileDescriptor>(accept_socket));
|
||||
std::make_unique<ConnectionFileDescriptor>(pair->second.release()));
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
|
||||
@@ -7,14 +7,14 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "lldb/Core/Communication.h"
|
||||
#include "TestingSupport/SubsystemRAII.h"
|
||||
#include "lldb/Core/ThreadedCommunication.h"
|
||||
#include "lldb/Host/Config.h"
|
||||
#include "lldb/Host/ConnectionFileDescriptor.h"
|
||||
#include "lldb/Host/Pipe.h"
|
||||
#include "lldb/Host/Socket.h"
|
||||
#include "llvm/Testing/Support/Error.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "TestingSupport/Host/SocketTestUtilities.h"
|
||||
#include "TestingSupport/SubsystemRAII.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
@@ -31,15 +31,17 @@ private:
|
||||
};
|
||||
|
||||
static void CommunicationReadTest(bool use_read_thread) {
|
||||
std::unique_ptr<TCPSocket> a, b;
|
||||
ASSERT_TRUE(CreateTCPConnectedSockets("localhost", &a, &b));
|
||||
llvm::Expected<Socket::Pair> pair = Socket::CreatePair();
|
||||
ASSERT_THAT_EXPECTED(pair, llvm::Succeeded());
|
||||
Socket &a = *pair->first;
|
||||
|
||||
size_t num_bytes = 4;
|
||||
ASSERT_THAT_ERROR(a->Write("test", num_bytes).ToError(), llvm::Succeeded());
|
||||
ASSERT_THAT_ERROR(a.Write("test", num_bytes).ToError(), llvm::Succeeded());
|
||||
ASSERT_EQ(num_bytes, 4U);
|
||||
|
||||
ThreadedCommunication comm("test");
|
||||
comm.SetConnection(std::make_unique<ConnectionFileDescriptor>(b.release()));
|
||||
comm.SetConnection(
|
||||
std::make_unique<ConnectionFileDescriptor>(pair->second.release()));
|
||||
comm.SetCloseOnEOF(true);
|
||||
|
||||
if (use_read_thread) {
|
||||
@@ -73,7 +75,7 @@ static void CommunicationReadTest(bool use_read_thread) {
|
||||
EXPECT_THAT_ERROR(error.ToError(), llvm::Failed());
|
||||
|
||||
// This read should return EOF.
|
||||
ASSERT_THAT_ERROR(a->Close().ToError(), llvm::Succeeded());
|
||||
ASSERT_THAT_ERROR(a.Close().ToError(), llvm::Succeeded());
|
||||
error.Clear();
|
||||
EXPECT_EQ(
|
||||
comm.Read(buf, sizeof(buf), std::chrono::seconds(5), status, &error), 0U);
|
||||
@@ -118,17 +120,19 @@ TEST_F(CommunicationTest, ReadThread) {
|
||||
}
|
||||
|
||||
TEST_F(CommunicationTest, SynchronizeWhileClosing) {
|
||||
std::unique_ptr<TCPSocket> a, b;
|
||||
ASSERT_TRUE(CreateTCPConnectedSockets("localhost", &a, &b));
|
||||
llvm::Expected<Socket::Pair> pair = Socket::CreatePair();
|
||||
ASSERT_THAT_EXPECTED(pair, llvm::Succeeded());
|
||||
Socket &a = *pair->first;
|
||||
|
||||
ThreadedCommunication comm("test");
|
||||
comm.SetConnection(std::make_unique<ConnectionFileDescriptor>(b.release()));
|
||||
comm.SetConnection(
|
||||
std::make_unique<ConnectionFileDescriptor>(pair->second.release()));
|
||||
comm.SetCloseOnEOF(true);
|
||||
ASSERT_TRUE(comm.StartReadThread());
|
||||
|
||||
// Ensure that we can safely synchronize with the read thread while it is
|
||||
// closing the read end (in response to us closing the write end).
|
||||
ASSERT_THAT_ERROR(a->Close().ToError(), llvm::Succeeded());
|
||||
ASSERT_THAT_ERROR(a.Close().ToError(), llvm::Succeeded());
|
||||
comm.SynchronizeWithReadThread();
|
||||
|
||||
ASSERT_TRUE(comm.StopReadThread());
|
||||
|
||||
@@ -74,6 +74,41 @@ TEST_F(SocketTest, DecodeHostAndPort) {
|
||||
llvm::HasValue(Socket::HostAndPort{"abcd:12fg:AF58::1", 12345}));
|
||||
}
|
||||
|
||||
TEST_F(SocketTest, CreatePair) {
|
||||
std::vector<std::optional<Socket::SocketProtocol>> functional_protocols = {
|
||||
std::nullopt,
|
||||
Socket::ProtocolTcp,
|
||||
#if LLDB_ENABLE_POSIX
|
||||
Socket::ProtocolUnixDomain,
|
||||
Socket::ProtocolUnixAbstract,
|
||||
#endif
|
||||
};
|
||||
for (auto p : functional_protocols) {
|
||||
auto expected_socket_pair = Socket::CreatePair(p);
|
||||
ASSERT_THAT_EXPECTED(expected_socket_pair, llvm::Succeeded());
|
||||
Socket &a = *expected_socket_pair->first;
|
||||
Socket &b = *expected_socket_pair->second;
|
||||
size_t num_bytes = 1;
|
||||
ASSERT_THAT_ERROR(a.Write("a", num_bytes).takeError(), llvm::Succeeded());
|
||||
ASSERT_EQ(num_bytes, 1);
|
||||
char c;
|
||||
ASSERT_THAT_ERROR(b.Read(&c, num_bytes).takeError(), llvm::Succeeded());
|
||||
ASSERT_EQ(num_bytes, 1);
|
||||
ASSERT_EQ(c, 'a');
|
||||
}
|
||||
|
||||
std::vector<Socket::SocketProtocol> erroring_protocols = {
|
||||
#if !LLDB_ENABLE_POSIX
|
||||
Socket::ProtocolUnixDomain,
|
||||
Socket::ProtocolUnixAbstract,
|
||||
#endif
|
||||
};
|
||||
for (auto p : erroring_protocols) {
|
||||
ASSERT_THAT_EXPECTED(Socket::CreatePair(p),
|
||||
llvm::FailedWithMessage("Unsupported protocol"));
|
||||
}
|
||||
}
|
||||
|
||||
#if LLDB_ENABLE_POSIX
|
||||
TEST_F(SocketTest, DomainListenConnectAccept) {
|
||||
llvm::SmallString<64> Path;
|
||||
|
||||
Reference in New Issue
Block a user