mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 10:55:58 +08:00
[MLIR] Adding gpu.host_register op and lower it to a runtime call.
Reviewed By: herhut Differential Revision: https://reviews.llvm.org/D85631
This commit is contained in:
@@ -440,6 +440,10 @@ public:
|
||||
ConversionPatternRewriter &rewriter,
|
||||
SmallVectorImpl<Value> &sizes) const;
|
||||
|
||||
/// Computes the size of type in bytes.
|
||||
Value getSizeInBytes(Location loc, Type type,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
/// Computes total size in bytes of to store the given shape.
|
||||
Value getCumulativeSizeInBytes(Location loc, Type elementType,
|
||||
ArrayRef<Value> shape,
|
||||
|
||||
@@ -741,4 +741,16 @@ def GPU_ModuleEndOp : GPU_Op<"module_end", [
|
||||
let printer = [{ p << getOperationName(); }];
|
||||
}
|
||||
|
||||
def GPU_HostRegisterOp : GPU_Op<"host_register">,
|
||||
Arguments<(ins AnyUnrankedMemRef:$value)> {
|
||||
let summary = "Registers a memref for access from device.";
|
||||
let description = [{
|
||||
This op registers the host memory pointed to by a memref to be accessed from
|
||||
a device.
|
||||
}];
|
||||
|
||||
let assemblyFormat = "$value attr-dict `:` type($value)";
|
||||
let verifier = [{ return success(); }];
|
||||
}
|
||||
|
||||
#endif // GPU_OPS
|
||||
|
||||
@@ -117,6 +117,26 @@ protected:
|
||||
"mgpuStreamSynchronize",
|
||||
llvmVoidType,
|
||||
{llvmPointerType /* void *stream */}};
|
||||
FunctionCallBuilder hostRegisterCallBuilder = {
|
||||
"mgpuMemHostRegisterMemRef",
|
||||
llvmVoidType,
|
||||
{llvmIntPtrType /* intptr_t rank */,
|
||||
llvmPointerType /* void *memrefDesc */,
|
||||
llvmIntPtrType /* intptr_t elementSizeBytes */}};
|
||||
};
|
||||
|
||||
/// A rewrite patter to convert gpu.host_register operations into a GPU runtime
|
||||
/// call. Currently it supports CUDA and ROCm (HIP).
|
||||
class ConvertHostRegisterOpToGpuRuntimeCallPattern
|
||||
: public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
|
||||
public:
|
||||
ConvertHostRegisterOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
|
||||
: ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
|
||||
|
||||
private:
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// A rewrite patter to convert gpu.launch_func operations into a sequence of
|
||||
@@ -192,6 +212,33 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
|
||||
builder.getSymbolRefAttr(function), arguments);
|
||||
}
|
||||
|
||||
// Returns whether value is of LLVM type.
|
||||
static bool isLLVMType(Value value) {
|
||||
return value.getType().isa<LLVM::LLVMType>();
|
||||
}
|
||||
|
||||
LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
|
||||
Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
if (!llvm::all_of(operands, isLLVMType))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Cannot convert if operands aren't of LLVM type.");
|
||||
|
||||
Location loc = op->getLoc();
|
||||
|
||||
auto memRefType = cast<gpu::HostRegisterOp>(op).value().getType();
|
||||
auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType();
|
||||
auto elementSize = getSizeInBytes(loc, elementType, rewriter);
|
||||
|
||||
auto arguments =
|
||||
typeConverter.promoteOperands(loc, op->getOperands(), operands, rewriter);
|
||||
arguments.push_back(elementSize);
|
||||
hostRegisterCallBuilder.create(loc, rewriter, arguments);
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Creates a struct containing all kernel parameters on the stack and returns
|
||||
// an array of type-erased pointers to the fields of the struct. The array can
|
||||
// then be passed to the CUDA / ROCm (HIP) kernel launch calls.
|
||||
@@ -269,11 +316,6 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
|
||||
LLVM::Linkage::Internal);
|
||||
}
|
||||
|
||||
// Returns whether value is of LLVM type.
|
||||
static bool isLLVMType(Value value) {
|
||||
return value.getType().isa<LLVM::LLVMType>();
|
||||
}
|
||||
|
||||
// Emits LLVM IR to launch a kernel function. Expects the module that contains
|
||||
// the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a
|
||||
// hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR.
|
||||
@@ -351,6 +393,7 @@ mlir::createGpuToLLVMConversionPass(StringRef gpuBinaryAnnotation) {
|
||||
void mlir::populateGpuToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
|
||||
StringRef gpuBinaryAnnotation) {
|
||||
patterns.insert<ConvertHostRegisterOpToGpuRuntimeCallPattern>(converter);
|
||||
patterns.insert<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(
|
||||
converter, gpuBinaryAnnotation);
|
||||
patterns.insert<EraseGpuModuleOpPattern>(&converter.getContext());
|
||||
|
||||
@@ -927,6 +927,22 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
|
||||
: createIndexConstant(rewriter, loc, s));
|
||||
}
|
||||
|
||||
Value ConvertToLLVMPattern::getSizeInBytes(
|
||||
Location loc, Type type, ConversionPatternRewriter &rewriter) const {
|
||||
// Compute the size of an individual element. This emits the MLIR equivalent
|
||||
// of the following sizeof(...) implementation in LLVM IR:
|
||||
// %0 = getelementptr %elementType* null, %indexType 1
|
||||
// %1 = ptrtoint %elementType* %0 to %indexType
|
||||
// which is a common pattern of getting the size of a type in bytes.
|
||||
auto convertedPtrType =
|
||||
typeConverter.convertType(type).cast<LLVM::LLVMType>().getPointerTo();
|
||||
auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
|
||||
auto gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, convertedPtrType,
|
||||
ArrayRef<Value>{nullPtr, createIndexConstant(rewriter, loc, 1)});
|
||||
return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
|
||||
}
|
||||
|
||||
Value ConvertToLLVMPattern::getCumulativeSizeInBytes(
|
||||
Location loc, Type elementType, ArrayRef<Value> sizes,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
@@ -936,21 +952,7 @@ Value ConvertToLLVMPattern::getCumulativeSizeInBytes(
|
||||
for (unsigned i = 1, e = sizes.size(); i < e; ++i)
|
||||
cumulativeSizeInBytes = rewriter.create<LLVM::MulOp>(
|
||||
loc, getIndexType(), ArrayRef<Value>{cumulativeSizeInBytes, sizes[i]});
|
||||
|
||||
// Compute the size of an individual element. This emits the MLIR equivalent
|
||||
// of the following sizeof(...) implementation in LLVM IR:
|
||||
// %0 = getelementptr %elementType* null, %indexType 1
|
||||
// %1 = ptrtoint %elementType* %0 to %indexType
|
||||
// which is a common pattern of getting the size of a type in bytes.
|
||||
auto convertedPtrType = typeConverter.convertType(elementType)
|
||||
.cast<LLVM::LLVMType>()
|
||||
.getPointerTo();
|
||||
auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
|
||||
auto gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, convertedPtrType,
|
||||
ArrayRef<Value>{nullPtr, createIndexConstant(rewriter, loc, 1)});
|
||||
auto elementSize =
|
||||
rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep);
|
||||
auto elementSize = this->getSizeInBytes(loc, elementType, rewriter);
|
||||
return rewriter.create<LLVM::MulOp>(
|
||||
loc, getIndexType(), ArrayRef<Value>{cumulativeSizeInBytes, elementSize});
|
||||
}
|
||||
|
||||
@@ -25,9 +25,9 @@ func @main() {
|
||||
%c6 = constant 6 : index
|
||||
|
||||
%cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
|
||||
call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
|
||||
gpu.host_register %cast_data : memref<*xi32>
|
||||
%cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
|
||||
call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
|
||||
gpu.host_register %cast_sum : memref<*xi32>
|
||||
|
||||
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
|
||||
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
|
||||
@@ -58,6 +58,5 @@ func @main() {
|
||||
return
|
||||
}
|
||||
|
||||
func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>)
|
||||
func @print_memref_i32(memref<*xi32>)
|
||||
|
||||
|
||||
@@ -25,9 +25,9 @@ func @main() {
|
||||
%c6 = constant 6 : index
|
||||
|
||||
%cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
|
||||
call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
|
||||
gpu.host_register %cast_data : memref<*xi32>
|
||||
%cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
|
||||
call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
|
||||
gpu.host_register %cast_sum : memref<*xi32>
|
||||
|
||||
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
|
||||
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
|
||||
@@ -58,6 +58,5 @@ func @main() {
|
||||
return
|
||||
}
|
||||
|
||||
func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>)
|
||||
func @print_memref_i32(memref<*xi32>)
|
||||
|
||||
|
||||
@@ -25,9 +25,9 @@ func @main() {
|
||||
%c6 = constant 6 : index
|
||||
|
||||
%cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
|
||||
call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
|
||||
gpu.host_register %cast_data : memref<*xi32>
|
||||
%cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
|
||||
call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
|
||||
gpu.host_register %cast_sum : memref<*xi32>
|
||||
|
||||
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
|
||||
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
|
||||
@@ -58,6 +58,5 @@ func @main() {
|
||||
return
|
||||
}
|
||||
|
||||
func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>)
|
||||
func @print_memref_i32(memref<*xi32>)
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ func @main() {
|
||||
%sy = dim %dst, %c1 : memref<?x?x?xf32>
|
||||
%sz = dim %dst, %c0 : memref<?x?x?xf32>
|
||||
%cast_dst = memref_cast %dst : memref<?x?x?xf32> to memref<*xf32>
|
||||
call @mgpuMemHostRegisterFloat(%cast_dst) : (memref<*xf32>) -> ()
|
||||
gpu.host_register %cast_dst : memref<*xf32>
|
||||
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
|
||||
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %sy, %block_z = %sz) {
|
||||
%t0 = muli %tz, %block_y : index
|
||||
@@ -28,5 +28,4 @@ func @main() {
|
||||
return
|
||||
}
|
||||
|
||||
func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
|
||||
func @print_memref_f32(%ptr : memref<*xf32>)
|
||||
|
||||
@@ -25,9 +25,9 @@ func @main() {
|
||||
%c6 = constant 6 : index
|
||||
|
||||
%cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
|
||||
call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
|
||||
gpu.host_register %cast_data : memref<*xi32>
|
||||
%cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
|
||||
call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
|
||||
gpu.host_register %cast_sum : memref<*xi32>
|
||||
|
||||
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
|
||||
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
|
||||
@@ -58,6 +58,5 @@ func @main() {
|
||||
return
|
||||
}
|
||||
|
||||
func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>)
|
||||
func @print_memref_i32(memref<*xi32>)
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ func @main() {
|
||||
%c0 = constant 0 : index
|
||||
%sx = dim %dst, %c0 : memref<?xf32>
|
||||
%cast_dst = memref_cast %dst : memref<?xf32> to memref<*xf32>
|
||||
call @mgpuMemHostRegisterFloat(%cast_dst) : (memref<*xf32>) -> ()
|
||||
gpu.host_register %cast_dst : memref<*xf32>
|
||||
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
|
||||
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) {
|
||||
%val = index_cast %tx : index to i32
|
||||
@@ -25,5 +25,4 @@ func @main() {
|
||||
return
|
||||
}
|
||||
|
||||
func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
|
||||
func @print_memref_f32(memref<*xf32>)
|
||||
|
||||
@@ -25,9 +25,9 @@ func @main() {
|
||||
%c6 = constant 6 : index
|
||||
|
||||
%cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32>
|
||||
call @mgpuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> ()
|
||||
gpu.host_register %cast_data : memref<*xi32>
|
||||
%cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32>
|
||||
call @mgpuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> ()
|
||||
gpu.host_register %cast_sum : memref<*xi32>
|
||||
|
||||
store %cst0, %data[%c0, %c0] : memref<2x6xi32>
|
||||
store %cst1, %data[%c0, %c1] : memref<2x6xi32>
|
||||
@@ -58,6 +58,5 @@ func @main() {
|
||||
return
|
||||
}
|
||||
|
||||
func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>)
|
||||
func @print_memref_i32(memref<*xi32>)
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ func @main() {
|
||||
%21 = constant 5 : i32
|
||||
%22 = memref_cast %arg0 : memref<5xf32> to memref<?xf32>
|
||||
%23 = memref_cast %22 : memref<?xf32> to memref<*xf32>
|
||||
call @mgpuMemHostRegisterFloat(%23) : (memref<*xf32>) -> ()
|
||||
gpu.host_register %23 : memref<*xf32>
|
||||
call @print_memref_f32(%23) : (memref<*xf32>) -> ()
|
||||
%24 = constant 1.0 : f32
|
||||
call @other_func(%24, %22) : (f32, memref<?xf32>) -> ()
|
||||
@@ -26,5 +26,4 @@ func @main() {
|
||||
return
|
||||
}
|
||||
|
||||
func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
|
||||
func @print_memref_f32(%ptr : memref<*xf32>)
|
||||
|
||||
@@ -26,11 +26,11 @@ func @main() {
|
||||
%c6 = constant 6 : index
|
||||
|
||||
%cast_data = memref_cast %data : memref<2x6xf32> to memref<*xf32>
|
||||
call @mgpuMemHostRegisterFloat(%cast_data) : (memref<*xf32>) -> ()
|
||||
gpu.host_register %cast_data : memref<*xf32>
|
||||
%cast_sum = memref_cast %sum : memref<2xf32> to memref<*xf32>
|
||||
call @mgpuMemHostRegisterFloat(%cast_sum) : (memref<*xf32>) -> ()
|
||||
gpu.host_register %cast_sum : memref<*xf32>
|
||||
%cast_mul = memref_cast %mul : memref<2xf32> to memref<*xf32>
|
||||
call @mgpuMemHostRegisterFloat(%cast_mul) : (memref<*xf32>) -> ()
|
||||
gpu.host_register %cast_mul : memref<*xf32>
|
||||
|
||||
store %cst0, %data[%c0, %c0] : memref<2x6xf32>
|
||||
store %cst1, %data[%c0, %c1] : memref<2x6xf32>
|
||||
@@ -66,5 +66,4 @@ func @main() {
|
||||
return
|
||||
}
|
||||
|
||||
func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
|
||||
func @print_memref_f32(memref<*xf32>)
|
||||
|
||||
@@ -7,8 +7,8 @@ func @main() {
|
||||
%one = constant 1 : index
|
||||
%c0 = constant 0 : index
|
||||
%sx = dim %dst, %c0 : memref<?xf32>
|
||||
%cast_dest = memref_cast %dst : memref<?xf32> to memref<*xf32>
|
||||
call @mgpuMemHostRegisterFloat(%cast_dest) : (memref<*xf32>) -> ()
|
||||
%cast_dst = memref_cast %dst : memref<?xf32> to memref<*xf32>
|
||||
gpu.host_register %cast_dst : memref<*xf32>
|
||||
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
|
||||
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) {
|
||||
%t0 = index_cast %tx : index to i32
|
||||
@@ -24,9 +24,8 @@ func @main() {
|
||||
store %value, %dst[%tx] : memref<?xf32>
|
||||
gpu.terminator
|
||||
}
|
||||
call @print_memref_f32(%cast_dest) : (memref<*xf32>) -> ()
|
||||
call @print_memref_f32(%cast_dst) : (memref<*xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
|
||||
func @print_memref_f32(%ptr : memref<*xf32>)
|
||||
|
||||
@@ -8,7 +8,7 @@ func @main() {
|
||||
%c0 = constant 0 : index
|
||||
%sx = dim %dst, %c0 : memref<?xi32>
|
||||
%cast_dst = memref_cast %dst : memref<?xi32> to memref<*xi32>
|
||||
call @mgpuMemHostRegisterInt32(%cast_dst) : (memref<*xi32>) -> ()
|
||||
gpu.host_register %cast_dst : memref<*xi32>
|
||||
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
|
||||
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) {
|
||||
%t0 = index_cast %tx : index to i32
|
||||
@@ -25,5 +25,4 @@ func @main() {
|
||||
return
|
||||
}
|
||||
|
||||
func @mgpuMemHostRegisterInt32(%memref : memref<*xi32>)
|
||||
func @print_memref_i32(%memref : memref<*xi32>)
|
||||
|
||||
@@ -18,7 +18,7 @@ func @main() {
|
||||
%21 = constant 5 : i32
|
||||
%22 = memref_cast %arg0 : memref<5xf32> to memref<?xf32>
|
||||
%cast = memref_cast %22 : memref<?xf32> to memref<*xf32>
|
||||
call @mgpuMemHostRegisterFloat(%cast) : (memref<*xf32>) -> ()
|
||||
gpu.host_register %cast : memref<*xf32>
|
||||
%23 = memref_cast %22 : memref<?xf32> to memref<*xf32>
|
||||
call @print_memref_f32(%23) : (memref<*xf32>) -> ()
|
||||
%24 = constant 1.0 : f32
|
||||
@@ -28,6 +28,5 @@ func @main() {
|
||||
return
|
||||
}
|
||||
|
||||
func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
|
||||
func @mgpuMemGetDeviceMemRef1dFloat(%ptr : memref<?xf32>) -> (memref<?xf32>)
|
||||
func @print_memref_f32(%ptr : memref<*xf32>)
|
||||
|
||||
@@ -8,7 +8,7 @@ func @main() {
|
||||
%c1 = constant 1 : index
|
||||
%sx = dim %dst, %c0 : memref<?xi32>
|
||||
%cast_dst = memref_cast %dst : memref<?xi32> to memref<*xi32>
|
||||
call @mgpuMemHostRegisterInt32(%cast_dst) : (memref<*xi32>) -> ()
|
||||
gpu.host_register %cast_dst : memref<*xi32>
|
||||
%dst_device = call @mgpuMemGetDeviceMemRef1dInt32(%dst) : (memref<?xi32>) -> (memref<?xi32>)
|
||||
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
|
||||
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %c1, %block_z = %c1) {
|
||||
@@ -26,6 +26,5 @@ func @main() {
|
||||
return
|
||||
}
|
||||
|
||||
func @mgpuMemHostRegisterInt32(%ptr : memref<*xi32>)
|
||||
func @mgpuMemGetDeviceMemRef1dInt32(%ptr : memref<?xi32>) -> (memref<?xi32>)
|
||||
func @print_memref_i32(%ptr : memref<*xi32>)
|
||||
|
||||
@@ -26,9 +26,9 @@ func @main() {
|
||||
%6 = memref_cast %3 : memref<?xf32> to memref<*xf32>
|
||||
%7 = memref_cast %4 : memref<?xf32> to memref<*xf32>
|
||||
%8 = memref_cast %5 : memref<?xf32> to memref<*xf32>
|
||||
call @mgpuMemHostRegisterFloat(%6) : (memref<*xf32>) -> ()
|
||||
call @mgpuMemHostRegisterFloat(%7) : (memref<*xf32>) -> ()
|
||||
call @mgpuMemHostRegisterFloat(%8) : (memref<*xf32>) -> ()
|
||||
gpu.host_register %6 : memref<*xf32>
|
||||
gpu.host_register %7 : memref<*xf32>
|
||||
gpu.host_register %8 : memref<*xf32>
|
||||
%9 = call @mgpuMemGetDeviceMemRef1dFloat(%3) : (memref<?xf32>) -> (memref<?xf32>)
|
||||
%10 = call @mgpuMemGetDeviceMemRef1dFloat(%4) : (memref<?xf32>) -> (memref<?xf32>)
|
||||
%11 = call @mgpuMemGetDeviceMemRef1dFloat(%5) : (memref<?xf32>) -> (memref<?xf32>)
|
||||
@@ -38,6 +38,5 @@ func @main() {
|
||||
return
|
||||
}
|
||||
|
||||
func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
|
||||
func @mgpuMemGetDeviceMemRef1dFloat(%ptr : memref<?xf32>) -> (memref<?xf32>)
|
||||
func @print_memref_f32(%ptr : memref<*xf32>)
|
||||
|
||||
@@ -55,8 +55,8 @@ func @main() {
|
||||
%cast0 = memref_cast %22 : memref<?xf32> to memref<*xf32>
|
||||
%cast1 = memref_cast %23 : memref<?xf32> to memref<*xf32>
|
||||
|
||||
call @mgpuMemHostRegisterFloat(%cast0) : (memref<*xf32>) -> ()
|
||||
call @mgpuMemHostRegisterFloat(%cast1) : (memref<*xf32>) -> ()
|
||||
gpu.host_register %cast0 : memref<*xf32>
|
||||
gpu.host_register %cast1 : memref<*xf32>
|
||||
|
||||
%24 = call @mgpuMemGetDeviceMemRef1dFloat(%22) : (memref<?xf32>) -> (memref<?xf32>)
|
||||
%26 = call @mgpuMemGetDeviceMemRef1dFloat(%23) : (memref<?xf32>) -> (memref<?xf32>)
|
||||
@@ -71,6 +71,5 @@ func @main() {
|
||||
return
|
||||
}
|
||||
|
||||
func @mgpuMemHostRegisterFloat(%ptr : memref<*xf32>)
|
||||
func @mgpuMemGetDeviceMemRef1dFloat(%ptr : memref<?xf32>) -> (memref<?xf32>)
|
||||
func @print_memref_f32(%ptr : memref<*xf32>)
|
||||
|
||||
@@ -75,17 +75,19 @@ extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
|
||||
CUDA_REPORT_IF_ERROR(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0));
|
||||
}
|
||||
|
||||
// Allows to register a MemRef with the CUDA runtime. Initializes array with
|
||||
// value. Helpful until we have transfer functions implemented.
|
||||
template <typename T>
|
||||
void mgpuMemHostRegisterMemRef(const DynamicMemRefType<T> &memRef, T value) {
|
||||
llvm::SmallVector<int64_t, 4> denseStrides(memRef.rank);
|
||||
llvm::ArrayRef<int64_t> sizes(memRef.sizes, memRef.rank);
|
||||
llvm::ArrayRef<int64_t> strides(memRef.strides, memRef.rank);
|
||||
// Allows to register a MemRef with the CUDA runtime. Helpful until we have
|
||||
// transfer functions implemented.
|
||||
extern "C" void
|
||||
mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor,
|
||||
int64_t elementSizeBytes) {
|
||||
|
||||
llvm::SmallVector<int64_t, 4> denseStrides(rank);
|
||||
llvm::ArrayRef<int64_t> sizes(descriptor->sizes, rank);
|
||||
llvm::ArrayRef<int64_t> strides(sizes.end(), rank);
|
||||
|
||||
std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(),
|
||||
std::multiplies<int64_t>());
|
||||
auto count = denseStrides.front();
|
||||
auto sizeBytes = denseStrides.front() * elementSizeBytes;
|
||||
|
||||
// Only densely packed tensors are currently supported.
|
||||
std::rotate(denseStrides.begin(), denseStrides.begin() + 1,
|
||||
@@ -93,17 +95,6 @@ void mgpuMemHostRegisterMemRef(const DynamicMemRefType<T> &memRef, T value) {
|
||||
denseStrides.back() = 1;
|
||||
assert(strides == llvm::makeArrayRef(denseStrides));
|
||||
|
||||
auto *pointer = memRef.data + memRef.offset;
|
||||
std::fill_n(pointer, count, value);
|
||||
mgpuMemHostRegister(pointer, count * sizeof(T));
|
||||
}
|
||||
|
||||
extern "C" void mgpuMemHostRegisterFloat(int64_t rank, void *ptr) {
|
||||
UnrankedMemRefType<float> memRef = {rank, ptr};
|
||||
mgpuMemHostRegisterMemRef(DynamicMemRefType<float>(memRef), 1.23f);
|
||||
}
|
||||
|
||||
extern "C" void mgpuMemHostRegisterInt32(int64_t rank, void *ptr) {
|
||||
UnrankedMemRefType<int32_t> memRef = {rank, ptr};
|
||||
mgpuMemHostRegisterMemRef(DynamicMemRefType<int32_t>(memRef), 123);
|
||||
auto ptr = descriptor->data + descriptor->offset * elementSizeBytes;
|
||||
mgpuMemHostRegister(ptr, sizeBytes);
|
||||
}
|
||||
|
||||
@@ -76,17 +76,19 @@ extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
|
||||
HIP_REPORT_IF_ERROR(hipHostRegister(ptr, sizeBytes, /*flags=*/0));
|
||||
}
|
||||
|
||||
// Allows to register a MemRef with the ROCM runtime. Initializes array with
|
||||
// value. Helpful until we have transfer functions implemented.
|
||||
template <typename T>
|
||||
void mgpuMemHostRegisterMemRef(T *pointer, llvm::ArrayRef<int64_t> sizes,
|
||||
llvm::ArrayRef<int64_t> strides, T value) {
|
||||
assert(sizes.size() == strides.size());
|
||||
llvm::SmallVector<int64_t, 4> denseStrides(strides.size());
|
||||
// Allows to register a MemRef with the ROCm runtime. Helpful until we have
|
||||
// transfer functions implemented.
|
||||
extern "C" void
|
||||
mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor,
|
||||
int64_t elementSizeBytes) {
|
||||
|
||||
llvm::SmallVector<int64_t, 4> denseStrides(rank);
|
||||
llvm::ArrayRef<int64_t> sizes(descriptor->sizes, rank);
|
||||
llvm::ArrayRef<int64_t> strides(sizes.end(), rank);
|
||||
|
||||
std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(),
|
||||
std::multiplies<int64_t>());
|
||||
auto count = denseStrides.front();
|
||||
auto sizeBytes = denseStrides.front() * elementSizeBytes;
|
||||
|
||||
// Only densely packed tensors are currently supported.
|
||||
std::rotate(denseStrides.begin(), denseStrides.begin() + 1,
|
||||
@@ -94,22 +96,8 @@ void mgpuMemHostRegisterMemRef(T *pointer, llvm::ArrayRef<int64_t> sizes,
|
||||
denseStrides.back() = 1;
|
||||
assert(strides == llvm::makeArrayRef(denseStrides));
|
||||
|
||||
std::fill_n(pointer, count, value);
|
||||
mgpuMemHostRegister(pointer, count * sizeof(T));
|
||||
}
|
||||
|
||||
extern "C" void mgpuMemHostRegisterFloat(int64_t rank, void *ptr) {
|
||||
auto *desc = static_cast<StridedMemRefType<float, 1> *>(ptr);
|
||||
auto sizes = llvm::ArrayRef<int64_t>(desc->sizes, rank);
|
||||
auto strides = llvm::ArrayRef<int64_t>(desc->sizes + rank, rank);
|
||||
mgpuMemHostRegisterMemRef(desc->data + desc->offset, sizes, strides, 1.23f);
|
||||
}
|
||||
|
||||
extern "C" void mgpuMemHostRegisterInt32(int64_t rank, void *ptr) {
|
||||
auto *desc = static_cast<StridedMemRefType<int32_t, 1> *>(ptr);
|
||||
auto sizes = llvm::ArrayRef<int64_t>(desc->sizes, rank);
|
||||
auto strides = llvm::ArrayRef<int64_t>(desc->sizes + rank, rank);
|
||||
mgpuMemHostRegisterMemRef(desc->data + desc->offset, sizes, strides, 123);
|
||||
auto ptr = descriptor->data + descriptor->offset * elementSizeBytes;
|
||||
mgpuMemHostRegister(ptr, sizeBytes);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
||||
Reference in New Issue
Block a user