[SCCIterator] Fix an issue in scc_member_iterator sorting
authorHongtao Yu <hoy@fb.com>
Tue, 21 Mar 2023 17:57:11 +0000 (10:57 -0700)
committerHongtao Yu <hoy@fb.com>
Mon, 27 Mar 2023 17:48:05 +0000 (10:48 -0700)
Members in an scc are supposed to be sorted in a top-down or topological order based on edge weights. Previously this is achived by building a MST out of the SCC and enforcing an BFS walk on the MST. A BFS on a tree does give a top-down topological order, however, the MST built here isn't really a tree. This is becuase of a trick done to avoid expansive detection of a cycle on a directed graph when an edge is added. When the MST is built, its edges are considered undirected. But in reality they are directed, thus a BST walk doesn't necessarily give a topological order. I'm tweaking the BFS walk slightly to yield a topological order.

Basically I'm using Kahn's algorithm on MST to compute a topological traversal order. The algorithm starts from nodes that have no incoming edge. These nodes are "roots" of the MST forest. This ensures that nodes are visited before their descendants are, thus ensures a topological traversal order of the MST.

Reviewed By: wenlei

Differential Revision: https://reviews.llvm.org/D130717

llvm/include/llvm/ADT/SCCIterator.h

index e4035a0..267139d 100644 (file)
@@ -23,6 +23,7 @@
 #define LLVM_ADT_SCCITERATOR_H
 
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/GraphTraits.h"
 #include "llvm/ADT/iterator.h"
 #include <cassert>
@@ -258,7 +259,8 @@ class scc_member_iterator {
   struct NodeInfo {
     NodeInfo *Group = this;
     uint32_t Rank = 0;
-    bool Visited = true;
+    bool Visited = false;
+    DenseSet<const EdgeType *> IncomingMSTEdges;
   };
 
   // Find the root group of the node and compress the path from node to the
@@ -340,20 +342,22 @@ scc_member_iterator<GraphT, GT>::scc_member_iterator(
       MSTEdges.insert(Edge);
   }
 
-  // Do BFS on MST, starting from nodes that have no incoming edge. These nodes
-  // are "roots" of the MST forest. This ensures that nodes are visited before
-  // their decsendents are, thus ensures hot edges are processed before cold
-  // edges, based on how MST is computed.
+  // Run Kahn's algorithm on MST to compute a topological traversal order.
+  // The algorithm starts from nodes that have no incoming edge. These nodes are
+  // "roots" of the MST forest. This ensures that nodes are visited before their
+  // descendants are, thus ensures hot edges are processed before cold edges,
+  // based on how MST is computed.
+  std::queue<NodeType *> Queue;
   for (const auto *Edge : MSTEdges)
-    NodeInfoMap[Edge->Target].Visited = false;
+    NodeInfoMap[Edge->Target].IncomingMSTEdges.insert(Edge);
 
-  std::queue<NodeType *> Queue;
-  // Initialze the queue with MST roots. Note that walking through SortedEdges
-  // instead of NodeInfoMap ensures an ordered deterministic push.
+  // Walk through SortedEdges to initialize the queue, instead of using NodeInfoMap
+  // to ensure an ordered deterministic push.
   for (auto *Edge : SortedEdges) {
-    if (NodeInfoMap[Edge->Source].Visited) {
+    if (!NodeInfoMap[Edge->Source].Visited &&
+        NodeInfoMap[Edge->Source].IncomingMSTEdges.empty()) {
       Queue.push(Edge->Source);
-      NodeInfoMap[Edge->Source].Visited = false;
+      NodeInfoMap[Edge->Source].Visited = true;
     }
   }
 
@@ -362,8 +366,9 @@ scc_member_iterator<GraphT, GT>::scc_member_iterator(
     Queue.pop();
     Nodes.push_back(Node);
     for (auto &Edge : Node->Edges) {
-      if (MSTEdges.count(&Edge) && !NodeInfoMap[Edge.Target].Visited) {
-        NodeInfoMap[Edge.Target].Visited = true;
+      NodeInfoMap[Edge.Target].IncomingMSTEdges.erase(&Edge);
+      if (MSTEdges.count(&Edge) &&
+          NodeInfoMap[Edge.Target].IncomingMSTEdges.empty()) {
         Queue.push(Edge.Target);
       }
     }