[mlir][cse] do not replace operands in previously simplified operations
authorAndrew Young <youngar17@gmail.com>
Mon, 29 Mar 2021 02:25:32 +0000 (19:25 -0700)
committerAndrew Young <youngar17@gmail.com>
Wed, 31 Mar 2021 19:20:34 +0000 (12:20 -0700)
If an operation has been inserted as a key in to the known values
hashtable, then it can not be modified in a way which changes its hash.
This change avoids modifying the operands of any previously recorded
operation, which prevents their hash from changing.

In an SSACFG region, it is impossible to visit an operation before
visiting its operands, so this is not a problem. This situation can only
happen in regions without strict dominance, such as graph regions.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D99486

mlir/lib/Transforms/CSE.cpp
mlir/test/Transforms/cse.mlir

index 57f3904..efeb051 100644 (file)
@@ -72,10 +72,10 @@ struct CSE : public CSEBase<CSE> {
 
   /// Attempt to eliminate a redundant operation. Returns success if the
   /// operation was marked for removal, failure otherwise.
-  LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op);
-
+  LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op,
+                                  bool hasSSADominance);
   void simplifyBlock(ScopedMapTy &knownValues, DominanceInfo &domInfo,
-                     Block *bb);
+                     Block *bb, bool hasSSADominance);
   void simplifyRegion(ScopedMapTy &knownValues, DominanceInfo &domInfo,
                       Region &region);
 
@@ -88,7 +88,8 @@ private:
 } // end anonymous namespace
 
 /// Attempt to eliminate a redundant operation.
-LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op) {
+LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op,
+                                     bool hasSSADominance) {
   // Don't simplify terminator operations.
   if (op->hasTrait<OpTrait::IsTerminator>())
     return failure();
@@ -113,10 +114,29 @@ LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op) {
 
   // Look for an existing definition for the operation.
   if (auto *existing = knownValues.lookup(op)) {
+
     // If we find one then replace all uses of the current operation with the
-    // existing one and mark it for deletion.
-    op->replaceAllUsesWith(existing);
-    opsToErase.push_back(op);
+    // existing one and mark it for deletion. We can only replace an operand in
+    // an operation if it has not been visited yet.
+    if (hasSSADominance) {
+      // If the region has SSA dominance, then we are guaranteed to have not
+      // visited any use of the current operation.
+      op->replaceAllUsesWith(existing);
+      opsToErase.push_back(op);
+    } else {
+      // When the region does not have SSA dominance, we need to check if we
+      // have visited a use before replacing any use.
+      for (auto it : llvm::zip(op->getResults(), existing->getResults())) {
+        std::get<0>(it).replaceUsesWithIf(
+            std::get<1>(it), [&](OpOperand &operand) {
+              return !knownValues.count(operand.getOwner());
+            });
+      }
+
+      // There may be some remaining uses of the operation.
+      if (op->use_empty())
+        opsToErase.push_back(op);
+    }
 
     // If the existing operation has an unknown location and the current
     // operation doesn't, then set the existing op's location to that of the
@@ -136,10 +156,10 @@ LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op) {
 }
 
 void CSE::simplifyBlock(ScopedMapTy &knownValues, DominanceInfo &domInfo,
-                        Block *bb) {
+                        Block *bb, bool hasSSADominance) {
   for (auto &inst : *bb) {
     // If the operation is simplified, we don't process any held regions.
-    if (succeeded(simplifyOperation(knownValues, &inst)))
+    if (succeeded(simplifyOperation(knownValues, &inst, hasSSADominance)))
       continue;
 
     // If this operation is isolated above, we can't process nested regions with
@@ -164,17 +184,19 @@ void CSE::simplifyRegion(ScopedMapTy &knownValues, DominanceInfo &domInfo,
   if (region.empty())
     return;
 
+  bool hasSSADominance = domInfo.hasDominanceInfo(&region);
+
   // If the region only contains one block, then simplify it directly.
   if (std::next(region.begin()) == region.end()) {
     ScopedMapTy::ScopeTy scope(knownValues);
-    simplifyBlock(knownValues, domInfo, &region.front());
+    simplifyBlock(knownValues, domInfo, &region.front(), hasSSADominance);
     return;
   }
 
   // If the region does not have dominanceInfo, then skip it.
   // TODO: Regions without SSA dominance should define a different
   // traversal order which is appropriate and can be used here.
-  if (!domInfo.hasDominanceInfo(&region))
+  if (!hasSSADominance)
     return;
 
   // Note, deque is being used here because there was significant performance
@@ -195,7 +217,8 @@ void CSE::simplifyRegion(ScopedMapTy &knownValues, DominanceInfo &domInfo,
     // Check to see if we need to process this node.
     if (!currentNode->processed) {
       currentNode->processed = true;
-      simplifyBlock(knownValues, domInfo, currentNode->node->getBlock());
+      simplifyBlock(knownValues, domInfo, currentNode->node->getBlock(),
+                    hasSSADominance);
     }
 
     // Otherwise, check to see if we need to process a child node.
index 455d560..17109f7 100644 (file)
@@ -244,3 +244,24 @@ func @nested_isolated() -> i32 {
 
   return %0 : i32
 }
+
+/// This test is checking that CSE gracefully handles values in graph regions
+/// where the use occurs before the def, and one of the defs could be CSE'd with
+/// the other.
+// CHECK-LABEL: @use_before_def
+func @use_before_def() {
+  // CHECK-NEXT: test.graph_region
+  test.graph_region {
+    // CHECK-NEXT: addi %c1_i32, %c1_i32_0
+    %0 = addi %1, %2 : i32
+
+    // CHECK-NEXT: constant 1
+    // CHECK-NEXT: constant 1
+    %1 = constant 1 : i32
+    %2 = constant 1 : i32
+
+    // CHECK-NEXT: "foo.yield"(%0) : (i32) -> ()
+    "foo.yield"(%0) : (i32) -> ()
+  }
+  return
+}