mirror of
https://github.com/intel/llvm.git
synced 2026-02-04 11:38:04 +08:00
[mlir][sparse] fix a bug in sparse2sparse reshape.
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D133521
This commit is contained in:
@@ -541,14 +541,18 @@ static void translateIndices(Location loc, ConversionPatternRewriter &rewriter,
|
||||
/// coo->add(reshape(elem.indices), elem.value)
|
||||
/// }
|
||||
/// s = newSparseTensor(coo)
|
||||
template <typename ReshapeOp>
|
||||
static LogicalResult
|
||||
genSparse2SparseReshape(Operation *op, ConversionPatternRewriter &rewriter,
|
||||
ArrayRef<ReassociationIndices> reassociation, Value src,
|
||||
RankedTensorType dstTp, RankedTensorType srcTp) {
|
||||
Location loc = op->getLoc();
|
||||
auto encDst = getSparseTensorEncoding(dstTp);
|
||||
genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
Location loc = op.getLoc();
|
||||
auto srcTp = op.getSrc().getType().template cast<RankedTensorType>();
|
||||
auto dstTp = op.getResult().getType().template cast<RankedTensorType>();
|
||||
auto encSrc = getSparseTensorEncoding(srcTp);
|
||||
assert(encDst && encSrc);
|
||||
auto encDst = getSparseTensorEncoding(dstTp);
|
||||
if (!encDst || !encSrc)
|
||||
return failure();
|
||||
|
||||
unsigned srcRank = srcTp.getRank();
|
||||
unsigned dstRank = dstTp.getRank();
|
||||
Type elemTp = srcTp.getElementType();
|
||||
@@ -560,14 +564,16 @@ genSparse2SparseReshape(Operation *op, ConversionPatternRewriter &rewriter,
|
||||
encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
|
||||
SmallVector<Value, 4> sizes;
|
||||
SmallVector<Value, 8> params;
|
||||
sizesFromPtr(rewriter, sizes, loc, noPerm, srcTp, src);
|
||||
sizesFromSrc(rewriter, sizes, loc, op.getSrc());
|
||||
newParams(rewriter, params, loc, srcTp, noPerm, Action::kToIterator, sizes,
|
||||
src);
|
||||
adaptor.getSrc());
|
||||
Value iter = genNewCall(rewriter, loc, params);
|
||||
// Start a new COO for the destination tensor.
|
||||
sizes.clear();
|
||||
params.clear();
|
||||
sizesFromPtr(rewriter, sizes, loc, encDst, dstTp, src);
|
||||
// Fills sizes array using the sizes from destination type.
|
||||
assert(dstTp.hasStaticShape());
|
||||
sizesFromType(rewriter, sizes, loc, dstTp);
|
||||
newParams(rewriter, params, loc, dstTp, encDst, Action::kEmptyCOO, sizes);
|
||||
Value coo = genNewCall(rewriter, loc, params);
|
||||
Value dstPerm = params[2];
|
||||
@@ -586,7 +592,8 @@ genSparse2SparseReshape(Operation *op, ConversionPatternRewriter &rewriter,
|
||||
// not need to store the value in elemPtr, as the value is still there.
|
||||
Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes);
|
||||
rewriter.setInsertionPointToStart(after);
|
||||
translateIndices(loc, rewriter, reassociation, dstTp, srcTp, dstIdx, srcIdx);
|
||||
translateIndices(loc, rewriter, op.getReassociationIndices(), dstTp, srcTp,
|
||||
dstIdx, srcIdx);
|
||||
genAddEltCall(rewriter, loc, elemTp, coo, elemPtr, dstIdx, dstPerm);
|
||||
rewriter.create<scf::YieldOp>(loc);
|
||||
// Final call to construct sparse tensor storage and free temporary resources.
|
||||
@@ -756,15 +763,7 @@ public:
|
||||
LogicalResult
|
||||
matchAndRewrite(ReshapeOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type dstType = op.getResult().getType();
|
||||
Type srcType = op.getSrc().getType();
|
||||
auto encDst = getSparseTensorEncoding(dstType);
|
||||
auto encSrc = getSparseTensorEncoding(srcType);
|
||||
if (encDst && encSrc)
|
||||
return genSparse2SparseReshape(
|
||||
op, rewriter, op.getReassociationIndices(), adaptor.getOperands()[0],
|
||||
dstType.cast<RankedTensorType>(), srcType.cast<RankedTensorType>());
|
||||
return failure(); // handled elsewhere
|
||||
return genSparse2SparseReshape(op, adaptor, rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user