[mlir][spirv] Extract Atomic/Cast/Group op implementation. NFC.

Continue to work outlined in D155747 and split the main SPIR-V ops
implementation file into a few smaller and quicker to compile files.
This organization matches the op definition organizaion in `.td` files.

In this patch, extract atomic, cast/conversion, and group op
implementation into separate files.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D155777
This commit is contained in:
Jakub Kuderski
2023-07-20 11:13:51 -04:00
parent 8dacf55af4
commit ab6827f2d4
7 changed files with 1213 additions and 1144 deletions

View File

@@ -51,7 +51,7 @@ def SPIRV_BitcastOp : SPIRV_Op<"Bitcast", [Pure]> {
If Result Type has a different number of components than Operand, the
total number of bits in Result Type must equal the total number of bits
in Operand. Let L be the type, either Result Type or Operands type,
in Operand. Let L be the type, either Result Type or Operand's type,
that has the larger number of components. Let S be the other type, with
the smaller number of components. The number of components in L must be
an integer multiple of the number of components in S. The first
@@ -335,17 +335,17 @@ def SPIRV_UConvertOp : SPIRV_CastOp<"UConvert",
def SPIRV_ConvertPtrToUOp : SPIRV_Op<"ConvertPtrToU", []> {
let summary = [{
Bit pattern-preserving conversion of a pointer to
Bit pattern-preserving conversion of a pointer to
an unsigned scalar integer of possibly different bit width.
}];
let description = [{
Result Type must be a scalar of integer type, whose Signedness operand is 0.
Pointer must be a physical pointer type. If the bit width of Pointer is
smaller than that of Result Type, the conversion zero extends Pointer.
If the bit width of Pointer is larger than that of Result Type,
the conversion truncates Pointer.
Pointer must be a physical pointer type. If the bit width of Pointer is
smaller than that of Result Type, the conversion zero extends Pointer.
If the bit width of Pointer is larger than that of Result Type,
the conversion truncates Pointer.
For same bit width Pointer and Result Type, this is the same as OpBitcast.
@@ -359,7 +359,7 @@ def SPIRV_ConvertPtrToUOp : SPIRV_Op<"ConvertPtrToU", []> {
#### Example:
```mlir
%1 = spirv.ConvertPtrToU %0 : !spirv.ptr<i32, Generic> to i32
%1 = spirv.ConvertPtrToU %0 : !spirv.ptr<i32, Generic> to i32
```
}];
@@ -390,18 +390,18 @@ def SPIRV_ConvertPtrToUOp : SPIRV_Op<"ConvertPtrToU", []> {
def SPIRV_ConvertUToPtrOp : SPIRV_Op<"ConvertUToPtr", [UnsignedOp]> {
let summary = [{
Bit pattern-preserving conversion of an unsigned scalar integer
Bit pattern-preserving conversion of an unsigned scalar integer
to a pointer.
}];
let description = [{
Result Type must be a physical pointer type.
Integer Value must be a scalar of integer type, whose Signedness
operand is 0. If the bit width of Integer Value is smaller
Integer Value must be a scalar of integer type, whose Signedness
operand is 0. If the bit width of Integer Value is smaller
than that of Result Type, the conversion zero extends Integer Value.
If the bit width of Integer Value is larger than that of Result Type,
the conversion truncates Integer Value.
If the bit width of Integer Value is larger than that of Result Type,
the conversion truncates Integer Value.
For same-width Integer Value and Result Type, this is the same as OpBitcast.
@@ -415,7 +415,7 @@ def SPIRV_ConvertUToPtrOp : SPIRV_Op<"ConvertUToPtr", [UnsignedOp]> {
#### Example:
```mlir
%1 = spirv.ConvertUToPtr %0 : i32 to !spirv.ptr<i32, Generic>
%1 = spirv.ConvertUToPtr %0 : i32 to !spirv.ptr<i32, Generic>
```
}];

View File

@@ -0,0 +1,441 @@
//===- AtomicOps.cpp - MLIR SPIR-V Atomic Ops ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines the atomic operations in the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "SPIRVOpUtils.h"
#include "SPIRVParsingUtils.h"
using namespace mlir::spirv::AttrNames;
namespace mlir::spirv {
// Parses an atomic update op. If the update op does not take a value (like
// AtomicIIncrement) `hasValue` must be false.
static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
OperationState &state, bool hasValue) {
spirv::Scope scope;
spirv::MemorySemantics memoryScope;
SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
OpAsmParser::UnresolvedOperand ptrInfo, valueInfo;
Type type;
SMLoc loc;
if (parseEnumStrAttr<spirv::ScopeAttr>(scope, parser, state,
kMemoryScopeAttrName) ||
parseEnumStrAttr<spirv::MemorySemanticsAttr>(memoryScope, parser, state,
kSemanticsAttrName) ||
parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) ||
parser.getCurrentLocation(&loc) || parser.parseColonType(type))
return failure();
auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
if (!ptrType)
return parser.emitError(loc, "expected pointer type");
SmallVector<Type, 2> operandTypes;
operandTypes.push_back(ptrType);
if (hasValue)
operandTypes.push_back(ptrType.getPointeeType());
if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(),
state.operands))
return failure();
return parser.addTypeToList(ptrType.getPointeeType(), state.types);
}
// Prints an atomic update op.
static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
printer << " \"";
auto scopeAttr = op->getAttrOfType<spirv::ScopeAttr>(kMemoryScopeAttrName);
printer << spirv::stringifyScope(scopeAttr.getValue()) << "\" \"";
auto memorySemanticsAttr =
op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName);
printer << spirv::stringifyMemorySemantics(memorySemanticsAttr.getValue())
<< "\" " << op->getOperands() << " : " << op->getOperand(0).getType();
}
template <typename T>
static StringRef stringifyTypeName();
template <>
StringRef stringifyTypeName<IntegerType>() {
return "integer";
}
template <>
StringRef stringifyTypeName<FloatType>() {
return "float";
}
// Verifies an atomic update op.
template <typename ExpectedElementType>
static LogicalResult verifyAtomicUpdateOp(Operation *op) {
auto ptrType = llvm::cast<spirv::PointerType>(op->getOperand(0).getType());
auto elementType = ptrType.getPointeeType();
if (!llvm::isa<ExpectedElementType>(elementType))
return op->emitOpError() << "pointer operand must point to an "
<< stringifyTypeName<ExpectedElementType>()
<< " value, found " << elementType;
if (op->getNumOperands() > 1) {
auto valueType = op->getOperand(1).getType();
if (valueType != elementType)
return op->emitOpError("expected value to have the same type as the "
"pointer operand's pointee type ")
<< elementType << ", but found " << valueType;
}
auto memorySemantics =
op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName)
.getValue();
if (failed(verifyMemorySemantics(op, memorySemantics))) {
return failure();
}
return success();
}
template <typename T>
static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer) {
printer << " \"" << stringifyScope(atomOp.getMemoryScope()) << "\" \""
<< stringifyMemorySemantics(atomOp.getEqualSemantics()) << "\" \""
<< stringifyMemorySemantics(atomOp.getUnequalSemantics()) << "\" "
<< atomOp.getOperands() << " : " << atomOp.getPointer().getType();
}
static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser,
OperationState &state) {
spirv::Scope memoryScope;
spirv::MemorySemantics equalSemantics, unequalSemantics;
SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
Type type;
if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, state,
kMemoryScopeAttrName) ||
parseEnumStrAttr<spirv::MemorySemanticsAttr>(
equalSemantics, parser, state, kEqualSemanticsAttrName) ||
parseEnumStrAttr<spirv::MemorySemanticsAttr>(
unequalSemantics, parser, state, kUnequalSemanticsAttrName) ||
parser.parseOperandList(operandInfo, 3))
return failure();
auto loc = parser.getCurrentLocation();
if (parser.parseColonType(type))
return failure();
auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
if (!ptrType)
return parser.emitError(loc, "expected pointer type");
if (parser.resolveOperands(
operandInfo,
{ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()},
parser.getNameLoc(), state.operands))
return failure();
return parser.addTypeToList(ptrType.getPointeeType(), state.types);
}
template <typename T>
static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp) {
// According to the spec:
// "The type of Value must be the same as Result Type. The type of the value
// pointed to by Pointer must be the same as Result Type. This type must also
// match the type of Comparator."
if (atomOp.getType() != atomOp.getValue().getType())
return atomOp.emitOpError("value operand must have the same type as the op "
"result, but found ")
<< atomOp.getValue().getType() << " vs " << atomOp.getType();
if (atomOp.getType() != atomOp.getComparator().getType())
return atomOp.emitOpError(
"comparator operand must have the same type as the op "
"result, but found ")
<< atomOp.getComparator().getType() << " vs " << atomOp.getType();
Type pointeeType =
llvm::cast<spirv::PointerType>(atomOp.getPointer().getType())
.getPointeeType();
if (atomOp.getType() != pointeeType)
return atomOp.emitOpError(
"pointer operand's pointee type must have the same "
"as the op result type, but found ")
<< pointeeType << " vs " << atomOp.getType();
// TODO: Unequal cannot be set to Release or Acquire and Release.
// In addition, Unequal cannot be set to a stronger memory-order then Equal.
return success();
}
//===----------------------------------------------------------------------===//
// spirv.AtomicAndOp
//===----------------------------------------------------------------------===//
LogicalResult AtomicAndOp::verify() {
return verifyAtomicUpdateOp<IntegerType>(getOperation());
}
ParseResult AtomicAndOp::parse(OpAsmParser &parser, OperationState &result) {
return parseAtomicUpdateOp(parser, result, true);
}
void AtomicAndOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
//===----------------------------------------------------------------------===//
// spirv.AtomicCompareExchangeOp
//===----------------------------------------------------------------------===//
LogicalResult AtomicCompareExchangeOp::verify() {
return verifyAtomicCompareExchangeImpl(*this);
}
ParseResult AtomicCompareExchangeOp::parse(OpAsmParser &parser,
OperationState &result) {
return parseAtomicCompareExchangeImpl(parser, result);
}
void AtomicCompareExchangeOp::print(OpAsmPrinter &p) {
printAtomicCompareExchangeImpl(*this, p);
}
//===----------------------------------------------------------------------===//
// spirv.AtomicCompareExchangeWeakOp
//===----------------------------------------------------------------------===//
LogicalResult AtomicCompareExchangeWeakOp::verify() {
return verifyAtomicCompareExchangeImpl(*this);
}
ParseResult AtomicCompareExchangeWeakOp::parse(OpAsmParser &parser,
OperationState &result) {
return parseAtomicCompareExchangeImpl(parser, result);
}
void AtomicCompareExchangeWeakOp::print(OpAsmPrinter &p) {
printAtomicCompareExchangeImpl(*this, p);
}
//===----------------------------------------------------------------------===//
// spirv.AtomicExchange
//===----------------------------------------------------------------------===//
void AtomicExchangeOp::print(OpAsmPrinter &printer) {
printer << " \"" << stringifyScope(getMemoryScope()) << "\" \""
<< stringifyMemorySemantics(getSemantics()) << "\" " << getOperands()
<< " : " << getPointer().getType();
}
ParseResult AtomicExchangeOp::parse(OpAsmParser &parser,
OperationState &result) {
spirv::Scope memoryScope;
spirv::MemorySemantics semantics;
SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
Type type;
if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, result,
kMemoryScopeAttrName) ||
parseEnumStrAttr<spirv::MemorySemanticsAttr>(semantics, parser, result,
kSemanticsAttrName) ||
parser.parseOperandList(operandInfo, 2))
return failure();
auto loc = parser.getCurrentLocation();
if (parser.parseColonType(type))
return failure();
auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
if (!ptrType)
return parser.emitError(loc, "expected pointer type");
if (parser.resolveOperands(operandInfo, {ptrType, ptrType.getPointeeType()},
parser.getNameLoc(), result.operands))
return failure();
return parser.addTypeToList(ptrType.getPointeeType(), result.types);
}
LogicalResult AtomicExchangeOp::verify() {
if (getType() != getValue().getType())
return emitOpError("value operand must have the same type as the op "
"result, but found ")
<< getValue().getType() << " vs " << getType();
Type pointeeType =
llvm::cast<spirv::PointerType>(getPointer().getType()).getPointeeType();
if (getType() != pointeeType)
return emitOpError("pointer operand's pointee type must have the same "
"as the op result type, but found ")
<< pointeeType << " vs " << getType();
return success();
}
//===----------------------------------------------------------------------===//
// spirv.AtomicIAddOp
//===----------------------------------------------------------------------===//
LogicalResult AtomicIAddOp::verify() {
return verifyAtomicUpdateOp<IntegerType>(getOperation());
}
ParseResult AtomicIAddOp::parse(OpAsmParser &parser, OperationState &result) {
return parseAtomicUpdateOp(parser, result, true);
}
void AtomicIAddOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
//===----------------------------------------------------------------------===//
// spirv.EXT.AtomicFAddOp
//===----------------------------------------------------------------------===//
LogicalResult EXTAtomicFAddOp::verify() {
return verifyAtomicUpdateOp<FloatType>(getOperation());
}
ParseResult EXTAtomicFAddOp::parse(OpAsmParser &parser,
OperationState &result) {
return parseAtomicUpdateOp(parser, result, true);
}
void spirv::EXTAtomicFAddOp::print(OpAsmPrinter &p) {
printAtomicUpdateOp(*this, p);
}
//===----------------------------------------------------------------------===//
// spirv.AtomicIDecrementOp
//===----------------------------------------------------------------------===//
LogicalResult AtomicIDecrementOp::verify() {
return verifyAtomicUpdateOp<IntegerType>(getOperation());
}
ParseResult AtomicIDecrementOp::parse(OpAsmParser &parser,
OperationState &result) {
return parseAtomicUpdateOp(parser, result, false);
}
void AtomicIDecrementOp::print(OpAsmPrinter &p) {
printAtomicUpdateOp(*this, p);
}
//===----------------------------------------------------------------------===//
// spirv.AtomicIIncrementOp
//===----------------------------------------------------------------------===//
LogicalResult AtomicIIncrementOp::verify() {
return verifyAtomicUpdateOp<IntegerType>(getOperation());
}
ParseResult AtomicIIncrementOp::parse(OpAsmParser &parser,
OperationState &result) {
return parseAtomicUpdateOp(parser, result, false);
}
void AtomicIIncrementOp::print(OpAsmPrinter &p) {
printAtomicUpdateOp(*this, p);
}
//===----------------------------------------------------------------------===//
// spirv.AtomicISubOp
//===----------------------------------------------------------------------===//
LogicalResult AtomicISubOp::verify() {
return verifyAtomicUpdateOp<IntegerType>(getOperation());
}
ParseResult AtomicISubOp::parse(OpAsmParser &parser, OperationState &result) {
return parseAtomicUpdateOp(parser, result, true);
}
void AtomicISubOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
//===----------------------------------------------------------------------===//
// spirv.AtomicOrOp
//===----------------------------------------------------------------------===//
LogicalResult AtomicOrOp::verify() {
return verifyAtomicUpdateOp<IntegerType>(getOperation());
}
ParseResult AtomicOrOp::parse(OpAsmParser &parser, OperationState &result) {
return parseAtomicUpdateOp(parser, result, true);
}
void AtomicOrOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
//===----------------------------------------------------------------------===//
// spirv.AtomicSMaxOp
//===----------------------------------------------------------------------===//
LogicalResult AtomicSMaxOp::verify() {
return verifyAtomicUpdateOp<IntegerType>(getOperation());
}
ParseResult AtomicSMaxOp::parse(OpAsmParser &parser, OperationState &result) {
return parseAtomicUpdateOp(parser, result, true);
}
void AtomicSMaxOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
//===----------------------------------------------------------------------===//
// spirv.AtomicSMinOp
//===----------------------------------------------------------------------===//
LogicalResult AtomicSMinOp::verify() {
return verifyAtomicUpdateOp<IntegerType>(getOperation());
}
ParseResult AtomicSMinOp::parse(OpAsmParser &parser, OperationState &result) {
return parseAtomicUpdateOp(parser, result, true);
}
void AtomicSMinOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
//===----------------------------------------------------------------------===//
// spirv.AtomicUMaxOp
//===----------------------------------------------------------------------===//
LogicalResult AtomicUMaxOp::verify() {
return verifyAtomicUpdateOp<IntegerType>(getOperation());
}
ParseResult AtomicUMaxOp::parse(OpAsmParser &parser, OperationState &result) {
return parseAtomicUpdateOp(parser, result, true);
}
void AtomicUMaxOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
//===----------------------------------------------------------------------===//
// spirv.AtomicUMinOp
//===----------------------------------------------------------------------===//
LogicalResult AtomicUMinOp::verify() {
return verifyAtomicUpdateOp<IntegerType>(getOperation());
}
ParseResult AtomicUMinOp::parse(OpAsmParser &parser, OperationState &result) {
return parseAtomicUpdateOp(parser, result, true);
}
void AtomicUMinOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
//===----------------------------------------------------------------------===//
// spirv.AtomicXorOp
//===----------------------------------------------------------------------===//
LogicalResult AtomicXorOp::verify() {
return verifyAtomicUpdateOp<IntegerType>(getOperation());
}
ParseResult AtomicXorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseAtomicUpdateOp(parser, result, true);
}
void AtomicXorOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
} // namespace mlir::spirv

View File

@@ -3,7 +3,10 @@ mlir_tablegen(SPIRVCanonicalization.inc -gen-rewriters)
add_public_tablegen_target(MLIRSPIRVCanonicalizationIncGen)
add_mlir_dialect_library(MLIRSPIRVDialect
AtomicOps.cpp
CastOps.cpp
CooperativeMatrixOps.cpp
GroupOps.cpp
IntegerDotProductOps.cpp
JointMatrixOps.cpp
SPIRVAttributes.cpp

View File

@@ -0,0 +1,339 @@
//===- CastOps.cpp - MLIR SPIR-V Cast Ops --------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines the cast and conversion operations in the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "SPIRVOpUtils.h"
#include "SPIRVParsingUtils.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir::spirv::AttrNames;
namespace mlir::spirv {
static LogicalResult verifyCastOp(Operation *op,
bool requireSameBitWidth = true,
bool skipBitWidthCheck = false) {
// Some CastOps have no limit on bit widths for result and operand type.
if (skipBitWidthCheck)
return success();
Type operandType = op->getOperand(0).getType();
Type resultType = op->getResult(0).getType();
// ODS checks that result type and operand type have the same shape. Check
// that composite types match and extract the element types, if any.
using TypePair = std::pair<Type, Type>;
auto [operandElemTy, resultElemTy] =
TypeSwitch<Type, TypePair>(operandType)
.Case<VectorType, spirv::CooperativeMatrixType,
spirv::CooperativeMatrixNVType, spirv::JointMatrixINTELType>(
[resultType](auto concreteOperandTy) -> TypePair {
if (auto concreteResultTy =
dyn_cast<decltype(concreteOperandTy)>(resultType)) {
return {concreteOperandTy.getElementType(),
concreteResultTy.getElementType()};
}
return {};
})
.Default([resultType](Type operandType) -> TypePair {
return {operandType, resultType};
});
if (!operandElemTy || !resultElemTy)
return op->emitOpError("incompatible operand and result types");
unsigned operandTypeBitWidth = operandElemTy.getIntOrFloatBitWidth();
unsigned resultTypeBitWidth = resultElemTy.getIntOrFloatBitWidth();
bool isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
if (requireSameBitWidth) {
if (!isSameBitWidth) {
return op->emitOpError(
"expected the same bit widths for operand type and result "
"type, but provided ")
<< operandElemTy << " and " << resultElemTy;
}
return success();
}
if (isSameBitWidth) {
return op->emitOpError(
"expected the different bit widths for operand type and result "
"type, but provided ")
<< operandElemTy << " and " << resultElemTy;
}
return success();
}
//===----------------------------------------------------------------------===//
// spirv.BitcastOp
//===----------------------------------------------------------------------===//
LogicalResult BitcastOp::verify() {
// TODO: The SPIR-V spec validation rules are different for different
// versions.
auto operandType = getOperand().getType();
auto resultType = getResult().getType();
if (operandType == resultType) {
return emitError("result type must be different from operand type");
}
if (llvm::isa<spirv::PointerType>(operandType) &&
!llvm::isa<spirv::PointerType>(resultType)) {
return emitError(
"unhandled bit cast conversion from pointer type to non-pointer type");
}
if (!llvm::isa<spirv::PointerType>(operandType) &&
llvm::isa<spirv::PointerType>(resultType)) {
return emitError(
"unhandled bit cast conversion from non-pointer type to pointer type");
}
auto operandBitWidth = getBitWidth(operandType);
auto resultBitWidth = getBitWidth(resultType);
if (operandBitWidth != resultBitWidth) {
return emitOpError("mismatch in result type bitwidth ")
<< resultBitWidth << " and operand type bitwidth "
<< operandBitWidth;
}
return success();
}
//===----------------------------------------------------------------------===//
// spirv.ConvertPtrToUOp
//===----------------------------------------------------------------------===//
LogicalResult ConvertPtrToUOp::verify() {
auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
auto resultType = llvm::cast<spirv::ScalarType>(getResult().getType());
if (!resultType || !resultType.isSignlessInteger())
return emitError("result must be a scalar type of unsigned integer");
auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
if (!spirvModule)
return success();
auto addressingModel = spirvModule.getAddressingModel();
if ((addressingModel == spirv::AddressingModel::Logical) ||
(addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 &&
operandType.getStorageClass() !=
spirv::StorageClass::PhysicalStorageBuffer))
return emitError("operand must be a physical pointer");
return success();
}
//===----------------------------------------------------------------------===//
// spirv.ConvertUToPtrOp
//===----------------------------------------------------------------------===//
LogicalResult ConvertUToPtrOp::verify() {
auto operandType = llvm::cast<spirv::ScalarType>(getOperand().getType());
auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
if (!operandType || !operandType.isSignlessInteger())
return emitError("result must be a scalar type of unsigned integer");
auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
if (!spirvModule)
return success();
auto addressingModel = spirvModule.getAddressingModel();
if ((addressingModel == spirv::AddressingModel::Logical) ||
(addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 &&
resultType.getStorageClass() !=
spirv::StorageClass::PhysicalStorageBuffer))
return emitError("result must be a physical pointer");
return success();
}
//===----------------------------------------------------------------------===//
// spirv.PtrCastToGenericOp
//===----------------------------------------------------------------------===//
LogicalResult PtrCastToGenericOp::verify() {
auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
spirv::StorageClass operandStorage = operandType.getStorageClass();
if (operandStorage != spirv::StorageClass::Workgroup &&
operandStorage != spirv::StorageClass::CrossWorkgroup &&
operandStorage != spirv::StorageClass::Function)
return emitError("pointer must point to the Workgroup, CrossWorkgroup"
", or Function Storage Class");
spirv::StorageClass resultStorage = resultType.getStorageClass();
if (resultStorage != spirv::StorageClass::Generic)
return emitError("result type must be of storage class Generic");
Type operandPointeeType = operandType.getPointeeType();
Type resultPointeeType = resultType.getPointeeType();
if (operandPointeeType != resultPointeeType)
return emitOpError("pointer operand's pointee type must have the same "
"as the op result type, but found ")
<< operandPointeeType << " vs " << resultPointeeType;
return success();
}
//===----------------------------------------------------------------------===//
// spirv.GenericCastToPtrOp
//===----------------------------------------------------------------------===//
LogicalResult GenericCastToPtrOp::verify() {
auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
spirv::StorageClass operandStorage = operandType.getStorageClass();
if (operandStorage != spirv::StorageClass::Generic)
return emitError("pointer type must be of storage class Generic");
spirv::StorageClass resultStorage = resultType.getStorageClass();
if (resultStorage != spirv::StorageClass::Workgroup &&
resultStorage != spirv::StorageClass::CrossWorkgroup &&
resultStorage != spirv::StorageClass::Function)
return emitError("result must point to the Workgroup, CrossWorkgroup, "
"or Function Storage Class");
Type operandPointeeType = operandType.getPointeeType();
Type resultPointeeType = resultType.getPointeeType();
if (operandPointeeType != resultPointeeType)
return emitOpError("pointer operand's pointee type must have the same "
"as the op result type, but found ")
<< operandPointeeType << " vs " << resultPointeeType;
return success();
}
//===----------------------------------------------------------------------===//
// spirv.GenericCastToPtrExplicitOp
//===----------------------------------------------------------------------===//
LogicalResult GenericCastToPtrExplicitOp::verify() {
auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
spirv::StorageClass operandStorage = operandType.getStorageClass();
if (operandStorage != spirv::StorageClass::Generic)
return emitError("pointer type must be of storage class Generic");
spirv::StorageClass resultStorage = resultType.getStorageClass();
if (resultStorage != spirv::StorageClass::Workgroup &&
resultStorage != spirv::StorageClass::CrossWorkgroup &&
resultStorage != spirv::StorageClass::Function)
return emitError("result must point to the Workgroup, CrossWorkgroup, "
"or Function Storage Class");
Type operandPointeeType = operandType.getPointeeType();
Type resultPointeeType = resultType.getPointeeType();
if (operandPointeeType != resultPointeeType)
return emitOpError("pointer operand's pointee type must have the same "
"as the op result type, but found ")
<< operandPointeeType << " vs " << resultPointeeType;
return success();
}
//===----------------------------------------------------------------------===//
// spirv.ConvertFToSOp
//===----------------------------------------------------------------------===//
LogicalResult ConvertFToSOp::verify() {
return verifyCastOp(*this, /*requireSameBitWidth=*/false,
/*skipBitWidthCheck=*/true);
}
//===----------------------------------------------------------------------===//
// spirv.ConvertFToUOp
//===----------------------------------------------------------------------===//
LogicalResult ConvertFToUOp::verify() {
return verifyCastOp(*this, /*requireSameBitWidth=*/false,
/*skipBitWidthCheck=*/true);
}
//===----------------------------------------------------------------------===//
// spirv.ConvertSToFOp
//===----------------------------------------------------------------------===//
LogicalResult ConvertSToFOp::verify() {
return verifyCastOp(*this, /*requireSameBitWidth=*/false,
/*skipBitWidthCheck=*/true);
}
//===----------------------------------------------------------------------===//
// spirv.ConvertUToFOp
//===----------------------------------------------------------------------===//
LogicalResult ConvertUToFOp::verify() {
return verifyCastOp(*this, /*requireSameBitWidth=*/false,
/*skipBitWidthCheck=*/true);
}
//===----------------------------------------------------------------------===//
// spirv.INTELConvertBF16ToFOp
//===----------------------------------------------------------------------===//
LogicalResult INTELConvertBF16ToFOp::verify() {
auto operandType = getOperand().getType();
auto resultType = getResult().getType();
// ODS checks that vector result type and vector operand type have the same
// shape.
if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
unsigned operandNumElements = vectorType.getNumElements();
unsigned resultNumElements =
llvm::cast<VectorType>(resultType).getNumElements();
if (operandNumElements != resultNumElements) {
return emitOpError(
"operand and result must have same number of elements");
}
}
return success();
}
//===----------------------------------------------------------------------===//
// spirv.INTELConvertFToBF16Op
//===----------------------------------------------------------------------===//
LogicalResult INTELConvertFToBF16Op::verify() {
auto operandType = getOperand().getType();
auto resultType = getResult().getType();
// ODS checks that vector result type and vector operand type have the same
// shape.
if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
unsigned operandNumElements = vectorType.getNumElements();
unsigned resultNumElements =
llvm::cast<VectorType>(resultType).getNumElements();
if (operandNumElements != resultNumElements) {
return emitOpError(
"operand and result must have same number of elements");
}
}
return success();
}
//===----------------------------------------------------------------------===//
// spirv.FConvertOp
//===----------------------------------------------------------------------===//
LogicalResult spirv::FConvertOp::verify() {
return verifyCastOp(*this, /*requireSameBitWidth=*/false);
}
//===----------------------------------------------------------------------===//
// spirv.SConvertOp
//===----------------------------------------------------------------------===//
LogicalResult spirv::SConvertOp::verify() {
return verifyCastOp(*this, /*requireSameBitWidth=*/false);
}
//===----------------------------------------------------------------------===//
// spirv.UConvertOp
//===----------------------------------------------------------------------===//
LogicalResult spirv::UConvertOp::verify() {
return verifyCastOp(*this, /*requireSameBitWidth=*/false);
}
} // namespace mlir::spirv

View File

@@ -0,0 +1,407 @@
//===- GroupOps.cpp - MLIR SPIR-V Group Ops ------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines the group operations in the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "SPIRVOpUtils.h"
#include "SPIRVParsingUtils.h"
using namespace mlir::spirv::AttrNames;
namespace mlir::spirv {
static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
OperationState &state) {
spirv::Scope executionScope;
GroupOperation groupOperation;
OpAsmParser::UnresolvedOperand valueInfo;
if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
kExecutionScopeAttrName) ||
spirv::parseEnumStrAttr<GroupOperationAttr>(groupOperation, parser, state,
kGroupOperationAttrName) ||
parser.parseOperand(valueInfo))
return failure();
std::optional<OpAsmParser::UnresolvedOperand> clusterSizeInfo;
if (succeeded(parser.parseOptionalKeyword(kClusterSize))) {
clusterSizeInfo = OpAsmParser::UnresolvedOperand();
if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) ||
parser.parseRParen())
return failure();
}
Type resultType;
if (parser.parseColonType(resultType))
return failure();
if (parser.resolveOperand(valueInfo, resultType, state.operands))
return failure();
if (clusterSizeInfo) {
Type i32Type = parser.getBuilder().getIntegerType(32);
if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands))
return failure();
}
return parser.addTypeToList(resultType, state.types);
}
static void printGroupNonUniformArithmeticOp(Operation *groupOp,
OpAsmPrinter &printer) {
printer
<< " \""
<< stringifyScope(
groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
.getValue())
<< "\" \""
<< stringifyGroupOperation(
groupOp->getAttrOfType<GroupOperationAttr>(kGroupOperationAttrName)
.getValue())
<< "\" " << groupOp->getOperand(0);
if (groupOp->getNumOperands() > 1)
printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
printer << " : " << groupOp->getResult(0).getType();
}
static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
spirv::Scope scope =
groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
.getValue();
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return groupOp->emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
GroupOperation operation =
groupOp->getAttrOfType<GroupOperationAttr>(kGroupOperationAttrName)
.getValue();
if (operation == GroupOperation::ClusteredReduce &&
groupOp->getNumOperands() == 1)
return groupOp->emitOpError("cluster size operand must be provided for "
"'ClusteredReduce' group operation");
if (groupOp->getNumOperands() > 1) {
Operation *sizeOp = groupOp->getOperand(1).getDefiningOp();
int32_t clusterSize = 0;
// TODO: support specialization constant here.
if (failed(extractValueFromConstOp(sizeOp, clusterSize)))
return groupOp->emitOpError(
"cluster size operand must come from a constant op");
if (!llvm::isPowerOf2_32(clusterSize))
return groupOp->emitOpError(
"cluster size operand must be a power of two");
}
return success();
}
//===----------------------------------------------------------------------===//
// spirv.GroupBroadcast
//===----------------------------------------------------------------------===//
LogicalResult GroupBroadcastOp::verify() {
spirv::Scope scope = getExecutionScope();
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
if (auto localIdTy = llvm::dyn_cast<VectorType>(getLocalid().getType()))
if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3)
return emitOpError("localid is a vector and can be with only "
" 2 or 3 components, actual number is ")
<< localIdTy.getNumElements();
return success();
}
//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformBallotOp
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformBallotOp::verify() {
spirv::Scope scope = getExecutionScope();
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
return success();
}
//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformBroadcast
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformBroadcastOp::verify() {
spirv::Scope scope = getExecutionScope();
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
// SPIR-V spec: "Before version 1.5, Id must come from a
// constant instruction.
auto targetEnv = spirv::getDefaultTargetEnv(getContext());
if (auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>())
targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule);
if (targetEnv.getVersion() < spirv::Version::V_1_5) {
auto *idOp = getId().getDefiningOp();
if (!idOp || !isa<spirv::ConstantOp, // for normal constant
spirv::ReferenceOfOp>(idOp)) // for spec constant
return emitOpError("id must be the result of a constant op");
}
return success();
}
//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformShuffle*
//===----------------------------------------------------------------------===//
template <typename OpTy>
static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op) {
spirv::Scope scope = op.getExecutionScope();
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
if (op.getOperands().back().getType().isSignedInteger())
return op.emitOpError("second operand must be a singless/unsigned integer");
return success();
}
LogicalResult GroupNonUniformShuffleOp::verify() {
return verifyGroupNonUniformShuffleOp(*this);
}
LogicalResult GroupNonUniformShuffleDownOp::verify() {
return verifyGroupNonUniformShuffleOp(*this);
}
LogicalResult GroupNonUniformShuffleUpOp::verify() {
return verifyGroupNonUniformShuffleOp(*this);
}
LogicalResult GroupNonUniformShuffleXorOp::verify() {
return verifyGroupNonUniformShuffleOp(*this);
}
//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformElectOp
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformElectOp::verify() {
spirv::Scope scope = getExecutionScope();
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
return success();
}
//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformFAddOp
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformFAddOp::verify() {
return verifyGroupNonUniformArithmeticOp(*this);
}
ParseResult GroupNonUniformFAddOp::parse(OpAsmParser &parser,
OperationState &result) {
return parseGroupNonUniformArithmeticOp(parser, result);
}
void GroupNonUniformFAddOp::print(OpAsmPrinter &p) {
printGroupNonUniformArithmeticOp(*this, p);
}
//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformFMaxOp
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformFMaxOp::verify() {
return verifyGroupNonUniformArithmeticOp(*this);
}
ParseResult GroupNonUniformFMaxOp::parse(OpAsmParser &parser,
OperationState &result) {
return parseGroupNonUniformArithmeticOp(parser, result);
}
void GroupNonUniformFMaxOp::print(OpAsmPrinter &p) {
printGroupNonUniformArithmeticOp(*this, p);
}
//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformFMinOp
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformFMinOp::verify() {
return verifyGroupNonUniformArithmeticOp(*this);
}
ParseResult GroupNonUniformFMinOp::parse(OpAsmParser &parser,
OperationState &result) {
return parseGroupNonUniformArithmeticOp(parser, result);
}
void GroupNonUniformFMinOp::print(OpAsmPrinter &p) {
printGroupNonUniformArithmeticOp(*this, p);
}
//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformFMulOp
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformFMulOp::verify() {
return verifyGroupNonUniformArithmeticOp(*this);
}
ParseResult GroupNonUniformFMulOp::parse(OpAsmParser &parser,
OperationState &result) {
return parseGroupNonUniformArithmeticOp(parser, result);
}
void GroupNonUniformFMulOp::print(OpAsmPrinter &p) {
printGroupNonUniformArithmeticOp(*this, p);
}
//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformIAddOp
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformIAddOp::verify() {
return verifyGroupNonUniformArithmeticOp(*this);
}
ParseResult GroupNonUniformIAddOp::parse(OpAsmParser &parser,
OperationState &result) {
return parseGroupNonUniformArithmeticOp(parser, result);
}
void GroupNonUniformIAddOp::print(OpAsmPrinter &p) {
printGroupNonUniformArithmeticOp(*this, p);
}
//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformIMulOp
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformIMulOp::verify() {
return verifyGroupNonUniformArithmeticOp(*this);
}
ParseResult GroupNonUniformIMulOp::parse(OpAsmParser &parser,
OperationState &result) {
return parseGroupNonUniformArithmeticOp(parser, result);
}
void GroupNonUniformIMulOp::print(OpAsmPrinter &p) {
printGroupNonUniformArithmeticOp(*this, p);
}
//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformSMaxOp
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformSMaxOp::verify() {
return verifyGroupNonUniformArithmeticOp(*this);
}
ParseResult GroupNonUniformSMaxOp::parse(OpAsmParser &parser,
OperationState &result) {
return parseGroupNonUniformArithmeticOp(parser, result);
}
void GroupNonUniformSMaxOp::print(OpAsmPrinter &p) {
printGroupNonUniformArithmeticOp(*this, p);
}
//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformSMinOp
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformSMinOp::verify() {
return verifyGroupNonUniformArithmeticOp(*this);
}
ParseResult GroupNonUniformSMinOp::parse(OpAsmParser &parser,
OperationState &result) {
return parseGroupNonUniformArithmeticOp(parser, result);
}
void GroupNonUniformSMinOp::print(OpAsmPrinter &p) {
printGroupNonUniformArithmeticOp(*this, p);
}
//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformUMaxOp
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformUMaxOp::verify() {
return verifyGroupNonUniformArithmeticOp(*this);
}
ParseResult GroupNonUniformUMaxOp::parse(OpAsmParser &parser,
OperationState &result) {
return parseGroupNonUniformArithmeticOp(parser, result);
}
void GroupNonUniformUMaxOp::print(OpAsmPrinter &p) {
printGroupNonUniformArithmeticOp(*this, p);
}
//===----------------------------------------------------------------------===//
// spirv.GroupNonUniformUMinOp
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformUMinOp::verify() {
return verifyGroupNonUniformArithmeticOp(*this);
}
ParseResult GroupNonUniformUMinOp::parse(OpAsmParser &parser,
OperationState &result) {
return parseGroupNonUniformArithmeticOp(parser, result);
}
void GroupNonUniformUMinOp::print(OpAsmPrinter &p) {
printGroupNonUniformArithmeticOp(*this, p);
}
//===----------------------------------------------------------------------===//
// Group op verification
//===----------------------------------------------------------------------===//
template <typename Op>
static LogicalResult verifyGroupOp(Op op) {
spirv::Scope scope = op.getExecutionScope();
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
return success();
}
LogicalResult GroupIAddOp::verify() { return verifyGroupOp(*this); }
LogicalResult GroupFAddOp::verify() { return verifyGroupOp(*this); }
LogicalResult GroupFMinOp::verify() { return verifyGroupOp(*this); }
LogicalResult GroupUMinOp::verify() { return verifyGroupOp(*this); }
LogicalResult GroupSMinOp::verify() { return verifyGroupOp(*this); }
LogicalResult GroupFMaxOp::verify() { return verifyGroupOp(*this); }
LogicalResult GroupUMaxOp::verify() { return verifyGroupOp(*this); }
LogicalResult GroupSMaxOp::verify() { return verifyGroupOp(*this); }
LogicalResult GroupIMulKHROp::verify() { return verifyGroupOp(*this); }
LogicalResult GroupFMulKHROp::verify() { return verifyGroupOp(*this); }
} // namespace mlir::spirv

View File

@@ -29,4 +29,9 @@ inline unsigned getBitWidth(Type type) {
llvm_unreachable("unhandled bit width computation for type");
}
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value);
LogicalResult verifyMemorySemantics(Operation *op,
spirv::MemorySemantics memorySemantics);
} // namespace mlir::spirv

File diff suppressed because it is too large Load Diff