[MLIR][NVVM] Fix results-check for mbarrier Op (#171657)

This patch fixes the lowering of the newly
added mbarrier.arrive Op w.r.t return value.
(Follow-up of PR #170545)

Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
This commit is contained in:
Durgadoss R
2025-12-12 22:45:09 +05:30
committed by GitHub
parent 2af693bbec
commit 9dc6f18a3e
2 changed files with 15 additions and 4 deletions

View File

@@ -785,8 +785,7 @@ def NVVM_MBarrierArriveOp : NVVM_Op<"mbarrier.arrive"> {
auto [id, args] = NVVM::MBarrierArriveOp::getIntrinsicIDAndArgs(
*op, moduleTranslation, builder);
int addrSpace = llvm::cast<LLVMPointerType>(op.getAddr().getType()).getAddressSpace();
if (addrSpace != static_cast<unsigned>(NVVM::NVVMMemorySpace::SharedCluster))
if (op.getNumResults() > 0)
$res = createIntrinsicCall(builder, id, args);
else
createIntrinsicCall(builder, id, args);
@@ -827,8 +826,7 @@ def NVVM_MBarrierArriveDropOp : NVVM_Op<"mbarrier.arrive_drop"> {
auto [id, args] = NVVM::MBarrierArriveDropOp::getIntrinsicIDAndArgs(
*op, moduleTranslation, builder);
int addrSpace = llvm::cast<LLVMPointerType>(op.getAddr().getType()).getAddressSpace();
if (addrSpace != static_cast<unsigned>(NVVM::NVVMMemorySpace::SharedCluster))
if (op.getNumResults() > 0)
$res = createIntrinsicCall(builder, id, args);
else
createIntrinsicCall(builder, id, args);

View File

@@ -101,3 +101,16 @@ llvm.func @mbarrier_arrive_nocomplete_shared(%barrier: !llvm.ptr<3>) {
%0 = nvvm.mbarrier.arrive.nocomplete %barrier, %count : !llvm.ptr<3>, i32 -> i64
llvm.return
}
llvm.func @mbarrier_arrive_ignore_retval(%count : i32, %barrier: !llvm.ptr<3>) {
// CHECK-LABEL: define void @mbarrier_arrive_ignore_retval(i32 %0, ptr addrspace(3) %1) {
// CHECK-NEXT: %3 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cta.space.cta(ptr addrspace(3) %1, i32 %0)
// CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cta.space.cta(ptr addrspace(3) %1, i32 %0)
// CHECK-NEXT: ret void
// CHECK-NEXT: }
nvvm.mbarrier.arrive %barrier, %count : !llvm.ptr<3>
nvvm.mbarrier.arrive %barrier, %count : !llvm.ptr<3>
llvm.return
}