From ae9e1d1df46a50a6748514ee1d7d85e7fa81890d Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 2 Mar 2023 17:22:06 +0100 Subject: [PATCH] [mlir][SparseTensor] Fix incorrect API usage in RewritePatterns Incorrect API usage was detected by D144552. Differential Revision: https://reviews.llvm.org/D145166 --- .../SparseTensor/Transforms/SparseTensorRewriting.cpp | 2 +- .../Dialect/SparseTensor/Transforms/SparseVectorization.cpp | 11 ++++++----- mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp | 12 ++++++++---- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index b128669..0663bd9 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -444,7 +444,7 @@ public: auto denseTp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); auto convert = rewriter.create(loc, denseTp, op.getSrc()); - op->setOperand(0, convert); + rewriter.updateRootInPlace(op, [&]() { op->setOperand(0, convert); }); return success(); } if (encDst) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp index bc05137..1772eef 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -546,7 +546,7 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName())); rewriter.setInsertionPointToStart(forOpNew.getBody()); } else { - forOp.setStep(step); + rewriter.updateRootInPlace(forOp, [&]() { forOp.setStep(step); }); rewriter.setInsertionPoint(yield); } vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(), @@ -575,10 +575,11 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, // Now do some relinking (last one is not completely type safe // but all bad ones are removed right away). This also folds away // nop broadcast operations. - forOp.getResult(0).replaceAllUsesWith(vres); - forOp.getInductionVar().replaceAllUsesWith(forOpNew.getInductionVar()); - forOp.getRegionIterArg(0).replaceAllUsesWith( - forOpNew.getRegionIterArg(0)); + rewriter.replaceAllUsesWith(forOp.getResult(0), vres); + rewriter.replaceAllUsesWith(forOp.getInductionVar(), + forOpNew.getInductionVar()); + rewriter.replaceAllUsesWith(forOp.getRegionIterArg(0), + forOpNew.getRegionIterArg(0)); rewriter.eraseOp(forOp); } return true; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index d8fcac0..575b505 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -838,9 +838,12 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block, if (auto indexOp = dyn_cast(def)) return genIndexValue(env, indexOp.getDim()); if (def->getBlock() == block) { - for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) - def->setOperand( - i, relinkBranch(env, rewriter, block, def->getOperand(i), ldx)); + for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) { + rewriter.updateRootInPlace(def, [&]() { + def->setOperand( + i, relinkBranch(env, rewriter, block, def->getOperand(i), ldx)); + }); + } } } return e; @@ -1615,7 +1618,8 @@ private: auto dstTp = RankedTensorType::get(srcTp.getShape(), srcTp.getElementType(), dstEnc); auto convert = rewriter.create(tval.getLoc(), dstTp, tval); - env.op()->setOperand(tensor, convert); + rewriter.updateRootInPlace( + env.op(), [&]() { env.op()->setOperand(tensor, convert); }); rewriter.setInsertionPointAfter(env.op()); rewriter.create(tval.getLoc(), convert); return success(); -- 2.7.4