Added a utility to compute a topo ordering of a graph
authorBenoit Steiner <bsteiner@google.com>
Thu, 12 Apr 2018 22:41:41 +0000 (15:41 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 12 Apr 2018 22:43:56 +0000 (15:43 -0700)
PiperOrigin-RevId: 192683166

tensorflow/core/grappler/utils/topological_sort.cc
tensorflow/core/grappler/utils/topological_sort.h
tensorflow/core/grappler/utils/topological_sort_test.cc

index 8d8ff4d..a8e464d 100644 (file)
@@ -26,24 +26,24 @@ namespace grappler {
 
 // Kahn's algorithm is implemented.
 // For details, see https://en.wikipedia.org/wiki/Topological_sorting
-Status TopologicalSort(GraphDef* graph) {
+Status ComputeTopologicalOrder(const GraphDef& graph,
+                               std::vector<int>* ready_nodes) {
   SimpleGraphView graph_view;
-  TF_RETURN_IF_ERROR(graph_view.Initialize(*graph));
+  TF_RETURN_IF_ERROR(graph_view.Initialize(graph));
 
-  std::vector<int> ready_nodes;
-  ready_nodes.reserve(graph_view.num_nodes());
+  ready_nodes->reserve(graph_view.num_nodes());
 
   int front = 0;
   int back = 0;
   std::vector<int> num_ready_inputs(graph_view.num_nodes(), 0);
   for (int i = 0; i < graph_view.num_nodes(); i++) {
     if (graph_view.inputs(i).empty()) {
-      ready_nodes.push_back(i);
+      ready_nodes->push_back(i);
       back++;
     }
-    if (IsMerge(graph->node(i))) {
+    if (IsMerge(graph.node(i))) {
       for (int input : graph_view.inputs(i)) {
-        if (IsNextIteration(graph->node(input))) {
+        if (IsNextIteration(graph.node(input))) {
           num_ready_inputs[i]++;
         }
       }
@@ -51,11 +51,11 @@ Status TopologicalSort(GraphDef* graph) {
   }
 
   while (front != back) {
-    int ready_node = ready_nodes[front];
+    int ready_node = (*ready_nodes)[front];
     for (int fanout : graph_view.outputs(ready_node)) {
       ++num_ready_inputs[fanout];
       if (num_ready_inputs[fanout] == graph_view.inputs(fanout).size()) {
-        ready_nodes.push_back(fanout);
+        ready_nodes->push_back(fanout);
         ++back;
       }
     }
@@ -66,7 +66,24 @@ Status TopologicalSort(GraphDef* graph) {
     return errors::InvalidArgument(
         "The graph couldn't be sorted in topological order.");
   }
+  return Status::OK();
+}
 
+Status ComputeTopologicalOrder(
+    const GraphDef& graph,
+    std::unordered_map<const NodeDef*, int>* topo_order) {
+  std::vector<int> ready_nodes;
+  TF_RETURN_IF_ERROR(ComputeTopologicalOrder(graph, &ready_nodes));
+  topo_order->reserve(graph.node_size());
+  for (int i = 0; i < ready_nodes.size(); ++i) {
+    (*topo_order)[&graph.node(ready_nodes[i])] = i;
+  }
+  return Status::OK();
+}
+
+Status TopologicalSort(GraphDef* graph) {
+  std::vector<int> ready_nodes;
+  TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, &ready_nodes));
   PermuteNodesInPlace(graph, &ready_nodes, /*invert_permutation=*/true);
   return Status::OK();
 }
index 7700fe4..668c88d 100644 (file)
@@ -22,6 +22,10 @@ limitations under the License.
 namespace tensorflow {
 namespace grappler {
 
+// Compute a topological ordering for the graph nodes.
+Status ComputeTopologicalOrder(
+    const GraphDef& graph, std::unordered_map<const NodeDef*, int>* topo_order);
+
 // Sort a graph in topological order.
 Status TopologicalSort(GraphDef* graph);
 
index c96f15b..f5c9500 100644 (file)
@@ -52,8 +52,19 @@ TEST_F(TopologicalSortTest, NoLoop) {
   *graph.add_node() = CreateNode("5", {});
   *graph.add_node() = CreateNode("4", {});
 
+  std::unordered_map<const NodeDef*, int> topo_order;
+  TF_EXPECT_OK(ComputeTopologicalOrder(graph, &topo_order));
+
+  const std::vector<string> order = {"5", "4", "2", "0", "3", "1"};
+  for (const auto& topo : topo_order) {
+    const string& node_name = topo.first->name();
+    const int topo_order = topo.second;
+    std::cout << "Node " << node_name << " at order " << topo_order
+              << std::endl;
+    EXPECT_EQ(node_name, order[topo_order]);
+  }
+
   TF_EXPECT_OK(TopologicalSort(&graph));
-  std::vector<string> order = {"5", "4", "2", "0", "3", "1"};
   for (int i = 0; i < order.size(); i++) {
     EXPECT_EQ(graph.node(i).name(), order[i]);
   }
@@ -68,8 +79,17 @@ TEST_F(TopologicalSortTest, WithLoop) {
   *graph.add_node() = CreateNode("5", "NextIteration", {"4"});
   *graph.add_node() = CreateNode("1", {});
 
+  std::unordered_map<const NodeDef*, int> topo_order;
+  TF_EXPECT_OK(ComputeTopologicalOrder(graph, &topo_order));
+
+  const std::vector<string> order = {"1", "2", "3", "4", "5"};
+  for (const auto& topo : topo_order) {
+    const string& node_name = topo.first->name();
+    const int topo_order = topo.second;
+    EXPECT_EQ(node_name, order[topo_order]);
+  }
+
   TF_EXPECT_OK(TopologicalSort(&graph));
-  std::vector<string> order = {"1", "2", "3", "4", "5"};
   for (int i = 0; i < order.size(); i++) {
     EXPECT_EQ(graph.node(i).name(), order[i]);
   }