Add utility 'replaceAllUsesWith' methods to Operation.
authorRiver Riddle <riverriddle@google.com>
Wed, 7 Aug 2019 20:48:19 +0000 (13:48 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 7 Aug 2019 20:48:52 +0000 (13:48 -0700)
These methods will allow replacing the uses of results with an existing operation, with the same number of results, or a range of values. This removes a number of hand-rolled result replacement loops and simplifies replacement for operations with multiple results.

PiperOrigin-RevId: 262206600

mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/Operation.h
mlir/lib/IR/PatternMatch.cpp
mlir/lib/Transforms/CSE.cpp
mlir/lib/Transforms/MemRefDataFlowOpt.cpp
mlir/lib/Transforms/Utils/Utils.cpp

index 16777ba2cc27e1d5c2aadd22a3f931fd1a27c8bf..6c75cb54cfde08599a739e2bbd0153d94ab12593 100644 (file)
@@ -529,9 +529,10 @@ struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
   /// Return the result at index 'i'.
   Value *getResult(unsigned i) { return this->getOperation()->getResult(i); }
 
-  /// Set the result at index 'i' to 'value'.
-  void setResult(unsigned i, Value *value) {
-    this->getOperation()->setResult(i, value);
+  /// Replace all uses of results of this operation with the provided 'values'.
+  /// 'values' may correspond to an existing operation, or a range of 'Value'.
+  template <typename ValuesT> void replaceAllUsesWith(ValuesT &&values) {
+    this->getOperation()->replaceAllUsesWith(std::forward<ValuesT>(values));
   }
 
   /// Return the type of the `i`-th result.
@@ -572,6 +573,11 @@ public:
     getResult()->replaceAllUsesWith(newValue);
   }
 
+  /// Replace all uses of 'this' value with the result of 'op'.
+  void replaceAllUsesWith(Operation *op) {
+    this->getOperation()->replaceAllUsesWith(op);
+  }
+
   static LogicalResult verifyTrait(Operation *op) {
     return impl::verifyOneResult(op);
   }
index 6e17ef063f86ab4ce767ceb85b7e256964af5044..515cd857dd05fb03f378dd18aaf4018cef360bd3 100644 (file)
@@ -137,6 +137,25 @@ public:
   /// Replace any uses of 'from' with 'to' within this operation.
   void replaceUsesOfWith(Value *from, Value *to);
 
+  /// Replace all uses of results of this operation with the provided 'values'.
+  template <typename ValuesT,
+            typename = decltype(std::declval<ValuesT>().begin())>
+  void replaceAllUsesWith(ValuesT &&values) {
+    assert(std::distance(values.begin(), values.end()) == getNumResults() &&
+           "expected 'values' to correspond 1-1 with the number of results");
+
+    auto valueIt = values.begin();
+    for (unsigned i = 0, e = getNumResults(); i != e; ++i)
+      getResult(i)->replaceAllUsesWith(*(valueIt++));
+  }
+
+  /// Replace all uses of results of this operation with results of 'op'.
+  void replaceAllUsesWith(Operation *op) {
+    assert(getNumResults() == op->getNumResults());
+    for (unsigned i = 0, e = getNumResults(); i != e; ++i)
+      getResult(i)->replaceAllUsesWith(op->getResult(i));
+  }
+
   /// Destroys this operation and its subclass data.
   void destroy();
 
index 94fa7ab43f7d0280f7f3cbf700745c01592a6fde..b575abe941d653c4cc9233c354f131a0731dd444 100644 (file)
@@ -91,8 +91,7 @@ void PatternRewriter::replaceOp(Operation *op, ArrayRef<Value *> newValues,
 
   assert(op->getNumResults() == newValues.size() &&
          "incorrect # of replacement values");
-  for (unsigned i = 0, e = newValues.size(); i != e; ++i)
-    op->getResult(i)->replaceAllUsesWith(newValues[i]);
+  op->replaceAllUsesWith(newValues);
 
   notifyOperationRemoved(op);
   op->erase();
index 188db62549096b563ffe3f8f15a8248aacd334ad..eeb63e7f9eb2fefa8ff873939da729748610b29d 100644 (file)
@@ -150,8 +150,7 @@ LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op) {
   if (auto *existing = knownValues.lookup(op)) {
     // If we find one then replace all uses of the current operation with the
     // existing one and mark it for deletion.
-    for (unsigned i = 0, e = existing->getNumResults(); i != e; ++i)
-      op->getResult(i)->replaceAllUsesWith(existing->getResult(i));
+    op->replaceAllUsesWith(existing);
     opsToErase.push_back(op);
 
     // If the existing operation has an unknown location and the current
index 93f7331f7a31489d29e6c3524c10024c93965475..4f8b1c61cbf08a230358134cf504d0b0e3b45e2b 100644 (file)
@@ -204,7 +204,7 @@ void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) {
 
   // Perform the actual store to load forwarding.
   Value *storeVal = cast<AffineStoreOp>(lastWriteStoreOp).getValueToStore();
-  loadOp.getResult()->replaceAllUsesWith(storeVal);
+  loadOp.replaceAllUsesWith(storeVal);
   // Record the memref for a later sweep to optimize away.
   memrefsToErase.insert(loadOp.getMemRef());
   // Record this to erase later.
index 55b831010987d45dc2110272f0b0538a8a3057b0..250c76913c26c09290121e9c2eac5fad8ef5289b 100644 (file)
@@ -242,11 +242,8 @@ bool mlir::replaceAllMemRefUsesWith(Value *oldMemRef, Value *newMemRef,
 
     // Create the new operation.
     auto *repOp = builder.createOperation(state);
-    // Replace old memref's deferencing op's uses.
-    unsigned r = 0;
-    for (auto *res : opInst->getResults()) {
-      res->replaceAllUsesWith(repOp->getResult(r++));
-    }
+    opInst->replaceAllUsesWith(repOp);
+
     // Collect and erase at the end since one of these op's could be
     // domInstFilter or postDomInstFilter as well!
     opsToErase.push_back(opInst);