2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
17 from collections import deque
18 from re import match, compile
23 from mo.graph.graph import Node, Graph
24 from mo.utils.error import Error
25 from mo.utils.utils import refer_to_faq_msg
28 def backward_bfs_for_operation(start_node: Node, op_names: list):
30 Find node with 'op' attribute equal to one of from 'op_name', searching in the backward direction.
31 In case of branching algorithm goes into each branch, but if it can't find layer in one of them it returns
34 :param start_node: Start node for BFS algorithm
35 :param op_names: The list with names of operations to search
38 q = deque([start_node])
41 in_nodes_size = len(node.in_nodes())
42 for id in range(in_nodes_size): # in_nodes() can return either list or dict
43 pnode = node.in_node(id)
44 if pnode.kind == 'op':
45 if pnode.has_valid('op') and pnode.op in op_names:
46 if pnode.id not in ret:
50 elif pnode.kind == 'data' and pnode.value is None:
52 return [Node(start_node.graph, x) for x in ret]
55 def bfs_search(graph: Graph, start_nodes: list = list()):
57 Performs breadth-first search over a graph and returns a list of nodes in the BFS order.
58 :param graph: networkx graph to traverse.
59 :param start_nodes: list of start nodes of the graph. If the list is empty then start from all nodes that do not
61 :return: the list of nodes in the BFS order.
64 if len(start_nodes) == 0:
65 start_nodes = [node_name for node_name in graph.nodes() if len(graph.in_edges(node_name)) == 0]
67 visited = set(start_nodes)
68 d = deque(start_nodes)
71 cur_node_name = d.popleft()
72 result.append(cur_node_name)
73 for src_node, dst_node in graph.out_edges(cur_node_name):
74 if dst_node not in visited:
80 def dfs(graph: Graph, node_name: str, visited: set):
82 Implementation of the depth-first search algorithm starting from the specific node.
83 :param graph: networkx graph to operate on.
84 :param node_name: node name to start search from.
85 :param visited: set of already visited nodes.
86 :return: list of nodes in the DFS-visit order.
90 while len(stack) != 0:
93 visited.add(node_name)
95 for _, out_node_name in graph.out_edges(node_name):
96 if out_node_name not in visited:
97 stack.insert(0, node_name)
98 stack.insert(0, out_node_name)
102 order.append(node_name)
106 def pseudo_topological_sort(graph: Graph, reverse: bool = False):
108 The function performs topological sort but doesn't check for cycle existence. So it may produce wrong nodes order
109 for some applications.
110 :param graph: graph to pseudo-topologically sort.
111 :param reverse: flag indicating whether need to reverse nodes order.
112 :return: nodes in the topological sort if cycle doesn't exist and in pseudo-topological sort if not.
114 nodes_without_inputs = list()
115 for node_name in graph.nodes():
116 if len(graph.in_edges(node_name)) == 0:
117 nodes_without_inputs.append(node_name)
120 for node_name in nodes_without_inputs:
121 if node_name not in visited:
122 order.extend(dfs(graph, node_name, visited))
127 return list(reversed(order))
130 def nodes_matching_name_pattern(graph: Graph, pattern: str):
132 Returns list of node names of the graph that match regular expression.
133 :param graph: graph to operate on.
134 :param pattern: regular expression describing node name pattern.
135 :return: list of matched node names.
137 compiled_pattern = compile(pattern)
138 return [node_name for node_name in list(graph.nodes()) if match(compiled_pattern, node_name)]
141 def is_connected_component(graph: Graph, node_names: list):
143 Checks that specified list of nodes forms a connected sub-graph. It ignores edges direction.
144 The algorithm is the following. Run BFS from one of the nodes from the node_names list ignoring edges order and
145 visiting only nodes from the node_names list. Prepare list of visited nodes. If this list is equal to the
146 node_names list (we actually check that the node_names set is sub-set of 'visited' set that is equivalent) then the
147 sub-graph is connected.
148 :param graph: graph to operate on.
149 :param node_names: list of node names to be checked.
150 :return: Result of the check.
152 if len(node_names) == 0:
155 d = deque([node_names[0]])
156 visited = set([node_names[0]])
158 cur_node_name = d.popleft()
159 visited.add(cur_node_name)
160 # find adjacent nodes from the list of node_names. Ignoring edges direction
161 adj_nodes = [src_node for src_node, _ in graph.in_edges(cur_node_name) if src_node in node_names] + \
162 [dst_node for _, dst_node in graph.out_edges(cur_node_name) if dst_node in node_names]
163 for adj_node in adj_nodes:
164 if adj_node not in visited:
166 visited.add(adj_node)
167 return set(node_names).issubset(visited)
170 def sub_graph_between_nodes(graph: Graph, start_nodes: list, end_nodes: list, detect_extra_start_node: callable=None):
172 Finds nodes of the sub-graph between 'start_nodes' and 'end_nodes'. Input nodes for the sub-graph nodes are also
173 added to the sub-graph. Constant inputs of the 'start_nodes' are also added to the sub-graph.
174 :param graph: graph to operate on.
175 :param start_nodes: list of nodes names that specifies start nodes.
176 :param end_nodes: list of nodes names that specifies end nodes.
177 :return: list of nodes of the identified sub-graph or None if the sub-graph cannot be extracted.
179 sub_graph_nodes = list()
180 visited = set(start_nodes)
181 d = deque(start_nodes)
182 extra_start_nodes = []
184 nx.set_node_attributes(G=graph, name='prev', values=None)
186 cur_node_name = d.popleft()
187 sub_graph_nodes.append(cur_node_name)
188 if cur_node_name not in end_nodes: # do not add output nodes of the end_nodes
189 for _, dst_node_name in graph.out_edges(cur_node_name):
190 if dst_node_name not in visited:
191 d.append(dst_node_name)
192 visited.add(dst_node_name)
193 graph.node[dst_node_name]['prev'] = cur_node_name
195 for src_node_name, _ in graph.in_edges(cur_node_name):
196 # add input nodes for the non-start_nodes
197 if cur_node_name not in start_nodes and src_node_name not in visited:
198 if detect_extra_start_node is not None and detect_extra_start_node(Node(graph, cur_node_name)):
199 extra_start_nodes.append(cur_node_name)
201 d.append(src_node_name)
202 graph.node[src_node_name]['prev'] = cur_node_name
203 visited.add(src_node_name)
205 # use forward dfs to check that all end nodes are reachable from at least one of input nodes
206 forward_visited = set()
207 for start_node in start_nodes:
208 dfs(graph, start_node, forward_visited)
209 for end_node in end_nodes:
210 if end_node not in forward_visited:
211 raise Error('End node "{}" is not reachable from start nodes: {}. '.format(end_node, start_nodes) +
212 refer_to_faq_msg(74))
214 for node_name in sub_graph_nodes:
215 # sub-graph should not contain Placeholder nodes
216 if graph.node[node_name].get('op', '') == 'Placeholder':
219 while cur_node and 'prev' in graph.node[cur_node]:
220 path.append(str(cur_node))
221 cur_node = graph.node[cur_node]['prev']
222 log.debug("The path from input node is the following: {}".format('\n'.join(path)))
223 raise Error('The matched sub-graph contains network input node "{}". '.format(node_name) +
224 refer_to_faq_msg(75))
225 if detect_extra_start_node is None:
226 return sub_graph_nodes
228 return sub_graph_nodes, extra_start_nodes
231 def node_neighbourhood(node_name: str, depth: int, next_node_fn):
233 Find neighbourhood of the node..
234 :param node_name: name of the node to find neighbourhood for.
235 :param depth: maximum depth of search nodes.
236 :param next_node_fn: callable that accepts node name and should return list of adjacent nodes.
237 :return: list of names of nodes in the neighbourhood.
241 deq = deque([node_name])
243 cur_node_name = deq.popleft()
244 cur_dist = dist[cur_node_name]
246 for next_node_name in next_node_fn(cur_node_name):
247 next_dist = dist.setdefault(next_node_name, depth + 1)
248 if next_dist > cur_dist + 1:
249 dist[next_node_name] = cur_dist + 1
250 deq.append(next_node_name)
251 return list(dist.keys())
254 def node_incoming_neighbourhood(graph: Graph, node_name: str, depth: int):
256 Find input neighbourhood of the node.
257 :param graph: graph to operate on.
258 :param node_name: name of the node to find neighbourhood for.
259 :param depth: maximum depth of input nodes.
260 :return: list of names of nodes in the neighbourhood.
262 return node_neighbourhood(node_name, depth, lambda node_name: [u for u, v in graph.in_edges([node_name])])
265 def node_outcoming_neighbourhood(graph: Graph, node_name: str, depth: int):
267 Find output neighbourhood of the node.
268 :param graph: graph to operate on.
269 :param node_name: name of the node to find neighbourhood for.
270 :param depth: maximum depth of output nodes.
271 :return: list of names of nodes in the neighbourhood.
273 return node_neighbourhood(node_name, depth, lambda node_name: [v for u, v in graph.out_edges([node_name])])
276 def scope_output_nodes(graph: Graph, scope: str, scope_delimiter: str='/'):
278 The function returns nodes producing output of the sub-graph defined by scope (name prefix). The node is considered
279 output of the scope if it is in this scope and it's output is outside of the scope.
280 :param graph: graph to operate on.
281 :param scope: string with scope (prefix of the node name).
282 :param scope_delimiter: delimiter between scope parts.
283 :return: list of Node objects which are outputs of the scope.
285 if scope[-1] != scope_delimiter:
286 scope += scope_delimiter
289 for node_id in graph.nodes():
290 if node_id.startswith(scope):
291 for _, out_node_name in graph.out_edges(node_id):
292 if not out_node_name.startswith(scope):
295 return [Node(graph, node_id) for node_id in result]