Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / utils / graph.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 from collections import deque
18 from re import match, compile
19
20 import logging as log
21 import networkx as nx
22
23 from mo.graph.graph import Node, Graph
24 from mo.utils.error import Error
25 from mo.utils.utils import refer_to_faq_msg
26
27
28 def backward_bfs_for_operation(start_node: Node, op_names: list):
29     """
30     Find node with 'op' attribute equal to one of from 'op_name', searching in the backward direction.
31     In case of branching algorithm goes into each branch, but if it can't find layer in one of them it returns
32     empty list.
33
34     :param start_node: Start node for BFS algorithm
35     :param op_names: The list with names of operations to search
36     """
37     ret = []
38     q = deque([start_node])
39     while len(q) != 0:
40         node = q.popleft()
41         in_nodes_size = len(node.in_nodes())
42         for id in range(in_nodes_size):  # in_nodes() can return either list or dict
43             pnode = node.in_node(id)
44             if pnode.kind == 'op':
45                 if pnode.has_valid('op') and pnode.op in op_names:
46                     if pnode.id not in ret:
47                         ret.append(pnode.id)
48                 else:
49                     q.append(pnode)
50             elif pnode.kind == 'data' and pnode.value is None:
51                 q.append(pnode)
52     return [Node(start_node.graph, x) for x in ret]
53
54
55 def bfs_search(graph: Graph, start_nodes: list = list()):
56     """
57     Performs breadth-first search over a graph and returns a list of nodes in the BFS order.
58     :param graph: networkx graph to traverse.
59     :param start_nodes: list of start nodes of the graph. If the list is empty then start from all nodes that do not
60     have input nodes.
61     :return: the list of nodes in the BFS order.
62     """
63     result = list()
64     if len(start_nodes) == 0:
65         start_nodes = [node_name for node_name in graph.nodes() if len(graph.in_edges(node_name)) == 0]
66
67     visited = set(start_nodes)
68     d = deque(start_nodes)
69
70     while len(d) != 0:
71         cur_node_name = d.popleft()
72         result.append(cur_node_name)
73         for src_node, dst_node in graph.out_edges(cur_node_name):
74             if dst_node not in visited:
75                 d.append(dst_node)
76                 visited.add(dst_node)
77     return result
78
79
80 def dfs(graph: Graph, node_name: str, visited: set):
81     """
82     Implementation of the depth-first search algorithm starting from the specific node.
83     :param graph: networkx graph to operate on.
84     :param node_name: node name to start search from.
85     :param visited: set of already visited nodes.
86     :return: list of nodes in the DFS-visit order.
87     """
88     order = []
89     stack = [node_name]
90     while len(stack) != 0:
91         node_name = stack[0]
92         stack.pop(0)
93         visited.add(node_name)
94         has_child = False
95         for _, out_node_name in graph.out_edges(node_name):
96             if out_node_name not in visited:
97                 stack.insert(0, node_name)
98                 stack.insert(0, out_node_name)
99                 has_child = True
100                 break
101         if not has_child:
102             order.append(node_name)
103     return order
104
105
106 def pseudo_topological_sort(graph: Graph, reverse: bool = False):
107     """
108     The function performs topological sort but doesn't check for cycle existence. So it may produce wrong nodes order
109     for some applications.
110     :param graph: graph to pseudo-topologically sort.
111     :param reverse: flag indicating whether need to reverse nodes order.
112     :return: nodes in the topological sort if cycle doesn't exist and in pseudo-topological sort if not.
113     """
114     nodes_without_inputs = list()
115     for node_name in graph.nodes():
116         if len(graph.in_edges(node_name)) == 0:
117             nodes_without_inputs.append(node_name)
118     order = list()
119     visited = set()
120     for node_name in nodes_without_inputs:
121         if node_name not in visited:
122             order.extend(dfs(graph, node_name, visited))
123
124     if reverse:
125         return order
126     else:
127         return list(reversed(order))
128
129
130 def nodes_matching_name_pattern(graph: Graph, pattern: str):
131     """
132     Returns list of node names of the graph that match regular expression.
133     :param graph: graph to operate on.
134     :param pattern: regular expression describing node name pattern.
135     :return: list of matched node names.
136     """
137     compiled_pattern = compile(pattern)
138     return [node_name for node_name in list(graph.nodes()) if match(compiled_pattern, node_name)]
139
140
141 def is_connected_component(graph: Graph, node_names: list):
142     """
143     Checks that specified list of nodes forms a connected sub-graph. It ignores edges direction.
144     The algorithm is the following. Run BFS from one of the nodes from the node_names list ignoring edges order and
145     visiting only nodes from the node_names list. Prepare list of visited nodes. If this list is equal to the
146     node_names list (we actually check that the node_names set is sub-set of 'visited' set that is equivalent) then the
147     sub-graph is connected.
148     :param graph: graph to operate on.
149     :param node_names: list of node names to be checked.
150     :return: Result of the check.
151     """
152     if len(node_names) == 0:
153         return True
154
155     d = deque([node_names[0]])
156     visited = set([node_names[0]])
157     while len(d) != 0:
158         cur_node_name = d.popleft()
159         visited.add(cur_node_name)
160         # find adjacent nodes from the list of node_names. Ignoring edges direction
161         adj_nodes = [src_node for src_node, _ in graph.in_edges(cur_node_name) if src_node in node_names] + \
162                     [dst_node for _, dst_node in graph.out_edges(cur_node_name) if dst_node in node_names]
163         for adj_node in adj_nodes:
164             if adj_node not in visited:
165                 d.append(adj_node)
166                 visited.add(adj_node)
167     return set(node_names).issubset(visited)
168
169
170 def sub_graph_between_nodes(graph: Graph, start_nodes: list, end_nodes: list, detect_extra_start_node: callable=None):
171     """
172     Finds nodes of the sub-graph between 'start_nodes' and 'end_nodes'. Input nodes for the sub-graph nodes are also
173     added to the sub-graph. Constant inputs of the 'start_nodes' are also added to the sub-graph.
174     :param graph: graph to operate on.
175     :param start_nodes: list of nodes names that specifies start nodes.
176     :param end_nodes: list of nodes names that specifies end nodes.
177     :return: list of nodes of the identified sub-graph or None if the sub-graph cannot be extracted.
178     """
179     sub_graph_nodes = list()
180     visited = set(start_nodes)
181     d = deque(start_nodes)
182     extra_start_nodes = []
183
184     nx.set_node_attributes(G=graph, name='prev', values=None)
185     while len(d) != 0:
186         cur_node_name = d.popleft()
187         sub_graph_nodes.append(cur_node_name)
188         if cur_node_name not in end_nodes:  # do not add output nodes of the end_nodes
189             for _, dst_node_name in graph.out_edges(cur_node_name):
190                 if dst_node_name not in visited:
191                     d.append(dst_node_name)
192                     visited.add(dst_node_name)
193                     graph.node[dst_node_name]['prev'] = cur_node_name
194
195         for src_node_name, _ in graph.in_edges(cur_node_name):
196             # add input nodes for the non-start_nodes
197             if cur_node_name not in start_nodes and src_node_name not in visited:
198                 if detect_extra_start_node is not None and detect_extra_start_node(Node(graph, cur_node_name)):
199                     extra_start_nodes.append(cur_node_name)
200                 else:
201                     d.append(src_node_name)
202                     graph.node[src_node_name]['prev'] = cur_node_name
203                     visited.add(src_node_name)
204
205     # use forward dfs to check that all end nodes are reachable from at least one of input nodes
206     forward_visited = set()
207     for start_node in start_nodes:
208         dfs(graph, start_node, forward_visited)
209     for end_node in end_nodes:
210         if end_node not in forward_visited:
211             raise Error('End node "{}" is not reachable from start nodes: {}. '.format(end_node, start_nodes) +
212                         refer_to_faq_msg(74))
213
214     for node_name in sub_graph_nodes:
215         # sub-graph should not contain Placeholder nodes
216         if graph.node[node_name].get('op', '') == 'Placeholder':
217             path = list()
218             cur_node = node_name
219             while cur_node and 'prev' in graph.node[cur_node]:
220                 path.append(str(cur_node))
221                 cur_node = graph.node[cur_node]['prev']
222             log.debug("The path from input node is the following: {}".format('\n'.join(path)))
223             raise Error('The matched sub-graph contains network input node "{}". '.format(node_name) +
224                         refer_to_faq_msg(75))
225     if detect_extra_start_node is None:
226         return sub_graph_nodes
227     else:
228         return sub_graph_nodes, extra_start_nodes
229
230
231 def node_neighbourhood(node_name: str, depth: int, next_node_fn):
232     """
233     Find neighbourhood of the node..
234     :param node_name: name of the node to find neighbourhood for.
235     :param depth: maximum depth of search nodes.
236     :param next_node_fn: callable that accepts node name and should return list of adjacent nodes.
237     :return: list of names of nodes in the neighbourhood.
238     """
239     dist = dict()
240     dist[node_name] = 0
241     deq = deque([node_name])
242     while len(deq) != 0:
243         cur_node_name = deq.popleft()
244         cur_dist = dist[cur_node_name]
245         if cur_dist < depth:
246             for next_node_name in next_node_fn(cur_node_name):
247                 next_dist = dist.setdefault(next_node_name, depth + 1)
248                 if next_dist > cur_dist + 1:
249                     dist[next_node_name] = cur_dist + 1
250                     deq.append(next_node_name)
251     return list(dist.keys())
252
253
254 def node_incoming_neighbourhood(graph: Graph, node_name: str, depth: int):
255     """
256     Find input neighbourhood of the node.
257     :param graph: graph to operate on.
258     :param node_name: name of the node to find neighbourhood for.
259     :param depth: maximum depth of input nodes.
260     :return: list of names of nodes in the neighbourhood.
261     """
262     return node_neighbourhood(node_name, depth, lambda node_name: [u for u, v in graph.in_edges([node_name])])
263
264
265 def node_outcoming_neighbourhood(graph: Graph, node_name: str, depth: int):
266     """
267     Find output neighbourhood of the node.
268     :param graph: graph to operate on.
269     :param node_name: name of the node to find neighbourhood for.
270     :param depth: maximum depth of output nodes.
271     :return: list of names of nodes in the neighbourhood.
272     """
273     return node_neighbourhood(node_name, depth, lambda node_name: [v for u, v in graph.out_edges([node_name])])
274
275
276 def scope_output_nodes(graph: Graph, scope: str, scope_delimiter: str='/'):
277     """
278     The function returns nodes producing output of the sub-graph defined by scope (name prefix). The node is considered
279     output of the scope if it is in this scope and it's output is outside of the scope.
280     :param graph: graph to operate on.
281     :param scope: string with scope (prefix of the node name).
282     :param scope_delimiter: delimiter between scope parts.
283     :return: list of Node objects which are outputs of the scope.
284     """
285     if scope[-1] != scope_delimiter:
286         scope += scope_delimiter
287
288     result = set()
289     for node_id in graph.nodes():
290         if node_id.startswith(scope):
291             for _, out_node_name in graph.out_edges(node_id):
292                 if not out_node_name.startswith(scope):
293                     result.add(node_id)
294                     break
295     return [Node(graph, node_id) for node_id in result]
296