mirror of
https://github.com/intel/llvm.git
synced 2026-01-16 05:32:28 +08:00
[mlir][SparseTensor] Fix incorrect API usage in RewritePatterns
Incorrect API usage was detected by D144552. Differential Revision: https://reviews.llvm.org/D145166
This commit is contained in:
@@ -444,7 +444,7 @@ public:
|
||||
auto denseTp =
|
||||
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
|
||||
auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
|
||||
op->setOperand(0, convert);
|
||||
rewriter.updateRootInPlace(op, [&]() { op->setOperand(0, convert); });
|
||||
return success();
|
||||
}
|
||||
if (encDst) {
|
||||
|
||||
@@ -546,7 +546,7 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
|
||||
forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()));
|
||||
rewriter.setInsertionPointToStart(forOpNew.getBody());
|
||||
} else {
|
||||
forOp.setStep(step);
|
||||
rewriter.updateRootInPlace(forOp, [&]() { forOp.setStep(step); });
|
||||
rewriter.setInsertionPoint(yield);
|
||||
}
|
||||
vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
|
||||
@@ -575,10 +575,11 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
|
||||
// Now do some relinking (last one is not completely type safe
|
||||
// but all bad ones are removed right away). This also folds away
|
||||
// nop broadcast operations.
|
||||
forOp.getResult(0).replaceAllUsesWith(vres);
|
||||
forOp.getInductionVar().replaceAllUsesWith(forOpNew.getInductionVar());
|
||||
forOp.getRegionIterArg(0).replaceAllUsesWith(
|
||||
forOpNew.getRegionIterArg(0));
|
||||
rewriter.replaceAllUsesWith(forOp.getResult(0), vres);
|
||||
rewriter.replaceAllUsesWith(forOp.getInductionVar(),
|
||||
forOpNew.getInductionVar());
|
||||
rewriter.replaceAllUsesWith(forOp.getRegionIterArg(0),
|
||||
forOpNew.getRegionIterArg(0));
|
||||
rewriter.eraseOp(forOp);
|
||||
}
|
||||
return true;
|
||||
|
||||
@@ -838,9 +838,12 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
|
||||
if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
|
||||
return genIndexValue(env, indexOp.getDim());
|
||||
if (def->getBlock() == block) {
|
||||
for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
|
||||
def->setOperand(
|
||||
i, relinkBranch(env, rewriter, block, def->getOperand(i), ldx));
|
||||
for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
|
||||
rewriter.updateRootInPlace(def, [&]() {
|
||||
def->setOperand(
|
||||
i, relinkBranch(env, rewriter, block, def->getOperand(i), ldx));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
return e;
|
||||
@@ -1615,7 +1618,8 @@ private:
|
||||
auto dstTp = RankedTensorType::get(srcTp.getShape(),
|
||||
srcTp.getElementType(), dstEnc);
|
||||
auto convert = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
|
||||
env.op()->setOperand(tensor, convert);
|
||||
rewriter.updateRootInPlace(
|
||||
env.op(), [&]() { env.op()->setOperand(tensor, convert); });
|
||||
rewriter.setInsertionPointAfter(env.op());
|
||||
rewriter.create<bufferization::DeallocTensorOp>(tval.getLoc(), convert);
|
||||
return success();
|
||||
|
||||
Reference in New Issue
Block a user