[mlir][sparse] introduce a higher-order tensor mapping

This extension to the sparse tensor type system in MLIR
opens up a whole new set of sparse storage schemes, such as
block sparse storage (e.g. BCSR) and ELL (aka jagged diagonals).

This revision merely introduces the type extension and
initial documentation. The actual interpretation of the type
(reading in tensors, lowering to code, etc.) will follow.

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D135206
This commit is contained in:
Aart Bik
2022-10-04 14:34:37 -07:00
parent 1522f190eb
commit c48e90877f
18 changed files with 191 additions and 39 deletions

View File

@@ -525,7 +525,7 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
"reshape should not change element type");
// Start an iterator over the source tensor (in original index order).
auto noPerm = SparseTensorEncodingAttr::get(
op->getContext(), encSrc.getDimLevelType(), AffineMap(),
op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(),
encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
SmallVector<Value, 4> srcSizes;
SmallVector<Value, 8> params;
@@ -595,7 +595,7 @@ static void genSparseCOOIterationLoop(
// Start an iterator over the tensor (in original index order).
auto noPerm = SparseTensorEncodingAttr::get(
rewriter.getContext(), enc.getDimLevelType(), AffineMap(),
rewriter.getContext(), enc.getDimLevelType(), AffineMap(), AffineMap(),
enc.getPointerBitWidth(), enc.getIndexBitWidth());
SmallVector<Value, 4> sizes;
SmallVector<Value, 8> params;
@@ -857,7 +857,8 @@ public:
// the correct sparsity information to either of them.
auto enc = SparseTensorEncodingAttr::get(
op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(),
encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
encDst.getHigherOrdering(), encSrc.getPointerBitWidth(),
encSrc.getIndexBitWidth());
newParams(rewriter, params, loc, stp, enc, Action::kToCOO, sizes, src);
Value coo = genNewCall(rewriter, loc, params);
params[3] = constantPointerTypeEncoding(rewriter, loc, encDst);
@@ -889,7 +890,8 @@ public:
op->getContext(),
SmallVector<SparseTensorEncodingAttr::DimLevelType>(
rank, SparseTensorEncodingAttr::DimLevelType::Dense),
AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
AffineMap(), AffineMap(), encSrc.getPointerBitWidth(),
encSrc.getIndexBitWidth());
SmallVector<Value, 4> sizes;
SmallVector<Value, 8> params;
sizesFromPtr(rewriter, sizes, loc, encSrc, srcTensorTp, src);
@@ -1373,7 +1375,7 @@ public:
SmallVector<Value, 8> params;
sizesFromPtr(rewriter, sizes, loc, encSrc, srcType, src);
auto enc = SparseTensorEncodingAttr::get(
op->getContext(), encSrc.getDimLevelType(), AffineMap(),
op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(),
encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
newParams(rewriter, params, loc, srcType, enc, Action::kToCOO, sizes, src);
Value coo = genNewCall(rewriter, loc, params);