[mlir][SparseTensor] Fix incorrect API usage in RewritePatterns

Incorrect API usage was detected by D144552.

Differential Revision: https://reviews.llvm.org/D145166
This commit is contained in:
Matthias Springer
2023-03-02 17:22:06 +01:00
parent 37114036aa
commit ae9e1d1df4
3 changed files with 15 additions and 10 deletions

View File

@@ -444,7 +444,7 @@ public:
auto denseTp =
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
op->setOperand(0, convert);
rewriter.updateRootInPlace(op, [&]() { op->setOperand(0, convert); });
return success();
}
if (encDst) {

View File

@@ -546,7 +546,7 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()));
rewriter.setInsertionPointToStart(forOpNew.getBody());
} else {
forOp.setStep(step);
rewriter.updateRootInPlace(forOp, [&]() { forOp.setStep(step); });
rewriter.setInsertionPoint(yield);
}
vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
@@ -575,10 +575,11 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
// Now do some relinking (last one is not completely type safe
// but all bad ones are removed right away). This also folds away
// nop broadcast operations.
forOp.getResult(0).replaceAllUsesWith(vres);
forOp.getInductionVar().replaceAllUsesWith(forOpNew.getInductionVar());
forOp.getRegionIterArg(0).replaceAllUsesWith(
forOpNew.getRegionIterArg(0));
rewriter.replaceAllUsesWith(forOp.getResult(0), vres);
rewriter.replaceAllUsesWith(forOp.getInductionVar(),
forOpNew.getInductionVar());
rewriter.replaceAllUsesWith(forOp.getRegionIterArg(0),
forOpNew.getRegionIterArg(0));
rewriter.eraseOp(forOp);
}
return true;

View File

@@ -838,9 +838,12 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
return genIndexValue(env, indexOp.getDim());
if (def->getBlock() == block) {
for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
def->setOperand(
i, relinkBranch(env, rewriter, block, def->getOperand(i), ldx));
for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
rewriter.updateRootInPlace(def, [&]() {
def->setOperand(
i, relinkBranch(env, rewriter, block, def->getOperand(i), ldx));
});
}
}
}
return e;
@@ -1615,7 +1618,8 @@ private:
auto dstTp = RankedTensorType::get(srcTp.getShape(),
srcTp.getElementType(), dstEnc);
auto convert = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
env.op()->setOperand(tensor, convert);
rewriter.updateRootInPlace(
env.op(), [&]() { env.op()->setOperand(tensor, convert); });
rewriter.setInsertionPointAfter(env.op());
rewriter.create<bufferization::DeallocTensorOp>(tval.getLoc(), convert);
return success();