mirror of
https://github.com/intel/llvm.git
synced 2026-01-16 21:55:39 +08:00
[mlir][sparse] use ValueRange instead of std::pair for iterator position. (#90243)
`ValueRange` is more easy to be extended (e.g., for padded iterator).
This commit is contained in:
@@ -222,7 +222,7 @@ public:
|
||||
///
|
||||
SmallVector<Value> getValPosits(TensorId tid) const {
|
||||
SmallVector<Value> batchCrds = iters[tid].back().back()->getBatchCrds();
|
||||
Value lastLvlPos = iters[tid].back().back()->getCurPosition().first;
|
||||
Value lastLvlPos = iters[tid].back().back()->getCurPosition().front();
|
||||
batchCrds.push_back(lastLvlPos);
|
||||
return batchCrds;
|
||||
};
|
||||
|
||||
@@ -94,8 +94,10 @@ public:
|
||||
|
||||
ValueRange getLvlBuffers() const override { return {}; }
|
||||
|
||||
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
|
||||
Value max) const override {
|
||||
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
|
||||
ValueRange parentPos) const override {
|
||||
assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
|
||||
Value p = parentPos.front();
|
||||
Value posLo = MULI(p, lvlSize);
|
||||
return {posLo, lvlSize};
|
||||
}
|
||||
@@ -112,9 +114,9 @@ public:
|
||||
|
||||
ValueRange getLvlBuffers() const override { return {}; }
|
||||
|
||||
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
|
||||
Value max) const override {
|
||||
assert(max == nullptr && "Dense level can not be non-unique.");
|
||||
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange,
|
||||
ValueRange parentPos) const override {
|
||||
assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
|
||||
// No need to linearize the position for non-annotated tensors.
|
||||
return {C_IDX(0), lvlSize};
|
||||
}
|
||||
@@ -127,9 +129,11 @@ public:
|
||||
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
|
||||
|
||||
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
|
||||
Value p, Value max) const override {
|
||||
assert(max == nullptr &&
|
||||
ValueRange parentPos) const override {
|
||||
|
||||
assert(parentPos.size() == 1 &&
|
||||
"compressed level must be the first non-unique level.");
|
||||
Value p = parentPos.front();
|
||||
|
||||
SmallVector<Value> memCrd(batchPrefix);
|
||||
memCrd.push_back(p);
|
||||
@@ -147,11 +151,11 @@ public:
|
||||
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
|
||||
|
||||
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
|
||||
Value p, Value max) const override {
|
||||
assert(max == nullptr &&
|
||||
ValueRange parentPos) const override {
|
||||
assert(parentPos.size() == 1 &&
|
||||
"loose-compressed level must be the first non-unique level.");
|
||||
SmallVector<Value> memCrd(batchPrefix);
|
||||
|
||||
Value p = parentPos.front();
|
||||
p = MULI(p, C_IDX(2));
|
||||
memCrd.push_back(p);
|
||||
Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
|
||||
@@ -168,10 +172,13 @@ public:
|
||||
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
|
||||
|
||||
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
|
||||
Value p, Value segHi) const override {
|
||||
ValueRange parentPos) const override {
|
||||
assert(parentPos.size() == 1 || parentPos.size() == 2);
|
||||
Value p = parentPos.front();
|
||||
Value segHi = parentPos.size() == 2 ? parentPos.back() : nullptr;
|
||||
|
||||
if (segHi == nullptr)
|
||||
return {p, ADDI(p, C_IDX(1))};
|
||||
|
||||
// Use the segHi as the loop upper bound.
|
||||
return {p, segHi};
|
||||
}
|
||||
@@ -184,11 +191,12 @@ public:
|
||||
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
|
||||
|
||||
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
|
||||
Value p, Value max) const override {
|
||||
assert(max == nullptr && isUnique() && "n:m level can not be non-unique.");
|
||||
ValueRange parentPos) const override {
|
||||
assert(parentPos.size() == 1 && isUnique() &&
|
||||
"n:m level can not be non-unique.");
|
||||
// Each n:m blk has exactly n specified elements.
|
||||
auto n = getN(lt);
|
||||
Value posLo = MULI(p, C_IDX(n));
|
||||
Value posLo = MULI(parentPos.front(), C_IDX(n));
|
||||
return {posLo, ADDI(posLo, C_IDX(n))};
|
||||
}
|
||||
};
|
||||
@@ -316,23 +324,21 @@ public:
|
||||
posHi = vs.back();
|
||||
};
|
||||
|
||||
ValuePair getCurPosition() const override { return {getItPos(), nullptr}; }
|
||||
|
||||
void genInitImpl(OpBuilder &b, Location l,
|
||||
const SparseIterator *parent) override {
|
||||
|
||||
if (isBatchIterator() && batchCrds.size() <= stl.lvl)
|
||||
batchCrds.resize(stl.lvl + 1, nullptr);
|
||||
|
||||
Value pos = C_IDX(0);
|
||||
Value hi = nullptr;
|
||||
Value c0 = C_IDX(0);
|
||||
ValueRange pPos = c0;
|
||||
// If the parent iterator is a batch iterator, we also start from 0 (but
|
||||
// on a different batch).
|
||||
if (parent && !parent->isBatchIterator())
|
||||
std::tie(pos, hi) = parent->getCurPosition();
|
||||
pPos = parent->getCurPosition();
|
||||
|
||||
ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
|
||||
std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pos, hi);
|
||||
std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
|
||||
// Seek to the lowest position.
|
||||
seek(posLo);
|
||||
}
|
||||
@@ -406,21 +412,19 @@ public:
|
||||
return {b.getIndexType(), b.getIndexType()};
|
||||
}
|
||||
|
||||
ValuePair getCurPosition() const override { return {getPos(), getSegHi()}; }
|
||||
|
||||
void genInitImpl(OpBuilder &b, Location l,
|
||||
const SparseIterator *parent) override {
|
||||
Value c0 = C_IDX(0);
|
||||
ValueRange pPos = c0;
|
||||
|
||||
Value pos = C_IDX(0);
|
||||
Value hi = nullptr;
|
||||
// If the parent iterator is a batch iterator, we also start from 0 (but
|
||||
// on a different batch).
|
||||
if (parent && !parent->isBatchIterator())
|
||||
std::tie(pos, hi) = parent->getCurPosition();
|
||||
pPos = parent->getCurPosition();
|
||||
|
||||
Value posLo;
|
||||
ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
|
||||
std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pos, hi);
|
||||
std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
|
||||
|
||||
seek({posLo, genSegmentHigh(b, l, posLo)});
|
||||
}
|
||||
@@ -505,7 +509,7 @@ public:
|
||||
|
||||
SmallVector<Value> serialize() const override { return wrap->serialize(); };
|
||||
void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
|
||||
ValuePair getCurPosition() const override { return wrap->getCurPosition(); }
|
||||
ValueRange getCurPosition() const override { return wrap->getCurPosition(); }
|
||||
|
||||
void genInitImpl(OpBuilder &b, Location l,
|
||||
const SparseIterator *parent) override {
|
||||
@@ -756,9 +760,8 @@ public:
|
||||
Value upperBound(OpBuilder &b, Location l) const override {
|
||||
return subSect.subSectSz;
|
||||
}
|
||||
std::pair<Value, Value> getCurPosition() const override {
|
||||
return wrap->getCurPosition();
|
||||
};
|
||||
|
||||
ValueRange getCurPosition() const override { return wrap->getCurPosition(); };
|
||||
|
||||
Value getNxLvlTupleId(OpBuilder &b, Location l) const {
|
||||
if (randomAccessible()) {
|
||||
|
||||
@@ -36,8 +36,9 @@ public:
|
||||
Value iv) const = 0;
|
||||
|
||||
/// Peeks the lower and upper bound to *fully* traverse the level with
|
||||
/// the given position `p` that the immediate parent level is current at.
|
||||
/// Returns a pair of values for *posLo* and *loopHi* respectively.
|
||||
/// the given position `parentPos`, see SparseTensorIterator::getCurPostion(),
|
||||
/// that the immediate parent level is current at. Returns a pair of values
|
||||
/// for *posLo* and *loopHi* respectively.
|
||||
///
|
||||
/// For a dense level, the *posLo* is the linearized position at beginning,
|
||||
/// while *loopHi* is the largest *coordinate*, it also implies that the
|
||||
@@ -45,12 +46,9 @@ public:
|
||||
///
|
||||
/// For a sparse level, [posLo, loopHi) specifies the range of index pointer
|
||||
/// to load coordinate from the coordinate buffer.
|
||||
///
|
||||
/// `bound` is only used when the level is `non-unique` and deduplication is
|
||||
/// required. It specifies the max upper bound of the non-unique segment.
|
||||
virtual std::pair<Value, Value> peekRangeAt(OpBuilder &b, Location l,
|
||||
ValueRange batchPrefix, Value p,
|
||||
Value segHi = Value()) const = 0;
|
||||
ValueRange batchPrefix,
|
||||
ValueRange parentPos) const = 0;
|
||||
|
||||
Level getLevel() const { return lvl; }
|
||||
LevelType getLT() const { return lt; }
|
||||
@@ -199,18 +197,17 @@ public:
|
||||
}
|
||||
virtual Value genNotEndImpl(OpBuilder &b, Location l) = 0;
|
||||
virtual Value derefImpl(OpBuilder &b, Location l) = 0;
|
||||
// Gets the current position and the optional *position high* (for
|
||||
// non-unique iterators), the value is essentially the number of sparse
|
||||
// coordinate that the iterator is current visiting. It should be able to
|
||||
// uniquely identify the sparse range for the next level. See
|
||||
// SparseTensorLevel::peekRangeAt();
|
||||
// Gets the ValueRange that together specifies the current position of the
|
||||
// iterator. For a unique level, the position can be a single index points to
|
||||
// the current coordinate being visited. For a non-unique level, an extra
|
||||
// index for the `segment high` is needed to to specifies the range of
|
||||
// duplicated coordinates. The ValueRange should be able to uniquely identify
|
||||
// the sparse range for the next level. See SparseTensorLevel::peekRangeAt();
|
||||
//
|
||||
// Not every type of iterator supports the operation, e.g., non-empty
|
||||
// subsection iterator does not because it represent a range of coordinates
|
||||
// instead of just one.
|
||||
virtual std::pair<Value, Value> getCurPosition() const {
|
||||
llvm_unreachable("unsupported");
|
||||
};
|
||||
virtual ValueRange getCurPosition() const { return getCursor(); };
|
||||
|
||||
// Returns a pair of values for *upper*, *lower* bound respectively.
|
||||
virtual std::pair<Value, Value> genForCond(OpBuilder &b, Location l) {
|
||||
|
||||
Reference in New Issue
Block a user