[mlir] Make DenseArrayAttr generic

This patch turns `DenseArrayBaseAttr` into a fully-functional attribute by
adding a generic parser and printer, supporting bool or integer and floating
point element types with bitwidths divisible by 8. It has been renamed
to `DenseArrayAttr`. The patch maintains the specialized subclasses,
e.g. `DenseI32ArrayAttr`, which remain the preferred API for accessing
elements in C++.

This allows `DenseArrayAttr` to hold signed and unsigned integer elements:

```
array<si8: -128, 127>
array<ui8: 255>
```

"Exotic" floating point elements:

```
array<bf16: 1.2, 3.4>
```

And integers of other bitwidths:

```
array<i24: 8388607>
```

Reviewed By: rriddle, lattner

Differential Revision: https://reviews.llvm.org/D132758
This commit is contained in:
Jeff Niu
2022-08-25 16:21:28 -07:00
parent 039b969b32
commit cec7e80ebd
9 changed files with 297 additions and 171 deletions

View File

@@ -761,9 +761,9 @@ namespace detail {
/// Base class for DenseArrayAttr that is instantiated and specialized for each
/// supported element type below.
template <typename T>
class DenseArrayAttr : public DenseArrayBaseAttr {
class DenseArrayAttrImpl : public DenseArrayAttr {
public:
using DenseArrayBaseAttr::DenseArrayBaseAttr;
using DenseArrayAttr::DenseArrayAttr;
/// Implicit conversion to ArrayRef<T>.
operator ArrayRef<T>() const;
@@ -773,7 +773,7 @@ public:
T operator[](std::size_t index) const { return asArrayRef()[index]; }
/// Builder from ArrayRef<T>.
static DenseArrayAttr get(MLIRContext *context, ArrayRef<T> content);
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef<T> content);
/// Print the short form `[42, 100, -1]` without any type prefix.
void print(AsmPrinter &printer) const;
@@ -791,23 +791,23 @@ public:
static bool classof(Attribute attr);
};
extern template class DenseArrayAttr<bool>;
extern template class DenseArrayAttr<int8_t>;
extern template class DenseArrayAttr<int16_t>;
extern template class DenseArrayAttr<int32_t>;
extern template class DenseArrayAttr<int64_t>;
extern template class DenseArrayAttr<float>;
extern template class DenseArrayAttr<double>;
extern template class DenseArrayAttrImpl<bool>;
extern template class DenseArrayAttrImpl<int8_t>;
extern template class DenseArrayAttrImpl<int16_t>;
extern template class DenseArrayAttrImpl<int32_t>;
extern template class DenseArrayAttrImpl<int64_t>;
extern template class DenseArrayAttrImpl<float>;
extern template class DenseArrayAttrImpl<double>;
} // namespace detail
// Public name for all the supported DenseArrayAttr
using DenseBoolArrayAttr = detail::DenseArrayAttr<bool>;
using DenseI8ArrayAttr = detail::DenseArrayAttr<int8_t>;
using DenseI16ArrayAttr = detail::DenseArrayAttr<int16_t>;
using DenseI32ArrayAttr = detail::DenseArrayAttr<int32_t>;
using DenseI64ArrayAttr = detail::DenseArrayAttr<int64_t>;
using DenseF32ArrayAttr = detail::DenseArrayAttr<float>;
using DenseF64ArrayAttr = detail::DenseArrayAttr<double>;
using DenseBoolArrayAttr = detail::DenseArrayAttrImpl<bool>;
using DenseI8ArrayAttr = detail::DenseArrayAttrImpl<int8_t>;
using DenseI16ArrayAttr = detail::DenseArrayAttrImpl<int16_t>;
using DenseI32ArrayAttr = detail::DenseArrayAttrImpl<int32_t>;
using DenseI64ArrayAttr = detail::DenseArrayAttrImpl<int64_t>;
using DenseF32ArrayAttr = detail::DenseArrayAttrImpl<float>;
using DenseF64ArrayAttr = detail::DenseArrayAttrImpl<double>;
//===----------------------------------------------------------------------===//
// DenseResourceElementsAttr

View File

@@ -140,7 +140,7 @@ def Builtin_ArrayAttr : Builtin_Attr<"Array", [
}
//===----------------------------------------------------------------------===//
// DenseArrayBaseAttr
// DenseArrayAttr
//===----------------------------------------------------------------------===//
def Builtin_DenseArrayRawDataParameter : ArrayRefParameter<
@@ -155,23 +155,28 @@ def Builtin_DenseArrayRawDataParameter : ArrayRefParameter<
}];
}
def Builtin_DenseArrayBase : Builtin_Attr<
"DenseArrayBase", [ElementsAttrInterface, TypedAttrInterface]> {
let summary = "A dense array of i8, i16, i32, i64, f32, or f64.";
def Builtin_DenseArray : Builtin_Attr<
"DenseArray", [ElementsAttrInterface, TypedAttrInterface]> {
let summary = "A dense array of integer or floating point elements.";
let description = [{
A dense array attribute is an attribute that represents a dense array of
primitive element types. Contrary to DenseIntOrFPElementsAttr this is a
flat unidimensional array which does not have a storage optimization for
splat. This allows to expose the raw array through a C++ API as
`ArrayRef<T>`. This is the base class attribute, the actual access is
intended to be managed through the subclasses `DenseI8ArrayAttr`,
`DenseI16ArrayAttr`, `DenseI32ArrayAttr`, `DenseI64ArrayAttr`,
`DenseF32ArrayAttr`, and `DenseF64ArrayAttr`.
`ArrayRef<T>` for compatible types. The element type must be bool or an
integer or float whose bitwidth is a multiple of 8. Bool elements are stored
as bytes.
This is the base class attribute. Access to C++ types is intended to be
managed through the subclasses `DenseI8ArrayAttr`, `DenseI16ArrayAttr`,
`DenseI32ArrayAttr`, `DenseI64ArrayAttr`, `DenseF32ArrayAttr`,
and `DenseF64ArrayAttr`.
Syntax:
```
dense-array-attribute ::= `[` `:` (integer-type | float-type) tensor-literal `]`
dense-array-attribute ::= `array` `<` (integer-type | float-type)
(`:` tensor-literal)? `>`
```
Examples:
@@ -181,16 +186,26 @@ def Builtin_DenseArrayBase : Builtin_Attr<
array<f64: 42., 12.>
```
when a specific subclass is used as argument of an operation, the declarative
assembly will omit the type and print directly:
```
When a specific subclass is used as argument of an operation, the
declarative assembly will omit the type and print directly:
```mlir
[1, 2, 3]
```
}];
let parameters = (ins
AttributeSelfTypeParameter<"", "RankedTensorType">:$type,
Builtin_DenseArrayRawDataParameter:$rawData
);
let builders = [
AttrBuilderWithInferredContext<(ins "RankedTensorType":$type,
"ArrayRef<char>":$rawData), [{
return $_get(type.getContext(), type, rawData);
}]>,
];
let extraClassDeclaration = [{
/// Allow implicit conversion to ElementsAttr.
operator ElementsAttr() const {
@@ -207,13 +222,9 @@ def Builtin_DenseArrayBase : Builtin_Attr<
const int64_t *value_begin_impl(OverloadToken<int64_t>) const;
const float *value_begin_impl(OverloadToken<float>) const;
const double *value_begin_impl(OverloadToken<double>) const;
/// Printer for the short form: will dispatch to the appropriate subclass.
void print(AsmPrinter &printer) const;
void print(raw_ostream &os) const;
/// Print the short form `42, 100, -1` without any braces or prefix.
void printWithoutBraces(raw_ostream &os) const;
}];
let genVerifyDecl = 1;
}
//===----------------------------------------------------------------------===//

View File

@@ -827,96 +827,142 @@ ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
}
//===----------------------------------------------------------------------===//
// ElementsAttr Parser
// DenseArrayAttr Parser
//===----------------------------------------------------------------------===//
namespace {
/// This class provides an implementation of AsmParser, allowing to call back
/// into the libMLIRIR-provided APIs for invoking attribute parsing code defined
/// in libMLIRIR.
class CustomAsmParser : public AsmParserImpl<AsmParser> {
/// A generic dense array element parser. It parsers integer and floating point
/// elements.
class DenseArrayElementParser {
public:
CustomAsmParser(Parser &parser)
: AsmParserImpl<AsmParser>(parser.getToken().getLoc(), parser) {}
explicit DenseArrayElementParser(Type type) : type(type) {}
/// Parse an integer element.
ParseResult parseIntegerElement(Parser &p);
/// Parse a floating point element.
ParseResult parseFloatElement(Parser &p);
/// Convert the current contents to a dense array.
DenseArrayAttr getAttr() {
return DenseArrayAttr::get(RankedTensorType::get(size, type), rawData);
}
private:
/// Append the raw data of an APInt to the result.
void append(const APInt &data);
/// The array element type.
Type type;
/// The resultant byte array representing the contents of the array.
std::vector<char> rawData;
/// The number of elements in the array.
int64_t size = 0;
};
} // namespace
void DenseArrayElementParser::append(const APInt &data) {
unsigned byteSize = data.getBitWidth() / 8;
size_t offset = rawData.size();
rawData.insert(rawData.end(), byteSize, 0);
llvm::StoreIntToMemory(
data, reinterpret_cast<uint8_t *>(rawData.data() + offset), byteSize);
++size;
}
ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) {
bool isNegative = p.consumeIf(Token::minus);
// Parse an integer literal as an APInt.
Optional<APInt> value;
StringRef spelling = p.getToken().getSpelling();
if (p.getToken().isAny(Token::kw_true, Token::kw_false)) {
if (!type.isInteger(1))
return p.emitError("expected i1 type for 'true' or 'false' values");
value = APInt(/*numBits=*/8, p.getToken().is(Token::kw_true),
!type.isUnsignedInteger());
p.consumeToken();
} else if (p.consumeIf(Token::integer)) {
value = buildAttributeAPInt(type, isNegative, spelling);
if (!value)
return p.emitError("integer constant out of range");
} else {
return p.emitError("expected integer literal");
}
append(*value);
return success();
}
ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) {
bool isNegative = p.consumeIf(Token::minus);
Token token = p.getToken();
Optional<APFloat> result;
auto floatType = type.cast<FloatType>();
if (p.consumeIf(Token::integer)) {
// Parse an integer literal as a float.
if (p.parseFloatFromIntegerLiteral(result, token, isNegative,
floatType.getFloatSemantics(),
floatType.getWidth()))
return failure();
} else if (p.consumeIf(Token::floatliteral)) {
// Parse a floating point literal.
Optional<double> val = token.getFloatingPointValue();
if (!val)
return failure();
result = APFloat(isNegative ? -*val : *val);
if (!type.isF64()) {
bool unused;
result->convert(floatType.getFloatSemantics(),
APFloat::rmNearestTiesToEven, &unused);
}
} else {
return p.emitError("expected integer or floating point literal");
}
append(result->bitcastToAPInt());
return success();
}
/// Parse a dense array attribute.
Attribute Parser::parseDenseArrayAttr(Type type) {
consumeToken(Token::kw_array);
SMLoc typeLoc = getToken().getLoc();
if (parseToken(Token::less, "expected '<' after 'array'") ||
(!type && !(type = parseType())))
return {};
CustomAsmParser parser(*this);
Attribute result;
// Check for empty list.
bool isEmptyList = getToken().is(Token::greater);
if (!isEmptyList &&
parseToken(Token::colon, "expected ':' after dense array type"))
if (parseToken(Token::less, "expected '<' after 'array'"))
return {};
if (auto intType = type.dyn_cast<IntegerType>()) {
switch (type.getIntOrFloatBitWidth()) {
case 1:
if (isEmptyList)
result = DenseBoolArrayAttr::get(parser.getContext(), {});
else
result = DenseBoolArrayAttr::parseWithoutBraces(parser, Type{});
break;
case 8:
if (isEmptyList)
result = DenseI8ArrayAttr::get(parser.getContext(), {});
else
result = DenseI8ArrayAttr::parseWithoutBraces(parser, Type{});
break;
case 16:
if (isEmptyList)
result = DenseI16ArrayAttr::get(parser.getContext(), {});
else
result = DenseI16ArrayAttr::parseWithoutBraces(parser, Type{});
break;
case 32:
if (isEmptyList)
result = DenseI32ArrayAttr::get(parser.getContext(), {});
else
result = DenseI32ArrayAttr::parseWithoutBraces(parser, Type{});
break;
case 64:
if (isEmptyList)
result = DenseI64ArrayAttr::get(parser.getContext(), {});
else
result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{});
break;
default:
emitError(typeLoc, "expected i1, i8, i16, i32, or i64 but got: ") << type;
return {};
}
} else if (auto floatType = type.dyn_cast<FloatType>()) {
switch (type.getIntOrFloatBitWidth()) {
case 32:
if (isEmptyList)
result = DenseF32ArrayAttr::get(parser.getContext(), {});
else
result = DenseF32ArrayAttr::parseWithoutBraces(parser, Type{});
break;
case 64:
if (isEmptyList)
result = DenseF64ArrayAttr::get(parser.getContext(), {});
else
result = DenseF64ArrayAttr::parseWithoutBraces(parser, Type{});
break;
default:
emitError(typeLoc, "expected f32 or f64 but got: ") << type;
return {};
}
} else {
// Only bool or integer and floating point elements divisible by bytes are
// supported.
SMLoc typeLoc = getToken().getLoc();
if (!type && !(type = parseType()))
return {};
if (!type.isIntOrIndexOrFloat()) {
emitError(typeLoc, "expected integer or float type, got: ") << type;
return {};
}
if (!type.isInteger(1) && type.getIntOrFloatBitWidth() % 8 != 0) {
emitError(typeLoc, "element type bitwidth must be a multiple of 8");
return {};
}
// Check for empty list.
if (consumeIf(Token::greater))
return DenseArrayAttr::get(RankedTensorType::get(0, type), {});
if (parseToken(Token::colon, "expected ':' after dense array type"))
return {};
DenseArrayElementParser eltParser(type);
if (type.isIntOrIndex()) {
if (parseCommaSeparatedList(
[&] { return eltParser.parseIntegerElement(*this); }))
return {};
} else {
if (parseCommaSeparatedList(
[&] { return eltParser.parseFloatElement(*this); }))
return {};
}
if (parseToken(Token::greater, "expected '>' to close an array attribute"))
return {};
return result;
return eltParser.getAttr();
}
/// Parse a dense elements attribute.

View File

@@ -383,7 +383,7 @@ MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, intptr_t size,
// Accessors.
intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) {
return unwrap(attr).cast<DenseArrayBaseAttr>().size();
return unwrap(attr).cast<DenseArrayAttr>().size();
}
//===----------------------------------------------------------------------===//

View File

@@ -1476,6 +1476,9 @@ protected:
void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
bool allowHex);
/// Print a dense array attribute.
void printDenseArrayAttr(DenseArrayAttr attr);
void printDialectAttribute(Attribute attr);
void printDialectType(Type type);
@@ -1860,12 +1863,13 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
}
} else if (auto stridedLayoutAttr = attr.dyn_cast<StridedLayoutAttr>()) {
stridedLayoutAttr.print(os);
} else if (auto denseArrayAttr = attr.dyn_cast<DenseArrayBaseAttr>()) {
} else if (auto denseArrayAttr = attr.dyn_cast<DenseArrayAttr>()) {
typeElision = AttrTypeElision::Must;
os << "array<" << denseArrayAttr.getType().getElementType();
if (!denseArrayAttr.empty())
if (!denseArrayAttr.empty()) {
os << ": ";
denseArrayAttr.printWithoutBraces(os);
printDenseArrayAttr(denseArrayAttr);
}
os << ">";
} else if (auto resourceAttr = attr.dyn_cast<DenseResourceElementsAttr>()) {
os << "dense_resource<";
@@ -1890,11 +1894,11 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
/// Print the integer element of a DenseElementsAttr.
static void printDenseIntElement(const APInt &value, raw_ostream &os,
bool isSigned) {
if (value.getBitWidth() == 1)
Type type) {
if (type.isInteger(1))
os << (value.getBoolValue() ? "true" : "false");
else
value.print(os, isSigned);
value.print(os, !type.isUnsignedInteger());
}
static void
@@ -1988,14 +1992,13 @@ void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
// printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2
// and hence was replaced.
if (complexElementType.isa<IntegerType>()) {
bool isSigned = !complexElementType.isUnsignedInteger();
auto valueIt = attr.value_begin<std::complex<APInt>>();
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
auto complexValue = *(valueIt + index);
os << "(";
printDenseIntElement(complexValue.real(), os, isSigned);
printDenseIntElement(complexValue.real(), os, complexElementType);
os << ",";
printDenseIntElement(complexValue.imag(), os, isSigned);
printDenseIntElement(complexValue.imag(), os, complexElementType);
os << ")";
});
} else {
@@ -2010,10 +2013,9 @@ void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
});
}
} else if (elementType.isIntOrIndex()) {
bool isSigned = !elementType.isUnsignedInteger();
auto valueIt = attr.value_begin<APInt>();
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
printDenseIntElement(*(valueIt + index), os, isSigned);
printDenseIntElement(*(valueIt + index), os, elementType);
});
} else {
assert(elementType.isa<FloatType>() && "unexpected element type");
@@ -2031,6 +2033,29 @@ void AsmPrinter::Impl::printDenseStringElementsAttr(
printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
}
void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) {
Type type = attr.getElementType();
unsigned bitwidth = type.isInteger(1) ? 8 : type.getIntOrFloatBitWidth();
unsigned byteSize = bitwidth / 8;
ArrayRef<char> data = attr.getRawData();
auto printElementAt = [&](unsigned i) {
APInt value(bitwidth, 0);
llvm::LoadIntFromMemory(
value, reinterpret_cast<const uint8_t *>(data.begin() + byteSize * i),
byteSize);
// Print the data as-is or as a float.
if (type.isIntOrIndex()) {
printDenseIntElement(value, getStream(), type);
} else {
APFloat fltVal(type.cast<FloatType>().getFloatSemantics(), value);
printFloatValue(fltVal, getStream());
}
};
llvm::interleaveComma(llvm::seq<unsigned>(0, attr.size()), getStream(),
printElementAt);
}
void AsmPrinter::Impl::printType(Type type) {
if (!type) {
os << "<<NULL TYPE>>";

View File

@@ -741,50 +741,50 @@ DenseElementsAttr::ComplexIntElementIterator::operator*() const {
// DenseArrayAttr
//===----------------------------------------------------------------------===//
const bool *DenseArrayBaseAttr::value_begin_impl(OverloadToken<bool>) const {
LogicalResult
DenseArrayAttr::verify(function_ref<InFlightDiagnostic()> emitError,
RankedTensorType type, ArrayRef<char> rawData) {
if (type.getRank() != 1)
return emitError() << "expected rank 1 tensor type";
if (!type.getElementType().isIntOrIndexOrFloat())
return emitError() << "expected integer or floating point element type";
int64_t dataSize = rawData.size();
int64_t size = type.getShape().front();
if (type.getElementType().isInteger(1)) {
if (size != dataSize)
return emitError() << "expected " << size
<< " bytes for i1 array but got " << dataSize;
} else if (size * type.getElementTypeBitWidth() != dataSize * 8) {
return emitError() << "expected data size (" << size << " elements, "
<< type.getElementTypeBitWidth()
<< " bits each) does not match: " << dataSize
<< " bytes";
}
return success();
}
const bool *DenseArrayAttr::value_begin_impl(OverloadToken<bool>) const {
return cast<DenseBoolArrayAttr>().asArrayRef().begin();
}
const int8_t *
DenseArrayBaseAttr::value_begin_impl(OverloadToken<int8_t>) const {
const int8_t *DenseArrayAttr::value_begin_impl(OverloadToken<int8_t>) const {
return cast<DenseI8ArrayAttr>().asArrayRef().begin();
}
const int16_t *
DenseArrayBaseAttr::value_begin_impl(OverloadToken<int16_t>) const {
const int16_t *DenseArrayAttr::value_begin_impl(OverloadToken<int16_t>) const {
return cast<DenseI16ArrayAttr>().asArrayRef().begin();
}
const int32_t *
DenseArrayBaseAttr::value_begin_impl(OverloadToken<int32_t>) const {
const int32_t *DenseArrayAttr::value_begin_impl(OverloadToken<int32_t>) const {
return cast<DenseI32ArrayAttr>().asArrayRef().begin();
}
const int64_t *
DenseArrayBaseAttr::value_begin_impl(OverloadToken<int64_t>) const {
const int64_t *DenseArrayAttr::value_begin_impl(OverloadToken<int64_t>) const {
return cast<DenseI64ArrayAttr>().asArrayRef().begin();
}
const float *DenseArrayBaseAttr::value_begin_impl(OverloadToken<float>) const {
const float *DenseArrayAttr::value_begin_impl(OverloadToken<float>) const {
return cast<DenseF32ArrayAttr>().asArrayRef().begin();
}
const double *
DenseArrayBaseAttr::value_begin_impl(OverloadToken<double>) const {
const double *DenseArrayAttr::value_begin_impl(OverloadToken<double>) const {
return cast<DenseF64ArrayAttr>().asArrayRef().begin();
}
void DenseArrayBaseAttr::print(AsmPrinter &printer) const {
print(printer.getStream());
}
void DenseArrayBaseAttr::printWithoutBraces(raw_ostream &os) const {
llvm::TypeSwitch<DenseArrayBaseAttr>(*this)
.Case<DenseBoolArrayAttr, DenseI8ArrayAttr, DenseI16ArrayAttr,
DenseI32ArrayAttr, DenseI64ArrayAttr, DenseF32ArrayAttr,
DenseF64ArrayAttr>([&](auto attr) { attr.printWithoutBraces(os); });
}
void DenseArrayBaseAttr::print(raw_ostream &os) const {
os << "[";
printWithoutBraces(os);
os << "]";
}
namespace {
/// Instantiations of this class provide utilities for interacting with native
/// data types in the context of DenseArrayAttr.
@@ -869,19 +869,19 @@ struct DenseArrayAttrUtil<double> {
} // namespace
template <typename T>
void DenseArrayAttr<T>::print(AsmPrinter &printer) const {
void DenseArrayAttrImpl<T>::print(AsmPrinter &printer) const {
print(printer.getStream());
}
template <typename T>
void DenseArrayAttr<T>::printWithoutBraces(raw_ostream &os) const {
void DenseArrayAttrImpl<T>::printWithoutBraces(raw_ostream &os) const {
llvm::interleaveComma(asArrayRef(), os, [&](T value) {
DenseArrayAttrUtil<T>::printElement(os, value);
});
}
template <typename T>
void DenseArrayAttr<T>::print(raw_ostream &os) const {
void DenseArrayAttrImpl<T>::print(raw_ostream &os) const {
os << "[";
printWithoutBraces(os);
os << "]";
@@ -889,8 +889,8 @@ void DenseArrayAttr<T>::print(raw_ostream &os) const {
/// Parse a DenseArrayAttr without the braces: `1, 2, 3`
template <typename T>
Attribute DenseArrayAttr<T>::parseWithoutBraces(AsmParser &parser,
Type odsType) {
Attribute DenseArrayAttrImpl<T>::parseWithoutBraces(AsmParser &parser,
Type odsType) {
SmallVector<T> data;
if (failed(parser.parseCommaSeparatedList([&]() {
T value;
@@ -905,7 +905,7 @@ Attribute DenseArrayAttr<T>::parseWithoutBraces(AsmParser &parser,
/// Parse a DenseArrayAttr: `[ 1, 2, 3 ]`
template <typename T>
Attribute DenseArrayAttr<T>::parse(AsmParser &parser, Type odsType) {
Attribute DenseArrayAttrImpl<T>::parse(AsmParser &parser, Type odsType) {
if (parser.parseLSquare())
return {};
// Handle empty list case.
@@ -919,7 +919,7 @@ Attribute DenseArrayAttr<T>::parse(AsmParser &parser, Type odsType) {
/// Conversion from DenseArrayAttr<T> to ArrayRef<T>.
template <typename T>
DenseArrayAttr<T>::operator ArrayRef<T>() const {
DenseArrayAttrImpl<T>::operator ArrayRef<T>() const {
ArrayRef<char> raw = getRawData();
assert((raw.size() % sizeof(T)) == 0);
return ArrayRef<T>(reinterpret_cast<const T *>(raw.data()),
@@ -928,19 +928,19 @@ DenseArrayAttr<T>::operator ArrayRef<T>() const {
/// Builds a DenseArrayAttr<T> from an ArrayRef<T>.
template <typename T>
DenseArrayAttr<T> DenseArrayAttr<T>::get(MLIRContext *context,
ArrayRef<T> content) {
DenseArrayAttrImpl<T> DenseArrayAttrImpl<T>::get(MLIRContext *context,
ArrayRef<T> content) {
auto shapedType = RankedTensorType::get(
content.size(), DenseArrayAttrUtil<T>::getElementType(context));
auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()),
content.size() * sizeof(T));
return Base::get(context, shapedType, rawArray)
.template cast<DenseArrayAttr<T>>();
.template cast<DenseArrayAttrImpl<T>>();
}
template <typename T>
bool DenseArrayAttr<T>::classof(Attribute attr) {
if (auto denseArray = attr.dyn_cast<DenseArrayBaseAttr>())
bool DenseArrayAttrImpl<T>::classof(Attribute attr) {
if (auto denseArray = attr.dyn_cast<DenseArrayAttr>())
return DenseArrayAttrUtil<T>::checkElementType(denseArray.getElementType());
return false;
}
@@ -948,13 +948,13 @@ bool DenseArrayAttr<T>::classof(Attribute attr) {
namespace mlir {
namespace detail {
// Explicit instantiation for all the supported DenseArrayAttr.
template class DenseArrayAttr<bool>;
template class DenseArrayAttr<int8_t>;
template class DenseArrayAttr<int16_t>;
template class DenseArrayAttr<int32_t>;
template class DenseArrayAttr<int64_t>;
template class DenseArrayAttr<float>;
template class DenseArrayAttr<double>;
template class DenseArrayAttrImpl<bool>;
template class DenseArrayAttrImpl<int8_t>;
template class DenseArrayAttrImpl<int16_t>;
template class DenseArrayAttrImpl<int32_t>;
template class DenseArrayAttrImpl<int64_t>;
template class DenseArrayAttrImpl<float>;
template class DenseArrayAttrImpl<double>;
} // namespace detail
} // namespace mlir

View File

@@ -569,6 +569,26 @@ func.func @dense_array_attr() attributes {
f64attr = [-142.]
// CHECK-SAME: emptyattr = []
emptyattr = []
// CHECK: array.sizes
// CHECK-SAME: i0 = array<i0: 0, 0>
// CHECK-SAME: ui0 = array<ui0: 0, 0>
// CHECK-SAME: si0 = array<si0: 0, 0>
// CHECK-SAME: i24 = array<i24: -42, 42, 8388607>
// CHECK-SAME: ui24 = array<ui24: 16777215>
// CHECK-SAME: si24 = array<si24: -8388608>
// CHECK-SAME: bf16 = array<bf16: 1.2{{[0-9]+}}e+00, 3.4{{[0-9]+}}e+00>
// CHECK-SAME: f16 = array<f16: 1.{{[0-9]+}}e+00, 3.{{[0-9]+}}e+00>
"array.sizes"() {
x0_i0 = array<i0: 0, 0>,
x1_ui0 = array<ui0: 0, 0>,
x2_si0 = array<si0: 0, 0>,
x3_i24 = array<i24: -42, 42, 8388607>,
x4_ui24 = array<ui24: 16777215>,
x5_si24 = array<si24: -8388608>,
x6_bf16 = array<bf16: 1.2, 3.4>,
x7_f16 = array<f16: 1., 3.>
}: () -> ()
return
}

View File

@@ -521,3 +521,28 @@ func.func @duplicate_dictionary_attr_key() {
// expected-error@+1 {{`dense_resource` expected a shaped type}}
#attr = dense_resource<resource> : i32
// -----
// expected-error@below {{expected '<' after 'array'}}
#attr = array
// -----
// expected-error@below {{expected integer or float type}}
#attr = array<vector<i32>>
// -----
// expected-error@below {{element type bitwidth must be a multiple of 8}}
#attr = array<i7>
// -----
// expected-error@below {{expected ':' after dense array type}}
#attr = array<i8)
// -----
// expected-error@below {{expected '>' to close an array attribute}}
#attr = array<i8: 1)

View File

@@ -41,9 +41,8 @@ struct TestElementsAttrInterface
auto elementsAttr = attr.getValue().dyn_cast<ElementsAttr>();
if (!elementsAttr)
continue;
if (auto concreteAttr =
attr.getValue().dyn_cast<DenseArrayBaseAttr>()) {
llvm::TypeSwitch<DenseArrayBaseAttr>(concreteAttr)
if (auto concreteAttr = attr.getValue().dyn_cast<DenseArrayAttr>()) {
llvm::TypeSwitch<DenseArrayAttr>(concreteAttr)
.Case([&](DenseBoolArrayAttr attr) {
testElementsAttrIteration<bool>(op, attr, "bool");
})