[MLIR] Adding gpu.host_register op and lower it to a runtime call.

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D85631
This commit is contained in:
Christian Sigg
2020-08-10 10:13:57 +02:00
parent 566a66703f
commit 2c48e3629c
21 changed files with 132 additions and 107 deletions

View File

@@ -440,6 +440,10 @@ public:
ConversionPatternRewriter &rewriter,
SmallVectorImpl<Value> &sizes) const;
/// Computes the size of type in bytes.
Value getSizeInBytes(Location loc, Type type,
ConversionPatternRewriter &rewriter) const;
/// Computes total size in bytes of to store the given shape.
Value getCumulativeSizeInBytes(Location loc, Type elementType,
ArrayRef<Value> shape,

View File

@@ -741,4 +741,16 @@ def GPU_ModuleEndOp : GPU_Op<"module_end", [
let printer = [{ p << getOperationName(); }];
}
def GPU_HostRegisterOp : GPU_Op<"host_register">,
Arguments<(ins AnyUnrankedMemRef:$value)> {
let summary = "Registers a memref for access from device.";
let description = [{
This op registers the host memory pointed to by a memref to be accessed from
a device.
}];
let assemblyFormat = "$value attr-dict `:` type($value)";
let verifier = [{ return success(); }];
}
#endif // GPU_OPS

View File

@@ -117,6 +117,26 @@ protected:
"mgpuStreamSynchronize",
llvmVoidType,
{llvmPointerType /* void *stream */}};
FunctionCallBuilder hostRegisterCallBuilder = {
"mgpuMemHostRegisterMemRef",
llvmVoidType,
{llvmIntPtrType /* intptr_t rank */,
llvmPointerType /* void *memrefDesc */,
llvmIntPtrType /* intptr_t elementSizeBytes */}};
};
/// A rewrite patter to convert gpu.host_register operations into a GPU runtime
/// call. Currently it supports CUDA and ROCm (HIP).
class ConvertHostRegisterOpToGpuRuntimeCallPattern
: public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
public:
ConvertHostRegisterOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
: ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
private:
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
/// A rewrite patter to convert gpu.launch_func operations into a sequence of
@@ -192,6 +212,33 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
builder.getSymbolRefAttr(function), arguments);
}
// Returns whether value is of LLVM type.
static bool isLLVMType(Value value) {
return value.getType().isa<LLVM::LLVMType>();
}
LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (!llvm::all_of(operands, isLLVMType))
return rewriter.notifyMatchFailure(
op, "Cannot convert if operands aren't of LLVM type.");
Location loc = op->getLoc();
auto memRefType = cast<gpu::HostRegisterOp>(op).value().getType();
auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType();
auto elementSize = getSizeInBytes(loc, elementType, rewriter);
auto arguments =
typeConverter.promoteOperands(loc, op->getOperands(), operands, rewriter);
arguments.push_back(elementSize);
hostRegisterCallBuilder.create(loc, rewriter, arguments);
rewriter.eraseOp(op);
return success();
}
// Creates a struct containing all kernel parameters on the stack and returns
// an array of type-erased pointers to the fields of the struct. The array can
// then be passed to the CUDA / ROCm (HIP) kernel launch calls.
@@ -269,11 +316,6 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
LLVM::Linkage::Internal);
}
// Returns whether value is of LLVM type.
static bool isLLVMType(Value value) {
return value.getType().isa<LLVM::LLVMType>();
}
// Emits LLVM IR to launch a kernel function. Expects the module that contains
// the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a
// hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR.
@@ -351,6 +393,7 @@ mlir::createGpuToLLVMConversionPass(StringRef gpuBinaryAnnotation) {
void mlir::populateGpuToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
StringRef gpuBinaryAnnotation) {
patterns.insert<ConvertHostRegisterOpToGpuRuntimeCallPattern>(converter);
patterns.insert<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(
converter, gpuBinaryAnnotation);
patterns.insert<EraseGpuModuleOpPattern>(&converter.getContext());

View File

@@ -927,6 +927,22 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
: createIndexConstant(rewriter, loc, s));
}
Value ConvertToLLVMPattern::getSizeInBytes(
Location loc, Type type, ConversionPatternRewriter &rewriter) const {
// Compute the size of an individual element. This emits the MLIR equivalent
// of the following sizeof(...) implementation in LLVM IR:
// %0 = getelementptr %elementType* null, %indexType 1
// %1 = ptrtoint %elementType* %0 to %indexType
// which is a common pattern of getting the size of a type in bytes.
auto convertedPtrType =
typeConverter.convertType(type).cast<LLVM::LLVMType>().getPointerTo();
auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
auto gep = rewriter.create<LLVM::GEPOp>(
loc, convertedPtrType,
ArrayRef<Value>{nullPtr, createIndexConstant(rewriter, loc, 1)});
return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
}
Value ConvertToLLVMPattern::getCumulativeSizeInBytes(
Location loc, Type elementType, ArrayRef<Value> sizes,
ConversionPatternRewriter &rewriter) const {
@@ -936,21 +952,7 @@ Value ConvertToLLVMPattern::getCumulativeSizeInBytes(
for (unsigned i = 1, e = sizes.size(); i < e; ++i)
cumulativeSizeInBytes = rewriter.create<LLVM::MulOp>(
loc, getIndexType(), ArrayRef<Value>{cumulativeSizeInBytes, sizes[i]});
// Compute the size of an individual element. This emits the MLIR equivalent
// of the following sizeof(...) implementation in LLVM IR:
// %0 = getelementptr %elementType* null, %indexType 1
// %1 = ptrtoint %elementType* %0 to %indexType
// which is a common pattern of getting the size of a type in bytes.
auto convertedPtrType = typeConverter.convertType(elementType)
.cast<LLVM::LLVMType>()
.getPointerTo();
auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
auto gep = rewriter.create<LLVM::GEPOp>(
loc, convertedPtrType,
ArrayRef<Value>{nullPtr, createIndexConstant(rewriter, loc, 1)});
auto elementSize =
rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
auto elementSize = this->getSizeInBytes(loc, elementType, rewriter);
return rewriter.create<LLVM::MulOp>(
loc, getIndexType(), ArrayRef<Value>{cumulativeSizeInBytes, elementSize});
}

View File

@@ -25,9 +25,9 @@ func @main() {
%c6 = constant 6 : index
%cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
gpu.host_register %cast_data : memref<*xi32>
%cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
gpu.host_register %cast_sum : memref<*xi32>
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
@@ -58,6 +58,5 @@ func @main() {
return
}
func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>)
func @print_memref_i32(memref<*xi32>)

View File

@@ -25,9 +25,9 @@ func @main() {
%c6 = constant 6 : index
%cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
gpu.host_register %cast_data : memref<*xi32>
%cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
gpu.host_register %cast_sum : memref<*xi32>
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
@@ -58,6 +58,5 @@ func @main() {
return
}
func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>)
func @print_memref_i32(memref<*xi32>)

View File

@@ -25,9 +25,9 @@ func @main() {
%c6 = constant 6 : index
%cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
gpu.host_register %cast_data : memref<*xi32>
%cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
gpu.host_register %cast_sum : memref<*xi32>
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
@@ -58,6 +58,5 @@ func @main() {
return
}
func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>)
func @print_memref_i32(memref<*xi32>)

View File

@@ -11,7 +11,7 @@ func @main() {
%sy = dim %dst, %c1 : memref<?x?x?xf32>
%sz = dim %dst, %c0 : memref<?x?x?xf32>
%cast_dst = memref_cast %dst : memref<?x?x?xf32> to memref<*xf32>
call @mgpuMemHostRegisterFloat(%cast_dst) : (memref<*xf32>) -> ()
gpu.host_register %cast_dst : memref<*xf32>
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %sy, %block_z = %sz) {
%t0 = muli %tz, %block_y : index
@@ -28,5 +28,4 @@ func @main() {
return
}
func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
func @print_memref_f32(%ptr : memref<*xf32>)

View File

@@ -25,9 +25,9 @@ func @main() {
%c6 = constant 6 : index
%cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
gpu.host_register %cast_data : memref<*xi32>
%cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
gpu.host_register %cast_sum : memref<*xi32>
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
@@ -58,6 +58,5 @@ func @main() {
return
}
func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>)
func @print_memref_i32(memref<*xi32>)

View File

@@ -8,7 +8,7 @@ func @main() {
%c0 = constant 0 : index
%sx = dim %dst, %c0 : memref<?xf32>
%cast_dst = memref_cast %dst : memref<?xf32> to memref<*xf32>
call @mgpuMemHostRegisterFloat(%cast_dst) : (memref<*xf32>) -> ()
gpu.host_register %cast_dst : memref<*xf32>
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) {
%val = index_cast %tx : index to i32
@@ -25,5 +25,4 @@ func @main() {
return
}
func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
func @print_memref_f32(memref<*xf32>)

View File

@@ -25,9 +25,9 @@ func @main() {
%c6 = constant 6 : index
%cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
gpu.host_register %cast_data : memref<*xi32>
%cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
gpu.host_register %cast_sum : memref<*xi32>
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
@@ -58,6 +58,5 @@ func @main() {
return
}
func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>)
func @print_memref_i32(memref<*xi32>)

View File

@@ -18,7 +18,7 @@ func @main() {
%21 = constant 5 : i32
%22 = memref_cast %arg0 : memref<5xf32> to memref<?xf32>
%23 = memref_cast %22 : memref<?xf32> to memref<*xf32>
call @mgpuMemHostRegisterFloat(%23) : (memref<*xf32>) -> ()
gpu.host_register %23 : memref<*xf32>
call @print_memref_f32(%23) : (memref<*xf32>) -> ()
%24 = constant 1.0 : f32
call @other_func(%24, %22) : (f32, memref<?xf32>) -> ()
@@ -26,5 +26,4 @@ func @main() {
return
}
func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
func @print_memref_f32(%ptr : memref<*xf32>)

View File

@@ -26,11 +26,11 @@ func @main() {
%c6 = constant 6 : index
%cast_data = memref_cast %data : memref<2x6xf32> to memref<*xf32>
call @mgpuMemHostRegisterFloat(%cast_data) : (memref<*xf32>) -> ()
gpu.host_register %cast_data : memref<*xf32>
%cast_sum = memref_cast %sum : memref<2xf32> to memref<*xf32>
call @mgpuMemHostRegisterFloat(%cast_sum) : (memref<*xf32>) -> ()
gpu.host_register %cast_sum : memref<*xf32>
%cast_mul = memref_cast %mul : memref<2xf32> to memref<*xf32>
call @mgpuMemHostRegisterFloat(%cast_mul) : (memref<*xf32>) -> ()
gpu.host_register %cast_mul : memref<*xf32>
store %cst0, %data[%c0, %c0] : memref<2x6xf32>
store %cst1, %data[%c0, %c1] : memref<2x6xf32>
@@ -66,5 +66,4 @@ func @main() {
return
}
func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
func @print_memref_f32(memref<*xf32>)

View File

@@ -7,8 +7,8 @@ func @main() {
%one = constant 1 : index
%c0 = constant 0 : index
%sx = dim %dst, %c0 : memref<?xf32>
%cast_dest = memref_cast %dst : memref<?xf32> to memref<*xf32>
call @mgpuMemHostRegisterFloat(%cast_dest) : (memref<*xf32>) -> ()
%cast_dst = memref_cast %dst : memref<?xf32> to memref<*xf32>
gpu.host_register %cast_dst : memref<*xf32>
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) {
%t0 = index_cast %tx : index to i32
@@ -24,9 +24,8 @@ func @main() {
store %value, %dst[%tx] : memref<?xf32>
gpu.terminator
}
call @print_memref_f32(%cast_dest) : (memref<*xf32>) -> ()
call @print_memref_f32(%cast_dst) : (memref<*xf32>) -> ()
return
}
func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
func @print_memref_f32(%ptr : memref<*xf32>)

View File

@@ -8,7 +8,7 @@ func @main() {
%c0 = constant 0 : index
%sx = dim %dst, %c0 : memref<?xi32>
%cast_dst = memref_cast %dst : memref<?xi32> to memref<*xi32>
call @mgpuMemHostRegisterInt32(%cast_dst) : (memref<*xi32>) -> ()
gpu.host_register %cast_dst : memref<*xi32>
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) {
%t0 = index_cast %tx : index to i32
@@ -25,5 +25,4 @@ func @main() {
return
}
func @mgpuMemHostRegisterInt32(%memref : memref<*xi32>)
func @print_memref_i32(%memref : memref<*xi32>)

View File

@@ -18,7 +18,7 @@ func @main() {
%21 = constant 5 : i32
%22 = memref_cast %arg0 : memref<5xf32> to memref<?xf32>
%cast = memref_cast %22 : memref<?xf32> to memref<*xf32>
call @mgpuMemHostRegisterFloat(%cast) : (memref<*xf32>) -> ()
gpu.host_register %cast : memref<*xf32>
%23 = memref_cast %22 : memref<?xf32> to memref<*xf32>
call @print_memref_f32(%23) : (memref<*xf32>) -> ()
%24 = constant 1.0 : f32
@@ -28,6 +28,5 @@ func @main() {
return
}
func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
func @mgpuMemGetDeviceMemRef1dFloat(%ptr : memref<?xf32>) -> (memref<?xf32>)
func @print_memref_f32(%ptr : memref<*xf32>)

View File

@@ -8,7 +8,7 @@ func @main() {
%c1 = constant 1 : index
%sx = dim %dst, %c0 : memref<?xi32>
%cast_dst = memref_cast %dst : memref<?xi32> to memref<*xi32>
call @mgpuMemHostRegisterInt32(%cast_dst) : (memref<*xi32>) -> ()
gpu.host_register %cast_dst : memref<*xi32>
%dst_device = call @mgpuMemGetDeviceMemRef1dInt32(%dst) : (memref<?xi32>) -> (memref<?xi32>)
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %c1, %block_z = %c1) {
@@ -26,6 +26,5 @@ func @main() {
return
}
func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>)
func @mgpuMemGetDeviceMemRef1dInt32(%ptr : memref<?xi32>) -> (memref<?xi32>)
func @print_memref_i32(%ptr : memref<*xi32>)

View File

@@ -26,9 +26,9 @@ func @main() {
%6 = memref_cast %3 : memref<?xf32> to memref<*xf32>
%7 = memref_cast %4 : memref<?xf32> to memref<*xf32>
%8 = memref_cast %5 : memref<?xf32> to memref<*xf32>
call @mgpuMemHostRegisterFloat(%6) : (memref<*xf32>) -> ()
call @mgpuMemHostRegisterFloat(%7) : (memref<*xf32>) -> ()
call @mgpuMemHostRegisterFloat(%8) : (memref<*xf32>) -> ()
gpu.host_register %6 : memref<*xf32>
gpu.host_register %7 : memref<*xf32>
gpu.host_register %8 : memref<*xf32>
%9 = call @mgpuMemGetDeviceMemRef1dFloat(%3) : (memref<?xf32>) -> (memref<?xf32>)
%10 = call @mgpuMemGetDeviceMemRef1dFloat(%4) : (memref<?xf32>) -> (memref<?xf32>)
%11 = call @mgpuMemGetDeviceMemRef1dFloat(%5) : (memref<?xf32>) -> (memref<?xf32>)
@@ -38,6 +38,5 @@ func @main() {
return
}
func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
func @mgpuMemGetDeviceMemRef1dFloat(%ptr : memref<?xf32>) -> (memref<?xf32>)
func @print_memref_f32(%ptr : memref<*xf32>)

View File

@@ -55,8 +55,8 @@ func @main() {
%cast0 = memref_cast %22 : memref<?xf32> to memref<*xf32>
%cast1 = memref_cast %23 : memref<?xf32> to memref<*xf32>
call @mgpuMemHostRegisterFloat(%cast0) : (memref<*xf32>) -> ()
call @mgpuMemHostRegisterFloat(%cast1) : (memref<*xf32>) -> ()
gpu.host_register %cast0 : memref<*xf32>
gpu.host_register %cast1 : memref<*xf32>
%24 = call @mgpuMemGetDeviceMemRef1dFloat(%22) : (memref<?xf32>) -> (memref<?xf32>)
%26 = call @mgpuMemGetDeviceMemRef1dFloat(%23) : (memref<?xf32>) -> (memref<?xf32>)
@@ -71,6 +71,5 @@ func @main() {
return
}
func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
func @mgpuMemGetDeviceMemRef1dFloat(%ptr : memref<?xf32>) -> (memref<?xf32>)
func @print_memref_f32(%ptr : memref<*xf32>)

View File

@@ -75,17 +75,19 @@ extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
CUDA_REPORT_IF_ERROR(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0));
}
// Allows to register a MemRef with the CUDA runtime. Initializes array with
// value. Helpful until we have transfer functions implemented.
template <typename T>
void mgpuMemHostRegisterMemRef(const DynamicMemRefType<T> &memRef, T value) {
llvm::SmallVector<int64_t, 4> denseStrides(memRef.rank);
llvm::ArrayRef<int64_t> sizes(memRef.sizes, memRef.rank);
llvm::ArrayRef<int64_t> strides(memRef.strides, memRef.rank);
// Allows to register a MemRef with the CUDA runtime. Helpful until we have
// transfer functions implemented.
extern "C" void
mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor,
int64_t elementSizeBytes) {
llvm::SmallVector<int64_t, 4> denseStrides(rank);
llvm::ArrayRef<int64_t> sizes(descriptor->sizes, rank);
llvm::ArrayRef<int64_t> strides(sizes.end(), rank);
std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(),
std::multiplies<int64_t>());
auto count = denseStrides.front();
auto sizeBytes = denseStrides.front() * elementSizeBytes;
// Only densely packed tensors are currently supported.
std::rotate(denseStrides.begin(), denseStrides.begin() + 1,
@@ -93,17 +95,6 @@ void mgpuMemHostRegisterMemRef(const DynamicMemRefType<T> &memRef, T value) {
denseStrides.back() = 1;
assert(strides == llvm::makeArrayRef(denseStrides));
auto *pointer = memRef.data + memRef.offset;
std::fill_n(pointer, count, value);
mgpuMemHostRegister(pointer, count * sizeof(T));
}
extern "C" void mgpuMemHostRegisterFloat(int64_t rank, void *ptr) {
UnrankedMemRefType<float> memRef = {rank, ptr};
mgpuMemHostRegisterMemRef(DynamicMemRefType<float>(memRef), 1.23f);
}
extern "C" void mgpuMemHostRegisterInt32(int64_t rank, void *ptr) {
UnrankedMemRefType<int32_t> memRef = {rank, ptr};
mgpuMemHostRegisterMemRef(DynamicMemRefType<int32_t>(memRef), 123);
auto ptr = descriptor->data + descriptor->offset * elementSizeBytes;
mgpuMemHostRegister(ptr, sizeBytes);
}

View File

@@ -76,17 +76,19 @@ extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
HIP_REPORT_IF_ERROR(hipHostRegister(ptr, sizeBytes, /*flags=*/0));
}
// Allows to register a MemRef with the ROCM runtime. Initializes array with
// value. Helpful until we have transfer functions implemented.
template <typename T>
void mgpuMemHostRegisterMemRef(T *pointer, llvm::ArrayRef<int64_t> sizes,
llvm::ArrayRef<int64_t> strides, T value) {
assert(sizes.size() == strides.size());
llvm::SmallVector<int64_t, 4> denseStrides(strides.size());
// Allows to register a MemRef with the ROCm runtime. Helpful until we have
// transfer functions implemented.
extern "C" void
mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor,
int64_t elementSizeBytes) {
llvm::SmallVector<int64_t, 4> denseStrides(rank);
llvm::ArrayRef<int64_t> sizes(descriptor->sizes, rank);
llvm::ArrayRef<int64_t> strides(sizes.end(), rank);
std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(),
std::multiplies<int64_t>());
auto count = denseStrides.front();
auto sizeBytes = denseStrides.front() * elementSizeBytes;
// Only densely packed tensors are currently supported.
std::rotate(denseStrides.begin(), denseStrides.begin() + 1,
@@ -94,22 +96,8 @@ void mgpuMemHostRegisterMemRef(T *pointer, llvm::ArrayRef<int64_t> sizes,
denseStrides.back() = 1;
assert(strides == llvm::makeArrayRef(denseStrides));
std::fill_n(pointer, count, value);
mgpuMemHostRegister(pointer, count * sizeof(T));
}
extern "C" void mgpuMemHostRegisterFloat(int64_t rank, void *ptr) {
auto *desc = static_cast<StridedMemRefType<float, 1> *>(ptr);
auto sizes = llvm::ArrayRef<int64_t>(desc->sizes, rank);
auto strides = llvm::ArrayRef<int64_t>(desc->sizes + rank, rank);
mgpuMemHostRegisterMemRef(desc->data + desc->offset, sizes, strides, 1.23f);
}
extern "C" void mgpuMemHostRegisterInt32(int64_t rank, void *ptr) {
auto *desc = static_cast<StridedMemRefType<int32_t, 1> *>(ptr);
auto sizes = llvm::ArrayRef<int64_t>(desc->sizes, rank);
auto strides = llvm::ArrayRef<int64_t>(desc->sizes + rank, rank);
mgpuMemHostRegisterMemRef(desc->data + desc->offset, sizes, strides, 123);
auto ptr = descriptor->data + descriptor->offset * elementSizeBytes;
mgpuMemHostRegister(ptr, sizeBytes);
}
template <typename T>