[mlir][spirv] Add block read and write from SPV_INTEL_subgroups

Added support to OpSubgroupBlockReadINTEL and OpSubgroupBlockWriteINTEL

Differential Revision: https://reviews.llvm.org/D86876
This commit is contained in:
Artur Bialas
2020-09-02 19:52:29 -07:00
committed by Thomas Raoux
parent f7e04b710d
commit d9b4245f56
5 changed files with 268 additions and 2 deletions

View File

@@ -3252,6 +3252,8 @@ def SPV_OC_OpCooperativeMatrixLoadNV : I32EnumAttrCase<"OpCooperativeMatrixLoa
def SPV_OC_OpCooperativeMatrixStoreNV : I32EnumAttrCase<"OpCooperativeMatrixStoreNV", 5360>;
def SPV_OC_OpCooperativeMatrixMulAddNV : I32EnumAttrCase<"OpCooperativeMatrixMulAddNV", 5361>;
def SPV_OC_OpCooperativeMatrixLengthNV : I32EnumAttrCase<"OpCooperativeMatrixLengthNV", 5362>;
def SPV_OC_OpSubgroupBlockReadINTEL : I32EnumAttrCase<"OpSubgroupBlockReadINTEL", 5575>;
def SPV_OC_OpSubgroupBlockWriteINTEL : I32EnumAttrCase<"OpSubgroupBlockWriteINTEL", 5576>;
def SPV_OpcodeAttr :
SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
@@ -3308,7 +3310,8 @@ def SPV_OpcodeAttr :
SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR,
SPV_OC_OpTypeCooperativeMatrixNV, SPV_OC_OpCooperativeMatrixLoadNV,
SPV_OC_OpCooperativeMatrixStoreNV, SPV_OC_OpCooperativeMatrixMulAddNV,
SPV_OC_OpCooperativeMatrixLengthNV
SPV_OC_OpCooperativeMatrixLengthNV, SPV_OC_OpSubgroupBlockReadINTEL,
SPV_OC_OpSubgroupBlockWriteINTEL
]>;
// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!

View File

@@ -88,7 +88,6 @@ def SPV_GroupBroadcastOp : SPV_Op<"GroupBroadcast",
let assemblyFormat = [{
$execution_scope operands attr-dict `:` type($value) `,` type($localid)
}];
}
// -----
@@ -147,4 +146,104 @@ def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> {
// -----
def SPV_SubgroupBlockReadINTELOp : SPV_Op<"SubgroupBlockReadINTEL", []> {
let summary = "See extension SPV_INTEL_subgroups";
let description = [{
Reads one or more components of Result data for each invocation in the
subgroup from the specified Ptr as a block operation.
The data is read strided, so the first value read is:
Ptr[ SubgroupLocalInvocationId ]
and the second value read is:
Ptr[ SubgroupLocalInvocationId + SubgroupMaxSize ]
etc.
Result Type may be a scalar or vector type, and its component type must be
equal to the type pointed to by Ptr.
The type of Ptr must be a pointer type, and must point to a scalar type.
<!-- End of AutoGen section -->
```
subgroup-block-read-INTEL-op ::= ssa-id `=` `spv.SubgroupBlockReadINTEL`
storage-class ssa_use `:` spirv-element-type
```mlir
#### Example:
```
%0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : i32
```
}];
let availability = [
MinVersion<SPV_V_1_0>,
MaxVersion<SPV_V_1_5>,
Extension<[SPV_INTEL_subgroups]>,
Capability<[SPV_C_SubgroupBufferBlockIOINTEL]>
];
let arguments = (ins
SPV_AnyPtr:$ptr
);
let results = (outs
SPV_Type:$value
);
}
// -----
def SPV_SubgroupBlockWriteINTELOp : SPV_Op<"SubgroupBlockWriteINTEL", []> {
let summary = "See extension SPV_INTEL_subgroups";
let description = [{
Writes one or more components of Data for each invocation in the subgroup
from the specified Ptr as a block operation.
The data is written strided, so the first value is written to:
Ptr[ SubgroupLocalInvocationId ]
and the second value written is:
Ptr[ SubgroupLocalInvocationId + SubgroupMaxSize ]
etc.
The type of Ptr must be a pointer type, and must point to a scalar type.
The component type of Data must be equal to the type pointed to by Ptr.
<!-- End of AutoGen section -->
```
subgroup-block-write-INTEL-op ::= ssa-id `=` `spv.SubgroupBlockWriteINTEL`
storage-class ssa_use `,` ssa-use `:` spirv-element-type
```mlir
#### Example:
```
spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : i32
```
}];
let availability = [
MinVersion<SPV_V_1_0>,
MaxVersion<SPV_V_1_5>,
Extension<[SPV_INTEL_subgroups]>,
Capability<[SPV_C_SubgroupBufferBlockIOINTEL]>
];
let arguments = (ins
SPV_AnyPtr:$ptr,
SPV_Type:$value
);
let results = (outs);
}
// -----
#endif // SPIRV_GROUP_OPS

View File

@@ -468,6 +468,19 @@ static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
return success();
}
template <typename BlockReadWriteOpTy>
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
Value ptr, Value val) {
auto valType = val.getType();
if (auto valVecTy = valType.dyn_cast<VectorType>())
valType = valVecTy.getElementType();
if (valType != ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
return op.emitOpError("mismatch in result type and pointer type");
}
return success();
}
static ParseResult parseVariableDecorations(OpAsmParser &parser,
OperationState &state) {
auto builtInName = llvm::convertToSnakeFromCamelCase(
@@ -2025,6 +2038,93 @@ static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
return success();
}
//===----------------------------------------------------------------------===//
// spv.SubgroupBlockReadINTEL
//===----------------------------------------------------------------------===//
static ParseResult parseSubgroupBlockReadINTELOp(OpAsmParser &parser,
OperationState &state) {
// Parse the storage class specification
spirv::StorageClass storageClass;
OpAsmParser::OperandType ptrInfo;
Type elementType;
if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
parser.parseColon() || parser.parseType(elementType)) {
return failure();
}
auto ptrType = spirv::PointerType::get(elementType, storageClass);
if (auto valVecTy = elementType.dyn_cast<VectorType>())
ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) {
return failure();
}
state.addTypes(elementType);
return success();
}
static void print(spirv::SubgroupBlockReadINTELOp blockReadOp,
OpAsmPrinter &printer) {
SmallVector<StringRef, 4> elidedAttrs;
printer << spirv::SubgroupBlockReadINTELOp::getOperationName() << " "
<< blockReadOp.ptr();
printer << " : " << blockReadOp.getType();
}
static LogicalResult verify(spirv::SubgroupBlockReadINTELOp blockReadOp) {
if (failed(verifyBlockReadWritePtrAndValTypes(blockReadOp, blockReadOp.ptr(),
blockReadOp.value())))
return failure();
return success();
}
//===----------------------------------------------------------------------===//
// spv.SubgroupBlockWriteINTEL
//===----------------------------------------------------------------------===//
static ParseResult parseSubgroupBlockWriteINTELOp(OpAsmParser &parser,
OperationState &state) {
// Parse the storage class specification
spirv::StorageClass storageClass;
SmallVector<OpAsmParser::OperandType, 2> operandInfo;
auto loc = parser.getCurrentLocation();
Type elementType;
if (parseEnumStrAttr(storageClass, parser) ||
parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
parser.parseType(elementType)) {
return failure();
}
auto ptrType = spirv::PointerType::get(elementType, storageClass);
if (auto valVecTy = elementType.dyn_cast<VectorType>())
ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
state.operands)) {
return failure();
}
return success();
}
static void print(spirv::SubgroupBlockWriteINTELOp blockWriteOp,
OpAsmPrinter &printer) {
SmallVector<StringRef, 4> elidedAttrs;
printer << spirv::SubgroupBlockWriteINTELOp::getOperationName() << " "
<< blockWriteOp.ptr() << ", " << blockWriteOp.value();
printer << " : " << blockWriteOp.value().getType();
}
static LogicalResult verify(spirv::SubgroupBlockWriteINTELOp blockWriteOp) {
if (failed(verifyBlockReadWritePtrAndValTypes(
blockWriteOp, blockWriteOp.ptr(), blockWriteOp.value())))
return failure();
return success();
}
//===----------------------------------------------------------------------===//
// spv.GroupNonUniformElectOp
//===----------------------------------------------------------------------===//

View File

@@ -19,4 +19,28 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
%0 = spv.GroupBroadcast "Workgroup" %value, %localid : f32, vector<3xi32>
spv.ReturnValue %0: f32
}
// CHECK-LABEL: @subgroup_block_read_intel
spv.func @subgroup_block_read_intel(%ptr : !spv.ptr<i32, StorageBuffer>) -> i32 "None" {
// CHECK: spv.SubgroupBlockReadINTEL %{{.*}} : i32
%0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : i32
spv.ReturnValue %0: i32
}
// CHECK-LABEL: @subgroup_block_read_intel_vector
spv.func @subgroup_block_read_intel_vector(%ptr : !spv.ptr<i32, StorageBuffer>) -> vector<3xi32> "None" {
// CHECK: spv.SubgroupBlockReadINTEL %{{.*}} : vector<3xi32>
%0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : vector<3xi32>
spv.ReturnValue %0: vector<3xi32>
}
// CHECK-LABEL: @subgroup_block_write_intel
spv.func @subgroup_block_write_intel(%ptr : !spv.ptr<i32, StorageBuffer>, %value: i32) -> () "None" {
// CHECK: spv.SubgroupBlockWriteINTEL %{{.*}}, %{{.*}} : i32
spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : i32
spv.Return
}
// CHECK-LABEL: @subgroup_block_write_intel_vector
spv.func @subgroup_block_write_intel_vector(%ptr : !spv.ptr<i32, StorageBuffer>, %value: vector<3xi32>) -> () "None" {
// CHECK: spv.SubgroupBlockWriteINTEL %{{.*}}, %{{.*}} : vector<3xi32>
spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : vector<3xi32>
spv.Return
}
}

View File

@@ -61,3 +61,43 @@ func @group_broadcast_negative_locid_vec4(%value: f32, %localid: vector<4xi32> )
%0 = spv.GroupBroadcast "Subgroup" %value, %localid : f32, vector<4xi32>
return %0: f32
}
// -----
//===----------------------------------------------------------------------===//
// spv.SubgroupBlockReadINTEL
//===----------------------------------------------------------------------===//
func @subgroup_block_read_intel(%ptr : !spv.ptr<i32, StorageBuffer>) -> i32 {
// CHECK: spv.SubgroupBlockReadINTEL %{{.*}} : i32
%0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : i32
return %0: i32
}
// -----
func @subgroup_block_read_intel_vector(%ptr : !spv.ptr<i32, StorageBuffer>) -> vector<3xi32> {
// CHECK: spv.SubgroupBlockReadINTEL %{{.*}} : vector<3xi32>
%0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : vector<3xi32>
return %0: vector<3xi32>
}
// -----
//===----------------------------------------------------------------------===//
// spv.SubgroupBlockWriteINTEL
//===----------------------------------------------------------------------===//
func @subgroup_block_write_intel(%ptr : !spv.ptr<i32, StorageBuffer>, %value: i32) -> () {
// CHECK: spv.SubgroupBlockWriteINTEL %{{.*}}, %{{.*}} : i32
spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : i32
return
}
// -----
func @subgroup_block_write_intel_vector(%ptr : !spv.ptr<i32, StorageBuffer>, %value: vector<3xi32>) -> () {
// CHECK: spv.SubgroupBlockWriteINTEL %{{.*}}, %{{.*}} : vector<3xi32>
spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : vector<3xi32>
return
}