// Kahn's algorithm is implemented.
// For details, see https://en.wikipedia.org/wiki/Topological_sorting
-Status TopologicalSort(GraphDef* graph) {
+Status ComputeTopologicalOrder(const GraphDef& graph,
+ std::vector<int>* ready_nodes) {
SimpleGraphView graph_view;
- TF_RETURN_IF_ERROR(graph_view.Initialize(*graph));
+ TF_RETURN_IF_ERROR(graph_view.Initialize(graph));
- std::vector<int> ready_nodes;
- ready_nodes.reserve(graph_view.num_nodes());
+ ready_nodes->reserve(graph_view.num_nodes());
int front = 0;
int back = 0;
std::vector<int> num_ready_inputs(graph_view.num_nodes(), 0);
for (int i = 0; i < graph_view.num_nodes(); i++) {
if (graph_view.inputs(i).empty()) {
- ready_nodes.push_back(i);
+ ready_nodes->push_back(i);
back++;
}
- if (IsMerge(graph->node(i))) {
+ if (IsMerge(graph.node(i))) {
for (int input : graph_view.inputs(i)) {
- if (IsNextIteration(graph->node(input))) {
+ if (IsNextIteration(graph.node(input))) {
num_ready_inputs[i]++;
}
}
}
while (front != back) {
- int ready_node = ready_nodes[front];
+ int ready_node = (*ready_nodes)[front];
for (int fanout : graph_view.outputs(ready_node)) {
++num_ready_inputs[fanout];
if (num_ready_inputs[fanout] == graph_view.inputs(fanout).size()) {
- ready_nodes.push_back(fanout);
+ ready_nodes->push_back(fanout);
++back;
}
}
return errors::InvalidArgument(
"The graph couldn't be sorted in topological order.");
}
+ return Status::OK();
+}
+Status ComputeTopologicalOrder(
+ const GraphDef& graph,
+ std::unordered_map<const NodeDef*, int>* topo_order) {
+ std::vector<int> ready_nodes;
+ TF_RETURN_IF_ERROR(ComputeTopologicalOrder(graph, &ready_nodes));
+ topo_order->reserve(graph.node_size());
+ for (int i = 0; i < ready_nodes.size(); ++i) {
+ (*topo_order)[&graph.node(ready_nodes[i])] = i;
+ }
+ return Status::OK();
+}
+
+Status TopologicalSort(GraphDef* graph) {
+ std::vector<int> ready_nodes;
+ TF_RETURN_IF_ERROR(ComputeTopologicalOrder(*graph, &ready_nodes));
PermuteNodesInPlace(graph, &ready_nodes, /*invert_permutation=*/true);
return Status::OK();
}
*graph.add_node() = CreateNode("5", {});
*graph.add_node() = CreateNode("4", {});
+ std::unordered_map<const NodeDef*, int> topo_order;
+ TF_EXPECT_OK(ComputeTopologicalOrder(graph, &topo_order));
+
+ const std::vector<string> order = {"5", "4", "2", "0", "3", "1"};
+ for (const auto& topo : topo_order) {
+ const string& node_name = topo.first->name();
+ const int topo_order = topo.second;
+ std::cout << "Node " << node_name << " at order " << topo_order
+ << std::endl;
+ EXPECT_EQ(node_name, order[topo_order]);
+ }
+
TF_EXPECT_OK(TopologicalSort(&graph));
- std::vector<string> order = {"5", "4", "2", "0", "3", "1"};
for (int i = 0; i < order.size(); i++) {
EXPECT_EQ(graph.node(i).name(), order[i]);
}
*graph.add_node() = CreateNode("5", "NextIteration", {"4"});
*graph.add_node() = CreateNode("1", {});
+ std::unordered_map<const NodeDef*, int> topo_order;
+ TF_EXPECT_OK(ComputeTopologicalOrder(graph, &topo_order));
+
+ const std::vector<string> order = {"1", "2", "3", "4", "5"};
+ for (const auto& topo : topo_order) {
+ const string& node_name = topo.first->name();
+ const int topo_order = topo.second;
+ EXPECT_EQ(node_name, order[topo_order]);
+ }
+
TF_EXPECT_OK(TopologicalSort(&graph));
- std::vector<string> order = {"1", "2", "3", "4", "5"};
for (int i = 0; i < order.size(); i++) {
EXPECT_EQ(graph.node(i).name(), order[i]);
}