[mlir][spirv] Add cooperative matrix store op

Implement cooperative matrix store for the `SPV_KHR_cooperative_matrix`
extension: https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/KHR/SPV_KHR_cooperative_matrix.html.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D155631
This commit is contained in:
Jakub Kuderski
2023-07-19 11:01:07 -04:00
parent 1fa9e150b4
commit 68cd1dbc2e
4 changed files with 180 additions and 6 deletions

View File

@@ -4447,6 +4447,7 @@ def SPIRV_OC_OpUDotAccSat : I32EnumAttrCase<"OpUDotAccSat", 4454
def SPIRV_OC_OpSUDotAccSat : I32EnumAttrCase<"OpSUDotAccSat", 4455>;
def SPIRV_OC_OpTypeCooperativeMatrixKHR : I32EnumAttrCase<"OpTypeCooperativeMatrixKHR", 4456>;
def SPIRV_OC_OpCooperativeMatrixLoadKHR : I32EnumAttrCase<"OpCooperativeMatrixLoadKHR", 4457>;
def SPIRV_OC_OpCooperativeMatrixStoreKHR : I32EnumAttrCase<"OpCooperativeMatrixStoreKHR", 4458>;
def SPIRV_OC_OpCooperativeMatrixLengthKHR : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
def SPIRV_OC_OpTypeCooperativeMatrixNV : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>;
def SPIRV_OC_OpCooperativeMatrixLoadNV : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>;
@@ -4546,11 +4547,12 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpGroupNonUniformUMax, SPIRV_OC_OpGroupNonUniformFMax,
SPIRV_OC_OpSubgroupBallotKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,
SPIRV_OC_OpCooperativeMatrixLoadKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR,
SPIRV_OC_OpTypeCooperativeMatrixNV,
SPIRV_OC_OpCooperativeMatrixLoadNV, SPIRV_OC_OpCooperativeMatrixStoreNV,
SPIRV_OC_OpCooperativeMatrixMulAddNV, SPIRV_OC_OpCooperativeMatrixLengthNV,
SPIRV_OC_OpSUDotAccSat,
SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR,
SPIRV_OC_OpCooperativeMatrixStoreKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR,
SPIRV_OC_OpTypeCooperativeMatrixNV, SPIRV_OC_OpCooperativeMatrixLoadNV,
SPIRV_OC_OpCooperativeMatrixStoreNV, SPIRV_OC_OpCooperativeMatrixMulAddNV,
SPIRV_OC_OpCooperativeMatrixLengthNV,
SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL,
SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT, SPIRV_OC_OpGroupIMulKHR,
SPIRV_OC_OpGroupFMulKHR,

View File

@@ -134,6 +134,75 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
);
}
// -----
def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStore", []> {
let summary = "Stores a cooperative matrix through a pointer";
let description = [{
Store a cooperative matrix through a pointer.
Pointer is a pointer. Its type must be an OpTypePointer whose Type operand
is a scalar or vector type. If the Shader capability was declared, Pointer
must point into an array and any ArrayStride decoration on Pointer is
ignored.
Object is the object to store. Its type must be an
OpTypeCooperativeMatrixKHR.
MemoryLayout specifies how matrix elements are laid out in memory. It must
come from a 32-bit integer constant instruction whose value corresponds to a
Cooperative Matrix Layout. See the Cooperative Matrix Layout table for a
description of the layouts and detailed layout-specific rules.
Stride further qualifies how matrix elements are laid out in memory. It must
be a scalar integer type and its exact semantics depend on MemoryLayout.
Memory Operand must be a Memory Operand literal. If not present, it is the
same as specifying None.
NOTE: In earlier versions of the SPIR-V spec, 'Memory Operand' was known
as 'Memory Access'.
For a given dynamic instance of this instruction, all operands of this
instruction must be the same for all invocations in a given scope instance
(where the scope is the scope the cooperative matrix type was created with).
All invocations in a given scope instance must be active or all must be
inactive.
``` {.ebnf}
coop-matrix-store-op ::= `spirv.KHR.CooperativeMatrixStore `
ssa-use `, ` ssa-use `, `
ssa-use `, ` cooperative-matrix-layout `, `
(`[` memory-operand `]`)? `:`
pointer-type `,` coop-matrix-type
```
#### Example:
```
spirv.KHR.CooperativeMatrixStore %ptr, %obj, %stride :
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
```
}];
let availability = [
MinVersion<SPIRV_V_1_6>,
MaxVersion<SPIRV_V_1_6>,
Extension<[SPV_KHR_cooperative_matrix]>,
Capability<[SPIRV_C_CooperativeMatrixKHR]>
];
let arguments = (ins
SPIRV_AnyPtr:$pointer,
SPIRV_AnyCooperativeMatrix:$object,
SPIRV_Integer:$stride,
SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout,
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand
);
let results = (outs);
}
//===----------------------------------------------------------------------===//
// SPV_NV_cooperative_matrix extension ops.
//===----------------------------------------------------------------------===//
@@ -364,7 +433,7 @@ def SPIRV_NVCooperativeMatrixStoreOp : SPIRV_NvVendorOp<"CooperativeMatrixStore"
ssa-use `, ` ssa-use `, `
ssa-use `, ` ssa-use `, `
(`[` memory-access `]`)? `:`
pointer-type `,` spirv-element-type
pointer-type `,` coop-matrix-type
```
For example:

View File

@@ -4111,6 +4111,58 @@ LogicalResult spirv::KHRCooperativeMatrixLoadOp::verify() {
getResult().getType());
}
//===----------------------------------------------------------------------===//
// spirv.KHR.CooperativeMatrixStore
//===----------------------------------------------------------------------===//
ParseResult spirv::KHRCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
OperationState &result) {
std::array<OpAsmParser::UnresolvedOperand, 3> operandInfo = {};
for (auto &op : operandInfo) {
if (parser.parseOperand(op) || parser.parseComma())
return failure();
}
spirv::CooperativeMatrixLayoutKHR layout;
if (::parseEnumKeywordAttr<spirv::CooperativeMatrixLayoutKHRAttr>(
layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) {
return failure();
}
if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName))
return failure();
Type ptrType;
Type objectType;
if (parser.parseColon() || parser.parseType(ptrType) || parser.parseComma() ||
parser.parseType(objectType)) {
return failure();
}
Type strideType = parser.getBuilder().getIntegerType(32);
if (parser.resolveOperands(operandInfo, {ptrType, objectType, strideType},
parser.getNameLoc(), result.operands)) {
return failure();
}
return success();
}
void spirv::KHRCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
printer << " " << getPointer() << ", " << getObject() << ", " << getStride()
<< ", " << getMatrixLayout();
// Print optional memory operand attribute.
if (auto memOperand = getMemoryOperand())
printer << " [\"" << *memOperand << "\"]";
printer << " : " << getPointer().getType() << ", " << getObject().getType();
}
LogicalResult spirv::KHRCooperativeMatrixStoreOp::verify() {
return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
getObject().getType());
}
//===----------------------------------------------------------------------===//
// spirv.NV.CooperativeMatrixLength
//===----------------------------------------------------------------------===//

View File

@@ -57,6 +57,27 @@ spirv.func @cooperative_matrix_load_function(%ptr : !spirv.ptr<i32, Function>, %
spirv.Return
}
// CHECK-LABEL: @cooperative_matrix_store
spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
%m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, RowMajor :
// CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, RowMajor :
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
spirv.Return
}
// CHECK-LABEL: @cooperative_matrix_store_memoperand
spirv.func @cooperative_matrix_store_memoperand(%ptr : !spirv.ptr<i32, StorageBuffer>,
%m : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
%stride : i32) "None" {
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, ColumnMajor ["Volatile"] :
// CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, ColumnMajor ["Volatile"] :
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
spirv.Return
}
// -----
spirv.func @cooperative_matrix_load_bad_ptr(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32) "None" {
@@ -95,6 +116,36 @@ spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr<i32, StorageBuf
// -----
spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
%m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
// expected-error @+1 {{expected ','}}
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride :
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
spirv.Return
}
// -----
spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
%m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
// expected-error @+1 {{expected valid keyword}}
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, :
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
spirv.Return
}
// -----
spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr<i32, StorageBuffer>,
%stride : i32) "None" {
// expected-error @+1 {{op operand #1 must be any SPIR-V cooperative matrix type}}
spirv.KHR.CooperativeMatrixStore %ptr, %stride, %stride, RowMajor :
!spirv.ptr<i32, StorageBuffer>, i32
spirv.Return
}
// -----
//===----------------------------------------------------------------------===//
// NV.CooperativeMatrix
//===----------------------------------------------------------------------===//