[MLIR][GPU][NVVM] Add warp synchronous matrix-multiply accumulate ops

Add warp synchronous matrix-multiply accumulate ops in GPU and NVVM
dialect. Add following three ops to GPU dialect :-
  1.) subgroup_mma_load_matrix
  2.) subgroup_mma_store_matrix
  3.) subgroup_mma_compute
Add following three ops to NVVM dialect :-
  1.) wmma.m16n16k16.load.[a,b,c].[f16,f32].row.stride
  2.) wmma.m16n16k16.store.d.[f16,f32].row.stride
  3.) wmma.m16n16k16.mma.row.row.[f16,f32].[f16,f32]

Reviewed By: bondhugula, ftynse, ThomasRaoux

Differential Revision: https://reviews.llvm.org/D95330
This commit is contained in:
Navdeep Kumar
2021-05-06 12:05:07 +05:30
committed by Uday Bondhugula
parent 16c7829784
commit 875eb523c1
13 changed files with 1264 additions and 11 deletions

View File

@@ -35,6 +35,7 @@
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
@@ -300,6 +301,29 @@ llvm::Value *mlir::LLVM::detail::createIntrinsicCall(
return builder.CreateCall(fn, args);
}
llvm::Value *
mlir::LLVM::detail::createNvvmIntrinsicCall(llvm::IRBuilderBase &builder,
llvm::Intrinsic::ID intrinsic,
ArrayRef<llvm::Value *> args) {
llvm::Module *module = builder.GetInsertBlock()->getModule();
llvm::Function *fn;
if (llvm::Intrinsic::isOverloaded(intrinsic)) {
if (intrinsic != llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f16_f16 &&
intrinsic != llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f32_f32) {
// NVVM load and store instrinsic names are overloaded on the
// source/destination pointer type. Pointer is the first argument in the
// corresponding NVVM Op.
fn = llvm::Intrinsic::getDeclaration(module, intrinsic,
{args[0]->getType()});
} else {
fn = llvm::Intrinsic::getDeclaration(module, intrinsic, {});
}
} else {
fn = llvm::Intrinsic::getDeclaration(module, intrinsic);
}
return builder.CreateCall(fn, args);
}
/// Given a single MLIR operation, create the corresponding LLVM IR operation
/// using the `builder`.
LogicalResult