[mlir][sparse] fix bug in workspace dimension computation

Access pattern expansion is always done along the innermost stored
dimension, but this was incorrectly reordered due to using a
general utility typically used by original dimensions only.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D133472
This commit is contained in:
Aart Bik
2022-09-07 22:46:31 -07:00
parent ac3b8df8f2
commit ec8f2905a3
4 changed files with 213 additions and 12 deletions

View File

@@ -1166,10 +1166,15 @@ public:
Type idxType = rewriter.getIndexType();
// All initialization should be done on entry of the loop nest.
rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
// Determine the size for access expansion.
// Determine the size for access expansion (always the innermost stored
// dimension size, but we need to translate it back to the original
// dimension since the dim size utility applies dimension ordering).
auto enc = getSparseTensorEncoding(srcType);
Value src = adaptor.getOperands()[0];
Value sz = genDimSizeCall(rewriter, loc, enc, src, srcType.getRank() - 1);
unsigned innerDim = srcType.getRank() - 1;
if (AffineMap p = enc.getDimOrdering())
innerDim = p.getDimPosition(innerDim);
Value sz = genDimSizeCall(rewriter, loc, enc, src, innerDim);
// Allocate temporary buffers for values, filled-switch, and indices.
// We do not use stack buffers for this, since the expanded size may
// be rather large (as it envelops a single expanded dense dimension).