From 8481fb1698a4a54b3965a1654046df736a57e144 Mon Sep 17 00:00:00 2001 From: Zahi Moudallal Date: Thu, 14 Mar 2024 08:43:48 -0700 Subject: [PATCH] [MLIR][ROCDL] Fix BallotOp LLVM translation and add doc (#85116) This modifies the return type of the intrinsic call to handle 32 and 64 bits properly and document the MLIR operation. --- mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 10 +++++++++- mlir/test/Target/LLVMIR/rocdl.mlir | 11 +++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index abb38a3df806..1dabf5d7979b 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -162,10 +162,18 @@ def ROCDL_BallotOp : ROCDL_Op<"ballot">, Results<(outs LLVM_Type:$res)>, Arguments<(ins I1:$pred)> { + let summary = "Vote across thread group"; + + let description = [{ + Ballot provides a bit mask containing the 1-bit predicate value from each lane. + The nth bit of the result contains the 1 bit contributed by the nth warp lane. + }]; + string llvmBuilder = [{ $res = createIntrinsicCall(builder, - llvm::Intrinsic::amdgcn_ballot, {$pred}, {llvm::Type::getInt32Ty(moduleTranslation.getLLVMContext())}); + llvm::Intrinsic::amdgcn_ballot, {$pred}, {$_resultType}); }]; + let assemblyFormat = "$pred attr-dict `:` type($res)"; } diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index 93550f5c7bd5..ce6b56d48437 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -88,13 +88,20 @@ llvm.func @rocdl.bpermute(%src : i32) -> i32 { llvm.return %0 : i32 } -llvm.func @rocdl.ballot(%pred : i1) -> i32 { - // CHECK-LABEL: rocdl.ballot +llvm.func @rocdl.ballot32(%pred : i1) -> i32 { + // CHECK-LABEL: rocdl.ballot32 // CHECK: call i32 @llvm.amdgcn.ballot %0 = rocdl.ballot %pred : i32 llvm.return %0 : i32 } +llvm.func @rocdl.ballot64(%pred : i1) -> i64 { + // CHECK-LABEL: rocdl.ballot64 + // CHECK: call i64 @llvm.amdgcn.ballot + %0 = rocdl.ballot %pred : i64 + llvm.return %0 : i64 +} + llvm.func @rocdl.waitcnt() { // CHECK-LABEL: rocdl.waitcnt // CHECK-NEXT: call void @llvm.amdgcn.s.waitcnt(i32 0)