[mlir] Refactor ShapedType into an interface

ShapedType was created in a time before interfaces, and is one of the earliest
type base classes in the ecosystem. This commit refactors ShapedType into
an interface, which is what it would have been if interfaces had existed at that
time. The API of ShapedType and it's derived classes are essentially untouched
by this refactor, with the exception being the API surrounding kDynamicIndex
(which requires a sole home).

For now, the API of ShapedType and its name have been kept as consistent to
the current state of the world as possible (to help with potential migration churn,
among other reasons). Moving forward though, we should look into potentially
restructuring its API and possible its name as well (it should really have "Interface"
at the end like other interfaces at the very least).

One other potentially interesting note is that I've attached the ShapedType::Trait
to TensorType/BaseMemRefType to act as mixins for the ShapedType API. This
is kind of weird, but allows for sharing the same API (i.e. preventing API loss from
the transition from base class -> Interface). This inheritance doesn't affect any
of the derived classes, it is just for API mixin.

Differential Revision: https://reviews.llvm.org/D116962
This commit is contained in:
River Riddle
2022-01-10 10:55:57 -08:00
parent a60e83fe7c
commit 676bfb2a22
14 changed files with 403 additions and 314 deletions

View File

@@ -41,4 +41,151 @@ def MemRefElementTypeInterface : TypeInterface<"MemRefElementTypeInterface"> {
}];
}
//===----------------------------------------------------------------------===//
// ShapedType
//===----------------------------------------------------------------------===//
def ShapedTypeInterface : TypeInterface<"ShapedType"> {
let cppNamespace = "::mlir";
let description = [{
This interface provides a common API for interacting with multi-dimensional
container types. These types contain a shape and an element type.
A shape is a list of sizes corresponding to the dimensions of the container.
If the number of dimensions in the shape is unknown, the shape is "unranked".
If the number of dimensions is known, the shape "ranked". The sizes of the
dimensions of the shape must be positive, or kDynamicSize (in which case the
size of the dimension is dynamic, or not statically known).
}];
let methods = [
InterfaceMethod<[{
Returns a clone of this type with the given shape and element
type. If a shape is not provided, the current shape of the type is used.
}],
"::mlir::ShapedType", "cloneWith", (ins
"::llvm::Optional<::llvm::ArrayRef<int64_t>>":$shape,
"::mlir::Type":$elementType
)>,
InterfaceMethod<[{
Returns the element type of this shaped type.
}],
"::mlir::Type", "getElementType">,
InterfaceMethod<[{
Returns if this type is ranked, i.e. it has a known number of dimensions.
}],
"bool", "hasRank">,
InterfaceMethod<[{
Returns the shape of this type if it is ranked, otherwise asserts.
}],
"::llvm::ArrayRef<int64_t>", "getShape">,
];
let extraClassDeclaration = [{
// TODO: merge these two special values in a single one used everywhere.
// Unfortunately, uses of `-1` have crept deep into the codebase now and are
// hard to track.
static constexpr int64_t kDynamicSize = -1;
static constexpr int64_t kDynamicStrideOrOffset =
std::numeric_limits<int64_t>::min();
/// Whether the given dimension size indicates a dynamic dimension.
static constexpr bool isDynamic(int64_t dSize) {
return dSize == kDynamicSize;
}
static constexpr bool isDynamicStrideOrOffset(int64_t dStrideOrOffset) {
return dStrideOrOffset == kDynamicStrideOrOffset;
}
/// Return the number of elements present in the given shape.
static int64_t getNumElements(ArrayRef<int64_t> shape);
/// Returns the total amount of bits occupied by a value of this type. This
/// does not take into account any memory layout or widening constraints,
/// e.g. a vector<3xi57> may report to occupy 3x57=171 bit, even though in
/// practice it will likely be stored as in a 4xi64 vector register. Fails
/// with an assertion if the size cannot be computed statically, e.g. if the
/// type has a dynamic shape or if its elemental type does not have a known
/// bit width.
int64_t getSizeInBits() const;
}];
let extraSharedClassDeclaration = [{
/// Return a clone of this type with the given new shape and element type.
auto clone(::llvm::ArrayRef<int64_t> shape, Type elementType) {
return $_type.cloneWith(shape, elementType);
}
/// Return a clone of this type with the given new shape.
auto clone(::llvm::ArrayRef<int64_t> shape) {
return $_type.cloneWith(shape, $_type.getElementType());
}
/// Return a clone of this type with the given new element type.
auto clone(::mlir::Type elementType) {
return $_type.cloneWith(/*shape=*/llvm::None, elementType);
}
/// If an element type is an integer or a float, return its width. Otherwise,
/// abort.
unsigned getElementTypeBitWidth() const {
return $_type.getElementType().getIntOrFloatBitWidth();
}
/// If this is a ranked type, return the rank. Otherwise, abort.
int64_t getRank() const {
assert($_type.hasRank() && "cannot query rank of unranked shaped type");
return $_type.getShape().size();
}
/// If it has static shape, return the number of elements. Otherwise, abort.
int64_t getNumElements() const {
assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
return ::mlir::ShapedType::getNumElements($_type.getShape());
}
/// Returns true if this dimension has a dynamic size (for ranked types);
/// aborts for unranked types.
bool isDynamicDim(unsigned idx) const {
assert(idx < getRank() && "invalid index for shaped type");
return ::mlir::ShapedType::isDynamic($_type.getShape()[idx]);
}
/// Returns if this type has a static shape, i.e. if the type is ranked and
/// all dimensions have known size (>= 0).
bool hasStaticShape() const {
return $_type.hasRank() &&
llvm::none_of($_type.getShape(), ::mlir::ShapedType::isDynamic);
}
/// Returns if this type has a static shape and the shape is equal to
/// `shape` return true.
bool hasStaticShape(::llvm::ArrayRef<int64_t> shape) const {
return hasStaticShape() && $_type.getShape() == shape;
}
/// If this is a ranked type, return the number of dimensions with dynamic
/// size. Otherwise, abort.
int64_t getNumDynamicDims() const {
return llvm::count_if($_type.getShape(), ::mlir::ShapedType::isDynamic);
}
/// If this is ranked type, return the size of the specified dimension.
/// Otherwise, abort.
int64_t getDimSize(unsigned idx) const {
assert(idx < getRank() && "invalid index for shaped type");
return $_type.getShape()[idx];
}
/// Returns the position of the dynamic dimension relative to just the dynamic
/// dimensions, given its `index` within the shape.
unsigned getDynamicDimIndex(unsigned index) const {
assert(index < getRank() && "invalid index");
assert(::mlir::ShapedType::isDynamic(getDimSize(index)) && "invalid index");
return llvm::count_if($_type.getShape().take_front(index),
::mlir::ShapedType::isDynamic);
}
}];
}
#endif // MLIR_IR_BUILTINTYPEINTERFACES_TD_

View File

@@ -16,6 +16,12 @@ namespace llvm {
struct fltSemantics;
} // namespace llvm
//===----------------------------------------------------------------------===//
// Tablegen Interface Declarations
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinTypeInterfaces.h.inc"
namespace mlir {
class AffineExpr;
class AffineMap;
@@ -55,119 +61,68 @@ public:
const llvm::fltSemantics &getFloatSemantics();
};
//===----------------------------------------------------------------------===//
// ShapedType
//===----------------------------------------------------------------------===//
/// This is a common base class between Vector, UnrankedTensor, RankedTensor,
/// and MemRef types because they share behavior and semantics around shape,
/// rank, and fixed element type. Any type with these semantics should inherit
/// from ShapedType.
class ShapedType : public Type {
public:
using Type::Type;
// TODO: merge these two special values in a single one used everywhere.
// Unfortunately, uses of `-1` have crept deep into the codebase now and are
// hard to track.
static constexpr int64_t kDynamicSize = -1;
static constexpr int64_t kDynamicStrideOrOffset =
std::numeric_limits<int64_t>::min();
/// Return clone of this type with new shape and element type.
ShapedType clone(ArrayRef<int64_t> shape, Type elementType);
ShapedType clone(ArrayRef<int64_t> shape);
ShapedType clone(Type elementType);
/// Return the element type.
Type getElementType() const;
/// If an element type is an integer or a float, return its width. Otherwise,
/// abort.
unsigned getElementTypeBitWidth() const;
/// If it has static shape, return the number of elements. Otherwise, abort.
int64_t getNumElements() const;
/// If this is a ranked type, return the rank. Otherwise, abort.
int64_t getRank() const;
/// Whether or not this is a ranked type. Memrefs, vectors and ranked tensors
/// have a rank, while unranked tensors do not.
bool hasRank() const;
/// If this is a ranked type, return the shape. Otherwise, abort.
ArrayRef<int64_t> getShape() const;
/// If this is unranked type or any dimension has unknown size (<0), it
/// doesn't have static shape. If all dimensions have known size (>= 0), it
/// has static shape.
bool hasStaticShape() const;
/// If this has a static shape and the shape is equal to `shape` return true.
bool hasStaticShape(ArrayRef<int64_t> shape) const;
/// If this is a ranked type, return the number of dimensions with dynamic
/// size. Otherwise, abort.
int64_t getNumDynamicDims() const;
/// If this is ranked type, return the size of the specified dimension.
/// Otherwise, abort.
int64_t getDimSize(unsigned idx) const;
/// Returns true if this dimension has a dynamic size (for ranked types);
/// aborts for unranked types.
bool isDynamicDim(unsigned idx) const;
/// Returns the position of the dynamic dimension relative to just the dynamic
/// dimensions, given its `index` within the shape.
unsigned getDynamicDimIndex(unsigned index) const;
/// Get the total amount of bits occupied by a value of this type. This does
/// not take into account any memory layout or widening constraints, e.g. a
/// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice
/// it will likely be stored as in a 4xi64 vector register. Fail an assertion
/// if the size cannot be computed statically, i.e. if the type has a dynamic
/// shape or if its elemental type does not have a known bit width.
int64_t getSizeInBits() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type);
/// Whether the given dimension size indicates a dynamic dimension.
static constexpr bool isDynamic(int64_t dSize) {
return dSize == kDynamicSize;
}
static constexpr bool isDynamicStrideOrOffset(int64_t dStrideOrOffset) {
return dStrideOrOffset == kDynamicStrideOrOffset;
}
};
//===----------------------------------------------------------------------===//
// TensorType
//===----------------------------------------------------------------------===//
/// Tensor types represent multi-dimensional arrays, and have two variants:
/// RankedTensorType and UnrankedTensorType.
class TensorType : public ShapedType {
/// Note: This class attaches the ShapedType trait to act as a mixin to
/// provide many useful utility functions. This inheritance has no effect
/// on derived tensor types.
class TensorType : public Type, public ShapedType::Trait<TensorType> {
public:
using ShapedType::ShapedType;
using Type::Type;
/// Returns the element type of this tensor type.
Type getElementType() const;
/// Returns if this type is ranked, i.e. it has a known number of dimensions.
bool hasRank() const;
/// Returns the shape of this tensor type.
ArrayRef<int64_t> getShape() const;
/// Clone this type with the given shape and element type. If the
/// provided shape is `None`, the current shape of the type is used.
TensorType cloneWith(Optional<ArrayRef<int64_t>> shape,
Type elementType) const;
/// Return true if the specified element type is ok in a tensor.
static bool isValidElementType(Type type);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type);
/// Allow implicit conversion to ShapedType.
operator ShapedType() const { return cast<ShapedType>(); }
};
//===----------------------------------------------------------------------===//
// BaseMemRefType
//===----------------------------------------------------------------------===//
/// Base MemRef for Ranked and Unranked variants
class BaseMemRefType : public ShapedType {
/// This class provides a shared interface for ranked and unranked memref types.
/// Note: This class attaches the ShapedType trait to act as a mixin to
/// provide many useful utility functions. This inheritance has no effect
/// on derived memref types.
class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
public:
using ShapedType::ShapedType;
using Type::Type;
/// Returns the element type of this memref type.
Type getElementType() const;
/// Returns if this type is ranked, i.e. it has a known number of dimensions.
bool hasRank() const;
/// Returns the shape of this memref type.
ArrayRef<int64_t> getShape() const;
/// Clone this type with the given shape and element type. If the
/// provided shape is `None`, the current shape of the type is used.
BaseMemRefType cloneWith(Optional<ArrayRef<int64_t>> shape,
Type elementType) const;
/// Return true if the specified element type is ok in a memref.
static bool isValidElementType(Type type);
@@ -181,6 +136,9 @@ public:
/// [deprecated] Returns the memory space in old raw integer representation.
/// New `Attribute getMemorySpace()` method should be used instead.
unsigned getMemorySpaceAsInt() const;
/// Allow implicit conversion to ShapedType.
operator ShapedType() const { return cast<ShapedType>(); }
};
} // namespace mlir
@@ -192,12 +150,6 @@ public:
#define GET_TYPEDEF_CLASSES
#include "mlir/IR/BuiltinTypes.h.inc"
//===----------------------------------------------------------------------===//
// Tablegen Interface Declarations
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinTypeInterfaces.h.inc"
namespace mlir {
//===----------------------------------------------------------------------===//
@@ -439,11 +391,6 @@ inline FloatType FloatType::getF128(MLIRContext *ctx) {
return Float128Type::get(ctx);
}
inline bool ShapedType::classof(Type type) {
return type.isa<RankedTensorType, VectorType, UnrankedTensorType,
UnrankedMemRefType, MemRefType>();
}
inline bool TensorType::classof(Type type) {
return type.isa<RankedTensorType, UnrankedTensorType>();
}

View File

@@ -266,7 +266,7 @@ def Builtin_Integer : Builtin_Type<"Integer"> {
//===----------------------------------------------------------------------===//
def Builtin_MemRef : Builtin_Type<"MemRef", [
DeclareTypeInterfaceMethods<SubElementTypeInterface>
DeclareTypeInterfaceMethods<SubElementTypeInterface>, ShapedTypeInterface
], "BaseMemRefType"> {
let summary = "Shaped reference to a region of memory";
let description = [{
@@ -541,6 +541,16 @@ def Builtin_MemRef : Builtin_Type<"MemRef", [
"unsigned":$memorySpaceInd)>
];
let extraClassDeclaration = [{
using ShapedType::Trait<MemRefType>::clone;
using ShapedType::Trait<MemRefType>::getElementTypeBitWidth;
using ShapedType::Trait<MemRefType>::getRank;
using ShapedType::Trait<MemRefType>::getNumElements;
using ShapedType::Trait<MemRefType>::isDynamicDim;
using ShapedType::Trait<MemRefType>::hasStaticShape;
using ShapedType::Trait<MemRefType>::getNumDynamicDims;
using ShapedType::Trait<MemRefType>::getDimSize;
using ShapedType::Trait<MemRefType>::getDynamicDimIndex;
/// This is a builder type that keeps local references to arguments.
/// Arguments that are passed into the builder must outlive the builder.
class Builder;
@@ -620,7 +630,7 @@ def Builtin_Opaque : Builtin_Type<"Opaque"> {
//===----------------------------------------------------------------------===//
def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [
DeclareTypeInterfaceMethods<SubElementTypeInterface>
DeclareTypeInterfaceMethods<SubElementTypeInterface>, ShapedTypeInterface
], "TensorType"> {
let summary = "Multi-dimensional array with a fixed number of dimensions";
let description = [{
@@ -702,6 +712,16 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [
}]>
];
let extraClassDeclaration = [{
using ShapedType::Trait<RankedTensorType>::clone;
using ShapedType::Trait<RankedTensorType>::getElementTypeBitWidth;
using ShapedType::Trait<RankedTensorType>::getRank;
using ShapedType::Trait<RankedTensorType>::getNumElements;
using ShapedType::Trait<RankedTensorType>::isDynamicDim;
using ShapedType::Trait<RankedTensorType>::hasStaticShape;
using ShapedType::Trait<RankedTensorType>::getNumDynamicDims;
using ShapedType::Trait<RankedTensorType>::getDimSize;
using ShapedType::Trait<RankedTensorType>::getDynamicDimIndex;
/// This is a builder type that keeps local references to arguments.
/// Arguments that are passed into the builder must outlive the builder.
class Builder;
@@ -784,7 +804,7 @@ def Builtin_Tuple : Builtin_Type<"Tuple", [
//===----------------------------------------------------------------------===//
def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [
DeclareTypeInterfaceMethods<SubElementTypeInterface>
DeclareTypeInterfaceMethods<SubElementTypeInterface>, ShapedTypeInterface
], "BaseMemRefType"> {
let summary = "Shaped reference, with unknown rank, to a region of memory";
let description = [{
@@ -831,6 +851,16 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [
}]>
];
let extraClassDeclaration = [{
using ShapedType::Trait<UnrankedMemRefType>::clone;
using ShapedType::Trait<UnrankedMemRefType>::getElementTypeBitWidth;
using ShapedType::Trait<UnrankedMemRefType>::getRank;
using ShapedType::Trait<UnrankedMemRefType>::getNumElements;
using ShapedType::Trait<UnrankedMemRefType>::isDynamicDim;
using ShapedType::Trait<UnrankedMemRefType>::hasStaticShape;
using ShapedType::Trait<UnrankedMemRefType>::getNumDynamicDims;
using ShapedType::Trait<UnrankedMemRefType>::getDimSize;
using ShapedType::Trait<UnrankedMemRefType>::getDynamicDimIndex;
ArrayRef<int64_t> getShape() const { return llvm::None; }
/// [deprecated] Returns the memory space in old raw integer representation.
@@ -846,7 +876,7 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [
//===----------------------------------------------------------------------===//
def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", [
DeclareTypeInterfaceMethods<SubElementTypeInterface>
DeclareTypeInterfaceMethods<SubElementTypeInterface>, ShapedTypeInterface
], "TensorType"> {
let summary = "Multi-dimensional array with unknown dimensions";
let description = [{
@@ -874,6 +904,16 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", [
}]>
];
let extraClassDeclaration = [{
using ShapedType::Trait<UnrankedTensorType>::clone;
using ShapedType::Trait<UnrankedTensorType>::getElementTypeBitWidth;
using ShapedType::Trait<UnrankedTensorType>::getRank;
using ShapedType::Trait<UnrankedTensorType>::getNumElements;
using ShapedType::Trait<UnrankedTensorType>::isDynamicDim;
using ShapedType::Trait<UnrankedTensorType>::hasStaticShape;
using ShapedType::Trait<UnrankedTensorType>::getNumDynamicDims;
using ShapedType::Trait<UnrankedTensorType>::getDimSize;
using ShapedType::Trait<UnrankedTensorType>::getDynamicDimIndex;
ArrayRef<int64_t> getShape() const { return llvm::None; }
}];
let skipDefaultBuilders = 1;
@@ -885,8 +925,8 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", [
//===----------------------------------------------------------------------===//
def Builtin_Vector : Builtin_Type<"Vector", [
DeclareTypeInterfaceMethods<SubElementTypeInterface>
], "ShapedType"> {
DeclareTypeInterfaceMethods<SubElementTypeInterface>, ShapedTypeInterface
], "Type"> {
let summary = "Multi-dimensional SIMD vector type";
let description = [{
Syntax:
@@ -966,6 +1006,14 @@ def Builtin_Vector : Builtin_Type<"Vector", [
/// element type of bitwidth scaled by `scale`.
/// Return null if the scaled element type cannot be represented.
VectorType scaleElementBitwidth(unsigned scale);
/// Returns if this type is ranked (always true).
bool hasRank() const { return true; }
/// Clone this vector type with the given shape and element type. If the
/// provided shape is `None`, the current shape of the type is used.
VectorType cloneWith(Optional<ArrayRef<int64_t>> shape,
Type elementType);
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;

View File

@@ -51,10 +51,10 @@ MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
auto result = getStridesAndOffset(type, strides, offset);
(void)result;
assert(succeeded(result) && "unexpected failure in stride computation");
assert(!MemRefType::isDynamicStrideOrOffset(offset) &&
assert(!ShapedType::isDynamicStrideOrOffset(offset) &&
"expected static offset");
assert(!llvm::any_of(strides, [](int64_t stride) {
return MemRefType::isDynamicStrideOrOffset(stride);
return ShapedType::isDynamicStrideOrOffset(stride);
}) && "expected static strides");
auto convertedType = typeConverter.convertType(type);

View File

@@ -79,14 +79,14 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
Value index;
if (offset != 0) // Skip if offset is zero.
index = MemRefType::isDynamicStrideOrOffset(offset)
index = ShapedType::isDynamicStrideOrOffset(offset)
? memRefDescriptor.offset(rewriter, loc)
: createIndexConstant(rewriter, loc, offset);
for (int i = 0, e = indices.size(); i < e; ++i) {
Value increment = indices[i];
if (strides[i] != 1) { // Skip if stride is 1.
Value stride = MemRefType::isDynamicStrideOrOffset(strides[i])
Value stride = ShapedType::isDynamicStrideOrOffset(strides[i])
? memRefDescriptor.stride(rewriter, loc, i)
: createIndexConstant(rewriter, loc, strides[i]);
increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);

View File

@@ -106,7 +106,7 @@ struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
Operation *op) const {
uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op);
for (unsigned i = 0, e = type.getRank(); i < e; i++) {
if (type.isDynamic(type.getDimSize(i)))
if (ShapedType::isDynamic(type.getDimSize(i)))
continue;
sizeDivisor = sizeDivisor * type.getDimSize(i);
}
@@ -1467,7 +1467,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
ArrayRef<int64_t> strides, Value nextSize,
Value runningStride, unsigned idx) const {
assert(idx < strides.size());
if (!MemRefType::isDynamicStrideOrOffset(strides[idx]))
if (!ShapedType::isDynamicStrideOrOffset(strides[idx]))
return createIndexConstant(rewriter, loc, strides[idx]);
if (nextSize)
return runningStride

View File

@@ -342,22 +342,22 @@ bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
auto ss = std::get<0>(it), st = std::get<1>(it);
if (ss != st)
if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st))
if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
return false;
}
// If cast is towards more static offset along any dimension, don't fold.
if (sourceOffset != resultOffset)
if (MemRefType::isDynamicStrideOrOffset(sourceOffset) &&
!MemRefType::isDynamicStrideOrOffset(resultOffset))
if (ShapedType::isDynamicStrideOrOffset(sourceOffset) &&
!ShapedType::isDynamicStrideOrOffset(resultOffset))
return false;
// If cast is towards more static strides along any dimension, don't fold.
for (auto it : llvm::zip(sourceStrides, resultStrides)) {
auto ss = std::get<0>(it), st = std::get<1>(it);
if (ss != st)
if (MemRefType::isDynamicStrideOrOffset(ss) &&
!MemRefType::isDynamicStrideOrOffset(st))
if (ShapedType::isDynamicStrideOrOffset(ss) &&
!ShapedType::isDynamicStrideOrOffset(st))
return false;
}

View File

@@ -518,7 +518,7 @@ static void genBuffers(Merger &merger, CodeGen &codegen,
// Find upper bound in current dimension.
unsigned p = perm(enc, d);
Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p);
if (shape[p] == MemRefType::kDynamicSize)
if (ShapedType::isDynamic(shape[p]))
args.push_back(up);
assert(codegen.highs[tensor][idx] == nullptr);
codegen.sizes[idx] = codegen.highs[tensor][idx] = up;

View File

@@ -268,13 +268,12 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
fromElements.getResult().getType().cast<RankedTensorType>();
// The case where the type encodes the size of the dimension is handled
// above.
assert(resultType.getShape()[index.getInt()] ==
RankedTensorType::kDynamicSize);
assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
// Find the operand of the fromElements that corresponds to this index.
auto dynExtents = fromElements.dynamicExtents().begin();
for (auto dim : resultType.getShape().take_front(index.getInt()))
if (dim == RankedTensorType::kDynamicSize)
if (ShapedType::isDynamic(dim))
dynExtents++;
return Value{*dynExtents};
@@ -523,13 +522,13 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
auto operandsIt = tensorFromElements.dynamicExtents().begin();
for (int64_t dim : resultType.getShape()) {
if (dim != RankedTensorType::kDynamicSize) {
if (!ShapedType::isDynamic(dim)) {
newShape.push_back(dim);
continue;
}
APInt index;
if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
newShape.push_back(RankedTensorType::kDynamicSize);
newShape.push_back(ShapedType::kDynamicSize);
newOperands.push_back(*operandsIt++);
continue;
}
@@ -661,7 +660,7 @@ static LogicalResult verify(ReshapeOp op) {
return op.emitOpError("source and destination tensor should have the "
"same number of elements");
}
if (shapeSize == TensorType::kDynamicSize)
if (ShapedType::isDynamic(shapeSize))
return op.emitOpError("cannot use shape operand with dynamic length to "
"reshape to statically-ranked tensor type");
if (shapeSize != resultRankedType.getRank())

View File

@@ -172,13 +172,13 @@ static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
resStrides(bT.getRank(), 0);
for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
resShape[idx] =
(aShape[idx] == bShape[idx]) ? aShape[idx] : MemRefType::kDynamicSize;
(aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamicSize;
resStrides[idx] = (aStrides[idx] == bStrides[idx])
? aStrides[idx]
: MemRefType::kDynamicStrideOrOffset;
: ShapedType::kDynamicStrideOrOffset;
}
resOffset =
(aOffset == bOffset) ? aOffset : MemRefType::kDynamicStrideOrOffset;
(aOffset == bOffset) ? aOffset : ShapedType::kDynamicStrideOrOffset;
return MemRefType::get(
resShape, aT.getElementType(),
makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext()));

View File

@@ -0,0 +1,51 @@
//===- BuiltinTypeInterfaces.cpp ------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "llvm/ADT/Sequence.h"
using namespace mlir;
using namespace mlir::detail;
//===----------------------------------------------------------------------===//
/// Tablegen Interface Definitions
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinTypeInterfaces.cpp.inc"
//===----------------------------------------------------------------------===//
// ShapedType
//===----------------------------------------------------------------------===//
constexpr int64_t ShapedType::kDynamicSize;
constexpr int64_t ShapedType::kDynamicStrideOrOffset;
int64_t ShapedType::getNumElements(ArrayRef<int64_t> shape) {
int64_t num = 1;
for (int64_t dim : shape) {
num *= dim;
assert(num >= 0 && "integer overflow in element count computation");
}
return num;
}
int64_t ShapedType::getSizeInBits() const {
assert(hasStaticShape() &&
"cannot get the bit size of an aggregate with a dynamic shape");
auto elementType = getElementType();
if (elementType.isIntOrFloat())
return elementType.getIntOrFloatBitWidth() * getNumElements();
if (auto complexType = elementType.dyn_cast<ComplexType>()) {
elementType = complexType.getElementType();
return elementType.getIntOrFloatBitWidth() * getNumElements() * 2;
}
return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
}

View File

@@ -32,12 +32,6 @@ using namespace mlir::detail;
#define GET_TYPEDEF_CLASSES
#include "mlir/IR/BuiltinTypes.cpp.inc"
//===----------------------------------------------------------------------===//
/// Tablegen Interface Definitions
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinTypeInterfaces.cpp.inc"
//===----------------------------------------------------------------------===//
// BuiltinDialect
//===----------------------------------------------------------------------===//
@@ -271,171 +265,6 @@ LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
//===----------------------------------------------------------------------===//
// ShapedType
//===----------------------------------------------------------------------===//
constexpr int64_t ShapedType::kDynamicSize;
constexpr int64_t ShapedType::kDynamicStrideOrOffset;
ShapedType ShapedType::clone(ArrayRef<int64_t> shape, Type elementType) {
if (auto other = dyn_cast<MemRefType>()) {
MemRefType::Builder b(other);
b.setShape(shape);
b.setElementType(elementType);
return b;
}
if (auto other = dyn_cast<UnrankedMemRefType>()) {
MemRefType::Builder b(shape, elementType);
b.setMemorySpace(other.getMemorySpace());
return b;
}
if (isa<TensorType>())
return RankedTensorType::get(shape, elementType);
if (auto vecTy = dyn_cast<VectorType>())
return VectorType::get(shape, elementType, vecTy.getNumScalableDims());
llvm_unreachable("Unhandled ShapedType clone case");
}
ShapedType ShapedType::clone(ArrayRef<int64_t> shape) {
if (auto other = dyn_cast<MemRefType>()) {
MemRefType::Builder b(other);
b.setShape(shape);
return b;
}
if (auto other = dyn_cast<UnrankedMemRefType>()) {
MemRefType::Builder b(shape, other.getElementType());
b.setShape(shape);
b.setMemorySpace(other.getMemorySpace());
return b;
}
if (isa<TensorType>())
return RankedTensorType::get(shape, getElementType());
if (auto vecTy = dyn_cast<VectorType>())
return VectorType::get(shape, getElementType(), vecTy.getNumScalableDims());
llvm_unreachable("Unhandled ShapedType clone case");
}
ShapedType ShapedType::clone(Type elementType) {
if (auto other = dyn_cast<MemRefType>()) {
MemRefType::Builder b(other);
b.setElementType(elementType);
return b;
}
if (auto other = dyn_cast<UnrankedMemRefType>()) {
return UnrankedMemRefType::get(elementType, other.getMemorySpace());
}
if (isa<TensorType>()) {
if (hasRank())
return RankedTensorType::get(getShape(), elementType);
return UnrankedTensorType::get(elementType);
}
if (auto vecTy = dyn_cast<VectorType>())
return VectorType::get(getShape(), elementType, vecTy.getNumScalableDims());
llvm_unreachable("Unhandled ShapedType clone hit");
}
Type ShapedType::getElementType() const {
return TypeSwitch<Type, Type>(*this)
.Case<VectorType, RankedTensorType, UnrankedTensorType, MemRefType,
UnrankedMemRefType>([](auto ty) { return ty.getElementType(); });
}
unsigned ShapedType::getElementTypeBitWidth() const {
return getElementType().getIntOrFloatBitWidth();
}
int64_t ShapedType::getNumElements() const {
assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
auto shape = getShape();
int64_t num = 1;
for (auto dim : shape) {
num *= dim;
assert(num >= 0 && "integer overflow in element count computation");
}
return num;
}
int64_t ShapedType::getRank() const {
assert(hasRank() && "cannot query rank of unranked shaped type");
return getShape().size();
}
bool ShapedType::hasRank() const {
return !isa<UnrankedMemRefType, UnrankedTensorType>();
}
int64_t ShapedType::getDimSize(unsigned idx) const {
assert(idx < getRank() && "invalid index for shaped type");
return getShape()[idx];
}
bool ShapedType::isDynamicDim(unsigned idx) const {
assert(idx < getRank() && "invalid index for shaped type");
return isDynamic(getShape()[idx]);
}
unsigned ShapedType::getDynamicDimIndex(unsigned index) const {
assert(index < getRank() && "invalid index");
assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index");
return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic);
}
/// Get the number of bits require to store a value of the given shaped type.
/// Compute the value recursively since tensors are allowed to have vectors as
/// elements.
int64_t ShapedType::getSizeInBits() const {
assert(hasStaticShape() &&
"cannot get the bit size of an aggregate with a dynamic shape");
auto elementType = getElementType();
if (elementType.isIntOrFloat())
return elementType.getIntOrFloatBitWidth() * getNumElements();
if (auto complexType = elementType.dyn_cast<ComplexType>()) {
elementType = complexType.getElementType();
return elementType.getIntOrFloatBitWidth() * getNumElements() * 2;
}
// Tensors can have vectors and other tensors as elements, other shaped types
// cannot.
assert(isa<TensorType>() && "unsupported element type");
assert((elementType.isa<VectorType, TensorType>()) &&
"unsupported tensor element type");
return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
}
ArrayRef<int64_t> ShapedType::getShape() const {
if (auto vectorType = dyn_cast<VectorType>())
return vectorType.getShape();
if (auto tensorType = dyn_cast<RankedTensorType>())
return tensorType.getShape();
return cast<MemRefType>().getShape();
}
int64_t ShapedType::getNumDynamicDims() const {
return llvm::count_if(getShape(), isDynamic);
}
bool ShapedType::hasStaticShape() const {
return hasRank() && llvm::none_of(getShape(), isDynamic);
}
bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
return hasStaticShape() && getShape() == shape;
}
//===----------------------------------------------------------------------===//
// VectorType
//===----------------------------------------------------------------------===//
@@ -474,10 +303,44 @@ void VectorType::walkImmediateSubElements(
walkTypesFn(getElementType());
}
VectorType VectorType::cloneWith(Optional<ArrayRef<int64_t>> shape,
Type elementType) {
return VectorType::get(shape.getValueOr(getShape()), elementType,
getNumScalableDims());
}
//===----------------------------------------------------------------------===//
// TensorType
//===----------------------------------------------------------------------===//
Type TensorType::getElementType() const {
return llvm::TypeSwitch<TensorType, Type>(*this)
.Case<RankedTensorType, UnrankedTensorType>(
[](auto type) { return type.getElementType(); });
}
bool TensorType::hasRank() const { return !isa<UnrankedTensorType>(); }
ArrayRef<int64_t> TensorType::getShape() const {
return cast<RankedTensorType>().getShape();
}
TensorType TensorType::cloneWith(Optional<ArrayRef<int64_t>> shape,
Type elementType) const {
if (auto unrankedTy = dyn_cast<UnrankedTensorType>()) {
if (shape)
return RankedTensorType::get(*shape, elementType);
return UnrankedTensorType::get(elementType);
}
auto rankedTy = cast<RankedTensorType>();
if (!shape)
return RankedTensorType::get(rankedTy.getShape(), elementType,
rankedTy.getEncoding());
return RankedTensorType::get(shape.getValueOr(rankedTy.getShape()),
elementType, rankedTy.getEncoding());
}
// Check if "elementType" can be an element type of a tensor.
static LogicalResult
checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
@@ -542,6 +405,35 @@ void UnrankedTensorType::walkImmediateSubElements(
// BaseMemRefType
//===----------------------------------------------------------------------===//
Type BaseMemRefType::getElementType() const {
return llvm::TypeSwitch<BaseMemRefType, Type>(*this)
.Case<MemRefType, UnrankedMemRefType>(
[](auto type) { return type.getElementType(); });
}
bool BaseMemRefType::hasRank() const { return !isa<UnrankedMemRefType>(); }
ArrayRef<int64_t> BaseMemRefType::getShape() const {
return cast<MemRefType>().getShape();
}
BaseMemRefType BaseMemRefType::cloneWith(Optional<ArrayRef<int64_t>> shape,
Type elementType) const {
if (auto unrankedTy = dyn_cast<UnrankedMemRefType>()) {
if (!shape)
return UnrankedMemRefType::get(elementType, getMemorySpace());
MemRefType::Builder builder(*shape, elementType);
builder.setMemorySpace(getMemorySpace());
return builder;
}
MemRefType::Builder builder(cast<MemRefType>());
if (shape)
builder.setShape(*shape);
builder.setElementType(elementType);
return builder;
}
Attribute BaseMemRefType::getMemorySpace() const {
if (auto rankedMemRefTy = dyn_cast<MemRefType>())
return rankedMemRefTy.getMemorySpace();

View File

@@ -9,6 +9,7 @@ add_mlir_library(MLIRIR
BuiltinAttributes.cpp
BuiltinDialect.cpp
BuiltinTypes.cpp
BuiltinTypeInterfaces.cpp
Diagnostics.cpp
Dialect.cpp
Dominance.cpp

View File

@@ -30,14 +30,14 @@ TEST(ShapedTypeTest, CloneMemref) {
AffineMap map = makeStridedLinearLayoutMap({2, 3}, 5, &context);
ShapedType memrefType =
MemRefType::Builder(memrefOriginalShape, memrefOriginalType)
(ShapedType)MemRefType::Builder(memrefOriginalShape, memrefOriginalType)
.setMemorySpace(memSpace)
.setLayout(AffineMapAttr::get(map));
// Update shape.
llvm::SmallVector<int64_t> memrefNewShape({30, 40});
ASSERT_NE(memrefOriginalShape, memrefNewShape);
ASSERT_EQ(memrefType.clone(memrefNewShape),
(MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType)
(ShapedType)MemRefType::Builder(memrefNewShape, memrefOriginalType)
.setMemorySpace(memSpace)
.setLayout(AffineMapAttr::get(map)));
// Update type.
@@ -81,25 +81,29 @@ TEST(ShapedTypeTest, CloneTensor) {
// Update shape.
llvm::SmallVector<int64_t> tensorNewShape({30, 40});
ASSERT_NE(tensorOriginalShape, tensorNewShape);
ASSERT_EQ(tensorType.clone(tensorNewShape),
RankedTensorType::get(tensorNewShape, tensorOriginalType));
ASSERT_EQ(
tensorType.clone(tensorNewShape),
(ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));
// Update type.
Type tensorNewType = f32;
ASSERT_NE(tensorOriginalType, tensorNewType);
ASSERT_EQ(tensorType.clone(tensorNewType),
RankedTensorType::get(tensorOriginalShape, tensorNewType));
ASSERT_EQ(
tensorType.clone(tensorNewType),
(ShapedType)RankedTensorType::get(tensorOriginalShape, tensorNewType));
// Update both.
ASSERT_EQ(tensorType.clone(tensorNewShape, tensorNewType),
RankedTensorType::get(tensorNewShape, tensorNewType));
(ShapedType)RankedTensorType::get(tensorNewShape, tensorNewType));
// Test unranked tensor cloning.
ShapedType unrankedTensorType = UnrankedTensorType::get(tensorOriginalType);
ASSERT_EQ(unrankedTensorType.clone(tensorNewShape),
RankedTensorType::get(tensorNewShape, tensorOriginalType));
ASSERT_EQ(
unrankedTensorType.clone(tensorNewShape),
(ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));
ASSERT_EQ(unrankedTensorType.clone(tensorNewType),
UnrankedTensorType::get(tensorNewType));
ASSERT_EQ(unrankedTensorType.clone(tensorNewShape),
RankedTensorType::get(tensorNewShape, tensorOriginalType));
(ShapedType)UnrankedTensorType::get(tensorNewType));
ASSERT_EQ(
unrankedTensorType.clone(tensorNewShape),
(ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType));
}
TEST(ShapedTypeTest, CloneVector) {