[mlir][sparse] avoid non-perm on sparse tensor convert for new (#72459)

This avoids seeing non-perm on the convert from COO to non-COO for
higher dimensional new operators (viz. reading in BSR).

This is step 1 out of 3 to make sparse_tensor.new work for BSR
This commit is contained in:
Aart Bik
2023-11-15 20:47:37 -08:00
committed by GitHub
parent 84044061e8
commit e8fc282ff2

View File

@@ -1189,20 +1189,30 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
LogicalResult matchAndRewrite(NewOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
const auto dstTp = getSparseTensorType(op.getResult());
const auto encDst = dstTp.getEncoding();
if (!dstTp.hasEncoding() || getCOOStart(encDst) == 0)
auto stt = getSparseTensorType(op.getResult());
auto enc = stt.getEncoding();
if (!stt.hasEncoding() || getCOOStart(enc) == 0)
return failure();
// Implement the NewOp as follows:
// %orderedCoo = sparse_tensor.new %filename
// %t = sparse_tensor.convert %orderedCoo
// with enveloping reinterpreted_map ops for non-permutations.
RankedTensorType dstTp = stt.getRankedTensorType();
RankedTensorType cooTp = getCOOType(dstTp, /*ordered=*/true);
Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
Value convert = rewriter.replaceOpWithNewOp<ConvertOp>(
op, dstTp.getRankedTensorType(), cooTensor);
Value convert = cooTensor;
if (!stt.isPermutation()) { // demap coo, demap dstTp
auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl();
convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert);
dstTp = getSparseTensorType(convert).withEncoding(enc.withoutDimToLvl());
}
convert = rewriter.create<ConvertOp>(loc, dstTp, convert);
if (!stt.isPermutation()) // remap to original enc
convert = rewriter.create<ReinterpretMapOp>(loc, enc, convert);
rewriter.replaceOp(op, convert);
// Release the ordered COO tensor.
// Release the temporary ordered COO tensor.
rewriter.setInsertionPointAfterValue(convert);
rewriter.create<DeallocTensorOp>(loc, cooTensor);
@@ -1210,6 +1220,7 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
}
};
/// Sparse rewriting rule for the out operator.
struct OutRewriter : public OpRewritePattern<OutOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(OutOp op,
@@ -1250,6 +1261,7 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
primaryTypeFunctionSuffix(eltTp)};
Value value = genAllocaScalar(rewriter, loc, eltTp);
ModuleOp module = op->getParentOfType<ModuleOp>();
// For each element in the source tensor, output the element.
rewriter.create<ForeachOp>(
loc, src, std::nullopt,