[mlir][NFC] Add helper for common pattern of replaceAllUsesExcept
authorSean Silva <silvasean@google.com>
Wed, 12 May 2021 21:59:12 +0000 (14:59 -0700)
committerSean Silva <silvasean@google.com>
Thu, 13 May 2021 19:42:10 +0000 (12:42 -0700)
This covers the extremely common case of replacing all uses of a Value
with a new op that is itself a user of the original Value.

This should also be a little bit more efficient than the
`SmallPtrSet<Operation *, 1>{op}` idiom that was being used before.

Differential Revision: https://reviews.llvm.org/D102373

mlir/include/mlir/IR/Value.h
mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
mlir/lib/IR/Value.cpp

index 2a0bd67..bd80b2b 100644 (file)
@@ -166,6 +166,11 @@ public:
   replaceAllUsesExcept(Value newValue,
                        const SmallPtrSetImpl<Operation *> &exceptions) const;
 
+  /// Replace all uses of 'this' value with 'newValue', updating anything in the
+  /// IR that uses 'this' to use the other value instead except if the user is
+  /// 'exceptedUser'.
+  void replaceAllUsesExcept(Value newValue, Operation *exceptedUser) const;
+
   /// Replace all uses of 'this' value with 'newValue' if the given callback
   /// returns true.
   void replaceUsesWithIf(Value newValue,
index 8653bcf..1a785a0 100644 (file)
@@ -72,7 +72,7 @@ void mlir::normalizeAffineParallel(AffineParallelOp op) {
     applyOperands.push_back(iv);
     applyOperands.append(symbolOperands.begin(), symbolOperands.end());
     auto apply = builder.create<AffineApplyOp>(op.getLoc(), map, applyOperands);
-    iv.replaceAllUsesExcept(apply, SmallPtrSet<Operation *, 1>{apply});
+    iv.replaceAllUsesExcept(apply, apply);
   }
 
   SmallVector<int64_t, 8> newSteps(op.getNumDims(), 1);
@@ -181,8 +181,7 @@ static void normalizeAffineFor(AffineForOp op) {
   AffineMap ivMap = AffineMap::get(origLbMap.getNumDims() + 1,
                                    origLbMap.getNumSymbols(), newIVExpr);
   Operation *newIV = opBuilder.create<AffineApplyOp>(loc, ivMap, lbOperands);
-  op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0),
-                                            SmallPtrSet<Operation *, 1>{newIV});
+  op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0), newIV);
 }
 
 namespace {
index 79903c0..4c06d3d 100644 (file)
@@ -191,8 +191,7 @@ static LinalgOp fuse(OpBuilder &builder, LinalgOp producer,
       AffineApplyOp applyOp = builder.create<AffineApplyOp>(
           indexOp.getLoc(), index + offset,
           ValueRange{indexOp.getResult(), loopRanges[indexOp.dim()].offset});
-      indexOp.getResult().replaceAllUsesExcept(
-          applyOp, SmallPtrSet<Operation *, 1>{applyOp});
+      indexOp.getResult().replaceAllUsesExcept(applyOp, applyOp);
     }
   }
 
index 0479ab6..bdc1d70 100644 (file)
@@ -155,8 +155,7 @@ transformIndexOps(OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
     AffineApplyOp applyOp = b.create<AffineApplyOp>(
         indexOp.getLoc(), index + iv,
         ValueRange{indexOp.getResult(), ivs[rangeIndex->second]});
-    indexOp.getResult().replaceAllUsesExcept(
-        applyOp.getResult(), SmallPtrSet<Operation *, 1>{applyOp});
+    indexOp.getResult().replaceAllUsesExcept(applyOp, applyOp);
   }
 }
 
index cdb4afa..8282c07 100644 (file)
@@ -121,8 +121,7 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes) {
     Value inner_index = std::get<0>(ivs);
     AddIOp newIndex =
         b.create<AddIOp>(op.getLoc(), std::get<0>(ivs), std::get<1>(ivs));
-    inner_index.replaceAllUsesExcept(
-        newIndex, SmallPtrSet<Operation *, 1>{newIndex.getOperation()});
+    inner_index.replaceAllUsesExcept(newIndex, newIndex);
   }
 
   op.erase();
index e28ab9b..a4baa93 100644 (file)
@@ -63,12 +63,23 @@ void Value::replaceAllUsesWith(Value newValue) const {
 /// listed in 'exceptions' .
 void Value::replaceAllUsesExcept(
     Value newValue, const SmallPtrSetImpl<Operation *> &exceptions) const {
-  for (auto &use : llvm::make_early_inc_range(getUses())) {
+  for (OpOperand &use : llvm::make_early_inc_range(getUses())) {
     if (exceptions.count(use.getOwner()) == 0)
       use.set(newValue);
   }
 }
 
+/// Replace all uses of 'this' value with 'newValue', updating anything in the
+/// IR that uses 'this' to use the other value instead except if the user is
+/// 'exceptedUser'.
+void Value::replaceAllUsesExcept(Value newValue,
+                                 Operation *exceptedUser) const {
+  for (OpOperand &use : llvm::make_early_inc_range(getUses())) {
+    if (use.getOwner() != exceptedUser)
+      use.set(newValue);
+  }
+}
+
 /// Replace all uses of 'this' value with 'newValue' if the given callback
 /// returns true.
 void Value::replaceUsesWithIf(Value newValue,