[mlir][gpu] Add support for integer types in gpu.subgroup_mma ops

The signedness is carried by `!gpu.mma_matrix` types to most closely
match the Cooperative Matrix specification which determines signedness
with the type (and sometimes the operation).

See: https://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/NV/SPV_NV_cooperative_matrix.html

To handle the lowering from vector to gpu, ops such as arith.extsi are
pattern matched next to `vector.transfer_read` and `vector.contract` to
determine the signedness of the matrix type.

Enables s8 and u8 WMMA types in NVVM for the GPUToNVVM conversion.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D143223
This commit is contained in:
Quinn Dawkins
2023-02-05 23:53:38 -05:00
parent 622be09c81
commit 985f7ff632
11 changed files with 265 additions and 41 deletions

View File

@@ -101,7 +101,7 @@ def GPU_MMAMatrix : DialectType<
GPU_Dialect, IsMMAMatrixTypePred, "MMAMatrix type">;
// Memref type acceptable to gpu.subgroup_mma_{load|store}_matrix ops.
def GPU_MMAMemRef : MemRefOf<[F16, F32, VectorOfRankAndType<[1], [F16, F32]>]>;
def GPU_MMAMemRef : MemRefOf<[I8, I32, F16, F32, VectorOfRankAndType<[1], [I8, I32, F16, F32]>]>;
class MMAMatrixOf<list<Type> allowedTypes> :
ContainerType<AnyTypeOf<allowedTypes>, IsMMAMatrixTypePred,

View File

@@ -1150,6 +1150,10 @@ def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix",
matrix which eventually allows the lowering to determine the size of each
row. If the `transpose` attribute is present then the op does a transposed load.
For integer types, the resulting `!gpu.mma_matrix` type needs to specify the
signedness of the data if the matrix type is an `A` or `B` operand for
`gpu.subgroup_mma_compute`.
This op is often meant to be used along with `gpu.subgroup_mma_store_matrix` and
`gpu.subgroup_mma_compute`.
@@ -1201,7 +1205,7 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
```
}];
let arguments = (ins Arg<MMAMatrixOf<[F16, F32]>>:$src,
let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, I32, F16, F32]>>:$src,
Arg<GPU_MMAMemRef, "",[MemWrite]>:$dstMemref,
Variadic<Index>:$indices,
IndexAttr:$leadDimension,
@@ -1227,11 +1231,15 @@ def GPU_SubgroupMmaComputeOp
as `C += A * B`. The op returns a `!gpu.mma_matrix` which contains the result of
the operation held by all threads in a subgroup. `a_transpose` or
`b_transpose` if present, signify that the respective operand was loaded in a
transposed manner. The transpose opernads are required to map to correct
transposed manner. The transpose operands are required to map to correct
underlying intrisics but they currently do not seem to affect correctness
even if they are absent given that the operands were loaded correctly using
the `transpose` attribute in `gpu.subgroup_mma_load_matrix` op.
For integer types, the `A` and `B` matrices carry their signedness with their
types. The accumulator type is expected to be signless and imply a signed integer
with a greater width than the other two operands.
This op is meant to be used along with `gpu.subgroup_mma_store_matrix` and
`gpu.subgroup_mma_load_matrix` ops.
@@ -1244,9 +1252,9 @@ def GPU_SubgroupMmaComputeOp
```
}];
let arguments = (ins Arg<MMAMatrixOf<[F16, F32]>>:$opA,
Arg<MMAMatrixOf<[F16, F32]>>:$opB,
Arg<MMAMatrixOf<[F16, F32]>>:$opC,
let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, F16, F32]>>:$opA,
Arg<MMAMatrixOf<[SI8, UI8, F16, F32]>>:$opB,
Arg<MMAMatrixOf<[I32, F16, F32]>>:$opC,
OptionalAttr<UnitAttr>:$a_transpose,
OptionalAttr<UnitAttr>:$b_transpose);
@@ -1288,7 +1296,7 @@ def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix",
```
}];
let arguments = (ins AnyTypeOf<[F16, F32]>:$value);
let arguments = (ins AnyTypeOf<[SI8, UI8, I32, F16, F32]>:$value);
let results = (outs GPU_MMAMatrix:$res);

View File

@@ -37,7 +37,8 @@ enum NVVMMemorySpace {
/// of given chracteristics. This matches the logic in IntrinsicsNVVM.td
/// WMMA_REGS structure.
std::pair<mlir::Type, unsigned> inferMMAType(mlir::NVVM::MMATypes type,
mlir::NVVM::MMAFrag frag,
mlir::NVVM::MMAFrag frag, int nRow,
int nCol,
mlir::MLIRContext *context);
} // namespace NVVM
} // namespace mlir

View File

@@ -385,16 +385,20 @@ class NVVM_MMA_OPS {
list<list<WMMA_REGS>> fp_wmma_ops = MMA_OPS<
[GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>],
["f16"], [], ["f16", "f32"], []>.ret;
list<list<WMMA_REGS>> i8_wmma_ops = MMA_OPS<
[GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>],
["s8","u8"], [], ["s32"], []>.ret;
list<list<WMMA_REGS>> all_wmma_ops = !listconcat(
tf32_wmma_ops,
fp_wmma_ops);
fp_wmma_ops,
i8_wmma_ops);
list<WMMA_REGS> ldst_ab_ops = MMA_LDST_OPS<
[GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>],
["a", "b"], ["f16"]>.ret;
["a", "b"], ["f16","s8","u8"]>.ret;
list<WMMA_REGS> ldst_cd_ops = MMA_LDST_OPS<
[GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>],
["c", "d"], ["f16", "f32"]>.ret;
["c", "d"], ["f16", "f32","s32"]>.ret;
list<WMMA_REGS> ldst_tf32_ab_ops = MMA_LDST_OPS<
[GEOM<16, 16, 8>],
["a", "b"], ["tf32"]>.ret;

View File

@@ -57,6 +57,12 @@ static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
if (type.getElementType().isF32())
return type.getOperand().equals("COp") ? NVVM::MMATypes::f32
: NVVM::MMATypes::tf32;
if (type.getElementType().isSignedInteger(8))
return NVVM::MMATypes::s8;
// Accumulator type is signless and implies signed.
if (type.getElementType().isInteger(32))
return NVVM::MMATypes::s32;
llvm_unreachable("Unsupported type");
}
@@ -106,8 +112,11 @@ struct WmmaLoadOpToNVVMLowering
}
NVVM::MMAFrag frag = convertOperand(retType.getOperand());
// Check that there is an exisiting instruction for the combination we need.
if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0)
if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0) {
llvm::errs() << "No matching intrinsic " << m << " " << n << " " << k
<< "\n";
return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
}
Type resType = convertMMAToLLVMType(retType);
Location loc = op->getLoc();
@@ -366,8 +375,10 @@ struct WmmaElementwiseOpToNVVMLowering
LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
NVVM::MMAFrag frag = convertOperand(type.getOperand());
NVVM::MMATypes eltType = getElementType(type);
auto nRow = type.getShape()[0];
auto nCol = type.getShape()[1];
std::pair<Type, unsigned> typeInfo =
NVVM::inferMMAType(eltType, frag, type.getContext());
NVVM::inferMMAType(eltType, frag, nRow, nCol, type.getContext());
return LLVM::LLVMStructType::getLiteral(
type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
}

View File

@@ -140,6 +140,12 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp,
return false;
if (!getMemrefConstantHorizontalStride(readOp.getShapedType()))
return false;
// Only allow integer types if the signedness can be inferred.
if (!useNvGpu && readOp.getVectorType().getElementType().isInteger(8))
if (!readOp->hasOneUse() || !isa<arith::ExtSIOp>(*readOp->user_begin()))
return false;
AffineMap map = readOp.getPermutationMap();
OpBuilder b(readOp.getContext());
AffineExpr innerDim = b.getAffineDimExpr(map.getNumDims() - 1);
@@ -185,8 +191,16 @@ static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
/// Return true if this is a broadcast from scalar to a 2D vector.
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
return broadcastOp.getVectorType().getRank() == 2 &&
broadcastOp.getSource().getType().isa<FloatType>();
return broadcastOp.getVectorType().getRank() == 2;
}
/// Return true if this signed extend op can be folded into a contract op.
static bool signedExtendSupportsMMAMatrixType(arith::ExtSIOp extOp) {
if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp()))
return false;
return llvm::all_of(extOp->getUsers(), [](Operation *user) {
return isa<vector::ContractionOp>(user);
});
}
/// Return the MMA elementwise enum associated with `op` if it is supported.
@@ -268,6 +282,8 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
return constantSupportsMMAMatrixType(constant);
if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
return broadcastSupportsMMAMatrixType(broadcast);
if (auto extend = dyn_cast<arith::ExtSIOp>(op))
return signedExtendSupportsMMAMatrixType(extend);
return elementwiseSupportsMMAMatrixType(op);
}
@@ -411,8 +427,18 @@ struct CombineTransferReadOpTranspose final
LogicalResult matchAndRewrite(vector::TransposeOp op,
PatternRewriter &rewriter) const override {
auto transferReadOp =
op.getVector().getDefiningOp<vector::TransferReadOp>();
// Look through integer extend ops.
Value source = op.getVector();
auto extOp = source.getDefiningOp<arith::ExtSIOp>();
auto resultType = op.getVectorType();
if (extOp) {
source = extOp.getOperand();
resultType =
VectorType::get(resultType.getShape(),
source.getType().cast<VectorType>().getElementType());
}
auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>();
if (!transferReadOp)
return failure();
@@ -431,11 +457,23 @@ struct CombineTransferReadOpTranspose final
AffineMap::getPermutationMap(permU, op.getContext());
AffineMap newMap =
permutationMap.compose(transferReadOp.getPermutationMap());
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
op, op.getType(), transferReadOp.getSource(),
transferReadOp.getIndices(), AffineMapAttr::get(newMap),
transferReadOp.getPadding(), transferReadOp.getMask(),
transferReadOp.getInBoundsAttr());
auto loc = op.getLoc();
Value result =
rewriter
.create<vector::TransferReadOp>(
loc, resultType, transferReadOp.getSource(),
transferReadOp.getIndices(), AffineMapAttr::get(newMap),
transferReadOp.getPadding(), transferReadOp.getMask(),
transferReadOp.getInBoundsAttr())
.getResult();
// Fuse through the integer extend op.
if (extOp)
result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
.getResult();
rewriter.replaceOp(op, result);
return success();
}
};
@@ -479,14 +517,26 @@ static void convertTransferReadOp(vector::TransferReadOp op,
stride = 0;
}
assert(stride);
Value mappingResult = op.getResult();
auto elType = op.getVectorType().getElementType();
const char *fragType = inferFragType(op);
if (op->hasOneUse()) {
auto extOp = dyn_cast<arith::ExtSIOp>(*op->user_begin());
// Infer the signedness of the mma type from the signed extend.
if (extOp) {
elType = IntegerType::get(op.getContext(),
elType.cast<IntegerType>().getWidth(),
IntegerType::Signed);
mappingResult = extOp.getResult();
fragType = inferFragType(extOp);
}
}
gpu::MMAMatrixType type =
gpu::MMAMatrixType::get(op.getVectorType().getShape(),
op.getVectorType().getElementType(), fragType);
gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>(
op.getLoc(), type, op.getSource(), op.getIndices(),
b.getIndexAttr(*stride), isTranspose ? b.getUnitAttr() : UnitAttr());
valueMapping[op.getResult()] = load;
valueMapping[mappingResult] = load;
}
static void convertTransferWriteOp(vector::TransferWriteOp op,

View File

@@ -78,7 +78,9 @@ Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
bool MMAMatrixType::isValidElementType(Type elementType) {
return elementType.isF16() || elementType.isF32();
return elementType.isF16() || elementType.isF32() ||
elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
elementType.isInteger(32);
}
LogicalResult
@@ -93,7 +95,8 @@ MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
return emitError() << "MMAMatrixType must have exactly two dimensions";
if (!MMAMatrixType::isValidElementType(elementType))
return emitError() << "MMAMatrixType elements must be F16 or F32";
return emitError()
<< "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
return success();
}

View File

@@ -537,7 +537,8 @@ LogicalResult ShflOp::verify() {
}
std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
NVVM::MMAFrag frag,
NVVM::MMAFrag frag, int nRow,
int nCol,
MLIRContext *context) {
unsigned numberElements = 0;
Type elementType;
@@ -555,11 +556,48 @@ std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
} else if (type == NVVM::MMATypes::tf32) {
elementType = builder.getI32Type();
numberElements = 4;
} else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
elementType = builder.getI32Type();
int parallelSize = 0;
if (frag == NVVM::MMAFrag::a)
parallelSize = nRow;
if (frag == NVVM::MMAFrag::b)
parallelSize = nCol;
// m == 16 && n == 16 && k == 16
if (parallelSize == 16)
numberElements = 2;
// m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
else if (parallelSize == 8)
numberElements = 1;
else if (parallelSize == 32)
numberElements = 4;
} else if (type == NVVM::MMATypes::s32) {
elementType = builder.getI32Type();
numberElements = 8;
}
assert(numberElements != 0 && elementType != nullptr);
return std::make_pair(elementType, numberElements);
}
static std::pair<mlir::Type, unsigned>
inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
int k, MLIRContext *context) {
int nRow, nCol;
if (frag == NVVM::MMAFrag::a) {
nRow = m;
nCol = k;
} else if (frag == NVVM::MMAFrag::b) {
nRow = k;
nCol = n;
} else {
nRow = m;
nCol = n;
}
assert(nRow && nCol);
return inferMMAType(type, frag, nRow, nCol, context);
}
LogicalResult NVVM::WMMALoadOp::verify() {
unsigned addressSpace =
getPtr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
@@ -570,8 +608,8 @@ LogicalResult NVVM::WMMALoadOp::verify() {
if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
getEltype(), getFrag()) == 0)
return emitOpError() << "invalid attribute combination";
std::pair<Type, unsigned> typeInfo =
inferMMAType(getEltype(), getFrag(), getContext());
std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
getEltype(), getFrag(), getM(), getN(), getK(), getContext());
Type dstType = LLVM::LLVMStructType::getLiteral(
getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
if (getType() != dstType)
@@ -590,8 +628,8 @@ LogicalResult NVVM::WMMAStoreOp::verify() {
if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
getEltype()) == 0)
return emitOpError() << "invalid attribute combination";
std::pair<Type, unsigned> typeInfo =
inferMMAType(getEltype(), NVVM::MMAFrag::c, getContext());
std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
if (getArgs().size() != typeInfo.second)
return emitOpError() << "expected " << typeInfo.second << " data operands";
if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
@@ -606,12 +644,12 @@ LogicalResult NVVM::WMMAMmaOp::verify() {
getLayoutB(), getEltypeA(),
getEltypeB()) == 0)
return emitOpError() << "invalid attribute combination";
std::pair<Type, unsigned> typeInfoA =
inferMMAType(getEltypeA(), NVVM::MMAFrag::a, getContext());
std::pair<Type, unsigned> typeInfoB =
inferMMAType(getEltypeA(), NVVM::MMAFrag::b, getContext());
std::pair<Type, unsigned> typeInfoC =
inferMMAType(getEltypeB(), NVVM::MMAFrag::c, getContext());
std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
SmallVector<Type, 32> arguments;
arguments.append(typeInfoA.second, typeInfoA.first);
arguments.append(typeInfoB.second, typeInfoB.first);

View File

@@ -40,6 +40,45 @@ gpu.module @test_module {
// -----
gpu.module @test_module {
// CHECK-LABEL: func @gpu_wmma_int8_load_op() ->
// CHECK-SAME: !llvm.struct<(i32, i32)>
// CHECK32-LABEL: func @gpu_wmma_int8_load_op() ->
func.func @gpu_wmma_int8_load_op() -> (!gpu.mma_matrix<16x16xsi8, "AOp">) {
%wg = memref.alloca() {alignment = 32} : memref<32x32xi8, 3>
%i = arith.constant 16 : index
%j = arith.constant 16 : index
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xi8, 3> -> !gpu.mma_matrix<16x16xsi8, "AOp">
// CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64
// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
// CHECK: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<i8, 3>, ptr<i8, 3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64
// CHECK: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i64
// CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64
// CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<i8, 3>, i64) -> !llvm.ptr<i8, 3>
// CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
// CHECK: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]]
// CHECK-SAME: {eltype = #nvvm.mma_type<s8>, frag = #nvvm.mma_frag<a>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32} : (!llvm.ptr<i8, 3>) -> !llvm.struct<(i32, i32)>
// CHECK: llvm.return %[[FRAG]] : !llvm.struct<(i32, i32)>
// CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
// CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
// CHECK32: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<i8, 3>, ptr<i8, 3>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
// CHECK32: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i32
// CHECK32: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32
// CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<i8, 3>, i32) -> !llvm.ptr<i8, 3>
// CHECK32: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
// CHECK32: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]]
// CHECK32-SAME: {eltype = #nvvm.mma_type<s8>, frag = #nvvm.mma_frag<a>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32} : (!llvm.ptr<i8, 3>) -> !llvm.struct<(i32, i32)>
// CHECK32: llvm.return %[[FRAG]] : !llvm.struct<(i32, i32)>
return %0 : !gpu.mma_matrix<16x16xsi8, "AOp">
}
}
// -----
gpu.module @test_module {
// CHECK-LABEL: func @gpu_wmma_store_op
@@ -124,6 +163,35 @@ gpu.module @test_module {
// -----
gpu.module @test_module {
// CHECK-LABEL: func @gpu_wmma_mma_int8_op
// CHECK-SAME: (%[[A:.*]]: !llvm.struct<(i32, i32, i32, i32)>, %[[B:.*]]: !llvm.struct<(i32)>, %[[C:.*]]: !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>)
func.func @gpu_wmma_mma_int8_op(%A : !gpu.mma_matrix<32x16xsi8, "AOp">, %B : !gpu.mma_matrix<16x8xsi8, "BOp">, %C : !gpu.mma_matrix<32x8xi32, "COp">) -> (!gpu.mma_matrix<32x8xi32, "COp">) {
%D = gpu.subgroup_mma_compute %A, %B, %C {a_transpose} : !gpu.mma_matrix<32x16xsi8, "AOp">, !gpu.mma_matrix<16x8xsi8, "BOp"> -> !gpu.mma_matrix<32x8xi32, "COp">
// CHECK: %[[A1:.*]] = llvm.extractvalue %[[A]][0] : !llvm.struct<(i32, i32, i32, i32)>
// CHECK: %[[A2:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(i32, i32, i32, i32)>
// CHECK: %[[A3:.*]] = llvm.extractvalue %[[A]][2] : !llvm.struct<(i32, i32, i32, i32)>
// CHECK: %[[A4:.*]] = llvm.extractvalue %[[A]][3] : !llvm.struct<(i32, i32, i32, i32)>
// CHECK: %[[B1:.*]] = llvm.extractvalue %[[B]][0] : !llvm.struct<(i32)>
// CHECK: %[[C1:.*]] = llvm.extractvalue %[[C]][0] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK: %[[C2:.*]] = llvm.extractvalue %[[C]][1] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK: %[[C3:.*]] = llvm.extractvalue %[[C]][2] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK: %[[C4:.*]] = llvm.extractvalue %[[C]][3] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK: %[[C5:.*]] = llvm.extractvalue %[[C]][4] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK: %[[C6:.*]] = llvm.extractvalue %[[C]][5] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK: %[[C7:.*]] = llvm.extractvalue %[[C]][6] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK: %[[C8:.*]] = llvm.extractvalue %[[C]][7] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK: %[[RES:.*]] = nvvm.wmma.mma %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[B1]], %[[C1]], %[[C2]], %[[C3]], %[[C4]], %[[C5]], %[[C6]], %[[C7]], %[[C8]]
// CHECK-SAME: {eltypeA = #nvvm.mma_type<s8>, eltypeB = #nvvm.mma_type<s32>, k = 16 : i32, layoutA = #nvvm.mma_layout<col>, layoutB = #nvvm.mma_layout<row>, m = 32 : i32, n = 8 : i32} : (
// CHECK-SAME: i32, {{.*}}) -> !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
// CHECK: llvm.return %[[RES]] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
return %D : !gpu.mma_matrix<32x8xi32, "COp">
}
}
// -----
gpu.module @test_module {
// CHECK-LABEL: func @gpu_wmma_mma_loop_op

View File

@@ -225,3 +225,44 @@ func.func @matmul_transposed_broadcasted_2d(%arg0: memref<32x32xf16>, %arg1: mem
vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
return
}
// Do not convert to subgroup_mma ops with integer types if signedness cannot be inferred.
// CHECK-LABEL: func @matmul_no_extend_int8
// CHECK-DAG: %[[A:.+]] = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8>
// CHECK-DAG: %[[B:.+]] = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8>
// CHECK-DAG: %[[C:.+]] = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<16x16xi32>, vector<16x16xi32>
// CHECK: %[[D:.+]] = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[A]], %[[B]], %[[C]] : vector<16x16xi8>, vector<16x16xi8> into vector<16x16xi32>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32>
func.func @matmul_no_extend_int8(%arg0: memref<16x16xi8>, %arg1: memref<16x16xi8>, %arg2: memref<16x16xi32>) {
%cst_0 = arith.constant dense<0> : vector<16x16xi8>
%c0 = arith.constant 0 : index
%cst_i8 = arith.constant 0 : i8
%cst_i32 = arith.constant 0 : i32
%A = vector.transfer_read %arg0[%c0, %c0], %cst_i8 {in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8>
%B = vector.transfer_read %arg1[%c0, %c0], %cst_i8 {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8>
%C = vector.transfer_read %arg2[%c0, %c0], %cst_i32 {in_bounds = [true, true]} : memref<16x16xi32>, vector<16x16xi32>
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xi8>, vector<16x16xi8> into vector<16x16xi32>
vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32>
return
}
// CHECK-LABEL: func @matmul_int8
// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi8> -> !gpu.mma_matrix<16x16xsi8, "AOp">
// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi8> -> !gpu.mma_matrix<16x16xsi8, "BOp">
// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi32> -> !gpu.mma_matrix<16x16xi32, "COp">
// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xsi8, "AOp">, !gpu.mma_matrix<16x16xsi8, "BOp"> -> !gpu.mma_matrix<16x16xi32, "COp">
// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xi32, "COp">, memref<16x16xi32>
func.func @matmul_int8(%arg0: memref<16x16xi8>, %arg1: memref<16x16xi8>, %arg2: memref<16x16xi32>) {
%cst_0 = arith.constant dense<0> : vector<16x16xi8>
%c0 = arith.constant 0 : index
%cst_i8 = arith.constant 0 : i8
%cst_i32 = arith.constant 0 : i32
%Ar = vector.transfer_read %arg0[%c0, %c0], %cst_i8 {in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8>
%Br = vector.transfer_read %arg1[%c0, %c0], %cst_i8 {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8>
%C = vector.transfer_read %arg2[%c0, %c0], %cst_i32 {in_bounds = [true, true]} : memref<16x16xi32>, vector<16x16xi32>
%Ae = arith.extsi %Ar : vector<16x16xi8> to vector<16x16xi32>
%Be = arith.extsi %Br : vector<16x16xi8> to vector<16x16xi32>
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %Ae, %Be, %C : vector<16x16xi32>, vector<16x16xi32> into vector<16x16xi32>
vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32>
return
}

View File

@@ -485,8 +485,8 @@ func.func @mmamatrix_operand_type(){
func.func @mmamatrix_invalid_element_type(){
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
%i = arith.constant 16 : index
// expected-error @+1 {{MMAMatrixType elements must be F16 or F32}}
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xi32, "AOp">
// expected-error @+1 {{MMAMatrixType elements must be SI8, UI8, I32, F16, or F32}}
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xbf16, "AOp">
return
}
@@ -505,7 +505,7 @@ func.func @mmaLoadOp_identity_layout(){
// -----
func.func @mma_invalid_memref_type(%src: memref<32x4xvector<4x8xf32>>, %i: index) {
// expected-error @+1 {{operand #0 must be memref of 16-bit float or 32-bit float or vector of 16-bit float or 32-bit float values of ranks 1 values}}
// expected-error @+1 {{operand #0 must be memref of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float or vector of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float values of ranks 1 values}}
%0 = gpu.subgroup_mma_load_matrix %src[%i, %i] {leadDimension = 4 : index} : memref<32x4xvector<4x8xf32>> -> !gpu.mma_matrix<16x16xf16, "AOp">
return
}