[mlir][nvvm] Add lowering of gpu.printf to nvvm

When converting to nvvm lowering gpu.printf to vprintf allows us to
support printing when running on cuda.

Differential Revision: https://reviews.llvm.org/D141049
This commit is contained in:
Thomas Raoux
2023-01-05 21:20:45 +00:00
parent 9b5f62685a
commit 7efdc117b1
5 changed files with 170 additions and 15 deletions

View File

@@ -172,7 +172,17 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
return success();
}
static const char formatStringPrefix[] = "printfFormat_";
static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) {
const char formatStringPrefix[] = "printfFormat_";
// Get a unique global name.
unsigned stringNumber = 0;
SmallString<16> stringConstName;
do {
stringConstName.clear();
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
} while (moduleOp.lookupSymbol(stringConstName));
return stringConstName;
}
template <typename T>
static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
@@ -225,13 +235,8 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
Value printfDesc = printfBeginCall.getResult();
// Create a global constant for the format string
unsigned stringNumber = 0;
SmallString<16> stringConstName;
do {
stringConstName.clear();
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
} while (moduleOp.lookupSymbol(stringConstName));
// Get a unique global name for the format.
SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
llvm::SmallString<20> formatString(adaptor.getFormat());
formatString.push_back('\0'); // Null terminate for C
@@ -320,13 +325,8 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
LLVM::LLVMFuncOp printfDecl =
getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
// Create a global constant for the format string
unsigned stringNumber = 0;
SmallString<16> stringConstName;
do {
stringConstName.clear();
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
} while (moduleOp.lookupSymbol(stringConstName));
// Get a unique global name for the format.
SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
llvm::SmallString<20> formatString(adaptor.getFormat());
formatString.push_back('\0'); // Null terminate for C
@@ -359,6 +359,80 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
return success();
}
LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = gpuPrintfOp->getLoc();
mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8);
// Note: this is the GPUModule op, not the ModuleOp that surrounds it
// This ensures that global constants and declarations are placed within
// the device code, not the host code
auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
auto vprintfType =
LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {i8Ptr, i8Ptr});
LLVM::LLVMFuncOp vprintfDecl =
getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
// Get a unique global name for the format.
SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
llvm::SmallString<20> formatString(adaptor.getFormat());
formatString.push_back('\0'); // Null terminate for C
auto globalType =
LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
LLVM::GlobalOp global;
{
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
global = rewriter.create<LLVM::GlobalOp>(
loc, globalType,
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
rewriter.getStringAttr(formatString), /*allignment=*/0);
}
// Get a pointer to the format string's first element
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
Value stringStart = rewriter.create<LLVM::GEPOp>(
loc, i8Ptr, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
SmallVector<Type> types;
SmallVector<Value> args;
// Promote and pack the arguments into a stack allocation.
for (Value arg : adaptor.getArgs()) {
Type type = arg.getType();
Value promotedArg = arg;
assert(type.isIntOrFloat());
if (type.isa<FloatType>()) {
type = rewriter.getF64Type();
promotedArg = rewriter.create<LLVM::FPExtOp>(loc, type, arg);
}
types.push_back(type);
args.push_back(promotedArg);
}
Type structType =
LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types);
Type structPtrType = LLVM::LLVMPointerType::get(structType);
Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
rewriter.getIndexAttr(1));
Value tempAlloc = rewriter.create<LLVM::AllocaOp>(loc, structPtrType, one,
/*alignment=*/0);
for (auto [index, arg] : llvm::enumerate(args)) {
Value ptr = rewriter.create<LLVM::GEPOp>(
loc, LLVM::LLVMPointerType::get(arg.getType()), tempAlloc,
ArrayRef<LLVM::GEPArg>{0, index});
rewriter.create<LLVM::StoreOp>(loc, arg, ptr);
}
tempAlloc = rewriter.create<LLVM::BitcastOp>(loc, i8Ptr, tempAlloc);
std::array<Value, 2> printfArgs = {stringStart, tempAlloc};
rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs);
rewriter.eraseOp(gpuPrintfOp);
return success();
}
/// Unrolls op if it's operating on vectors.
LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
ConversionPatternRewriter &rewriter,

View File

@@ -67,6 +67,16 @@ private:
int addressSpace;
};
/// Lowering of gpu.printf to a vprintf standard library.
struct GPUPrintfOpToVPrintfLowering
: public ConvertOpToLLVMPattern<gpu::PrintfOp> {
using ConvertOpToLLVMPattern<gpu::PrintfOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
struct GPUReturnOpLowering : public ConvertOpToLLVMPattern<gpu::ReturnOp> {
using ConvertOpToLLVMPattern<gpu::ReturnOp>::ConvertOpToLLVMPattern;

View File

@@ -239,6 +239,7 @@ static void populateOpPatterns(LLVMTypeConverter &converter,
void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
populateWithGenerated(patterns);
patterns.add<GPUPrintfOpToVPrintfLowering>(converter);
patterns
.add<GPUIndexIntrinsicOpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>,

View File

@@ -501,3 +501,42 @@ gpu.module @test_module {
gpu.return
}
}
// -----
gpu.module @test_module {
// CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL0:[A-Za-z0-9_]+]]("Hello, world\0A\00")
// CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL1:[A-Za-z0-9_]+]]("Hello: %d\0A\00")
// CHECK-DAG: llvm.func @vprintf(!llvm.ptr<i8>, !llvm.ptr<i8>) -> i32
// CHECK-LABEL: func @test_const_printf
gpu.func @test_const_printf() {
// CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL0]] : !llvm.ptr<array<14 x i8>>
// CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr<array<14 x i8>>) -> !llvm.ptr<i8>
// CHECK-NEXT: %[[O:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<()> : (i64) -> !llvm.ptr<struct<()>>
// CHECK-NEXT: %[[ARGPTR:.*]] = llvm.bitcast %[[ALLOC]] : !llvm.ptr<struct<()>> to !llvm.ptr<i8>
// CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ARGPTR]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> i32
gpu.printf "Hello, world\n"
gpu.return
}
// CHECK-LABEL: func @test_printf
// CHECK: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: f32)
gpu.func @test_printf(%arg0: i32, %arg1: f32) {
// CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL1]] : !llvm.ptr<array<11 x i8>>
// CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr<array<11 x i8>>) -> !llvm.ptr<i8>
// CHECK-NEXT: %[[EXT:.+]] = llvm.fpext %[[ARG1]] : f32 to f64
// CHECK-NEXT: %[[O:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<(i32, f64)> : (i64) -> !llvm.ptr<struct<(i32, f64)>>
// CHECK-NEXT: %[[EL0:.*]] = llvm.getelementptr %[[ALLOC]][0, 0] : (!llvm.ptr<struct<(i32, f64)>>) -> !llvm.ptr<i32>
// CHECK-NEXT: llvm.store %[[ARG0]], %[[EL0]] : !llvm.ptr<i32>
// CHECK-NEXT: %[[EL1:.*]] = llvm.getelementptr %[[ALLOC]][0, 1] : (!llvm.ptr<struct<(i32, f64)>>) -> !llvm.ptr<f64>
// CHECK-NEXT: llvm.store %[[EXT]], %[[EL1]] : !llvm.ptr<f64>
// CHECK-NEXT: %[[ARGPTR:.*]] = llvm.bitcast %[[ALLOC]] : !llvm.ptr<struct<(i32, f64)>> to !llvm.ptr<i8>
// CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ARGPTR]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>) -> i32
gpu.printf "Hello: %d\n" %arg0, %arg1 : i32, f32
gpu.return
}
}

View File

@@ -0,0 +1,31 @@
// RUN: mlir-opt %s \
// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,gpu-to-cubin))' \
// RUN: | mlir-opt -gpu-to-llvm \
// RUN: | mlir-cpu-runner \
// RUN: --shared-libs=%mlir_lib_dir/libmlir_cuda_runtime%shlibext \
// RUN: --shared-libs=%mlir_lib_dir/libmlir_runner_utils%shlibext \
// RUN: --entry-point-result=void \
// RUN: | FileCheck %s
// CHECK: Hello from 0, 2, 3.000000
// CHECK: Hello from 1, 2, 3.000000
module attributes {gpu.container_module} {
gpu.module @kernels {
gpu.func @hello() kernel {
%0 = gpu.thread_id x
%csti8 = arith.constant 2 : i8
%cstf32 = arith.constant 3.0 : f32
gpu.printf "Hello from %lld, %d, %f\n" %0, %csti8, %cstf32 : index, i8, f32
gpu.return
}
}
func.func @main() {
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
gpu.launch_func @kernels::@hello
blocks in (%c1, %c1, %c1)
threads in (%c2, %c1, %c1)
return
}
}