[mlir][sparse] extend unpack operation to support unpacking a batched COO type

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D149103
This commit is contained in:
Peiming Liu
2023-04-21 20:06:36 +00:00
parent f9fbda7102
commit d4db528938
13 changed files with 393 additions and 125 deletions

View File

@@ -602,6 +602,25 @@ static Value reallocOrSubView(OpBuilder &builder, Location loc, int64_t len,
return ifOp.getResult(0);
}
static Value linearize(OpBuilder &builder, Location loc, ValueRange ivs,
ValueRange bounds) {
assert(ivs.size() == bounds.size());
Value crd = constantIndex(builder, loc, 0);
for (unsigned i = 0, e = ivs.size(); i < e; i++) {
crd = builder.create<arith::AddIOp>(loc, crd, ivs[i]);
if (i != ivs.size() - 1)
crd = builder.create<arith::MulIOp>(loc, crd, bounds[i + 1]);
}
return crd;
}
ReassociationIndices getReassociationForFlattening(ShapedType srcTp) {
ReassociationIndices reassociation;
for (int i = 0, e = srcTp.getRank(); i < e; i++)
reassociation.push_back(i);
return reassociation;
}
//===----------------------------------------------------------------------===//
// Codegen rules.
//===----------------------------------------------------------------------===//
@@ -1252,12 +1271,7 @@ static void populateCompressedWithHiPosArray(OpBuilder &builder, Location loc,
[&ubs, c0, c1, c2, nse, batV, posMemRef](OpBuilder &builder, Location loc,
ValueRange ivs) {
// Linearize index variables
Value crd = constantIndex(builder, loc, 0);
for (unsigned i = 0, e = ivs.size(); i < e; i++) {
crd = builder.create<arith::AddIOp>(loc, crd, ivs[i]);
if (i != ivs.size() - 1)
crd = builder.create<arith::MulIOp>(loc, crd, ubs[i + 1]);
}
Value crd = linearize(builder, loc, ivs, ubs);
Value len = constantIndex(builder, loc, nse);
Value pLo = builder.create<arith::MulIOp>(loc, crd, len);
SmallVector<Value> indices(ivs.begin(), ivs.end());
@@ -1420,6 +1434,166 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
}
};
static LogicalResult genUnBatchedUnpackOp(UnpackOp op,
SparseTensorDescriptor desc,
ConversionPatternRewriter &rewriter) {
Location loc = op.getLoc();
const auto srcTp = getSparseTensorType(op.getTensor());
const Level lvlRank = srcTp.getLvlRank();
Value flatBuf = lvlRank == 1 ? desc.getCrdMemRefOrView(rewriter, loc, 0)
: desc.getAOSMemRef();
Value valuesBuf = desc.getValMemRef();
// If frontend requests a static buffer, we reallocate the
// values/coordinates to ensure that we meet their need.
const auto valuesTp = getRankedTensorType(op.getValues());
if (valuesTp.hasStaticShape()) {
// FIXME: Reallocation is not always safe! E.g., if we are unpacking a
// tensor that is packed from constants.
valuesBuf =
reallocOrSubView(rewriter, loc, valuesTp.getShape()[0], valuesBuf);
}
const auto coordinatesTp = getRankedTensorType(op.getCoordinates());
if (coordinatesTp.hasStaticShape()) {
// FIXME: Reallocation is not always safe! E.g., if we are unpacking a
// tensor that is packed from constants.
auto len = coordinatesTp.getShape()[0] * coordinatesTp.getShape()[1];
flatBuf = reallocOrSubView(rewriter, loc, len, flatBuf);
}
Value coordinatesBuf = rewriter.create<memref::ExpandShapeOp>(
loc,
MemRefType::get(coordinatesTp.getShape(), coordinatesTp.getElementType()),
flatBuf, ArrayRef{ReassociationIndices{0, 1}});
// Converts MemRefs back to Tensors.
Value values = rewriter.create<bufferization::ToTensorOp>(loc, valuesBuf);
Value coordinates =
rewriter.create<bufferization::ToTensorOp>(loc, coordinatesBuf);
Value nse = genCast(rewriter, loc, desc.getValMemSize(rewriter, loc),
op.getNse().getType());
rewriter.replaceOp(op, {values, coordinates, nse});
return success();
}
static LogicalResult genBatchedUnpackOp(UnpackOp op, unsigned nBatched,
SparseTensorDescriptor desc,
ConversionPatternRewriter &rewriter) {
assert(nBatched != 0);
Location loc = op.getLoc();
Value c0 = constantIndex(rewriter, loc, 0);
Value c1 = constantIndex(rewriter, loc, 1);
Value c2 = constantIndex(rewriter, loc, 2);
auto genZeroedAlloc = [loc,
&rewriter](TensorType tt) -> TypedValue<MemRefType> {
auto mem = rewriter
.create<memref::AllocOp>(
loc, MemRefType::get(tt.getShape(), tt.getElementType()))
.getMemref();
// TODO: Instead of filling the entire buffer, we can only fill the
// trailing zeros.
rewriter.create<linalg::FillOp>(
loc, ValueRange{constantZero(rewriter, loc, tt.getElementType())}, mem);
return mem;
};
SparseTensorType stt = getSparseTensorType(op.getTensor());
TensorType valTensorTp = op.getValues().getType();
TensorType crdTensorTp = op.getCoordinates().getType();
TypedValue<MemRefType> valMemref = genZeroedAlloc(valTensorTp);
TypedValue<MemRefType> crdMemref = genZeroedAlloc(crdTensorTp);
assert(valTensorTp.hasStaticShape() && crdTensorTp.hasStaticShape());
SmallVector<Value> lbs(nBatched, c0), steps(nBatched, c1);
SmallVector<Value> ubs;
for (unsigned i = 0; i < nBatched; i++) {
assert(!ShapedType::isDynamic(stt.getDimShape()[i]));
ubs.push_back(constantIndex(rewriter, loc, stt.getDimShape()[i]));
}
DimLevelType dlt = stt.getLvlType(nBatched);
assert(isCompressedDLT(dlt) || isCompressedWithHiDLT(dlt));
Value posStep = isCompressedDLT(dlt) ? c1 // forward position index by 1
: c2; // forward position index by 2
auto loopNest = scf::buildLoopNest(
rewriter, loc, lbs, ubs, steps, {c0 /*maximum nse*/},
[&ubs, c0, c1, posStep, desc, nBatched, &valMemref,
&crdMemref](OpBuilder &builder, Location loc, ValueRange ivs,
ValueRange args) -> scf::ValueVector {
// crdMemref has shape: <... x nse x rank>
unsigned unBatchedRank = crdMemref.getType().getShape().back();
Value values = desc.getValMemRef();
Value flatCrds = unBatchedRank == 1
? desc.getCrdMemRefOrView(builder, loc, 0)
: desc.getAOSMemRef();
Value positions = desc.getPosMemRef(nBatched);
Value positLo = builder.create<arith::MulIOp>(
loc, linearize(builder, loc, ivs, ubs), posStep);
Value positHi = builder.create<arith::AddIOp>(loc, positLo, c1);
Value pLo = genIndexLoad(builder, loc, positions, positLo);
Value pHi = genIndexLoad(builder, loc, positions, positHi);
Value nse = builder.create<arith::SubIOp>(loc, pHi, pLo);
Value crdLo = builder.create<arith::MulIOp>(
loc, pLo, constantIndex(builder, loc, unBatchedRank));
Value nCrd = builder.create<arith::MulIOp>(
loc, nse, constantIndex(builder, loc, unBatchedRank));
SmallVector<Value> offsets, sizes, strides;
for (unsigned i = 0; i < nBatched; i++) {
offsets.push_back(ivs[i]);
sizes.push_back(c1);
strides.push_back(c1);
}
// [0, nse, 1].
offsets.push_back(c0);
sizes.push_back(nse);
strides.push_back(c1);
auto valView = builder.create<memref::SubViewOp>(
loc, valMemref, offsets, sizes, strides);
auto valReass = getReassociationForFlattening(valView.getType());
Value valDst =
builder.create<memref::CollapseShapeOp>(loc, valView, valReass);
Value valSrc =
builder.create<memref::SubViewOp>(loc, values, pLo, nse, c1);
builder.create<memref::CopyOp>(loc, valSrc, valDst);
// [0, rank, 1].
offsets.push_back(c0);
sizes.push_back(constantIndex(builder, loc, unBatchedRank));
strides.push_back(c1);
auto crdView = builder.create<memref::SubViewOp>(
loc, crdMemref, offsets, sizes, strides);
auto crdReass = getReassociationForFlattening(crdView.getType());
Value crdDst =
builder.create<memref::CollapseShapeOp>(loc, crdView, crdReass);
Value crdSrc =
builder.create<memref::SubViewOp>(loc, flatCrds, crdLo, nCrd, c1);
builder.create<memref::CopyOp>(loc, crdSrc, crdDst);
Value pred = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ugt, nse, args[0]);
// Choose the larger NSE
return {builder.create<arith::SelectOp>(loc, pred, nse, args[0])};
});
// Converts MemRefs back to Tensors.
Value values = rewriter.create<bufferization::ToTensorOp>(loc, valMemref);
Value coordinates =
rewriter.create<bufferization::ToTensorOp>(loc, crdMemref);
Value nse =
genCast(rewriter, loc, loopNest.results.front(), op.getNse().getType());
rewriter.replaceOp(op, {values, coordinates, nse});
return success();
}
struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
using OpConversionPattern::OpConversionPattern;
SparseUnpackOpConverter(TypeConverter &typeConverter, MLIRContext *context,
@@ -1431,52 +1605,26 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
matchAndRewrite(UnpackOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
Location loc = op.getLoc();
const auto srcTp = getSparseTensorType(op.getTensor());
const Level lvlRank = srcTp.getLvlRank();
const unsigned nBatched = op.getNumBatchedLvls();
assert(isCOOType(srcTp.getEncoding(), nBatched, true) &&
desc.getFields().size() == 4); // specifier + pos + crds + values
auto logicRes = nBatched == 0
? genUnBatchedUnpackOp(op, desc, rewriter)
: genBatchedUnpackOp(op, nBatched, desc, rewriter);
Value posBuf = desc.getPosMemRef(nBatched);
assert(isUniqueCOOType(srcTp) && desc.getFields().size() == 4);
Value flatBuf = lvlRank == 1 ? desc.getCrdMemRefOrView(rewriter, loc, 0)
: desc.getAOSMemRef();
Value valuesBuf = desc.getValMemRef();
Value posBuf = desc.getPosMemRef(0);
if (createDeallocs) {
// Unpack ends the lifetime of the sparse tensor. While the value array
// and coordinate array are unpacked and returned, the position array
// becomes useless and need to be freed (if user requests).
rewriter.create<memref::DeallocOp>(loc, posBuf);
// FIXME: Depending on whether the tensor being unpacked is created by
// PackOp or not, we may or may not need to free other memref fields of
// the sparse tensor too (PackOp borrows value/coordinate buffer).
rewriter.create<memref::DeallocOp>(op.getLoc(), posBuf);
}
// If frontend requests a static buffer, we reallocate the
// values/coordinates to ensure that we meet their need.
const auto valuesTp = getRankedTensorType(op.getValues());
if (valuesTp.hasStaticShape()) {
valuesBuf =
reallocOrSubView(rewriter, loc, valuesTp.getShape()[0], valuesBuf);
}
const auto coordinatesTp = getRankedTensorType(op.getCoordinates());
if (coordinatesTp.hasStaticShape()) {
auto len = coordinatesTp.getShape()[0] * coordinatesTp.getShape()[1];
flatBuf = reallocOrSubView(rewriter, loc, len, flatBuf);
}
Value coordinatesBuf = rewriter.create<memref::ExpandShapeOp>(
loc,
MemRefType::get(coordinatesTp.getShape(),
coordinatesTp.getElementType()),
flatBuf, ArrayRef{ReassociationIndices{0, 1}});
// Converts MemRefs back to Tensors.
Value values = rewriter.create<bufferization::ToTensorOp>(loc, valuesBuf);
Value coordinates =
rewriter.create<bufferization::ToTensorOp>(loc, coordinatesBuf);
Value nse = genCast(rewriter, loc, desc.getValMemSize(rewriter, loc),
op.getNse().getType());
rewriter.replaceOp(op, {values, coordinates, nse});
return success();
return logicRes;
}
private: