[mlir][nvvm] Introduce Syncronization Ops for WGMMA

This work introduces : `wgmma.fence.aligned`, `wgmma.commit.group.sync.aligned` and `wgmma.wait.group.sync.aligned` Ops. They are used to syncronize warpgroup level matrix multiply-accumulate instructions, as known as WGMMA.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D155676
This commit is contained in:
Guray Ozen
2023-07-19 10:29:41 +02:00
parent 7fa7a08f21
commit bf62748342
3 changed files with 89 additions and 0 deletions

View File

@@ -1419,4 +1419,51 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : NVVM_Op<"cp.async.bulk.tenso
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
// NVVM Wgmma Ops
//===----------------------------------------------------------------------===//
def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]> {
let arguments = (ins);
let description = [{
Enforce an ordering of register accesses between warpgroup level matrix
multiplication and other operations.
See for more information:
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-fence
}];
let assemblyFormat = "attr-dict";
let extraClassDefinition = [{
std::string $cppClass::getPtx() { return std::string("wgmma.fence.sync.aligned;"); }
}];
}
def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
Arguments<(ins )> {
let assemblyFormat = "attr-dict";
let description = [{
Commits all prior uncommitted warpgroup level matrix multiplication operations.
See for more information:
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-commit-group
}];
let extraClassDefinition = [{
std::string $cppClass::getPtx() { return std::string("wgmma.commit_group.sync.aligned;"); }
}];
}
def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>{
let arguments = (ins I32Attr:$group);
let assemblyFormat = "attr-dict $group";
let description = [{
Signal the completion of a preceding warpgroup operation.
See for more information:
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-wait-group
}];
let extraClassDefinition = [{
std::string $cppClass::getPtx() { return std::string("wgmma.wait_group.sync.aligned %0;"); }
}];
}
#endif // NVVMIR_OPS

View File

@@ -80,3 +80,23 @@ func.func @tma_load_5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i32
return
}
// CHECK-LABEL : @wgmma_execute
func.func @wgmma_execute() {
nvvm.wgmma.fence.aligned
nvvm.wgmma.commit.group.sync.aligned
nvvm.wgmma.wait.group.sync.aligned 0
// CHECK : llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;", ""
// CHECK : llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;", ""
// CHECK : llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned %0;", "n" %{{.*}} : (i32)
nvvm.wgmma.fence.aligned
nvvm.wgmma.commit.group.sync.aligned
nvvm.wgmma.wait.group.sync.aligned 1
// CHECK : llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;", ""
// CHECK : llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;", ""
// CHECK : llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned %0;", "n" %{{.*}} : (i32)
return
}

View File

@@ -407,3 +407,25 @@ llvm.func private @mbarrier_test_wait_shared(%barrier: !llvm.ptr<3>, %token : i6
%isComplete = nvvm.mbarrier.test.wait.shared %barrier, %token : !llvm.ptr<3>, i64 -> i1
llvm.return
}
// CHECK-LABEL : @wgmma_fence_aligned
func.func @wgmma_fence_aligned() {
// CHECK : nvvm.wgmma.fence.aligned
nvvm.wgmma.fence.aligned
return
}
// CHECK-LABEL : @wgmma_commit_group_sync_aligned
func.func @wgmma_commit_group_sync_aligned() {
// CHECK : nvvm.wgmma.commit.group.sync.aligned
nvvm.wgmma.commit.group.sync.aligned
return
}
// CHECK-LABEL : @wgmma_commit_group_sync_aligned
func.func @wgmma_wait_group_sync_aligned() {
// CHECK : nvvm.wgmma.wait.group.sync.aligned
nvvm.wgmma.wait.group.sync.aligned 0
return
}