mirror of
https://github.com/intel/llvm.git
synced 2026-01-30 22:53:05 +08:00
[mlir][sparse] avoid non-perm on sparse tensor convert for new (#72459)
This avoids seeing non-perm on the convert from COO to non-COO for higher dimensional new operators (viz. reading in BSR). This is step 1 out of 3 to make sparse_tensor.new work for BSR
This commit is contained in:
@@ -1189,20 +1189,30 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
|
||||
LogicalResult matchAndRewrite(NewOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
const auto dstTp = getSparseTensorType(op.getResult());
|
||||
const auto encDst = dstTp.getEncoding();
|
||||
if (!dstTp.hasEncoding() || getCOOStart(encDst) == 0)
|
||||
auto stt = getSparseTensorType(op.getResult());
|
||||
auto enc = stt.getEncoding();
|
||||
if (!stt.hasEncoding() || getCOOStart(enc) == 0)
|
||||
return failure();
|
||||
|
||||
// Implement the NewOp as follows:
|
||||
// %orderedCoo = sparse_tensor.new %filename
|
||||
// %t = sparse_tensor.convert %orderedCoo
|
||||
// with enveloping reinterpreted_map ops for non-permutations.
|
||||
RankedTensorType dstTp = stt.getRankedTensorType();
|
||||
RankedTensorType cooTp = getCOOType(dstTp, /*ordered=*/true);
|
||||
Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
|
||||
Value convert = rewriter.replaceOpWithNewOp<ConvertOp>(
|
||||
op, dstTp.getRankedTensorType(), cooTensor);
|
||||
Value convert = cooTensor;
|
||||
if (!stt.isPermutation()) { // demap coo, demap dstTp
|
||||
auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl();
|
||||
convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert);
|
||||
dstTp = getSparseTensorType(convert).withEncoding(enc.withoutDimToLvl());
|
||||
}
|
||||
convert = rewriter.create<ConvertOp>(loc, dstTp, convert);
|
||||
if (!stt.isPermutation()) // remap to original enc
|
||||
convert = rewriter.create<ReinterpretMapOp>(loc, enc, convert);
|
||||
rewriter.replaceOp(op, convert);
|
||||
|
||||
// Release the ordered COO tensor.
|
||||
// Release the temporary ordered COO tensor.
|
||||
rewriter.setInsertionPointAfterValue(convert);
|
||||
rewriter.create<DeallocTensorOp>(loc, cooTensor);
|
||||
|
||||
@@ -1210,6 +1220,7 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse rewriting rule for the out operator.
|
||||
struct OutRewriter : public OpRewritePattern<OutOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(OutOp op,
|
||||
@@ -1250,6 +1261,7 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
|
||||
primaryTypeFunctionSuffix(eltTp)};
|
||||
Value value = genAllocaScalar(rewriter, loc, eltTp);
|
||||
ModuleOp module = op->getParentOfType<ModuleOp>();
|
||||
|
||||
// For each element in the source tensor, output the element.
|
||||
rewriter.create<ForeachOp>(
|
||||
loc, src, std::nullopt,
|
||||
|
||||
Reference in New Issue
Block a user