mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 21:53:12 +08:00
Add support for IndexType inside DenseIntElementsAttr.
This also fixes issues discovered in the parsing/printing path.
This commit is contained in:
@@ -76,6 +76,9 @@ public:
|
||||
|
||||
/// Support method to enable LLVM-style type casting.
|
||||
static bool kindof(unsigned kind) { return kind == StandardTypes::Index; }
|
||||
|
||||
/// Storage bit width used for IndexType by internal compiler data structures.
|
||||
static constexpr unsigned kInternalStorageBitWidth = 64;
|
||||
};
|
||||
|
||||
/// Integer types can have arbitrary bitwidth up to a large fixed limit.
|
||||
|
||||
@@ -169,6 +169,8 @@ public:
|
||||
/// Return true of this is a signless integer or a float type.
|
||||
bool isSignlessIntOrFloat();
|
||||
|
||||
/// Return true if this is an integer (of any signedness) or an index type.
|
||||
bool isIntOrIndex();
|
||||
/// Return true if this is an integer (of any signedness) or a float type.
|
||||
bool isIntOrFloat();
|
||||
/// Return true if this is an integer (of any signedness), index, or float
|
||||
|
||||
@@ -1462,7 +1462,7 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
|
||||
bool isSigned = !type.getElementType().isUnsignedInteger();
|
||||
|
||||
// The function used to print elements of this attribute.
|
||||
auto printEltFn = type.getElementType().isa<IntegerType>()
|
||||
auto printEltFn = type.getElementType().isIntOrIndex()
|
||||
? printDenseIntElement
|
||||
: printDenseFloatElement;
|
||||
|
||||
|
||||
@@ -372,6 +372,17 @@ struct TypeAttributeStorage : public AttributeStorage {
|
||||
// Elements Attributes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Return the bit width which DenseElementsAttr should use for this type.
|
||||
inline size_t getDenseElementBitWidth(Type eltType) {
|
||||
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
|
||||
// with double semantics.
|
||||
if (eltType.isBF16())
|
||||
return 64;
|
||||
if (eltType.isIndex())
|
||||
return IndexType::kInternalStorageBitWidth;
|
||||
return eltType.getIntOrFloatBitWidth();
|
||||
}
|
||||
|
||||
/// An attribute representing a reference to a dense vector or tensor object.
|
||||
struct DenseElementsAttributeStorage : public AttributeStorage {
|
||||
struct KeyTy {
|
||||
@@ -405,7 +416,7 @@ struct DenseElementsAttributeStorage : public AttributeStorage {
|
||||
// same. Boolean values are packed at the bit level, and even though a splat
|
||||
// is detected the rest of the bits in the first byte may differ from the
|
||||
// splat value.
|
||||
if (key.type.getElementTypeBitWidth() == 1) {
|
||||
if (key.type.getElementType().isInteger(1)) {
|
||||
if (key.isSplat != isSplat)
|
||||
return false;
|
||||
if (isSplat)
|
||||
@@ -437,15 +448,10 @@ struct DenseElementsAttributeStorage : public AttributeStorage {
|
||||
assert(numElements != 1 && "splat of 1 element should already be detected");
|
||||
|
||||
// Handle boolean values directly as they are packed to 1-bit.
|
||||
size_t elementWidth = ty.getElementTypeBitWidth();
|
||||
if (elementWidth == 1)
|
||||
if (ty.getElementType().isInteger(1) == 1)
|
||||
return getKeyForBoolData(ty, data, numElements);
|
||||
|
||||
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
|
||||
// with double semantics.
|
||||
if (ty.getElementType().isBF16())
|
||||
elementWidth = 64;
|
||||
|
||||
size_t elementWidth = getDenseElementBitWidth(ty.getElementType());
|
||||
// Non 1-bit dense elements are padded to 8-bits.
|
||||
size_t storageSize = llvm::divideCeil(elementWidth, CHAR_BIT);
|
||||
assert(((data.size() / storageSize) == numElements) &&
|
||||
@@ -517,7 +523,7 @@ struct DenseElementsAttributeStorage : public AttributeStorage {
|
||||
std::memcpy(rawData, data.data(), data.size());
|
||||
|
||||
// If this is a boolean splat, make sure only the first bit is used.
|
||||
if (key.isSplat && key.type.getElementTypeBitWidth() == 1)
|
||||
if (key.isSplat && key.type.getElementType().isInteger(1))
|
||||
rawData[0] &= 1;
|
||||
copy = ArrayRef<char>(rawData, data.size());
|
||||
}
|
||||
|
||||
@@ -275,7 +275,7 @@ IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
|
||||
IntegerAttr IntegerAttr::get(Type type, int64_t value) {
|
||||
// This uses 64 bit APInts by default for index type.
|
||||
if (type.isIndex())
|
||||
return get(type, APInt(64, value));
|
||||
return get(type, APInt(IndexType::kInternalStorageBitWidth, value));
|
||||
|
||||
auto intType = type.cast<IntegerType>();
|
||||
return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger()));
|
||||
@@ -483,12 +483,6 @@ uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
|
||||
// DenseElementAttr Utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static size_t getDenseElementBitwidth(Type eltType) {
|
||||
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
|
||||
// with double semantics.
|
||||
return eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
|
||||
}
|
||||
|
||||
/// Get the bitwidth of a dense element type within the buffer.
|
||||
/// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
|
||||
static size_t getDenseElementStorageWidth(size_t origWidth) {
|
||||
@@ -592,7 +586,7 @@ DenseElementsAttr::IntElementIterator::IntElementIterator(
|
||||
DenseElementsAttr attr, size_t dataIndex)
|
||||
: DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
|
||||
attr.getRawData().data(), attr.isSplat(), dataIndex),
|
||||
bitWidth(getDenseElementBitwidth(attr.getType().getElementType())) {}
|
||||
bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {}
|
||||
|
||||
/// Accesses the raw APInt value at this iterator position.
|
||||
APInt DenseElementsAttr::IntElementIterator::operator*() const {
|
||||
@@ -613,12 +607,12 @@ DenseElementsAttr::FloatElementIterator::FloatElementIterator(
|
||||
|
||||
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
||||
ArrayRef<Attribute> values) {
|
||||
assert(type.getElementType().isIntOrFloat() &&
|
||||
"expected int or float element type");
|
||||
assert(type.getElementType().isIntOrIndexOrFloat() &&
|
||||
"expected int or index or float element type");
|
||||
assert(hasSameElementsOrSplat(type, values));
|
||||
|
||||
auto eltType = type.getElementType();
|
||||
size_t bitWidth = getDenseElementBitwidth(eltType);
|
||||
size_t bitWidth = getDenseElementBitWidth(eltType);
|
||||
size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
|
||||
|
||||
// Compress the attribute values into a character buffer.
|
||||
@@ -637,6 +631,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
||||
intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
|
||||
break;
|
||||
case StandardTypes::Integer:
|
||||
case StandardTypes::Index:
|
||||
intVal = values[i].isa<BoolAttr>()
|
||||
? APInt(1, values[i].cast<BoolAttr>().getValue() ? 1 : 0)
|
||||
: values[i].cast<IntegerAttr>().getValue();
|
||||
@@ -667,7 +662,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
||||
/// element type of 'type'.
|
||||
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
||||
ArrayRef<APInt> values) {
|
||||
assert(type.getElementType().isa<IntegerType>());
|
||||
assert(type.getElementType().isIntOrIndex());
|
||||
return getRaw(type, values);
|
||||
}
|
||||
|
||||
@@ -701,7 +696,7 @@ DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
|
||||
ArrayRef<APInt> values) {
|
||||
assert(hasSameElementsOrSplat(type, values));
|
||||
|
||||
size_t bitWidth = getDenseElementBitwidth(type.getElementType());
|
||||
size_t bitWidth = getDenseElementBitWidth(type.getElementType());
|
||||
size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
|
||||
std::vector<char> elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
|
||||
values.size());
|
||||
@@ -727,14 +722,17 @@ DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
|
||||
static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize, bool isInt,
|
||||
bool isSigned) {
|
||||
// Make sure that the data element size is the same as the type element width.
|
||||
if (getDenseElementBitwidth(type.getElementType()) !=
|
||||
if (getDenseElementBitWidth(type.getElementType()) !=
|
||||
static_cast<size_t>(dataEltSize * CHAR_BIT))
|
||||
return false;
|
||||
|
||||
// Check that the element type is either float or integer.
|
||||
// Check that the element type is either float or integer or index.
|
||||
if (!isInt)
|
||||
return type.getElementType().isa<FloatType>();
|
||||
|
||||
if (type.getElementType().isIndex())
|
||||
return true;
|
||||
|
||||
auto intType = type.getElementType().dyn_cast<IntegerType>();
|
||||
if (!intType)
|
||||
return false;
|
||||
@@ -798,18 +796,15 @@ auto DenseElementsAttr::getBoolValues() const
|
||||
/// this attribute must be of integer type.
|
||||
auto DenseElementsAttr::getIntValues() const
|
||||
-> llvm::iterator_range<IntElementIterator> {
|
||||
assert(getType().getElementType().isa<IntegerType>() &&
|
||||
"expected integer type");
|
||||
assert(getType().getElementType().isIntOrIndex() && "expected integral type");
|
||||
return {raw_int_begin(), raw_int_end()};
|
||||
}
|
||||
auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
|
||||
assert(getType().getElementType().isa<IntegerType>() &&
|
||||
"expected integer type");
|
||||
assert(getType().getElementType().isIntOrIndex() && "expected integral type");
|
||||
return raw_int_begin();
|
||||
}
|
||||
auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
|
||||
assert(getType().getElementType().isa<IntegerType>() &&
|
||||
"expected integer type");
|
||||
assert(getType().getElementType().isIntOrIndex() && "expected integral type");
|
||||
return raw_int_end();
|
||||
}
|
||||
|
||||
@@ -870,7 +865,7 @@ template <typename Fn, typename Attr>
|
||||
static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
|
||||
Type newElementType,
|
||||
llvm::SmallVectorImpl<char> &data) {
|
||||
size_t bitWidth = getDenseElementBitwidth(newElementType);
|
||||
size_t bitWidth = getDenseElementBitWidth(newElementType);
|
||||
size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
|
||||
|
||||
ShapedType newArrayType;
|
||||
@@ -937,7 +932,7 @@ DenseElementsAttr DenseIntElementsAttr::mapValues(
|
||||
/// Method for supporting type inquiry through isa, cast and dyn_cast.
|
||||
bool DenseIntElementsAttr::classof(Attribute attr) {
|
||||
return attr.isa<DenseElementsAttr>() &&
|
||||
attr.getType().cast<ShapedType>().getElementType().isa<IntegerType>();
|
||||
attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -83,6 +83,8 @@ bool Type::isSignlessIntOrFloat() {
|
||||
return isSignlessInteger() || isa<FloatType>();
|
||||
}
|
||||
|
||||
bool Type::isIntOrIndex() { return isa<IntegerType>() || isIndex(); }
|
||||
|
||||
bool Type::isIntOrFloat() { return isa<IntegerType>() || isa<FloatType>(); }
|
||||
|
||||
bool Type::isIntOrIndexOrFloat() { return isIntOrFloat() || isIndex(); }
|
||||
|
||||
@@ -1797,7 +1797,8 @@ static Optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
|
||||
return llvm::None;
|
||||
|
||||
// Extend or truncate the bitwidth to the right size.
|
||||
unsigned width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth();
|
||||
unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth
|
||||
: type.getIntOrFloatBitWidth();
|
||||
if (width > result.getBitWidth()) {
|
||||
result = result.zext(width);
|
||||
} else if (width < result.getBitWidth()) {
|
||||
@@ -1968,8 +1969,7 @@ private:
|
||||
}
|
||||
|
||||
/// Build a Dense Integer attribute for the given type.
|
||||
DenseElementsAttr getIntAttr(llvm::SMLoc loc, ShapedType type,
|
||||
IntegerType eltTy);
|
||||
DenseElementsAttr getIntAttr(llvm::SMLoc loc, ShapedType type, Type eltTy);
|
||||
|
||||
/// Build a Dense Float attribute for the given type.
|
||||
DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type,
|
||||
@@ -2044,14 +2044,17 @@ DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
|
||||
|
||||
// If the type is an integer, build a set of APInt values from the storage
|
||||
// with the correct bitwidth.
|
||||
if (auto intTy = type.getElementType().dyn_cast<IntegerType>())
|
||||
Type eltType = type.getElementType();
|
||||
if (auto intTy = eltType.dyn_cast<IntegerType>())
|
||||
return getIntAttr(loc, type, intTy);
|
||||
if (auto indexTy = eltType.dyn_cast<IndexType>())
|
||||
return getIntAttr(loc, type, indexTy);
|
||||
|
||||
// Otherwise, this must be a floating point type.
|
||||
auto floatTy = type.getElementType().dyn_cast<FloatType>();
|
||||
auto floatTy = eltType.dyn_cast<FloatType>();
|
||||
if (!floatTy) {
|
||||
p.emitError(loc) << "expected floating-point or integer element type, got "
|
||||
<< type.getElementType();
|
||||
<< eltType;
|
||||
return nullptr;
|
||||
}
|
||||
return getFloatAttr(loc, type, floatTy);
|
||||
@@ -2059,8 +2062,7 @@ DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
|
||||
|
||||
/// Build a Dense Integer attribute for the given type.
|
||||
DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc,
|
||||
ShapedType type,
|
||||
IntegerType eltTy) {
|
||||
ShapedType type, Type eltTy) {
|
||||
std::vector<APInt> intElements;
|
||||
intElements.reserve(storage.size());
|
||||
auto isUintType = type.getElementType().isUnsignedInteger();
|
||||
@@ -2085,11 +2087,12 @@ DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc,
|
||||
assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) &&
|
||||
"unexpected token type");
|
||||
if (token.isAny(Token::kw_true, Token::kw_false)) {
|
||||
if (!eltTy.isInteger(1))
|
||||
if (!eltTy.isInteger(1)) {
|
||||
p.emitError(tokenLoc)
|
||||
<< "expected i1 type for 'true' or 'false' values";
|
||||
APInt apInt(eltTy.getWidth(), token.is(Token::kw_true),
|
||||
/*isSigned=*/false);
|
||||
return nullptr;
|
||||
}
|
||||
APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false);
|
||||
intElements.push_back(apInt);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -697,6 +697,11 @@ func @densetensorattr() -> () {
|
||||
"intscalar"(){bar = dense<1> : tensor<i32>} : () -> ()
|
||||
// CHECK: "floatscalar"() {bar = dense<5.000000e+00> : tensor<f32>} : () -> ()
|
||||
"floatscalar"(){bar = dense<5.0> : tensor<f32>} : () -> ()
|
||||
|
||||
// CHECK: "index"() {bar = dense<1> : tensor<index>} : () -> ()
|
||||
"index"(){bar = dense<1> : tensor<index>} : () -> ()
|
||||
// CHECK: "index"() {bar = dense<[1, 2]> : tensor<2xindex>} : () -> ()
|
||||
"index"(){bar = dense<[1, 2]> : tensor<2xindex>} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user