mirror of
https://github.com/intel/llvm.git
synced 2026-02-06 06:31:50 +08:00
Reapply "[mlir][sparse] remove LevelType enum, construct LevelType from LevelFormat and Properties" (#81923) (#81934)
This commit is contained in:
@@ -153,6 +153,57 @@ enum class Action : uint32_t {
|
||||
kSortCOOInPlace = 8,
|
||||
};
|
||||
|
||||
/// This enum defines all supported storage format without the level properties.
|
||||
enum class LevelFormat : uint64_t {
|
||||
Undef = 0x00000000,
|
||||
Dense = 0x00010000,
|
||||
Compressed = 0x00020000,
|
||||
Singleton = 0x00040000,
|
||||
LooseCompressed = 0x00080000,
|
||||
NOutOfM = 0x00100000,
|
||||
};
|
||||
|
||||
template <LevelFormat... targets>
|
||||
constexpr bool isAnyOfFmt(LevelFormat fmt) {
|
||||
return (... || (targets == fmt));
|
||||
}
|
||||
|
||||
/// Returns string representation of the given level format.
|
||||
constexpr const char *toFormatString(LevelFormat lvlFmt) {
|
||||
switch (lvlFmt) {
|
||||
case LevelFormat::Undef:
|
||||
return "undef";
|
||||
case LevelFormat::Dense:
|
||||
return "dense";
|
||||
case LevelFormat::Compressed:
|
||||
return "compressed";
|
||||
case LevelFormat::Singleton:
|
||||
return "singleton";
|
||||
case LevelFormat::LooseCompressed:
|
||||
return "loose_compressed";
|
||||
case LevelFormat::NOutOfM:
|
||||
return "structured";
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
/// This enum defines all the nondefault properties for storage formats.
|
||||
enum class LevelPropNonDefault : uint64_t {
|
||||
Nonunique = 0x0001,
|
||||
Nonordered = 0x0002,
|
||||
};
|
||||
|
||||
/// Returns string representation of the given level properties.
|
||||
constexpr const char *toPropString(LevelPropNonDefault lvlProp) {
|
||||
switch (lvlProp) {
|
||||
case LevelPropNonDefault::Nonunique:
|
||||
return "nonunique";
|
||||
case LevelPropNonDefault::Nonordered:
|
||||
return "nonordered";
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
/// This enum defines all the sparse representations supportable by
|
||||
/// the SparseTensor dialect. We use a lightweight encoding to encode
|
||||
/// the "format" per se (dense, compressed, singleton, loose_compressed,
|
||||
@@ -167,359 +218,185 @@ enum class Action : uint32_t {
|
||||
/// where we need to store an undefined or indeterminate `LevelType`.
|
||||
/// It should not be used externally, since it does not indicate an
|
||||
/// actual/representable format.
|
||||
///
|
||||
/// Bit manipulations for LevelType:
|
||||
///
|
||||
/// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty |
|
||||
///
|
||||
enum class LevelType : uint64_t {
|
||||
Undef = 0x000000000000,
|
||||
Dense = 0x000000010000,
|
||||
Compressed = 0x000000020000,
|
||||
CompressedNu = 0x000000020001,
|
||||
CompressedNo = 0x000000020002,
|
||||
CompressedNuNo = 0x000000020003,
|
||||
Singleton = 0x000000040000,
|
||||
SingletonNu = 0x000000040001,
|
||||
SingletonNo = 0x000000040002,
|
||||
SingletonNuNo = 0x000000040003,
|
||||
LooseCompressed = 0x000000080000,
|
||||
LooseCompressedNu = 0x000000080001,
|
||||
LooseCompressedNo = 0x000000080002,
|
||||
LooseCompressedNuNo = 0x000000080003,
|
||||
NOutOfM = 0x000000100000,
|
||||
|
||||
struct LevelType {
|
||||
public:
|
||||
/// Check that the `LevelType` contains a valid (possibly undefined) value.
|
||||
static constexpr bool isValidLvlBits(uint64_t lvlBits) {
|
||||
auto fmt = static_cast<LevelFormat>(lvlBits & 0xffff0000);
|
||||
const uint64_t propertyBits = lvlBits & 0xffff;
|
||||
// If undefined/dense/NOutOfM, then must be unique and ordered.
|
||||
// Otherwise, the format must be one of the known ones.
|
||||
return (isAnyOfFmt<LevelFormat::Undef, LevelFormat::Dense,
|
||||
LevelFormat::NOutOfM>(fmt))
|
||||
? (propertyBits == 0)
|
||||
: (isAnyOfFmt<LevelFormat::Compressed, LevelFormat::Singleton,
|
||||
LevelFormat::LooseCompressed>(fmt));
|
||||
}
|
||||
|
||||
/// Convert a LevelFormat to its corresponding LevelType with the given
|
||||
/// properties. Returns std::nullopt when the properties are not applicable
|
||||
/// for the input level format.
|
||||
static std::optional<LevelType>
|
||||
buildLvlType(LevelFormat lf,
|
||||
const std::vector<LevelPropNonDefault> &properties,
|
||||
uint64_t n = 0, uint64_t m = 0) {
|
||||
assert((n & 0xff) == n && (m & 0xff) == m);
|
||||
uint64_t newN = n << 32;
|
||||
uint64_t newM = m << 40;
|
||||
uint64_t ltBits = static_cast<uint64_t>(lf) | newN | newM;
|
||||
for (auto p : properties)
|
||||
ltBits |= static_cast<uint64_t>(p);
|
||||
|
||||
return isValidLvlBits(ltBits) ? std::optional(LevelType(ltBits))
|
||||
: std::nullopt;
|
||||
}
|
||||
static std::optional<LevelType> buildLvlType(LevelFormat lf, bool ordered,
|
||||
bool unique, uint64_t n = 0,
|
||||
uint64_t m = 0) {
|
||||
std::vector<LevelPropNonDefault> properties;
|
||||
if (!ordered)
|
||||
properties.push_back(LevelPropNonDefault::Nonordered);
|
||||
if (!unique)
|
||||
properties.push_back(LevelPropNonDefault::Nonunique);
|
||||
return buildLvlType(lf, properties, n, m);
|
||||
}
|
||||
|
||||
/// Explicit conversion from uint64_t.
|
||||
constexpr explicit LevelType(uint64_t bits) : lvlBits(bits) {
|
||||
assert(isValidLvlBits(bits));
|
||||
};
|
||||
|
||||
/// Constructs a LevelType with the given format using all default properties.
|
||||
/*implicit*/ LevelType(LevelFormat f) : lvlBits(static_cast<uint64_t>(f)) {
|
||||
assert(isValidLvlBits(lvlBits) && !isa<LevelFormat::NOutOfM>());
|
||||
};
|
||||
|
||||
/// Converts to uint64_t
|
||||
explicit operator uint64_t() const { return lvlBits; }
|
||||
|
||||
bool operator==(const LevelType lhs) const {
|
||||
return static_cast<uint64_t>(lhs) == lvlBits;
|
||||
}
|
||||
bool operator!=(const LevelType lhs) const { return !(*this == lhs); }
|
||||
|
||||
LevelType stripProperties() const { return LevelType(lvlBits & ~0xffff); }
|
||||
|
||||
/// Get N of NOutOfM level type.
|
||||
constexpr uint64_t getN() const {
|
||||
assert(isa<LevelFormat::NOutOfM>());
|
||||
return (lvlBits >> 32) & 0xff;
|
||||
}
|
||||
|
||||
/// Get M of NOutOfM level type.
|
||||
constexpr uint64_t getM() const {
|
||||
assert(isa<LevelFormat::NOutOfM>());
|
||||
return (lvlBits >> 40) & 0xff;
|
||||
}
|
||||
|
||||
/// Get the `LevelFormat` of the `LevelType`.
|
||||
constexpr LevelFormat getLvlFmt() const {
|
||||
return static_cast<LevelFormat>(lvlBits & 0xffff0000);
|
||||
}
|
||||
|
||||
/// Check if the `LevelType` is in the `LevelFormat`.
|
||||
template <LevelFormat fmt>
|
||||
constexpr bool isa() const {
|
||||
return getLvlFmt() == fmt;
|
||||
}
|
||||
|
||||
/// Check if the `LevelType` has the properties
|
||||
template <LevelPropNonDefault p>
|
||||
constexpr bool isa() const {
|
||||
return lvlBits & static_cast<uint64_t>(p);
|
||||
}
|
||||
|
||||
/// Check if the `LevelType` needs positions array.
|
||||
constexpr bool isWithPosLT() const {
|
||||
return isa<LevelFormat::Compressed>() ||
|
||||
isa<LevelFormat::LooseCompressed>();
|
||||
}
|
||||
|
||||
/// Check if the `LevelType` needs coordinates array.
|
||||
constexpr bool isWithCrdLT() const {
|
||||
// All sparse levels has coordinate array.
|
||||
return !isa<LevelFormat::Dense>();
|
||||
}
|
||||
|
||||
std::string toMLIRString() const {
|
||||
std::string lvlStr = toFormatString(getLvlFmt());
|
||||
std::string propStr = "";
|
||||
if (isa<LevelPropNonDefault::Nonunique>())
|
||||
propStr += toPropString(LevelPropNonDefault::Nonunique);
|
||||
|
||||
if (isa<LevelPropNonDefault::Nonordered>()) {
|
||||
if (!propStr.empty())
|
||||
propStr += ", ";
|
||||
propStr += toPropString(LevelPropNonDefault::Nonordered);
|
||||
}
|
||||
if (!propStr.empty())
|
||||
lvlStr += ("(" + propStr + ")");
|
||||
return lvlStr;
|
||||
}
|
||||
|
||||
private:
|
||||
/// Bit manipulations for LevelType:
|
||||
///
|
||||
/// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty |
|
||||
///
|
||||
uint64_t lvlBits;
|
||||
};
|
||||
|
||||
/// This enum defines all supported storage format without the level properties.
|
||||
enum class LevelFormat : uint64_t {
|
||||
Dense = 0x00010000,
|
||||
Compressed = 0x00020000,
|
||||
Singleton = 0x00040000,
|
||||
LooseCompressed = 0x00080000,
|
||||
NOutOfM = 0x00100000,
|
||||
};
|
||||
|
||||
/// This enum defines all the nondefault properties for storage formats.
|
||||
enum class LevelPropertyNondefault : uint64_t {
|
||||
Nonunique = 0x0001,
|
||||
Nonordered = 0x0002,
|
||||
};
|
||||
|
||||
/// Get N of NOutOfM level type.
|
||||
constexpr uint64_t getN(LevelType lt) {
|
||||
return (static_cast<uint64_t>(lt) >> 32) & 0xff;
|
||||
}
|
||||
|
||||
/// Get M of NOutOfM level type.
|
||||
constexpr uint64_t getM(LevelType lt) {
|
||||
return (static_cast<uint64_t>(lt) >> 40) & 0xff;
|
||||
}
|
||||
|
||||
/// Convert N of NOutOfM level type to the stored bits.
|
||||
// For backward-compatibility. TODO: remove below after fully migration.
|
||||
constexpr uint64_t nToBits(uint64_t n) { return n << 32; }
|
||||
|
||||
/// Convert M of NOutOfM level type to the stored bits.
|
||||
constexpr uint64_t mToBits(uint64_t m) { return m << 40; }
|
||||
|
||||
/// Check if the `LevelType` is NOutOfM (regardless of
|
||||
/// properties and block sizes).
|
||||
constexpr bool isNOutOfMLT(LevelType lt) {
|
||||
return ((static_cast<uint64_t>(lt) & 0x100000) ==
|
||||
static_cast<uint64_t>(LevelType::NOutOfM));
|
||||
}
|
||||
|
||||
/// Check if the `LevelType` is NOutOfM with the correct block sizes.
|
||||
constexpr bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) {
|
||||
return isNOutOfMLT(lt) && getN(lt) == n && getM(lt) == m;
|
||||
}
|
||||
|
||||
/// Returns string representation of the given dimension level type.
|
||||
constexpr const char *toMLIRString(LevelType lvlType) {
|
||||
auto lt = static_cast<LevelType>(static_cast<uint64_t>(lvlType) & 0xffffffff);
|
||||
switch (lt) {
|
||||
case LevelType::Undef:
|
||||
return "undef";
|
||||
case LevelType::Dense:
|
||||
return "dense";
|
||||
case LevelType::Compressed:
|
||||
return "compressed";
|
||||
case LevelType::CompressedNu:
|
||||
return "compressed(nonunique)";
|
||||
case LevelType::CompressedNo:
|
||||
return "compressed(nonordered)";
|
||||
case LevelType::CompressedNuNo:
|
||||
return "compressed(nonunique, nonordered)";
|
||||
case LevelType::Singleton:
|
||||
return "singleton";
|
||||
case LevelType::SingletonNu:
|
||||
return "singleton(nonunique)";
|
||||
case LevelType::SingletonNo:
|
||||
return "singleton(nonordered)";
|
||||
case LevelType::SingletonNuNo:
|
||||
return "singleton(nonunique, nonordered)";
|
||||
case LevelType::LooseCompressed:
|
||||
return "loose_compressed";
|
||||
case LevelType::LooseCompressedNu:
|
||||
return "loose_compressed(nonunique)";
|
||||
case LevelType::LooseCompressedNo:
|
||||
return "loose_compressed(nonordered)";
|
||||
case LevelType::LooseCompressedNuNo:
|
||||
return "loose_compressed(nonunique, nonordered)";
|
||||
case LevelType::NOutOfM:
|
||||
return "structured";
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
/// Check that the `LevelType` contains a valid (possibly undefined) value.
|
||||
constexpr bool isValidLT(LevelType lt) {
|
||||
const uint64_t formatBits = static_cast<uint64_t>(lt) & 0xffff0000;
|
||||
const uint64_t propertyBits = static_cast<uint64_t>(lt) & 0xffff;
|
||||
// If undefined/dense/NOutOfM, then must be unique and ordered.
|
||||
// Otherwise, the format must be one of the known ones.
|
||||
return (formatBits <= 0x10000 || formatBits == 0x100000)
|
||||
? (propertyBits == 0)
|
||||
: (formatBits == 0x20000 || formatBits == 0x40000 ||
|
||||
formatBits == 0x80000);
|
||||
}
|
||||
|
||||
/// Check if the `LevelType` is the special undefined value.
|
||||
constexpr bool isUndefLT(LevelType lt) { return lt == LevelType::Undef; }
|
||||
|
||||
/// Check if the `LevelType` is dense (regardless of properties).
|
||||
constexpr bool isDenseLT(LevelType lt) {
|
||||
return (static_cast<uint64_t>(lt) & ~0xffff) ==
|
||||
static_cast<uint64_t>(LevelType::Dense);
|
||||
}
|
||||
|
||||
/// Check if the `LevelType` is compressed (regardless of properties).
|
||||
constexpr bool isCompressedLT(LevelType lt) {
|
||||
return (static_cast<uint64_t>(lt) & ~0xffff) ==
|
||||
static_cast<uint64_t>(LevelType::Compressed);
|
||||
}
|
||||
|
||||
/// Check if the `LevelType` is singleton (regardless of properties).
|
||||
constexpr bool isSingletonLT(LevelType lt) {
|
||||
return (static_cast<uint64_t>(lt) & ~0xffff) ==
|
||||
static_cast<uint64_t>(LevelType::Singleton);
|
||||
}
|
||||
|
||||
/// Check if the `LevelType` is loose compressed (regardless of properties).
|
||||
constexpr bool isLooseCompressedLT(LevelType lt) {
|
||||
return (static_cast<uint64_t>(lt) & ~0xffff) ==
|
||||
static_cast<uint64_t>(LevelType::LooseCompressed);
|
||||
}
|
||||
|
||||
/// Check if the `LevelType` needs positions array.
|
||||
constexpr bool isWithPosLT(LevelType lt) {
|
||||
return isCompressedLT(lt) || isLooseCompressedLT(lt);
|
||||
}
|
||||
|
||||
/// Check if the `LevelType` needs coordinates array.
|
||||
constexpr bool isWithCrdLT(LevelType lt) {
|
||||
return isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
|
||||
isNOutOfMLT(lt);
|
||||
}
|
||||
|
||||
/// Check if the `LevelType` is ordered (regardless of storage format).
|
||||
constexpr bool isOrderedLT(LevelType lt) {
|
||||
return !(static_cast<uint64_t>(lt) & 2);
|
||||
return !(static_cast<uint64_t>(lt) & 2);
|
||||
}
|
||||
|
||||
/// Check if the `LevelType` is unique (regardless of storage format).
|
||||
constexpr bool isUniqueLT(LevelType lt) {
|
||||
return !(static_cast<uint64_t>(lt) & 1);
|
||||
return !(static_cast<uint64_t>(lt) & 1);
|
||||
}
|
||||
|
||||
/// Convert a LevelType to its corresponding LevelFormat.
|
||||
/// Returns std::nullopt when input lt is Undef.
|
||||
constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) {
|
||||
if (lt == LevelType::Undef)
|
||||
return std::nullopt;
|
||||
return static_cast<LevelFormat>(static_cast<uint64_t>(lt) & 0xffff0000);
|
||||
}
|
||||
|
||||
/// Convert a LevelFormat to its corresponding LevelType with the given
|
||||
/// properties. Returns std::nullopt when the properties are not applicable
|
||||
/// for the input level format.
|
||||
inline std::optional<LevelType>
|
||||
buildLevelType(LevelFormat lf,
|
||||
const std::vector<LevelPropertyNondefault> &properties,
|
||||
const std::vector<LevelPropNonDefault> &properties,
|
||||
uint64_t n = 0, uint64_t m = 0) {
|
||||
uint64_t newN = n << 32;
|
||||
uint64_t newM = m << 40;
|
||||
uint64_t ltInt = static_cast<uint64_t>(lf) | newN | newM;
|
||||
for (auto p : properties) {
|
||||
ltInt |= static_cast<uint64_t>(p);
|
||||
}
|
||||
auto lt = static_cast<LevelType>(ltInt);
|
||||
return isValidLT(lt) ? std::optional(lt) : std::nullopt;
|
||||
return LevelType::buildLvlType(lf, properties, n, m);
|
||||
}
|
||||
|
||||
inline std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
|
||||
bool unique, uint64_t n = 0,
|
||||
uint64_t m = 0) {
|
||||
std::vector<LevelPropertyNondefault> properties;
|
||||
if (!ordered)
|
||||
properties.push_back(LevelPropertyNondefault::Nonordered);
|
||||
if (!unique)
|
||||
properties.push_back(LevelPropertyNondefault::Nonunique);
|
||||
return buildLevelType(lf, properties, n, m);
|
||||
return LevelType::buildLvlType(lf, ordered, unique, n, m);
|
||||
}
|
||||
|
||||
//
|
||||
// Ensure the above methods work as intended.
|
||||
//
|
||||
|
||||
static_assert(
|
||||
(getLevelFormat(LevelType::Undef) == std::nullopt &&
|
||||
*getLevelFormat(LevelType::Dense) == LevelFormat::Dense &&
|
||||
*getLevelFormat(LevelType::Compressed) == LevelFormat::Compressed &&
|
||||
*getLevelFormat(LevelType::CompressedNu) == LevelFormat::Compressed &&
|
||||
*getLevelFormat(LevelType::CompressedNo) == LevelFormat::Compressed &&
|
||||
*getLevelFormat(LevelType::CompressedNuNo) == LevelFormat::Compressed &&
|
||||
*getLevelFormat(LevelType::Singleton) == LevelFormat::Singleton &&
|
||||
*getLevelFormat(LevelType::SingletonNu) == LevelFormat::Singleton &&
|
||||
*getLevelFormat(LevelType::SingletonNo) == LevelFormat::Singleton &&
|
||||
*getLevelFormat(LevelType::SingletonNuNo) == LevelFormat::Singleton &&
|
||||
*getLevelFormat(LevelType::LooseCompressed) ==
|
||||
LevelFormat::LooseCompressed &&
|
||||
*getLevelFormat(LevelType::LooseCompressedNu) ==
|
||||
LevelFormat::LooseCompressed &&
|
||||
*getLevelFormat(LevelType::LooseCompressedNo) ==
|
||||
LevelFormat::LooseCompressed &&
|
||||
*getLevelFormat(LevelType::LooseCompressedNuNo) ==
|
||||
LevelFormat::LooseCompressed &&
|
||||
*getLevelFormat(LevelType::NOutOfM) == LevelFormat::NOutOfM),
|
||||
"getLevelFormat conversion is broken");
|
||||
|
||||
static_assert(
|
||||
(isValidLT(LevelType::Undef) && isValidLT(LevelType::Dense) &&
|
||||
isValidLT(LevelType::Compressed) && isValidLT(LevelType::CompressedNu) &&
|
||||
isValidLT(LevelType::CompressedNo) &&
|
||||
isValidLT(LevelType::CompressedNuNo) && isValidLT(LevelType::Singleton) &&
|
||||
isValidLT(LevelType::SingletonNu) && isValidLT(LevelType::SingletonNo) &&
|
||||
isValidLT(LevelType::SingletonNuNo) &&
|
||||
isValidLT(LevelType::LooseCompressed) &&
|
||||
isValidLT(LevelType::LooseCompressedNu) &&
|
||||
isValidLT(LevelType::LooseCompressedNo) &&
|
||||
isValidLT(LevelType::LooseCompressedNuNo) &&
|
||||
isValidLT(LevelType::NOutOfM)),
|
||||
"isValidLT definition is broken");
|
||||
|
||||
static_assert((isDenseLT(LevelType::Dense) &&
|
||||
!isDenseLT(LevelType::Compressed) &&
|
||||
!isDenseLT(LevelType::CompressedNu) &&
|
||||
!isDenseLT(LevelType::CompressedNo) &&
|
||||
!isDenseLT(LevelType::CompressedNuNo) &&
|
||||
!isDenseLT(LevelType::Singleton) &&
|
||||
!isDenseLT(LevelType::SingletonNu) &&
|
||||
!isDenseLT(LevelType::SingletonNo) &&
|
||||
!isDenseLT(LevelType::SingletonNuNo) &&
|
||||
!isDenseLT(LevelType::LooseCompressed) &&
|
||||
!isDenseLT(LevelType::LooseCompressedNu) &&
|
||||
!isDenseLT(LevelType::LooseCompressedNo) &&
|
||||
!isDenseLT(LevelType::LooseCompressedNuNo) &&
|
||||
!isDenseLT(LevelType::NOutOfM)),
|
||||
"isDenseLT definition is broken");
|
||||
|
||||
static_assert((!isCompressedLT(LevelType::Dense) &&
|
||||
isCompressedLT(LevelType::Compressed) &&
|
||||
isCompressedLT(LevelType::CompressedNu) &&
|
||||
isCompressedLT(LevelType::CompressedNo) &&
|
||||
isCompressedLT(LevelType::CompressedNuNo) &&
|
||||
!isCompressedLT(LevelType::Singleton) &&
|
||||
!isCompressedLT(LevelType::SingletonNu) &&
|
||||
!isCompressedLT(LevelType::SingletonNo) &&
|
||||
!isCompressedLT(LevelType::SingletonNuNo) &&
|
||||
!isCompressedLT(LevelType::LooseCompressed) &&
|
||||
!isCompressedLT(LevelType::LooseCompressedNu) &&
|
||||
!isCompressedLT(LevelType::LooseCompressedNo) &&
|
||||
!isCompressedLT(LevelType::LooseCompressedNuNo) &&
|
||||
!isCompressedLT(LevelType::NOutOfM)),
|
||||
"isCompressedLT definition is broken");
|
||||
|
||||
static_assert((!isSingletonLT(LevelType::Dense) &&
|
||||
!isSingletonLT(LevelType::Compressed) &&
|
||||
!isSingletonLT(LevelType::CompressedNu) &&
|
||||
!isSingletonLT(LevelType::CompressedNo) &&
|
||||
!isSingletonLT(LevelType::CompressedNuNo) &&
|
||||
isSingletonLT(LevelType::Singleton) &&
|
||||
isSingletonLT(LevelType::SingletonNu) &&
|
||||
isSingletonLT(LevelType::SingletonNo) &&
|
||||
isSingletonLT(LevelType::SingletonNuNo) &&
|
||||
!isSingletonLT(LevelType::LooseCompressed) &&
|
||||
!isSingletonLT(LevelType::LooseCompressedNu) &&
|
||||
!isSingletonLT(LevelType::LooseCompressedNo) &&
|
||||
!isSingletonLT(LevelType::LooseCompressedNuNo) &&
|
||||
!isSingletonLT(LevelType::NOutOfM)),
|
||||
"isSingletonLT definition is broken");
|
||||
|
||||
static_assert((!isLooseCompressedLT(LevelType::Dense) &&
|
||||
!isLooseCompressedLT(LevelType::Compressed) &&
|
||||
!isLooseCompressedLT(LevelType::CompressedNu) &&
|
||||
!isLooseCompressedLT(LevelType::CompressedNo) &&
|
||||
!isLooseCompressedLT(LevelType::CompressedNuNo) &&
|
||||
!isLooseCompressedLT(LevelType::Singleton) &&
|
||||
!isLooseCompressedLT(LevelType::SingletonNu) &&
|
||||
!isLooseCompressedLT(LevelType::SingletonNo) &&
|
||||
!isLooseCompressedLT(LevelType::SingletonNuNo) &&
|
||||
isLooseCompressedLT(LevelType::LooseCompressed) &&
|
||||
isLooseCompressedLT(LevelType::LooseCompressedNu) &&
|
||||
isLooseCompressedLT(LevelType::LooseCompressedNo) &&
|
||||
isLooseCompressedLT(LevelType::LooseCompressedNuNo) &&
|
||||
!isLooseCompressedLT(LevelType::NOutOfM)),
|
||||
"isLooseCompressedLT definition is broken");
|
||||
|
||||
static_assert((!isNOutOfMLT(LevelType::Dense) &&
|
||||
!isNOutOfMLT(LevelType::Compressed) &&
|
||||
!isNOutOfMLT(LevelType::CompressedNu) &&
|
||||
!isNOutOfMLT(LevelType::CompressedNo) &&
|
||||
!isNOutOfMLT(LevelType::CompressedNuNo) &&
|
||||
!isNOutOfMLT(LevelType::Singleton) &&
|
||||
!isNOutOfMLT(LevelType::SingletonNu) &&
|
||||
!isNOutOfMLT(LevelType::SingletonNo) &&
|
||||
!isNOutOfMLT(LevelType::SingletonNuNo) &&
|
||||
!isNOutOfMLT(LevelType::LooseCompressed) &&
|
||||
!isNOutOfMLT(LevelType::LooseCompressedNu) &&
|
||||
!isNOutOfMLT(LevelType::LooseCompressedNo) &&
|
||||
!isNOutOfMLT(LevelType::LooseCompressedNuNo) &&
|
||||
isNOutOfMLT(LevelType::NOutOfM)),
|
||||
"isNOutOfMLT definition is broken");
|
||||
|
||||
static_assert((isOrderedLT(LevelType::Dense) &&
|
||||
isOrderedLT(LevelType::Compressed) &&
|
||||
isOrderedLT(LevelType::CompressedNu) &&
|
||||
!isOrderedLT(LevelType::CompressedNo) &&
|
||||
!isOrderedLT(LevelType::CompressedNuNo) &&
|
||||
isOrderedLT(LevelType::Singleton) &&
|
||||
isOrderedLT(LevelType::SingletonNu) &&
|
||||
!isOrderedLT(LevelType::SingletonNo) &&
|
||||
!isOrderedLT(LevelType::SingletonNuNo) &&
|
||||
isOrderedLT(LevelType::LooseCompressed) &&
|
||||
isOrderedLT(LevelType::LooseCompressedNu) &&
|
||||
!isOrderedLT(LevelType::LooseCompressedNo) &&
|
||||
!isOrderedLT(LevelType::LooseCompressedNuNo) &&
|
||||
isOrderedLT(LevelType::NOutOfM)),
|
||||
"isOrderedLT definition is broken");
|
||||
|
||||
static_assert((isUniqueLT(LevelType::Dense) &&
|
||||
isUniqueLT(LevelType::Compressed) &&
|
||||
!isUniqueLT(LevelType::CompressedNu) &&
|
||||
isUniqueLT(LevelType::CompressedNo) &&
|
||||
!isUniqueLT(LevelType::CompressedNuNo) &&
|
||||
isUniqueLT(LevelType::Singleton) &&
|
||||
!isUniqueLT(LevelType::SingletonNu) &&
|
||||
isUniqueLT(LevelType::SingletonNo) &&
|
||||
!isUniqueLT(LevelType::SingletonNuNo) &&
|
||||
isUniqueLT(LevelType::LooseCompressed) &&
|
||||
!isUniqueLT(LevelType::LooseCompressedNu) &&
|
||||
isUniqueLT(LevelType::LooseCompressedNo) &&
|
||||
!isUniqueLT(LevelType::LooseCompressedNuNo) &&
|
||||
isUniqueLT(LevelType::NOutOfM)),
|
||||
"isUniqueLT definition is broken");
|
||||
inline bool isUndefLT(LevelType lt) { return lt.isa<LevelFormat::Undef>(); }
|
||||
inline bool isDenseLT(LevelType lt) { return lt.isa<LevelFormat::Dense>(); }
|
||||
inline bool isCompressedLT(LevelType lt) {
|
||||
return lt.isa<LevelFormat::Compressed>();
|
||||
}
|
||||
inline bool isLooseCompressedLT(LevelType lt) {
|
||||
return lt.isa<LevelFormat::LooseCompressed>();
|
||||
}
|
||||
inline bool isSingletonLT(LevelType lt) {
|
||||
return lt.isa<LevelFormat::Singleton>();
|
||||
}
|
||||
inline bool isNOutOfMLT(LevelType lt) { return lt.isa<LevelFormat::NOutOfM>(); }
|
||||
inline bool isOrderedLT(LevelType lt) {
|
||||
return !lt.isa<LevelPropNonDefault::Nonordered>();
|
||||
}
|
||||
inline bool isUniqueLT(LevelType lt) {
|
||||
return !lt.isa<LevelPropNonDefault::Nonunique>();
|
||||
}
|
||||
inline bool isWithCrdLT(LevelType lt) { return lt.isWithCrdLT(); }
|
||||
inline bool isWithPosLT(LevelType lt) { return lt.isWithPosLT(); }
|
||||
inline bool isValidLT(LevelType lt) {
|
||||
return LevelType::isValidLvlBits(static_cast<uint64_t>(lt));
|
||||
}
|
||||
inline std::optional<LevelFormat> getLevelFormat(LevelType lt) {
|
||||
LevelFormat fmt = lt.getLvlFmt();
|
||||
if (fmt == LevelFormat::Undef)
|
||||
return std::nullopt;
|
||||
return fmt;
|
||||
}
|
||||
inline uint64_t getN(LevelType lt) { return lt.getN(); }
|
||||
inline uint64_t getM(LevelType lt) { return lt.getM(); }
|
||||
inline bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) {
|
||||
return isNOutOfMLT(lt) && lt.getN() == n && lt.getM() == m;
|
||||
}
|
||||
inline std::string toMLIRString(LevelType lt) { return lt.toMLIRString(); }
|
||||
|
||||
/// Bit manipulations for affine encoding.
|
||||
///
|
||||
|
||||
@@ -34,9 +34,9 @@ static_assert(
|
||||
"MlirSparseTensorLevelFormat (C-API) and LevelFormat (C++) mismatch");
|
||||
|
||||
static_assert(static_cast<int>(MLIR_SPARSE_PROPERTY_NON_ORDERED) ==
|
||||
static_cast<int>(LevelPropertyNondefault::Nonordered) &&
|
||||
static_cast<int>(LevelPropNonDefault::Nonordered) &&
|
||||
static_cast<int>(MLIR_SPARSE_PROPERTY_NON_UNIQUE) ==
|
||||
static_cast<int>(LevelPropertyNondefault::Nonunique),
|
||||
static_cast<int>(LevelPropNonDefault::Nonunique),
|
||||
"MlirSparseTensorLevelProperty (C-API) and "
|
||||
"LevelPropertyNondefault (C++) mismatch");
|
||||
|
||||
@@ -80,7 +80,7 @@ enum MlirSparseTensorLevelFormat
|
||||
mlirSparseTensorEncodingAttrGetLvlFmt(MlirAttribute attr, intptr_t lvl) {
|
||||
LevelType lt =
|
||||
static_cast<LevelType>(mlirSparseTensorEncodingAttrGetLvlType(attr, lvl));
|
||||
return static_cast<MlirSparseTensorLevelFormat>(*getLevelFormat(lt));
|
||||
return static_cast<MlirSparseTensorLevelFormat>(lt.getLvlFmt());
|
||||
}
|
||||
|
||||
int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) {
|
||||
@@ -96,9 +96,9 @@ MlirSparseTensorLevelType mlirSparseTensorEncodingAttrBuildLvlType(
|
||||
const enum MlirSparseTensorLevelPropertyNondefault *properties,
|
||||
unsigned size, unsigned n, unsigned m) {
|
||||
|
||||
std::vector<LevelPropertyNondefault> props;
|
||||
std::vector<LevelPropNonDefault> props;
|
||||
for (unsigned i = 0; i < size; i++)
|
||||
props.push_back(static_cast<LevelPropertyNondefault>(properties[i]));
|
||||
props.push_back(static_cast<LevelPropNonDefault>(properties[i]));
|
||||
|
||||
return static_cast<MlirSparseTensorLevelType>(
|
||||
*buildLevelType(static_cast<LevelFormat>(lvlFmt), props, n, m));
|
||||
|
||||
@@ -88,9 +88,9 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
|
||||
ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)),
|
||||
"expected valid level property (e.g. nonordered, nonunique or high)")
|
||||
if (strVal.compare("nonunique") == 0) {
|
||||
*properties |= static_cast<uint64_t>(LevelPropertyNondefault::Nonunique);
|
||||
*properties |= static_cast<uint64_t>(LevelPropNonDefault::Nonunique);
|
||||
} else if (strVal.compare("nonordered") == 0) {
|
||||
*properties |= static_cast<uint64_t>(LevelPropertyNondefault::Nonordered);
|
||||
*properties |= static_cast<uint64_t>(LevelPropNonDefault::Nonordered);
|
||||
} else {
|
||||
parser.emitError(loc, "unknown level property: ") << strVal;
|
||||
return failure();
|
||||
|
||||
@@ -35,6 +35,14 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::sparse_tensor;
|
||||
|
||||
// Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
|
||||
// well.
|
||||
namespace mlir::sparse_tensor {
|
||||
llvm::hash_code hash_value(LevelType lt) {
|
||||
return llvm::hash_value(static_cast<uint64_t>(lt));
|
||||
}
|
||||
} // namespace mlir::sparse_tensor
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Local Convenience Methods.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -83,11 +91,11 @@ void StorageLayout::foreachField(
|
||||
}
|
||||
// The values array.
|
||||
if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
|
||||
LevelType::Undef)))
|
||||
LevelFormat::Undef)))
|
||||
return;
|
||||
// Put metadata at the end.
|
||||
if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
|
||||
LevelType::Undef)))
|
||||
LevelFormat::Undef)))
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -341,7 +349,7 @@ Level SparseTensorEncodingAttr::getLvlRank() const {
|
||||
|
||||
LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
|
||||
if (!getImpl())
|
||||
return LevelType::Dense;
|
||||
return LevelFormat::Dense;
|
||||
assert(l < getLvlRank() && "Level is out of bounds");
|
||||
return getLvlTypes()[l];
|
||||
}
|
||||
@@ -975,7 +983,7 @@ static SparseTensorEncodingAttr
|
||||
getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
|
||||
SmallVector<LevelType> lts;
|
||||
for (auto lt : enc.getLvlTypes())
|
||||
lts.push_back(*buildLevelType(*getLevelFormat(lt), true, true));
|
||||
lts.push_back(lt.stripProperties());
|
||||
|
||||
return SparseTensorEncodingAttr::get(
|
||||
enc.getContext(), lts,
|
||||
|
||||
@@ -46,7 +46,7 @@ static bool isZeroValue(Value val) {
|
||||
static bool isSparseTensor(Value v) {
|
||||
auto enc = getSparseTensorEncoding(v.getType());
|
||||
return enc && !llvm::all_of(enc.getLvlTypes(),
|
||||
[](auto lt) { return lt == LevelType::Dense; });
|
||||
[](auto lt) { return lt == LevelFormat::Dense; });
|
||||
}
|
||||
static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); }
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ protected:
|
||||
class DenseLevel : public SparseTensorLevel {
|
||||
public:
|
||||
DenseLevel(unsigned tid, Level lvl, Value lvlSize, bool encoded)
|
||||
: SparseTensorLevel(tid, lvl, LevelType::Dense, lvlSize),
|
||||
: SparseTensorLevel(tid, lvl, LevelFormat::Dense, lvlSize),
|
||||
encoded(encoded) {}
|
||||
|
||||
Value peekCrdAt(OpBuilder &, Location, Value pos) const override {
|
||||
@@ -1275,7 +1275,7 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
|
||||
Value sz = stt.hasEncoding() ? b.create<LvlOp>(l, t, lvl).getResult()
|
||||
: b.create<tensor::DimOp>(l, t, lvl).getResult();
|
||||
|
||||
switch (*getLevelFormat(lt)) {
|
||||
switch (lt.getLvlFmt()) {
|
||||
case LevelFormat::Dense:
|
||||
return std::make_unique<DenseLevel>(tid, lvl, sz, stt.hasEncoding());
|
||||
case LevelFormat::Compressed: {
|
||||
@@ -1296,6 +1296,8 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
|
||||
Value crd = genToCoordinates(b, l, t, lvl);
|
||||
return std::make_unique<NOutOfMLevel>(tid, lvl, lt, sz, crd);
|
||||
}
|
||||
case LevelFormat::Undef:
|
||||
llvm_unreachable("undefined level format");
|
||||
}
|
||||
llvm_unreachable("unrecognizable level format");
|
||||
}
|
||||
|
||||
@@ -226,7 +226,8 @@ Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops,
|
||||
syntheticTensor(numInputOutputTensors),
|
||||
numTensors(numInputOutputTensors + 1), numLoops(numLoops),
|
||||
hasSparseOut(false),
|
||||
lvlTypes(numTensors, std::vector<LevelType>(numLoops, LevelType::Undef)),
|
||||
lvlTypes(numTensors,
|
||||
std::vector<LevelType>(numLoops, LevelFormat::Undef)),
|
||||
loopToLvl(numTensors,
|
||||
std::vector<std::optional<Level>>(numLoops, std::nullopt)),
|
||||
lvlToLoop(numTensors,
|
||||
|
||||
@@ -313,11 +313,11 @@ protected:
|
||||
MergerTest3T1L() : MergerTestBase(3, 1) {
|
||||
EXPECT_TRUE(merger.getOutTensorID() == tid(2));
|
||||
// Tensor 0: sparse input vector.
|
||||
merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Compressed);
|
||||
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
|
||||
// Tensor 1: sparse input vector.
|
||||
merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Compressed);
|
||||
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
|
||||
// Tensor 2: dense output vector.
|
||||
merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Dense);
|
||||
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -327,13 +327,13 @@ protected:
|
||||
MergerTest4T1L() : MergerTestBase(4, 1) {
|
||||
EXPECT_TRUE(merger.getOutTensorID() == tid(3));
|
||||
// Tensor 0: sparse input vector.
|
||||
merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Compressed);
|
||||
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
|
||||
// Tensor 1: sparse input vector.
|
||||
merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Compressed);
|
||||
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
|
||||
// Tensor 2: sparse input vector
|
||||
merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Compressed);
|
||||
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
|
||||
// Tensor 3: dense output vector
|
||||
merger.setLevelAndType(tid(3), lid(0), 0, LevelType::Dense);
|
||||
merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -347,11 +347,11 @@ protected:
|
||||
MergerTest3T1LD() : MergerTestBase(3, 1) {
|
||||
EXPECT_TRUE(merger.getOutTensorID() == tid(2));
|
||||
// Tensor 0: sparse input vector.
|
||||
merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Compressed);
|
||||
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
|
||||
// Tensor 1: dense input vector.
|
||||
merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Dense);
|
||||
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense);
|
||||
// Tensor 2: dense output vector.
|
||||
merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Dense);
|
||||
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -365,13 +365,13 @@ protected:
|
||||
MergerTest4T1LU() : MergerTestBase(4, 1) {
|
||||
EXPECT_TRUE(merger.getOutTensorID() == tid(3));
|
||||
// Tensor 0: undef input vector.
|
||||
merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Undef);
|
||||
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
|
||||
// Tensor 1: dense input vector.
|
||||
merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Dense);
|
||||
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense);
|
||||
// Tensor 2: undef input vector.
|
||||
merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Undef);
|
||||
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Undef);
|
||||
// Tensor 3: dense output vector.
|
||||
merger.setLevelAndType(tid(3), lid(0), 0, LevelType::Dense);
|
||||
merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -387,11 +387,11 @@ protected:
|
||||
EXPECT_TRUE(merger.getSynTensorID() == tid(3));
|
||||
merger.setHasSparseOut(true);
|
||||
// Tensor 0: undef input vector.
|
||||
merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Undef);
|
||||
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
|
||||
// Tensor 1: undef input vector.
|
||||
merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Undef);
|
||||
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Undef);
|
||||
// Tensor 2: sparse output vector.
|
||||
merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Compressed);
|
||||
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user