[MLIR] Add replaceUsesWithIf on Operation
authorUday Bondhugula <uday@polymagelabs.com>
Tue, 21 Feb 2023 04:40:15 +0000 (10:10 +0530)
committerUday Bondhugula <uday@polymagelabs.com>
Tue, 21 Feb 2023 04:40:22 +0000 (10:10 +0530)
Add replaceUsesWithIf on Operation along the lines of
Value::replaceUsesWithIf. This had been missing on Operation and is
convenient to replace multi-result operations' results conditionally.

Reviewed By: lattner

Differential Revision: https://reviews.llvm.org/D144348

mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/ValueRange.h
mlir/lib/IR/OperationSupport.cpp
mlir/lib/Transforms/CSE.cpp

index e7ef01e..ac6bdfc 100644 (file)
@@ -257,6 +257,15 @@ public:
     getResults().replaceAllUsesWith(std::forward<ValuesT>(values));
   }
 
+  /// Replace uses of results of this operation with the provided `values` if
+  /// the given callback returns true.
+  template <typename ValuesT>
+  void replaceUsesWithIf(ValuesT &&values,
+                         function_ref<bool(OpOperand &)> shouldReplace) {
+    getResults().replaceUsesWithIf(std::forward<ValuesT>(values),
+                                   shouldReplace);
+  }
+
   /// Destroys this operation and its subclass data.
   void destroy();
 
index 8873260..0f7354b 100644 (file)
@@ -279,6 +279,26 @@ public:
   /// Replace all uses of results of this range with results of 'op'.
   void replaceAllUsesWith(Operation *op);
 
+  /// Replace uses of results of this range with the provided 'values' if the
+  /// given callback returns true. The size of `values` must match the size of
+  /// this range.
+  template <typename ValuesT>
+  std::enable_if_t<!std::is_convertible<ValuesT, Operation *>::value>
+  replaceUsesWithIf(ValuesT &&values,
+                    function_ref<bool(OpOperand &)> shouldReplace) {
+    assert(static_cast<size_t>(std::distance(values.begin(), values.end())) ==
+               size() &&
+           "expected 'values' to correspond 1-1 with the number of results");
+
+    for (auto it : llvm::zip(*this, values))
+      std::get<0>(it).replaceUsesWithIf(std::get<1>(it), shouldReplace);
+  }
+
+  /// Replace uses of results of this range with results of `op` if the given
+  /// callback returns true.
+  void replaceUsesWithIf(Operation *op,
+                         function_ref<bool(OpOperand &)> shouldReplace);
+
   //===--------------------------------------------------------------------===//
   // Users
   //===--------------------------------------------------------------------===//
index 20ce9b3..a38a12d 100644 (file)
@@ -589,6 +589,11 @@ void ResultRange::replaceAllUsesWith(Operation *op) {
   replaceAllUsesWith(op->getResults());
 }
 
+void ResultRange::replaceUsesWithIf(
+    Operation *op, function_ref<bool(OpOperand &)> shouldReplace) {
+  replaceUsesWithIf(op->getResults(), shouldReplace);
+}
+
 //===----------------------------------------------------------------------===//
 // ValueRange
 
index 93e5c95..e98cccc 100644 (file)
@@ -124,12 +124,9 @@ void CSE::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
   } else {
     // When the region does not have SSA dominance, we need to check if we
     // have visited a use before replacing any use.
-    for (auto it : llvm::zip(op->getResults(), existing->getResults())) {
-      std::get<0>(it).replaceUsesWithIf(
-          std::get<1>(it), [&](OpOperand &operand) {
-            return !knownValues.count(operand.getOwner());
-          });
-    }
+    op->replaceUsesWithIf(existing->getResults(), [&](OpOperand &operand) {
+      return !knownValues.count(operand.getOwner());
+    });
 
     // There may be some remaining uses of the operation.
     if (op->use_empty())