mirror of
https://github.com/intel/llvm.git
synced 2026-01-24 08:30:34 +08:00
[MLIR] change NVVM.mma.sync to the most useful variant.
Summary: the .row.col variant turns out to be the popular one, contrary to what I thought as .row.row. Since .row.col is so prevailing (as I inspect cuDNN's behavior), I'm going to remove the .row.row support here, which makes the patch a little bit easier. Reviewers: ftynse Subscribers: jholewinski, bixia, sanjoy.google, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D74655
This commit is contained in:
@@ -125,7 +125,7 @@ def NVVM_MmaOp :
|
||||
Arguments<(ins Variadic<LLVM_Type>:$args)> {
|
||||
string llvmBuilder = [{
|
||||
$res = createIntrinsicCall(
|
||||
builder, llvm::Intrinsic::nvvm_mma_m8n8k4_row_row_f32_f32, $args);
|
||||
builder, llvm::Intrinsic::nvvm_mma_m8n8k4_row_col_f32_f32, $args);
|
||||
}];
|
||||
let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)";
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
|
||||
@@ -131,7 +131,7 @@ static LogicalResult verify(MmaOp op) {
|
||||
f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
|
||||
f32Ty, f32Ty, f32Ty} &&
|
||||
op.getType() == f32x8StructTy && alayout.getValue() == "row" &&
|
||||
blayout.getValue() == "row") {
|
||||
blayout.getValue() == "col") {
|
||||
return success();
|
||||
}
|
||||
return op.emitOpError("unimplemented mma.sync variant");
|
||||
|
||||
@@ -295,7 +295,7 @@ func @nvvm_invalid_mma_0(%a0 : !llvm.half, %a1 : !llvm<"<2 x half>">,
|
||||
%c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float,
|
||||
%c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) {
|
||||
// expected-error@+1 {{expected operands to be 4 <halfx2>s followed by either 4 <halfx2>s or 8 floats}}
|
||||
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="row"} : (!llvm.half, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
|
||||
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (!llvm.half, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
|
||||
llvm.return %0 : !llvm<"{ float, float, float, float, float, float, float, float }">
|
||||
}
|
||||
|
||||
@@ -307,7 +307,7 @@ func @nvvm_invalid_mma_1(%a0 : !llvm<"<2 x half>">, %a1 : !llvm<"<2 x half>">,
|
||||
%c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float,
|
||||
%c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) {
|
||||
// expected-error@+1 {{expected result type to be a struct of either 4 <halfx2>s or 8 floats}}
|
||||
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, half }">
|
||||
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, half }">
|
||||
llvm.return %0 : !llvm<"{ float, float, float, float, float, float, float, half }">
|
||||
}
|
||||
|
||||
@@ -331,7 +331,7 @@ func @nvvm_invalid_mma_3(%a0 : !llvm<"<2 x half>">, %a1 : !llvm<"<2 x half>">,
|
||||
%c0 : !llvm<"<2 x half>">, %c1 : !llvm<"<2 x half>">,
|
||||
%c2 : !llvm<"<2 x half>">, %c3 : !llvm<"<2 x half>">) {
|
||||
// expected-error@+1 {{unimplemented mma.sync variant}}
|
||||
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3 {alayout="row", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">) -> !llvm<"{ float, float, float, float, float, float, float, float }">
|
||||
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3 {alayout="row", blayout="col"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">) -> !llvm<"{ float, float, float, float, float, float, float, float }">
|
||||
llvm.return %0 : !llvm<"{ float, float, float, float, float, float, float, float }">
|
||||
}
|
||||
|
||||
@@ -343,7 +343,7 @@ func @nvvm_invalid_mma_4(%a0 : !llvm<"<2 x half>">, %a1 : !llvm<"<2 x half>">,
|
||||
%c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float,
|
||||
%c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) {
|
||||
// expected-error@+1 {{unimplemented mma.sync variant}}
|
||||
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{<2 x half>, <2 x half>, <2 x half>, <2 x half>}">
|
||||
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{<2 x half>, <2 x half>, <2 x half>, <2 x half>}">
|
||||
llvm.return %0 : !llvm<"{<2 x half>, <2 x half>, <2 x half>, <2 x half>}">
|
||||
}
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ func @nvvm_mma(%a0 : !llvm<"<2 x half>">, %a1 : !llvm<"<2 x half>">,
|
||||
%b0 : !llvm<"<2 x half>">, %b1 : !llvm<"<2 x half>">,
|
||||
%c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float,
|
||||
%c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) {
|
||||
// CHECK: nvvm.mma.sync {{.*}} {alayout = "row", blayout = "row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
|
||||
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
|
||||
// CHECK: nvvm.mma.sync {{.*}} {alayout = "row", blayout = "col"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
|
||||
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
|
||||
llvm.return %0 : !llvm<"{ float, float, float, float, float, float, float, float }">
|
||||
}
|
||||
|
||||
@@ -68,8 +68,8 @@ llvm.func @nvvm_mma(%a0 : !llvm<"<2 x half>">, %a1 : !llvm<"<2 x half>">,
|
||||
%b0 : !llvm<"<2 x half>">, %b1 : !llvm<"<2 x half>">,
|
||||
%c0 : !llvm.float, %c1 : !llvm.float, %c2 : !llvm.float, %c3 : !llvm.float,
|
||||
%c4 : !llvm.float, %c5 : !llvm.float, %c6 : !llvm.float, %c7 : !llvm.float) {
|
||||
// CHECK: call { float, float, float, float, float, float, float, float } @llvm.nvvm.mma.m8n8k4.row.row.f32.f32
|
||||
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="row"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
|
||||
// CHECK: call { float, float, float, float, float, float, float, float } @llvm.nvvm.mma.m8n8k4.row.col.f32.f32
|
||||
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (!llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm<"<2 x half>">, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float, !llvm.float) -> !llvm<"{ float, float, float, float, float, float, float, float }">
|
||||
llvm.return %0 : !llvm<"{ float, float, float, float, float, float, float, float }">
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user