[mlir][sparse] Adding SparseTensorType::{operator==, hasSameDimToLvlMap}

Depends On D143800

Reviewed By: aartbik, Peiming

Differential Revision: https://reviews.llvm.org/D144052
This commit is contained in:
wren romano
2023-02-15 12:03:52 -08:00
parent ee437afa91
commit bb4fc6b6d6
2 changed files with 27 additions and 9 deletions

View File

@@ -105,6 +105,16 @@ public:
/// implicit conversion.
RankedTensorType getRankedTensorType() const { return rtp; }
bool operator==(const SparseTensorType &other) const {
// All other fields are derived from `rtp` and therefore don't need
// to be checked.
return rtp == other.rtp;
}
bool operator!=(const SparseTensorType &other) const {
return !(*this == other);
}
MLIRContext *getContext() const { return rtp.getContext(); }
Type getElementType() const { return rtp.getElementType(); }
@@ -130,6 +140,8 @@ public:
bool isIdentity() const { return !dim2lvl; }
/// Returns the dimToLvl mapping (or the null-map for the identity).
/// If you intend to compare the results of this method for equality,
/// see `hasSameDimToLvlMap` instead.
AffineMap getDimToLvlMap() const { return dim2lvl; }
/// Returns the dimToLvl mapping, where the identity map is expanded out
@@ -142,6 +154,17 @@ public:
: AffineMap::getMultiDimIdentityMap(getDimRank(), getContext());
}
/// Returns true iff the two types have the same mapping. This method
/// takes care to handle identity maps properly, so it should be preferred
/// over using `getDimToLvlMap` followed by `AffineMap::operator==`.
bool hasSameDimToLvlMap(const SparseTensorType &other) const {
// If the maps are the identity, then we need to check the rank
// to be sure they're the same size identity. (And since identity
// means dimRank==lvlRank, we use lvlRank as a minor optimization.)
return isIdentity() ? (other.isIdentity() && lvlRank == other.lvlRank)
: (dim2lvl == other.dim2lvl);
}
/// Returns the dimension-rank.
Dimension getDimRank() const { return rtp.getRank(); }

View File

@@ -48,12 +48,6 @@ static bool isSparseTensor(OpOperand *op) {
llvm::is_contained(enc.getDimLevelType(), DimLevelType::Compressed);
}
static bool hasSameDimOrdering(RankedTensorType rtp1, RankedTensorType rtp2) {
assert(rtp1.getRank() == rtp2.getRank());
return SparseTensorType(rtp1).getDimToLvlMap() ==
SparseTensorType(rtp2).getDimToLvlMap();
}
// Helper method to find zero/uninitialized allocation.
static bool isAlloc(OpOperand *op, bool isZero) {
Value val = op->get();
@@ -796,8 +790,9 @@ private:
// 2. the src tensor is not ordered in the same way as the target
// tensor (e.g., src tensor is not ordered or src tensor haves a different
// dimOrdering).
if (!isUniqueCOOType(srcRTT) && !(SparseTensorType(srcRTT).isAllOrdered() &&
hasSameDimOrdering(srcRTT, dstTp))) {
if (const SparseTensorType srcTp(srcRTT);
!isUniqueCOOType(srcRTT) &&
!(srcTp.isAllOrdered() && srcTp.hasSameDimToLvlMap(dstTp))) {
// Construct a COO tensor from the src tensor.
// TODO: there may be cases for which more efficiently without
// going through an intermediate COO, such as cases that only change
@@ -841,7 +836,7 @@ private:
// Sort the COO tensor so that its elements are ordered via increasing
// indices for the storage ordering of the dst tensor. Use SortCoo if the
// COO tensor has the same dim ordering as the dst tensor.
if (dimRank > 1 && hasSameDimOrdering(srcTp, dstTp)) {
if (dimRank > 1 && srcTp.hasSameDimToLvlMap(dstTp)) {
MemRefType indTp =
get1DMemRefType(getIndexOverheadType(rewriter, encSrc),
/*withLayout=*/false);