2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
7 #include <boost/assert.hpp>
8 #include <boost/optional.hpp>
29 template <typename TNodeId>
30 boost::optional<TNodeId> GetNextChild(TNodeId node,
31 std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
32 std::map<TNodeId, NodeState>& nodeStates)
34 for (TNodeId childNode : getIncomingEdges(node))
36 if (nodeStates.find(childNode) == nodeStates.end())
42 if (nodeStates.find(childNode)->second == NodeState::Visiting)
52 template<typename TNodeId>
53 bool TopologicallySort(
55 std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
56 std::vector<TNodeId>& outSorted,
57 std::map<TNodeId, NodeState>& nodeStates)
59 std::stack<TNodeId> nodeStack;
61 // If the node is never visited we should search it
62 if (nodeStates.find(initialNode) == nodeStates.end())
64 nodeStack.push(initialNode);
67 while (!nodeStack.empty())
69 TNodeId current = nodeStack.top();
71 nodeStates[current] = NodeState::Visiting;
73 boost::optional<TNodeId> nextChildOfCurrent = GetNextChild(current, getIncomingEdges, nodeStates);
75 if (nextChildOfCurrent)
77 TNodeId nextChild = nextChildOfCurrent.get();
79 // If the child has not been searched, add to the stack and iterate over this node
80 if (nodeStates.find(nextChild) == nodeStates.end())
82 nodeStack.push(nextChild);
86 // If we re-encounter a node being visited there is a cycle
87 if (nodeStates[nextChild] == NodeState::Visiting)
95 nodeStates[current] = NodeState::Visited;
96 outSorted.push_back(current);
104 // Sorts a directed acyclic graph (DAG) into a flat list such that all inputs to a node are before the node itself.
105 // Returns true if successful or false if there is an error in the graph structure (e.g. it contains a cycle).
106 // The graph is defined entirely by the "getIncomingEdges" function which the user provides. For a given node,
107 // it must return the list of nodes which are required to come before it.
108 // "targetNodes" is the list of nodes where the search begins - i.e. the nodes that you want to evaluate.
109 // This is an iterative implementation based on https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
110 template<typename TNodeId, typename TTargetNodes>
111 bool GraphTopologicalSort(
112 const TTargetNodes& targetNodes,
113 std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
114 std::vector<TNodeId>& outSorted)
117 std::map<TNodeId, NodeState> nodeStates;
119 for (TNodeId targetNode : targetNodes)
121 if (!TopologicallySort(targetNode, getIncomingEdges, outSorted, nodeStates))