[mlir][sparse] fix a bug in sparse2sparse reshape.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D133521
This commit is contained in:
Peiming Liu
2022-09-08 20:48:26 +00:00
parent f76dcede3f
commit 180bf5f940

View File

@@ -541,14 +541,18 @@ static void translateIndices(Location loc, ConversionPatternRewriter &rewriter,
/// coo->add(reshape(elem.indices), elem.value)
/// }
/// s = newSparseTensor(coo)
template <typename ReshapeOp>
static LogicalResult
genSparse2SparseReshape(Operation *op, ConversionPatternRewriter &rewriter,
ArrayRef<ReassociationIndices> reassociation, Value src,
RankedTensorType dstTp, RankedTensorType srcTp) {
Location loc = op->getLoc();
auto encDst = getSparseTensorEncoding(dstTp);
genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) {
Location loc = op.getLoc();
auto srcTp = op.getSrc().getType().template cast<RankedTensorType>();
auto dstTp = op.getResult().getType().template cast<RankedTensorType>();
auto encSrc = getSparseTensorEncoding(srcTp);
assert(encDst && encSrc);
auto encDst = getSparseTensorEncoding(dstTp);
if (!encDst || !encSrc)
return failure();
unsigned srcRank = srcTp.getRank();
unsigned dstRank = dstTp.getRank();
Type elemTp = srcTp.getElementType();
@@ -560,14 +564,16 @@ genSparse2SparseReshape(Operation *op, ConversionPatternRewriter &rewriter,
encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
SmallVector<Value, 4> sizes;
SmallVector<Value, 8> params;
sizesFromPtr(rewriter, sizes, loc, noPerm, srcTp, src);
sizesFromSrc(rewriter, sizes, loc, op.getSrc());
newParams(rewriter, params, loc, srcTp, noPerm, Action::kToIterator, sizes,
src);
adaptor.getSrc());
Value iter = genNewCall(rewriter, loc, params);
// Start a new COO for the destination tensor.
sizes.clear();
params.clear();
sizesFromPtr(rewriter, sizes, loc, encDst, dstTp, src);
// Fills sizes array using the sizes from destination type.
assert(dstTp.hasStaticShape());
sizesFromType(rewriter, sizes, loc, dstTp);
newParams(rewriter, params, loc, dstTp, encDst, Action::kEmptyCOO, sizes);
Value coo = genNewCall(rewriter, loc, params);
Value dstPerm = params[2];
@@ -586,7 +592,8 @@ genSparse2SparseReshape(Operation *op, ConversionPatternRewriter &rewriter,
// not need to store the value in elemPtr, as the value is still there.
Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes);
rewriter.setInsertionPointToStart(after);
translateIndices(loc, rewriter, reassociation, dstTp, srcTp, dstIdx, srcIdx);
translateIndices(loc, rewriter, op.getReassociationIndices(), dstTp, srcTp,
dstIdx, srcIdx);
genAddEltCall(rewriter, loc, elemTp, coo, elemPtr, dstIdx, dstPerm);
rewriter.create<scf::YieldOp>(loc);
// Final call to construct sparse tensor storage and free temporary resources.
@@ -756,15 +763,7 @@ public:
LogicalResult
matchAndRewrite(ReshapeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type dstType = op.getResult().getType();
Type srcType = op.getSrc().getType();
auto encDst = getSparseTensorEncoding(dstType);
auto encSrc = getSparseTensorEncoding(srcType);
if (encDst && encSrc)
return genSparse2SparseReshape(
op, rewriter, op.getReassociationIndices(), adaptor.getOperands()[0],
dstType.cast<RankedTensorType>(), srcType.cast<RankedTensorType>());
return failure(); // handled elsewhere
return genSparse2SparseReshape(op, adaptor, rewriter);
}
};