[MLIR] change NVVM.mma.sync to the most useful variant.

Summary:
the .row.col variant turns out to be the popular one, contrary to what I
thought as .row.row. Since .row.col is so prevailing (as I inspect
cuDNN's behavior), I'm going to remove the .row.row support here, which
makes the patch a little bit easier.

Reviewers: ftynse

Subscribers: jholewinski, bixia, sanjoy.google, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D74655
This commit is contained in:
Tim Shen
2020-02-14 15:07:44 -08:00
parent f581e655ec
commit b762bbd4c8
5 changed files with 10 additions and 10 deletions

View File

@@ -125,7 +125,7 @@ def NVVM_MmaOp :
Arguments<(ins Variadic<LLVM_Type>:$args)> {
string llvmBuilder = [{
$res = createIntrinsicCall(
builder, llvm::Intrinsic::nvvm_mma_m8n8k4_row_row_f32_f32, $args);
builder, llvm::Intrinsic::nvvm_mma_m8n8k4_row_col_f32_f32, $args);
}];
let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)";
let verifier = [{ return ::verify(*this); }];

View File

@@ -131,7 +131,7 @@ static LogicalResult verify(MmaOp op) {
f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
f32Ty, f32Ty, f32Ty} &&
op.getType() == f32x8StructTy && alayout.getValue() == "row" &&
blayout.getValue() == "row") {
blayout.getValue() == "col") {
return success();
}
return op.emitOpError("unimplemented mma.sync variant");

View File

@@ -295,7 +295,7 @@ func @nvvm_invalid_mma_0(%a0 : !llvm.half, %a1 : !llvm<"<2 x half>">,
%c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float,
%c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) {
// expected-error@+1 {{expected operands to be 4 <halfx2>s followed by either 4 <halfx2>s or 8 floats}}
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="row"} : (!llvm.half, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (!llvm.half, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
llvm.return %0 : !llvm<"{ float, float, float, float, float, float, float, float }">
}
@@ -307,7 +307,7 @@ func @nvvm_invalid_mma_1(%a0 : !llvm<"<2 x half>">, %a1 : !llvm<"<2 x half>">,
%c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float,
%c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) {
// expected-error@+1 {{expected result type to be a struct of either 4 <halfx2>s or 8 floats}}
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, half }">
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, half }">
llvm.return %0 : !llvm<"{ float, float, float, float, float, float, float, half }">
}
@@ -331,7 +331,7 @@ func @nvvm_invalid_mma_3(%a0 : !llvm<"<2 x half>">, %a1 : !llvm<"<2 x half>">,
%c0 : !llvm<"<2 x half>">, %c1 : !llvm<"<2 x half>">,
%c2 : !llvm<"<2 x half>">, %c3 : !llvm<"<2 x half>">) {
// expected-error@+1 {{unimplemented mma.sync variant}}
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3 {alayout="row", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">) -> !llvm<"{ float, float, float, float, float, float, float, float }">
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3 {alayout="row", blayout="col"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">) -> !llvm<"{ float, float, float, float, float, float, float, float }">
llvm.return %0 : !llvm<"{ float, float, float, float, float, float, float, float }">
}
@@ -343,7 +343,7 @@ func @nvvm_invalid_mma_4(%a0 : !llvm<"<2 x half>">, %a1 : !llvm<"<2 x half>">,
%c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float,
%c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) {
// expected-error@+1 {{unimplemented mma.sync variant}}
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{<2 x half>, <2 x half>, <2 x half>, <2 x half>}">
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{<2 x half>, <2 x half>, <2 x half>, <2 x half>}">
llvm.return %0 : !llvm<"{<2 x half>, <2 x half>, <2 x half>, <2 x half>}">
}

View File

@@ -64,7 +64,7 @@ func @nvvm_mma(%a0 : !llvm<"<2 x half>">, %a1 : !llvm<"<2 x half>">,
%b0 : !llvm<"<2 x half>">, %b1 : !llvm<"<2 x half>">,
%c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float,
%c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) {
// CHECK: nvvm.mma.sync {{.*}} {alayout = "row", blayout = "row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
// CHECK: nvvm.mma.sync {{.*}} {alayout = "row", blayout = "col"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
llvm.return %0 : !llvm<"{ float, float, float, float, float, float, float, float }">
}

View File

@@ -68,8 +68,8 @@ llvm.func @nvvm_mma(%a0 : !llvm<"<2 x half>">, %a1 : !llvm<"<2 x half>">,
%b0 : !llvm<"<2 x half>">, %b1 : !llvm<"<2 x half>">,
%c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float,
%c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) {
// CHECK: call { float, float, float, float, float, float, float, float } @llvm.nvvm.mma.m8n8k4.row.row.f32.f32
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
// CHECK: call { float, float, float, float, float, float, float, float } @llvm.nvvm.mma.m8n8k4.row.col.f32.f32
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
llvm.return %0 : !llvm<"{ float, float, float, float, float, float, float, float }">
}