mirror of
https://github.com/intel/llvm.git
synced 2026-02-05 13:21:04 +08:00
[mlir][gpu] Add support for integer types in gpu.subgroup_mma ops
The signedness is carried by `!gpu.mma_matrix` types to most closely match the Cooperative Matrix specification which determines signedness with the type (and sometimes the operation). See: https://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/NV/SPV_NV_cooperative_matrix.html To handle the lowering from vector to gpu, ops such as arith.extsi are pattern matched next to `vector.transfer_read` and `vector.contract` to determine the signedness of the matrix type. Enables s8 and u8 WMMA types in NVVM for the GPUToNVVM conversion. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D143223
This commit is contained in:
@@ -101,7 +101,7 @@ def GPU_MMAMatrix : DialectType<
|
||||
GPU_Dialect, IsMMAMatrixTypePred, "MMAMatrix type">;
|
||||
|
||||
// Memref type acceptable to gpu.subgroup_mma_{load|store}_matrix ops.
|
||||
def GPU_MMAMemRef : MemRefOf<[F16, F32, VectorOfRankAndType<[1], [F16, F32]>]>;
|
||||
def GPU_MMAMemRef : MemRefOf<[I8, I32, F16, F32, VectorOfRankAndType<[1], [I8, I32, F16, F32]>]>;
|
||||
|
||||
class MMAMatrixOf<list<Type> allowedTypes> :
|
||||
ContainerType<AnyTypeOf<allowedTypes>, IsMMAMatrixTypePred,
|
||||
|
||||
@@ -1150,6 +1150,10 @@ def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix",
|
||||
matrix which eventually allows the lowering to determine the size of each
|
||||
row. If the `transpose` attribute is present then the op does a transposed load.
|
||||
|
||||
For integer types, the resulting `!gpu.mma_matrix` type needs to specify the
|
||||
signedness of the data if the matrix type is an `A` or `B` operand for
|
||||
`gpu.subgroup_mma_compute`.
|
||||
|
||||
This op is often meant to be used along with `gpu.subgroup_mma_store_matrix` and
|
||||
`gpu.subgroup_mma_compute`.
|
||||
|
||||
@@ -1201,7 +1205,7 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Arg<MMAMatrixOf<[F16, F32]>>:$src,
|
||||
let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, I32, F16, F32]>>:$src,
|
||||
Arg<GPU_MMAMemRef, "",[MemWrite]>:$dstMemref,
|
||||
Variadic<Index>:$indices,
|
||||
IndexAttr:$leadDimension,
|
||||
@@ -1227,11 +1231,15 @@ def GPU_SubgroupMmaComputeOp
|
||||
as `C += A * B`. The op returns a `!gpu.mma_matrix` which contains the result of
|
||||
the operation held by all threads in a subgroup. `a_transpose` or
|
||||
`b_transpose` if present, signify that the respective operand was loaded in a
|
||||
transposed manner. The transpose opernads are required to map to correct
|
||||
transposed manner. The transpose operands are required to map to correct
|
||||
underlying intrisics but they currently do not seem to affect correctness
|
||||
even if they are absent given that the operands were loaded correctly using
|
||||
the `transpose` attribute in `gpu.subgroup_mma_load_matrix` op.
|
||||
|
||||
For integer types, the `A` and `B` matrices carry their signedness with their
|
||||
types. The accumulator type is expected to be signless and imply a signed integer
|
||||
with a greater width than the other two operands.
|
||||
|
||||
This op is meant to be used along with `gpu.subgroup_mma_store_matrix` and
|
||||
`gpu.subgroup_mma_load_matrix` ops.
|
||||
|
||||
@@ -1244,9 +1252,9 @@ def GPU_SubgroupMmaComputeOp
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Arg<MMAMatrixOf<[F16, F32]>>:$opA,
|
||||
Arg<MMAMatrixOf<[F16, F32]>>:$opB,
|
||||
Arg<MMAMatrixOf<[F16, F32]>>:$opC,
|
||||
let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, F16, F32]>>:$opA,
|
||||
Arg<MMAMatrixOf<[SI8, UI8, F16, F32]>>:$opB,
|
||||
Arg<MMAMatrixOf<[I32, F16, F32]>>:$opC,
|
||||
OptionalAttr<UnitAttr>:$a_transpose,
|
||||
OptionalAttr<UnitAttr>:$b_transpose);
|
||||
|
||||
@@ -1288,7 +1296,7 @@ def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix",
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTypeOf<[F16, F32]>:$value);
|
||||
let arguments = (ins AnyTypeOf<[SI8, UI8, I32, F16, F32]>:$value);
|
||||
|
||||
let results = (outs GPU_MMAMatrix:$res);
|
||||
|
||||
|
||||
@@ -37,7 +37,8 @@ enum NVVMMemorySpace {
|
||||
/// of given chracteristics. This matches the logic in IntrinsicsNVVM.td
|
||||
/// WMMA_REGS structure.
|
||||
std::pair<mlir::Type, unsigned> inferMMAType(mlir::NVVM::MMATypes type,
|
||||
mlir::NVVM::MMAFrag frag,
|
||||
mlir::NVVM::MMAFrag frag, int nRow,
|
||||
int nCol,
|
||||
mlir::MLIRContext *context);
|
||||
} // namespace NVVM
|
||||
} // namespace mlir
|
||||
|
||||
@@ -385,16 +385,20 @@ class NVVM_MMA_OPS {
|
||||
list<list<WMMA_REGS>> fp_wmma_ops = MMA_OPS<
|
||||
[GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>],
|
||||
["f16"], [], ["f16", "f32"], []>.ret;
|
||||
list<list<WMMA_REGS>> i8_wmma_ops = MMA_OPS<
|
||||
[GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>],
|
||||
["s8","u8"], [], ["s32"], []>.ret;
|
||||
list<list<WMMA_REGS>> all_wmma_ops = !listconcat(
|
||||
tf32_wmma_ops,
|
||||
fp_wmma_ops);
|
||||
fp_wmma_ops,
|
||||
i8_wmma_ops);
|
||||
|
||||
list<WMMA_REGS> ldst_ab_ops = MMA_LDST_OPS<
|
||||
[GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>],
|
||||
["a", "b"], ["f16"]>.ret;
|
||||
["a", "b"], ["f16","s8","u8"]>.ret;
|
||||
list<WMMA_REGS> ldst_cd_ops = MMA_LDST_OPS<
|
||||
[GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>],
|
||||
["c", "d"], ["f16", "f32"]>.ret;
|
||||
["c", "d"], ["f16", "f32","s32"]>.ret;
|
||||
list<WMMA_REGS> ldst_tf32_ab_ops = MMA_LDST_OPS<
|
||||
[GEOM<16, 16, 8>],
|
||||
["a", "b"], ["tf32"]>.ret;
|
||||
|
||||
@@ -57,6 +57,12 @@ static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
|
||||
if (type.getElementType().isF32())
|
||||
return type.getOperand().equals("COp") ? NVVM::MMATypes::f32
|
||||
: NVVM::MMATypes::tf32;
|
||||
|
||||
if (type.getElementType().isSignedInteger(8))
|
||||
return NVVM::MMATypes::s8;
|
||||
// Accumulator type is signless and implies signed.
|
||||
if (type.getElementType().isInteger(32))
|
||||
return NVVM::MMATypes::s32;
|
||||
llvm_unreachable("Unsupported type");
|
||||
}
|
||||
|
||||
@@ -106,8 +112,11 @@ struct WmmaLoadOpToNVVMLowering
|
||||
}
|
||||
NVVM::MMAFrag frag = convertOperand(retType.getOperand());
|
||||
// Check that there is an exisiting instruction for the combination we need.
|
||||
if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0)
|
||||
if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0) {
|
||||
llvm::errs() << "No matching intrinsic " << m << " " << n << " " << k
|
||||
<< "\n";
|
||||
return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
|
||||
}
|
||||
|
||||
Type resType = convertMMAToLLVMType(retType);
|
||||
Location loc = op->getLoc();
|
||||
@@ -366,8 +375,10 @@ struct WmmaElementwiseOpToNVVMLowering
|
||||
LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
|
||||
NVVM::MMAFrag frag = convertOperand(type.getOperand());
|
||||
NVVM::MMATypes eltType = getElementType(type);
|
||||
auto nRow = type.getShape()[0];
|
||||
auto nCol = type.getShape()[1];
|
||||
std::pair<Type, unsigned> typeInfo =
|
||||
NVVM::inferMMAType(eltType, frag, type.getContext());
|
||||
NVVM::inferMMAType(eltType, frag, nRow, nCol, type.getContext());
|
||||
return LLVM::LLVMStructType::getLiteral(
|
||||
type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
|
||||
}
|
||||
|
||||
@@ -140,6 +140,12 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp,
|
||||
return false;
|
||||
if (!getMemrefConstantHorizontalStride(readOp.getShapedType()))
|
||||
return false;
|
||||
|
||||
// Only allow integer types if the signedness can be inferred.
|
||||
if (!useNvGpu && readOp.getVectorType().getElementType().isInteger(8))
|
||||
if (!readOp->hasOneUse() || !isa<arith::ExtSIOp>(*readOp->user_begin()))
|
||||
return false;
|
||||
|
||||
AffineMap map = readOp.getPermutationMap();
|
||||
OpBuilder b(readOp.getContext());
|
||||
AffineExpr innerDim = b.getAffineDimExpr(map.getNumDims() - 1);
|
||||
@@ -185,8 +191,16 @@ static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
|
||||
|
||||
/// Return true if this is a broadcast from scalar to a 2D vector.
|
||||
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
|
||||
return broadcastOp.getVectorType().getRank() == 2 &&
|
||||
broadcastOp.getSource().getType().isa<FloatType>();
|
||||
return broadcastOp.getVectorType().getRank() == 2;
|
||||
}
|
||||
|
||||
/// Return true if this signed extend op can be folded into a contract op.
|
||||
static bool signedExtendSupportsMMAMatrixType(arith::ExtSIOp extOp) {
|
||||
if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp()))
|
||||
return false;
|
||||
return llvm::all_of(extOp->getUsers(), [](Operation *user) {
|
||||
return isa<vector::ContractionOp>(user);
|
||||
});
|
||||
}
|
||||
|
||||
/// Return the MMA elementwise enum associated with `op` if it is supported.
|
||||
@@ -268,6 +282,8 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
|
||||
return constantSupportsMMAMatrixType(constant);
|
||||
if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
|
||||
return broadcastSupportsMMAMatrixType(broadcast);
|
||||
if (auto extend = dyn_cast<arith::ExtSIOp>(op))
|
||||
return signedExtendSupportsMMAMatrixType(extend);
|
||||
return elementwiseSupportsMMAMatrixType(op);
|
||||
}
|
||||
|
||||
@@ -411,8 +427,18 @@ struct CombineTransferReadOpTranspose final
|
||||
|
||||
LogicalResult matchAndRewrite(vector::TransposeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto transferReadOp =
|
||||
op.getVector().getDefiningOp<vector::TransferReadOp>();
|
||||
// Look through integer extend ops.
|
||||
Value source = op.getVector();
|
||||
auto extOp = source.getDefiningOp<arith::ExtSIOp>();
|
||||
auto resultType = op.getVectorType();
|
||||
if (extOp) {
|
||||
source = extOp.getOperand();
|
||||
resultType =
|
||||
VectorType::get(resultType.getShape(),
|
||||
source.getType().cast<VectorType>().getElementType());
|
||||
}
|
||||
|
||||
auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>();
|
||||
if (!transferReadOp)
|
||||
return failure();
|
||||
|
||||
@@ -431,11 +457,23 @@ struct CombineTransferReadOpTranspose final
|
||||
AffineMap::getPermutationMap(permU, op.getContext());
|
||||
AffineMap newMap =
|
||||
permutationMap.compose(transferReadOp.getPermutationMap());
|
||||
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
|
||||
op, op.getType(), transferReadOp.getSource(),
|
||||
transferReadOp.getIndices(), AffineMapAttr::get(newMap),
|
||||
transferReadOp.getPadding(), transferReadOp.getMask(),
|
||||
transferReadOp.getInBoundsAttr());
|
||||
|
||||
auto loc = op.getLoc();
|
||||
Value result =
|
||||
rewriter
|
||||
.create<vector::TransferReadOp>(
|
||||
loc, resultType, transferReadOp.getSource(),
|
||||
transferReadOp.getIndices(), AffineMapAttr::get(newMap),
|
||||
transferReadOp.getPadding(), transferReadOp.getMask(),
|
||||
transferReadOp.getInBoundsAttr())
|
||||
.getResult();
|
||||
|
||||
// Fuse through the integer extend op.
|
||||
if (extOp)
|
||||
result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
|
||||
.getResult();
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -479,14 +517,26 @@ static void convertTransferReadOp(vector::TransferReadOp op,
|
||||
stride = 0;
|
||||
}
|
||||
assert(stride);
|
||||
Value mappingResult = op.getResult();
|
||||
auto elType = op.getVectorType().getElementType();
|
||||
const char *fragType = inferFragType(op);
|
||||
if (op->hasOneUse()) {
|
||||
auto extOp = dyn_cast<arith::ExtSIOp>(*op->user_begin());
|
||||
// Infer the signedness of the mma type from the signed extend.
|
||||
if (extOp) {
|
||||
elType = IntegerType::get(op.getContext(),
|
||||
elType.cast<IntegerType>().getWidth(),
|
||||
IntegerType::Signed);
|
||||
mappingResult = extOp.getResult();
|
||||
fragType = inferFragType(extOp);
|
||||
}
|
||||
}
|
||||
gpu::MMAMatrixType type =
|
||||
gpu::MMAMatrixType::get(op.getVectorType().getShape(),
|
||||
op.getVectorType().getElementType(), fragType);
|
||||
gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
|
||||
Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>(
|
||||
op.getLoc(), type, op.getSource(), op.getIndices(),
|
||||
b.getIndexAttr(*stride), isTranspose ? b.getUnitAttr() : UnitAttr());
|
||||
valueMapping[op.getResult()] = load;
|
||||
valueMapping[mappingResult] = load;
|
||||
}
|
||||
|
||||
static void convertTransferWriteOp(vector::TransferWriteOp op,
|
||||
|
||||
@@ -78,7 +78,9 @@ Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
|
||||
StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
|
||||
|
||||
bool MMAMatrixType::isValidElementType(Type elementType) {
|
||||
return elementType.isF16() || elementType.isF32();
|
||||
return elementType.isF16() || elementType.isF32() ||
|
||||
elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
|
||||
elementType.isInteger(32);
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
@@ -93,7 +95,8 @@ MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
return emitError() << "MMAMatrixType must have exactly two dimensions";
|
||||
|
||||
if (!MMAMatrixType::isValidElementType(elementType))
|
||||
return emitError() << "MMAMatrixType elements must be F16 or F32";
|
||||
return emitError()
|
||||
<< "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -537,7 +537,8 @@ LogicalResult ShflOp::verify() {
|
||||
}
|
||||
|
||||
std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
|
||||
NVVM::MMAFrag frag,
|
||||
NVVM::MMAFrag frag, int nRow,
|
||||
int nCol,
|
||||
MLIRContext *context) {
|
||||
unsigned numberElements = 0;
|
||||
Type elementType;
|
||||
@@ -555,11 +556,48 @@ std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
|
||||
} else if (type == NVVM::MMATypes::tf32) {
|
||||
elementType = builder.getI32Type();
|
||||
numberElements = 4;
|
||||
} else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
|
||||
elementType = builder.getI32Type();
|
||||
int parallelSize = 0;
|
||||
if (frag == NVVM::MMAFrag::a)
|
||||
parallelSize = nRow;
|
||||
if (frag == NVVM::MMAFrag::b)
|
||||
parallelSize = nCol;
|
||||
|
||||
// m == 16 && n == 16 && k == 16
|
||||
if (parallelSize == 16)
|
||||
numberElements = 2;
|
||||
// m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
|
||||
else if (parallelSize == 8)
|
||||
numberElements = 1;
|
||||
else if (parallelSize == 32)
|
||||
numberElements = 4;
|
||||
} else if (type == NVVM::MMATypes::s32) {
|
||||
elementType = builder.getI32Type();
|
||||
numberElements = 8;
|
||||
}
|
||||
assert(numberElements != 0 && elementType != nullptr);
|
||||
return std::make_pair(elementType, numberElements);
|
||||
}
|
||||
|
||||
static std::pair<mlir::Type, unsigned>
|
||||
inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
|
||||
int k, MLIRContext *context) {
|
||||
int nRow, nCol;
|
||||
if (frag == NVVM::MMAFrag::a) {
|
||||
nRow = m;
|
||||
nCol = k;
|
||||
} else if (frag == NVVM::MMAFrag::b) {
|
||||
nRow = k;
|
||||
nCol = n;
|
||||
} else {
|
||||
nRow = m;
|
||||
nCol = n;
|
||||
}
|
||||
assert(nRow && nCol);
|
||||
return inferMMAType(type, frag, nRow, nCol, context);
|
||||
}
|
||||
|
||||
LogicalResult NVVM::WMMALoadOp::verify() {
|
||||
unsigned addressSpace =
|
||||
getPtr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
|
||||
@@ -570,8 +608,8 @@ LogicalResult NVVM::WMMALoadOp::verify() {
|
||||
if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
|
||||
getEltype(), getFrag()) == 0)
|
||||
return emitOpError() << "invalid attribute combination";
|
||||
std::pair<Type, unsigned> typeInfo =
|
||||
inferMMAType(getEltype(), getFrag(), getContext());
|
||||
std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
|
||||
getEltype(), getFrag(), getM(), getN(), getK(), getContext());
|
||||
Type dstType = LLVM::LLVMStructType::getLiteral(
|
||||
getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
|
||||
if (getType() != dstType)
|
||||
@@ -590,8 +628,8 @@ LogicalResult NVVM::WMMAStoreOp::verify() {
|
||||
if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
|
||||
getEltype()) == 0)
|
||||
return emitOpError() << "invalid attribute combination";
|
||||
std::pair<Type, unsigned> typeInfo =
|
||||
inferMMAType(getEltype(), NVVM::MMAFrag::c, getContext());
|
||||
std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
|
||||
getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
|
||||
if (getArgs().size() != typeInfo.second)
|
||||
return emitOpError() << "expected " << typeInfo.second << " data operands";
|
||||
if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
|
||||
@@ -606,12 +644,12 @@ LogicalResult NVVM::WMMAMmaOp::verify() {
|
||||
getLayoutB(), getEltypeA(),
|
||||
getEltypeB()) == 0)
|
||||
return emitOpError() << "invalid attribute combination";
|
||||
std::pair<Type, unsigned> typeInfoA =
|
||||
inferMMAType(getEltypeA(), NVVM::MMAFrag::a, getContext());
|
||||
std::pair<Type, unsigned> typeInfoB =
|
||||
inferMMAType(getEltypeA(), NVVM::MMAFrag::b, getContext());
|
||||
std::pair<Type, unsigned> typeInfoC =
|
||||
inferMMAType(getEltypeB(), NVVM::MMAFrag::c, getContext());
|
||||
std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
|
||||
getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
|
||||
std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
|
||||
getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
|
||||
std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
|
||||
getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
|
||||
SmallVector<Type, 32> arguments;
|
||||
arguments.append(typeInfoA.second, typeInfoA.first);
|
||||
arguments.append(typeInfoB.second, typeInfoB.first);
|
||||
|
||||
@@ -40,6 +40,45 @@ gpu.module @test_module {
|
||||
|
||||
// -----
|
||||
|
||||
gpu.module @test_module {
|
||||
|
||||
// CHECK-LABEL: func @gpu_wmma_int8_load_op() ->
|
||||
// CHECK-SAME: !llvm.struct<(i32, i32)>
|
||||
// CHECK32-LABEL: func @gpu_wmma_int8_load_op() ->
|
||||
func.func @gpu_wmma_int8_load_op() -> (!gpu.mma_matrix<16x16xsi8, "AOp">) {
|
||||
%wg = memref.alloca() {alignment = 32} : memref<32x32xi8, 3>
|
||||
%i = arith.constant 16 : index
|
||||
%j = arith.constant 16 : index
|
||||
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xi8, 3> -> !gpu.mma_matrix<16x16xsi8, "AOp">
|
||||
// CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64
|
||||
// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
|
||||
// CHECK: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<i8, 3>, ptr<i8, 3>, i64, array<2 x i64>, array<2 x i64>)>
|
||||
// CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64
|
||||
// CHECK: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i64
|
||||
// CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64
|
||||
// CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<i8, 3>, i64) -> !llvm.ptr<i8, 3>
|
||||
// CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
|
||||
// CHECK: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]]
|
||||
// CHECK-SAME: {eltype = #nvvm.mma_type<s8>, frag = #nvvm.mma_frag<a>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32} : (!llvm.ptr<i8, 3>) -> !llvm.struct<(i32, i32)>
|
||||
// CHECK: llvm.return %[[FRAG]] : !llvm.struct<(i32, i32)>
|
||||
|
||||
// CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
|
||||
// CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
|
||||
// CHECK32: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<i8, 3>, ptr<i8, 3>, i32, array<2 x i32>, array<2 x i32>)>
|
||||
// CHECK32: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
|
||||
// CHECK32: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i32
|
||||
// CHECK32: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32
|
||||
// CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<i8, 3>, i32) -> !llvm.ptr<i8, 3>
|
||||
// CHECK32: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
|
||||
// CHECK32: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]]
|
||||
// CHECK32-SAME: {eltype = #nvvm.mma_type<s8>, frag = #nvvm.mma_frag<a>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32} : (!llvm.ptr<i8, 3>) -> !llvm.struct<(i32, i32)>
|
||||
// CHECK32: llvm.return %[[FRAG]] : !llvm.struct<(i32, i32)>
|
||||
return %0 : !gpu.mma_matrix<16x16xsi8, "AOp">
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
gpu.module @test_module {
|
||||
|
||||
// CHECK-LABEL: func @gpu_wmma_store_op
|
||||
@@ -124,6 +163,35 @@ gpu.module @test_module {
|
||||
|
||||
// -----
|
||||
|
||||
gpu.module @test_module {
|
||||
|
||||
// CHECK-LABEL: func @gpu_wmma_mma_int8_op
|
||||
// CHECK-SAME: (%[[A:.*]]: !llvm.struct<(i32, i32, i32, i32)>, %[[B:.*]]: !llvm.struct<(i32)>, %[[C:.*]]: !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>)
|
||||
func.func @gpu_wmma_mma_int8_op(%A : !gpu.mma_matrix<32x16xsi8, "AOp">, %B : !gpu.mma_matrix<16x8xsi8, "BOp">, %C : !gpu.mma_matrix<32x8xi32, "COp">) -> (!gpu.mma_matrix<32x8xi32, "COp">) {
|
||||
%D = gpu.subgroup_mma_compute %A, %B, %C {a_transpose} : !gpu.mma_matrix<32x16xsi8, "AOp">, !gpu.mma_matrix<16x8xsi8, "BOp"> -> !gpu.mma_matrix<32x8xi32, "COp">
|
||||
// CHECK: %[[A1:.*]] = llvm.extractvalue %[[A]][0] : !llvm.struct<(i32, i32, i32, i32)>
|
||||
// CHECK: %[[A2:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(i32, i32, i32, i32)>
|
||||
// CHECK: %[[A3:.*]] = llvm.extractvalue %[[A]][2] : !llvm.struct<(i32, i32, i32, i32)>
|
||||
// CHECK: %[[A4:.*]] = llvm.extractvalue %[[A]][3] : !llvm.struct<(i32, i32, i32, i32)>
|
||||
// CHECK: %[[B1:.*]] = llvm.extractvalue %[[B]][0] : !llvm.struct<(i32)>
|
||||
// CHECK: %[[C1:.*]] = llvm.extractvalue %[[C]][0] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
|
||||
// CHECK: %[[C2:.*]] = llvm.extractvalue %[[C]][1] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
|
||||
// CHECK: %[[C3:.*]] = llvm.extractvalue %[[C]][2] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
|
||||
// CHECK: %[[C4:.*]] = llvm.extractvalue %[[C]][3] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
|
||||
// CHECK: %[[C5:.*]] = llvm.extractvalue %[[C]][4] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
|
||||
// CHECK: %[[C6:.*]] = llvm.extractvalue %[[C]][5] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
|
||||
// CHECK: %[[C7:.*]] = llvm.extractvalue %[[C]][6] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
|
||||
// CHECK: %[[C8:.*]] = llvm.extractvalue %[[C]][7] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
|
||||
// CHECK: %[[RES:.*]] = nvvm.wmma.mma %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[B1]], %[[C1]], %[[C2]], %[[C3]], %[[C4]], %[[C5]], %[[C6]], %[[C7]], %[[C8]]
|
||||
// CHECK-SAME: {eltypeA = #nvvm.mma_type<s8>, eltypeB = #nvvm.mma_type<s32>, k = 16 : i32, layoutA = #nvvm.mma_layout<col>, layoutB = #nvvm.mma_layout<row>, m = 32 : i32, n = 8 : i32} : (
|
||||
// CHECK-SAME: i32, {{.*}}) -> !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
|
||||
// CHECK: llvm.return %[[RES]] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
|
||||
return %D : !gpu.mma_matrix<32x8xi32, "COp">
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
gpu.module @test_module {
|
||||
|
||||
// CHECK-LABEL: func @gpu_wmma_mma_loop_op
|
||||
|
||||
@@ -225,3 +225,44 @@ func.func @matmul_transposed_broadcasted_2d(%arg0: memref<32x32xf16>, %arg1: mem
|
||||
vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
|
||||
return
|
||||
}
|
||||
|
||||
// Do not convert to subgroup_mma ops with integer types if signedness cannot be inferred.
|
||||
// CHECK-LABEL: func @matmul_no_extend_int8
|
||||
// CHECK-DAG: %[[A:.+]] = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8>
|
||||
// CHECK-DAG: %[[B:.+]] = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8>
|
||||
// CHECK-DAG: %[[C:.+]] = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<16x16xi32>, vector<16x16xi32>
|
||||
// CHECK: %[[D:.+]] = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[A]], %[[B]], %[[C]] : vector<16x16xi8>, vector<16x16xi8> into vector<16x16xi32>
|
||||
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32>
|
||||
func.func @matmul_no_extend_int8(%arg0: memref<16x16xi8>, %arg1: memref<16x16xi8>, %arg2: memref<16x16xi32>) {
|
||||
%cst_0 = arith.constant dense<0> : vector<16x16xi8>
|
||||
%c0 = arith.constant 0 : index
|
||||
%cst_i8 = arith.constant 0 : i8
|
||||
%cst_i32 = arith.constant 0 : i32
|
||||
%A = vector.transfer_read %arg0[%c0, %c0], %cst_i8 {in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8>
|
||||
%B = vector.transfer_read %arg1[%c0, %c0], %cst_i8 {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8>
|
||||
%C = vector.transfer_read %arg2[%c0, %c0], %cst_i32 {in_bounds = [true, true]} : memref<16x16xi32>, vector<16x16xi32>
|
||||
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xi8>, vector<16x16xi8> into vector<16x16xi32>
|
||||
vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @matmul_int8
|
||||
// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi8> -> !gpu.mma_matrix<16x16xsi8, "AOp">
|
||||
// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi8> -> !gpu.mma_matrix<16x16xsi8, "BOp">
|
||||
// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi32> -> !gpu.mma_matrix<16x16xi32, "COp">
|
||||
// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xsi8, "AOp">, !gpu.mma_matrix<16x16xsi8, "BOp"> -> !gpu.mma_matrix<16x16xi32, "COp">
|
||||
// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xi32, "COp">, memref<16x16xi32>
|
||||
func.func @matmul_int8(%arg0: memref<16x16xi8>, %arg1: memref<16x16xi8>, %arg2: memref<16x16xi32>) {
|
||||
%cst_0 = arith.constant dense<0> : vector<16x16xi8>
|
||||
%c0 = arith.constant 0 : index
|
||||
%cst_i8 = arith.constant 0 : i8
|
||||
%cst_i32 = arith.constant 0 : i32
|
||||
%Ar = vector.transfer_read %arg0[%c0, %c0], %cst_i8 {in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8>
|
||||
%Br = vector.transfer_read %arg1[%c0, %c0], %cst_i8 {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8>
|
||||
%C = vector.transfer_read %arg2[%c0, %c0], %cst_i32 {in_bounds = [true, true]} : memref<16x16xi32>, vector<16x16xi32>
|
||||
%Ae = arith.extsi %Ar : vector<16x16xi8> to vector<16x16xi32>
|
||||
%Be = arith.extsi %Br : vector<16x16xi8> to vector<16x16xi32>
|
||||
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %Ae, %Be, %C : vector<16x16xi32>, vector<16x16xi32> into vector<16x16xi32>
|
||||
vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -485,8 +485,8 @@ func.func @mmamatrix_operand_type(){
|
||||
func.func @mmamatrix_invalid_element_type(){
|
||||
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
|
||||
%i = arith.constant 16 : index
|
||||
// expected-error @+1 {{MMAMatrixType elements must be F16 or F32}}
|
||||
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xi32, "AOp">
|
||||
// expected-error @+1 {{MMAMatrixType elements must be SI8, UI8, I32, F16, or F32}}
|
||||
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xbf16, "AOp">
|
||||
return
|
||||
}
|
||||
|
||||
@@ -505,7 +505,7 @@ func.func @mmaLoadOp_identity_layout(){
|
||||
// -----
|
||||
|
||||
func.func @mma_invalid_memref_type(%src: memref<32x4xvector<4x8xf32>>, %i: index) {
|
||||
// expected-error @+1 {{operand #0 must be memref of 16-bit float or 32-bit float or vector of 16-bit float or 32-bit float values of ranks 1 values}}
|
||||
// expected-error @+1 {{operand #0 must be memref of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float or vector of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float values of ranks 1 values}}
|
||||
%0 = gpu.subgroup_mma_load_matrix %src[%i, %i] {leadDimension = 4 : index} : memref<32x4xvector<4x8xf32>> -> !gpu.mma_matrix<16x16xf16, "AOp">
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user