Publishing 2019 R3 content
[platform/upstream/dldt.git] / model-optimizer / mo / utils / graph.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 from collections import deque
18 from re import match, compile
19
20 import logging as log
21 import networkx as nx
22
23 from mo.graph.graph import Node, Graph
24 from mo.utils.error import Error
25 from mo.utils.utils import refer_to_faq_msg
26
27
28 def backward_bfs_for_operation(start_node: Node, op_names: list):
29     """
30     Find node with 'op' attribute equal to one of from 'op_name', searching in the backward direction.
31     In case of branching algorithm goes into each branch, but if it can't find layer in one of them it returns
32     empty list.
33
34     :param start_node: Start node for BFS algorithm
35     :param op_names: The list with names of operations to search
36     """
37     ret = []
38     q = deque([start_node])
39     while len(q) != 0:
40         node = q.popleft()
41         in_nodes_size = len(node.in_nodes())
42         for id in range(in_nodes_size):  # in_nodes() can return either list or dict
43             pnode = node.in_node(id)
44             if pnode.kind == 'op':
45                 if pnode.has_valid('op') and pnode.op in op_names:
46                     if pnode.id not in ret:
47                         ret.append(pnode.id)
48                 else:
49                     q.append(pnode)
50             elif pnode.kind == 'data' and pnode.value is None:
51                 q.append(pnode)
52     return [Node(start_node.graph, x) for x in ret]
53
54
55 def bfs_search(graph: Graph, start_nodes: list = list()):
56     """
57     Performs breadth-first search over a graph and returns a list of nodes in the BFS order.
58     :param graph: networkx graph to traverse.
59     :param start_nodes: list of start nodes of the graph. If the list is empty then start from all nodes that do not
60     have input nodes.
61     :return: the list of nodes in the BFS order.
62     """
63     result = list()
64     if len(start_nodes) == 0:
65         start_nodes = [node_name for node_name in graph.nodes() if len(graph.in_edges(node_name)) == 0]
66
67     visited = set(start_nodes)
68     d = deque(start_nodes)
69
70     while len(d) != 0:
71         cur_node_name = d.popleft()
72         result.append(cur_node_name)
73         for src_node, dst_node in graph.out_edges(cur_node_name):
74             if dst_node not in visited:
75                 d.append(dst_node)
76                 visited.add(dst_node)
77     return result
78
79
80 def nodes_matching_name_pattern(graph: Graph, pattern: str):
81     """
82     Returns list of node names of the graph that match regular expression.
83     :param graph: graph to operate on.
84     :param pattern: regular expression describing node name pattern.
85     :return: list of matched node names.
86     """
87     compiled_pattern = compile(pattern)
88     return [node_name for node_name in list(graph.nodes()) if match(compiled_pattern, node_name)]
89
90
91 def is_connected_component(graph: Graph, node_names: list):
92     """
93     Checks that specified list of nodes forms a connected sub-graph. It ignores edges direction.
94     The algorithm is the following. Run BFS from one of the nodes from the node_names list ignoring edges order and
95     visiting only nodes from the node_names list. Prepare list of visited nodes. If this list is equal to the
96     node_names list (we actually check that the node_names set is sub-set of 'visited' set that is equivalent) then the
97     sub-graph is connected.
98     :param graph: graph to operate on.
99     :param node_names: list of node names to be checked.
100     :return: Result of the check.
101     """
102     if len(node_names) == 0:
103         return True
104
105     d = deque([node_names[0]])
106     visited = set([node_names[0]])
107     while len(d) != 0:
108         cur_node_name = d.popleft()
109         visited.add(cur_node_name)
110         # find adjacent nodes from the list of node_names. Ignoring edges direction
111         adj_nodes = [src_node for src_node, _ in graph.in_edges(cur_node_name) if src_node in node_names] + \
112                     [dst_node for _, dst_node in graph.out_edges(cur_node_name) if dst_node in node_names]
113         for adj_node in adj_nodes:
114             if adj_node not in visited:
115                 d.append(adj_node)
116                 visited.add(adj_node)
117     return set(node_names).issubset(visited)
118
119
120 def sub_graph_between_nodes(graph: Graph, start_nodes: list, end_nodes: list, detect_extra_start_node: callable=None):
121     """
122     Finds nodes of the sub-graph between 'start_nodes' and 'end_nodes'. Input nodes for the sub-graph nodes are also
123     added to the sub-graph. Constant inputs of the 'start_nodes' are also added to the sub-graph.
124     :param graph: graph to operate on.
125     :param start_nodes: list of nodes names that specifies start nodes.
126     :param end_nodes: list of nodes names that specifies end nodes.
127     :return: list of nodes of the identified sub-graph or None if the sub-graph cannot be extracted.
128     """
129     sub_graph_nodes = list()
130     visited = set(start_nodes)
131     d = deque(start_nodes)
132     extra_start_nodes = []
133
134     nx.set_node_attributes(G=graph, name='prev', values=None)
135     while len(d) != 0:
136         cur_node_name = d.popleft()
137         sub_graph_nodes.append(cur_node_name)
138         if cur_node_name not in end_nodes:  # do not add output nodes of the end_nodes
139             for _, dst_node_name in graph.out_edges(cur_node_name):
140                 if dst_node_name not in visited:
141                     d.append(dst_node_name)
142                     visited.add(dst_node_name)
143                     graph.node[dst_node_name]['prev'] = cur_node_name
144
145         for src_node_name, _ in graph.in_edges(cur_node_name):
146             # add input nodes for the non-start_nodes
147             if cur_node_name not in start_nodes and src_node_name not in visited:
148                 if detect_extra_start_node is not None and detect_extra_start_node(Node(graph, cur_node_name)):
149                     extra_start_nodes.append(cur_node_name)
150                 else:
151                     d.append(src_node_name)
152                     graph.node[src_node_name]['prev'] = cur_node_name
153                     visited.add(src_node_name)
154
155     # use forward dfs to check that all end nodes are reachable from at least one of input nodes
156     forward_visited = set()
157     for start_node in start_nodes:
158         graph.dfs(start_node, forward_visited)
159     for end_node in end_nodes:
160         if end_node not in forward_visited:
161             raise Error('End node "{}" is not reachable from start nodes: {}. '.format(end_node, start_nodes) +
162                         refer_to_faq_msg(74))
163
164     for node_name in sub_graph_nodes:
165         # sub-graph should not contain Placeholder nodes
166         if graph.node[node_name].get('op', '') == 'Parameter':
167             path = list()
168             cur_node = node_name
169             while cur_node and 'prev' in graph.node[cur_node]:
170                 path.append(str(cur_node))
171                 cur_node = graph.node[cur_node]['prev']
172             log.debug("The path from input node is the following: {}".format('\n'.join(path)))
173             raise Error('The matched sub-graph contains network input node "{}". '.format(node_name) +
174                         refer_to_faq_msg(75))
175     if detect_extra_start_node is None:
176         return sub_graph_nodes
177     else:
178         return sub_graph_nodes, extra_start_nodes
179
180
181 def invert_sub_graph_between_nodes(graph: Graph, start_nodes: list, end_nodes: list, detect_extra_start_node: callable=None):
182     """
183     Finds nodes of the sub-graph between 'start_nodes' and 'end_nodes'. But doing it from start_nodes stepping
184     backward by in edges.
185
186     Input nodes for the sub-graph nodes are also added to the sub-graph. Constant inputs of the 'start_nodes'
187     are also added to the sub-graph.
188     :param graph: graph to operate on.
189     :param start_nodes: list of nodes names that specifies start nodes.
190     :param end_nodes: list of nodes names that specifies end nodes.
191     :return: list of nodes of the identified sub-graph or None if the sub-graph cannot be extracted.
192     """
193     sub_graph_nodes = list()
194     visited = set(start_nodes)
195     d = deque(start_nodes)
196     extra_start_nodes = []
197
198     nx.set_node_attributes(G=graph, name='prev', values=None)
199     while len(d) != 0:
200         cur_node_name = d.popleft()
201         sub_graph_nodes.append(cur_node_name)
202         if cur_node_name not in start_nodes and \
203                 detect_extra_start_node is not None and detect_extra_start_node(Node(graph, cur_node_name)):
204             extra_start_nodes.append(cur_node_name)
205         else:
206             if cur_node_name not in end_nodes:  # do not add output nodes of the end_nodes
207                 for src_node_name, _ in graph.in_edges(cur_node_name):
208                     if src_node_name not in visited:
209                         d.append(src_node_name)
210                         visited.add(src_node_name)
211                         graph.node[cur_node_name]['prev'] = src_node_name
212
213     for node_name in sub_graph_nodes:
214         # sub-graph should not contain Input nodes
215         if graph.node[node_name].get('op', '') == 'Parameter':
216             path = list()
217             cur_node = node_name
218             while cur_node and 'prev' in graph.node[cur_node]:
219                 path.append(str(cur_node))
220                 cur_node = graph.node[cur_node]['prev']
221             log.debug("The path from input node is the following: {}".format('\n'.join(path)))
222             raise Error('The matched sub-graph contains network input node "{}". '.format(node_name) +
223                         refer_to_faq_msg(75))
224     if detect_extra_start_node is None:
225         return sub_graph_nodes
226     else:
227         return sub_graph_nodes, extra_start_nodes
228
229
230 def node_neighbourhood(node_name: str, depth: int, next_node_fn):
231     """
232     Find neighbourhood of the node..
233     :param node_name: name of the node to find neighbourhood for.
234     :param depth: maximum depth of search nodes.
235     :param next_node_fn: callable that accepts node name and should return list of adjacent nodes.
236     :return: list of names of nodes in the neighbourhood.
237     """
238     dist = dict()
239     dist[node_name] = 0
240     deq = deque([node_name])
241     while len(deq) != 0:
242         cur_node_name = deq.popleft()
243         cur_dist = dist[cur_node_name]
244         if cur_dist < depth:
245             for next_node_name in next_node_fn(cur_node_name):
246                 next_dist = dist.setdefault(next_node_name, depth + 1)
247                 if next_dist > cur_dist + 1:
248                     dist[next_node_name] = cur_dist + 1
249                     deq.append(next_node_name)
250     return list(dist.keys())
251
252
253 def node_incoming_neighbourhood(graph: Graph, node_name: str, depth: int):
254     """
255     Find input neighbourhood of the node.
256     :param graph: graph to operate on.
257     :param node_name: name of the node to find neighbourhood for.
258     :param depth: maximum depth of input nodes.
259     :return: list of names of nodes in the neighbourhood.
260     """
261     return node_neighbourhood(node_name, depth, lambda node_name: [u for u, v in graph.in_edges([node_name])])
262
263
264 def node_outcoming_neighbourhood(graph: Graph, node_name: str, depth: int):
265     """
266     Find output neighbourhood of the node.
267     :param graph: graph to operate on.
268     :param node_name: name of the node to find neighbourhood for.
269     :param depth: maximum depth of output nodes.
270     :return: list of names of nodes in the neighbourhood.
271     """
272     return node_neighbourhood(node_name, depth, lambda node_name: [v for u, v in graph.out_edges([node_name])])
273
274
275 def scope_output_nodes(graph: Graph, scope: str, scope_delimiter: str='/'):
276     """
277     The function returns nodes producing output of the sub-graph defined by scope (name prefix). The node is considered
278     output of the scope if it is in this scope and it's output is outside of the scope.
279     :param graph: graph to operate on.
280     :param scope: string with scope (prefix of the node name).
281     :param scope_delimiter: delimiter between scope parts.
282     :return: list of Node objects which are outputs of the scope.
283     """
284     if scope[-1] != scope_delimiter:
285         scope += scope_delimiter
286
287     result = set()
288     for node_id in graph.nodes():
289         if node_id.startswith(scope):
290             for _, out_node_name in graph.out_edges(node_id):
291                 if not out_node_name.startswith(scope):
292                     result.add(node_id)
293                     break
294     return [Node(graph, node_id) for node_id in result]
295