mirror of
https://github.com/intel/llvm.git
synced 2026-01-27 06:06:34 +08:00
[mlir][nvvm] Introduce Syncronization Ops for WGMMA
This work introduces : `wgmma.fence.aligned`, `wgmma.commit.group.sync.aligned` and `wgmma.wait.group.sync.aligned` Ops. They are used to syncronize warpgroup level matrix multiply-accumulate instructions, as known as WGMMA. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D155676
This commit is contained in:
@@ -1419,4 +1419,51 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : NVVM_Op<"cp.async.bulk.tenso
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NVVM Wgmma Ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned",
|
||||
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]> {
|
||||
let arguments = (ins);
|
||||
let description = [{
|
||||
Enforce an ordering of register accesses between warpgroup level matrix
|
||||
multiplication and other operations.
|
||||
See for more information:
|
||||
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-fence
|
||||
}];
|
||||
let assemblyFormat = "attr-dict";
|
||||
let extraClassDefinition = [{
|
||||
std::string $cppClass::getPtx() { return std::string("wgmma.fence.sync.aligned;"); }
|
||||
}];
|
||||
}
|
||||
|
||||
def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned",
|
||||
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
|
||||
Arguments<(ins )> {
|
||||
let assemblyFormat = "attr-dict";
|
||||
let description = [{
|
||||
Commits all prior uncommitted warpgroup level matrix multiplication operations.
|
||||
See for more information:
|
||||
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-commit-group
|
||||
}];
|
||||
let extraClassDefinition = [{
|
||||
std::string $cppClass::getPtx() { return std::string("wgmma.commit_group.sync.aligned;"); }
|
||||
}];
|
||||
}
|
||||
|
||||
def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned",
|
||||
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>{
|
||||
let arguments = (ins I32Attr:$group);
|
||||
let assemblyFormat = "attr-dict $group";
|
||||
let description = [{
|
||||
Signal the completion of a preceding warpgroup operation.
|
||||
See for more information:
|
||||
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-wait-group
|
||||
}];
|
||||
let extraClassDefinition = [{
|
||||
std::string $cppClass::getPtx() { return std::string("wgmma.wait_group.sync.aligned %0;"); }
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // NVVMIR_OPS
|
||||
|
||||
@@ -80,3 +80,23 @@ func.func @tma_load_5d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier
|
||||
nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32, i32, i32, i32
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL : @wgmma_execute
|
||||
func.func @wgmma_execute() {
|
||||
nvvm.wgmma.fence.aligned
|
||||
nvvm.wgmma.commit.group.sync.aligned
|
||||
nvvm.wgmma.wait.group.sync.aligned 0
|
||||
// CHECK : llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;", ""
|
||||
// CHECK : llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;", ""
|
||||
// CHECK : llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned %0;", "n" %{{.*}} : (i32)
|
||||
|
||||
|
||||
nvvm.wgmma.fence.aligned
|
||||
nvvm.wgmma.commit.group.sync.aligned
|
||||
nvvm.wgmma.wait.group.sync.aligned 1
|
||||
// CHECK : llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;", ""
|
||||
// CHECK : llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;", ""
|
||||
// CHECK : llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned %0;", "n" %{{.*}} : (i32)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -407,3 +407,25 @@ llvm.func private @mbarrier_test_wait_shared(%barrier: !llvm.ptr<3>, %token : i6
|
||||
%isComplete = nvvm.mbarrier.test.wait.shared %barrier, %token : !llvm.ptr<3>, i64 -> i1
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL : @wgmma_fence_aligned
|
||||
func.func @wgmma_fence_aligned() {
|
||||
// CHECK : nvvm.wgmma.fence.aligned
|
||||
nvvm.wgmma.fence.aligned
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL : @wgmma_commit_group_sync_aligned
|
||||
func.func @wgmma_commit_group_sync_aligned() {
|
||||
// CHECK : nvvm.wgmma.commit.group.sync.aligned
|
||||
nvvm.wgmma.commit.group.sync.aligned
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL : @wgmma_commit_group_sync_aligned
|
||||
func.func @wgmma_wait_group_sync_aligned() {
|
||||
// CHECK : nvvm.wgmma.wait.group.sync.aligned
|
||||
nvvm.wgmma.wait.group.sync.aligned 0
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user