[mlir] Introduce `replaceUsesOfWith` to `RewriterBase`
authorGuray Ozen <guray.ozen@gmail.com>
Wed, 16 Nov 2022 16:23:43 +0000 (17:23 +0100)
committerGuray Ozen <guray.ozen@gmail.com>
Wed, 16 Nov 2022 16:53:11 +0000 (17:53 +0100)
Finding uses of a value and replacing them with a new one is a common method. I have not seen an safe and easy shortcut that does that. This revision attempts to address that by intoroducing `replaceUsesOfWith` to `RewriterBase`.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D138110

mlir/include/mlir/IR/PatternMatch.h
mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
mlir/lib/IR/PatternMatch.cpp

index e257b67..7b05a2d 100644 (file)
@@ -502,6 +502,11 @@ public:
     finalizeRootUpdate(root);
   }
 
+  /// Find uses of `from` and replace it with `to`. It also marks every modified
+  /// uses and notifies the rewriter that an in-place operation modification is
+  /// about to happen.
+  void replaceAllUsesWith(Value from, Value to);
+
   /// Used to notify the rewriter that the IR failed to be rewritten because of
   /// a match failure, and provide a callback to populate a diagnostic with the
   /// reason why the failure occurred. This method allows for derived rewriters
index ccac412..ec493bf 100644 (file)
@@ -246,15 +246,9 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
                                       sourceBlock.getOperations());
 
   // Step 5. RAUW thread indices to thread ops.
-  for (Value blockIdx : foreachThreadOp.getThreadIndices()) {
-    Value val = bvm.lookup(blockIdx);
-    SmallVector<OpOperand *> uses;
-    for (OpOperand &use : blockIdx.getUses())
-      uses.push_back(&use);
-    for (OpOperand *operand : uses) {
-      Operation *op = operand->getOwner();
-      rewriter.updateRootInPlace(op, [&]() { operand->set(val); });
-    }
+  for (Value loopIndex : foreachThreadOp.getThreadIndices()) {
+    Value blockIdx = bvm.lookup(loopIndex);
+    rewriter.replaceAllUsesWith(loopIndex, blockIdx);
   }
 
   // Step 6. Erase old op.
@@ -492,15 +486,9 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
                                       sourceBlock.getOperations());
 
   // Step 6. RAUW thread indices to thread ops.
-  for (Value threadIdx : foreachThreadOp.getThreadIndices()) {
-    Value val = bvm.lookup(threadIdx);
-    SmallVector<OpOperand *> uses;
-    for (OpOperand &use : threadIdx.getUses())
-      uses.push_back(&use);
-    for (OpOperand *operand : uses) {
-      Operation *op = operand->getOwner();
-      rewriter.updateRootInPlace(op, [&]() { operand->set(val); });
-    }
+  for (Value loopIndex : foreachThreadOp.getThreadIndices()) {
+    Value threadIdx = bvm.lookup(loopIndex);
+    rewriter.replaceAllUsesWith(loopIndex, threadIdx);
   }
 
   // Step 7. syncthreads.
index d2de65e..d3072b5 100644 (file)
@@ -309,6 +309,14 @@ void RewriterBase::mergeBlocks(Block *source, Block *dest,
   source->erase();
 }
 
+/// Find uses of `from` and replace it with `to`
+void RewriterBase::replaceAllUsesWith(Value from, Value to) {
+  for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
+    Operation *op = operand.getOwner();
+    updateRootInPlace(op, [&]() { operand.set(to); });
+  }
+}
+
 // Merge the operations of block 'source' before the operation 'op'. Source
 // block should not have existing predecessors or successors.
 void RewriterBase::mergeBlockBefore(Block *source, Operation *op,