[mlir][sparse] Renaming "pointer/index" to "position/coordinate"

The old "pointer/index" names often cause confusion since these names clash with names of unrelated things in MLIR; so this change rectifies this by changing everything to use "position/coordinate" terminology instead.

In addition to the basic terminology, there have also been various conventions for making certain distinctions like: (1) the overall storage for coordinates in the sparse-tensor, vs the particular collection of coordinates of a given element; and (2) particular coordinates given as a `Value` or `TypedValue<MemRefType>`, vs particular coordinates given as `ValueRange` or similar.  I have striven to maintain these distinctions
as follows:

  * "p/c" are used for individual position/coordinate values, when there is no risk of confusion.  (Just like we use "d/l" to abbreviate "dim/lvl".)

  * "pos/crd" are used for individual position/coordinate values, when a longer name is helpful to avoid ambiguity or to form compound names (e.g., "parentPos").  (Just like we use "dim/lvl" when we need a longer form of "d/l".)

    I have also used these forms for a handful of compound names where the old name had been using a three-letter form previously, even though a longer form would be more appropriate.  I've avoided renaming these to use a longer form purely for expediency sake, since changing them would require a cascade of other renamings.  They should be updated to follow the new naming scheme, but that can be done in future patches.

  * "coords" is used for the complete collection of crd values associated with a single element.  In the runtime library this includes both `std::vector` and raw pointer representations.  In the compiler, this is used specifically for buffer variables with C++ type `Value`, `TypedValue<MemRefType>`, etc.

    The bare form "coords" is discouraged, since it fails to make the dim/lvl distinction; so the compound names "dimCoords/lvlCoords" should be used instead.  (Though there may exist a rare few cases where is is appropriate to be intentionally ambiguous about what coordinate-space the coords live in; in which case the bare "coords" is appropriate.)

    There is seldom the need for the pos variant of this notion.  In most circumstances we use the term "cursor", since the same buffer is reused for a 'moving' pos-collection.

  * "dcvs/lcvs" is used in the compiler as the `ValueRange` analogue of "dimCoords/lvlCoords".  (The "vs" stands for "`Value`s".)  I haven't found the need for it, but "pvs" would be the obvious name for a pos-`ValueRange`.

    The old "ind"-vs-"ivs" naming scheme does not seem to have been sustained in more recent code, which instead prefers other mnemonics (e.g., adding "Buf" to the end of the names for `TypeValue<MemRefType>`).  I have cleaned up a lot of these to follow the "coords"-vs-"cvs" naming scheme, though haven't done an exhaustive cleanup.

  * "positions/coordinates" are used for larger collections of pos/crd values; in particular, these are used when referring to the complete sparse-tensor storage components.

    I also prefer to use these unabbreviated names in the documentation, unless there is some specific reason why using the abbreviated forms helps resolve ambiguity.

In addition to making this terminology change, this change also does some cleanup along the way:
  * correcting the dim/lvl terminology in certain places.
  * adding `const` when it requires no other code changes.
  * miscellaneous cleanup that was entailed in order to make the proper distinctions.  Most of these are in CodegenUtils.{h,cpp}

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D144773
This commit is contained in:
wren romano
2023-03-06 12:19:41 -08:00
parent 3e00f24f63
commit 84cd51bb97
92 changed files with 2902 additions and 2807 deletions

View File

@@ -67,13 +67,13 @@ static void flattenOperands(ValueRange operands,
}
}
/// Generates a load with proper index typing.
/// Generates a load with proper `index` typing.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) {
idx = genCast(builder, loc, idx, builder.getIndexType());
return builder.create<memref::LoadOp>(loc, mem, idx);
}
/// Generates a store with proper index typing and (for indices) proper value.
/// Generates a store with proper `index` typing and proper value.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem,
Value idx) {
idx = genCast(builder, loc, idx, builder.getIndexType());
@@ -111,8 +111,7 @@ static Value sizeFromTensorAtDim(OpBuilder &builder, Location loc,
// accounting for the reordering applied to the sparse storage.
// FIXME: `toStoredDim` is deprecated.
const Level lvl = toStoredDim(stt, dim);
// FIXME: this method seems to get *level* sizes, but the name is confusing
return desc.getDimSize(builder, loc, lvl);
return desc.getLvlSize(builder, loc, lvl);
}
// Gets the dimension size at the given stored level 'lvl', either as a
@@ -150,13 +149,12 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
for (Level l = startLvl; l < lvlRank; l++) {
const auto dlt = stt.getLvlType(l);
if (isCompressedDLT(dlt)) {
// Append linear x pointers, initialized to zero. Since each compressed
// Append linear x positions, initialized to zero. Since each compressed
// dimension initially already has a single zero entry, this maintains
// the desired "linear + 1" length property at all times.
Type ptrType = stt.getPointerType();
Value ptrZero = constantZero(builder, loc, ptrType);
createPushback(builder, loc, desc, SparseTensorFieldKind::PtrMemRef, l,
ptrZero, linear);
Value posZero = constantZero(builder, loc, stt.getPosType());
createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, l,
posZero, linear);
return;
}
if (isSingletonDLT(dlt)) {
@@ -215,32 +213,32 @@ static void createAllocFields(OpBuilder &builder, Location loc,
// size based on available information. Otherwise we just
// initialize a few elements to start the reallocation chain.
// TODO: refine this
Value ptrHeuristic, idxHeuristic, valHeuristic;
Value posHeuristic, crdHeuristic, valHeuristic;
if (stt.isAllDense()) {
valHeuristic = dimSizes[0];
for (const Value sz : ArrayRef<Value>{dimSizes}.drop_front())
valHeuristic = builder.create<arith::MulIOp>(loc, valHeuristic, sz);
} else if (sizeHint) {
if (getCOOStart(stt.getEncoding()) == 0) {
ptrHeuristic = constantIndex(builder, loc, 2);
idxHeuristic = builder.create<arith::MulIOp>(
posHeuristic = constantIndex(builder, loc, 2);
crdHeuristic = builder.create<arith::MulIOp>(
loc, constantIndex(builder, loc, dimRank), sizeHint); // AOS
} else if (dimRank == 2 && stt.isDenseLvl(0) && stt.isCompressedLvl(1)) {
ptrHeuristic = builder.create<arith::AddIOp>(
posHeuristic = builder.create<arith::AddIOp>(
loc, sizeHint, constantIndex(builder, loc, 1));
idxHeuristic = sizeHint;
crdHeuristic = sizeHint;
} else {
ptrHeuristic = idxHeuristic = constantIndex(builder, loc, 16);
posHeuristic = crdHeuristic = constantIndex(builder, loc, 16);
}
valHeuristic = sizeHint;
} else {
ptrHeuristic = idxHeuristic = valHeuristic =
posHeuristic = crdHeuristic = valHeuristic =
constantIndex(builder, loc, 16);
}
foreachFieldAndTypeInSparseTensor(
stt,
[&builder, &fields, stt, loc, ptrHeuristic, idxHeuristic, valHeuristic,
[&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
enableInit](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
Level /*lvl*/, DimLevelType /*dlt*/) -> bool {
assert(fields.size() == fIdx);
@@ -249,13 +247,13 @@ static void createAllocFields(OpBuilder &builder, Location loc,
case SparseTensorFieldKind::StorageSpec:
field = SparseTensorSpecifier::getInitValue(builder, loc, stt);
break;
case SparseTensorFieldKind::PtrMemRef:
case SparseTensorFieldKind::IdxMemRef:
case SparseTensorFieldKind::PosMemRef:
case SparseTensorFieldKind::CrdMemRef:
case SparseTensorFieldKind::ValMemRef:
field = createAllocation(
builder, loc, fType.cast<MemRefType>(),
(fKind == SparseTensorFieldKind::PtrMemRef) ? ptrHeuristic
: (fKind == SparseTensorFieldKind::IdxMemRef) ? idxHeuristic
(fKind == SparseTensorFieldKind::PosMemRef) ? posHeuristic
: (fKind == SparseTensorFieldKind::CrdMemRef) ? crdHeuristic
: valHeuristic,
enableInit);
break;
@@ -269,87 +267,89 @@ static void createAllocFields(OpBuilder &builder, Location loc,
MutSparseTensorDescriptor desc(stt, fields);
// Initialize the storage scheme to an empty tensor. Initialized memSizes
// to all zeros, sets the dimSizes to known values and gives all pointer
// to all zeros, sets the dimSizes to known values and gives all position
// fields an initial zero entry, so that it is easier to maintain the
// "linear + 1" length property.
Value ptrZero = constantZero(builder, loc, stt.getPointerType());
Value posZero = constantZero(builder, loc, stt.getPosType());
for (Level lvlRank = stt.getLvlRank(), l = 0; l < lvlRank; l++) {
// Fills dim sizes array.
// FIXME: this method seems to set *level* sizes, but the name is confusing
// FIXME: `toOrigDim` is deprecated.
desc.setDimSize(builder, loc, l, dimSizes[toOrigDim(stt, l)]);
// Pushes a leading zero to pointers memref.
desc.setLvlSize(builder, loc, l, dimSizes[toOrigDim(stt, l)]);
// Pushes a leading zero to positions memref.
if (stt.isCompressedLvl(l))
createPushback(builder, loc, desc, SparseTensorFieldKind::PtrMemRef, l,
ptrZero);
createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, l,
posZero);
}
allocSchemeForRank(builder, loc, desc, /*rank=*/0);
}
/// Helper method that generates block specific to compressed case:
///
/// plo = pointers[l][pos[l-1]]
/// phi = pointers[l][pos[l-1]+1]
/// msz = indices[l].size()
/// if (plo < phi) {
/// present = indices[l][phi-1] == i[l]
/// // given: parentPos = posCursor[lvl-1]
/// pstart = desc.positions[lvl][parentPos]
/// pstop = desc.positions[lvl][parentPos+1]
/// plast = pstop - 1
/// msz = desc.coordinates[lvl].size()
/// if (pstart < pstop) {
/// isPresent = (desc.coordinates[lvl][plast] == lvlCoords[lvl])
/// } else { // first insertion
/// present = false
/// pointers[l][pos[l-1]] = msz
/// isPresent = false
/// desc.positions[lvl][parentPos] = msz
/// }
/// if (present) { // index already present
/// next = phi-1
/// if (isPresent) { // coordinate is already present
/// pnext = plast
/// } else {
/// indices[l].push_back(i[l])
/// pointers[l][pos[l-1]+1] = msz+1
/// next = msz
/// <prepare level l + 1>
/// desc.coordinates[lvl].push_back(lvlCoords[lvl])
/// desc.positions[lvl][parentPos+1] = msz+1
/// pnext = msz
/// <prepare level lvl+1>
/// }
/// pos[l] = next
/// posCursor[lvl] = pnext
static Value genCompressed(OpBuilder &builder, Location loc,
MutSparseTensorDescriptor desc, ValueRange indices,
Value value, Value pos, Level lvl) {
MutSparseTensorDescriptor desc, ValueRange lvlCoords,
Value /*unused*/, Value parentPos, Level lvl) {
const SparseTensorType stt(desc.getRankedTensorType());
const Level lvlRank = stt.getLvlRank();
assert(lvl < lvlRank && "Level is out of bounds");
assert(indices.size() == static_cast<size_t>(lvlRank) &&
assert(lvlCoords.size() == static_cast<size_t>(lvlRank) &&
"Level-rank mismatch");
SmallVector<Type> types;
Type indexType = builder.getIndexType();
Type boolType = builder.getIntegerType(1);
unsigned idxIndex;
unsigned idxStride;
std::tie(idxIndex, idxStride) = desc.getIdxMemRefIndexAndStride(lvl);
Value one = constantIndex(builder, loc, 1);
Value pp1 = builder.create<arith::AddIOp>(loc, pos, one);
Value plo = genLoad(builder, loc, desc.getPtrMemRef(lvl), pos);
Value phi = genLoad(builder, loc, desc.getPtrMemRef(lvl), pp1);
Value msz = desc.getIdxMemSize(builder, loc, lvl);
Value idxStrideC;
if (idxStride > 1) {
idxStrideC = constantIndex(builder, loc, idxStride);
msz = builder.create<arith::DivUIOp>(loc, msz, idxStrideC);
}
Value phim1 = builder.create<arith::SubIOp>(
loc, genCast(builder, loc, phi, indexType), one);
unsigned crdFidx;
unsigned crdStride;
std::tie(crdFidx, crdStride) = desc.getCrdMemRefIndexAndStride(lvl);
const Value one = constantIndex(builder, loc, 1);
const Value pp1 = builder.create<arith::AddIOp>(loc, parentPos, one);
const Value positionsAtLvl = desc.getPosMemRef(lvl);
const Value pstart = genLoad(builder, loc, positionsAtLvl, parentPos);
const Value pstop = genLoad(builder, loc, positionsAtLvl, pp1);
const Value crdMsz = desc.getCrdMemSize(builder, loc, lvl);
const Value crdStrideC =
crdStride > 1 ? constantIndex(builder, loc, crdStride) : Value();
const Value msz =
crdStrideC ? builder.create<arith::DivUIOp>(loc, crdMsz, crdStrideC)
: crdMsz;
const Value plast = builder.create<arith::SubIOp>(
loc, genCast(builder, loc, pstop, indexType), one);
// Conditional expression.
Value lt =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, plo, phi);
Value lt = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
pstart, pstop);
types.push_back(boolType);
scf::IfOp ifOp1 = builder.create<scf::IfOp>(loc, types, lt, /*else*/ true);
types.pop_back();
builder.setInsertionPointToStart(&ifOp1.getThenRegion().front());
Value crd = genLoad(
builder, loc, desc.getMemRefField(idxIndex),
idxStride > 1 ? builder.create<arith::MulIOp>(loc, phim1, idxStrideC)
: phim1);
Value crd =
genLoad(builder, loc, desc.getMemRefField(crdFidx),
crdStrideC ? builder.create<arith::MulIOp>(loc, plast, crdStrideC)
: plast);
Value eq = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, genCast(builder, loc, crd, indexType),
indices[lvl]);
lvlCoords[lvl]);
builder.create<scf::YieldOp>(loc, eq);
builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
if (lvl > 0)
genStore(builder, loc, msz, desc.getPtrMemRef(lvl), pos);
genStore(builder, loc, msz, positionsAtLvl, parentPos);
builder.create<scf::YieldOp>(loc, constantI1(builder, loc, false));
builder.setInsertionPointAfter(ifOp1);
// If present construct. Note that for a non-unique dimension level, we
@@ -363,22 +363,22 @@ static Value genCompressed(OpBuilder &builder, Location loc,
const Value p = stt.isUniqueLvl(lvl) ? ifOp1.getResult(0)
: constantI1(builder, loc, false);
scf::IfOp ifOp2 = builder.create<scf::IfOp>(loc, types, p, /*else*/ true);
// If present (fields unaffected, update next to phim1).
// If present (fields unaffected, update pnext to plast).
builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
// FIXME: This does not looks like a clean way, but probably the most
// efficient way.
desc.getFields().push_back(phim1);
desc.getFields().push_back(plast);
builder.create<scf::YieldOp>(loc, desc.getFields());
desc.getFields().pop_back();
// If !present (changes fields, update next).
// If !present (changes fields, update pnext).
builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
Value mszp1 = builder.create<arith::AddIOp>(loc, msz, one);
genStore(builder, loc, mszp1, desc.getPtrMemRef(lvl), pp1);
createPushback(builder, loc, desc, SparseTensorFieldKind::IdxMemRef, lvl,
indices[lvl]);
// Prepare the next dimension "as needed".
genStore(builder, loc, mszp1, positionsAtLvl, pp1);
createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, lvl,
lvlCoords[lvl]);
// Prepare the next level "as needed".
if ((lvl + 1) < lvlRank)
allocSchemeForRank(builder, loc, desc, lvl + 1);
@@ -415,40 +415,41 @@ static void genInsertBody(OpBuilder &builder, ModuleOp module,
const SparseTensorType stt(rtp);
const Level lvlRank = stt.getLvlRank();
// Construct fields and indices arrays from parameters.
// Extract fields and coordinates from args.
SmallVector<Value> fields = llvm::to_vector(args.drop_back(lvlRank + 1));
MutSparseTensorDescriptor desc(rtp, fields);
const SmallVector<Value> indices =
const SmallVector<Value> coordinates =
llvm::to_vector(args.take_back(lvlRank + 1).drop_back());
Value value = args.back();
Value pos = constantZero(builder, loc, builder.getIndexType());
Value parentPos = constantZero(builder, loc, builder.getIndexType());
// Generate code for every level.
for (Level l = 0; l < lvlRank; l++) {
const auto dlt = stt.getLvlType(l);
if (isCompressedDLT(dlt)) {
// Create:
// if (!present) {
// indices[l].push_back(i[l])
// <update pointers and prepare level l + 1>
// coordinates[l].push_back(coords[l])
// <update positions and prepare level l + 1>
// }
// pos[l] = indices.size() - 1
// <insert @ pos[l] at next level l + 1>
pos = genCompressed(builder, loc, desc, indices, value, pos, l);
// positions[l] = coordinates.size() - 1
// <insert @ positions[l] at next level l + 1>
parentPos =
genCompressed(builder, loc, desc, coordinates, value, parentPos, l);
} else if (isSingletonDLT(dlt)) {
// Create:
// indices[l].push_back(i[l])
// pos[l] = pos[l-1]
// <insert @ pos[l] at next level l + 1>
createPushback(builder, loc, desc, SparseTensorFieldKind::IdxMemRef, l,
indices[l]);
// coordinates[l].push_back(coords[l])
// positions[l] = positions[l-1]
// <insert @ positions[l] at next level l + 1>
createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, l,
coordinates[l]);
} else {
assert(isDenseDLT(dlt));
// Construct the new position as:
// pos[l] = size * pos[l-1] + i[l]
// <insert @ pos[l] at next level l + 1>
// positions[l] = size * positions[l-1] + coords[l]
// <insert @ positions[l] at next level l + 1>
Value size = sizeFromTensorAtLvl(builder, loc, desc, l);
Value mult = builder.create<arith::MulIOp>(loc, size, pos);
pos = builder.create<arith::AddIOp>(loc, mult, indices[l]);
Value mult = builder.create<arith::MulIOp>(loc, size, parentPos);
parentPos = builder.create<arith::AddIOp>(loc, mult, coordinates[l]);
}
}
// Reached the actual value append/insert.
@@ -456,7 +457,7 @@ static void genInsertBody(OpBuilder &builder, ModuleOp module,
createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef,
std::nullopt, value);
else
genStore(builder, loc, value, desc.getValMemRef(), pos);
genStore(builder, loc, value, desc.getValMemRef(), parentPos);
builder.create<func::ReturnOp>(loc, fields);
}
@@ -464,19 +465,18 @@ static void genInsertBody(OpBuilder &builder, ModuleOp module,
/// function doesn't exist yet, call `createFunc` to generate the function.
static void genInsertionCallHelper(OpBuilder &builder,
MutSparseTensorDescriptor desc,
SmallVectorImpl<Value> &indices, Value value,
SmallVectorImpl<Value> &lcvs, Value value,
func::FuncOp insertPoint,
StringRef namePrefix,
FuncGeneratorType createFunc) {
// The mangled name of the function has this format:
// <namePrefix>_<DLT>_<shape>_<ordering>_<eltType>
// _<indexBitWidth>_<pointerBitWidth>
// <namePrefix>_<DLT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
const SparseTensorType stt(desc.getRankedTensorType());
SmallString<32> nameBuffer;
llvm::raw_svector_ostream nameOstream(nameBuffer);
nameOstream << namePrefix;
assert(static_cast<size_t>(stt.getLvlRank()) == indices.size());
const Level lvlRank = stt.getLvlRank();
assert(lcvs.size() == static_cast<size_t>(lvlRank));
for (Level l = 0; l < lvlRank; l++)
nameOstream << toMLIRString(stt.getLvlType(l)) << "_";
// Static dim sizes are used in the generated code while dynamic sizes are
@@ -488,7 +488,7 @@ static void genInsertionCallHelper(OpBuilder &builder,
if (!stt.isIdentity())
nameOstream << stt.getDimToLvlMap() << "_";
nameOstream << stt.getElementType() << "_";
nameOstream << stt.getIndexBitWidth() << "_" << stt.getPointerBitWidth();
nameOstream << stt.getCrdWidth() << "_" << stt.getPosWidth();
// Look up the function.
ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
@@ -496,9 +496,9 @@ static void genInsertionCallHelper(OpBuilder &builder,
auto result = SymbolRefAttr::get(context, nameOstream.str());
auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
// Construct parameters for fields and indices.
// Construct operands: fields, coords, and value.
SmallVector<Value> operands = llvm::to_vector(desc.getFields());
operands.append(indices);
operands.append(lcvs);
operands.push_back(value);
Location loc = insertPoint.getLoc();
@@ -531,31 +531,31 @@ static void genEndInsert(OpBuilder &builder, Location loc,
for (Level l = 0; l < lvlRank; l++) {
const auto dlt = stt.getLvlType(l);
if (isCompressedDLT(dlt)) {
// Compressed dimensions need a pointer cleanup for all entries
// Compressed dimensions need a position cleanup for all entries
// that were not visited during the insertion pass.
//
// TODO: avoid cleanup and keep compressed scheme consistent at all
// times?
//
if (l > 0) {
Type ptrType = stt.getPointerType();
Value ptrMemRef = desc.getPtrMemRef(l);
Value hi = desc.getPtrMemSize(builder, loc, l);
Type posType = stt.getPosType();
Value posMemRef = desc.getPosMemRef(l);
Value hi = desc.getPosMemSize(builder, loc, l);
Value zero = constantIndex(builder, loc, 0);
Value one = constantIndex(builder, loc, 1);
// Vector of only one, but needed by createFor's prototype.
SmallVector<Value, 1> inits{genLoad(builder, loc, ptrMemRef, zero)};
SmallVector<Value, 1> inits{genLoad(builder, loc, posMemRef, zero)};
scf::ForOp loop = createFor(builder, loc, hi, inits, one);
Value i = loop.getInductionVar();
Value oldv = loop.getRegionIterArg(0);
Value newv = genLoad(builder, loc, ptrMemRef, i);
Value ptrZero = constantZero(builder, loc, ptrType);
Value newv = genLoad(builder, loc, posMemRef, i);
Value posZero = constantZero(builder, loc, posType);
Value cond = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, newv, ptrZero);
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, TypeRange(ptrType),
loc, arith::CmpIPredicate::eq, newv, posZero);
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, TypeRange(posType),
cond, /*else*/ true);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
genStore(builder, loc, oldv, ptrMemRef, i);
genStore(builder, loc, oldv, posMemRef, i);
builder.create<scf::YieldOp>(loc, oldv);
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
builder.create<scf::YieldOp>(loc, newv);
@@ -684,12 +684,12 @@ public:
LogicalResult
matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
std::optional<int64_t> index = op.getConstantIndex();
if (!index || !getSparseTensorEncoding(adaptor.getSource().getType()))
std::optional<int64_t> dim = op.getConstantIndex();
if (!dim || !getSparseTensorEncoding(adaptor.getSource().getType()))
return failure();
auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
auto sz = sizeFromTensorAtDim(rewriter, op.getLoc(), desc, *index);
auto sz = sizeFromTensorAtDim(rewriter, op.getLoc(), desc, *dim);
rewriter.replaceOp(op, sz);
return success();
@@ -843,7 +843,7 @@ public:
rewriter.create<linalg::FillOp>(
loc, ValueRange{constantZero(rewriter, loc, boolType)},
ValueRange{filled});
// Replace expansion op with these buffers and initial index.
// Replace expansion op with these buffers and initial coordinate.
assert(op.getNumResults() == 4);
rewriter.replaceOp(op, {values, filled, added, zero});
return success();
@@ -866,9 +866,9 @@ public:
Value count = adaptor.getCount();
const SparseTensorType dstType(desc.getRankedTensorType());
Type eltType = dstType.getElementType();
// Prepare indices.
SmallVector<Value> indices(adaptor.getIndices());
// If the innermost level is ordered, we need to sort the indices
// Prepare level-coords.
SmallVector<Value> lcvs(adaptor.getLvlCoords());
// If the innermost level is ordered, we need to sort the coordinates
// in the "added" array prior to applying the compression.
if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
rewriter.create<SortOp>(loc, count, ValueRange{added}, ValueRange{},
@@ -880,26 +880,25 @@ public:
//
// Generate
// out_memrefs = for (i = 0; i < count; i++)(in_memrefs) {
// index = added[i];
// value = values[index];
// insert({prev_indices, index}, value);
// new_memrefs = insert(in_memrefs, {prev_indices, index}, value);
// values[index] = 0;
// filled[index] = false;
// crd = added[i];
// value = values[crd];
// insert({lvlCoords, crd}, value);
// new_memrefs = insert(in_memrefs, {lvlCoords, crd}, value);
// values[crd] = 0;
// filled[crd] = false;
// yield new_memrefs
// }
scf::ForOp loop = createFor(rewriter, loc, count, desc.getFields());
Value i = loop.getInductionVar();
Value index = genLoad(rewriter, loc, added, i);
Value value = genLoad(rewriter, loc, values, index);
indices.push_back(index);
Value crd = genLoad(rewriter, loc, added, i);
Value value = genLoad(rewriter, loc, values, crd);
lcvs.push_back(crd);
// TODO: faster for subsequent insertions?
auto insertPoint = op->template getParentOfType<func::FuncOp>();
genInsertionCallHelper(rewriter, desc, indices, value, insertPoint,
genInsertionCallHelper(rewriter, desc, lcvs, value, insertPoint,
kInsertFuncNamePrefix, genInsertBody);
genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values,
index);
genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, index);
genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values, crd);
genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, crd);
rewriter.create<scf::YieldOp>(loc, desc.getFields());
rewriter.setInsertionPointAfter(loop);
Value result = genTuple(rewriter, loc, dstType, loop->getResults());
@@ -924,12 +923,11 @@ public:
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> fields;
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
// Prepare and indices.
SmallVector<Value> indices(adaptor.getIndices());
SmallVector<Value> lcvs(adaptor.getLvlCoords());
// Generate insertion.
Value value = adaptor.getValue();
auto insertPoint = op->template getParentOfType<func::FuncOp>();
genInsertionCallHelper(rewriter, desc, indices, value, insertPoint,
genInsertionCallHelper(rewriter, desc, lcvs, value, insertPoint,
kInsertFuncNamePrefix, genInsertBody);
// Replace operation with resulting memrefs.
@@ -938,39 +936,38 @@ public:
}
};
/// Sparse codegen rule for pointer accesses.
class SparseToPointersConverter : public OpConversionPattern<ToPointersOp> {
/// Sparse codegen rule for position accesses.
class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
public:
using OpAdaptor = typename ToPointersOp::Adaptor;
using OpConversionPattern<ToPointersOp>::OpConversionPattern;
using OpAdaptor = typename ToPositionsOp::Adaptor;
using OpConversionPattern<ToPositionsOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ToPointersOp op, OpAdaptor adaptor,
matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested pointer access with corresponding field.
// Replace the requested position access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
uint64_t dim = op.getDimension().getZExtValue();
rewriter.replaceOp(op, desc.getPtrMemRef(dim));
rewriter.replaceOp(op, desc.getPosMemRef(op.getLevel()));
return success();
}
};
/// Sparse codegen rule for index accesses.
class SparseToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
/// Sparse codegen rule for accessing the coordinates arrays.
class SparseToCoordinatesConverter
: public OpConversionPattern<ToCoordinatesOp> {
public:
using OpAdaptor = typename ToIndicesOp::Adaptor;
using OpConversionPattern<ToIndicesOp>::OpConversionPattern;
using OpAdaptor = typename ToCoordinatesOp::Adaptor;
using OpConversionPattern<ToCoordinatesOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor,
matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested pointer access with corresponding field.
// Replace the requested coordinates access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
Location loc = op.getLoc();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
uint64_t dim = op.getDimension().getZExtValue();
Value field = desc.getIdxMemRefOrView(rewriter, loc, dim);
Value field = desc.getCrdMemRefOrView(rewriter, loc, op.getLevel());
// Insert a cast to bridge the actual type to the user expected type. If the
// actual type and the user expected type aren't compatible, the compiler or
@@ -984,16 +981,16 @@ public:
}
};
/// Sparse codegen rule for accessing the linear indices buffer.
class SparseToIndicesBufferConverter
: public OpConversionPattern<ToIndicesBufferOp> {
/// Sparse codegen rule for accessing the linear coordinates buffer.
class SparseToCoordinatesBufferConverter
: public OpConversionPattern<ToCoordinatesBufferOp> {
public:
using OpAdaptor = typename ToIndicesBufferOp::Adaptor;
using OpConversionPattern<ToIndicesBufferOp>::OpConversionPattern;
using OpAdaptor = typename ToCoordinatesBufferOp::Adaptor;
using OpConversionPattern<ToCoordinatesBufferOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ToIndicesBufferOp op, OpAdaptor adaptor,
matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested pointer access with corresponding field.
// Replace the requested coordinates access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
@@ -1011,7 +1008,7 @@ public:
LogicalResult
matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested pointer access with corresponding field.
// Replace the requested values access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
@@ -1118,8 +1115,8 @@ public:
assert(srcEnc.getDimLevelType() == dstEnc.getDimLevelType());
assert(srcEnc.getDimOrdering() == dstEnc.getDimOrdering());
assert(srcEnc.getHigherOrdering() == dstEnc.getHigherOrdering());
assert(srcEnc.getPointerBitWidth() == dstEnc.getPointerBitWidth());
assert(srcEnc.getIndexBitWidth() == dstEnc.getIndexBitWidth());
assert(srcEnc.getPosWidth() == dstEnc.getPosWidth());
assert(srcEnc.getCrdWidth() == dstEnc.getCrdWidth());
// TODO: support dynamic slices.
for (int i = 0, e = op.getSourceType().getRank(); i < e; i++) {
@@ -1174,11 +1171,12 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
case SparseTensorFieldKind::StorageSpec:
field = SparseTensorSpecifier::getInitValue(rewriter, loc, rtp);
break;
case SparseTensorFieldKind::PtrMemRef: {
// TACO-style COO starts with a PtrBuffer
case SparseTensorFieldKind::PosMemRef: {
// TACO-style COO starts with a PosBuffer
// By creating a constant value for it, we avoid the complexity of
// memory management.
auto tensorType = RankedTensorType::get({2}, enc.getPointerType());
const auto posTp = enc.getPosType();
auto tensorType = RankedTensorType::get({2}, posTp);
auto memrefType = MemRefType::get(tensorType.getShape(),
tensorType.getElementType());
auto cstPtr = rewriter.create<arith::ConstantOp>(
@@ -1186,35 +1184,34 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
DenseElementsAttr::get(
tensorType,
ArrayRef<Attribute>{
IntegerAttr::get(enc.getPointerType(), 0),
IntegerAttr::get(posTp, 0),
IntegerAttr::get(
enc.getPointerType(),
op.getData().getType().getShape()[0])}));
posTp, op.getValues().getType().getShape()[0])}));
field = rewriter.create<bufferization::ToMemrefOp>(loc, memrefType,
cstPtr);
break;
}
case SparseTensorFieldKind::IdxMemRef: {
auto tensorType = op.getIndices().getType();
case SparseTensorFieldKind::CrdMemRef: {
auto tensorType = op.getCoordinates().getType();
auto memrefType = MemRefType::get(tensorType.getShape(),
tensorType.getElementType());
auto idxMemRef = rewriter.create<bufferization::ToMemrefOp>(
op->getLoc(), memrefType, op.getIndices());
auto crdMemRef = rewriter.create<bufferization::ToMemrefOp>(
op->getLoc(), memrefType, op.getCoordinates());
ReassociationIndices reassociation;
for (int i = 0, e = tensorType.getRank(); i < e; i++)
reassociation.push_back(i);
// Flattened the indices buffer to rank 1.
field = rewriter.create<memref::CollapseShapeOp>(
loc, idxMemRef, ArrayRef<ReassociationIndices>(reassociation));
loc, crdMemRef, ArrayRef<ReassociationIndices>(reassociation));
break;
}
case SparseTensorFieldKind::ValMemRef: {
auto tensorType = op.getData().getType();
auto tensorType = op.getValues().getType();
auto memrefType = MemRefType::get(tensorType.getShape(),
tensorType.getElementType());
field = rewriter.create<bufferization::ToMemrefOp>(
op->getLoc(), memrefType, op.getData());
op->getLoc(), memrefType, op.getValues());
break;
}
}
@@ -1228,15 +1225,18 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
});
MutSparseTensorDescriptor desc(rtp, fields);
auto noe = linalg::createOrFoldDimOp(rewriter, loc, op.getData(), 0);
for (unsigned i = 0, e = rtp.getRank(); i < e; i++) {
int dim = rtp.getShape()[i];
assert(!ShapedType::isDynamic(dim));
desc.setDimSize(rewriter, loc, i, constantIndex(rewriter, loc, dim));
if (i == 0)
desc.setPtrMemSize(rewriter, loc, i, constantIndex(rewriter, loc, 2));
auto noe = linalg::createOrFoldDimOp(rewriter, loc, op.getValues(), 0);
// FIXME: should use `SparseTensorType::getLvlRank` in lieu of
// `RankedTensorType::getRank`, because the latter introduces dim/lvl
// ambiguity.
for (Level lvl = 0, lvlRank = rtp.getRank(); lvl < lvlRank; lvl++) {
const auto sh = rtp.getShape()[lvl];
assert(!ShapedType::isDynamic(sh));
desc.setLvlSize(rewriter, loc, lvl, constantIndex(rewriter, loc, sh));
if (lvl == 0)
desc.setPosMemSize(rewriter, loc, lvl, constantIndex(rewriter, loc, 2));
desc.setIdxMemSize(rewriter, loc, i, noe);
desc.setCrdMemSize(rewriter, loc, lvl, noe);
}
desc.setValMemSize(rewriter, loc, noe);
@@ -1252,39 +1252,43 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
ConversionPatternRewriter &rewriter) const override {
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
Location loc = op.getLoc();
int64_t rank = op.getTensor().getType().getRank();
const auto srcTp = getSparseTensorType(op.getTensor());
const Level lvlRank = srcTp.getLvlRank();
assert(isUniqueCOOType(op.getTensor().getType()) &&
desc.getFields().size() == 4);
assert(isUniqueCOOType(srcTp) && desc.getFields().size() == 4);
Value flatBuf = rank == 1 ? desc.getIdxMemRefOrView(rewriter, loc, 0)
: desc.getAOSMemRef();
Value dataBuf = desc.getValMemRef();
Value flatBuf = lvlRank == 1 ? desc.getCrdMemRefOrView(rewriter, loc, 0)
: desc.getAOSMemRef();
Value valuesBuf = desc.getValMemRef();
// If frontend requests a static buffer, we reallocate the data/indices
// to ensure that we meet their need.
TensorType dataTp = op.getData().getType();
if (dataTp.hasStaticShape()) {
dataBuf = reallocOrSubView(rewriter, loc, dataTp.getShape()[0], dataBuf);
// If frontend requests a static buffer, we reallocate the
// values/coordinates to ensure that we meet their need.
const auto valuesTp = getRankedTensorType(op.getValues());
if (valuesTp.hasStaticShape()) {
valuesBuf =
reallocOrSubView(rewriter, loc, valuesTp.getShape()[0], valuesBuf);
}
TensorType indicesTp = op.getIndices().getType();
if (indicesTp.hasStaticShape()) {
auto len = indicesTp.getShape()[0] * indicesTp.getShape()[1];
const auto coordinatesTp = getRankedTensorType(op.getCoordinates());
if (coordinatesTp.hasStaticShape()) {
auto len = coordinatesTp.getShape()[0] * coordinatesTp.getShape()[1];
flatBuf = reallocOrSubView(rewriter, loc, len, flatBuf);
}
Value idxBuf = rewriter.create<memref::ExpandShapeOp>(
loc, MemRefType::get(indicesTp.getShape(), indicesTp.getElementType()),
Value coordinatesBuf = rewriter.create<memref::ExpandShapeOp>(
loc,
MemRefType::get(coordinatesTp.getShape(),
coordinatesTp.getElementType()),
flatBuf, ArrayRef{ReassociationIndices{0, 1}});
// Converts MemRefs back to Tensors.
Value data = rewriter.create<bufferization::ToTensorOp>(loc, dataBuf);
Value indices = rewriter.create<bufferization::ToTensorOp>(loc, idxBuf);
Value nnz = genCast(rewriter, loc, desc.getValMemSize(rewriter, loc),
op.getNnz().getType());
Value values = rewriter.create<bufferization::ToTensorOp>(loc, valuesBuf);
Value coordinates =
rewriter.create<bufferization::ToTensorOp>(loc, coordinatesBuf);
Value nse = genCast(rewriter, loc, desc.getValMemSize(rewriter, loc),
op.getNse().getType());
rewriter.replaceOp(op, {data, indices, nnz});
rewriter.replaceOp(op, {values, coordinates, nse});
return success();
}
};
@@ -1296,52 +1300,54 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
const auto dstTp = getSparseTensorType(op.getResult());
const auto encDst = dstTp.getEncoding();
// Creating COO with NewOp is handled by direct IR codegen. All other cases
// are handled by rewriting.
if (!dstTp.hasEncoding() || getCOOStart(encDst) != 0)
if (!dstTp.hasEncoding() || getCOOStart(dstTp.getEncoding()) != 0)
return failure();
// Implement the NewOp(filename) as follows:
// reader = getSparseTensorReader(filename)
// nse = getSparseTensorNNZ()
// tmp = bufferization.alloc_tensor an ordered COO with
// dst dim ordering, size_hint = nse
// indices = to_indices_buffer(tmp)
// values = to_values(tmp)
// isSorted = getSparseTensorReaderRead(indices, values, dimOrdering)
// if (!isSorted) sort_coo(nse, indices, values)
// %reader = @getSparseTensorReader(%filename)
// %nse = @getSparseTensorNSE(%reader)
// %coo = bufferization.alloc_tensor an ordered COO with
// dst dim ordering, size_hint = %nse
// %coordinates = sparse_tensor.coordinates_buffer(%coo)
// %values = sparse_tensor.values(%coo)
// %isSorted = @sparseTensorReaderReadToBuffers(%coordinates, %values)
// if (! %isSorted) sparse_tensor.sort_coo(%nse, %coordinates, %values)
// update storage specifier
// dst = sparse_tensor.ConvertOp tmp
// @delSparseTensorReader(%reader)
// Create a sparse tensor reader.
Value fileName = op.getSource();
Type opaqueTp = getOpaquePointerType(rewriter);
const Value fileName = op.getSource();
const Type opaqueTp = getOpaquePointerType(rewriter);
// FIXME: use `createCheckedSparseTensorReader` instead, because
// `createSparseTensorReader` is unsafe.
Value reader = createFuncCall(rewriter, loc, "createSparseTensorReader",
{opaqueTp}, {fileName}, EmitCInterface::Off)
.getResult(0);
Type indexTp = rewriter.getIndexType();
const Type indexTp = rewriter.getIndexType();
const Dimension dimRank = dstTp.getDimRank();
const Level lvlRank = dstTp.getLvlRank();
// If the result tensor has dynamic dimensions, get the dynamic sizes from
// the sparse tensor reader.
SmallVector<Value> dynSizes;
if (dstTp.hasDynamicDimShape()) {
// FIXME: call `getSparseTensorReaderDimSizes` instead, because
// `copySparseTensorReaderDimSizes` copies the memref over,
// instead of just accessing the reader's memory directly.
Value dimSizes = genAlloca(rewriter, loc, dimRank, indexTp);
createFuncCall(rewriter, loc, "copySparseTensorReaderDimSizes", {},
{reader, dimSizes}, EmitCInterface::On)
.getResult(0);
ArrayRef<int64_t> dstShape = dstTp.getRankedTensorType().getShape();
for (auto &d : llvm::enumerate(dstShape)) {
if (d.value() == ShapedType::kDynamic) {
for (const auto &d : llvm::enumerate(dstTp.getDimShape()))
if (ShapedType::isDynamic(d.value()))
dynSizes.push_back(rewriter.create<memref::LoadOp>(
loc, dimSizes, constantIndex(rewriter, loc, d.index())));
}
}
}
Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNNZ",
Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNSE",
{indexTp}, {reader}, EmitCInterface::Off)
.getResult(0);
// Construct allocation for each field.
@@ -1350,63 +1356,71 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
fields, nse);
MutSparseTensorDescriptor desc(dstTp, fields);
// Read the COO tensor data.
Type eltTp = dstTp.getElementType();
Type indBufEleTp = getIndexOverheadType(rewriter, encDst);
SmallString<32> getReadFuncName{"getSparseTensorReaderRead",
overheadTypeFunctionSuffix(indBufEleTp),
primaryTypeFunctionSuffix(eltTp)};
Value xs = desc.getAOSMemRef();
Value ys = desc.getValMemRef();
SmallVector<Value> dim2lvlValues(dimRank, Value());
if (auto dimOrder = encDst.getDimOrdering()) {
// Construct the `dim2lvl` buffer for handing off to the runtime library.
// FIXME: This code is (mostly) copied from the SparseTensorConversion.cpp
// handling of `NewOp`, and only handles permutations. Fixing this
// requires waiting for wrengr to finish redoing the CL that handles
// all dim<->lvl stuff more robustly.
SmallVector<Value> dim2lvlValues(dimRank);
if (!dstTp.isIdentity()) {
const auto dimOrder = dstTp.getDimToLvlMap();
assert(dimOrder.isPermutation() && "Got non-permutation");
for (uint64_t l = 0; l < dimRank; l++) {
uint64_t d = dimOrder.getDimPosition(l);
for (Level l = 0; l < lvlRank; l++) {
const Dimension d = dimOrder.getDimPosition(l);
dim2lvlValues[d] = constantIndex(rewriter, loc, l);
}
} else {
for (uint64_t l = 0; l < dimRank; l++)
dim2lvlValues[l] = constantIndex(rewriter, loc, l);
// The `SparseTensorType` ctor already ensures `dimRank == lvlRank`
// when `isIdentity`; so no need to re-assert it here.
for (Dimension d = 0; d < dimRank; d++)
dim2lvlValues[d] = constantIndex(rewriter, loc, d);
}
Value dim2lvl = allocaBuffer(rewriter, loc, dim2lvlValues);
Value f = constantI1(rewriter, loc, false);
// Read the COO tensor data.
Value xs = desc.getAOSMemRef();
Value ys = desc.getValMemRef();
const Type boolTp = rewriter.getIntegerType(1);
const Type elemTp = dstTp.getElementType();
const Type crdTp = dstTp.getCrdType();
// FIXME: This function name is weird; should rename to
// "sparseTensorReaderReadToBuffers".
SmallString<32> readToBuffersFuncName{"getSparseTensorReaderRead",
overheadTypeFunctionSuffix(crdTp),
primaryTypeFunctionSuffix(elemTp)};
Value isSorted =
createFuncCall(rewriter, loc, getReadFuncName, {f.getType()},
createFuncCall(rewriter, loc, readToBuffersFuncName, {boolTp},
{reader, dim2lvl, xs, ys}, EmitCInterface::On)
.getResult(0);
// If the destination tensor is a sorted COO, we need to sort the COO tensor
// data if the input elements aren't sorted yet.
if (encDst.isOrderedLvl(dimRank - 1)) {
if (dstTp.isOrderedLvl(lvlRank - 1)) {
Value kFalse = constantI1(rewriter, loc, false);
Value notSorted = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, isSorted, f);
loc, arith::CmpIPredicate::eq, isSorted, kFalse);
scf::IfOp ifOp =
rewriter.create<scf::IfOp>(loc, notSorted, /*else*/ false);
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
rewriter.create<SortCooOp>(
loc, nse, xs, ValueRange{ys}, rewriter.getIndexAttr(dimRank),
loc, nse, xs, ValueRange{ys}, rewriter.getIndexAttr(lvlRank),
rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
rewriter.setInsertionPointAfter(ifOp);
}
// Set PtrMemRef0[1] = nse.
Value c1 = constantIndex(rewriter, loc, 1);
Value ptrMemref0 = desc.getPtrMemRef(0);
Type ptrEleTy = getMemRefType(ptrMemref0).getElementType();
Value ptrNse =
ptrEleTy == nse.getType()
? nse
: rewriter.create<arith::IndexCastOp>(loc, ptrEleTy, nse);
rewriter.create<memref::StoreOp>(loc, ptrNse, ptrMemref0, c1);
// Set PosMemRef0[1] = nse.
const Value c1 = constantIndex(rewriter, loc, 1);
const Value posMemref0 = desc.getPosMemRef(0);
const Type posTp = dstTp.getPosType();
const Value posNse = genCast(rewriter, loc, nse, posTp);
rewriter.create<memref::StoreOp>(loc, posNse, posMemref0, c1);
// Update storage specifier.
Value idxSize = rewriter.create<arith::MulIOp>(
loc, nse, constantIndex(rewriter, loc, dimRank));
desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::IdxMemSize, 0,
idxSize);
Value coordinatesSize = rewriter.create<arith::MulIOp>(
loc, nse, constantIndex(rewriter, loc, lvlRank));
desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::CrdMemSize, 0,
coordinatesSize);
desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::ValMemSize,
std::nullopt, nse);
@@ -1436,8 +1450,8 @@ void mlir::populateSparseTensorCodegenPatterns(
SparseCastConverter, SparseTensorDeallocConverter,
SparseExtractSliceCoverter, SparseTensorLoadConverter,
SparseExpandConverter, SparseCompressConverter,
SparseInsertConverter, SparseToPointersConverter,
SparseToIndicesConverter, SparseToIndicesBufferConverter,
SparseInsertConverter, SparseToPositionsConverter,
SparseToCoordinatesConverter, SparseToCoordinatesBufferConverter,
SparseToValuesConverter, SparseConvertConverter,
SparseNewOpConverter, SparseNumberOfEntriesConverter>(
typeConverter, patterns.getContext());