mirror of
https://github.com/intel/llvm.git
synced 2026-01-22 23:49:22 +08:00
[mlir][spirv] Extract Atomic/Cast/Group op implementation. NFC.
Continue to work outlined in D155747 and split the main SPIR-V ops implementation file into a few smaller and quicker to compile files. This organization matches the op definition organizaion in `.td` files. In this patch, extract atomic, cast/conversion, and group op implementation into separate files. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D155777
This commit is contained in:
@@ -51,7 +51,7 @@ def SPIRV_BitcastOp : SPIRV_Op<"Bitcast", [Pure]> {
|
||||
|
||||
If Result Type has a different number of components than Operand, the
|
||||
total number of bits in Result Type must equal the total number of bits
|
||||
in Operand. Let L be the type, either Result Type or Operand’s type,
|
||||
in Operand. Let L be the type, either Result Type or Operand's type,
|
||||
that has the larger number of components. Let S be the other type, with
|
||||
the smaller number of components. The number of components in L must be
|
||||
an integer multiple of the number of components in S. The first
|
||||
@@ -335,17 +335,17 @@ def SPIRV_UConvertOp : SPIRV_CastOp<"UConvert",
|
||||
|
||||
def SPIRV_ConvertPtrToUOp : SPIRV_Op<"ConvertPtrToU", []> {
|
||||
let summary = [{
|
||||
Bit pattern-preserving conversion of a pointer to
|
||||
Bit pattern-preserving conversion of a pointer to
|
||||
an unsigned scalar integer of possibly different bit width.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Result Type must be a scalar of integer type, whose Signedness operand is 0.
|
||||
|
||||
Pointer must be a physical pointer type. If the bit width of Pointer is
|
||||
smaller than that of Result Type, the conversion zero extends Pointer.
|
||||
If the bit width of Pointer is larger than that of Result Type,
|
||||
the conversion truncates Pointer.
|
||||
Pointer must be a physical pointer type. If the bit width of Pointer is
|
||||
smaller than that of Result Type, the conversion zero extends Pointer.
|
||||
If the bit width of Pointer is larger than that of Result Type,
|
||||
the conversion truncates Pointer.
|
||||
|
||||
For same bit width Pointer and Result Type, this is the same as OpBitcast.
|
||||
|
||||
@@ -359,7 +359,7 @@ def SPIRV_ConvertPtrToUOp : SPIRV_Op<"ConvertPtrToU", []> {
|
||||
#### Example:
|
||||
|
||||
```mlir
|
||||
%1 = spirv.ConvertPtrToU %0 : !spirv.ptr<i32, Generic> to i32
|
||||
%1 = spirv.ConvertPtrToU %0 : !spirv.ptr<i32, Generic> to i32
|
||||
```
|
||||
}];
|
||||
|
||||
@@ -390,18 +390,18 @@ def SPIRV_ConvertPtrToUOp : SPIRV_Op<"ConvertPtrToU", []> {
|
||||
|
||||
def SPIRV_ConvertUToPtrOp : SPIRV_Op<"ConvertUToPtr", [UnsignedOp]> {
|
||||
let summary = [{
|
||||
Bit pattern-preserving conversion of an unsigned scalar integer
|
||||
Bit pattern-preserving conversion of an unsigned scalar integer
|
||||
to a pointer.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Result Type must be a physical pointer type.
|
||||
|
||||
Integer Value must be a scalar of integer type, whose Signedness
|
||||
operand is 0. If the bit width of Integer Value is smaller
|
||||
Integer Value must be a scalar of integer type, whose Signedness
|
||||
operand is 0. If the bit width of Integer Value is smaller
|
||||
than that of Result Type, the conversion zero extends Integer Value.
|
||||
If the bit width of Integer Value is larger than that of Result Type,
|
||||
the conversion truncates Integer Value.
|
||||
If the bit width of Integer Value is larger than that of Result Type,
|
||||
the conversion truncates Integer Value.
|
||||
|
||||
For same-width Integer Value and Result Type, this is the same as OpBitcast.
|
||||
|
||||
@@ -415,7 +415,7 @@ def SPIRV_ConvertUToPtrOp : SPIRV_Op<"ConvertUToPtr", [UnsignedOp]> {
|
||||
#### Example:
|
||||
|
||||
```mlir
|
||||
%1 = spirv.ConvertUToPtr %0 : i32 to !spirv.ptr<i32, Generic>
|
||||
%1 = spirv.ConvertUToPtr %0 : i32 to !spirv.ptr<i32, Generic>
|
||||
```
|
||||
}];
|
||||
|
||||
|
||||
441
mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
Normal file
441
mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
Normal file
@@ -0,0 +1,441 @@
|
||||
//===- AtomicOps.cpp - MLIR SPIR-V Atomic Ops ----------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Defines the atomic operations in the SPIR-V dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
||||
|
||||
#include "SPIRVOpUtils.h"
|
||||
#include "SPIRVParsingUtils.h"
|
||||
|
||||
using namespace mlir::spirv::AttrNames;
|
||||
|
||||
namespace mlir::spirv {
|
||||
|
||||
// Parses an atomic update op. If the update op does not take a value (like
|
||||
// AtomicIIncrement) `hasValue` must be false.
|
||||
static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
|
||||
OperationState &state, bool hasValue) {
|
||||
spirv::Scope scope;
|
||||
spirv::MemorySemantics memoryScope;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
|
||||
OpAsmParser::UnresolvedOperand ptrInfo, valueInfo;
|
||||
Type type;
|
||||
SMLoc loc;
|
||||
if (parseEnumStrAttr<spirv::ScopeAttr>(scope, parser, state,
|
||||
kMemoryScopeAttrName) ||
|
||||
parseEnumStrAttr<spirv::MemorySemanticsAttr>(memoryScope, parser, state,
|
||||
kSemanticsAttrName) ||
|
||||
parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) ||
|
||||
parser.getCurrentLocation(&loc) || parser.parseColonType(type))
|
||||
return failure();
|
||||
|
||||
auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
|
||||
if (!ptrType)
|
||||
return parser.emitError(loc, "expected pointer type");
|
||||
|
||||
SmallVector<Type, 2> operandTypes;
|
||||
operandTypes.push_back(ptrType);
|
||||
if (hasValue)
|
||||
operandTypes.push_back(ptrType.getPointeeType());
|
||||
if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(),
|
||||
state.operands))
|
||||
return failure();
|
||||
return parser.addTypeToList(ptrType.getPointeeType(), state.types);
|
||||
}
|
||||
|
||||
// Prints an atomic update op.
|
||||
static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
|
||||
printer << " \"";
|
||||
auto scopeAttr = op->getAttrOfType<spirv::ScopeAttr>(kMemoryScopeAttrName);
|
||||
printer << spirv::stringifyScope(scopeAttr.getValue()) << "\" \"";
|
||||
auto memorySemanticsAttr =
|
||||
op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName);
|
||||
printer << spirv::stringifyMemorySemantics(memorySemanticsAttr.getValue())
|
||||
<< "\" " << op->getOperands() << " : " << op->getOperand(0).getType();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static StringRef stringifyTypeName();
|
||||
|
||||
template <>
|
||||
StringRef stringifyTypeName<IntegerType>() {
|
||||
return "integer";
|
||||
}
|
||||
|
||||
template <>
|
||||
StringRef stringifyTypeName<FloatType>() {
|
||||
return "float";
|
||||
}
|
||||
|
||||
// Verifies an atomic update op.
|
||||
template <typename ExpectedElementType>
|
||||
static LogicalResult verifyAtomicUpdateOp(Operation *op) {
|
||||
auto ptrType = llvm::cast<spirv::PointerType>(op->getOperand(0).getType());
|
||||
auto elementType = ptrType.getPointeeType();
|
||||
if (!llvm::isa<ExpectedElementType>(elementType))
|
||||
return op->emitOpError() << "pointer operand must point to an "
|
||||
<< stringifyTypeName<ExpectedElementType>()
|
||||
<< " value, found " << elementType;
|
||||
|
||||
if (op->getNumOperands() > 1) {
|
||||
auto valueType = op->getOperand(1).getType();
|
||||
if (valueType != elementType)
|
||||
return op->emitOpError("expected value to have the same type as the "
|
||||
"pointer operand's pointee type ")
|
||||
<< elementType << ", but found " << valueType;
|
||||
}
|
||||
auto memorySemantics =
|
||||
op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName)
|
||||
.getValue();
|
||||
if (failed(verifyMemorySemantics(op, memorySemantics))) {
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer) {
|
||||
printer << " \"" << stringifyScope(atomOp.getMemoryScope()) << "\" \""
|
||||
<< stringifyMemorySemantics(atomOp.getEqualSemantics()) << "\" \""
|
||||
<< stringifyMemorySemantics(atomOp.getUnequalSemantics()) << "\" "
|
||||
<< atomOp.getOperands() << " : " << atomOp.getPointer().getType();
|
||||
}
|
||||
|
||||
static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser,
|
||||
OperationState &state) {
|
||||
spirv::Scope memoryScope;
|
||||
spirv::MemorySemantics equalSemantics, unequalSemantics;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
|
||||
Type type;
|
||||
if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, state,
|
||||
kMemoryScopeAttrName) ||
|
||||
parseEnumStrAttr<spirv::MemorySemanticsAttr>(
|
||||
equalSemantics, parser, state, kEqualSemanticsAttrName) ||
|
||||
parseEnumStrAttr<spirv::MemorySemanticsAttr>(
|
||||
unequalSemantics, parser, state, kUnequalSemanticsAttrName) ||
|
||||
parser.parseOperandList(operandInfo, 3))
|
||||
return failure();
|
||||
|
||||
auto loc = parser.getCurrentLocation();
|
||||
if (parser.parseColonType(type))
|
||||
return failure();
|
||||
|
||||
auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
|
||||
if (!ptrType)
|
||||
return parser.emitError(loc, "expected pointer type");
|
||||
|
||||
if (parser.resolveOperands(
|
||||
operandInfo,
|
||||
{ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()},
|
||||
parser.getNameLoc(), state.operands))
|
||||
return failure();
|
||||
|
||||
return parser.addTypeToList(ptrType.getPointeeType(), state.types);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp) {
|
||||
// According to the spec:
|
||||
// "The type of Value must be the same as Result Type. The type of the value
|
||||
// pointed to by Pointer must be the same as Result Type. This type must also
|
||||
// match the type of Comparator."
|
||||
if (atomOp.getType() != atomOp.getValue().getType())
|
||||
return atomOp.emitOpError("value operand must have the same type as the op "
|
||||
"result, but found ")
|
||||
<< atomOp.getValue().getType() << " vs " << atomOp.getType();
|
||||
|
||||
if (atomOp.getType() != atomOp.getComparator().getType())
|
||||
return atomOp.emitOpError(
|
||||
"comparator operand must have the same type as the op "
|
||||
"result, but found ")
|
||||
<< atomOp.getComparator().getType() << " vs " << atomOp.getType();
|
||||
|
||||
Type pointeeType =
|
||||
llvm::cast<spirv::PointerType>(atomOp.getPointer().getType())
|
||||
.getPointeeType();
|
||||
if (atomOp.getType() != pointeeType)
|
||||
return atomOp.emitOpError(
|
||||
"pointer operand's pointee type must have the same "
|
||||
"as the op result type, but found ")
|
||||
<< pointeeType << " vs " << atomOp.getType();
|
||||
|
||||
// TODO: Unequal cannot be set to Release or Acquire and Release.
|
||||
// In addition, Unequal cannot be set to a stronger memory-order then Equal.
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.AtomicAndOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult AtomicAndOp::verify() {
|
||||
return verifyAtomicUpdateOp<IntegerType>(getOperation());
|
||||
}
|
||||
|
||||
ParseResult AtomicAndOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseAtomicUpdateOp(parser, result, true);
|
||||
}
|
||||
|
||||
void AtomicAndOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.AtomicCompareExchangeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult AtomicCompareExchangeOp::verify() {
|
||||
return verifyAtomicCompareExchangeImpl(*this);
|
||||
}
|
||||
|
||||
ParseResult AtomicCompareExchangeOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
return parseAtomicCompareExchangeImpl(parser, result);
|
||||
}
|
||||
|
||||
void AtomicCompareExchangeOp::print(OpAsmPrinter &p) {
|
||||
printAtomicCompareExchangeImpl(*this, p);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.AtomicCompareExchangeWeakOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult AtomicCompareExchangeWeakOp::verify() {
|
||||
return verifyAtomicCompareExchangeImpl(*this);
|
||||
}
|
||||
|
||||
ParseResult AtomicCompareExchangeWeakOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
return parseAtomicCompareExchangeImpl(parser, result);
|
||||
}
|
||||
|
||||
void AtomicCompareExchangeWeakOp::print(OpAsmPrinter &p) {
|
||||
printAtomicCompareExchangeImpl(*this, p);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.AtomicExchange
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void AtomicExchangeOp::print(OpAsmPrinter &printer) {
|
||||
printer << " \"" << stringifyScope(getMemoryScope()) << "\" \""
|
||||
<< stringifyMemorySemantics(getSemantics()) << "\" " << getOperands()
|
||||
<< " : " << getPointer().getType();
|
||||
}
|
||||
|
||||
ParseResult AtomicExchangeOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
spirv::Scope memoryScope;
|
||||
spirv::MemorySemantics semantics;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
|
||||
Type type;
|
||||
if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, result,
|
||||
kMemoryScopeAttrName) ||
|
||||
parseEnumStrAttr<spirv::MemorySemanticsAttr>(semantics, parser, result,
|
||||
kSemanticsAttrName) ||
|
||||
parser.parseOperandList(operandInfo, 2))
|
||||
return failure();
|
||||
|
||||
auto loc = parser.getCurrentLocation();
|
||||
if (parser.parseColonType(type))
|
||||
return failure();
|
||||
|
||||
auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
|
||||
if (!ptrType)
|
||||
return parser.emitError(loc, "expected pointer type");
|
||||
|
||||
if (parser.resolveOperands(operandInfo, {ptrType, ptrType.getPointeeType()},
|
||||
parser.getNameLoc(), result.operands))
|
||||
return failure();
|
||||
|
||||
return parser.addTypeToList(ptrType.getPointeeType(), result.types);
|
||||
}
|
||||
|
||||
LogicalResult AtomicExchangeOp::verify() {
|
||||
if (getType() != getValue().getType())
|
||||
return emitOpError("value operand must have the same type as the op "
|
||||
"result, but found ")
|
||||
<< getValue().getType() << " vs " << getType();
|
||||
|
||||
Type pointeeType =
|
||||
llvm::cast<spirv::PointerType>(getPointer().getType()).getPointeeType();
|
||||
if (getType() != pointeeType)
|
||||
return emitOpError("pointer operand's pointee type must have the same "
|
||||
"as the op result type, but found ")
|
||||
<< pointeeType << " vs " << getType();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.AtomicIAddOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult AtomicIAddOp::verify() {
|
||||
return verifyAtomicUpdateOp<IntegerType>(getOperation());
|
||||
}
|
||||
|
||||
ParseResult AtomicIAddOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseAtomicUpdateOp(parser, result, true);
|
||||
}
|
||||
|
||||
void AtomicIAddOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.EXT.AtomicFAddOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult EXTAtomicFAddOp::verify() {
|
||||
return verifyAtomicUpdateOp<FloatType>(getOperation());
|
||||
}
|
||||
|
||||
ParseResult EXTAtomicFAddOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
return parseAtomicUpdateOp(parser, result, true);
|
||||
}
|
||||
|
||||
void spirv::EXTAtomicFAddOp::print(OpAsmPrinter &p) {
|
||||
printAtomicUpdateOp(*this, p);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.AtomicIDecrementOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult AtomicIDecrementOp::verify() {
|
||||
return verifyAtomicUpdateOp<IntegerType>(getOperation());
|
||||
}
|
||||
|
||||
ParseResult AtomicIDecrementOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
return parseAtomicUpdateOp(parser, result, false);
|
||||
}
|
||||
|
||||
void AtomicIDecrementOp::print(OpAsmPrinter &p) {
|
||||
printAtomicUpdateOp(*this, p);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.AtomicIIncrementOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult AtomicIIncrementOp::verify() {
|
||||
return verifyAtomicUpdateOp<IntegerType>(getOperation());
|
||||
}
|
||||
|
||||
ParseResult AtomicIIncrementOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
return parseAtomicUpdateOp(parser, result, false);
|
||||
}
|
||||
|
||||
void AtomicIIncrementOp::print(OpAsmPrinter &p) {
|
||||
printAtomicUpdateOp(*this, p);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.AtomicISubOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult AtomicISubOp::verify() {
|
||||
return verifyAtomicUpdateOp<IntegerType>(getOperation());
|
||||
}
|
||||
|
||||
ParseResult AtomicISubOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseAtomicUpdateOp(parser, result, true);
|
||||
}
|
||||
|
||||
void AtomicISubOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.AtomicOrOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult AtomicOrOp::verify() {
|
||||
return verifyAtomicUpdateOp<IntegerType>(getOperation());
|
||||
}
|
||||
|
||||
ParseResult AtomicOrOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseAtomicUpdateOp(parser, result, true);
|
||||
}
|
||||
|
||||
void AtomicOrOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.AtomicSMaxOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult AtomicSMaxOp::verify() {
|
||||
return verifyAtomicUpdateOp<IntegerType>(getOperation());
|
||||
}
|
||||
|
||||
ParseResult AtomicSMaxOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseAtomicUpdateOp(parser, result, true);
|
||||
}
|
||||
|
||||
void AtomicSMaxOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.AtomicSMinOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult AtomicSMinOp::verify() {
|
||||
return verifyAtomicUpdateOp<IntegerType>(getOperation());
|
||||
}
|
||||
|
||||
ParseResult AtomicSMinOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseAtomicUpdateOp(parser, result, true);
|
||||
}
|
||||
|
||||
void AtomicSMinOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.AtomicUMaxOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult AtomicUMaxOp::verify() {
|
||||
return verifyAtomicUpdateOp<IntegerType>(getOperation());
|
||||
}
|
||||
|
||||
ParseResult AtomicUMaxOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseAtomicUpdateOp(parser, result, true);
|
||||
}
|
||||
|
||||
void AtomicUMaxOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.AtomicUMinOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult AtomicUMinOp::verify() {
|
||||
return verifyAtomicUpdateOp<IntegerType>(getOperation());
|
||||
}
|
||||
|
||||
ParseResult AtomicUMinOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseAtomicUpdateOp(parser, result, true);
|
||||
}
|
||||
|
||||
void AtomicUMinOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.AtomicXorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult AtomicXorOp::verify() {
|
||||
return verifyAtomicUpdateOp<IntegerType>(getOperation());
|
||||
}
|
||||
|
||||
ParseResult AtomicXorOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseAtomicUpdateOp(parser, result, true);
|
||||
}
|
||||
|
||||
void AtomicXorOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
|
||||
|
||||
} // namespace mlir::spirv
|
||||
@@ -3,7 +3,10 @@ mlir_tablegen(SPIRVCanonicalization.inc -gen-rewriters)
|
||||
add_public_tablegen_target(MLIRSPIRVCanonicalizationIncGen)
|
||||
|
||||
add_mlir_dialect_library(MLIRSPIRVDialect
|
||||
AtomicOps.cpp
|
||||
CastOps.cpp
|
||||
CooperativeMatrixOps.cpp
|
||||
GroupOps.cpp
|
||||
IntegerDotProductOps.cpp
|
||||
JointMatrixOps.cpp
|
||||
SPIRVAttributes.cpp
|
||||
|
||||
339
mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
Normal file
339
mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
Normal file
@@ -0,0 +1,339 @@
|
||||
//===- CastOps.cpp - MLIR SPIR-V Cast Ops --------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Defines the cast and conversion operations in the SPIR-V dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
||||
|
||||
#include "SPIRVOpUtils.h"
|
||||
#include "SPIRVParsingUtils.h"
|
||||
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
using namespace mlir::spirv::AttrNames;
|
||||
|
||||
namespace mlir::spirv {
|
||||
|
||||
static LogicalResult verifyCastOp(Operation *op,
|
||||
bool requireSameBitWidth = true,
|
||||
bool skipBitWidthCheck = false) {
|
||||
// Some CastOps have no limit on bit widths for result and operand type.
|
||||
if (skipBitWidthCheck)
|
||||
return success();
|
||||
|
||||
Type operandType = op->getOperand(0).getType();
|
||||
Type resultType = op->getResult(0).getType();
|
||||
|
||||
// ODS checks that result type and operand type have the same shape. Check
|
||||
// that composite types match and extract the element types, if any.
|
||||
using TypePair = std::pair<Type, Type>;
|
||||
auto [operandElemTy, resultElemTy] =
|
||||
TypeSwitch<Type, TypePair>(operandType)
|
||||
.Case<VectorType, spirv::CooperativeMatrixType,
|
||||
spirv::CooperativeMatrixNVType, spirv::JointMatrixINTELType>(
|
||||
[resultType](auto concreteOperandTy) -> TypePair {
|
||||
if (auto concreteResultTy =
|
||||
dyn_cast<decltype(concreteOperandTy)>(resultType)) {
|
||||
return {concreteOperandTy.getElementType(),
|
||||
concreteResultTy.getElementType()};
|
||||
}
|
||||
return {};
|
||||
})
|
||||
.Default([resultType](Type operandType) -> TypePair {
|
||||
return {operandType, resultType};
|
||||
});
|
||||
|
||||
if (!operandElemTy || !resultElemTy)
|
||||
return op->emitOpError("incompatible operand and result types");
|
||||
|
||||
unsigned operandTypeBitWidth = operandElemTy.getIntOrFloatBitWidth();
|
||||
unsigned resultTypeBitWidth = resultElemTy.getIntOrFloatBitWidth();
|
||||
bool isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
|
||||
|
||||
if (requireSameBitWidth) {
|
||||
if (!isSameBitWidth) {
|
||||
return op->emitOpError(
|
||||
"expected the same bit widths for operand type and result "
|
||||
"type, but provided ")
|
||||
<< operandElemTy << " and " << resultElemTy;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
if (isSameBitWidth) {
|
||||
return op->emitOpError(
|
||||
"expected the different bit widths for operand type and result "
|
||||
"type, but provided ")
|
||||
<< operandElemTy << " and " << resultElemTy;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.BitcastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult BitcastOp::verify() {
|
||||
// TODO: The SPIR-V spec validation rules are different for different
|
||||
// versions.
|
||||
auto operandType = getOperand().getType();
|
||||
auto resultType = getResult().getType();
|
||||
if (operandType == resultType) {
|
||||
return emitError("result type must be different from operand type");
|
||||
}
|
||||
if (llvm::isa<spirv::PointerType>(operandType) &&
|
||||
!llvm::isa<spirv::PointerType>(resultType)) {
|
||||
return emitError(
|
||||
"unhandled bit cast conversion from pointer type to non-pointer type");
|
||||
}
|
||||
if (!llvm::isa<spirv::PointerType>(operandType) &&
|
||||
llvm::isa<spirv::PointerType>(resultType)) {
|
||||
return emitError(
|
||||
"unhandled bit cast conversion from non-pointer type to pointer type");
|
||||
}
|
||||
auto operandBitWidth = getBitWidth(operandType);
|
||||
auto resultBitWidth = getBitWidth(resultType);
|
||||
if (operandBitWidth != resultBitWidth) {
|
||||
return emitOpError("mismatch in result type bitwidth ")
|
||||
<< resultBitWidth << " and operand type bitwidth "
|
||||
<< operandBitWidth;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.ConvertPtrToUOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ConvertPtrToUOp::verify() {
|
||||
auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
|
||||
auto resultType = llvm::cast<spirv::ScalarType>(getResult().getType());
|
||||
if (!resultType || !resultType.isSignlessInteger())
|
||||
return emitError("result must be a scalar type of unsigned integer");
|
||||
auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
|
||||
if (!spirvModule)
|
||||
return success();
|
||||
auto addressingModel = spirvModule.getAddressingModel();
|
||||
if ((addressingModel == spirv::AddressingModel::Logical) ||
|
||||
(addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 &&
|
||||
operandType.getStorageClass() !=
|
||||
spirv::StorageClass::PhysicalStorageBuffer))
|
||||
return emitError("operand must be a physical pointer");
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.ConvertUToPtrOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ConvertUToPtrOp::verify() {
|
||||
auto operandType = llvm::cast<spirv::ScalarType>(getOperand().getType());
|
||||
auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
|
||||
if (!operandType || !operandType.isSignlessInteger())
|
||||
return emitError("result must be a scalar type of unsigned integer");
|
||||
auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
|
||||
if (!spirvModule)
|
||||
return success();
|
||||
auto addressingModel = spirvModule.getAddressingModel();
|
||||
if ((addressingModel == spirv::AddressingModel::Logical) ||
|
||||
(addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 &&
|
||||
resultType.getStorageClass() !=
|
||||
spirv::StorageClass::PhysicalStorageBuffer))
|
||||
return emitError("result must be a physical pointer");
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.PtrCastToGenericOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult PtrCastToGenericOp::verify() {
|
||||
auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
|
||||
auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
|
||||
|
||||
spirv::StorageClass operandStorage = operandType.getStorageClass();
|
||||
if (operandStorage != spirv::StorageClass::Workgroup &&
|
||||
operandStorage != spirv::StorageClass::CrossWorkgroup &&
|
||||
operandStorage != spirv::StorageClass::Function)
|
||||
return emitError("pointer must point to the Workgroup, CrossWorkgroup"
|
||||
", or Function Storage Class");
|
||||
|
||||
spirv::StorageClass resultStorage = resultType.getStorageClass();
|
||||
if (resultStorage != spirv::StorageClass::Generic)
|
||||
return emitError("result type must be of storage class Generic");
|
||||
|
||||
Type operandPointeeType = operandType.getPointeeType();
|
||||
Type resultPointeeType = resultType.getPointeeType();
|
||||
if (operandPointeeType != resultPointeeType)
|
||||
return emitOpError("pointer operand's pointee type must have the same "
|
||||
"as the op result type, but found ")
|
||||
<< operandPointeeType << " vs " << resultPointeeType;
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.GenericCastToPtrOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult GenericCastToPtrOp::verify() {
|
||||
auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
|
||||
auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
|
||||
|
||||
spirv::StorageClass operandStorage = operandType.getStorageClass();
|
||||
if (operandStorage != spirv::StorageClass::Generic)
|
||||
return emitError("pointer type must be of storage class Generic");
|
||||
|
||||
spirv::StorageClass resultStorage = resultType.getStorageClass();
|
||||
if (resultStorage != spirv::StorageClass::Workgroup &&
|
||||
resultStorage != spirv::StorageClass::CrossWorkgroup &&
|
||||
resultStorage != spirv::StorageClass::Function)
|
||||
return emitError("result must point to the Workgroup, CrossWorkgroup, "
|
||||
"or Function Storage Class");
|
||||
|
||||
Type operandPointeeType = operandType.getPointeeType();
|
||||
Type resultPointeeType = resultType.getPointeeType();
|
||||
if (operandPointeeType != resultPointeeType)
|
||||
return emitOpError("pointer operand's pointee type must have the same "
|
||||
"as the op result type, but found ")
|
||||
<< operandPointeeType << " vs " << resultPointeeType;
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.GenericCastToPtrExplicitOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult GenericCastToPtrExplicitOp::verify() {
|
||||
auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
|
||||
auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
|
||||
|
||||
spirv::StorageClass operandStorage = operandType.getStorageClass();
|
||||
if (operandStorage != spirv::StorageClass::Generic)
|
||||
return emitError("pointer type must be of storage class Generic");
|
||||
|
||||
spirv::StorageClass resultStorage = resultType.getStorageClass();
|
||||
if (resultStorage != spirv::StorageClass::Workgroup &&
|
||||
resultStorage != spirv::StorageClass::CrossWorkgroup &&
|
||||
resultStorage != spirv::StorageClass::Function)
|
||||
return emitError("result must point to the Workgroup, CrossWorkgroup, "
|
||||
"or Function Storage Class");
|
||||
|
||||
Type operandPointeeType = operandType.getPointeeType();
|
||||
Type resultPointeeType = resultType.getPointeeType();
|
||||
if (operandPointeeType != resultPointeeType)
|
||||
return emitOpError("pointer operand's pointee type must have the same "
|
||||
"as the op result type, but found ")
|
||||
<< operandPointeeType << " vs " << resultPointeeType;
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.ConvertFToSOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ConvertFToSOp::verify() {
|
||||
return verifyCastOp(*this, /*requireSameBitWidth=*/false,
|
||||
/*skipBitWidthCheck=*/true);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.ConvertFToUOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ConvertFToUOp::verify() {
|
||||
return verifyCastOp(*this, /*requireSameBitWidth=*/false,
|
||||
/*skipBitWidthCheck=*/true);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.ConvertSToFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ConvertSToFOp::verify() {
|
||||
return verifyCastOp(*this, /*requireSameBitWidth=*/false,
|
||||
/*skipBitWidthCheck=*/true);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.ConvertUToFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ConvertUToFOp::verify() {
|
||||
return verifyCastOp(*this, /*requireSameBitWidth=*/false,
|
||||
/*skipBitWidthCheck=*/true);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.INTELConvertBF16ToFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult INTELConvertBF16ToFOp::verify() {
|
||||
auto operandType = getOperand().getType();
|
||||
auto resultType = getResult().getType();
|
||||
// ODS checks that vector result type and vector operand type have the same
|
||||
// shape.
|
||||
if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
|
||||
unsigned operandNumElements = vectorType.getNumElements();
|
||||
unsigned resultNumElements =
|
||||
llvm::cast<VectorType>(resultType).getNumElements();
|
||||
if (operandNumElements != resultNumElements) {
|
||||
return emitOpError(
|
||||
"operand and result must have same number of elements");
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.INTELConvertFToBF16Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult INTELConvertFToBF16Op::verify() {
|
||||
auto operandType = getOperand().getType();
|
||||
auto resultType = getResult().getType();
|
||||
// ODS checks that vector result type and vector operand type have the same
|
||||
// shape.
|
||||
if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
|
||||
unsigned operandNumElements = vectorType.getNumElements();
|
||||
unsigned resultNumElements =
|
||||
llvm::cast<VectorType>(resultType).getNumElements();
|
||||
if (operandNumElements != resultNumElements) {
|
||||
return emitOpError(
|
||||
"operand and result must have same number of elements");
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.FConvertOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult spirv::FConvertOp::verify() {
|
||||
return verifyCastOp(*this, /*requireSameBitWidth=*/false);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.SConvertOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult spirv::SConvertOp::verify() {
|
||||
return verifyCastOp(*this, /*requireSameBitWidth=*/false);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.UConvertOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult spirv::UConvertOp::verify() {
|
||||
return verifyCastOp(*this, /*requireSameBitWidth=*/false);
|
||||
}
|
||||
|
||||
} // namespace mlir::spirv
|
||||
407
mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
Normal file
407
mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
Normal file
@@ -0,0 +1,407 @@
|
||||
//===- GroupOps.cpp - MLIR SPIR-V Group Ops ------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Defines the group operations in the SPIR-V dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
|
||||
|
||||
#include "SPIRVOpUtils.h"
|
||||
#include "SPIRVParsingUtils.h"
|
||||
|
||||
using namespace mlir::spirv::AttrNames;
|
||||
|
||||
namespace mlir::spirv {
|
||||
|
||||
static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
|
||||
OperationState &state) {
|
||||
spirv::Scope executionScope;
|
||||
GroupOperation groupOperation;
|
||||
OpAsmParser::UnresolvedOperand valueInfo;
|
||||
if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
|
||||
kExecutionScopeAttrName) ||
|
||||
spirv::parseEnumStrAttr<GroupOperationAttr>(groupOperation, parser, state,
|
||||
kGroupOperationAttrName) ||
|
||||
parser.parseOperand(valueInfo))
|
||||
return failure();
|
||||
|
||||
std::optional<OpAsmParser::UnresolvedOperand> clusterSizeInfo;
|
||||
if (succeeded(parser.parseOptionalKeyword(kClusterSize))) {
|
||||
clusterSizeInfo = OpAsmParser::UnresolvedOperand();
|
||||
if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) ||
|
||||
parser.parseRParen())
|
||||
return failure();
|
||||
}
|
||||
|
||||
Type resultType;
|
||||
if (parser.parseColonType(resultType))
|
||||
return failure();
|
||||
|
||||
if (parser.resolveOperand(valueInfo, resultType, state.operands))
|
||||
return failure();
|
||||
|
||||
if (clusterSizeInfo) {
|
||||
Type i32Type = parser.getBuilder().getIntegerType(32);
|
||||
if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands))
|
||||
return failure();
|
||||
}
|
||||
|
||||
return parser.addTypeToList(resultType, state.types);
|
||||
}
|
||||
|
||||
static void printGroupNonUniformArithmeticOp(Operation *groupOp,
|
||||
OpAsmPrinter &printer) {
|
||||
printer
|
||||
<< " \""
|
||||
<< stringifyScope(
|
||||
groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
|
||||
.getValue())
|
||||
<< "\" \""
|
||||
<< stringifyGroupOperation(
|
||||
groupOp->getAttrOfType<GroupOperationAttr>(kGroupOperationAttrName)
|
||||
.getValue())
|
||||
<< "\" " << groupOp->getOperand(0);
|
||||
|
||||
if (groupOp->getNumOperands() > 1)
|
||||
printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
|
||||
printer << " : " << groupOp->getResult(0).getType();
|
||||
}
|
||||
|
||||
static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
|
||||
spirv::Scope scope =
|
||||
groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
|
||||
.getValue();
|
||||
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
|
||||
return groupOp->emitOpError(
|
||||
"execution scope must be 'Workgroup' or 'Subgroup'");
|
||||
|
||||
GroupOperation operation =
|
||||
groupOp->getAttrOfType<GroupOperationAttr>(kGroupOperationAttrName)
|
||||
.getValue();
|
||||
if (operation == GroupOperation::ClusteredReduce &&
|
||||
groupOp->getNumOperands() == 1)
|
||||
return groupOp->emitOpError("cluster size operand must be provided for "
|
||||
"'ClusteredReduce' group operation");
|
||||
if (groupOp->getNumOperands() > 1) {
|
||||
Operation *sizeOp = groupOp->getOperand(1).getDefiningOp();
|
||||
int32_t clusterSize = 0;
|
||||
|
||||
// TODO: support specialization constant here.
|
||||
if (failed(extractValueFromConstOp(sizeOp, clusterSize)))
|
||||
return groupOp->emitOpError(
|
||||
"cluster size operand must come from a constant op");
|
||||
|
||||
if (!llvm::isPowerOf2_32(clusterSize))
|
||||
return groupOp->emitOpError(
|
||||
"cluster size operand must be a power of two");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.GroupBroadcast
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult GroupBroadcastOp::verify() {
|
||||
spirv::Scope scope = getExecutionScope();
|
||||
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
|
||||
return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
|
||||
|
||||
if (auto localIdTy = llvm::dyn_cast<VectorType>(getLocalid().getType()))
|
||||
if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3)
|
||||
return emitOpError("localid is a vector and can be with only "
|
||||
" 2 or 3 components, actual number is ")
|
||||
<< localIdTy.getNumElements();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.GroupNonUniformBallotOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult GroupNonUniformBallotOp::verify() {
|
||||
spirv::Scope scope = getExecutionScope();
|
||||
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
|
||||
return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.GroupNonUniformBroadcast
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult GroupNonUniformBroadcastOp::verify() {
|
||||
spirv::Scope scope = getExecutionScope();
|
||||
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
|
||||
return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
|
||||
|
||||
// SPIR-V spec: "Before version 1.5, Id must come from a
|
||||
// constant instruction.
|
||||
auto targetEnv = spirv::getDefaultTargetEnv(getContext());
|
||||
if (auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>())
|
||||
targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule);
|
||||
|
||||
if (targetEnv.getVersion() < spirv::Version::V_1_5) {
|
||||
auto *idOp = getId().getDefiningOp();
|
||||
if (!idOp || !isa<spirv::ConstantOp, // for normal constant
|
||||
spirv::ReferenceOfOp>(idOp)) // for spec constant
|
||||
return emitOpError("id must be the result of a constant op");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.GroupNonUniformShuffle*
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <typename OpTy>
|
||||
static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op) {
|
||||
spirv::Scope scope = op.getExecutionScope();
|
||||
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
|
||||
return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
|
||||
|
||||
if (op.getOperands().back().getType().isSignedInteger())
|
||||
return op.emitOpError("second operand must be a singless/unsigned integer");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult GroupNonUniformShuffleOp::verify() {
|
||||
return verifyGroupNonUniformShuffleOp(*this);
|
||||
}
|
||||
LogicalResult GroupNonUniformShuffleDownOp::verify() {
|
||||
return verifyGroupNonUniformShuffleOp(*this);
|
||||
}
|
||||
LogicalResult GroupNonUniformShuffleUpOp::verify() {
|
||||
return verifyGroupNonUniformShuffleOp(*this);
|
||||
}
|
||||
LogicalResult GroupNonUniformShuffleXorOp::verify() {
|
||||
return verifyGroupNonUniformShuffleOp(*this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.GroupNonUniformElectOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult GroupNonUniformElectOp::verify() {
|
||||
spirv::Scope scope = getExecutionScope();
|
||||
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
|
||||
return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.GroupNonUniformFAddOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult GroupNonUniformFAddOp::verify() {
|
||||
return verifyGroupNonUniformArithmeticOp(*this);
|
||||
}
|
||||
|
||||
ParseResult GroupNonUniformFAddOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
return parseGroupNonUniformArithmeticOp(parser, result);
|
||||
}
|
||||
|
||||
void GroupNonUniformFAddOp::print(OpAsmPrinter &p) {
|
||||
printGroupNonUniformArithmeticOp(*this, p);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.GroupNonUniformFMaxOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult GroupNonUniformFMaxOp::verify() {
|
||||
return verifyGroupNonUniformArithmeticOp(*this);
|
||||
}
|
||||
|
||||
ParseResult GroupNonUniformFMaxOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
return parseGroupNonUniformArithmeticOp(parser, result);
|
||||
}
|
||||
|
||||
void GroupNonUniformFMaxOp::print(OpAsmPrinter &p) {
|
||||
printGroupNonUniformArithmeticOp(*this, p);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.GroupNonUniformFMinOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult GroupNonUniformFMinOp::verify() {
|
||||
return verifyGroupNonUniformArithmeticOp(*this);
|
||||
}
|
||||
|
||||
ParseResult GroupNonUniformFMinOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
return parseGroupNonUniformArithmeticOp(parser, result);
|
||||
}
|
||||
|
||||
void GroupNonUniformFMinOp::print(OpAsmPrinter &p) {
|
||||
printGroupNonUniformArithmeticOp(*this, p);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.GroupNonUniformFMulOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult GroupNonUniformFMulOp::verify() {
|
||||
return verifyGroupNonUniformArithmeticOp(*this);
|
||||
}
|
||||
|
||||
ParseResult GroupNonUniformFMulOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
return parseGroupNonUniformArithmeticOp(parser, result);
|
||||
}
|
||||
|
||||
void GroupNonUniformFMulOp::print(OpAsmPrinter &p) {
|
||||
printGroupNonUniformArithmeticOp(*this, p);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.GroupNonUniformIAddOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult GroupNonUniformIAddOp::verify() {
|
||||
return verifyGroupNonUniformArithmeticOp(*this);
|
||||
}
|
||||
|
||||
ParseResult GroupNonUniformIAddOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
return parseGroupNonUniformArithmeticOp(parser, result);
|
||||
}
|
||||
|
||||
void GroupNonUniformIAddOp::print(OpAsmPrinter &p) {
|
||||
printGroupNonUniformArithmeticOp(*this, p);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.GroupNonUniformIMulOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult GroupNonUniformIMulOp::verify() {
|
||||
return verifyGroupNonUniformArithmeticOp(*this);
|
||||
}
|
||||
|
||||
ParseResult GroupNonUniformIMulOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
return parseGroupNonUniformArithmeticOp(parser, result);
|
||||
}
|
||||
|
||||
void GroupNonUniformIMulOp::print(OpAsmPrinter &p) {
|
||||
printGroupNonUniformArithmeticOp(*this, p);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.GroupNonUniformSMaxOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult GroupNonUniformSMaxOp::verify() {
|
||||
return verifyGroupNonUniformArithmeticOp(*this);
|
||||
}
|
||||
|
||||
ParseResult GroupNonUniformSMaxOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
return parseGroupNonUniformArithmeticOp(parser, result);
|
||||
}
|
||||
|
||||
void GroupNonUniformSMaxOp::print(OpAsmPrinter &p) {
|
||||
printGroupNonUniformArithmeticOp(*this, p);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.GroupNonUniformSMinOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult GroupNonUniformSMinOp::verify() {
|
||||
return verifyGroupNonUniformArithmeticOp(*this);
|
||||
}
|
||||
|
||||
ParseResult GroupNonUniformSMinOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
return parseGroupNonUniformArithmeticOp(parser, result);
|
||||
}
|
||||
|
||||
void GroupNonUniformSMinOp::print(OpAsmPrinter &p) {
|
||||
printGroupNonUniformArithmeticOp(*this, p);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.GroupNonUniformUMaxOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult GroupNonUniformUMaxOp::verify() {
|
||||
return verifyGroupNonUniformArithmeticOp(*this);
|
||||
}
|
||||
|
||||
ParseResult GroupNonUniformUMaxOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
return parseGroupNonUniformArithmeticOp(parser, result);
|
||||
}
|
||||
|
||||
void GroupNonUniformUMaxOp::print(OpAsmPrinter &p) {
|
||||
printGroupNonUniformArithmeticOp(*this, p);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.GroupNonUniformUMinOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult GroupNonUniformUMinOp::verify() {
|
||||
return verifyGroupNonUniformArithmeticOp(*this);
|
||||
}
|
||||
|
||||
ParseResult GroupNonUniformUMinOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
return parseGroupNonUniformArithmeticOp(parser, result);
|
||||
}
|
||||
|
||||
void GroupNonUniformUMinOp::print(OpAsmPrinter &p) {
|
||||
printGroupNonUniformArithmeticOp(*this, p);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Group op verification
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <typename Op>
|
||||
static LogicalResult verifyGroupOp(Op op) {
|
||||
spirv::Scope scope = op.getExecutionScope();
|
||||
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
|
||||
return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult GroupIAddOp::verify() { return verifyGroupOp(*this); }
|
||||
|
||||
LogicalResult GroupFAddOp::verify() { return verifyGroupOp(*this); }
|
||||
|
||||
LogicalResult GroupFMinOp::verify() { return verifyGroupOp(*this); }
|
||||
|
||||
LogicalResult GroupUMinOp::verify() { return verifyGroupOp(*this); }
|
||||
|
||||
LogicalResult GroupSMinOp::verify() { return verifyGroupOp(*this); }
|
||||
|
||||
LogicalResult GroupFMaxOp::verify() { return verifyGroupOp(*this); }
|
||||
|
||||
LogicalResult GroupUMaxOp::verify() { return verifyGroupOp(*this); }
|
||||
|
||||
LogicalResult GroupSMaxOp::verify() { return verifyGroupOp(*this); }
|
||||
|
||||
LogicalResult GroupIMulKHROp::verify() { return verifyGroupOp(*this); }
|
||||
|
||||
LogicalResult GroupFMulKHROp::verify() { return verifyGroupOp(*this); }
|
||||
|
||||
} // namespace mlir::spirv
|
||||
@@ -29,4 +29,9 @@ inline unsigned getBitWidth(Type type) {
|
||||
llvm_unreachable("unhandled bit width computation for type");
|
||||
}
|
||||
|
||||
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value);
|
||||
|
||||
LogicalResult verifyMemorySemantics(Operation *op,
|
||||
spirv::MemorySemantics memorySemantics);
|
||||
|
||||
} // namespace mlir::spirv
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user