diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 4f1c446faec3..9a6d3161be3d 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -453,11 +453,22 @@ SparseTensorEncodingAttr::tranlateShape(ArrayRef srcShape, // Do constant propagation on the affine map. AffineExpr evalExp = simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0); - if (auto c = evalExp.dyn_cast()) + if (auto c = evalExp.dyn_cast()) { ret.push_back(c.getValue() + 1); - else + } else { + if (auto mod = evalExp.dyn_cast(); + mod && mod.getKind() == AffineExprKind::Mod) { + // We can still infer a static bound for expressions in form + // "d % constant" since d % constant \in [0, constant). + if (auto bound = mod.getRHS().dyn_cast()) { + ret.push_back(bound.getValue()); + continue; + } + } ret.push_back(ShapedType::kDynamic); + } } + assert(ret.size() == rank); return ret; } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index e855a6e19a71..5cdf8cd7ccc9 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -725,6 +725,18 @@ public: } }; +class SparseReMapConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Simply fold the operation. + rewriter.replaceOp(op, adaptor.getSource()); + return success(); + } +}; + /// Sparse codegen rule for the alloc operator. /// TODO(springerm): remove when bufferization.alloc_tensor is gone class SparseTensorAllocConverter @@ -1564,7 +1576,7 @@ void mlir::populateSparseTensorCodegenPatterns( SparseCastConverter, SparseExtractSliceConverter, SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter, SparseInsertConverter, - SparseReorderCOOConverter, + SparseReorderCOOConverter, SparseReMapConverter, SparseSliceGetterOpConverter, SparseSliceGetterOpConverter { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Simply fold the operation. + rewriter.replaceOp(op, adaptor.getSource()); + return success(); + } +}; + /// Sparse conversion rule for the new operator. class SparseTensorNewConverter : public OpConversionPattern { public: @@ -770,7 +782,7 @@ void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns) { patterns .add +#DSDD = #sparse_tensor.encoding<{ + map = (i, j, k, l) -> ( i : dense, j : compressed, k : dense, l : dense) +}> + + !Filename = !llvm.ptr // @@ -77,6 +82,13 @@ module { %vecv = vector.transfer_read %val[%c0], %f0 : memref, vector<12xf64> vector.print %vecv : vector<12xf64> + // CHECK-NEXT: ( 1, 2, 0, 3, 4, 0, 0, 5, 6, 7, 8, 0 ) + %t1 = sparse_tensor.reinterpret_map %A : tensor + to tensor + %vdsdd = sparse_tensor.values %t1 : tensor to memref + %vecdsdd = vector.transfer_read %vdsdd[%c0], %f0 : memref, vector<12xf64> + vector.print %vecdsdd : vector<12xf64> + // Release the resources. bufferization.dealloc_tensor %A: tensor