[OpenMP][clang] Register vtables on device for indirect calls runtime (#167011)

This is a branch off of
https://github.com/llvm/llvm-project/pull/159856, in which consists of
the runtime portion of the changes required to support indirect function
and virtual function calls on an `omp target device` when the virtual
class / indirect function is mapped to the device from the host.

Key Changes

- Introduced a new flag OMP_DECLARE_TARGET_INDIRECT_VTABLE to mark
VTable registrations
- Modified setupIndirectCallTable to support both VTable entries and
indirect function pointers

Details:
The setupIndirectCallTable implementation was modified to support this
registration type by retrieving the first address of the VTable and
inferring the remaining data needed to build the indirect call table.
Since the Vtables / Classes registered as indirect can be larger than 8
bytes, and the vtables may not be at the first address we either need to
pass the size to __llvm_omp_indirect_call_lookup and have a check at
each step of the binary search, or add multiple entries to the indirect
table for each address registered. The latter was chosen.

Commit: a00def3f20e166d4fb9328e6f0bc0742cd0afa31 is not a part of this
PR and is handled / reviewed in:
https://github.com/llvm/llvm-project/pull/159856,

This is PR (2/3) 
Register Vtable PR (1/3):
https://github.com/llvm/llvm-project/pull/159856,
Codegen / _llvm_omp_indirect_call_lookup PR (3/3):
https://github.com/llvm/llvm-project/pull/159857
This commit is contained in:
Jason-VanBeusekom
2025-11-26 11:33:26 -06:00
committed by GitHub
parent 356479191c
commit 84d511df8d
4 changed files with 165 additions and 17 deletions

View File

@@ -94,6 +94,8 @@ enum OpenMPOffloadingDeclareTargetFlags {
OMP_DECLARE_TARGET_INDIRECT = 0x08,
/// This is an entry corresponding to a requirement to be registered.
OMP_REGISTER_REQUIRES = 0x10,
/// Mark the entry global as being an indirect vtable.
OMP_DECLARE_TARGET_INDIRECT_VTABLE = 0x20,
};
enum TargetAllocTy : int32_t {

View File

@@ -437,20 +437,22 @@ static int loadImagesOntoDevice(DeviceTy &Device) {
llvm::offloading::EntryTy DeviceEntry = Entry;
if (Entry.Size) {
if (Device.RTL->get_global(Binary, Entry.Size, Entry.SymbolName,
&DeviceEntry.Address) != OFFLOAD_SUCCESS)
REPORT("Failed to load symbol %s\n", Entry.SymbolName);
if (!(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT_VTABLE))
if (Device.RTL->get_global(Binary, Entry.Size, Entry.SymbolName,
&DeviceEntry.Address) != OFFLOAD_SUCCESS)
REPORT("Failed to load symbol %s\n", Entry.SymbolName);
// If unified memory is active, the corresponding global is a device
// reference to the host global. We need to initialize the pointer on
// the device to point to the memory on the host.
if ((PM->getRequirements() & OMP_REQ_UNIFIED_SHARED_MEMORY) ||
(PM->getRequirements() & OMPX_REQ_AUTO_ZERO_COPY)) {
if (!(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT_VTABLE) &&
!(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT) &&
((PM->getRequirements() & OMP_REQ_UNIFIED_SHARED_MEMORY) ||
(PM->getRequirements() & OMPX_REQ_AUTO_ZERO_COPY)))
if (Device.RTL->data_submit(DeviceId, DeviceEntry.Address,
Entry.Address,
Entry.Size) != OFFLOAD_SUCCESS)
REPORT("Failed to write symbol for USM %s\n", Entry.SymbolName);
}
} else if (Entry.Address) {
if (Device.RTL->get_function(Binary, Entry.SymbolName,
&DeviceEntry.Address) != OFFLOAD_SUCCESS)

View File

@@ -112,21 +112,58 @@ setupIndirectCallTable(DeviceTy &Device, __tgt_device_image *Image,
llvm::SmallVector<std::pair<void *, void *>> IndirectCallTable;
for (const auto &Entry : Entries) {
if (Entry.Kind != llvm::object::OffloadKind::OFK_OpenMP ||
Entry.Size == 0 || !(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT))
Entry.Size == 0 ||
(!(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT) &&
!(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT_VTABLE)))
continue;
assert(Entry.Size == sizeof(void *) && "Global not a function pointer?");
auto &[HstPtr, DevPtr] = IndirectCallTable.emplace_back();
size_t PtrSize = sizeof(void *);
if (Entry.Flags & OMP_DECLARE_TARGET_INDIRECT_VTABLE) {
// This is a VTable entry, the current entry is the first index of the
// VTable and Entry.Size is the total size of the VTable. Unlike the
// indirect function case below, the Global is not of size Entry.Size and
// is instead of size PtrSize (sizeof(void*)).
void *Vtable;
void *res;
if (Device.RTL->get_global(Binary, PtrSize, Entry.SymbolName, &Vtable))
return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
"failed to load %s", Entry.SymbolName);
void *Ptr;
if (Device.RTL->get_global(Binary, Entry.Size, Entry.SymbolName, &Ptr))
return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
"failed to load %s", Entry.SymbolName);
// HstPtr = Entry.Address;
if (Device.retrieveData(&res, Vtable, PtrSize, AsyncInfo))
return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
"failed to load %s", Entry.SymbolName);
if (Device.synchronize(AsyncInfo))
return error::createOffloadError(
error::ErrorCode::INVALID_BINARY,
"failed to synchronize after retrieving %s", Entry.SymbolName);
// Calculate and emplace entire Vtable from first Vtable byte
for (uint64_t i = 0; i < Entry.Size / PtrSize; ++i) {
auto &[HstPtr, DevPtr] = IndirectCallTable.emplace_back();
HstPtr = reinterpret_cast<void *>(
reinterpret_cast<uintptr_t>(Entry.Address) + i * PtrSize);
DevPtr = reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(res) +
i * PtrSize);
}
} else {
// Indirect function case: Entry.Size should equal PtrSize since we're
// dealing with a single function pointer (not a VTable)
assert(Entry.Size == PtrSize && "Global not a function pointer?");
auto &[HstPtr, DevPtr] = IndirectCallTable.emplace_back();
void *Ptr;
if (Device.RTL->get_global(Binary, Entry.Size, Entry.SymbolName, &Ptr))
return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
"failed to load %s", Entry.SymbolName);
HstPtr = Entry.Address;
if (Device.retrieveData(&DevPtr, Ptr, Entry.Size, AsyncInfo))
return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
"failed to load %s", Entry.SymbolName);
HstPtr = Entry.Address;
if (Device.retrieveData(&DevPtr, Ptr, Entry.Size, AsyncInfo))
return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
"failed to load %s", Entry.SymbolName);
}
if (Device.synchronize(AsyncInfo))
return error::createOffloadError(
error::ErrorCode::INVALID_BINARY,
"failed to synchronize after retrieving %s", Entry.SymbolName);
}
// If we do not have any indirect globals we exit early.

View File

@@ -0,0 +1,107 @@
// RUN: %libomptarget-compile-run-and-check-generic
#include <assert.h>
#include <omp.h>
#include <stdio.h>
// ---------------------------------------------------------------------------
// Various definitions copied from OpenMP RTL
typedef struct {
uint64_t Reserved;
uint16_t Version;
uint16_t Kind; // OpenMP==1
uint32_t Flags;
void *Address;
char *SymbolName;
uint64_t Size;
uint64_t Data;
void *AuxAddr;
} __tgt_offload_entry;
enum OpenMPOffloadingDeclareTargetFlags {
/// Mark the entry global as having a 'link' attribute.
OMP_DECLARE_TARGET_LINK = 0x01,
/// Mark the entry global as being an indirectly callable function.
OMP_DECLARE_TARGET_INDIRECT = 0x08,
/// This is an entry corresponding to a requirement to be registered.
OMP_REGISTER_REQUIRES = 0x10,
/// Mark the entry global as being an indirect vtable.
OMP_DECLARE_TARGET_INDIRECT_VTABLE = 0x20,
};
#pragma omp begin declare variant match(device = {kind(gpu)})
// Provided by the runtime.
void *__llvm_omp_indirect_call_lookup(void *host_ptr);
#pragma omp declare target to(__llvm_omp_indirect_call_lookup) \
device_type(nohost)
#pragma omp end declare variant
#pragma omp begin declare variant match(device = {kind(cpu)})
// We assume unified addressing on the CPU target.
void *__llvm_omp_indirect_call_lookup(void *host_ptr) { return host_ptr; }
#pragma omp end declare variant
#pragma omp begin declare target
void foo(int *i) { *i += 1; }
void bar(int *i) { *i += 10; }
void baz(int *i) { *i += 100; }
#pragma omp end declare target
typedef void (*fptr_t)(int *i);
// Dispatch Table - declare separately on host and device to avoid
// registering with the library; this also allows us to use separate
// names, which is convenient for debugging. This dispatchTable is
// intended to mimic what Clang emits for C++ vtables.
fptr_t dispatchTable[] = {foo, bar, baz};
#pragma omp begin declare target device_type(nohost)
fptr_t GPUdispatchTable[] = {foo, bar, baz};
fptr_t *GPUdispatchTablePtr = GPUdispatchTable;
#pragma omp end declare target
// Define "manual" OpenMP offload entries, where we emit Clang
// offloading entry structure definitions in the appropriate ELF
// section. This allows us to emulate the offloading entries that Clang would
// normally emit for us
__attribute__((weak, section("llvm_offload_entries"), aligned(8)))
const __tgt_offload_entry __offloading_entry[] = {{
0ULL, // Reserved
1, // Version
1, // Kind
OMP_DECLARE_TARGET_INDIRECT_VTABLE, // Flags
&dispatchTable, // Address
"GPUdispatchTablePtr", // SymbolName
(size_t)(sizeof(dispatchTable)), // Size
0ULL, // Data
NULL // AuxAddr
}};
// Mimic how Clang emits vtable pointers for C++ classes
typedef struct {
fptr_t *dispatchPtr;
} myClass;
// ---------------------------------------------------------------------------
int main() {
myClass obj_foo = {dispatchTable + 0};
myClass obj_bar = {dispatchTable + 1};
myClass obj_baz = {dispatchTable + 2};
int aaa = 0;
#pragma omp target map(aaa) map(to : obj_foo, obj_bar, obj_baz)
{
// Lookup
fptr_t *foo_ptr = __llvm_omp_indirect_call_lookup(obj_foo.dispatchPtr);
fptr_t *bar_ptr = __llvm_omp_indirect_call_lookup(obj_bar.dispatchPtr);
fptr_t *baz_ptr = __llvm_omp_indirect_call_lookup(obj_baz.dispatchPtr);
foo_ptr[0](&aaa);
bar_ptr[0](&aaa);
baz_ptr[0](&aaa);
}
assert(aaa == 111);
// CHECK: PASS
printf("PASS\n");
return 0;
}