mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 03:56:16 +08:00
[mlir][Type] Remove usages of Type::getKind
This is in preparation for removing the use of "kinds" within attributes and types in MLIR. Differential Revision: https://reviews.llvm.org/D85475
This commit is contained in:
@@ -101,10 +101,7 @@ public:
|
||||
}
|
||||
|
||||
/// Support for isa/cast.
|
||||
static bool classof(Type type) {
|
||||
return type.getKind() >= FIRST_NEW_LLVM_TYPE &&
|
||||
type.getKind() <= LAST_NEW_LLVM_TYPE;
|
||||
}
|
||||
static bool classof(Type type);
|
||||
|
||||
LLVMDialect &getDialect();
|
||||
|
||||
|
||||
@@ -71,10 +71,7 @@ public:
|
||||
int64_t storageTypeMax);
|
||||
|
||||
/// Support method to enable LLVM-style type casting.
|
||||
static bool classof(Type type) {
|
||||
return type.getKind() >= Type::FIRST_QUANTIZATION_TYPE &&
|
||||
type.getKind() <= QuantizationTypes::LAST_USED_QUANTIZATION_TYPE;
|
||||
}
|
||||
static bool classof(Type type);
|
||||
|
||||
/// Gets the minimum possible stored by a storageType. storageTypeMin must
|
||||
/// be greater than or equal to this value.
|
||||
|
||||
@@ -294,13 +294,7 @@ public:
|
||||
int64_t getSizeInBits() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(Type type) {
|
||||
return type.getKind() == StandardTypes::Vector ||
|
||||
type.getKind() == StandardTypes::RankedTensor ||
|
||||
type.getKind() == StandardTypes::UnrankedTensor ||
|
||||
type.getKind() == StandardTypes::UnrankedMemRef ||
|
||||
type.getKind() == StandardTypes::MemRef;
|
||||
}
|
||||
static bool classof(Type type);
|
||||
|
||||
/// Whether the given dimension size indicates a dynamic dimension.
|
||||
static constexpr bool isDynamic(int64_t dSize) {
|
||||
@@ -358,20 +352,10 @@ public:
|
||||
using ShapedType::ShapedType;
|
||||
|
||||
/// Return true if the specified element type is ok in a tensor.
|
||||
static bool isValidElementType(Type type) {
|
||||
// Note: Non standard/builtin types are allowed to exist within tensor
|
||||
// types. Dialects are expected to verify that tensor types have a valid
|
||||
// element type within that dialect.
|
||||
return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
|
||||
IndexType>() ||
|
||||
(type.getKind() > Type::Kind::LAST_STANDARD_TYPE);
|
||||
}
|
||||
static bool isValidElementType(Type type);
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(Type type) {
|
||||
return type.getKind() == StandardTypes::RankedTensor ||
|
||||
type.getKind() == StandardTypes::UnrankedTensor;
|
||||
}
|
||||
static bool classof(Type type);
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -443,10 +427,7 @@ public:
|
||||
using ShapedType::ShapedType;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(Type type) {
|
||||
return type.getKind() == StandardTypes::MemRef ||
|
||||
type.getKind() == StandardTypes::UnrankedMemRef;
|
||||
}
|
||||
static bool classof(Type type);
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -629,6 +610,23 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Deferred Method Definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
inline bool BaseMemRefType::classof(Type type) {
|
||||
return type.isa<MemRefType, UnrankedMemRefType>();
|
||||
}
|
||||
|
||||
inline bool ShapedType::classof(Type type) {
|
||||
return type.isa<RankedTensorType, VectorType, UnrankedTensorType,
|
||||
UnrankedMemRefType, MemRefType>();
|
||||
}
|
||||
|
||||
inline bool TensorType::classof(Type type) {
|
||||
return type.isa<RankedTensorType, UnrankedTensorType>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type Utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -27,6 +27,10 @@ using namespace mlir::LLVM;
|
||||
// LLVMType.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool LLVMType::classof(Type type) {
|
||||
return llvm::isa<LLVMDialect>(type.getDialect());
|
||||
}
|
||||
|
||||
LLVMDialect &LLVMType::getDialect() {
|
||||
return static_cast<LLVMDialect &>(Type::getDialect());
|
||||
}
|
||||
|
||||
@@ -55,11 +55,5 @@ static void print(RangeType rt, DialectAsmPrinter &os) { os << "range"; }
|
||||
|
||||
void mlir::linalg::LinalgDialect::printType(Type type,
|
||||
DialectAsmPrinter &os) const {
|
||||
switch (type.getKind()) {
|
||||
default:
|
||||
llvm_unreachable("Unhandled Linalg type");
|
||||
case LinalgTypes::Range:
|
||||
print(type.cast<RangeType>(), os);
|
||||
break;
|
||||
}
|
||||
print(type.cast<RangeType>(), os);
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#include "mlir/Dialect/Quant/QuantTypes.h"
|
||||
#include "TypeDetail.h"
|
||||
#include "mlir/Dialect/Quant/QuantOps.h"
|
||||
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
@@ -23,6 +24,10 @@ unsigned QuantizedType::getFlags() const {
|
||||
return static_cast<ImplType *>(impl)->flags;
|
||||
}
|
||||
|
||||
bool QuantizedType::classof(Type type) {
|
||||
return llvm::isa<QuantizationDialect>(type.getDialect());
|
||||
}
|
||||
|
||||
LogicalResult QuantizedType::verifyConstructionInvariants(
|
||||
Location loc, unsigned flags, Type storageType, Type expressedType,
|
||||
int64_t storageTypeMin, int64_t storageTypeMax) {
|
||||
|
||||
@@ -365,18 +365,12 @@ static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type,
|
||||
|
||||
/// Print a type registered to this dialect.
|
||||
void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const {
|
||||
switch (type.getKind()) {
|
||||
default:
|
||||
if (auto anyType = type.dyn_cast<AnyQuantizedType>())
|
||||
printAnyQuantizedType(anyType, os);
|
||||
else if (auto uniformType = type.dyn_cast<UniformQuantizedType>())
|
||||
printUniformQuantizedType(uniformType, os);
|
||||
else if (auto perAxisType = type.dyn_cast<UniformQuantizedPerAxisType>())
|
||||
printUniformQuantizedPerAxisType(perAxisType, os);
|
||||
else
|
||||
llvm_unreachable("Unhandled quantized type");
|
||||
case QuantizationTypes::Any:
|
||||
printAnyQuantizedType(type.cast<AnyQuantizedType>(), os);
|
||||
break;
|
||||
case QuantizationTypes::UniformQuantized:
|
||||
printUniformQuantizedType(type.cast<UniformQuantizedType>(), os);
|
||||
break;
|
||||
case QuantizationTypes::UniformQuantizedPerAxis:
|
||||
printUniformQuantizedPerAxisType(type.cast<UniformQuantizedPerAxisType>(),
|
||||
os);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,48 +19,33 @@ static bool isQuantizablePrimitiveType(Type inputType) {
|
||||
|
||||
const ExpressedToQuantizedConverter
|
||||
ExpressedToQuantizedConverter::forInputType(Type inputType) {
|
||||
switch (inputType.getKind()) {
|
||||
default:
|
||||
if (isQuantizablePrimitiveType(inputType)) {
|
||||
// Supported primitive type (which just is the expressed type).
|
||||
return ExpressedToQuantizedConverter{inputType, inputType};
|
||||
}
|
||||
// Unsupported.
|
||||
return ExpressedToQuantizedConverter{inputType, nullptr};
|
||||
case StandardTypes::RankedTensor:
|
||||
case StandardTypes::UnrankedTensor:
|
||||
case StandardTypes::Vector: {
|
||||
if (inputType.isa<TensorType, VectorType>()) {
|
||||
Type elementType = inputType.cast<ShapedType>().getElementType();
|
||||
if (!isQuantizablePrimitiveType(elementType)) {
|
||||
// Unsupported.
|
||||
if (!isQuantizablePrimitiveType(elementType))
|
||||
return ExpressedToQuantizedConverter{inputType, nullptr};
|
||||
}
|
||||
return ExpressedToQuantizedConverter{
|
||||
inputType, inputType.cast<ShapedType>().getElementType()};
|
||||
}
|
||||
return ExpressedToQuantizedConverter{inputType, elementType};
|
||||
}
|
||||
// Supported primitive type (which just is the expressed type).
|
||||
if (isQuantizablePrimitiveType(inputType))
|
||||
return ExpressedToQuantizedConverter{inputType, inputType};
|
||||
// Unsupported.
|
||||
return ExpressedToQuantizedConverter{inputType, nullptr};
|
||||
}
|
||||
|
||||
Type ExpressedToQuantizedConverter::convert(QuantizedType elementalType) const {
|
||||
assert(expressedType && "convert() on unsupported conversion");
|
||||
|
||||
switch (inputType.getKind()) {
|
||||
default:
|
||||
if (elementalType.getExpressedType() == expressedType) {
|
||||
// If the expressed types match, just use the new elemental type.
|
||||
return elementalType;
|
||||
}
|
||||
// Unsupported.
|
||||
return nullptr;
|
||||
case StandardTypes::RankedTensor:
|
||||
return RankedTensorType::get(inputType.cast<RankedTensorType>().getShape(),
|
||||
elementalType);
|
||||
case StandardTypes::UnrankedTensor:
|
||||
if (auto tensorType = inputType.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), elementalType);
|
||||
if (auto tensorType = inputType.dyn_cast<UnrankedTensorType>())
|
||||
return UnrankedTensorType::get(elementalType);
|
||||
case StandardTypes::Vector:
|
||||
return VectorType::get(inputType.cast<VectorType>().getShape(),
|
||||
elementalType);
|
||||
}
|
||||
if (auto vectorType = inputType.dyn_cast<VectorType>())
|
||||
return VectorType::get(vectorType.getShape(), elementalType);
|
||||
|
||||
// If the expressed types match, just use the new elemental type.
|
||||
if (elementalType.getExpressedType() == expressedType)
|
||||
return elementalType;
|
||||
// Unsupported.
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ElementsAttr
|
||||
|
||||
@@ -78,20 +78,17 @@ Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
|
||||
size = alignment;
|
||||
return type;
|
||||
}
|
||||
|
||||
switch (type.getKind()) {
|
||||
case spirv::TypeKind::Struct:
|
||||
return decorateType(type.cast<spirv::StructType>(), size, alignment);
|
||||
case spirv::TypeKind::Array:
|
||||
return decorateType(type.cast<spirv::ArrayType>(), size, alignment);
|
||||
case StandardTypes::Vector:
|
||||
return decorateType(type.cast<VectorType>(), size, alignment);
|
||||
case spirv::TypeKind::RuntimeArray:
|
||||
if (auto structType = type.dyn_cast<spirv::StructType>())
|
||||
return decorateType(structType, size, alignment);
|
||||
if (auto arrayType = type.dyn_cast<spirv::ArrayType>())
|
||||
return decorateType(arrayType, size, alignment);
|
||||
if (auto vectorType = type.dyn_cast<VectorType>())
|
||||
return decorateType(vectorType, size, alignment);
|
||||
if (auto arrayType = type.dyn_cast<spirv::RuntimeArrayType>()) {
|
||||
size = std::numeric_limits<Size>().max();
|
||||
return decorateType(type.cast<spirv::RuntimeArrayType>(), alignment);
|
||||
default:
|
||||
llvm_unreachable("unhandled SPIR-V type");
|
||||
return decorateType(arrayType, alignment);
|
||||
}
|
||||
llvm_unreachable("unhandled SPIR-V type");
|
||||
}
|
||||
|
||||
Type VulkanLayoutUtils::decorateType(VectorType vectorType,
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
namespace mlir {
|
||||
@@ -727,31 +728,11 @@ static void print(MatrixType type, DialectAsmPrinter &os) {
|
||||
}
|
||||
|
||||
void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
|
||||
switch (type.getKind()) {
|
||||
case TypeKind::Array:
|
||||
print(type.cast<ArrayType>(), os);
|
||||
return;
|
||||
case TypeKind::CooperativeMatrix:
|
||||
print(type.cast<CooperativeMatrixNVType>(), os);
|
||||
return;
|
||||
case TypeKind::Pointer:
|
||||
print(type.cast<PointerType>(), os);
|
||||
return;
|
||||
case TypeKind::RuntimeArray:
|
||||
print(type.cast<RuntimeArrayType>(), os);
|
||||
return;
|
||||
case TypeKind::Image:
|
||||
print(type.cast<ImageType>(), os);
|
||||
return;
|
||||
case TypeKind::Struct:
|
||||
print(type.cast<StructType>(), os);
|
||||
return;
|
||||
case TypeKind::Matrix:
|
||||
print(type.cast<MatrixType>(), os);
|
||||
return;
|
||||
default:
|
||||
llvm_unreachable("unhandled SPIR-V type");
|
||||
}
|
||||
TypeSwitch<Type>(type)
|
||||
.Case<ArrayType, CooperativeMatrixNVType, PointerType, RuntimeArrayType,
|
||||
ImageType, StructType, MatrixType>(
|
||||
[&](auto type) { print(type, os); })
|
||||
.Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -1534,8 +1534,7 @@ bool spirv::ConstantOp::isBuildableWith(Type type) {
|
||||
if (!type.isa<spirv::SPIRVType>())
|
||||
return false;
|
||||
|
||||
if (type.getKind() >= Type::FIRST_SPIRV_TYPE &&
|
||||
type.getKind() <= spirv::TypeKind::LAST_SPIRV_TYPE) {
|
||||
if (isa<SPIRVDialect>(type.getDialect())) {
|
||||
// TODO: support constant struct
|
||||
return type.isa<spirv::ArrayType>();
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::spirv;
|
||||
@@ -163,18 +164,11 @@ Optional<int64_t> ArrayType::getSizeInBytes() {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool CompositeType::classof(Type type) {
|
||||
switch (type.getKind()) {
|
||||
case TypeKind::Array:
|
||||
case TypeKind::CooperativeMatrix:
|
||||
case TypeKind::Matrix:
|
||||
case TypeKind::RuntimeArray:
|
||||
case TypeKind::Struct:
|
||||
return true;
|
||||
case StandardTypes::Vector:
|
||||
return isValid(type.cast<VectorType>());
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
if (auto vectorType = type.dyn_cast<VectorType>())
|
||||
return isValid(vectorType);
|
||||
return type
|
||||
.isa<spirv::ArrayType, spirv::CooperativeMatrixNVType, spirv::MatrixType,
|
||||
spirv::RuntimeArrayType, spirv::StructType>();
|
||||
}
|
||||
|
||||
bool CompositeType::isValid(VectorType type) {
|
||||
@@ -183,22 +177,14 @@ bool CompositeType::isValid(VectorType type) {
|
||||
}
|
||||
|
||||
Type CompositeType::getElementType(unsigned index) const {
|
||||
switch (getKind()) {
|
||||
case spirv::TypeKind::Array:
|
||||
return cast<ArrayType>().getElementType();
|
||||
case spirv::TypeKind::CooperativeMatrix:
|
||||
return cast<CooperativeMatrixNVType>().getElementType();
|
||||
case spirv::TypeKind::Matrix:
|
||||
return cast<MatrixType>().getColumnType();
|
||||
case spirv::TypeKind::RuntimeArray:
|
||||
return cast<RuntimeArrayType>().getElementType();
|
||||
case spirv::TypeKind::Struct:
|
||||
return cast<StructType>().getElementType(index);
|
||||
case StandardTypes::Vector:
|
||||
return cast<VectorType>().getElementType();
|
||||
default:
|
||||
llvm_unreachable("invalid composite type");
|
||||
}
|
||||
return TypeSwitch<Type, Type>(*this)
|
||||
.Case<ArrayType, CooperativeMatrixNVType, RuntimeArrayType, VectorType>(
|
||||
[](auto type) { return type.getElementType(); })
|
||||
.Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
|
||||
.Case<StructType>(
|
||||
[index](StructType type) { return type.getElementType(index); })
|
||||
.Default(
|
||||
[](Type) -> Type { llvm_unreachable("invalid composite type"); });
|
||||
}
|
||||
|
||||
unsigned CompositeType::getNumElements() const {
|
||||
|
||||
@@ -123,16 +123,16 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
|
||||
|
||||
// Returns the type kind if the given type is a vector or ranked tensor type.
|
||||
// Returns llvm::None otherwise.
|
||||
auto getCompositeTypeKind = [](Type type) -> Optional<StandardTypes::Kind> {
|
||||
auto getCompositeTypeKind = [](Type type) -> Optional<TypeID> {
|
||||
if (type.isa<VectorType, RankedTensorType>())
|
||||
return static_cast<StandardTypes::Kind>(type.getKind());
|
||||
return type.getTypeID();
|
||||
return llvm::None;
|
||||
};
|
||||
|
||||
// Make sure the composite type, if has, is consistent.
|
||||
auto compositeKind1 = getCompositeTypeKind(type1);
|
||||
auto compositeKind2 = getCompositeTypeKind(type2);
|
||||
Optional<StandardTypes::Kind> resultCompositeKind;
|
||||
Optional<TypeID> compositeKind1 = getCompositeTypeKind(type1);
|
||||
Optional<TypeID> compositeKind2 = getCompositeTypeKind(type2);
|
||||
Optional<TypeID> resultCompositeKind;
|
||||
|
||||
if (compositeKind1 && compositeKind2) {
|
||||
// Disallow mixing vector and tensor.
|
||||
@@ -151,9 +151,9 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
|
||||
return {};
|
||||
|
||||
// Compose the final broadcasted type
|
||||
if (resultCompositeKind == StandardTypes::Vector)
|
||||
if (resultCompositeKind == VectorType::getTypeID())
|
||||
return VectorType::get(resultShape, elementType);
|
||||
if (resultCompositeKind == StandardTypes::RankedTensor)
|
||||
if (resultCompositeKind == RankedTensorType::getTypeID())
|
||||
return RankedTensorType::get(resultShape, elementType);
|
||||
return elementType;
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
|
||||
@@ -244,16 +245,11 @@ int64_t ShapedType::getSizeInBits() const {
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> ShapedType::getShape() const {
|
||||
switch (getKind()) {
|
||||
case StandardTypes::Vector:
|
||||
return cast<VectorType>().getShape();
|
||||
case StandardTypes::RankedTensor:
|
||||
return cast<RankedTensorType>().getShape();
|
||||
case StandardTypes::MemRef:
|
||||
return cast<MemRefType>().getShape();
|
||||
default:
|
||||
llvm_unreachable("not a ShapedType or not ranked");
|
||||
}
|
||||
if (auto vectorType = dyn_cast<VectorType>())
|
||||
return vectorType.getShape();
|
||||
if (auto tensorType = dyn_cast<RankedTensorType>())
|
||||
return tensorType.getShape();
|
||||
return cast<MemRefType>().getShape();
|
||||
}
|
||||
|
||||
int64_t ShapedType::getNumDynamicDims() const {
|
||||
@@ -305,13 +301,23 @@ ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
|
||||
|
||||
// Check if "elementType" can be an element type of a tensor. Emit errors if
|
||||
// location is not nullptr. Returns failure if check failed.
|
||||
static inline LogicalResult checkTensorElementType(Location location,
|
||||
Type elementType) {
|
||||
static LogicalResult checkTensorElementType(Location location,
|
||||
Type elementType) {
|
||||
if (!TensorType::isValidElementType(elementType))
|
||||
return emitError(location, "invalid tensor element type");
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Return true if the specified element type is ok in a tensor.
|
||||
bool TensorType::isValidElementType(Type type) {
|
||||
// Note: Non standard/builtin types are allowed to exist within tensor
|
||||
// types. Dialects are expected to verify that tensor types have a valid
|
||||
// element type within that dialect.
|
||||
return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
|
||||
IndexType>() ||
|
||||
!type.getDialect().getNamespace().empty();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RankedTensorType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Reference in New Issue
Block a user