[libc] Export a pointer to the RPC client directly (#117913)

Summary:
We currently have an unnecessary level of indirection when initializing
the RPC client. This is a holdover from when the RPC client was not
trivially copyable and simply makes it more complicated. Here we use the
`asm` syntax to give the C++ variable a valid name so that we can just
copy to it directly.

Another advantage to this, is that if users want to piggy-back on the
same RPC interface they need only declare theirs as extern with the same
symbol name, or make it weak to optionally use it if LIBC isn't
avaialb.e
This commit is contained in:
Joseph Huber
2024-11-27 14:57:38 -06:00
committed by GitHub
parent 175051b05e
commit 89d8e70031
8 changed files with 20 additions and 48 deletions

View File

@@ -13,14 +13,10 @@
namespace LIBC_NAMESPACE_DECL {
namespace rpc {
/// The libc client instance used to communicate with the server.
Client client;
/// Externally visible symbol to signify the usage of an RPC client to
/// whomever needs to run the server as well as provide a way to initialize
/// the client with a copy..
extern "C" [[gnu::visibility("protected")]] const void *__llvm_libc_rpc_client =
&client;
/// The libc client instance used to communicate with the server. Externally
/// visible symbol to signify the usage of an RPC client to whomever needs to
/// run the server as well as provide a way to initialize the client.
[[gnu::visibility("protected")]] Client client;
} // namespace rpc
} // namespace LIBC_NAMESPACE_DECL

View File

@@ -29,7 +29,7 @@ static_assert(cpp::is_trivially_copyable<Client>::value &&
"The client is not trivially copyable from the server");
/// The libc client instance used to communicate with the server.
extern Client client;
[[gnu::visibility("protected")]] extern Client client asm("__llvm_rpc_client");
} // namespace rpc
} // namespace LIBC_NAMESPACE_DECL

View File

@@ -477,27 +477,15 @@ int load(int argc, const char **argv, const char **envp, void *image,
// device's internal pointer.
hsa_executable_symbol_t rpc_client_sym;
if (hsa_status_t err = hsa_executable_get_symbol_by_name(
executable, "__llvm_libc_rpc_client", &dev_agent, &rpc_client_sym))
executable, "__llvm_rpc_client", &dev_agent, &rpc_client_sym))
handle_error(err);
void *rpc_client_host;
if (hsa_status_t err =
hsa_amd_memory_pool_allocate(finegrained_pool, sizeof(void *),
/*flags=*/0, &rpc_client_host))
handle_error(err);
hsa_amd_agents_allow_access(1, &dev_agent, nullptr, rpc_client_host);
void *rpc_client_dev;
if (hsa_status_t err = hsa_executable_symbol_get_info(
rpc_client_sym, HSA_EXECUTABLE_SYMBOL_INFO_VARIABLE_ADDRESS,
&rpc_client_dev))
handle_error(err);
// Copy the address of the client buffer from the device to the host.
if (hsa_status_t err = hsa_memcpy(rpc_client_host, host_agent, rpc_client_dev,
dev_agent, sizeof(void *)))
handle_error(err);
void *rpc_client_buffer;
if (hsa_status_t err =
hsa_amd_memory_lock(&client, sizeof(rpc::Client),
@@ -506,14 +494,12 @@ int load(int argc, const char **argv, const char **envp, void *image,
// Copy the RPC client buffer to the address pointed to by the symbol.
if (hsa_status_t err =
hsa_memcpy(*reinterpret_cast<void **>(rpc_client_host), dev_agent,
rpc_client_buffer, host_agent, sizeof(rpc::Client)))
hsa_memcpy(rpc_client_dev, dev_agent, rpc_client_buffer, host_agent,
sizeof(rpc::Client)))
handle_error(err);
if (hsa_status_t err = hsa_amd_memory_unlock(&client))
handle_error(err);
if (hsa_status_t err = hsa_amd_memory_pool_free(rpc_client_host))
handle_error(err);
// Obtain the GPU's fixed-frequency clock rate and copy it to the GPU.
// If the clock_freq symbol is missing, no work to do.

View File

@@ -314,15 +314,10 @@ int load(int argc, const char **argv, const char **envp, void *image,
CUdeviceptr rpc_client_dev = 0;
uint64_t client_ptr_size = sizeof(void *);
if (CUresult err = cuModuleGetGlobal(&rpc_client_dev, &client_ptr_size,
binary, "__llvm_libc_rpc_client"))
binary, "__llvm_rpc_client"))
handle_error(err);
CUdeviceptr rpc_client_host = 0;
if (CUresult err =
cuMemcpyDtoH(&rpc_client_host, rpc_client_dev, sizeof(void *)))
handle_error(err);
if (CUresult err =
cuMemcpyHtoD(rpc_client_host, &client, sizeof(rpc::Client)))
if (CUresult err = cuMemcpyHtoD(rpc_client_dev, &client, sizeof(rpc::Client)))
handle_error(err);
LaunchParameters single_threaded_params = {1, 1, 1, 1, 1, 1};

View File

@@ -1538,7 +1538,7 @@ private:
// required an RPC server. If its users were all optimized out then we can
// safely remove it.
// TODO: This should be somewhere more common in the future.
if (GlobalVariable *GV = M.getNamedGlobal("__llvm_libc_rpc_client")) {
if (GlobalVariable *GV = M.getNamedGlobal("__llvm_rpc_client")) {
if (!GV->getType()->isPointerTy())
return false;

View File

@@ -3,14 +3,14 @@
; RUN: opt -S -passes=openmp-opt < %s | FileCheck %s --check-prefix=PRELINK
@client = internal addrspace(1) global i64 zeroinitializer, align 8
@__llvm_libc_rpc_client = protected local_unnamed_addr addrspace(1) global ptr addrspacecast (ptr addrspace(1) @client to ptr), align 8
@__llvm_rpc_client = protected local_unnamed_addr addrspace(1) global ptr addrspacecast (ptr addrspace(1) @client to ptr), align 8
;.
; POSTLINK: @client = internal addrspace(1) global i64 0, align 8
; POSTLINK: @__llvm_libc_rpc_client = protected local_unnamed_addr addrspace(1) global ptr addrspacecast (ptr addrspace(1) @client to ptr), align 8
; POSTLINK: @__llvm_rpc_client = protected local_unnamed_addr addrspace(1) global ptr addrspacecast (ptr addrspace(1) @client to ptr), align 8
;.
; PRELINK: @client = internal addrspace(1) global i64 0, align 8
; PRELINK: @__llvm_libc_rpc_client = protected local_unnamed_addr addrspace(1) global ptr addrspacecast (ptr addrspace(1) @client to ptr), align 8
; PRELINK: @__llvm_rpc_client = protected local_unnamed_addr addrspace(1) global ptr addrspacecast (ptr addrspace(1) @client to ptr), align 8
;.
define i64 @a() {
; POSTLINK-LABEL: define {{[^@]+}}@a

View File

@@ -3,11 +3,11 @@
; RUN: opt -S -passes=openmp-opt < %s | FileCheck %s --check-prefix=PRELINK
@client = internal addrspace(1) global i32 zeroinitializer, align 8
@__llvm_libc_rpc_client = protected local_unnamed_addr addrspace(1) global ptr addrspacecast (ptr addrspace(1) @client to ptr), align 8
@__llvm_rpc_client = protected local_unnamed_addr addrspace(1) global ptr addrspacecast (ptr addrspace(1) @client to ptr), align 8
;.
; PRELINK: @client = internal addrspace(1) global i32 0, align 8
; PRELINK: @__llvm_libc_rpc_client = protected local_unnamed_addr addrspace(1) global ptr addrspacecast (ptr addrspace(1) @client to ptr), align 8
; PRELINK: @__llvm_rpc_client = protected local_unnamed_addr addrspace(1) global ptr addrspacecast (ptr addrspace(1) @client to ptr), align 8
;.
define void @a() {
; POSTLINK-LABEL: define {{[^@]+}}@a() {

View File

@@ -31,7 +31,7 @@ RPCServerTy::isDeviceUsingRPC(plugin::GenericDeviceTy &Device,
plugin::GenericGlobalHandlerTy &Handler,
plugin::DeviceImageTy &Image) {
#ifdef LIBOMPTARGET_RPC_SUPPORT
return Handler.isSymbolInImage(Device, Image, "__llvm_libc_rpc_client");
return Handler.isSymbolInImage(Device, Image, "__llvm_rpc_client");
#else
return false;
#endif
@@ -51,19 +51,14 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
"Failed to initialize RPC server for device %d", Device.getDeviceId());
// Get the address of the RPC client from the device.
void *ClientPtr;
plugin::GlobalTy ClientGlobal("__llvm_libc_rpc_client", sizeof(void *));
plugin::GlobalTy ClientGlobal("__llvm_rpc_client", sizeof(rpc::Client));
if (auto Err =
Handler.getGlobalMetadataFromDevice(Device, Image, ClientGlobal))
return Err;
if (auto Err = Device.dataRetrieve(&ClientPtr, ClientGlobal.getPtr(),
sizeof(void *), nullptr))
return Err;
rpc::Client client(NumPorts, RPCBuffer);
if (auto Err =
Device.dataSubmit(ClientPtr, &client, sizeof(rpc::Client), nullptr))
if (auto Err = Device.dataSubmit(ClientGlobal.getPtr(), &client,
sizeof(rpc::Client), nullptr))
return Err;
Buffers[Device.getDeviceId()] = RPCBuffer;