mirror of
https://github.com/intel/llvm.git
synced 2026-01-19 09:31:59 +08:00
[mlir][nvvm] Add lowering of gpu.printf to nvvm
When converting to nvvm lowering gpu.printf to vprintf allows us to support printing when running on cuda. Differential Revision: https://reviews.llvm.org/D141049
This commit is contained in:
@@ -172,7 +172,17 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
|
||||
return success();
|
||||
}
|
||||
|
||||
static const char formatStringPrefix[] = "printfFormat_";
|
||||
static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) {
|
||||
const char formatStringPrefix[] = "printfFormat_";
|
||||
// Get a unique global name.
|
||||
unsigned stringNumber = 0;
|
||||
SmallString<16> stringConstName;
|
||||
do {
|
||||
stringConstName.clear();
|
||||
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
|
||||
} while (moduleOp.lookupSymbol(stringConstName));
|
||||
return stringConstName;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
|
||||
@@ -225,13 +235,8 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
|
||||
auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
|
||||
Value printfDesc = printfBeginCall.getResult();
|
||||
|
||||
// Create a global constant for the format string
|
||||
unsigned stringNumber = 0;
|
||||
SmallString<16> stringConstName;
|
||||
do {
|
||||
stringConstName.clear();
|
||||
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
|
||||
} while (moduleOp.lookupSymbol(stringConstName));
|
||||
// Get a unique global name for the format.
|
||||
SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
|
||||
|
||||
llvm::SmallString<20> formatString(adaptor.getFormat());
|
||||
formatString.push_back('\0'); // Null terminate for C
|
||||
@@ -320,13 +325,8 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
|
||||
LLVM::LLVMFuncOp printfDecl =
|
||||
getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
|
||||
|
||||
// Create a global constant for the format string
|
||||
unsigned stringNumber = 0;
|
||||
SmallString<16> stringConstName;
|
||||
do {
|
||||
stringConstName.clear();
|
||||
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
|
||||
} while (moduleOp.lookupSymbol(stringConstName));
|
||||
// Get a unique global name for the format.
|
||||
SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
|
||||
|
||||
llvm::SmallString<20> formatString(adaptor.getFormat());
|
||||
formatString.push_back('\0'); // Null terminate for C
|
||||
@@ -359,6 +359,80 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
|
||||
gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = gpuPrintfOp->getLoc();
|
||||
|
||||
mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
|
||||
mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8);
|
||||
|
||||
// Note: this is the GPUModule op, not the ModuleOp that surrounds it
|
||||
// This ensures that global constants and declarations are placed within
|
||||
// the device code, not the host code
|
||||
auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
|
||||
|
||||
auto vprintfType =
|
||||
LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {i8Ptr, i8Ptr});
|
||||
LLVM::LLVMFuncOp vprintfDecl =
|
||||
getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
|
||||
|
||||
// Get a unique global name for the format.
|
||||
SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
|
||||
|
||||
llvm::SmallString<20> formatString(adaptor.getFormat());
|
||||
formatString.push_back('\0'); // Null terminate for C
|
||||
auto globalType =
|
||||
LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
|
||||
LLVM::GlobalOp global;
|
||||
{
|
||||
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
||||
global = rewriter.create<LLVM::GlobalOp>(
|
||||
loc, globalType,
|
||||
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
|
||||
rewriter.getStringAttr(formatString), /*allignment=*/0);
|
||||
}
|
||||
|
||||
// Get a pointer to the format string's first element
|
||||
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
|
||||
Value stringStart = rewriter.create<LLVM::GEPOp>(
|
||||
loc, i8Ptr, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
|
||||
SmallVector<Type> types;
|
||||
SmallVector<Value> args;
|
||||
// Promote and pack the arguments into a stack allocation.
|
||||
for (Value arg : adaptor.getArgs()) {
|
||||
Type type = arg.getType();
|
||||
Value promotedArg = arg;
|
||||
assert(type.isIntOrFloat());
|
||||
if (type.isa<FloatType>()) {
|
||||
type = rewriter.getF64Type();
|
||||
promotedArg = rewriter.create<LLVM::FPExtOp>(loc, type, arg);
|
||||
}
|
||||
types.push_back(type);
|
||||
args.push_back(promotedArg);
|
||||
}
|
||||
Type structType =
|
||||
LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types);
|
||||
Type structPtrType = LLVM::LLVMPointerType::get(structType);
|
||||
Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
|
||||
rewriter.getIndexAttr(1));
|
||||
Value tempAlloc = rewriter.create<LLVM::AllocaOp>(loc, structPtrType, one,
|
||||
/*alignment=*/0);
|
||||
for (auto [index, arg] : llvm::enumerate(args)) {
|
||||
Value ptr = rewriter.create<LLVM::GEPOp>(
|
||||
loc, LLVM::LLVMPointerType::get(arg.getType()), tempAlloc,
|
||||
ArrayRef<LLVM::GEPArg>{0, index});
|
||||
rewriter.create<LLVM::StoreOp>(loc, arg, ptr);
|
||||
}
|
||||
tempAlloc = rewriter.create<LLVM::BitcastOp>(loc, i8Ptr, tempAlloc);
|
||||
std::array<Value, 2> printfArgs = {stringStart, tempAlloc};
|
||||
|
||||
rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs);
|
||||
rewriter.eraseOp(gpuPrintfOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Unrolls op if it's operating on vectors.
|
||||
LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
|
||||
@@ -67,6 +67,16 @@ private:
|
||||
int addressSpace;
|
||||
};
|
||||
|
||||
/// Lowering of gpu.printf to a vprintf standard library.
|
||||
struct GPUPrintfOpToVPrintfLowering
|
||||
: public ConvertOpToLLVMPattern<gpu::PrintfOp> {
|
||||
using ConvertOpToLLVMPattern<gpu::PrintfOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
struct GPUReturnOpLowering : public ConvertOpToLLVMPattern<gpu::ReturnOp> {
|
||||
using ConvertOpToLLVMPattern<gpu::ReturnOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
|
||||
@@ -239,6 +239,7 @@ static void populateOpPatterns(LLVMTypeConverter &converter,
|
||||
void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns) {
|
||||
populateWithGenerated(patterns);
|
||||
patterns.add<GPUPrintfOpToVPrintfLowering>(converter);
|
||||
patterns
|
||||
.add<GPUIndexIntrinsicOpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
|
||||
NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>,
|
||||
|
||||
@@ -501,3 +501,42 @@ gpu.module @test_module {
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
gpu.module @test_module {
|
||||
// CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL0:[A-Za-z0-9_]+]]("Hello, world\0A\00")
|
||||
// CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL1:[A-Za-z0-9_]+]]("Hello: %d\0A\00")
|
||||
// CHECK-DAG: llvm.func @vprintf(!llvm.ptr<i8>, !llvm.ptr<i8>) -> i32
|
||||
|
||||
// CHECK-LABEL: func @test_const_printf
|
||||
gpu.func @test_const_printf() {
|
||||
// CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL0]] : !llvm.ptr<array<14 x i8>>
|
||||
// CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr<array<14 x i8>>) -> !llvm.ptr<i8>
|
||||
// CHECK-NEXT: %[[O:.*]] = llvm.mlir.constant(1 : index) : i64
|
||||
// CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<()> : (i64) -> !llvm.ptr<struct<()>>
|
||||
// CHECK-NEXT: %[[ARGPTR:.*]] = llvm.bitcast %[[ALLOC]] : !llvm.ptr<struct<()>> to !llvm.ptr<i8>
|
||||
// CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ARGPTR]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> i32
|
||||
gpu.printf "Hello, world\n"
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_printf
|
||||
// CHECK: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: f32)
|
||||
gpu.func @test_printf(%arg0: i32, %arg1: f32) {
|
||||
// CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL1]] : !llvm.ptr<array<11 x i8>>
|
||||
// CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr<array<11 x i8>>) -> !llvm.ptr<i8>
|
||||
// CHECK-NEXT: %[[EXT:.+]] = llvm.fpext %[[ARG1]] : f32 to f64
|
||||
// CHECK-NEXT: %[[O:.*]] = llvm.mlir.constant(1 : index) : i64
|
||||
// CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<(i32, f64)> : (i64) -> !llvm.ptr<struct<(i32, f64)>>
|
||||
// CHECK-NEXT: %[[EL0:.*]] = llvm.getelementptr %[[ALLOC]][0, 0] : (!llvm.ptr<struct<(i32, f64)>>) -> !llvm.ptr<i32>
|
||||
// CHECK-NEXT: llvm.store %[[ARG0]], %[[EL0]] : !llvm.ptr<i32>
|
||||
// CHECK-NEXT: %[[EL1:.*]] = llvm.getelementptr %[[ALLOC]][0, 1] : (!llvm.ptr<struct<(i32, f64)>>) -> !llvm.ptr<f64>
|
||||
// CHECK-NEXT: llvm.store %[[EXT]], %[[EL1]] : !llvm.ptr<f64>
|
||||
// CHECK-NEXT: %[[ARGPTR:.*]] = llvm.bitcast %[[ALLOC]] : !llvm.ptr<struct<(i32, f64)>> to !llvm.ptr<i8>
|
||||
// CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ARGPTR]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> i32
|
||||
gpu.printf "Hello: %d\n" %arg0, %arg1 : i32, f32
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
31
mlir/test/Integration/GPU/CUDA/printf.mlir
Normal file
31
mlir/test/Integration/GPU/CUDA/printf.mlir
Normal file
@@ -0,0 +1,31 @@
|
||||
// RUN: mlir-opt %s \
|
||||
// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,gpu-to-cubin))' \
|
||||
// RUN: | mlir-opt -gpu-to-llvm \
|
||||
// RUN: | mlir-cpu-runner \
|
||||
// RUN: --shared-libs=%mlir_lib_dir/libmlir_cuda_runtime%shlibext \
|
||||
// RUN: --shared-libs=%mlir_lib_dir/libmlir_runner_utils%shlibext \
|
||||
// RUN: --entry-point-result=void \
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
// CHECK: Hello from 0, 2, 3.000000
|
||||
// CHECK: Hello from 1, 2, 3.000000
|
||||
module attributes {gpu.container_module} {
|
||||
gpu.module @kernels {
|
||||
gpu.func @hello() kernel {
|
||||
%0 = gpu.thread_id x
|
||||
%csti8 = arith.constant 2 : i8
|
||||
%cstf32 = arith.constant 3.0 : f32
|
||||
gpu.printf "Hello from %lld, %d, %f\n" %0, %csti8, %cstf32 : index, i8, f32
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
||||
func.func @main() {
|
||||
%c2 = arith.constant 2 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
gpu.launch_func @kernels::@hello
|
||||
blocks in (%c1, %c1, %c1)
|
||||
threads in (%c2, %c1, %c1)
|
||||
return
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user