[mlir][sparse] Moving/renaming genBuffer to allocaBuffer

This allows allocaBuffer to be used outside of SparseTensorConversion.cpp, which will be helpful for a some future commits.

Reviewed By: aartbik, Peiming

Differential Revision: https://reviews.llvm.org/D140047
This commit is contained in:
wren romano
2022-12-14 12:58:50 -08:00
parent 71cc0f1c04
commit 62896428a7
3 changed files with 27 additions and 22 deletions

View File

@@ -207,18 +207,6 @@ static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) {
return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
}
/// Generates a temporary buffer of the given type and given contents.
static Value genBuffer(OpBuilder &builder, Location loc, ValueRange values) {
unsigned sz = values.size();
assert(sz >= 1);
Value buffer = genAlloca(builder, loc, sz, values[0].getType());
for (unsigned i = 0; i < sz; i++) {
Value idx = constantIndex(builder, loc, i);
builder.create<memref::StoreOp>(loc, values[i], buffer, idx);
}
return buffer;
}
/// Generates a temporary buffer for the level-types of the given encoding.
static Value genLvlTypesBuffer(OpBuilder &builder, Location loc,
SparseTensorEncodingAttr enc) {
@@ -227,7 +215,7 @@ static Value genLvlTypesBuffer(OpBuilder &builder, Location loc,
lvlTypes.reserve(dlts.size());
for (auto dlt : dlts)
lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt));
return genBuffer(builder, loc, lvlTypes);
return allocaBuffer(builder, loc, lvlTypes);
}
/// This class abstracts over the API of `_mlir_ciface_newSparseTensor`:
@@ -329,7 +317,7 @@ NewCallParams &NewCallParams::genBuffers(SparseTensorEncodingAttr enc,
// Dimension-sizes array of the enveloping tensor. Useful for either
// verification of external data, or for construction of internal data.
assert(dimSizes.size() == dimRank && "Dimension-rank mismatch");
params[kParamDimSizes] = genBuffer(builder, loc, dimSizes);
params[kParamDimSizes] = allocaBuffer(builder, loc, dimSizes);
// The level-sizes array must be passed as well, since for arbitrary
// dim2lvl mappings it cannot be trivially reconstructed at runtime.
// For now however, since we're still assuming permutations, we will
@@ -358,10 +346,10 @@ NewCallParams &NewCallParams::genBuffers(SparseTensorEncodingAttr enc,
lvlSizes[i] = dimSizes[i];
}
}
params[kParamLvlSizes] = genBuffer(builder, loc, lvlSizes);
params[kParamLvl2Dim] = genBuffer(builder, loc, lvl2dim);
params[kParamLvlSizes] = allocaBuffer(builder, loc, lvlSizes);
params[kParamLvl2Dim] = allocaBuffer(builder, loc, lvl2dim);
params[kParamDim2Lvl] =
dimOrder ? genBuffer(builder, loc, dim2lvl) : params[kParamLvl2Dim];
dimOrder ? allocaBuffer(builder, loc, dim2lvl) : params[kParamLvl2Dim];
// Secondary and primary types encoding.
setTemplateTypes(enc, stp);
// Finally, make note that initialization is complete.
@@ -780,7 +768,7 @@ public:
// Construct the dimShape.
const auto dimShape = stp.getShape();
SmallVector<Value> dimShapeValues = getDimShape(rewriter, loc, stp);
Value dimShapeBuffer = genBuffer(rewriter, loc, dimShapeValues);
Value dimShapeBuffer = allocaBuffer(rewriter, loc, dimShapeValues);
// Allocate `SparseTensorReader` and perform all initial setup that
// does not depend on lvlSizes (nor dim2lvl, lvl2dim, etc).
Type opaqueTp = getOpaquePointerType(rewriter);
@@ -833,9 +821,9 @@ public:
? rewriter.create<memref::LoadOp>(loc, dimSizesBuffer, dim)
: dimShapeValues[d];
}
lvlSizesBuffer = genBuffer(rewriter, loc, lvlSizeValues);
lvl2dimBuffer = genBuffer(rewriter, loc, lvl2dimValues);
dim2lvlBuffer = genBuffer(rewriter, loc, dim2lvlValues);
lvlSizesBuffer = allocaBuffer(rewriter, loc, lvlSizeValues);
lvl2dimBuffer = allocaBuffer(rewriter, loc, lvl2dimValues);
dim2lvlBuffer = allocaBuffer(rewriter, loc, dim2lvlValues);
} else {
assert(dimRank == lvlRank && "Rank mismatch");
SmallVector<Value> iotaValues;
@@ -843,7 +831,7 @@ public:
for (unsigned i = 0; i < lvlRank; i++)
iotaValues.push_back(constantIndex(rewriter, loc, i));
lvlSizesBuffer = dimSizesBuffer ? dimSizesBuffer : dimShapeBuffer;
dim2lvlBuffer = lvl2dimBuffer = genBuffer(rewriter, loc, iotaValues);
dim2lvlBuffer = lvl2dimBuffer = allocaBuffer(rewriter, loc, iotaValues);
}
// Use the `reader` to parse the file.
SmallVector<Value, 8> params{