diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 710fd460267c..b5e2185c206d 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -46,14 +46,14 @@ public: Function, SplatElements, + DenseIntElements, + DenseFPElements, FIRST_ELEMENTS_ATTR = SplatElements, - LAST_ELEMENTS_ATTR = SplatElements, + LAST_ELEMENTS_ATTR = DenseFPElements, }; /// Return the classification for this attribute. - Kind getKind() const { - return kind; - } + Kind getKind() const { return kind; } /// Return true if this field is, or contains, a function attribute. bool isOrContainsFunction() const { return isOrContainsFunctionCache; } @@ -74,8 +74,8 @@ private: /// This field is true if this is, or contains, a function attribute. bool isOrContainsFunctionCache : 1; - Attribute(const Attribute&) = delete; - void operator=(const Attribute&) = delete; + Attribute(const Attribute &) = delete; + void operator=(const Attribute &) = delete; }; inline raw_ostream &operator<<(raw_ostream &os, const Attribute &attr) { @@ -87,14 +87,13 @@ class BoolAttr : public Attribute { public: static BoolAttr *get(bool value, MLIRContext *context); - bool getValue() const { - return value; - } + bool getValue() const { return value; } /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(const Attribute *attr) { return attr->getKind() == Kind::Bool; } + private: BoolAttr(bool value) : Attribute(Kind::Bool, /*isOrContainsFunction=*/false), value(value) {} @@ -106,14 +105,13 @@ class IntegerAttr : public Attribute { public: static IntegerAttr *get(int64_t value, MLIRContext *context); - int64_t getValue() const { - return value; - } + int64_t getValue() const { return value; } /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(const Attribute *attr) { return attr->getKind() == Kind::Integer; } + private: IntegerAttr(int64_t value) : Attribute(Kind::Integer, /*isOrContainsFunction=*/false), value(value) { @@ -130,14 +128,13 @@ public: // correctness, otherwise constant folding will be done with host math. This // is completely incorrect for BF16 and other datatypes, and subtly wrong // for float32. - double getValue() const { - return value; - } + double getValue() const { return value; } /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(const Attribute *attr) { return attr->getKind() == Kind::Float; } + private: FloatAttr(double value) : Attribute(Kind::Float, /*isOrContainsFunction=*/false), value(value) {} @@ -149,14 +146,13 @@ class StringAttr : public Attribute { public: static StringAttr *get(StringRef bytes, MLIRContext *context); - StringRef getValue() const { - return value; - } + StringRef getValue() const { return value; } /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(const Attribute *attr) { return attr->getKind() == Kind::String; } + private: StringAttr(StringRef value) : Attribute(Kind::String, /*isOrContainsFunction=*/false), value(value) {} @@ -168,21 +164,20 @@ private: /// type homogenous given that attributes don't, in general, carry types. class ArrayAttr : public Attribute { public: - static ArrayAttr *get(ArrayRef value, MLIRContext *context); + static ArrayAttr *get(ArrayRef value, MLIRContext *context); - ArrayRef getValue() const { - return value; - } + ArrayRef getValue() const { return value; } /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(const Attribute *attr) { return attr->getKind() == Kind::Array; } + private: ArrayAttr(ArrayRef value, bool isOrContainsFunction) : Attribute(Kind::Array, isOrContainsFunction), value(value) {} ~ArrayAttr() = delete; - ArrayRef value; + ArrayRef value; }; class AffineMapAttr : public Attribute { @@ -289,6 +284,98 @@ private: : ElementsAttr(Kind::SplatElements, type), elt(elt) {} Attribute *elt; }; + +/// An attribute represents a reference to a dense vector or tensor object. +/// +/// This class is designed to store elements with any bit widths equal or less +/// than 64. +class DenseElementsAttr : public ElementsAttr { +public: + /// It assumes the elements in the input array have been truncated to the bits + /// width specified by the element type (note all float type are 64 bits). + /// When the value is retrieved, the bits are read from the storage and extend + /// to 64 bits if necessary. + static DenseElementsAttr *get(VectorOrTensorType *type, ArrayRef data); + + // TODO: Read the data from the attribute list and compress them + // to a character array. Then call the above method to construct the + // attribute. + static DenseElementsAttr *get(VectorOrTensorType *type, + ArrayRef values); + + void getValues(SmallVectorImpl &values) const; + + ArrayRef getRawData() const { return data; } + + /// Method for support type inquiry through isa, cast and dyn_cast. + static bool classof(const Attribute *attr) { + return attr->getKind() == Kind::DenseIntElements || + attr->getKind() == Kind::DenseFPElements; + } + +protected: + DenseElementsAttr(Kind kind, VectorOrTensorType *type, ArrayRef data) + : ElementsAttr(kind, type), data(data) {} + +private: + ArrayRef data; +}; + +/// An attribute represents a reference to a dense integer vector or tensor +/// object. +class DenseIntElementsAttr : public DenseElementsAttr { +public: + DenseIntElementsAttr(VectorOrTensorType *type, ArrayRef data, + size_t bitsWidth) + : DenseElementsAttr(Kind::DenseIntElements, type, data), + bitsWidth(bitsWidth) {} + + // TODO: returns APInts instead of IntegerAttr. + void getValues(SmallVectorImpl &values) const; + + APInt getValue(ArrayRef indices) const; + + /// Writes the lowest `bitWidth` bits of `value` to the bit position `bitPos` + /// in array `rawData`. + static void writeBits(char *rawData, size_t bitPos, size_t bitWidth, + uint64_t value); + + /// Reads the next `bitWidth` bits from the bit position `bitPos` in array + /// `rawData` and return them as the lowest bits of an uint64 integer. + static uint64_t readBits(const char *rawData, size_t bitPos, + size_t bitsWidth); + + /// Method for support type inquiry through isa, cast and dyn_cast. + static bool classof(const Attribute *attr) { + return attr->getKind() == Kind::DenseIntElements; + } + +private: + ~DenseIntElementsAttr() = delete; + + size_t bitsWidth; +}; + +/// An attribute represents a reference to a dense float vector or tensor +/// object. Each element is stored as a double. +class DenseFPElementsAttr : public DenseElementsAttr { +public: + DenseFPElementsAttr(VectorOrTensorType *type, ArrayRef data) + : DenseElementsAttr(Kind::DenseFPElements, type, data) {} + + // TODO: returns APFPs instead of FloatAttr. + void getValues(SmallVectorImpl &values) const; + + APFloat getValue(ArrayRef indices) const; + + /// Method for support type inquiry through isa, cast and dyn_cast. + static bool classof(const Attribute *attr) { + return attr->getKind() == Kind::DenseFPElements; + } + +private: + ~DenseFPElementsAttr() = delete; +}; } // end namespace mlir. #endif diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 347075978884..62c60e386151 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -100,6 +100,8 @@ public: TypeAttr *getTypeAttr(Type *type); FunctionAttr *getFunctionAttr(const Function *value); ElementsAttr *getSplatElementsAttr(VectorOrTensorType *type, Attribute *elt); + ElementsAttr *getDenseElementsAttr(VectorOrTensorType *type, + ArrayRef data); // Affine expressions and affine maps. AffineExpr getAffineDimExpr(unsigned position); diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h index b9a4e6e24070..3d0afdf607d5 100644 --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -92,6 +92,10 @@ public: /// Return true if this is an integer type with the specified width. bool isInteger(unsigned width) const; + /// Return the bitwidth of this type. For vector or tensor types, returns the + /// element type's bitwidth. + unsigned getBitWidth() const; + // Convenience factories. static IntegerType *getInteger(unsigned width, MLIRContext *ctx); static FloatType *getBF16(MLIRContext *ctx); @@ -293,6 +297,10 @@ class VectorOrTensorType : public Type { public: Type *getElementType() const { return elementType; } + /// If this is ranked tensor or vector type, return the number of elements. If + /// it is an unranked tensor or vector, abort. + unsigned getNumElements() const; + /// If this is ranked tensor or vector type, return the rank. If it is an /// unranked tensor, return -1. int getRank() const; @@ -466,7 +474,6 @@ static bool isValidTensorElementType(Type *type) { return isa(type) || isa(type) || isa(type) || isa(type); } - } // end namespace mlir #endif // MLIR_IR_TYPES_H diff --git a/mlir/include/mlir/Support/STLExtras.h b/mlir/include/mlir/Support/STLExtras.h index 4ea236fe3489..69e0b67249e3 100644 --- a/mlir/include/mlir/Support/STLExtras.h +++ b/mlir/include/mlir/Support/STLExtras.h @@ -41,8 +41,7 @@ namespace mlir { template inline void interleave(ForwardIterator begin, ForwardIterator end, - UnaryFunctor each_fn, - NullaryFunctor between_fn) { + UnaryFunctor each_fn, NullaryFunctor between_fn) { if (begin == end) return; each_fn(*begin); @@ -59,6 +58,11 @@ inline void interleave(const Container &c, UnaryFunctor each_fn, interleave(c.begin(), c.end(), each_fn, between_fn); } +template class Container, typename raw_ostream> +inline void interleaveComma(const Container &c, raw_ostream &os) { + interleave(c.begin(), c.end(), [&](T a) { os << a; }, [&]() { os << ", "; }); +} + } // end namespace mlir // Allow tuples to be usable as DenseMap keys. @@ -80,8 +84,7 @@ static inline unsigned llvm_combineHashValue(unsigned a, unsigned b) { } namespace llvm { -template -struct DenseMapInfo > { +template struct DenseMapInfo> { typedef std::tuple Tuple; static inline Tuple getEmptyKey() { @@ -92,34 +95,34 @@ struct DenseMapInfo > { return Tuple(DenseMapInfo::getTombstoneKey()...); } - template - static unsigned getHashValueImpl(const Tuple& values, std::false_type) { + template + static unsigned getHashValueImpl(const Tuple &values, std::false_type) { typedef typename std::tuple_element::type EltType; - std::integral_constant atEnd; + std::integral_constant atEnd; return llvm_combineHashValue( - DenseMapInfo::getHashValue(std::get(values)), - getHashValueImpl(values, atEnd)); + DenseMapInfo::getHashValue(std::get(values)), + getHashValueImpl(values, atEnd)); } - template - static unsigned getHashValueImpl(const Tuple& values, std::true_type) { + template + static unsigned getHashValueImpl(const Tuple &values, std::true_type) { return 0; } - static unsigned getHashValue(const std::tuple& values) { + static unsigned getHashValue(const std::tuple &values) { std::integral_constant atEnd; return getHashValueImpl<0>(values, atEnd); } - template + template static bool isEqualImpl(const Tuple &lhs, const Tuple &rhs, std::false_type) { typedef typename std::tuple_element::type EltType; - std::integral_constant atEnd; - return DenseMapInfo::isEqual(std::get(lhs), std::get(rhs)) - && isEqualImpl(lhs, rhs, atEnd); + std::integral_constant atEnd; + return DenseMapInfo::isEqual(std::get(lhs), std::get(rhs)) && + isEqualImpl(lhs, rhs, atEnd); } - template + template static bool isEqualImpl(const Tuple &lhs, const Tuple &rhs, std::true_type) { return true; } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 5fb69e08e2b9..a3efbf6eea8e 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -296,6 +296,7 @@ protected: void printAffineMapReference(AffineMap affineMap); void printIntegerSetId(int integerSetId) const; void printIntegerSetReference(IntegerSet integerSet); + void printDenseElementsAttr(const DenseElementsAttr *attr); /// This enum is used to represent the binding stength of the enclosing /// context that an AffineExprStorage is being printed in, so we can @@ -457,6 +458,16 @@ void ModulePrinter::printAttribute(const Attribute *attr) { } break; } + case Attribute::Kind::DenseIntElements: + case Attribute::Kind::DenseFPElements: { + auto *eltsAttr = cast(attr); + os << "dense<"; + printType(eltsAttr->getType()); + os << ", "; + printDenseElementsAttr(eltsAttr); + os << '>'; + break; + } case Attribute::Kind::SplatElements: { auto *elementsAttr = cast(attr); os << "splat<"; @@ -469,6 +480,59 @@ void ModulePrinter::printAttribute(const Attribute *attr) { } } +void ModulePrinter::printDenseElementsAttr(const DenseElementsAttr *attr) { + auto *type = attr->getType(); + auto shape = type->getShape(); + auto rank = type->getRank(); + + SmallVector elements; + attr->getValues(elements); + + // Special case for degenerate tensors. + if (elements.empty()) { + for (int i = 0; i < rank; ++i) + os << '['; + for (int i = 0; i < rank; ++i) + os << ']'; + return; + } + + // We use a mixed-radix counter to iterate through the shape. When we bump a + // non-least-significant digit, we emit a close bracket. When we next emit an + // element we re-open all closed brackets. + + // The mixed-radix counter, with radices in 'shape'. + SmallVector counter(rank, 0); + // The number of brackets that have been opened and not closed. + unsigned openBrackets = 0; + + auto bumpCounter = [&]() { + // Bump the least significant digit. + ++counter[rank - 1]; + // Iterate backwards bubbling back the increment. + for (unsigned i = rank - 1; i > 0; --i) + if (counter[i] >= shape[i]) { + // Index 'i' is rolled over. Bump (i-1) and close a bracket. + counter[i] = 0; + ++counter[i - 1]; + --openBrackets; + os << ']'; + } + }; + + for (unsigned idx = 0, e = elements.size(); idx != e; ++idx) { + if (idx != 0) + os << ", "; + while (openBrackets++ < rank) + os << '['; + openBrackets = rank; + printAttribute(elements[idx]); + bumpCounter(); + } + while (openBrackets-- > 0) + os << ']'; +} + void ModulePrinter::printType(const Type *type) { switch (type->getKind()) { case Type::Kind::Index: diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 8d1c02ef9ea3..cd541171c48a 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -149,6 +149,11 @@ ElementsAttr *Builder::getSplatElementsAttr(VectorOrTensorType *type, return SplatElementsAttr::get(type, elt); } +ElementsAttr *Builder::getDenseElementsAttr(VectorOrTensorType *type, + ArrayRef data) { + return DenseElementsAttr::get(type, data); +} + //===----------------------------------------------------------------------===// // Affine Expressions, Affine Maps, and Integet Sets. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index f2caa15eb813..059c686dbdc7 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -32,6 +32,7 @@ #include "mlir/IR/Types.h" #include "mlir/Support/MathExtras.h" #include "mlir/Support/STLExtras.h" +#include "llvm/ADT/APInt.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/StringMap.h" #include "llvm/Support/Allocator.h" @@ -180,6 +181,22 @@ struct AttributeListKeyInfo : DenseMapInfo { } }; +struct DenseElementsAttrInfo : DenseMapInfo { + using KeyTy = std::pair>; + using DenseMapInfo::getHashValue; + using DenseMapInfo::isEqual; + + static unsigned getHashValue(KeyTy key) { + return hash_combine( + key.first, hash_combine_range(key.second.begin(), key.second.end())); + } + + static bool isEqual(const KeyTy &lhs, const DenseElementsAttr *rhs) { + if (rhs == getEmptyKey() || rhs == getTombstoneKey()) + return false; + return lhs == std::make_pair(rhs->getType(), rhs->getRawData()); + } +}; } // end anonymous namespace. namespace mlir { @@ -277,6 +294,9 @@ public: DenseMap functionAttrs; DenseMap, SplatElementsAttr *> splatElementsAttrs; + using DenseElementsAttrSet = + DenseSet; + DenseElementsAttrSet denseElementsAttrs; public: MLIRContextImpl() : filenames(locationAllocator), identifiers(allocator) {} @@ -798,6 +818,139 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef attrs, return *existing.first = result; } +DenseElementsAttr *DenseElementsAttr::get(VectorOrTensorType *type, + ArrayRef data) { + auto bitsRequired = (long)type->getBitWidth() * type->getNumElements(); + assert((bitsRequired <= data.size() * 8L) && + "Input data bit size should be larger than that type requires"); + + auto &impl = type->getContext()->getImpl(); + + // Look to see if this constant is already defined. + DenseElementsAttrInfo::KeyTy key({type, data}); + auto existing = impl.denseElementsAttrs.insert_as(nullptr, key); + + // If we already have it, return that value. + if (!existing.second) + return *existing.first; + + // Otherwise, allocate a new one, unique it and return it. + auto *eltType = type->getElementType(); + switch (eltType->getKind()) { + case Type::Kind::BF16: + case Type::Kind::F16: + case Type::Kind::F32: + case Type::Kind::F64: { + auto *result = impl.allocator.Allocate(); + auto *copy = (char *)impl.allocator.Allocate(data.size(), 64); + std::uninitialized_copy(data.begin(), data.end(), copy); + new (result) DenseFPElementsAttr(type, {copy, data.size()}); + return *existing.first = result; + } + case Type::Kind::Integer: { + auto width = cast(eltType)->getWidth(); + auto *result = impl.allocator.Allocate(); + auto *copy = (char *)impl.allocator.Allocate(data.size(), 64); + std::uninitialized_copy(data.begin(), data.end(), copy); + new (result) DenseIntElementsAttr(type, {copy, data.size()}, width); + return *existing.first = result; + } + default: + llvm_unreachable("unexpected element type"); + } +} + +/// Writes the lowest `bitWidth` bits of `value` to bit position `bitPos` +/// starting from `rawData`. +void DenseIntElementsAttr::writeBits(char *data, size_t bitPos, size_t bitWidth, + uint64_t value) { + // Read the destination bytes which will be written to. + uint64_t dst = 0; + auto dstData = reinterpret_cast(&dst); + auto endPos = bitPos + bitWidth; + auto start = data + bitPos / 8; + auto end = data + endPos / 8 + (endPos % 8 != 0); + std::copy(start, end, dstData); + + // Clean up the invalid bits in the destination bytes. + dst &= ~(-1UL << (bitPos % 8)); + + // Get the valid bits of the source value, shift them to right position, + // then add them to the destination bytes. + value <<= bitPos % 8; + dst |= value; + + // Write the destination bytes back. + ArrayRef range({dstData, (size_t)(end - start)}); + std::copy(range.begin(), range.end(), start); +} + +/// Reads the next `bitWidth` bits from the bit position `bitPos` of `rawData` +/// and put them in the lowest bits. +uint64_t DenseIntElementsAttr::readBits(const char *rawData, size_t bitPos, + size_t bitsWidth) { + uint64_t dst = 0; + auto dstData = reinterpret_cast(&dst); + auto endPos = bitPos + bitsWidth; + auto start = rawData + bitPos / 8; + auto end = rawData + endPos / 8 + (endPos % 8 != 0); + std::copy(start, end, dstData); + + dst >>= bitPos % 8; + dst &= ~(-1UL << bitsWidth); + return dst; +} + +void DenseElementsAttr::getValues(SmallVectorImpl &values) const { + switch (getKind()) { + case Attribute::Kind::DenseIntElements: + cast(this)->getValues(values); + return; + case Attribute::Kind::DenseFPElements: + cast(this)->getValues(values); + return; + default: + llvm_unreachable("unexpected element type"); + } +} + +void DenseIntElementsAttr::getValues( + SmallVectorImpl &values) const { + auto elementNum = getType()->getNumElements(); + auto context = getType()->getContext(); + values.reserve(elementNum); + if (bitsWidth == 64) { + ArrayRef vs( + {reinterpret_cast(getRawData().data()), + getRawData().size() / 8}); + for (auto value : vs) { + auto *attr = IntegerAttr::get(value, context); + values.push_back(attr); + } + } else { + const auto *rawData = getRawData().data(); + for (size_t pos = 0; pos < elementNum * bitsWidth; pos += bitsWidth) { + uint64_t bits = readBits(rawData, pos, bitsWidth); + APInt value(bitsWidth, bits, /*isSigned=*/true); + auto *attr = IntegerAttr::get(value.getSExtValue(), context); + values.push_back(attr); + } + } +} + +void DenseFPElementsAttr::getValues( + SmallVectorImpl &values) const { + auto elementNum = getType()->getNumElements(); + auto context = getType()->getContext(); + ArrayRef vs({reinterpret_cast(getRawData().data()), + getRawData().size() / 8}); + values.reserve(elementNum); + for (auto v : vs) { + auto *attr = FloatAttr::get(v, context); + values.push_back(attr); + } +} + ElementsAttr *SplatElementsAttr::get(VectorOrTensorType *type, Attribute *elt) { auto &impl = type->getContext()->getImpl(); diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp index 6b9525f3dbec..0ad3f4728fe6 100644 --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -17,13 +17,35 @@ #include "mlir/IR/Types.h" #include "mlir/IR/AffineMap.h" -#include "llvm/Support/raw_ostream.h" #include "mlir/Support/STLExtras.h" +#include "llvm/Support/raw_ostream.h" using namespace mlir; +unsigned Type::getBitWidth() const { + switch (getKind()) { + // TODO: Currently the IR uses host double type to store all the float + // datatypes. This is completely incorrect for BF16 and other datatypes. + // We have to fix this once APFloat is used in the IR. + case Type::Kind::BF16: + case Type::Kind::F16: + case Type::Kind::F32: + case Type::Kind::F64: + return 64; + case Type::Kind::Integer: + return cast(this)->getWidth(); + case Type::Kind::Vector: + case Type::Kind::RankedTensor: + case Type::Kind::UnrankedTensor: + return cast(this)->getElementType()->getBitWidth(); + // TODO: Handle more types. + default: + llvm_unreachable("unexpected type"); + } +} + IntegerType::IntegerType(unsigned width, MLIRContext *context) - : Type(Kind::Integer, context), width(width) { - assert(width <= kMaxWidth && "admissible integer bitwidth exceeded"); + : Type(Kind::Integer, context), width(width) { + assert(width <= kMaxWidth && "admissible integer bitwidth exceeded"); } FloatType::FloatType(Kind kind, MLIRContext *context) : Type(kind, context) {} @@ -32,25 +54,39 @@ OtherType::OtherType(Kind kind, MLIRContext *context) : Type(kind, context) {} FunctionType::FunctionType(Type *const *inputsAndResults, unsigned numInputs, unsigned numResults, MLIRContext *context) - : Type(Kind::Function, context, numInputs), - numResults(numResults), inputsAndResults(inputsAndResults) { -} + : Type(Kind::Function, context, numInputs), numResults(numResults), + inputsAndResults(inputsAndResults) {} VectorOrTensorType::VectorOrTensorType(Kind kind, MLIRContext *context, Type *elementType, unsigned subClassData) : Type(kind, context, subClassData), elementType(elementType) {} +unsigned VectorOrTensorType::getNumElements() const { + switch (getKind()) { + case Kind::Vector: + case Kind::RankedTensor: { + auto shape = getShape(); + unsigned num = 1; + for (auto dim : shape) + num *= dim; + return num; + } + default: + llvm_unreachable("not a VectorOrTensorType or not ranked"); + } +} + /// If this is ranked tensor or vector type, return the rank. If it is an /// unranked tensor, return -1. int VectorOrTensorType::getRank() const { switch (getKind()) { - default: - llvm_unreachable("not a VectorOrTensorType"); case Kind::Vector: case Kind::RankedTensor: return getShape().size(); case Kind::UnrankedTensor: return -1; + default: + llvm_unreachable("not a VectorOrTensorType"); } } @@ -60,7 +96,7 @@ int VectorOrTensorType::getDimSize(unsigned i) const { case Kind::RankedTensor: return getShape()[i]; default: - llvm_unreachable("not a VectorOrTensorType"); + llvm_unreachable("not a VectorOrTensorType or not ranked"); } } @@ -94,14 +130,13 @@ TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context) RankedTensorType::RankedTensorType(ArrayRef shape, Type *elementType, MLIRContext *context) - : TensorType(Kind::RankedTensor, elementType, context), - shapeElements(shape.data()) { + : TensorType(Kind::RankedTensor, elementType, context), + shapeElements(shape.data()) { setSubclassData(shape.size()); } UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context) - : TensorType(Kind::UnrankedTensor, elementType, context) { -} + : TensorType(Kind::UnrankedTensor, elementType, context) {} MemRefType::MemRefType(ArrayRef shape, Type *elementType, ArrayRef affineMapList, unsigned memorySpace, diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 938d9606c68a..d202cf306329 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -35,6 +35,8 @@ #include "mlir/IR/OperationSet.h" #include "mlir/IR/StmtVisitor.h" #include "mlir/IR/Types.h" +#include "mlir/Support/STLExtras.h" +#include "llvm/ADT/APInt.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/PrettyStackTrace.h" @@ -205,6 +207,8 @@ public: AffineMap parseAffineMapReference(); IntegerSet parseIntegerSetInline(); IntegerSet parseIntegerSetReference(); + ElementsAttr *parseDenseElementsAttr(VectorOrTensorType *type); + VectorOrTensorType *parseVectorOrTensorType(); private: // The Parser is subclassed and reinstantiated. Do not add additional @@ -624,6 +628,144 @@ ParseResult Parser::parseTypeList(SmallVectorImpl &elements) { // Attribute parsing. //===----------------------------------------------------------------------===// +namespace { +class TensorLiteralParser { +public: + TensorLiteralParser(Parser &p, Type *eltTy) + : p(p), eltTy(eltTy), currBitPos(0), bitsWidth(eltTy->getBitWidth()) {} + + ParseResult parse() { return parseList(shape); } + + ArrayRef getValues() const { + return {reinterpret_cast(storage.data()), storage.size() * 8}; + } + + ArrayRef getShape() const { return shape; } + +private: + /// Parse either a single element or a list of elements. Return the dimensions + /// of the parsed sub-tensor in dims. + ParseResult parseElementOrList(llvm::SmallVectorImpl &dims); + + /// Parse a list of either lists or elements, returning the dimensions of the + /// parsed sub-tensors in dims. For example: + /// parseList([1, 2, 3]) -> Success, [3] + /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] + /// parseList([[1, 2], 3]) -> Failure + /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure + ParseResult parseList(llvm::SmallVectorImpl &dims); + + void addToStorage(uint64_t value) { + if (bitsWidth == 64) + storage.push_back(value); + + if (currBitPos + bitsWidth > storage.size() * 64) + storage.push_back(0L); + + auto *rawData = reinterpret_cast(storage.data()); + DenseIntElementsAttr::writeBits(rawData, currBitPos, bitsWidth, value); + currBitPos += bitsWidth; + } + + Parser &p; + Type *eltTy; + size_t currBitPos; + size_t bitsWidth; + SmallVector shape; + std::vector storage; +}; +} // namespace + +/// Parse either a single element or a list of elements. Return the dimensions +/// of the parsed sub-tensor in dims. +ParseResult +TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl &dims) { + switch (p.getToken().getKind()) { + case Token::l_square: + return parseList(dims); + case Token::floatliteral: + case Token::integer: + case Token::minus: { + auto *result = p.parseAttribute(); + if (!result) + return p.emitError("expected tensor element"); + // check result matches the element type. + switch (eltTy->getKind()) { + case Type::Kind::BF16: + case Type::Kind::F16: + case Type::Kind::F32: + case Type::Kind::F64: { + if (!isa(result)) + return p.emitError("expected tensor literal element has float type"); + double value = cast(result)->getValue(); + addToStorage(*(uint64_t *)(&value)); + break; + } + case Type::Kind::Integer: { + if (!isa(result)) + return p.emitError("expected tensor literal element has integer type"); + auto value = cast(result)->getValue(); + // If we couldn't successfully round trip the value, it means some bits + // are truncated and we should give up here. + llvm::APInt apint(bitsWidth, (uint64_t)value, /*isSigned=*/true); + if (apint.getSExtValue() != value) + return p.emitError("tensor literal element has more bits than that " + "specified in the type"); + addToStorage((uint64_t)value); + break; + } + default: + return p.emitError("expected integer or float tensor element"); + } + break; + } + default: + return p.emitError("expected '[' or scalar constant inside tensor literal"); + } + return ParseSuccess; +} + +/// Parse a list of either lists or elements, returning the dimensions of the +/// parsed sub-tensors in dims. For example: +/// parseList([1, 2, 3]) -> Success, [3] +/// parseList([[1, 2], [3, 4]]) -> Success, [2, 2] +/// parseList([[1, 2], 3]) -> Failure +/// parseList([[1, [2, 3]], [4, [5]]]) -> Failure +ParseResult TensorLiteralParser::parseList(llvm::SmallVectorImpl &dims) { + p.consumeToken(Token::l_square); + + auto checkDims = [&](const llvm::SmallVectorImpl &prevDims, + const llvm::SmallVectorImpl &newDims) { + if (prevDims == newDims) + return ParseSuccess; + return p.emitError("tensor literal is invalid; ranks are not consistent " + "between elements"); + }; + + bool first = true; + llvm::SmallVector newDims; + unsigned size = 0; + auto parseCommaSeparatedList = [&]() { + llvm::SmallVector thisDims; + if (parseElementOrList(thisDims)) + return ParseFailure; + ++size; + if (!first) + return checkDims(newDims, thisDims); + newDims = thisDims; + first = false; + return ParseSuccess; + }; + if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList)) + return ParseFailure; + + // Return the sublists' dimensions with 'size' prepended. + dims.clear(); + dims.push_back(size); + dims.insert(dims.end(), newDims.begin(), newDims.end()); + return ParseSuccess; +} + /// Given a parsed reference to a function name like @foo and a type that it /// corresponds to, resolve it to a concrete function object (possibly /// synthesizing a forward reference) or emit an error and return null on @@ -659,7 +801,7 @@ Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc, /// | type /// | `[` (attribute-value (`,` attribute-value)*)? `]` /// | function-id `:` function-type -/// | `splat<` (tensor-type | vector-type)`,` +/// | (`splat<` | `dense<`) (tensor-type | vector-type)`,` /// attribute-value `>` /// Attribute *Parser::parseAttribute() { @@ -757,24 +899,13 @@ Attribute *Parser::parseAttribute() { case Token::kw_splat: { consumeToken(Token::kw_splat); - if (parseToken(Token::less, "Expected '<' after 'elements'")) + if (parseToken(Token::less, "expected '<' after 'splat'")) return nullptr; - auto *type = dyn_cast(parseType()); - if (!type) { - return ( - emitError("expected elements literal has a tensor or vector type"), - nullptr); - } - - if (parseToken(Token::comma, "Expected ','")) + auto *type = parseVectorOrTensorType(); + if (!type) return nullptr; - if (!type->hasStaticShape() || type->getRank() == -1) { - return (emitError("tensor literals must be ranked and have static shape"), - nullptr); - } - switch (getToken().getKind()) { case Token::floatliteral: case Token::integer: @@ -785,12 +916,32 @@ Attribute *Parser::parseAttribute() { return builder.getSplatElementsAttr(type, scalar); } default: - return ( - emitError("expected '[' or scalar constant inside tensor literal"), - nullptr); + return (emitError("expected scalar constant inside tensor literal"), + nullptr); } } + case Token::kw_dense: { + consumeToken(Token::kw_dense); + if (parseToken(Token::less, "expected '<' after 'dense'")) + return nullptr; + auto *type = parseVectorOrTensorType(); + if (!type) + return nullptr; + + switch (getToken().getKind()) { + case Token::l_square: { + auto attr = parseDenseElementsAttr(type); + if (!attr) + return nullptr; + if (parseToken(Token::greater, "expected '>'")) + return nullptr; + return attr; + } + default: + return (emitError("expected '[' to start dense tensor literal"), nullptr); + } + } default: { if (Type *type = parseType()) return builder.getTypeAttr(type); @@ -799,6 +950,42 @@ Attribute *Parser::parseAttribute() { } } +ElementsAttr *Parser::parseDenseElementsAttr(VectorOrTensorType *type) { + auto *eltTy = type->getElementType(); + TensorLiteralParser literalParser(*this, eltTy); + if (literalParser.parse()) + return nullptr; + if (literalParser.getShape() != type->getShape()) { + std::string str; + llvm::raw_string_ostream s(str); + s << "inferred shape of elements literal (["; + interleaveComma(literalParser.getShape(), s); + s << "]) does not match type (["; + interleaveComma(type->getShape(), s); + s << "])"; + return (emitError(s.str()), nullptr); + } + return builder.getDenseElementsAttr(type, literalParser.getValues()); +} + +VectorOrTensorType *Parser::parseVectorOrTensorType() { + auto *type = dyn_cast(parseType()); + if (!type) { + return (emitError("expected elements literal has a tensor or vector type"), + nullptr); + } + + if (parseToken(Token::comma, "expected ','")) + return nullptr; + + if (!type->hasStaticShape() || type->getRank() == -1) { + return (emitError("tensor literals must be ranked and have static shape"), + nullptr); + } + + return type; +} + /// Attribute dictionary. /// /// attribute-dict ::= `{` `}` @@ -848,8 +1035,8 @@ enum AffineLowPrecOp { Sub }; -/// Higher precedence ops - all at the same precedence level. HNoOp is false in -/// the boolean sense. +/// Higher precedence ops - all at the same precedence level. HNoOp is false +/// in the boolean sense. enum AffineHighPrecOp { /// Null value. HNoOp, @@ -957,8 +1144,8 @@ AffineExpr AffineParser::getBinaryAffineOpExpr(AffineLowPrecOp op, } } -/// Consume this token if it is a lower precedence affine op (there are only two -/// precedence levels). +/// Consume this token if it is a lower precedence affine op (there are only +/// two precedence levels). AffineLowPrecOp AffineParser::consumeIfLowPrecOp() { switch (getToken().getKind()) { case Token::plus: @@ -1103,8 +1290,8 @@ AffineExpr AffineParser::parseIntegerExpr() { // Eg: for an expression without parentheses (like i + j + k + l), each // of the four identifiers is an operand. For i + j*k + l, j*k is not an // operand expression, it's an op expression and will be parsed via -// parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and -l -// are valid operands that will be parsed by this function. +// parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and +// -l are valid operands that will be parsed by this function. AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) { switch (getToken().getKind()) { case Token::bare_identifier: @@ -1148,13 +1335,13 @@ AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) { /// /// llhs: the affine expression appearing on the left of the one being parsed. /// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null, -/// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned if -/// llhs is non-null; otherwise lhs is returned. This is to deal with left +/// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned +/// if llhs is non-null; otherwise lhs is returned. This is to deal with left /// associativity. /// /// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function -/// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where (e2*e3) -/// will be parsed using parseAffineHighPrecOpExpr(). +/// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where +/// (e2*e3) will be parsed using parseAffineHighPrecOpExpr(). AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp) { AffineExpr lhs; @@ -1208,16 +1395,16 @@ AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs, /// | bare-id /// | integer-literal /// -/// Additional conditions are checked depending on the production. For eg., one -/// of the operands for `*` has to be either constant/symbolic; the second +/// Additional conditions are checked depending on the production. For eg., +/// one of the operands for `*` has to be either constant/symbolic; the second /// operand for floordiv, ceildiv, and mod has to be a positive integer. AffineExpr AffineParser::parseAffineExpr() { return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp); } -/// Parse a dim or symbol from the lists appearing before the actual expressions -/// of the affine map. Update our state to store the dimensional/symbolic -/// identifier. +/// Parse a dim or symbol from the lists appearing before the actual +/// expressions of the affine map. Update our state to store the +/// dimensional/symbolic identifier. ParseResult AffineParser::parseIdentifierDefinition(AffineExpr idExpr) { if (getToken().isNot(Token::bare_identifier)) return emitError("expected bare identifier"); @@ -1288,9 +1475,9 @@ AffineMap AffineParser::parseAffineMapInline() { return res; }; - // Parse a multi-dimensional affine expression (a comma-separated list of 1-d - // affine expressions); the list cannot be empty. - // Grammar: multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) + // Parse a multi-dimensional affine expression (a comma-separated list of + // 1-d affine expressions); the list cannot be empty. Grammar: + // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, false)) return AffineMap::Invalid(); @@ -1357,8 +1544,8 @@ AffineMap Parser::parseAffineMapReference() { //===----------------------------------------------------------------------===// namespace { -/// This class contains parser state that is common across CFG and ML functions, -/// notably for dealing with operations and SSA values. +/// This class contains parser state that is common across CFG and ML +/// functions, notably for dealing with operations and SSA values. class FunctionParser : public Parser { public: enum class Kind { CFGFunc, MLFunc }; @@ -1371,15 +1558,15 @@ public: /// This represents a use of an SSA value in the program. The first two /// entries in the tuple are the name and result number of a reference. The - /// third is the location of the reference, which is used in case this ends up - /// being a use of an undefined value. + /// third is the location of the reference, which is used in case this ends + /// up being a use of an undefined value. struct SSAUseInfo { StringRef name; // Value name, e.g. %42 or %abc unsigned number; // Number, specified with #12 SMLoc loc; // Location of first definition or use. }; - /// Given a reference to an SSA value and its type, return a reference. This + /// Given a reference to an SSA value and its type, return a reference. This /// returns null on failure. SSAValue *resolveSSAUse(SSAUseInfo useInfo, Type *type); @@ -1442,8 +1629,9 @@ SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc, // Forward references are always created as instructions, even in ML // functions, because we just need something with a def/use chain. // - // We create these placeholders as having an empty name, which we know cannot - // be created through normal user input, allowing us to distinguish them. + // We create these placeholders as having an empty name, which we know + // cannot be created through normal user input, allowing us to distinguish + // them. auto name = OperationName("placeholder", getContext()); auto *inst = OperationInst::create(getEncodedSourceLocation(loc), name, /*operands=*/{}, type, @@ -1512,9 +1700,9 @@ ParseResult FunctionParser::addDefinition(SSAUseInfo useInfo, SSAValue *value) { "previously defined here"); } - // If it was a forward reference, update everything that used it to use the - // actual definition instead, delete the forward ref, and remove it from our - // set of forward references we track. + // If it was a forward reference, update everything that used it to use + // the actual definition instead, delete the forward ref, and remove it + // from our set of forward references we track. existing->replaceAllUsesWith(value); existing->getDefiningInst()->destroy(); forwardReferencePlaceholders.erase(existing); @@ -1528,7 +1716,8 @@ ParseResult FunctionParser::addDefinition(SSAUseInfo useInfo, SSAValue *value) { /// After the function is finished parsing, this function checks to see if /// there are any remaining issues. ParseResult FunctionParser::finalizeFunction(Function *func, SMLoc loc) { - // Check for any forward references that are left. If we find any, error out. + // Check for any forward references that are left. If we find any, error + // out. if (!forwardReferencePlaceholders.empty()) { SmallVector, 4> errors; // Iteration over the map isn't deterministic, so sort by source location. @@ -1825,9 +2014,9 @@ public: return !(result = parser.parseType()); } - /// Parse an arbitrary attribute and return it in result. This also adds the - /// attribute to the specified attribute list with the specified name. this - /// captures the location of the attribute in 'loc' if it is non-null. + /// Parse an arbitrary attribute and return it in result. This also adds + /// the attribute to the specified attribute list with the specified name. + /// this captures the location of the attribute in 'loc' if it is non-null. bool parseAttribute(Attribute *&result, const char *attrName, SmallVectorImpl &attrs) override { result = parser.parseAttribute(); @@ -1997,7 +2186,8 @@ Operation *FunctionParser::parseCustomOperation( consumeToken(); - // If the custom op parser crashes, produce some indication to help debugging. + // If the custom op parser crashes, produce some indication to help + // debugging. std::string opNameStr = opName.str(); llvm::PrettyStackTraceFormat fmt("MLIR Parser: custom op parser '%s'", opNameStr.c_str()); @@ -2176,7 +2366,8 @@ ParseResult CFGFunctionParser::parseBasicBlock() { if (parseToken(Token::colon, "expected ':' after basic block name")) return ParseFailure; - // Set the insertion point to the block we want to insert new operations into. + // Set the insertion point to the block we want to insert new operations + // into. builder.setInsertionPoint(block); auto createOpFunc = [&](const OperationState &result) -> Operation * { @@ -2218,7 +2409,8 @@ ParseResult CFGFunctionParser::parseBranchBlockAndUseList( /// terminator-stmt ::= `br` bb-id branch-use-list? /// branch-use-list ::= `(` ssa-use-list `)` ':' type-list-no-parens /// terminator-stmt ::= -/// `cond_br` ssa-use `,` bb-id branch-use-list? `,` bb-id branch-use-list? +/// `cond_br` ssa-use `,` bb-id branch-use-list? `,` bb-id +/// branch-use-list? /// terminator-stmt ::= `return` ssa-use-and-type-list? /// TerminatorInst *CFGFunctionParser::parseTerminator() { @@ -2471,9 +2663,9 @@ MLFunctionParser::parseDimAndSymbolList(SmallVectorImpl &operands, // Loop bound. /// -/// lower-bound ::= `max`? affine-map dim-and-symbol-use-list | shorthand-bound -/// upper-bound ::= `min`? affine-map dim-and-symbol-use-list | shorthand-bound -/// shorthand-bound ::= ssa-id | `-`? integer-literal +/// lower-bound ::= `max`? affine-map dim-and-symbol-use-list | +/// shorthand-bound upper-bound ::= `min`? affine-map dim-and-symbol-use-list +/// | shorthand-bound shorthand-bound ::= ssa-id | `-`? integer-literal /// ParseResult MLFunctionParser::parseBound(SmallVectorImpl &operands, AffineMap &map, bool isLower) { @@ -2532,8 +2724,8 @@ ParseResult MLFunctionParser::parseBound(SmallVectorImpl &operands, /// affine-constraint ::= affine-expr `>=` `0` /// | affine-expr `==` `0` /// -/// isEq is set to true if the parsed constraint is an equality, false if it is -/// an inequality (greater than or equal). +/// isEq is set to true if the parsed constraint is an equality, false if it +/// is an inequality (greater than or equal). /// AffineExpr AffineParser::parseAffineConstraint(bool *isEq) { AffineExpr expr = parseAffineExpr(); @@ -2568,9 +2760,11 @@ AffineExpr AffineParser::parseAffineConstraint(bool *isEq) { /// Parse an integer set definition. /// integer-set-inline -/// ::= dim-and-symbol-id-lists `:` affine-constraint-conjunction +/// ::= dim-and-symbol-id-lists `:` +/// affine-constraint-conjunction /// affine-constraint-conjunction ::= /*empty*/ -/// | affine-constraint (`,` affine-constraint)* +/// | affine-constraint (`,` +/// affine-constraint)* /// IntegerSet AffineParser::parseIntegerSetInline() { unsigned numDims = 0, numSymbols = 0; @@ -2859,11 +3053,12 @@ ModuleParser::parseMLArgumentList(SmallVectorImpl &argTypes, return parseCommaSeparatedListUntil(Token::r_paren, parseElt); } -/// Parse a function signature, starting with a name and including the parameter -/// list. +/// Parse a function signature, starting with a name and including the +/// parameter list. /// /// argument-list ::= type (`,` type)* | /*empty*/ | ml-argument-list -/// function-signature ::= function-id `(` argument-list `)` (`->` type-list)? +/// function-signature ::= function-id `(` argument-list `)` (`->` +/// type-list)? /// ParseResult ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type, @@ -2963,7 +3158,8 @@ ParseResult ModuleParser::parseCFGFunc() { return ParseFailure; } - // Okay, the CFG function signature was parsed correctly, create the function. + // Okay, the CFG function signature was parsed correctly, create the + // function. auto *function = new CFGFunction(getEncodedSourceLocation(loc), name, type, attrs); getModule()->getFunctions().push_back(function); @@ -2979,7 +3175,8 @@ ParseResult ModuleParser::parseCFGFunc() { /// ML function declarations. /// /// ml-func ::= `mlfunc` ml-func-signature -/// (`attributes` attribute-dict)? `{` ml-stmt* ml-return-stmt `}` +/// (`attributes` attribute-dict)? `{` ml-stmt* ml-return-stmt +/// `}` /// ParseResult ModuleParser::parseMLFunc() { consumeToken(Token::kw_mlfunc); @@ -2997,7 +3194,8 @@ ParseResult ModuleParser::parseMLFunc() { return ParseFailure; } - // Okay, the ML function signature was parsed correctly, create the function. + // Okay, the ML function signature was parsed correctly, create the + // function. auto *function = MLFunction::create(getEncodedSourceLocation(loc), name, type, attrs); getModule()->getFunctions().push_back(function); @@ -3019,9 +3217,9 @@ ParseResult ModuleParser::parseMLFunc() { return parser.parseFunctionBody(); } -/// Given an attribute that could refer to a function attribute in the remapping -/// table, walk it and rewrite it to use the mapped function. If it doesn't -/// refer to anything in the table, then it is returned unmodified. +/// Given an attribute that could refer to a function attribute in the +/// remapping table, walk it and rewrite it to use the mapped function. If it +/// doesn't refer to anything in the table, then it is returned unmodified. static Attribute * remapFunctionAttrs(Attribute *input, DenseMap &remappingTable, @@ -3097,8 +3295,8 @@ ParseResult ModuleParser::finalizeModule() { if (remappingTable.empty()) return ParseSuccess; - // Otherwise, walk the entire module replacing uses of one attribute set with - // the correct ones. + // Otherwise, walk the entire module replacing uses of one attribute set + // with the correct ones. for (auto &fn : *getModule()) { if (auto *cfgFn = dyn_cast(&fn)) { for (auto &bb : *cfgFn) { @@ -3147,8 +3345,8 @@ ParseResult ModuleParser::parseModule() { return finalizeModule(); // If we got an error token, then the lexer already emitted an error, just - // stop. Someday we could introduce error recovery if there was demand for - // it. + // stop. Someday we could introduce error recovery if there was demand + // for it. case Token::error: return ParseFailure; @@ -3183,7 +3381,8 @@ ParseResult ModuleParser::parseModule() { //===----------------------------------------------------------------------===// /// This parses the file specified by the indicated SourceMgr and returns an -/// MLIR module if it was valid. If not, it emits diagnostics and returns null. +/// MLIR module if it was valid. If not, it emits diagnostics and returns +/// null. Module *mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, MLIRContext *context) { @@ -3195,16 +3394,16 @@ Module *mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, return nullptr; } - // Make sure the parse module has no other structural problems detected by the - // verifier. + // Make sure the parse module has no other structural problems detected by + // the verifier. if (module->verify()) return nullptr; return module.release(); } -/// This parses the program string to a MLIR module if it was valid. If not, it -/// emits diagnostics and returns null. +/// This parses the program string to a MLIR module if it was valid. If not, +/// it emits diagnostics and returns null. Module *mlir::parseSourceString(StringRef moduleStr, MLIRContext *context) { auto memBuffer = MemoryBuffer::getMemBuffer(moduleStr); if (!memBuffer) diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def index ad0183edb20c..3ffe6351ed8c 100644 --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -93,6 +93,7 @@ TOK_KEYWORD(br) TOK_KEYWORD(ceildiv) TOK_KEYWORD(cfgfunc) TOK_KEYWORD(cond_br) +TOK_KEYWORD(dense) TOK_KEYWORD(else) TOK_KEYWORD(splat) TOK_KEYWORD(extfunc) diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index 22b69e444092..b7f30f1885ca 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -628,3 +628,64 @@ mlfunc @calls(%arg0 : i32) { // expected-error@+2 {{expected SSA operand}} cfgfunc@n(){b( // ----- + +cfgfunc @elementsattr_non_tensor_type() -> () { +bb0: + "foo"(){bar: dense} : () -> () // expected-error {{expected elements literal has a tensor or vector type}} +} + +// ----- + +cfgfunc @elementsattr_non_ranked() -> () { +bb0: + "foo"(){bar: dense, [4]>} : () -> () // expected-error {{tensor literals must be ranked and have static shape}} +} + +// ----- + +cfgfunc @elementsattr_shape_mismatch() -> () { +bb0: + "foo"(){bar: dense, [4]>} : () -> () // expected-error {{inferred shape of elements literal ([1]) does not match type ([5])}} +} + +// ----- + +cfgfunc @elementsattr_invalid() -> () { +bb0: + "foo"(){bar: dense, [4, [5]]>} : () -> () // expected-error {{tensor literal is invalid; ranks are not consistent between elements}} +} + +// ----- + +cfgfunc @elementsattr_badtoken() -> () { +bb0: + "foo"(){bar: dense, [tf_opaque]>} : () -> () // expected-error {{expected '[' or scalar constant inside tensor literal}} +} + +// ----- + +cfgfunc @elementsattr_floattype1() -> () { +bb0: + "foo"(){bar: dense, [4.0]>} : () -> () // expected-error {{expected tensor literal element has integer type}} +} + +// ----- + +cfgfunc @elementsattr_floattype2() -> () { +bb0: + "foo"(){bar: dense, [4]>} : () -> () // expected-error {{expected tensor literal element has float type}} +} + +// ----- + +cfgfunc @elementsattr_toolarge1() -> () { +bb0: + "foo"(){bar: dense, [777]>} : () -> () // expected-error {{tensor literal element has more bits than that specified in the type}} +} + +// ----- + +cfgfunc @elementsattr_toolarge2() -> () { +bb0: + "foo"(){bar: dense, [-777]>} : () -> () // expected-error {{tensor literal element has more bits than that specified in the type}} +} diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index db19d22cab20..8e7d7fdbc0a8 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -485,8 +485,8 @@ mlfunc @mlfuncsimplemap(%arg0 : index, %arg1 : index) -> () { return } -// CHECK-LABEL: cfgfunc @tensorattr -cfgfunc @tensorattr() -> () { +// CHECK-LABEL: cfgfunc @splattensorattr +cfgfunc @splattensorattr() -> () { bb0: // CHECK: "splatIntTensor"() {bar: splat, 5>} : () -> () "splatIntTensor"(){bar: splat, 5>} : () -> () @@ -498,3 +498,96 @@ bb0: "splatFloatVector"(){bar: splat, -5.0>} : () -> () return } + +// CHECK-LABEL: cfgfunc @densetensorattr +cfgfunc @densetensorattr() -> () { +bb0: + +// NOTE: The {{\[\[}} syntax is because "[[" confuses FileCheck. +// CHECK: "fooi3"() {bar: dense, {{\[\[\[}}1, -2, 1, 2]], {{\[\[}}0, 2, -1, 2]]]>} : () -> () + "fooi3"(){bar: dense, [[[1, -2, 1, 2]], [[0, 2, -1, 2]]]>} : () -> () +// CHECK: "fooi6"() {bar: dense, {{\[\[\[}}5, -6, 1, 2]], {{\[\[}}7, 8, 3, 4]]]>} : () -> () + "fooi6"(){bar: dense, [[[5, -6, 1, 2]], [[7, 8, 3, 4]]]>} : () -> () +// CHECK: "fooi8"() {bar: dense, {{\[\[\[}}5]]]>} : () -> () + "fooi8"(){bar: dense, [[[5]]]>} : () -> () +// CHECK: "fooi13"() {bar: dense, {{\[\[\[}}1, -2, 1, 2]], {{\[\[}}0, 2, -1, 2]]]>} : () -> () + "fooi13"(){bar: dense, [[[1, -2, 1, 2]], [[0, 2, -1, 2]]]>} : () -> () +// CHECK: "fooi16"() {bar: dense, {{\[\[\[}}-5]]]>} : () -> () + "fooi16"(){bar: dense, [[[-5]]]>} : () -> () +// CHECK: "fooi23"() {bar: dense, {{\[\[\[}}1, -2, 1, 2]], {{\[\[}}0, 2, -1, 2]]]>} : () -> () + "fooi23"(){bar: dense, [[[1, -2, 1, 2]], [[0, 2, -1, 2]]]>} : () -> () +// CHECK: "fooi32"() {bar: dense, {{\[\[\[}}5]]]>} : () -> () + "fooi32"(){bar: dense, [[[5]]]>} : () -> () +// CHECK: "fooi33"() {bar: dense, {{\[\[\[}}1, -2, 1, 2]], {{\[\[}}0, 2, -1, 2]]]>} : () -> () + "fooi33"(){bar: dense, [[[1, -2, 1, 2]], [[0, 2, -1, 2]]]>} : () -> () +// CHECK: "fooi43"() {bar: dense, {{\[\[\[}}1, -2, 1, 2]], {{\[\[}}0, 2, -1, 2]]]>} : () -> () + "fooi43"(){bar: dense, [[[1, -2, 1, 2]], [[0, 2, -1, 2]]]>} : () -> () +// CHECK: "fooi53"() {bar: dense, {{\[\[\[}}1, -2, 1, 2]], {{\[\[}}0, 2, -1, 2]]]>} : () -> () + "fooi53"(){bar: dense, [[[1, -2, 1, 2]], [[0, 2, -1, 2]]]>} : () -> () +// CHECK: "fooi64"() {bar: dense, {{\[\[\[}}1, -2, 1, 2]], {{\[\[}}0, 3, -1, 2]]]>} : () -> () + "fooi64"(){bar: dense, [[[1, -2, 1, 2]], [[0, 3, -1, 2]]]>} : () -> () +// CHECK: "fooi64"() {bar: dense, {{\[\[\[}}-5]]]>} : () -> () + "fooi64"(){bar: dense, [[[-5]]]>} : () -> () + +// CHECK: "foo2"() {bar: dense, []>} : () -> () + "foo2"(){bar: dense, []>} : () -> () +// CHECK: "foo2"() {bar: dense, {{\[\[}}]]>} : () -> () + "foo2"(){bar: dense, [[]]>} : () -> () +// CHECK: "foo3"() {bar: dense, {{\[\[\[}}5, -6, 1, 2]], {{\[\[}}7, 8, 3, 4]]]>} : () -> () + "foo3"(){bar: dense, [[[5, -6, 1, 2]], [[7, 8, 3, 4]]]>} : () -> () + +// CHECK: "float1"() {bar: dense, {{\[\[\[}}5.000000e+00]]]>} : () -> () + "float1"(){bar: dense, [[[5.0]]]>} : () -> () +// CHECK: "float2"() {bar: dense, []>} : () -> () + "float2"(){bar: dense, []>} : () -> () +// CHECK: "float2"() {bar: dense, {{\[\[}}]]>} : () -> () + "float2"(){bar: dense, [[]]>} : () -> () + +// CHECK: "bfloat16"() {bar: dense, {{\[\[\[}}-5.000000e+00, 6.000000e+00, 1.000000e+00, 2.000000e+00]], {{\[\[}}7.000000e+00, -8.000000e+00, 3.000000e+00, 4.000000e+00]]]>} : () -> () + "bfloat16"(){bar: dense, [[[-5.0, 6.0, 1.0, 2.0]], [[7.0, -8.0, 3.0, 4.0]]]>} : () -> () +// CHECK: "float16"() {bar: dense, {{\[\[\[}}-5.000000e+00, 6.000000e+00, 1.000000e+00, 2.000000e+00]], {{\[\[}}7.000000e+00, -8.000000e+00, 3.000000e+00, 4.000000e+00]]]>} : () -> () + "float16"(){bar: dense, [[[-5.0, 6.0, 1.0, 2.0]], [[7.0, -8.0, 3.0, 4.0]]]>} : () -> () +// CHECK: "float32"() {bar: dense, {{\[\[\[}}-5.000000e+00, 6.000000e+00, 1.000000e+00, 2.000000e+00]], {{\[\[}}7.000000e+00, -8.000000e+00, 3.000000e+00, 4.000000e+00]]]>} : () -> () + "float32"(){bar: dense, [[[-5.0, 6.0, 1.0, 2.0]], [[7.0, -8.0, 3.0, 4.0]]]>} : () -> () +// CHECK: "float64"() {bar: dense, {{\[\[\[}}-5.000000e+00, 6.000000e+00, 1.000000e+00, 2.000000e+00]], {{\[\[}}7.000000e+00, -8.000000e+00, 3.000000e+00, 4.000000e+00]]]>} : () -> () + "float64"(){bar: dense, [[[-5.0, 6.0, 1.0, 2.0]], [[7.0, -8.0, 3.0, 4.0]]]>} : () -> () + return +} + +// CHECK-LABEL: cfgfunc @densevectorattr +cfgfunc @densevectorattr() -> () { +bb0: +// NOTE: The {{\[\[}} syntax is because "[[" confuses FileCheck. +// CHECK: "fooi8"() {bar: dense, {{\[\[\[}}5]]]>} : () -> () + "fooi8"(){bar: dense, [[[5]]]>} : () -> () +// CHECK: "fooi16"() {bar: dense, {{\[\[\[}}-5]]]>} : () -> () + "fooi16"(){bar: dense, [[[-5]]]>} : () -> () +// CHECK: "foo32"() {bar: dense, {{\[\[\[}}5]]]>} : () -> () + "foo32"(){bar: dense, [[[5]]]>} : () -> () +// CHECK: "fooi64"() {bar: dense, {{\[\[\[}}-5]]]>} : () -> () + "fooi64"(){bar: dense, [[[-5]]]>} : () -> () + +// CHECK: "foo2"() {bar: dense, []>} : () -> () + "foo2"(){bar: dense, []>} : () -> () +// CHECK: "foo2"() {bar: dense, {{\[\[}}]]>} : () -> () + "foo2"(){bar: dense, [[]]>} : () -> () +// CHECK: "foo3"() {bar: dense, {{\[\[\[}}5, -6, 1, 2]], {{\[\[}}7, 8, 3, 4]]]>} : () -> () + "foo3"(){bar: dense, [[[5, -6, 1, 2]], [[7, 8, 3, 4]]]>} : () -> () + +// CHECK: "float1"() {bar: dense, {{\[\[\[}}5.000000e+00]]]>} : () -> () + "float1"(){bar: dense, [[[5.0]]]>} : () -> () +// CHECK: "float2"() {bar: dense, []>} : () -> () + "float2"(){bar: dense, []>} : () -> () +// CHECK: "float2"() {bar: dense, {{\[\[}}]]>} : () -> () + "float2"(){bar: dense, [[]]>} : () -> () + +// CHECK: "bfloat16"() {bar: dense, {{\[\[\[}}-5.000000e+00, 6.000000e+00, 1.000000e+00, 2.000000e+00]], {{\[\[}}7.000000e+00, -8.000000e+00, 3.000000e+00, 4.000000e+00]]]>} : () -> () + "bfloat16"(){bar: dense, [[[-5.0, 6.0, 1.0, 2.0]], [[7.0, -8.0, 3.0, 4.0]]]>} : () -> () +// CHECK: "float16"() {bar: dense, {{\[\[\[}}-5.000000e+00, 6.000000e+00, 1.000000e+00, 2.000000e+00]], {{\[\[}}7.000000e+00, -8.000000e+00, 3.000000e+00, 4.000000e+00]]]>} : () -> () + "float16"(){bar: dense, [[[-5.0, 6.0, 1.0, 2.0]], [[7.0, -8.0, 3.0, 4.0]]]>} : () -> () +// CHECK: "float32"() {bar: dense, {{\[\[\[}}-5.000000e+00, 6.000000e+00, 1.000000e+00, 2.000000e+00]], {{\[\[}}7.000000e+00, -8.000000e+00, 3.000000e+00, 4.000000e+00]]]>} : () -> () + "float32"(){bar: dense, [[[-5.0, 6.0, 1.0, 2.0]], [[7.0, -8.0, 3.0, 4.0]]]>} : () -> () +// CHECK: "float64"() {bar: dense, {{\[\[\[}}-5.000000e+00, 6.000000e+00, 1.000000e+00, 2.000000e+00]], {{\[\[}}7.000000e+00, -8.000000e+00, 3.000000e+00, 4.000000e+00]]]>} : () -> () + "float64"(){bar: dense, [[[-5.0, 6.0, 1.0, 2.0]], [[7.0, -8.0, 3.0, 4.0]]]>} : () -> () + return +} \ No newline at end of file