Release 18.08
[platform/upstream/armnn.git] / src / armnnUtils / GraphTopologicalSort.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #pragma once
6
7 #include <boost/assert.hpp>
8 #include <boost/optional.hpp>
9
10 #include <functional>
11 #include <map>
12 #include <stack>
13 #include <vector>
14
15
16 namespace armnnUtils
17 {
18
19 namespace
20 {
21
22 enum class NodeState
23 {
24     Visiting,
25     Visited,
26 };
27
28
29 template <typename TNodeId>
30 boost::optional<TNodeId> GetNextChild(TNodeId node,
31                                       std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
32                                       std::map<TNodeId, NodeState>& nodeStates)
33 {
34     for (TNodeId childNode : getIncomingEdges(node))
35     {
36         if (nodeStates.find(childNode) == nodeStates.end())
37         {
38             return childNode;
39         }
40         else
41         {
42             if (nodeStates.find(childNode)->second == NodeState::Visiting)
43             {
44                 return childNode;
45             }
46         }
47     }
48
49     return {};
50 }
51
52 template<typename TNodeId>
53 bool TopologicallySort(
54     TNodeId initialNode,
55     std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
56     std::vector<TNodeId>& outSorted,
57     std::map<TNodeId, NodeState>& nodeStates)
58 {
59     std::stack<TNodeId> nodeStack;
60
61     // If the node is never visited we should search it
62     if (nodeStates.find(initialNode) == nodeStates.end())
63     {
64         nodeStack.push(initialNode);
65     }
66
67     while (!nodeStack.empty())
68     {
69         TNodeId current = nodeStack.top();
70
71         nodeStates[current] = NodeState::Visiting;
72
73         boost::optional<TNodeId> nextChildOfCurrent = GetNextChild(current, getIncomingEdges, nodeStates);
74
75         if (nextChildOfCurrent)
76         {
77             TNodeId nextChild = nextChildOfCurrent.get();
78
79             // If the child has not been searched, add to the stack and iterate over this node
80             if (nodeStates.find(nextChild) == nodeStates.end())
81             {
82                 nodeStack.push(nextChild);
83                 continue;
84             }
85
86             // If we re-encounter a node being visited there is a cycle
87             if (nodeStates[nextChild] == NodeState::Visiting)
88             {
89                 return false;
90             }
91         }
92
93         nodeStack.pop();
94
95         nodeStates[current] = NodeState::Visited;
96         outSorted.push_back(current);
97     }
98
99     return true;
100 }
101
102 }
103
104 // Sorts a directed acyclic graph (DAG) into a flat list such that all inputs to a node are before the node itself.
105 // Returns true if successful or false if there is an error in the graph structure (e.g. it contains a cycle).
106 // The graph is defined entirely by the "getIncomingEdges" function which the user provides. For a given node,
107 // it must return the list of nodes which are required to come before it.
108 // "targetNodes" is the list of nodes where the search begins - i.e. the nodes that you want to evaluate.
109 // This is an iterative implementation based on https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
110 template<typename TNodeId, typename TTargetNodes>
111 bool GraphTopologicalSort(
112     const TTargetNodes& targetNodes,
113     std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
114     std::vector<TNodeId>& outSorted)
115 {
116     outSorted.clear();
117     std::map<TNodeId, NodeState> nodeStates;
118
119     for (TNodeId targetNode : targetNodes)
120     {
121         if (!TopologicallySort(targetNode, getIncomingEdges, outSorted, nodeStates))
122         {
123             return false;
124         }
125     }
126
127     return true;
128 }
129
130 }