mirror of
https://github.com/intel/llvm.git
synced 2026-01-29 04:16:38 +08:00
[mlir][sparse] extend unpack operation to support unpacking a batched COO type
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D149103
This commit is contained in:
@@ -602,6 +602,25 @@ static Value reallocOrSubView(OpBuilder &builder, Location loc, int64_t len,
|
||||
return ifOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value linearize(OpBuilder &builder, Location loc, ValueRange ivs,
|
||||
ValueRange bounds) {
|
||||
assert(ivs.size() == bounds.size());
|
||||
Value crd = constantIndex(builder, loc, 0);
|
||||
for (unsigned i = 0, e = ivs.size(); i < e; i++) {
|
||||
crd = builder.create<arith::AddIOp>(loc, crd, ivs[i]);
|
||||
if (i != ivs.size() - 1)
|
||||
crd = builder.create<arith::MulIOp>(loc, crd, bounds[i + 1]);
|
||||
}
|
||||
return crd;
|
||||
}
|
||||
|
||||
ReassociationIndices getReassociationForFlattening(ShapedType srcTp) {
|
||||
ReassociationIndices reassociation;
|
||||
for (int i = 0, e = srcTp.getRank(); i < e; i++)
|
||||
reassociation.push_back(i);
|
||||
return reassociation;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Codegen rules.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -1252,12 +1271,7 @@ static void populateCompressedWithHiPosArray(OpBuilder &builder, Location loc,
|
||||
[&ubs, c0, c1, c2, nse, batV, posMemRef](OpBuilder &builder, Location loc,
|
||||
ValueRange ivs) {
|
||||
// Linearize index variables
|
||||
Value crd = constantIndex(builder, loc, 0);
|
||||
for (unsigned i = 0, e = ivs.size(); i < e; i++) {
|
||||
crd = builder.create<arith::AddIOp>(loc, crd, ivs[i]);
|
||||
if (i != ivs.size() - 1)
|
||||
crd = builder.create<arith::MulIOp>(loc, crd, ubs[i + 1]);
|
||||
}
|
||||
Value crd = linearize(builder, loc, ivs, ubs);
|
||||
Value len = constantIndex(builder, loc, nse);
|
||||
Value pLo = builder.create<arith::MulIOp>(loc, crd, len);
|
||||
SmallVector<Value> indices(ivs.begin(), ivs.end());
|
||||
@@ -1420,6 +1434,166 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
|
||||
}
|
||||
};
|
||||
|
||||
static LogicalResult genUnBatchedUnpackOp(UnpackOp op,
|
||||
SparseTensorDescriptor desc,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
Location loc = op.getLoc();
|
||||
const auto srcTp = getSparseTensorType(op.getTensor());
|
||||
const Level lvlRank = srcTp.getLvlRank();
|
||||
Value flatBuf = lvlRank == 1 ? desc.getCrdMemRefOrView(rewriter, loc, 0)
|
||||
: desc.getAOSMemRef();
|
||||
Value valuesBuf = desc.getValMemRef();
|
||||
|
||||
// If frontend requests a static buffer, we reallocate the
|
||||
// values/coordinates to ensure that we meet their need.
|
||||
const auto valuesTp = getRankedTensorType(op.getValues());
|
||||
if (valuesTp.hasStaticShape()) {
|
||||
// FIXME: Reallocation is not always safe! E.g., if we are unpacking a
|
||||
// tensor that is packed from constants.
|
||||
valuesBuf =
|
||||
reallocOrSubView(rewriter, loc, valuesTp.getShape()[0], valuesBuf);
|
||||
}
|
||||
|
||||
const auto coordinatesTp = getRankedTensorType(op.getCoordinates());
|
||||
if (coordinatesTp.hasStaticShape()) {
|
||||
// FIXME: Reallocation is not always safe! E.g., if we are unpacking a
|
||||
// tensor that is packed from constants.
|
||||
auto len = coordinatesTp.getShape()[0] * coordinatesTp.getShape()[1];
|
||||
flatBuf = reallocOrSubView(rewriter, loc, len, flatBuf);
|
||||
}
|
||||
|
||||
Value coordinatesBuf = rewriter.create<memref::ExpandShapeOp>(
|
||||
loc,
|
||||
MemRefType::get(coordinatesTp.getShape(), coordinatesTp.getElementType()),
|
||||
flatBuf, ArrayRef{ReassociationIndices{0, 1}});
|
||||
|
||||
// Converts MemRefs back to Tensors.
|
||||
Value values = rewriter.create<bufferization::ToTensorOp>(loc, valuesBuf);
|
||||
Value coordinates =
|
||||
rewriter.create<bufferization::ToTensorOp>(loc, coordinatesBuf);
|
||||
Value nse = genCast(rewriter, loc, desc.getValMemSize(rewriter, loc),
|
||||
op.getNse().getType());
|
||||
|
||||
rewriter.replaceOp(op, {values, coordinates, nse});
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult genBatchedUnpackOp(UnpackOp op, unsigned nBatched,
|
||||
SparseTensorDescriptor desc,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
assert(nBatched != 0);
|
||||
Location loc = op.getLoc();
|
||||
Value c0 = constantIndex(rewriter, loc, 0);
|
||||
Value c1 = constantIndex(rewriter, loc, 1);
|
||||
Value c2 = constantIndex(rewriter, loc, 2);
|
||||
|
||||
auto genZeroedAlloc = [loc,
|
||||
&rewriter](TensorType tt) -> TypedValue<MemRefType> {
|
||||
auto mem = rewriter
|
||||
.create<memref::AllocOp>(
|
||||
loc, MemRefType::get(tt.getShape(), tt.getElementType()))
|
||||
.getMemref();
|
||||
// TODO: Instead of filling the entire buffer, we can only fill the
|
||||
// trailing zeros.
|
||||
rewriter.create<linalg::FillOp>(
|
||||
loc, ValueRange{constantZero(rewriter, loc, tt.getElementType())}, mem);
|
||||
return mem;
|
||||
};
|
||||
SparseTensorType stt = getSparseTensorType(op.getTensor());
|
||||
TensorType valTensorTp = op.getValues().getType();
|
||||
TensorType crdTensorTp = op.getCoordinates().getType();
|
||||
TypedValue<MemRefType> valMemref = genZeroedAlloc(valTensorTp);
|
||||
TypedValue<MemRefType> crdMemref = genZeroedAlloc(crdTensorTp);
|
||||
assert(valTensorTp.hasStaticShape() && crdTensorTp.hasStaticShape());
|
||||
|
||||
SmallVector<Value> lbs(nBatched, c0), steps(nBatched, c1);
|
||||
SmallVector<Value> ubs;
|
||||
for (unsigned i = 0; i < nBatched; i++) {
|
||||
assert(!ShapedType::isDynamic(stt.getDimShape()[i]));
|
||||
ubs.push_back(constantIndex(rewriter, loc, stt.getDimShape()[i]));
|
||||
}
|
||||
|
||||
DimLevelType dlt = stt.getLvlType(nBatched);
|
||||
assert(isCompressedDLT(dlt) || isCompressedWithHiDLT(dlt));
|
||||
Value posStep = isCompressedDLT(dlt) ? c1 // forward position index by 1
|
||||
: c2; // forward position index by 2
|
||||
auto loopNest = scf::buildLoopNest(
|
||||
rewriter, loc, lbs, ubs, steps, {c0 /*maximum nse*/},
|
||||
[&ubs, c0, c1, posStep, desc, nBatched, &valMemref,
|
||||
&crdMemref](OpBuilder &builder, Location loc, ValueRange ivs,
|
||||
ValueRange args) -> scf::ValueVector {
|
||||
// crdMemref has shape: <... x nse x rank>
|
||||
unsigned unBatchedRank = crdMemref.getType().getShape().back();
|
||||
Value values = desc.getValMemRef();
|
||||
Value flatCrds = unBatchedRank == 1
|
||||
? desc.getCrdMemRefOrView(builder, loc, 0)
|
||||
: desc.getAOSMemRef();
|
||||
|
||||
Value positions = desc.getPosMemRef(nBatched);
|
||||
Value positLo = builder.create<arith::MulIOp>(
|
||||
loc, linearize(builder, loc, ivs, ubs), posStep);
|
||||
Value positHi = builder.create<arith::AddIOp>(loc, positLo, c1);
|
||||
|
||||
Value pLo = genIndexLoad(builder, loc, positions, positLo);
|
||||
Value pHi = genIndexLoad(builder, loc, positions, positHi);
|
||||
Value nse = builder.create<arith::SubIOp>(loc, pHi, pLo);
|
||||
|
||||
Value crdLo = builder.create<arith::MulIOp>(
|
||||
loc, pLo, constantIndex(builder, loc, unBatchedRank));
|
||||
Value nCrd = builder.create<arith::MulIOp>(
|
||||
loc, nse, constantIndex(builder, loc, unBatchedRank));
|
||||
|
||||
SmallVector<Value> offsets, sizes, strides;
|
||||
for (unsigned i = 0; i < nBatched; i++) {
|
||||
offsets.push_back(ivs[i]);
|
||||
sizes.push_back(c1);
|
||||
strides.push_back(c1);
|
||||
}
|
||||
// [0, nse, 1].
|
||||
offsets.push_back(c0);
|
||||
sizes.push_back(nse);
|
||||
strides.push_back(c1);
|
||||
|
||||
auto valView = builder.create<memref::SubViewOp>(
|
||||
loc, valMemref, offsets, sizes, strides);
|
||||
auto valReass = getReassociationForFlattening(valView.getType());
|
||||
Value valDst =
|
||||
builder.create<memref::CollapseShapeOp>(loc, valView, valReass);
|
||||
Value valSrc =
|
||||
builder.create<memref::SubViewOp>(loc, values, pLo, nse, c1);
|
||||
builder.create<memref::CopyOp>(loc, valSrc, valDst);
|
||||
|
||||
// [0, rank, 1].
|
||||
offsets.push_back(c0);
|
||||
sizes.push_back(constantIndex(builder, loc, unBatchedRank));
|
||||
strides.push_back(c1);
|
||||
|
||||
auto crdView = builder.create<memref::SubViewOp>(
|
||||
loc, crdMemref, offsets, sizes, strides);
|
||||
auto crdReass = getReassociationForFlattening(crdView.getType());
|
||||
Value crdDst =
|
||||
builder.create<memref::CollapseShapeOp>(loc, crdView, crdReass);
|
||||
Value crdSrc =
|
||||
builder.create<memref::SubViewOp>(loc, flatCrds, crdLo, nCrd, c1);
|
||||
builder.create<memref::CopyOp>(loc, crdSrc, crdDst);
|
||||
|
||||
Value pred = builder.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::ugt, nse, args[0]);
|
||||
// Choose the larger NSE
|
||||
return {builder.create<arith::SelectOp>(loc, pred, nse, args[0])};
|
||||
});
|
||||
|
||||
// Converts MemRefs back to Tensors.
|
||||
Value values = rewriter.create<bufferization::ToTensorOp>(loc, valMemref);
|
||||
Value coordinates =
|
||||
rewriter.create<bufferization::ToTensorOp>(loc, crdMemref);
|
||||
Value nse =
|
||||
genCast(rewriter, loc, loopNest.results.front(), op.getNse().getType());
|
||||
|
||||
rewriter.replaceOp(op, {values, coordinates, nse});
|
||||
return success();
|
||||
}
|
||||
|
||||
struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
SparseUnpackOpConverter(TypeConverter &typeConverter, MLIRContext *context,
|
||||
@@ -1431,52 +1605,26 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
|
||||
matchAndRewrite(UnpackOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
|
||||
Location loc = op.getLoc();
|
||||
const auto srcTp = getSparseTensorType(op.getTensor());
|
||||
const Level lvlRank = srcTp.getLvlRank();
|
||||
const unsigned nBatched = op.getNumBatchedLvls();
|
||||
assert(isCOOType(srcTp.getEncoding(), nBatched, true) &&
|
||||
desc.getFields().size() == 4); // specifier + pos + crds + values
|
||||
auto logicRes = nBatched == 0
|
||||
? genUnBatchedUnpackOp(op, desc, rewriter)
|
||||
: genBatchedUnpackOp(op, nBatched, desc, rewriter);
|
||||
Value posBuf = desc.getPosMemRef(nBatched);
|
||||
|
||||
assert(isUniqueCOOType(srcTp) && desc.getFields().size() == 4);
|
||||
|
||||
Value flatBuf = lvlRank == 1 ? desc.getCrdMemRefOrView(rewriter, loc, 0)
|
||||
: desc.getAOSMemRef();
|
||||
Value valuesBuf = desc.getValMemRef();
|
||||
Value posBuf = desc.getPosMemRef(0);
|
||||
if (createDeallocs) {
|
||||
// Unpack ends the lifetime of the sparse tensor. While the value array
|
||||
// and coordinate array are unpacked and returned, the position array
|
||||
// becomes useless and need to be freed (if user requests).
|
||||
rewriter.create<memref::DeallocOp>(loc, posBuf);
|
||||
// FIXME: Depending on whether the tensor being unpacked is created by
|
||||
// PackOp or not, we may or may not need to free other memref fields of
|
||||
// the sparse tensor too (PackOp borrows value/coordinate buffer).
|
||||
rewriter.create<memref::DeallocOp>(op.getLoc(), posBuf);
|
||||
}
|
||||
|
||||
// If frontend requests a static buffer, we reallocate the
|
||||
// values/coordinates to ensure that we meet their need.
|
||||
const auto valuesTp = getRankedTensorType(op.getValues());
|
||||
if (valuesTp.hasStaticShape()) {
|
||||
valuesBuf =
|
||||
reallocOrSubView(rewriter, loc, valuesTp.getShape()[0], valuesBuf);
|
||||
}
|
||||
|
||||
const auto coordinatesTp = getRankedTensorType(op.getCoordinates());
|
||||
if (coordinatesTp.hasStaticShape()) {
|
||||
auto len = coordinatesTp.getShape()[0] * coordinatesTp.getShape()[1];
|
||||
flatBuf = reallocOrSubView(rewriter, loc, len, flatBuf);
|
||||
}
|
||||
|
||||
Value coordinatesBuf = rewriter.create<memref::ExpandShapeOp>(
|
||||
loc,
|
||||
MemRefType::get(coordinatesTp.getShape(),
|
||||
coordinatesTp.getElementType()),
|
||||
flatBuf, ArrayRef{ReassociationIndices{0, 1}});
|
||||
|
||||
// Converts MemRefs back to Tensors.
|
||||
Value values = rewriter.create<bufferization::ToTensorOp>(loc, valuesBuf);
|
||||
Value coordinates =
|
||||
rewriter.create<bufferization::ToTensorOp>(loc, coordinatesBuf);
|
||||
Value nse = genCast(rewriter, loc, desc.getValMemSize(rewriter, loc),
|
||||
op.getNse().getType());
|
||||
|
||||
rewriter.replaceOp(op, {values, coordinates, nse});
|
||||
return success();
|
||||
return logicRes;
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
Reference in New Issue
Block a user