diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td index e11c5c393648..826df0012fb8 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -53,6 +53,32 @@ class GPU_IndexOp traits = []> : let assemblyFormat = "$dimension attr-dict"; } +def GPU_ClusterDimOp : GPU_IndexOp<"cluster_dim"> { + let description = [{ + Returns the number of thread blocks in the cluster along + the x, y, or z `dimension`. + + Example: + + ```mlir + %cDimX = gpu.cluster_dim x + ``` + }]; +} + +def GPU_ClusterIdOp : GPU_IndexOp<"cluster_id"> { + let description = [{ + Returns the cluster id, i.e. the index of the current cluster within the + grid along the x, y, or z `dimension`. + + Example: + + ```mlir + %cIdY = gpu.cluster_id y + ``` + }]; +} + def GPU_BlockDimOp : GPU_IndexOp<"block_dim"> { let description = [{ Returns the number of threads in the thread block (aka the block size) along @@ -467,8 +493,15 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [ "blockSizeY", "blockSizeZ"]>]>, Arguments<(ins Variadic:$asyncDependencies, SymbolRefAttr:$kernel, - LaunchIndx:$gridSizeX, LaunchIndx:$gridSizeY, LaunchIndx:$gridSizeZ, - LaunchIndx:$blockSizeX, LaunchIndx:$blockSizeY, LaunchIndx:$blockSizeZ, + LaunchIndx:$gridSizeX, + LaunchIndx:$gridSizeY, + LaunchIndx:$gridSizeZ, + LaunchIndx:$blockSizeX, + LaunchIndx:$blockSizeY, + LaunchIndx:$blockSizeZ, + Optional:$clusterSizeX, + Optional:$clusterSizeY, + Optional:$clusterSizeZ, Optional:$dynamicSharedMemorySize, Variadic:$kernelOperands, Optional:$asyncObject)>, @@ -506,6 +539,12 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [ The remaining operands if present are passed as arguments to the kernel function. + The `gpu.launch_func` also supports kernel launching with clusters if + supported by the target architecture. The cluster size can be set by + `clusterSizeX`, `clusterSizeY`, and `clusterSizeZ` arguments. When these + arguments are present, the Op launches a kernel that clusters the given + thread blocks. This feature is exclusive to certain architectures. + Example: ```mlir @@ -535,6 +574,15 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [ %gDimY = gpu.grid_dim y %gDimZ = gpu.grid_dim z + // (Optional) Cluster size only for support architectures + %cIdX = gpu.cluster_id x + %cIdY = gpu.cluster_id y + %cIdZ = gpu.cluster_id z + + %cDimX = gpu.cluster_dim x + %cDimY = gpu.cluster_dim y + %cDimZ = gpu.cluster_dim z + "some_op"(%bx, %tx) : (index, index) -> () %42 = load %arg1[%bx] : memref } @@ -545,6 +593,7 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [ async // (Optional) Don't block host, return token. [%t0] // (Optional) Execute only after %t0 has completed. @kernels::@kernel_1 // Kernel function. + clusters in (%cst, %cst, %cst) // (Optional) Cluster size only for support architectures. blocks in (%cst, %cst, %cst) // Grid size. threads in (%cst, %cst, %cst) // Block size. dynamic_shared_memory_size %s // (Optional) Amount of dynamic shared @@ -562,11 +611,13 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [ "KernelDim3":$blockSize, "Value":$dynamicSharedMemorySize, "ValueRange":$kernelOperands, CArg<"Type", "nullptr">:$asyncTokenType, - CArg<"ValueRange", "{}">:$asyncDependencies)>, + CArg<"ValueRange", "{}">:$asyncDependencies, + CArg<"std::optional", "std::nullopt">:$clusterSize)>, OpBuilder<(ins "SymbolRefAttr":$kernel, "KernelDim3":$gridSize, "KernelDim3":$blockSize, "Value":$dynamicSharedMemorySize, "ValueRange":$kernelOperands, - CArg<"Value", "nullptr">:$asyncObject)> + CArg<"Value", "nullptr">:$asyncObject, + CArg<"std::optional", "std::nullopt">:$clusterSize)> ]; let extraClassDeclaration = [{ @@ -576,12 +627,23 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [ /// The name of the kernel. StringAttr getKernelName(); + /// Returns true if cluster size is specified. + bool hasClusterSize() { + if (getClusterSizeX() && getClusterSizeY() && getClusterSizeZ()) + return true; + return false; + } + /// The number of operands passed to the kernel function. unsigned getNumKernelOperands(); /// The i-th operand passed to the kernel function. Value getKernelOperand(unsigned i); + /// Get the SSA values passed as operands to specify the cluster size. + /// When the cluster sizes are not specified, it asserts. + KernelDim3 getClusterSizeOperandValues(); + /// Get the SSA values passed as operands to specify the grid size. KernelDim3 getGridSizeOperandValues(); @@ -597,10 +659,11 @@ def GPU_LaunchFuncOp :GPU_Op<"launch_func", [ let assemblyFormat = [{ custom(type($asyncToken), $asyncDependencies) (`<` $asyncObject^ `:` type($asyncObject) `>`)? - $kernel + $kernel + ( `clusters` `in` ` ` `(` $clusterSizeX^ `,` $clusterSizeY `,` $clusterSizeZ `)` )? `blocks` `in` ` ` `(` $gridSizeX `,` $gridSizeY `,` $gridSizeZ `)` `threads` `in` ` ` `(` $blockSizeX `,` $blockSizeY `,` $blockSizeZ `)` - custom(type($gridSizeX)) + custom(type($gridSizeX), ref($clusterSizeX), type($clusterSizeX), type($clusterSizeY), type($clusterSizeZ)) (`dynamic_shared_memory_size` $dynamicSharedMemorySize^)? custom($kernelOperands, type($kernelOperands)) attr-dict }]; diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index 3dd8aae81c59..2da97c20e9c9 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -1128,13 +1128,19 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( loc, launchOp.getKernelOperands(), adaptor.getKernelOperands(), rewriter, /*useBarePtrCallConv=*/kernelBarePtrCallConv); + std::optional clusterSize = std::nullopt; + if (launchOp.hasClusterSize()) { + clusterSize = + gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(), + adaptor.getClusterSizeZ()}; + } rewriter.create( launchOp.getLoc(), launchOp.getKernelAttr(), gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(), adaptor.getGridSizeZ()}, gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(), adaptor.getBlockSizeZ()}, - adaptor.getDynamicSharedMemorySize(), arguments, stream); + adaptor.getDynamicSharedMemorySize(), arguments, stream, clusterSize); if (launchOp.getAsyncToken()) rewriter.replaceOp(launchOp, {stream}); else diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 86a77f557cb9..9456784c406a 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -313,17 +313,20 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { populateWithGenerated(patterns); patterns.add(converter); - patterns - .add, - GPUIndexIntrinsicOpLowering, - GPUIndexIntrinsicOpLowering, - GPUIndexIntrinsicOpLowering, - GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>( - converter); + patterns.add< + GPUIndexIntrinsicOpLowering, + GPUIndexIntrinsicOpLowering, + GPUIndexIntrinsicOpLowering, + GPUIndexIntrinsicOpLowering, + GPUIndexIntrinsicOpLowering, + GPUIndexIntrinsicOpLowering, + GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(converter); patterns.add( converter, NVVM::kSharedMemoryAlignmentBit); diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 9517c053c836..1b6db1fb0c79 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -32,6 +32,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/StringSaver.h" +#include using namespace mlir; using namespace mlir::gpu; @@ -985,7 +986,8 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, GPUFuncOp kernelFunc, KernelDim3 gridSize, KernelDim3 getBlockSize, Value dynamicSharedMemorySize, ValueRange kernelOperands, Type asyncTokenType, - ValueRange asyncDependencies) { + ValueRange asyncDependencies, + std::optional clusterSize) { result.addOperands(asyncDependencies); if (asyncTokenType) result.types.push_back(builder.getType()); @@ -993,6 +995,8 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, // Add grid and block sizes as op operands, followed by the data operands. result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x, getBlockSize.y, getBlockSize.z}); + if (clusterSize.has_value()) + result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z}); if (dynamicSharedMemorySize) result.addOperands(dynamicSharedMemorySize); result.addOperands(kernelOperands); @@ -1008,6 +1012,11 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, for (auto &sz : prop.operandSegmentSizes) sz = 1; prop.operandSegmentSizes[0] = asyncDependencies.size(); + if (!clusterSize.has_value()) { + prop.operandSegmentSizes[segmentSizesLen - 4] = 0; + prop.operandSegmentSizes[segmentSizesLen - 5] = 0; + prop.operandSegmentSizes[segmentSizesLen - 6] = 0; + } prop.operandSegmentSizes[segmentSizesLen - 3] = dynamicSharedMemorySize ? 1 : 0; prop.operandSegmentSizes[segmentSizesLen - 2] = @@ -1018,10 +1027,13 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, SymbolRefAttr kernel, KernelDim3 gridSize, KernelDim3 getBlockSize, Value dynamicSharedMemorySize, - ValueRange kernelOperands, Value asyncObject) { + ValueRange kernelOperands, Value asyncObject, + std::optional clusterSize) { // Add grid and block sizes as op operands, followed by the data operands. result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x, getBlockSize.y, getBlockSize.z}); + if (clusterSize.has_value()) + result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z}); if (dynamicSharedMemorySize) result.addOperands(dynamicSharedMemorySize); result.addOperands(kernelOperands); @@ -1034,6 +1046,11 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, for (auto &sz : prop.operandSegmentSizes) sz = 1; prop.operandSegmentSizes[0] = 0; + if (!clusterSize.has_value()) { + prop.operandSegmentSizes[segmentSizesLen - 4] = 0; + prop.operandSegmentSizes[segmentSizesLen - 5] = 0; + prop.operandSegmentSizes[segmentSizesLen - 6] = 0; + } prop.operandSegmentSizes[segmentSizesLen - 3] = dynamicSharedMemorySize ? 1 : 0; prop.operandSegmentSizes[segmentSizesLen - 2] = @@ -1067,6 +1084,13 @@ KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() { return KernelDim3{operands[3], operands[4], operands[5]}; } +KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() { + assert(hasClusterSize() && + "cluster size is not set, check hasClusterSize() first"); + auto operands = getOperands().drop_front(getAsyncDependencies().size()); + return KernelDim3{operands[6], operands[7], operands[8]}; +} + LogicalResult LaunchFuncOp::verify() { auto module = (*this)->getParentOfType(); if (!module) @@ -1078,21 +1102,35 @@ LogicalResult LaunchFuncOp::verify() { GPUDialect::getContainerModuleAttrName() + "' attribute"); + if (hasClusterSize()) { + if (getClusterSizeY().getType() != getClusterSizeX().getType() || + getClusterSizeZ().getType() != getClusterSizeX().getType()) + return emitOpError() + << "expects types of the cluster dimensions must be the same"; + } + return success(); } -static ParseResult parseLaunchDimType(OpAsmParser &parser, Type &dimTy) { +static ParseResult +parseLaunchDimType(OpAsmParser &parser, Type &dimTy, + std::optional clusterValue, + Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) { if (succeeded(parser.parseOptionalColon())) { if (parser.parseType(dimTy)) return failure(); } else { dimTy = IndexType::get(parser.getContext()); } + if (clusterValue.has_value()) { + clusterXTy = clusterYTy = clusterZTy = dimTy; + } return success(); } -static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, - Type dimTy) { +static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, + Value clusterValue, Type clusterXTy, + Type clusterYTy, Type clusterZTy) { if (!dimTy.isIndex()) printer << ": " << dimTy; } diff --git a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp index cb2d66d5b0d3..69017efb9a0e 100644 --- a/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp +++ b/mlir/lib/Dialect/GPU/IR/InferIntRangeInterfaceImpls.cpp @@ -19,6 +19,8 @@ using namespace mlir::gpu; // Maximum grid and block dimensions of all known GPUs are less than 2^32. static constexpr uint64_t kMaxDim = std::numeric_limits::max(); +// Maximum cluster size +static constexpr uint64_t kMaxClusterDim = 8; // Maximum subgroups are no larger than 128. static constexpr uint64_t kMaxSubgroupSize = 128; @@ -82,6 +84,17 @@ static std::optional getKnownLaunchDim(Op op, LaunchDims type) { return std::nullopt; } +void ClusterDimOp::inferResultRanges(ArrayRef, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), getIndexRange(1, kMaxClusterDim)); +} + +void ClusterIdOp::inferResultRanges(ArrayRef, + SetIntRangeFn setResultRange) { + uint64_t max = kMaxClusterDim; + setResultRange(getResult(), getIndexRange(0, max - 1ULL)); +} + void BlockDimOp::inferResultRanges(ArrayRef, SetIntRangeFn setResultRange) { std::optional knownVal = diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp index a8e743c51913..9b63d2a22a7a 100644 --- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp @@ -194,6 +194,60 @@ mgpuLaunchKernel(CUfunction function, intptr_t gridX, intptr_t gridY, extra)); } +extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuLaunchClusterKernel( + CUfunction function, intptr_t clusterX, intptr_t clusterY, + intptr_t clusterZ, intptr_t gridX, intptr_t gridY, intptr_t gridZ, + intptr_t blockX, intptr_t blockY, intptr_t blockZ, int32_t smem, + CUstream stream, void **params, void **extra, size_t /*paramsCount*/) { + ScopedContext scopedContext; + if (smem > 0) { + // Avoid checking driver as it's more expensive than if statement + int32_t maxShmem = 0; + CUdevice device = getDefaultCuDevice(); + CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice)); + CUDA_REPORT_IF_ERROR(cuDeviceGetAttribute( + &maxShmem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + device)); + if (maxShmem < smem) { + fprintf(stderr, + "Requested shared memory (%dkb) is larger than maximum allowed " + "shared memory (%dkb) for this device\n", + smem, maxShmem); + } + CUDA_REPORT_IF_ERROR(cuFuncSetAttribute( + function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem)); + } + CUlaunchConfig config; + config.gridDimX = gridX; + config.gridDimY = gridY; + config.gridDimZ = gridZ; + config.blockDimX = blockX; + config.blockDimY = blockY; + config.blockDimZ = blockZ; + config.sharedMemBytes = smem; + config.hStream = stream; + CUlaunchAttribute launchAttr[2]; + launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + launchAttr[0].value.clusterDim.x = clusterX; + launchAttr[0].value.clusterDim.y = clusterY; + launchAttr[0].value.clusterDim.z = clusterZ; + launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; + launchAttr[1].value.clusterSchedulingPolicyPreference = + CU_CLUSTER_SCHEDULING_POLICY_SPREAD; + config.numAttrs = 2; + config.attrs = launchAttr; + + debug_print("Launching kernel," + "cluster: %ld, %ld, %ld, " + "grid=%ld,%ld,%ld, " + "threads: %ld, %ld, %ld, " + "smem: %dkb\n", + clusterX, clusterY, clusterZ, gridX, gridY, gridZ, blockX, blockY, + blockZ, smem); + + CUDA_REPORT_IF_ERROR(cuLaunchKernelEx(&config, function, params, extra)); +} + extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUstream mgpuStreamCreate() { ScopedContext scopedContext; CUstream stream = nullptr; diff --git a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp b/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp index 47fe6973778c..2acccb7c2faf 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp @@ -136,6 +136,9 @@ public: // Get the kernel launch callee. FunctionCallee getKernelLaunchFn(); + // Get the kernel launch callee. + FunctionCallee getClusterKernelLaunchFn(); + // Get the module function callee. FunctionCallee getModuleFunctionFn(); @@ -228,6 +231,17 @@ llvm::FunctionCallee llvm::LaunchKernel::getKernelLaunchFn() { false)); } +llvm::FunctionCallee llvm::LaunchKernel::getClusterKernelLaunchFn() { + return module.getOrInsertFunction( + "mgpuLaunchClusterKernel", + FunctionType::get( + voidTy, + ArrayRef({ptrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy, + intPtrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy, + i32Ty, ptrTy, ptrTy, ptrTy}), + false)); +} + llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() { return module.getOrInsertFunction( "mgpuModuleGetFunction", @@ -401,10 +415,22 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op, // Create the launch call. Value *nullPtr = ConstantPointerNull::get(ptrTy); - builder.CreateCall( - getKernelLaunchFn(), - ArrayRef({moduleFunction, gx, gy, gz, bx, by, bz, - dynamicMemorySize, stream, argArray, nullPtr})); + + // Launch kernel with clusters if cluster size is specified. + if (op.hasClusterSize()) { + mlir::gpu::KernelDim3 cluster = op.getClusterSizeOperandValues(); + Value *cx = llvmValue(cluster.x), *cy = llvmValue(cluster.y), + *cz = llvmValue(cluster.z); + builder.CreateCall( + getClusterKernelLaunchFn(), + ArrayRef({moduleFunction, cx, cy, cz, gx, gy, gz, bx, by, bz, + dynamicMemorySize, stream, argArray, nullPtr})); + } else { + builder.CreateCall( + getKernelLaunchFn(), + ArrayRef({moduleFunction, gx, gy, gz, bx, by, bz, + dynamicMemorySize, stream, argArray, nullPtr})); + } // Sync & destroy the stream, for synchronous launches. if (handleStream) { diff --git a/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir index f5462b579b5e..c0b05ef08603 100644 --- a/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir +++ b/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir @@ -96,3 +96,41 @@ module attributes {gpu.container_module} { return } } + + +// ----- + +module attributes {gpu.container_module} { + // CHECK: gpu.module + gpu.module @kernel_module [#nvvm.target] { + llvm.func @kernel(%arg0: i32, %arg1: !llvm.ptr, + %arg2: !llvm.ptr, %arg3: i64, %arg4: i64, + %arg5: i64) attributes {gpu.kernel} { + llvm.return + } + } + + func.func @foo(%buffer: memref) { + // CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : index) : i64 + // CHECK: [[C32:%.*]] = llvm.mlir.constant(32 : i32) : i32 + // CHECK: [[C256:%.*]] = llvm.mlir.constant(256 : i32) : i32 + // CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : index) : i64 + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : i32 + %c256 = arith.constant 256 : i32 + %c2 = arith.constant 2 : index + + // CHECK: gpu.launch_func @kernel_module::@kernel + // CHECK: clusters in ([[C2]], [[C2]], [[C2]]) + // CHECK: blocks in ([[C8]], [[C8]], [[C8]]) threads in ([[C8]], [[C8]], [[C8]]) : i64 + // CHECK: dynamic_shared_memory_size [[C256]] + // CHECK: args([[C32]] : i32, %{{.*}} : !llvm.ptr, %{{.*}} : !llvm.ptr, %{{.*}} : i64, %{{.*}} : i64, %{{.*}} : i64) + gpu.launch_func @kernel_module::@kernel + clusters in (%c2, %c2, %c2) + blocks in (%c8, %c8, %c8) + threads in (%c8, %c8, %c8) + dynamic_shared_memory_size %c256 + args(%c32 : i32, %buffer : memref) + return + } +} diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir index df9921ef14d3..3a2197ad4d5a 100644 --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -57,7 +57,7 @@ module attributes {gpu.container_module} { func.func @launch_func_missing_callee_attribute(%sz : index) { // expected-error@+1 {{'gpu.launch_func' op requires attribute 'kernel'}} "gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz) - {operandSegmentSizes = array} + {operandSegmentSizes = array} : (index, index, index, index, index, index) -> () return } diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir index c638e0b21ab6..481934364156 100644 --- a/mlir/test/Dialect/GPU/ops.mlir +++ b/mlir/test/Dialect/GPU/ops.mlir @@ -152,6 +152,9 @@ module attributes {gpu.container_module} { // CHECK: gpu.launch_func @kernels::@kernel_1 blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) args(%{{.*}} : f32, %{{.*}} : memref) gpu.launch_func @kernels::@kernel_1 blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) args(%0 : f32, %1 : memref) + // CHECK: gpu.launch_func @kernels::@kernel_1 clusters in (%{{.*}}, %{{.*}}, %{{.*}}) blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) args(%{{.*}} : f32, %{{.*}} : memref) + gpu.launch_func @kernels::@kernel_1 clusters in (%cst, %cst, %cst) blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) args(%0 : f32, %1 : memref) + gpu.launch_func @kernels::@kernel_1 blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) dynamic_shared_memory_size %c0 args(%0 : f32, %1 : memref) // CHECK: gpu.launch_func @kernels::@kernel_2 blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) diff --git a/mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir b/mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir new file mode 100644 index 000000000000..5beba4881348 --- /dev/null +++ b/mlir/test/Integration/GPU/CUDA/sm90/cga_cluster.mlir @@ -0,0 +1,65 @@ +// RUN: mlir-opt %s \ +// RUN: -test-lower-to-nvvm="cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3" \ +// RUN: | mlir-cpu-runner \ +// RUN: --shared-libs=%mlir_cuda_runtime \ +// RUN: --shared-libs=%mlir_runner_utils \ +// RUN: --shared-libs=%mlir_c_runner_utils \ +// RUN: --entry-point-result=void \ +// RUN: | FileCheck %s + +// CHECK: clusterIdx: (1, 1, 0) in Cluster Dimension: (2, 2, 1) blockIdx: (3, 3, 0) + +module attributes {gpu.container_module} { + func.func @main() { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + gpu.launch_func @gpumodule::@kernel_cluster clusters in(%c2,%c2,%c1) blocks in (%c4, %c4, %c1) threads in (%c1, %c1, %c1) + return + } + gpu.module @gpumodule { + gpu.func @kernel_cluster() kernel attributes {gpu.known_block_size = array, gpu.known_grid_size = array} { + %cidX = gpu.cluster_id x + %cidY = gpu.cluster_id y + %cidZ = gpu.cluster_id z + %cdimX = gpu.cluster_dim x + %cdimY = gpu.cluster_dim y + %cdimZ = gpu.cluster_dim z + %bidX = gpu.block_id x + %bidY = gpu.block_id y + %bidZ = gpu.block_id z + %cidX_i32 = index.casts %cidX : index to i32 + %cidY_i32 = index.casts %cidY : index to i32 + %cidZ_i32 = index.casts %cidZ : index to i32 + %cdimX_i32 = index.casts %cdimX : index to i32 + %cdimY_i32 = index.casts %cdimY : index to i32 + %cdimZ_i32 = index.casts %cdimZ : index to i32 + %bidX_i32 = index.casts %bidX : index to i32 + %bidY_i32 = index.casts %bidY : index to i32 + %bidZ_i32 = index.casts %bidZ : index to i32 + + %c3 = arith.constant 3 : index + %cnd1 = arith.cmpi eq, %bidX, %c3 : index + %cnd2 = arith.cmpi eq, %bidY, %c3 : index + scf.if %cnd1 { + scf.if %cnd2 { + gpu.printf "clusterIdx: (%d, %d, %d) in Cluster Dimension: (%d, %d, %d) blockIdx: (%d, %d, %d) \n" + %cidX_i32, + %cidY_i32, + %cidZ_i32, + %cdimX_i32, + %cdimY_i32, + %cdimZ_i32, + %bidX_i32, + %bidY_i32, + %bidZ_i32 + : + i32, i32, i32, i32, i32, i32, i32, i32, i32 + } + } + + gpu.return + } + } +} + diff --git a/mlir/test/Target/LLVMIR/gpu.mlir b/mlir/test/Target/LLVMIR/gpu.mlir index fddbbee962c1..190b53bcf208 100644 --- a/mlir/test/Target/LLVMIR/gpu.mlir +++ b/mlir/test/Target/LLVMIR/gpu.mlir @@ -75,3 +75,22 @@ module attributes {gpu.container_module} { llvm.func @mgpuStreamSynchronize(!llvm.ptr) llvm.func @mgpuStreamDestroy(!llvm.ptr) } + +// ----- + +// Test cluster/block/thread syntax. +module attributes {gpu.container_module} { + // CHECK: @kernel_module_bin_cst = internal constant [4 x i8] c"BLOB", align 8 + gpu.binary @kernel_module [#gpu.object<#nvvm.target, "BLOB">] + llvm.func @foo() { + // CHECK: [[S2:%.*]] = alloca ptr, i64 0, align 8 + // CHECK: [[S3:%.*]] = call ptr @mgpuModuleLoad(ptr @kernel_module_bin_cst) + // CHECK: [[S4:%.*]] = call ptr @mgpuModuleGetFunction(ptr [[S3]], ptr @kernel_module_kernel_kernel_name) + // CHECK: [[S5:%.*]] = call ptr @mgpuStreamCreate() + // CHECK: call void @mgpuLaunchClusterKernel(ptr [[S4]], i64 2, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i32 0, ptr [[S5]], ptr [[S2]], ptr null) + %0 = llvm.mlir.constant(1 : index) : i64 + %1 = llvm.mlir.constant(2 : index) : i64 + gpu.launch_func @kernel_module::@kernel clusters in (%1, %0, %0) blocks in (%0, %0, %0) threads in (%0, %0, %0) : i64 + llvm.return + } +}