[mlir][Type] Remove usages of Type::getKind

This is in preparation for removing the use of "kinds" within attributes and types in MLIR.

Differential Revision: https://reviews.llvm.org/D85475
This commit is contained in:
River Riddle
2020-08-07 13:30:43 -07:00
parent fff39b62bb
commit c8c45985fb
14 changed files with 114 additions and 171 deletions

View File

@@ -101,10 +101,7 @@ public:
}
/// Support for isa/cast.
static bool classof(Type type) {
return type.getKind() >= FIRST_NEW_LLVM_TYPE &&
type.getKind() <= LAST_NEW_LLVM_TYPE;
}
static bool classof(Type type);
LLVMDialect &getDialect();

View File

@@ -71,10 +71,7 @@ public:
int64_t storageTypeMax);
/// Support method to enable LLVM-style type casting.
static bool classof(Type type) {
return type.getKind() >= Type::FIRST_QUANTIZATION_TYPE &&
type.getKind() <= QuantizationTypes::LAST_USED_QUANTIZATION_TYPE;
}
static bool classof(Type type);
/// Gets the minimum possible stored by a storageType. storageTypeMin must
/// be greater than or equal to this value.

View File

@@ -294,13 +294,7 @@ public:
int64_t getSizeInBits() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type) {
return type.getKind() == StandardTypes::Vector ||
type.getKind() == StandardTypes::RankedTensor ||
type.getKind() == StandardTypes::UnrankedTensor ||
type.getKind() == StandardTypes::UnrankedMemRef ||
type.getKind() == StandardTypes::MemRef;
}
static bool classof(Type type);
/// Whether the given dimension size indicates a dynamic dimension.
static constexpr bool isDynamic(int64_t dSize) {
@@ -358,20 +352,10 @@ public:
using ShapedType::ShapedType;
/// Return true if the specified element type is ok in a tensor.
static bool isValidElementType(Type type) {
// Note: Non standard/builtin types are allowed to exist within tensor
// types. Dialects are expected to verify that tensor types have a valid
// element type within that dialect.
return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
IndexType>() ||
(type.getKind() > Type::Kind::LAST_STANDARD_TYPE);
}
static bool isValidElementType(Type type);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type) {
return type.getKind() == StandardTypes::RankedTensor ||
type.getKind() == StandardTypes::UnrankedTensor;
}
static bool classof(Type type);
};
//===----------------------------------------------------------------------===//
@@ -443,10 +427,7 @@ public:
using ShapedType::ShapedType;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type) {
return type.getKind() == StandardTypes::MemRef ||
type.getKind() == StandardTypes::UnrankedMemRef;
}
static bool classof(Type type);
};
//===----------------------------------------------------------------------===//
@@ -629,6 +610,23 @@ public:
}
};
//===----------------------------------------------------------------------===//
// Deferred Method Definitions
//===----------------------------------------------------------------------===//
inline bool BaseMemRefType::classof(Type type) {
return type.isa<MemRefType, UnrankedMemRefType>();
}
inline bool ShapedType::classof(Type type) {
return type.isa<RankedTensorType, VectorType, UnrankedTensorType,
UnrankedMemRefType, MemRefType>();
}
inline bool TensorType::classof(Type type) {
return type.isa<RankedTensorType, UnrankedTensorType>();
}
//===----------------------------------------------------------------------===//
// Type Utilities
//===----------------------------------------------------------------------===//

View File

@@ -27,6 +27,10 @@ using namespace mlir::LLVM;
// LLVMType.
//===----------------------------------------------------------------------===//
bool LLVMType::classof(Type type) {
return llvm::isa<LLVMDialect>(type.getDialect());
}
LLVMDialect &LLVMType::getDialect() {
return static_cast<LLVMDialect &>(Type::getDialect());
}

View File

@@ -55,11 +55,5 @@ static void print(RangeType rt, DialectAsmPrinter &os) { os << "range"; }
void mlir::linalg::LinalgDialect::printType(Type type,
DialectAsmPrinter &os) const {
switch (type.getKind()) {
default:
llvm_unreachable("Unhandled Linalg type");
case LinalgTypes::Range:
print(type.cast<RangeType>(), os);
break;
}
print(type.cast<RangeType>(), os);
}

View File

@@ -8,6 +8,7 @@
#include "mlir/Dialect/Quant/QuantTypes.h"
#include "TypeDetail.h"
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/StandardTypes.h"
@@ -23,6 +24,10 @@ unsigned QuantizedType::getFlags() const {
return static_cast<ImplType *>(impl)->flags;
}
bool QuantizedType::classof(Type type) {
return llvm::isa<QuantizationDialect>(type.getDialect());
}
LogicalResult QuantizedType::verifyConstructionInvariants(
Location loc, unsigned flags, Type storageType, Type expressedType,
int64_t storageTypeMin, int64_t storageTypeMax) {

View File

@@ -365,18 +365,12 @@ static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type,
/// Print a type registered to this dialect.
void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const {
switch (type.getKind()) {
default:
if (auto anyType = type.dyn_cast<AnyQuantizedType>())
printAnyQuantizedType(anyType, os);
else if (auto uniformType = type.dyn_cast<UniformQuantizedType>())
printUniformQuantizedType(uniformType, os);
else if (auto perAxisType = type.dyn_cast<UniformQuantizedPerAxisType>())
printUniformQuantizedPerAxisType(perAxisType, os);
else
llvm_unreachable("Unhandled quantized type");
case QuantizationTypes::Any:
printAnyQuantizedType(type.cast<AnyQuantizedType>(), os);
break;
case QuantizationTypes::UniformQuantized:
printUniformQuantizedType(type.cast<UniformQuantizedType>(), os);
break;
case QuantizationTypes::UniformQuantizedPerAxis:
printUniformQuantizedPerAxisType(type.cast<UniformQuantizedPerAxisType>(),
os);
break;
}
}

View File

@@ -19,48 +19,33 @@ static bool isQuantizablePrimitiveType(Type inputType) {
const ExpressedToQuantizedConverter
ExpressedToQuantizedConverter::forInputType(Type inputType) {
switch (inputType.getKind()) {
default:
if (isQuantizablePrimitiveType(inputType)) {
// Supported primitive type (which just is the expressed type).
return ExpressedToQuantizedConverter{inputType, inputType};
}
// Unsupported.
return ExpressedToQuantizedConverter{inputType, nullptr};
case StandardTypes::RankedTensor:
case StandardTypes::UnrankedTensor:
case StandardTypes::Vector: {
if (inputType.isa<TensorType, VectorType>()) {
Type elementType = inputType.cast<ShapedType>().getElementType();
if (!isQuantizablePrimitiveType(elementType)) {
// Unsupported.
if (!isQuantizablePrimitiveType(elementType))
return ExpressedToQuantizedConverter{inputType, nullptr};
}
return ExpressedToQuantizedConverter{
inputType, inputType.cast<ShapedType>().getElementType()};
}
return ExpressedToQuantizedConverter{inputType, elementType};
}
// Supported primitive type (which just is the expressed type).
if (isQuantizablePrimitiveType(inputType))
return ExpressedToQuantizedConverter{inputType, inputType};
// Unsupported.
return ExpressedToQuantizedConverter{inputType, nullptr};
}
Type ExpressedToQuantizedConverter::convert(QuantizedType elementalType) const {
assert(expressedType && "convert() on unsupported conversion");
switch (inputType.getKind()) {
default:
if (elementalType.getExpressedType() == expressedType) {
// If the expressed types match, just use the new elemental type.
return elementalType;
}
// Unsupported.
return nullptr;
case StandardTypes::RankedTensor:
return RankedTensorType::get(inputType.cast<RankedTensorType>().getShape(),
elementalType);
case StandardTypes::UnrankedTensor:
if (auto tensorType = inputType.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), elementalType);
if (auto tensorType = inputType.dyn_cast<UnrankedTensorType>())
return UnrankedTensorType::get(elementalType);
case StandardTypes::Vector:
return VectorType::get(inputType.cast<VectorType>().getShape(),
elementalType);
}
if (auto vectorType = inputType.dyn_cast<VectorType>())
return VectorType::get(vectorType.getShape(), elementalType);
// If the expressed types match, just use the new elemental type.
if (elementalType.getExpressedType() == expressedType)
return elementalType;
// Unsupported.
return nullptr;
}
ElementsAttr

View File

@@ -78,20 +78,17 @@ Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
size = alignment;
return type;
}
switch (type.getKind()) {
case spirv::TypeKind::Struct:
return decorateType(type.cast<spirv::StructType>(), size, alignment);
case spirv::TypeKind::Array:
return decorateType(type.cast<spirv::ArrayType>(), size, alignment);
case StandardTypes::Vector:
return decorateType(type.cast<VectorType>(), size, alignment);
case spirv::TypeKind::RuntimeArray:
if (auto structType = type.dyn_cast<spirv::StructType>())
return decorateType(structType, size, alignment);
if (auto arrayType = type.dyn_cast<spirv::ArrayType>())
return decorateType(arrayType, size, alignment);
if (auto vectorType = type.dyn_cast<VectorType>())
return decorateType(vectorType, size, alignment);
if (auto arrayType = type.dyn_cast<spirv::RuntimeArrayType>()) {
size = std::numeric_limits<Size>().max();
return decorateType(type.cast<spirv::RuntimeArrayType>(), alignment);
default:
llvm_unreachable("unhandled SPIR-V type");
return decorateType(arrayType, alignment);
}
llvm_unreachable("unhandled SPIR-V type");
}
Type VulkanLayoutUtils::decorateType(VectorType vectorType,

View File

@@ -26,6 +26,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"
namespace mlir {
@@ -727,31 +728,11 @@ static void print(MatrixType type, DialectAsmPrinter &os) {
}
void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
switch (type.getKind()) {
case TypeKind::Array:
print(type.cast<ArrayType>(), os);
return;
case TypeKind::CooperativeMatrix:
print(type.cast<CooperativeMatrixNVType>(), os);
return;
case TypeKind::Pointer:
print(type.cast<PointerType>(), os);
return;
case TypeKind::RuntimeArray:
print(type.cast<RuntimeArrayType>(), os);
return;
case TypeKind::Image:
print(type.cast<ImageType>(), os);
return;
case TypeKind::Struct:
print(type.cast<StructType>(), os);
return;
case TypeKind::Matrix:
print(type.cast<MatrixType>(), os);
return;
default:
llvm_unreachable("unhandled SPIR-V type");
}
TypeSwitch<Type>(type)
.Case<ArrayType, CooperativeMatrixNVType, PointerType, RuntimeArrayType,
ImageType, StructType, MatrixType>(
[&](auto type) { print(type, os); })
.Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
}
//===----------------------------------------------------------------------===//

View File

@@ -1534,8 +1534,7 @@ bool spirv::ConstantOp::isBuildableWith(Type type) {
if (!type.isa<spirv::SPIRVType>())
return false;
if (type.getKind() >= Type::FIRST_SPIRV_TYPE &&
type.getKind() <= spirv::TypeKind::LAST_SPIRV_TYPE) {
if (isa<SPIRVDialect>(type.getDialect())) {
// TODO: support constant struct
return type.isa<spirv::ArrayType>();
}

View File

@@ -18,6 +18,7 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::spirv;
@@ -163,18 +164,11 @@ Optional<int64_t> ArrayType::getSizeInBytes() {
//===----------------------------------------------------------------------===//
bool CompositeType::classof(Type type) {
switch (type.getKind()) {
case TypeKind::Array:
case TypeKind::CooperativeMatrix:
case TypeKind::Matrix:
case TypeKind::RuntimeArray:
case TypeKind::Struct:
return true;
case StandardTypes::Vector:
return isValid(type.cast<VectorType>());
default:
return false;
}
if (auto vectorType = type.dyn_cast<VectorType>())
return isValid(vectorType);
return type
.isa<spirv::ArrayType, spirv::CooperativeMatrixNVType, spirv::MatrixType,
spirv::RuntimeArrayType, spirv::StructType>();
}
bool CompositeType::isValid(VectorType type) {
@@ -183,22 +177,14 @@ bool CompositeType::isValid(VectorType type) {
}
Type CompositeType::getElementType(unsigned index) const {
switch (getKind()) {
case spirv::TypeKind::Array:
return cast<ArrayType>().getElementType();
case spirv::TypeKind::CooperativeMatrix:
return cast<CooperativeMatrixNVType>().getElementType();
case spirv::TypeKind::Matrix:
return cast<MatrixType>().getColumnType();
case spirv::TypeKind::RuntimeArray:
return cast<RuntimeArrayType>().getElementType();
case spirv::TypeKind::Struct:
return cast<StructType>().getElementType(index);
case StandardTypes::Vector:
return cast<VectorType>().getElementType();
default:
llvm_unreachable("invalid composite type");
}
return TypeSwitch<Type, Type>(*this)
.Case<ArrayType, CooperativeMatrixNVType, RuntimeArrayType, VectorType>(
[](auto type) { return type.getElementType(); })
.Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
.Case<StructType>(
[index](StructType type) { return type.getElementType(index); })
.Default(
[](Type) -> Type { llvm_unreachable("invalid composite type"); });
}
unsigned CompositeType::getNumElements() const {

View File

@@ -123,16 +123,16 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
// Returns the type kind if the given type is a vector or ranked tensor type.
// Returns llvm::None otherwise.
auto getCompositeTypeKind = [](Type type) -> Optional<StandardTypes::Kind> {
auto getCompositeTypeKind = [](Type type) -> Optional<TypeID> {
if (type.isa<VectorType, RankedTensorType>())
return static_cast<StandardTypes::Kind>(type.getKind());
return type.getTypeID();
return llvm::None;
};
// Make sure the composite type, if has, is consistent.
auto compositeKind1 = getCompositeTypeKind(type1);
auto compositeKind2 = getCompositeTypeKind(type2);
Optional<StandardTypes::Kind> resultCompositeKind;
Optional<TypeID> compositeKind1 = getCompositeTypeKind(type1);
Optional<TypeID> compositeKind2 = getCompositeTypeKind(type2);
Optional<TypeID> resultCompositeKind;
if (compositeKind1 && compositeKind2) {
// Disallow mixing vector and tensor.
@@ -151,9 +151,9 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
return {};
// Compose the final broadcasted type
if (resultCompositeKind == StandardTypes::Vector)
if (resultCompositeKind == VectorType::getTypeID())
return VectorType::get(resultShape, elementType);
if (resultCompositeKind == StandardTypes::RankedTensor)
if (resultCompositeKind == RankedTensorType::getTypeID())
return RankedTensorType::get(resultShape, elementType);
return elementType;
}

View File

@@ -11,6 +11,7 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/Twine.h"
@@ -244,16 +245,11 @@ int64_t ShapedType::getSizeInBits() const {
}
ArrayRef<int64_t> ShapedType::getShape() const {
switch (getKind()) {
case StandardTypes::Vector:
return cast<VectorType>().getShape();
case StandardTypes::RankedTensor:
return cast<RankedTensorType>().getShape();
case StandardTypes::MemRef:
return cast<MemRefType>().getShape();
default:
llvm_unreachable("not a ShapedType or not ranked");
}
if (auto vectorType = dyn_cast<VectorType>())
return vectorType.getShape();
if (auto tensorType = dyn_cast<RankedTensorType>())
return tensorType.getShape();
return cast<MemRefType>().getShape();
}
int64_t ShapedType::getNumDynamicDims() const {
@@ -305,13 +301,23 @@ ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
// Check if "elementType" can be an element type of a tensor. Emit errors if
// location is not nullptr. Returns failure if check failed.
static inline LogicalResult checkTensorElementType(Location location,
Type elementType) {
static LogicalResult checkTensorElementType(Location location,
Type elementType) {
if (!TensorType::isValidElementType(elementType))
return emitError(location, "invalid tensor element type");
return success();
}
/// Return true if the specified element type is ok in a tensor.
bool TensorType::isValidElementType(Type type) {
// Note: Non standard/builtin types are allowed to exist within tensor
// types. Dialects are expected to verify that tensor types have a valid
// element type within that dialect.
return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
IndexType>() ||
!type.getDialect().getNamespace().empty();
}
//===----------------------------------------------------------------------===//
// RankedTensorType
//===----------------------------------------------------------------------===//