[mlir][sparse] use ValueRange instead of std::pair for iterator position. (#90243)

`ValueRange` is more easy to be extended (e.g., for padded iterator).
This commit is contained in:
Peiming Liu
2024-04-29 10:47:07 -07:00
committed by GitHub
parent d566a5cd22
commit 7e2eeb5753
3 changed files with 47 additions and 47 deletions

View File

@@ -222,7 +222,7 @@ public:
///
SmallVector<Value> getValPosits(TensorId tid) const {
SmallVector<Value> batchCrds = iters[tid].back().back()->getBatchCrds();
Value lastLvlPos = iters[tid].back().back()->getCurPosition().first;
Value lastLvlPos = iters[tid].back().back()->getCurPosition().front();
batchCrds.push_back(lastLvlPos);
return batchCrds;
};

View File

@@ -94,8 +94,10 @@ public:
ValueRange getLvlBuffers() const override { return {}; }
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
Value max) const override {
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
ValueRange parentPos) const override {
assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
Value p = parentPos.front();
Value posLo = MULI(p, lvlSize);
return {posLo, lvlSize};
}
@@ -112,9 +114,9 @@ public:
ValueRange getLvlBuffers() const override { return {}; }
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
Value max) const override {
assert(max == nullptr && "Dense level can not be non-unique.");
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange,
ValueRange parentPos) const override {
assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
// No need to linearize the position for non-annotated tensors.
return {C_IDX(0), lvlSize};
}
@@ -127,9 +129,11 @@ public:
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value p, Value max) const override {
assert(max == nullptr &&
ValueRange parentPos) const override {
assert(parentPos.size() == 1 &&
"compressed level must be the first non-unique level.");
Value p = parentPos.front();
SmallVector<Value> memCrd(batchPrefix);
memCrd.push_back(p);
@@ -147,11 +151,11 @@ public:
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value p, Value max) const override {
assert(max == nullptr &&
ValueRange parentPos) const override {
assert(parentPos.size() == 1 &&
"loose-compressed level must be the first non-unique level.");
SmallVector<Value> memCrd(batchPrefix);
Value p = parentPos.front();
p = MULI(p, C_IDX(2));
memCrd.push_back(p);
Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
@@ -168,10 +172,13 @@ public:
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value p, Value segHi) const override {
ValueRange parentPos) const override {
assert(parentPos.size() == 1 || parentPos.size() == 2);
Value p = parentPos.front();
Value segHi = parentPos.size() == 2 ? parentPos.back() : nullptr;
if (segHi == nullptr)
return {p, ADDI(p, C_IDX(1))};
// Use the segHi as the loop upper bound.
return {p, segHi};
}
@@ -184,11 +191,12 @@ public:
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value p, Value max) const override {
assert(max == nullptr && isUnique() && "n:m level can not be non-unique.");
ValueRange parentPos) const override {
assert(parentPos.size() == 1 && isUnique() &&
"n:m level can not be non-unique.");
// Each n:m blk has exactly n specified elements.
auto n = getN(lt);
Value posLo = MULI(p, C_IDX(n));
Value posLo = MULI(parentPos.front(), C_IDX(n));
return {posLo, ADDI(posLo, C_IDX(n))};
}
};
@@ -316,23 +324,21 @@ public:
posHi = vs.back();
};
ValuePair getCurPosition() const override { return {getItPos(), nullptr}; }
void genInitImpl(OpBuilder &b, Location l,
const SparseIterator *parent) override {
if (isBatchIterator() && batchCrds.size() <= stl.lvl)
batchCrds.resize(stl.lvl + 1, nullptr);
Value pos = C_IDX(0);
Value hi = nullptr;
Value c0 = C_IDX(0);
ValueRange pPos = c0;
// If the parent iterator is a batch iterator, we also start from 0 (but
// on a different batch).
if (parent && !parent->isBatchIterator())
std::tie(pos, hi) = parent->getCurPosition();
pPos = parent->getCurPosition();
ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pos, hi);
std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
// Seek to the lowest position.
seek(posLo);
}
@@ -406,21 +412,19 @@ public:
return {b.getIndexType(), b.getIndexType()};
}
ValuePair getCurPosition() const override { return {getPos(), getSegHi()}; }
void genInitImpl(OpBuilder &b, Location l,
const SparseIterator *parent) override {
Value c0 = C_IDX(0);
ValueRange pPos = c0;
Value pos = C_IDX(0);
Value hi = nullptr;
// If the parent iterator is a batch iterator, we also start from 0 (but
// on a different batch).
if (parent && !parent->isBatchIterator())
std::tie(pos, hi) = parent->getCurPosition();
pPos = parent->getCurPosition();
Value posLo;
ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pos, hi);
std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
seek({posLo, genSegmentHigh(b, l, posLo)});
}
@@ -505,7 +509,7 @@ public:
SmallVector<Value> serialize() const override { return wrap->serialize(); };
void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
ValuePair getCurPosition() const override { return wrap->getCurPosition(); }
ValueRange getCurPosition() const override { return wrap->getCurPosition(); }
void genInitImpl(OpBuilder &b, Location l,
const SparseIterator *parent) override {
@@ -756,9 +760,8 @@ public:
Value upperBound(OpBuilder &b, Location l) const override {
return subSect.subSectSz;
}
std::pair<Value, Value> getCurPosition() const override {
return wrap->getCurPosition();
};
ValueRange getCurPosition() const override { return wrap->getCurPosition(); };
Value getNxLvlTupleId(OpBuilder &b, Location l) const {
if (randomAccessible()) {

View File

@@ -36,8 +36,9 @@ public:
Value iv) const = 0;
/// Peeks the lower and upper bound to *fully* traverse the level with
/// the given position `p` that the immediate parent level is current at.
/// Returns a pair of values for *posLo* and *loopHi* respectively.
/// the given position `parentPos`, see SparseTensorIterator::getCurPostion(),
/// that the immediate parent level is current at. Returns a pair of values
/// for *posLo* and *loopHi* respectively.
///
/// For a dense level, the *posLo* is the linearized position at beginning,
/// while *loopHi* is the largest *coordinate*, it also implies that the
@@ -45,12 +46,9 @@ public:
///
/// For a sparse level, [posLo, loopHi) specifies the range of index pointer
/// to load coordinate from the coordinate buffer.
///
/// `bound` is only used when the level is `non-unique` and deduplication is
/// required. It specifies the max upper bound of the non-unique segment.
virtual std::pair<Value, Value> peekRangeAt(OpBuilder &b, Location l,
ValueRange batchPrefix, Value p,
Value segHi = Value()) const = 0;
ValueRange batchPrefix,
ValueRange parentPos) const = 0;
Level getLevel() const { return lvl; }
LevelType getLT() const { return lt; }
@@ -199,18 +197,17 @@ public:
}
virtual Value genNotEndImpl(OpBuilder &b, Location l) = 0;
virtual Value derefImpl(OpBuilder &b, Location l) = 0;
// Gets the current position and the optional *position high* (for
// non-unique iterators), the value is essentially the number of sparse
// coordinate that the iterator is current visiting. It should be able to
// uniquely identify the sparse range for the next level. See
// SparseTensorLevel::peekRangeAt();
// Gets the ValueRange that together specifies the current position of the
// iterator. For a unique level, the position can be a single index points to
// the current coordinate being visited. For a non-unique level, an extra
// index for the `segment high` is needed to to specifies the range of
// duplicated coordinates. The ValueRange should be able to uniquely identify
// the sparse range for the next level. See SparseTensorLevel::peekRangeAt();
//
// Not every type of iterator supports the operation, e.g., non-empty
// subsection iterator does not because it represent a range of coordinates
// instead of just one.
virtual std::pair<Value, Value> getCurPosition() const {
llvm_unreachable("unsupported");
};
virtual ValueRange getCurPosition() const { return getCursor(); };
// Returns a pair of values for *upper*, *lower* bound respectively.
virtual std::pair<Value, Value> genForCond(OpBuilder &b, Location l) {