[mlir][SparseTensor] Fix incorrect API usage in RewritePatterns
authorMatthias Springer <me@m-sp.org>
Thu, 2 Mar 2023 16:22:06 +0000 (17:22 +0100)
committerMatthias Springer <me@m-sp.org>
Thu, 2 Mar 2023 16:59:57 +0000 (17:59 +0100)
Incorrect API usage was detected by D144552.

Differential Revision: https://reviews.llvm.org/D145166

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

index b128669..0663bd9 100644 (file)
@@ -444,7 +444,7 @@ public:
       auto denseTp =
           RankedTensorType::get(rtp.getShape(), rtp.getElementType());
       auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
-      op->setOperand(0, convert);
+      rewriter.updateRootInPlace(op, [&]() { op->setOperand(0, convert); });
       return success();
     }
     if (encDst) {
index bc05137..1772eef 100644 (file)
@@ -546,7 +546,7 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
           forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()));
       rewriter.setInsertionPointToStart(forOpNew.getBody());
     } else {
-      forOp.setStep(step);
+      rewriter.updateRootInPlace(forOp, [&]() { forOp.setStep(step); });
       rewriter.setInsertionPoint(yield);
     }
     vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
@@ -575,10 +575,11 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
         // Now do some relinking (last one is not completely type safe
         // but all bad ones are removed right away). This also folds away
         // nop broadcast operations.
-        forOp.getResult(0).replaceAllUsesWith(vres);
-        forOp.getInductionVar().replaceAllUsesWith(forOpNew.getInductionVar());
-        forOp.getRegionIterArg(0).replaceAllUsesWith(
-            forOpNew.getRegionIterArg(0));
+        rewriter.replaceAllUsesWith(forOp.getResult(0), vres);
+        rewriter.replaceAllUsesWith(forOp.getInductionVar(),
+                                    forOpNew.getInductionVar());
+        rewriter.replaceAllUsesWith(forOp.getRegionIterArg(0),
+                                    forOpNew.getRegionIterArg(0));
         rewriter.eraseOp(forOp);
       }
       return true;
index d8fcac0..575b505 100644 (file)
@@ -838,9 +838,12 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
     if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
       return genIndexValue(env, indexOp.getDim());
     if (def->getBlock() == block) {
-      for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
-        def->setOperand(
-            i, relinkBranch(env, rewriter, block, def->getOperand(i), ldx));
+      for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
+        rewriter.updateRootInPlace(def, [&]() {
+          def->setOperand(
+              i, relinkBranch(env, rewriter, block, def->getOperand(i), ldx));
+        });
+      }
     }
   }
   return e;
@@ -1615,7 +1618,8 @@ private:
       auto dstTp = RankedTensorType::get(srcTp.getShape(),
                                          srcTp.getElementType(), dstEnc);
       auto convert = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
-      env.op()->setOperand(tensor, convert);
+      rewriter.updateRootInPlace(
+          env.op(), [&]() { env.op()->setOperand(tensor, convert); });
       rewriter.setInsertionPointAfter(env.op());
       rewriter.create<bufferization::DeallocTensorOp>(tval.getLoc(), convert);
       return success();