mirror of
https://github.com/intel/llvm.git
synced 2026-02-05 04:46:27 +08:00
[mlir] Make DenseArrayAttr generic
This patch turns `DenseArrayBaseAttr` into a fully-functional attribute by adding a generic parser and printer, supporting bool or integer and floating point element types with bitwidths divisible by 8. It has been renamed to `DenseArrayAttr`. The patch maintains the specialized subclasses, e.g. `DenseI32ArrayAttr`, which remain the preferred API for accessing elements in C++. This allows `DenseArrayAttr` to hold signed and unsigned integer elements: ``` array<si8: -128, 127> array<ui8: 255> ``` "Exotic" floating point elements: ``` array<bf16: 1.2, 3.4> ``` And integers of other bitwidths: ``` array<i24: 8388607> ``` Reviewed By: rriddle, lattner Differential Revision: https://reviews.llvm.org/D132758
This commit is contained in:
@@ -761,9 +761,9 @@ namespace detail {
|
||||
/// Base class for DenseArrayAttr that is instantiated and specialized for each
|
||||
/// supported element type below.
|
||||
template <typename T>
|
||||
class DenseArrayAttr : public DenseArrayBaseAttr {
|
||||
class DenseArrayAttrImpl : public DenseArrayAttr {
|
||||
public:
|
||||
using DenseArrayBaseAttr::DenseArrayBaseAttr;
|
||||
using DenseArrayAttr::DenseArrayAttr;
|
||||
|
||||
/// Implicit conversion to ArrayRef<T>.
|
||||
operator ArrayRef<T>() const;
|
||||
@@ -773,7 +773,7 @@ public:
|
||||
T operator[](std::size_t index) const { return asArrayRef()[index]; }
|
||||
|
||||
/// Builder from ArrayRef<T>.
|
||||
static DenseArrayAttr get(MLIRContext *context, ArrayRef<T> content);
|
||||
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef<T> content);
|
||||
|
||||
/// Print the short form `[42, 100, -1]` without any type prefix.
|
||||
void print(AsmPrinter &printer) const;
|
||||
@@ -791,23 +791,23 @@ public:
|
||||
static bool classof(Attribute attr);
|
||||
};
|
||||
|
||||
extern template class DenseArrayAttr<bool>;
|
||||
extern template class DenseArrayAttr<int8_t>;
|
||||
extern template class DenseArrayAttr<int16_t>;
|
||||
extern template class DenseArrayAttr<int32_t>;
|
||||
extern template class DenseArrayAttr<int64_t>;
|
||||
extern template class DenseArrayAttr<float>;
|
||||
extern template class DenseArrayAttr<double>;
|
||||
extern template class DenseArrayAttrImpl<bool>;
|
||||
extern template class DenseArrayAttrImpl<int8_t>;
|
||||
extern template class DenseArrayAttrImpl<int16_t>;
|
||||
extern template class DenseArrayAttrImpl<int32_t>;
|
||||
extern template class DenseArrayAttrImpl<int64_t>;
|
||||
extern template class DenseArrayAttrImpl<float>;
|
||||
extern template class DenseArrayAttrImpl<double>;
|
||||
} // namespace detail
|
||||
|
||||
// Public name for all the supported DenseArrayAttr
|
||||
using DenseBoolArrayAttr = detail::DenseArrayAttr<bool>;
|
||||
using DenseI8ArrayAttr = detail::DenseArrayAttr<int8_t>;
|
||||
using DenseI16ArrayAttr = detail::DenseArrayAttr<int16_t>;
|
||||
using DenseI32ArrayAttr = detail::DenseArrayAttr<int32_t>;
|
||||
using DenseI64ArrayAttr = detail::DenseArrayAttr<int64_t>;
|
||||
using DenseF32ArrayAttr = detail::DenseArrayAttr<float>;
|
||||
using DenseF64ArrayAttr = detail::DenseArrayAttr<double>;
|
||||
using DenseBoolArrayAttr = detail::DenseArrayAttrImpl<bool>;
|
||||
using DenseI8ArrayAttr = detail::DenseArrayAttrImpl<int8_t>;
|
||||
using DenseI16ArrayAttr = detail::DenseArrayAttrImpl<int16_t>;
|
||||
using DenseI32ArrayAttr = detail::DenseArrayAttrImpl<int32_t>;
|
||||
using DenseI64ArrayAttr = detail::DenseArrayAttrImpl<int64_t>;
|
||||
using DenseF32ArrayAttr = detail::DenseArrayAttrImpl<float>;
|
||||
using DenseF64ArrayAttr = detail::DenseArrayAttrImpl<double>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DenseResourceElementsAttr
|
||||
|
||||
@@ -140,7 +140,7 @@ def Builtin_ArrayAttr : Builtin_Attr<"Array", [
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DenseArrayBaseAttr
|
||||
// DenseArrayAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Builtin_DenseArrayRawDataParameter : ArrayRefParameter<
|
||||
@@ -155,23 +155,28 @@ def Builtin_DenseArrayRawDataParameter : ArrayRefParameter<
|
||||
}];
|
||||
}
|
||||
|
||||
def Builtin_DenseArrayBase : Builtin_Attr<
|
||||
"DenseArrayBase", [ElementsAttrInterface, TypedAttrInterface]> {
|
||||
let summary = "A dense array of i8, i16, i32, i64, f32, or f64.";
|
||||
def Builtin_DenseArray : Builtin_Attr<
|
||||
"DenseArray", [ElementsAttrInterface, TypedAttrInterface]> {
|
||||
let summary = "A dense array of integer or floating point elements.";
|
||||
let description = [{
|
||||
A dense array attribute is an attribute that represents a dense array of
|
||||
primitive element types. Contrary to DenseIntOrFPElementsAttr this is a
|
||||
flat unidimensional array which does not have a storage optimization for
|
||||
splat. This allows to expose the raw array through a C++ API as
|
||||
`ArrayRef<T>`. This is the base class attribute, the actual access is
|
||||
intended to be managed through the subclasses `DenseI8ArrayAttr`,
|
||||
`DenseI16ArrayAttr`, `DenseI32ArrayAttr`, `DenseI64ArrayAttr`,
|
||||
`DenseF32ArrayAttr`, and `DenseF64ArrayAttr`.
|
||||
`ArrayRef<T>` for compatible types. The element type must be bool or an
|
||||
integer or float whose bitwidth is a multiple of 8. Bool elements are stored
|
||||
as bytes.
|
||||
|
||||
This is the base class attribute. Access to C++ types is intended to be
|
||||
managed through the subclasses `DenseI8ArrayAttr`, `DenseI16ArrayAttr`,
|
||||
`DenseI32ArrayAttr`, `DenseI64ArrayAttr`, `DenseF32ArrayAttr`,
|
||||
and `DenseF64ArrayAttr`.
|
||||
|
||||
Syntax:
|
||||
|
||||
```
|
||||
dense-array-attribute ::= `[` `:` (integer-type | float-type) tensor-literal `]`
|
||||
dense-array-attribute ::= `array` `<` (integer-type | float-type)
|
||||
(`:` tensor-literal)? `>`
|
||||
```
|
||||
Examples:
|
||||
|
||||
@@ -181,16 +186,26 @@ def Builtin_DenseArrayBase : Builtin_Attr<
|
||||
array<f64: 42., 12.>
|
||||
```
|
||||
|
||||
when a specific subclass is used as argument of an operation, the declarative
|
||||
assembly will omit the type and print directly:
|
||||
```
|
||||
When a specific subclass is used as argument of an operation, the
|
||||
declarative assembly will omit the type and print directly:
|
||||
|
||||
```mlir
|
||||
[1, 2, 3]
|
||||
```
|
||||
}];
|
||||
|
||||
let parameters = (ins
|
||||
AttributeSelfTypeParameter<"", "RankedTensorType">:$type,
|
||||
Builtin_DenseArrayRawDataParameter:$rawData
|
||||
);
|
||||
|
||||
let builders = [
|
||||
AttrBuilderWithInferredContext<(ins "RankedTensorType":$type,
|
||||
"ArrayRef<char>":$rawData), [{
|
||||
return $_get(type.getContext(), type, rawData);
|
||||
}]>,
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Allow implicit conversion to ElementsAttr.
|
||||
operator ElementsAttr() const {
|
||||
@@ -207,13 +222,9 @@ def Builtin_DenseArrayBase : Builtin_Attr<
|
||||
const int64_t *value_begin_impl(OverloadToken<int64_t>) const;
|
||||
const float *value_begin_impl(OverloadToken<float>) const;
|
||||
const double *value_begin_impl(OverloadToken<double>) const;
|
||||
|
||||
/// Printer for the short form: will dispatch to the appropriate subclass.
|
||||
void print(AsmPrinter &printer) const;
|
||||
void print(raw_ostream &os) const;
|
||||
/// Print the short form `42, 100, -1` without any braces or prefix.
|
||||
void printWithoutBraces(raw_ostream &os) const;
|
||||
}];
|
||||
|
||||
let genVerifyDecl = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -827,96 +827,142 @@ ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ElementsAttr Parser
|
||||
// DenseArrayAttr Parser
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// This class provides an implementation of AsmParser, allowing to call back
|
||||
/// into the libMLIRIR-provided APIs for invoking attribute parsing code defined
|
||||
/// in libMLIRIR.
|
||||
class CustomAsmParser : public AsmParserImpl<AsmParser> {
|
||||
/// A generic dense array element parser. It parsers integer and floating point
|
||||
/// elements.
|
||||
class DenseArrayElementParser {
|
||||
public:
|
||||
CustomAsmParser(Parser &parser)
|
||||
: AsmParserImpl<AsmParser>(parser.getToken().getLoc(), parser) {}
|
||||
explicit DenseArrayElementParser(Type type) : type(type) {}
|
||||
|
||||
/// Parse an integer element.
|
||||
ParseResult parseIntegerElement(Parser &p);
|
||||
|
||||
/// Parse a floating point element.
|
||||
ParseResult parseFloatElement(Parser &p);
|
||||
|
||||
/// Convert the current contents to a dense array.
|
||||
DenseArrayAttr getAttr() {
|
||||
return DenseArrayAttr::get(RankedTensorType::get(size, type), rawData);
|
||||
}
|
||||
|
||||
private:
|
||||
/// Append the raw data of an APInt to the result.
|
||||
void append(const APInt &data);
|
||||
|
||||
/// The array element type.
|
||||
Type type;
|
||||
/// The resultant byte array representing the contents of the array.
|
||||
std::vector<char> rawData;
|
||||
/// The number of elements in the array.
|
||||
int64_t size = 0;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void DenseArrayElementParser::append(const APInt &data) {
|
||||
unsigned byteSize = data.getBitWidth() / 8;
|
||||
size_t offset = rawData.size();
|
||||
rawData.insert(rawData.end(), byteSize, 0);
|
||||
llvm::StoreIntToMemory(
|
||||
data, reinterpret_cast<uint8_t *>(rawData.data() + offset), byteSize);
|
||||
++size;
|
||||
}
|
||||
|
||||
ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) {
|
||||
bool isNegative = p.consumeIf(Token::minus);
|
||||
|
||||
// Parse an integer literal as an APInt.
|
||||
Optional<APInt> value;
|
||||
StringRef spelling = p.getToken().getSpelling();
|
||||
if (p.getToken().isAny(Token::kw_true, Token::kw_false)) {
|
||||
if (!type.isInteger(1))
|
||||
return p.emitError("expected i1 type for 'true' or 'false' values");
|
||||
value = APInt(/*numBits=*/8, p.getToken().is(Token::kw_true),
|
||||
!type.isUnsignedInteger());
|
||||
p.consumeToken();
|
||||
} else if (p.consumeIf(Token::integer)) {
|
||||
value = buildAttributeAPInt(type, isNegative, spelling);
|
||||
if (!value)
|
||||
return p.emitError("integer constant out of range");
|
||||
} else {
|
||||
return p.emitError("expected integer literal");
|
||||
}
|
||||
append(*value);
|
||||
return success();
|
||||
}
|
||||
|
||||
ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) {
|
||||
bool isNegative = p.consumeIf(Token::minus);
|
||||
|
||||
Token token = p.getToken();
|
||||
Optional<APFloat> result;
|
||||
auto floatType = type.cast<FloatType>();
|
||||
if (p.consumeIf(Token::integer)) {
|
||||
// Parse an integer literal as a float.
|
||||
if (p.parseFloatFromIntegerLiteral(result, token, isNegative,
|
||||
floatType.getFloatSemantics(),
|
||||
floatType.getWidth()))
|
||||
return failure();
|
||||
} else if (p.consumeIf(Token::floatliteral)) {
|
||||
// Parse a floating point literal.
|
||||
Optional<double> val = token.getFloatingPointValue();
|
||||
if (!val)
|
||||
return failure();
|
||||
result = APFloat(isNegative ? -*val : *val);
|
||||
if (!type.isF64()) {
|
||||
bool unused;
|
||||
result->convert(floatType.getFloatSemantics(),
|
||||
APFloat::rmNearestTiesToEven, &unused);
|
||||
}
|
||||
} else {
|
||||
return p.emitError("expected integer or floating point literal");
|
||||
}
|
||||
|
||||
append(result->bitcastToAPInt());
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Parse a dense array attribute.
|
||||
Attribute Parser::parseDenseArrayAttr(Type type) {
|
||||
consumeToken(Token::kw_array);
|
||||
SMLoc typeLoc = getToken().getLoc();
|
||||
if (parseToken(Token::less, "expected '<' after 'array'") ||
|
||||
(!type && !(type = parseType())))
|
||||
return {};
|
||||
CustomAsmParser parser(*this);
|
||||
Attribute result;
|
||||
// Check for empty list.
|
||||
bool isEmptyList = getToken().is(Token::greater);
|
||||
if (!isEmptyList &&
|
||||
parseToken(Token::colon, "expected ':' after dense array type"))
|
||||
if (parseToken(Token::less, "expected '<' after 'array'"))
|
||||
return {};
|
||||
|
||||
if (auto intType = type.dyn_cast<IntegerType>()) {
|
||||
switch (type.getIntOrFloatBitWidth()) {
|
||||
case 1:
|
||||
if (isEmptyList)
|
||||
result = DenseBoolArrayAttr::get(parser.getContext(), {});
|
||||
else
|
||||
result = DenseBoolArrayAttr::parseWithoutBraces(parser, Type{});
|
||||
break;
|
||||
case 8:
|
||||
if (isEmptyList)
|
||||
result = DenseI8ArrayAttr::get(parser.getContext(), {});
|
||||
else
|
||||
result = DenseI8ArrayAttr::parseWithoutBraces(parser, Type{});
|
||||
break;
|
||||
case 16:
|
||||
if (isEmptyList)
|
||||
result = DenseI16ArrayAttr::get(parser.getContext(), {});
|
||||
else
|
||||
result = DenseI16ArrayAttr::parseWithoutBraces(parser, Type{});
|
||||
break;
|
||||
case 32:
|
||||
if (isEmptyList)
|
||||
result = DenseI32ArrayAttr::get(parser.getContext(), {});
|
||||
else
|
||||
result = DenseI32ArrayAttr::parseWithoutBraces(parser, Type{});
|
||||
break;
|
||||
case 64:
|
||||
if (isEmptyList)
|
||||
result = DenseI64ArrayAttr::get(parser.getContext(), {});
|
||||
else
|
||||
result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{});
|
||||
break;
|
||||
default:
|
||||
emitError(typeLoc, "expected i1, i8, i16, i32, or i64 but got: ") << type;
|
||||
return {};
|
||||
}
|
||||
} else if (auto floatType = type.dyn_cast<FloatType>()) {
|
||||
switch (type.getIntOrFloatBitWidth()) {
|
||||
case 32:
|
||||
if (isEmptyList)
|
||||
result = DenseF32ArrayAttr::get(parser.getContext(), {});
|
||||
else
|
||||
result = DenseF32ArrayAttr::parseWithoutBraces(parser, Type{});
|
||||
break;
|
||||
case 64:
|
||||
if (isEmptyList)
|
||||
result = DenseF64ArrayAttr::get(parser.getContext(), {});
|
||||
else
|
||||
result = DenseF64ArrayAttr::parseWithoutBraces(parser, Type{});
|
||||
break;
|
||||
default:
|
||||
emitError(typeLoc, "expected f32 or f64 but got: ") << type;
|
||||
return {};
|
||||
}
|
||||
} else {
|
||||
// Only bool or integer and floating point elements divisible by bytes are
|
||||
// supported.
|
||||
SMLoc typeLoc = getToken().getLoc();
|
||||
if (!type && !(type = parseType()))
|
||||
return {};
|
||||
if (!type.isIntOrIndexOrFloat()) {
|
||||
emitError(typeLoc, "expected integer or float type, got: ") << type;
|
||||
return {};
|
||||
}
|
||||
if (!type.isInteger(1) && type.getIntOrFloatBitWidth() % 8 != 0) {
|
||||
emitError(typeLoc, "element type bitwidth must be a multiple of 8");
|
||||
return {};
|
||||
}
|
||||
|
||||
// Check for empty list.
|
||||
if (consumeIf(Token::greater))
|
||||
return DenseArrayAttr::get(RankedTensorType::get(0, type), {});
|
||||
if (parseToken(Token::colon, "expected ':' after dense array type"))
|
||||
return {};
|
||||
|
||||
DenseArrayElementParser eltParser(type);
|
||||
if (type.isIntOrIndex()) {
|
||||
if (parseCommaSeparatedList(
|
||||
[&] { return eltParser.parseIntegerElement(*this); }))
|
||||
return {};
|
||||
} else {
|
||||
if (parseCommaSeparatedList(
|
||||
[&] { return eltParser.parseFloatElement(*this); }))
|
||||
return {};
|
||||
}
|
||||
if (parseToken(Token::greater, "expected '>' to close an array attribute"))
|
||||
return {};
|
||||
return result;
|
||||
return eltParser.getAttr();
|
||||
}
|
||||
|
||||
/// Parse a dense elements attribute.
|
||||
|
||||
@@ -383,7 +383,7 @@ MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, intptr_t size,
|
||||
// Accessors.
|
||||
|
||||
intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) {
|
||||
return unwrap(attr).cast<DenseArrayBaseAttr>().size();
|
||||
return unwrap(attr).cast<DenseArrayAttr>().size();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -1476,6 +1476,9 @@ protected:
|
||||
void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
|
||||
bool allowHex);
|
||||
|
||||
/// Print a dense array attribute.
|
||||
void printDenseArrayAttr(DenseArrayAttr attr);
|
||||
|
||||
void printDialectAttribute(Attribute attr);
|
||||
void printDialectType(Type type);
|
||||
|
||||
@@ -1860,12 +1863,13 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
|
||||
}
|
||||
} else if (auto stridedLayoutAttr = attr.dyn_cast<StridedLayoutAttr>()) {
|
||||
stridedLayoutAttr.print(os);
|
||||
} else if (auto denseArrayAttr = attr.dyn_cast<DenseArrayBaseAttr>()) {
|
||||
} else if (auto denseArrayAttr = attr.dyn_cast<DenseArrayAttr>()) {
|
||||
typeElision = AttrTypeElision::Must;
|
||||
os << "array<" << denseArrayAttr.getType().getElementType();
|
||||
if (!denseArrayAttr.empty())
|
||||
if (!denseArrayAttr.empty()) {
|
||||
os << ": ";
|
||||
denseArrayAttr.printWithoutBraces(os);
|
||||
printDenseArrayAttr(denseArrayAttr);
|
||||
}
|
||||
os << ">";
|
||||
} else if (auto resourceAttr = attr.dyn_cast<DenseResourceElementsAttr>()) {
|
||||
os << "dense_resource<";
|
||||
@@ -1890,11 +1894,11 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
|
||||
|
||||
/// Print the integer element of a DenseElementsAttr.
|
||||
static void printDenseIntElement(const APInt &value, raw_ostream &os,
|
||||
bool isSigned) {
|
||||
if (value.getBitWidth() == 1)
|
||||
Type type) {
|
||||
if (type.isInteger(1))
|
||||
os << (value.getBoolValue() ? "true" : "false");
|
||||
else
|
||||
value.print(os, isSigned);
|
||||
value.print(os, !type.isUnsignedInteger());
|
||||
}
|
||||
|
||||
static void
|
||||
@@ -1988,14 +1992,13 @@ void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
|
||||
// printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2
|
||||
// and hence was replaced.
|
||||
if (complexElementType.isa<IntegerType>()) {
|
||||
bool isSigned = !complexElementType.isUnsignedInteger();
|
||||
auto valueIt = attr.value_begin<std::complex<APInt>>();
|
||||
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
|
||||
auto complexValue = *(valueIt + index);
|
||||
os << "(";
|
||||
printDenseIntElement(complexValue.real(), os, isSigned);
|
||||
printDenseIntElement(complexValue.real(), os, complexElementType);
|
||||
os << ",";
|
||||
printDenseIntElement(complexValue.imag(), os, isSigned);
|
||||
printDenseIntElement(complexValue.imag(), os, complexElementType);
|
||||
os << ")";
|
||||
});
|
||||
} else {
|
||||
@@ -2010,10 +2013,9 @@ void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
|
||||
});
|
||||
}
|
||||
} else if (elementType.isIntOrIndex()) {
|
||||
bool isSigned = !elementType.isUnsignedInteger();
|
||||
auto valueIt = attr.value_begin<APInt>();
|
||||
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
|
||||
printDenseIntElement(*(valueIt + index), os, isSigned);
|
||||
printDenseIntElement(*(valueIt + index), os, elementType);
|
||||
});
|
||||
} else {
|
||||
assert(elementType.isa<FloatType>() && "unexpected element type");
|
||||
@@ -2031,6 +2033,29 @@ void AsmPrinter::Impl::printDenseStringElementsAttr(
|
||||
printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
|
||||
}
|
||||
|
||||
void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) {
|
||||
Type type = attr.getElementType();
|
||||
unsigned bitwidth = type.isInteger(1) ? 8 : type.getIntOrFloatBitWidth();
|
||||
unsigned byteSize = bitwidth / 8;
|
||||
ArrayRef<char> data = attr.getRawData();
|
||||
|
||||
auto printElementAt = [&](unsigned i) {
|
||||
APInt value(bitwidth, 0);
|
||||
llvm::LoadIntFromMemory(
|
||||
value, reinterpret_cast<const uint8_t *>(data.begin() + byteSize * i),
|
||||
byteSize);
|
||||
// Print the data as-is or as a float.
|
||||
if (type.isIntOrIndex()) {
|
||||
printDenseIntElement(value, getStream(), type);
|
||||
} else {
|
||||
APFloat fltVal(type.cast<FloatType>().getFloatSemantics(), value);
|
||||
printFloatValue(fltVal, getStream());
|
||||
}
|
||||
};
|
||||
llvm::interleaveComma(llvm::seq<unsigned>(0, attr.size()), getStream(),
|
||||
printElementAt);
|
||||
}
|
||||
|
||||
void AsmPrinter::Impl::printType(Type type) {
|
||||
if (!type) {
|
||||
os << "<<NULL TYPE>>";
|
||||
|
||||
@@ -741,50 +741,50 @@ DenseElementsAttr::ComplexIntElementIterator::operator*() const {
|
||||
// DenseArrayAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
const bool *DenseArrayBaseAttr::value_begin_impl(OverloadToken<bool>) const {
|
||||
LogicalResult
|
||||
DenseArrayAttr::verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
RankedTensorType type, ArrayRef<char> rawData) {
|
||||
if (type.getRank() != 1)
|
||||
return emitError() << "expected rank 1 tensor type";
|
||||
if (!type.getElementType().isIntOrIndexOrFloat())
|
||||
return emitError() << "expected integer or floating point element type";
|
||||
int64_t dataSize = rawData.size();
|
||||
int64_t size = type.getShape().front();
|
||||
if (type.getElementType().isInteger(1)) {
|
||||
if (size != dataSize)
|
||||
return emitError() << "expected " << size
|
||||
<< " bytes for i1 array but got " << dataSize;
|
||||
} else if (size * type.getElementTypeBitWidth() != dataSize * 8) {
|
||||
return emitError() << "expected data size (" << size << " elements, "
|
||||
<< type.getElementTypeBitWidth()
|
||||
<< " bits each) does not match: " << dataSize
|
||||
<< " bytes";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
const bool *DenseArrayAttr::value_begin_impl(OverloadToken<bool>) const {
|
||||
return cast<DenseBoolArrayAttr>().asArrayRef().begin();
|
||||
}
|
||||
const int8_t *
|
||||
DenseArrayBaseAttr::value_begin_impl(OverloadToken<int8_t>) const {
|
||||
const int8_t *DenseArrayAttr::value_begin_impl(OverloadToken<int8_t>) const {
|
||||
return cast<DenseI8ArrayAttr>().asArrayRef().begin();
|
||||
}
|
||||
const int16_t *
|
||||
DenseArrayBaseAttr::value_begin_impl(OverloadToken<int16_t>) const {
|
||||
const int16_t *DenseArrayAttr::value_begin_impl(OverloadToken<int16_t>) const {
|
||||
return cast<DenseI16ArrayAttr>().asArrayRef().begin();
|
||||
}
|
||||
const int32_t *
|
||||
DenseArrayBaseAttr::value_begin_impl(OverloadToken<int32_t>) const {
|
||||
const int32_t *DenseArrayAttr::value_begin_impl(OverloadToken<int32_t>) const {
|
||||
return cast<DenseI32ArrayAttr>().asArrayRef().begin();
|
||||
}
|
||||
const int64_t *
|
||||
DenseArrayBaseAttr::value_begin_impl(OverloadToken<int64_t>) const {
|
||||
const int64_t *DenseArrayAttr::value_begin_impl(OverloadToken<int64_t>) const {
|
||||
return cast<DenseI64ArrayAttr>().asArrayRef().begin();
|
||||
}
|
||||
const float *DenseArrayBaseAttr::value_begin_impl(OverloadToken<float>) const {
|
||||
const float *DenseArrayAttr::value_begin_impl(OverloadToken<float>) const {
|
||||
return cast<DenseF32ArrayAttr>().asArrayRef().begin();
|
||||
}
|
||||
const double *
|
||||
DenseArrayBaseAttr::value_begin_impl(OverloadToken<double>) const {
|
||||
const double *DenseArrayAttr::value_begin_impl(OverloadToken<double>) const {
|
||||
return cast<DenseF64ArrayAttr>().asArrayRef().begin();
|
||||
}
|
||||
|
||||
void DenseArrayBaseAttr::print(AsmPrinter &printer) const {
|
||||
print(printer.getStream());
|
||||
}
|
||||
|
||||
void DenseArrayBaseAttr::printWithoutBraces(raw_ostream &os) const {
|
||||
llvm::TypeSwitch<DenseArrayBaseAttr>(*this)
|
||||
.Case<DenseBoolArrayAttr, DenseI8ArrayAttr, DenseI16ArrayAttr,
|
||||
DenseI32ArrayAttr, DenseI64ArrayAttr, DenseF32ArrayAttr,
|
||||
DenseF64ArrayAttr>([&](auto attr) { attr.printWithoutBraces(os); });
|
||||
}
|
||||
|
||||
void DenseArrayBaseAttr::print(raw_ostream &os) const {
|
||||
os << "[";
|
||||
printWithoutBraces(os);
|
||||
os << "]";
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Instantiations of this class provide utilities for interacting with native
|
||||
/// data types in the context of DenseArrayAttr.
|
||||
@@ -869,19 +869,19 @@ struct DenseArrayAttrUtil<double> {
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void DenseArrayAttr<T>::print(AsmPrinter &printer) const {
|
||||
void DenseArrayAttrImpl<T>::print(AsmPrinter &printer) const {
|
||||
print(printer.getStream());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DenseArrayAttr<T>::printWithoutBraces(raw_ostream &os) const {
|
||||
void DenseArrayAttrImpl<T>::printWithoutBraces(raw_ostream &os) const {
|
||||
llvm::interleaveComma(asArrayRef(), os, [&](T value) {
|
||||
DenseArrayAttrUtil<T>::printElement(os, value);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DenseArrayAttr<T>::print(raw_ostream &os) const {
|
||||
void DenseArrayAttrImpl<T>::print(raw_ostream &os) const {
|
||||
os << "[";
|
||||
printWithoutBraces(os);
|
||||
os << "]";
|
||||
@@ -889,8 +889,8 @@ void DenseArrayAttr<T>::print(raw_ostream &os) const {
|
||||
|
||||
/// Parse a DenseArrayAttr without the braces: `1, 2, 3`
|
||||
template <typename T>
|
||||
Attribute DenseArrayAttr<T>::parseWithoutBraces(AsmParser &parser,
|
||||
Type odsType) {
|
||||
Attribute DenseArrayAttrImpl<T>::parseWithoutBraces(AsmParser &parser,
|
||||
Type odsType) {
|
||||
SmallVector<T> data;
|
||||
if (failed(parser.parseCommaSeparatedList([&]() {
|
||||
T value;
|
||||
@@ -905,7 +905,7 @@ Attribute DenseArrayAttr<T>::parseWithoutBraces(AsmParser &parser,
|
||||
|
||||
/// Parse a DenseArrayAttr: `[ 1, 2, 3 ]`
|
||||
template <typename T>
|
||||
Attribute DenseArrayAttr<T>::parse(AsmParser &parser, Type odsType) {
|
||||
Attribute DenseArrayAttrImpl<T>::parse(AsmParser &parser, Type odsType) {
|
||||
if (parser.parseLSquare())
|
||||
return {};
|
||||
// Handle empty list case.
|
||||
@@ -919,7 +919,7 @@ Attribute DenseArrayAttr<T>::parse(AsmParser &parser, Type odsType) {
|
||||
|
||||
/// Conversion from DenseArrayAttr<T> to ArrayRef<T>.
|
||||
template <typename T>
|
||||
DenseArrayAttr<T>::operator ArrayRef<T>() const {
|
||||
DenseArrayAttrImpl<T>::operator ArrayRef<T>() const {
|
||||
ArrayRef<char> raw = getRawData();
|
||||
assert((raw.size() % sizeof(T)) == 0);
|
||||
return ArrayRef<T>(reinterpret_cast<const T *>(raw.data()),
|
||||
@@ -928,19 +928,19 @@ DenseArrayAttr<T>::operator ArrayRef<T>() const {
|
||||
|
||||
/// Builds a DenseArrayAttr<T> from an ArrayRef<T>.
|
||||
template <typename T>
|
||||
DenseArrayAttr<T> DenseArrayAttr<T>::get(MLIRContext *context,
|
||||
ArrayRef<T> content) {
|
||||
DenseArrayAttrImpl<T> DenseArrayAttrImpl<T>::get(MLIRContext *context,
|
||||
ArrayRef<T> content) {
|
||||
auto shapedType = RankedTensorType::get(
|
||||
content.size(), DenseArrayAttrUtil<T>::getElementType(context));
|
||||
auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()),
|
||||
content.size() * sizeof(T));
|
||||
return Base::get(context, shapedType, rawArray)
|
||||
.template cast<DenseArrayAttr<T>>();
|
||||
.template cast<DenseArrayAttrImpl<T>>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool DenseArrayAttr<T>::classof(Attribute attr) {
|
||||
if (auto denseArray = attr.dyn_cast<DenseArrayBaseAttr>())
|
||||
bool DenseArrayAttrImpl<T>::classof(Attribute attr) {
|
||||
if (auto denseArray = attr.dyn_cast<DenseArrayAttr>())
|
||||
return DenseArrayAttrUtil<T>::checkElementType(denseArray.getElementType());
|
||||
return false;
|
||||
}
|
||||
@@ -948,13 +948,13 @@ bool DenseArrayAttr<T>::classof(Attribute attr) {
|
||||
namespace mlir {
|
||||
namespace detail {
|
||||
// Explicit instantiation for all the supported DenseArrayAttr.
|
||||
template class DenseArrayAttr<bool>;
|
||||
template class DenseArrayAttr<int8_t>;
|
||||
template class DenseArrayAttr<int16_t>;
|
||||
template class DenseArrayAttr<int32_t>;
|
||||
template class DenseArrayAttr<int64_t>;
|
||||
template class DenseArrayAttr<float>;
|
||||
template class DenseArrayAttr<double>;
|
||||
template class DenseArrayAttrImpl<bool>;
|
||||
template class DenseArrayAttrImpl<int8_t>;
|
||||
template class DenseArrayAttrImpl<int16_t>;
|
||||
template class DenseArrayAttrImpl<int32_t>;
|
||||
template class DenseArrayAttrImpl<int64_t>;
|
||||
template class DenseArrayAttrImpl<float>;
|
||||
template class DenseArrayAttrImpl<double>;
|
||||
} // namespace detail
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -569,6 +569,26 @@ func.func @dense_array_attr() attributes {
|
||||
f64attr = [-142.]
|
||||
// CHECK-SAME: emptyattr = []
|
||||
emptyattr = []
|
||||
|
||||
// CHECK: array.sizes
|
||||
// CHECK-SAME: i0 = array<i0: 0, 0>
|
||||
// CHECK-SAME: ui0 = array<ui0: 0, 0>
|
||||
// CHECK-SAME: si0 = array<si0: 0, 0>
|
||||
// CHECK-SAME: i24 = array<i24: -42, 42, 8388607>
|
||||
// CHECK-SAME: ui24 = array<ui24: 16777215>
|
||||
// CHECK-SAME: si24 = array<si24: -8388608>
|
||||
// CHECK-SAME: bf16 = array<bf16: 1.2{{[0-9]+}}e+00, 3.4{{[0-9]+}}e+00>
|
||||
// CHECK-SAME: f16 = array<f16: 1.{{[0-9]+}}e+00, 3.{{[0-9]+}}e+00>
|
||||
"array.sizes"() {
|
||||
x0_i0 = array<i0: 0, 0>,
|
||||
x1_ui0 = array<ui0: 0, 0>,
|
||||
x2_si0 = array<si0: 0, 0>,
|
||||
x3_i24 = array<i24: -42, 42, 8388607>,
|
||||
x4_ui24 = array<ui24: 16777215>,
|
||||
x5_si24 = array<si24: -8388608>,
|
||||
x6_bf16 = array<bf16: 1.2, 3.4>,
|
||||
x7_f16 = array<f16: 1., 3.>
|
||||
}: () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -521,3 +521,28 @@ func.func @duplicate_dictionary_attr_key() {
|
||||
|
||||
// expected-error@+1 {{`dense_resource` expected a shaped type}}
|
||||
#attr = dense_resource<resource> : i32
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@below {{expected '<' after 'array'}}
|
||||
#attr = array
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@below {{expected integer or float type}}
|
||||
#attr = array<vector<i32>>
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@below {{element type bitwidth must be a multiple of 8}}
|
||||
#attr = array<i7>
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@below {{expected ':' after dense array type}}
|
||||
#attr = array<i8)
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@below {{expected '>' to close an array attribute}}
|
||||
#attr = array<i8: 1)
|
||||
|
||||
@@ -41,9 +41,8 @@ struct TestElementsAttrInterface
|
||||
auto elementsAttr = attr.getValue().dyn_cast<ElementsAttr>();
|
||||
if (!elementsAttr)
|
||||
continue;
|
||||
if (auto concreteAttr =
|
||||
attr.getValue().dyn_cast<DenseArrayBaseAttr>()) {
|
||||
llvm::TypeSwitch<DenseArrayBaseAttr>(concreteAttr)
|
||||
if (auto concreteAttr = attr.getValue().dyn_cast<DenseArrayAttr>()) {
|
||||
llvm::TypeSwitch<DenseArrayAttr>(concreteAttr)
|
||||
.Case([&](DenseBoolArrayAttr attr) {
|
||||
testElementsAttrIteration<bool>(op, attr, "bool");
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user