From 34c9c59ce4bbe3f6df2b3bc82d0485d4339e057e Mon Sep 17 00:00:00 2001 From: wren romano <2998727+wrengr@users.noreply.github.com> Date: Mon, 3 Apr 2023 12:55:59 -0700 Subject: [PATCH] [mlir][sparse] Using SparseTensorType in SparsePackOpConverter Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D147465 --- .../Transforms/SparseTensorCodegen.cpp | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index 31bf59552f4e..b9f75e9ad005 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -1235,29 +1235,28 @@ struct SparsePackOpConverter : public OpConversionPattern { matchAndRewrite(PackOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - const auto rtp = getRankedTensorType(op.getResult()); - assert(isUniqueCOOType(rtp)); + const auto stt = getSparseTensorType(op.getResult()); + assert(isUniqueCOOType(stt)); SmallVector fields; Location loc = op.getLoc(); foreachFieldAndTypeInSparseTensor( - rtp, - [&rewriter, &fields, &op, rtp, + stt, + [&rewriter, &fields, &op, stt, loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind, Level /*lvl*/, DimLevelType /*dlt*/) -> bool { assert(fields.size() == fIdx); - auto enc = getSparseTensorEncoding(rtp); Value field; switch (fKind) { case SparseTensorFieldKind::StorageSpec: - field = SparseTensorSpecifier::getInitValue(rewriter, loc, rtp); + field = SparseTensorSpecifier::getInitValue(rewriter, loc, stt); break; case SparseTensorFieldKind::PosMemRef: { // TACO-style COO starts with a PosBuffer // By creating a constant value for it, we avoid the complexity of // memory management. - const auto posTp = enc.getPosType(); + const auto posTp = stt.getPosType(); auto tensorType = RankedTensorType::get({2}, posTp); auto memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); @@ -1306,13 +1305,11 @@ struct SparsePackOpConverter : public OpConversionPattern { return true; }); - MutSparseTensorDescriptor desc(rtp, fields); + MutSparseTensorDescriptor desc(stt, fields); auto noe = linalg::createOrFoldDimOp(rewriter, loc, op.getValues(), 0); - // FIXME: should use `SparseTensorType::getLvlRank` in lieu of - // `RankedTensorType::getRank`, because the latter introduces dim/lvl - // ambiguity. - for (Level lvl = 0, lvlRank = rtp.getRank(); lvl < lvlRank; lvl++) { - const auto sh = rtp.getShape()[lvl]; + for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) { + // FIXME: dim/lvl confusion! + const auto sh = stt.getDimShape()[lvl]; assert(!ShapedType::isDynamic(sh)); desc.setLvlSize(rewriter, loc, lvl, constantIndex(rewriter, loc, sh)); if (lvl == 0)