mirror of
https://github.com/intel/llvm.git
synced 2026-02-03 02:26:27 +08:00
[mlir][sparse] Moving/renaming genBuffer to allocaBuffer
This allows allocaBuffer to be used outside of SparseTensorConversion.cpp, which will be helpful for a some future commits. Reviewed By: aartbik, Peiming Differential Revision: https://reviews.llvm.org/D140047
This commit is contained in:
@@ -207,18 +207,6 @@ static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) {
|
||||
return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
|
||||
}
|
||||
|
||||
/// Generates a temporary buffer of the given type and given contents.
|
||||
static Value genBuffer(OpBuilder &builder, Location loc, ValueRange values) {
|
||||
unsigned sz = values.size();
|
||||
assert(sz >= 1);
|
||||
Value buffer = genAlloca(builder, loc, sz, values[0].getType());
|
||||
for (unsigned i = 0; i < sz; i++) {
|
||||
Value idx = constantIndex(builder, loc, i);
|
||||
builder.create<memref::StoreOp>(loc, values[i], buffer, idx);
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
|
||||
/// Generates a temporary buffer for the level-types of the given encoding.
|
||||
static Value genLvlTypesBuffer(OpBuilder &builder, Location loc,
|
||||
SparseTensorEncodingAttr enc) {
|
||||
@@ -227,7 +215,7 @@ static Value genLvlTypesBuffer(OpBuilder &builder, Location loc,
|
||||
lvlTypes.reserve(dlts.size());
|
||||
for (auto dlt : dlts)
|
||||
lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt));
|
||||
return genBuffer(builder, loc, lvlTypes);
|
||||
return allocaBuffer(builder, loc, lvlTypes);
|
||||
}
|
||||
|
||||
/// This class abstracts over the API of `_mlir_ciface_newSparseTensor`:
|
||||
@@ -329,7 +317,7 @@ NewCallParams &NewCallParams::genBuffers(SparseTensorEncodingAttr enc,
|
||||
// Dimension-sizes array of the enveloping tensor. Useful for either
|
||||
// verification of external data, or for construction of internal data.
|
||||
assert(dimSizes.size() == dimRank && "Dimension-rank mismatch");
|
||||
params[kParamDimSizes] = genBuffer(builder, loc, dimSizes);
|
||||
params[kParamDimSizes] = allocaBuffer(builder, loc, dimSizes);
|
||||
// The level-sizes array must be passed as well, since for arbitrary
|
||||
// dim2lvl mappings it cannot be trivially reconstructed at runtime.
|
||||
// For now however, since we're still assuming permutations, we will
|
||||
@@ -358,10 +346,10 @@ NewCallParams &NewCallParams::genBuffers(SparseTensorEncodingAttr enc,
|
||||
lvlSizes[i] = dimSizes[i];
|
||||
}
|
||||
}
|
||||
params[kParamLvlSizes] = genBuffer(builder, loc, lvlSizes);
|
||||
params[kParamLvl2Dim] = genBuffer(builder, loc, lvl2dim);
|
||||
params[kParamLvlSizes] = allocaBuffer(builder, loc, lvlSizes);
|
||||
params[kParamLvl2Dim] = allocaBuffer(builder, loc, lvl2dim);
|
||||
params[kParamDim2Lvl] =
|
||||
dimOrder ? genBuffer(builder, loc, dim2lvl) : params[kParamLvl2Dim];
|
||||
dimOrder ? allocaBuffer(builder, loc, dim2lvl) : params[kParamLvl2Dim];
|
||||
// Secondary and primary types encoding.
|
||||
setTemplateTypes(enc, stp);
|
||||
// Finally, make note that initialization is complete.
|
||||
@@ -780,7 +768,7 @@ public:
|
||||
// Construct the dimShape.
|
||||
const auto dimShape = stp.getShape();
|
||||
SmallVector<Value> dimShapeValues = getDimShape(rewriter, loc, stp);
|
||||
Value dimShapeBuffer = genBuffer(rewriter, loc, dimShapeValues);
|
||||
Value dimShapeBuffer = allocaBuffer(rewriter, loc, dimShapeValues);
|
||||
// Allocate `SparseTensorReader` and perform all initial setup that
|
||||
// does not depend on lvlSizes (nor dim2lvl, lvl2dim, etc).
|
||||
Type opaqueTp = getOpaquePointerType(rewriter);
|
||||
@@ -833,9 +821,9 @@ public:
|
||||
? rewriter.create<memref::LoadOp>(loc, dimSizesBuffer, dim)
|
||||
: dimShapeValues[d];
|
||||
}
|
||||
lvlSizesBuffer = genBuffer(rewriter, loc, lvlSizeValues);
|
||||
lvl2dimBuffer = genBuffer(rewriter, loc, lvl2dimValues);
|
||||
dim2lvlBuffer = genBuffer(rewriter, loc, dim2lvlValues);
|
||||
lvlSizesBuffer = allocaBuffer(rewriter, loc, lvlSizeValues);
|
||||
lvl2dimBuffer = allocaBuffer(rewriter, loc, lvl2dimValues);
|
||||
dim2lvlBuffer = allocaBuffer(rewriter, loc, dim2lvlValues);
|
||||
} else {
|
||||
assert(dimRank == lvlRank && "Rank mismatch");
|
||||
SmallVector<Value> iotaValues;
|
||||
@@ -843,7 +831,7 @@ public:
|
||||
for (unsigned i = 0; i < lvlRank; i++)
|
||||
iotaValues.push_back(constantIndex(rewriter, loc, i));
|
||||
lvlSizesBuffer = dimSizesBuffer ? dimSizesBuffer : dimShapeBuffer;
|
||||
dim2lvlBuffer = lvl2dimBuffer = genBuffer(rewriter, loc, iotaValues);
|
||||
dim2lvlBuffer = lvl2dimBuffer = allocaBuffer(rewriter, loc, iotaValues);
|
||||
}
|
||||
// Use the `reader` to parse the file.
|
||||
SmallVector<Value, 8> params{
|
||||
|
||||
Reference in New Issue
Block a user