mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 21:53:12 +08:00
[mlir][spirv] Add cooperative matrix store op
Implement cooperative matrix store for the `SPV_KHR_cooperative_matrix` extension: https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/KHR/SPV_KHR_cooperative_matrix.html. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D155631
This commit is contained in:
@@ -4447,6 +4447,7 @@ def SPIRV_OC_OpUDotAccSat : I32EnumAttrCase<"OpUDotAccSat", 4454
|
||||
def SPIRV_OC_OpSUDotAccSat : I32EnumAttrCase<"OpSUDotAccSat", 4455>;
|
||||
def SPIRV_OC_OpTypeCooperativeMatrixKHR : I32EnumAttrCase<"OpTypeCooperativeMatrixKHR", 4456>;
|
||||
def SPIRV_OC_OpCooperativeMatrixLoadKHR : I32EnumAttrCase<"OpCooperativeMatrixLoadKHR", 4457>;
|
||||
def SPIRV_OC_OpCooperativeMatrixStoreKHR : I32EnumAttrCase<"OpCooperativeMatrixStoreKHR", 4458>;
|
||||
def SPIRV_OC_OpCooperativeMatrixLengthKHR : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
|
||||
def SPIRV_OC_OpTypeCooperativeMatrixNV : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>;
|
||||
def SPIRV_OC_OpCooperativeMatrixLoadNV : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>;
|
||||
@@ -4546,11 +4547,12 @@ def SPIRV_OpcodeAttr :
|
||||
SPIRV_OC_OpGroupNonUniformUMax, SPIRV_OC_OpGroupNonUniformFMax,
|
||||
SPIRV_OC_OpSubgroupBallotKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
|
||||
SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
|
||||
SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,
|
||||
SPIRV_OC_OpCooperativeMatrixLoadKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR,
|
||||
SPIRV_OC_OpTypeCooperativeMatrixNV,
|
||||
SPIRV_OC_OpCooperativeMatrixLoadNV, SPIRV_OC_OpCooperativeMatrixStoreNV,
|
||||
SPIRV_OC_OpCooperativeMatrixMulAddNV, SPIRV_OC_OpCooperativeMatrixLengthNV,
|
||||
SPIRV_OC_OpSUDotAccSat,
|
||||
SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR,
|
||||
SPIRV_OC_OpCooperativeMatrixStoreKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR,
|
||||
SPIRV_OC_OpTypeCooperativeMatrixNV, SPIRV_OC_OpCooperativeMatrixLoadNV,
|
||||
SPIRV_OC_OpCooperativeMatrixStoreNV, SPIRV_OC_OpCooperativeMatrixMulAddNV,
|
||||
SPIRV_OC_OpCooperativeMatrixLengthNV,
|
||||
SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL,
|
||||
SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT, SPIRV_OC_OpGroupIMulKHR,
|
||||
SPIRV_OC_OpGroupFMulKHR,
|
||||
|
||||
@@ -134,6 +134,75 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
|
||||
);
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStore", []> {
|
||||
let summary = "Stores a cooperative matrix through a pointer";
|
||||
|
||||
let description = [{
|
||||
Store a cooperative matrix through a pointer.
|
||||
Pointer is a pointer. Its type must be an OpTypePointer whose Type operand
|
||||
is a scalar or vector type. If the Shader capability was declared, Pointer
|
||||
must point into an array and any ArrayStride decoration on Pointer is
|
||||
ignored.
|
||||
|
||||
Object is the object to store. Its type must be an
|
||||
OpTypeCooperativeMatrixKHR.
|
||||
|
||||
MemoryLayout specifies how matrix elements are laid out in memory. It must
|
||||
come from a 32-bit integer constant instruction whose value corresponds to a
|
||||
Cooperative Matrix Layout. See the Cooperative Matrix Layout table for a
|
||||
description of the layouts and detailed layout-specific rules.
|
||||
|
||||
Stride further qualifies how matrix elements are laid out in memory. It must
|
||||
be a scalar integer type and its exact semantics depend on MemoryLayout.
|
||||
|
||||
Memory Operand must be a Memory Operand literal. If not present, it is the
|
||||
same as specifying None.
|
||||
|
||||
NOTE: In earlier versions of the SPIR-V spec, 'Memory Operand' was known
|
||||
as 'Memory Access'.
|
||||
|
||||
For a given dynamic instance of this instruction, all operands of this
|
||||
instruction must be the same for all invocations in a given scope instance
|
||||
(where the scope is the scope the cooperative matrix type was created with).
|
||||
All invocations in a given scope instance must be active or all must be
|
||||
inactive.
|
||||
|
||||
``` {.ebnf}
|
||||
coop-matrix-store-op ::= `spirv.KHR.CooperativeMatrixStore `
|
||||
ssa-use `, ` ssa-use `, `
|
||||
ssa-use `, ` cooperative-matrix-layout `, `
|
||||
(`[` memory-operand `]`)? `:`
|
||||
pointer-type `,` coop-matrix-type
|
||||
```
|
||||
|
||||
#### Example:
|
||||
|
||||
```
|
||||
spirv.KHR.CooperativeMatrixStore %ptr, %obj, %stride :
|
||||
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
|
||||
```
|
||||
}];
|
||||
|
||||
let availability = [
|
||||
MinVersion<SPIRV_V_1_6>,
|
||||
MaxVersion<SPIRV_V_1_6>,
|
||||
Extension<[SPV_KHR_cooperative_matrix]>,
|
||||
Capability<[SPIRV_C_CooperativeMatrixKHR]>
|
||||
];
|
||||
|
||||
let arguments = (ins
|
||||
SPIRV_AnyPtr:$pointer,
|
||||
SPIRV_AnyCooperativeMatrix:$object,
|
||||
SPIRV_Integer:$stride,
|
||||
SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout,
|
||||
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SPV_NV_cooperative_matrix extension ops.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -364,7 +433,7 @@ def SPIRV_NVCooperativeMatrixStoreOp : SPIRV_NvVendorOp<"CooperativeMatrixStore"
|
||||
ssa-use `, ` ssa-use `, `
|
||||
ssa-use `, ` ssa-use `, `
|
||||
(`[` memory-access `]`)? `:`
|
||||
pointer-type `,` spirv-element-type
|
||||
pointer-type `,` coop-matrix-type
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
@@ -4111,6 +4111,58 @@ LogicalResult spirv::KHRCooperativeMatrixLoadOp::verify() {
|
||||
getResult().getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.KHR.CooperativeMatrixStore
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ParseResult spirv::KHRCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
std::array<OpAsmParser::UnresolvedOperand, 3> operandInfo = {};
|
||||
for (auto &op : operandInfo) {
|
||||
if (parser.parseOperand(op) || parser.parseComma())
|
||||
return failure();
|
||||
}
|
||||
|
||||
spirv::CooperativeMatrixLayoutKHR layout;
|
||||
if (::parseEnumKeywordAttr<spirv::CooperativeMatrixLayoutKHRAttr>(
|
||||
layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName))
|
||||
return failure();
|
||||
|
||||
Type ptrType;
|
||||
Type objectType;
|
||||
if (parser.parseColon() || parser.parseType(ptrType) || parser.parseComma() ||
|
||||
parser.parseType(objectType)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
Type strideType = parser.getBuilder().getIntegerType(32);
|
||||
if (parser.resolveOperands(operandInfo, {ptrType, objectType, strideType},
|
||||
parser.getNameLoc(), result.operands)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void spirv::KHRCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
|
||||
printer << " " << getPointer() << ", " << getObject() << ", " << getStride()
|
||||
<< ", " << getMatrixLayout();
|
||||
|
||||
// Print optional memory operand attribute.
|
||||
if (auto memOperand = getMemoryOperand())
|
||||
printer << " [\"" << *memOperand << "\"]";
|
||||
printer << " : " << getPointer().getType() << ", " << getObject().getType();
|
||||
}
|
||||
|
||||
LogicalResult spirv::KHRCooperativeMatrixStoreOp::verify() {
|
||||
return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
|
||||
getObject().getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.NV.CooperativeMatrixLength
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -57,6 +57,27 @@ spirv.func @cooperative_matrix_load_function(%ptr : !spirv.ptr<i32, Function>, %
|
||||
spirv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @cooperative_matrix_store
|
||||
spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
|
||||
%m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
|
||||
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, RowMajor :
|
||||
// CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
|
||||
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, RowMajor :
|
||||
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
|
||||
spirv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @cooperative_matrix_store_memoperand
|
||||
spirv.func @cooperative_matrix_store_memoperand(%ptr : !spirv.ptr<i32, StorageBuffer>,
|
||||
%m : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
|
||||
%stride : i32) "None" {
|
||||
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, ColumnMajor ["Volatile"] :
|
||||
// CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
|
||||
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, ColumnMajor ["Volatile"] :
|
||||
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
|
||||
spirv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spirv.func @cooperative_matrix_load_bad_ptr(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32) "None" {
|
||||
@@ -95,6 +116,36 @@ spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr<i32, StorageBuf
|
||||
|
||||
// -----
|
||||
|
||||
spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
|
||||
%m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
|
||||
// expected-error @+1 {{expected ','}}
|
||||
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride :
|
||||
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
|
||||
spirv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
|
||||
%m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
|
||||
// expected-error @+1 {{expected valid keyword}}
|
||||
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, :
|
||||
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
|
||||
spirv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr<i32, StorageBuffer>,
|
||||
%stride : i32) "None" {
|
||||
// expected-error @+1 {{op operand #1 must be any SPIR-V cooperative matrix type}}
|
||||
spirv.KHR.CooperativeMatrixStore %ptr, %stride, %stride, RowMajor :
|
||||
!spirv.ptr<i32, StorageBuffer>, i32
|
||||
spirv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NV.CooperativeMatrix
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Reference in New Issue
Block a user