2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
17 from collections import deque
18 from re import match, compile
23 from mo.graph.graph import Node, Graph
24 from mo.utils.error import Error
25 from mo.utils.utils import refer_to_faq_msg
28 def backward_bfs_for_operation(start_node: Node, op_names: list):
30 Find node with 'op' attribute equal to one of from 'op_name', searching in the backward direction.
31 In case of branching algorithm goes into each branch, but if it can't find layer in one of them it returns
34 :param start_node: Start node for BFS algorithm
35 :param op_names: The list with names of operations to search
38 q = deque([start_node])
41 in_nodes_size = len(node.in_nodes())
42 for id in range(in_nodes_size): # in_nodes() can return either list or dict
43 pnode = node.in_node(id)
44 if pnode.kind == 'op':
45 if pnode.has_valid('op') and pnode.op in op_names:
46 if pnode.id not in ret:
50 elif pnode.kind == 'data' and pnode.value is None:
52 return [Node(start_node.graph, x) for x in ret]
55 def bfs_search(graph: Graph, start_nodes: list = list()):
57 Performs breadth-first search over a graph and returns a list of nodes in the BFS order.
58 :param graph: networkx graph to traverse.
59 :param start_nodes: list of start nodes of the graph. If the list is empty then start from all nodes that do not
61 :return: the list of nodes in the BFS order.
64 if len(start_nodes) == 0:
65 start_nodes = [node_name for node_name in graph.nodes() if len(graph.in_edges(node_name)) == 0]
67 visited = set(start_nodes)
68 d = deque(start_nodes)
71 cur_node_name = d.popleft()
72 result.append(cur_node_name)
73 for src_node, dst_node in graph.out_edges(cur_node_name):
74 if dst_node not in visited:
80 def nodes_matching_name_pattern(graph: Graph, pattern: str):
82 Returns list of node names of the graph that match regular expression.
83 :param graph: graph to operate on.
84 :param pattern: regular expression describing node name pattern.
85 :return: list of matched node names.
87 compiled_pattern = compile(pattern)
88 return [node_name for node_name in list(graph.nodes()) if match(compiled_pattern, node_name)]
91 def is_connected_component(graph: Graph, node_names: list):
93 Checks that specified list of nodes forms a connected sub-graph. It ignores edges direction.
94 The algorithm is the following. Run BFS from one of the nodes from the node_names list ignoring edges order and
95 visiting only nodes from the node_names list. Prepare list of visited nodes. If this list is equal to the
96 node_names list (we actually check that the node_names set is sub-set of 'visited' set that is equivalent) then the
97 sub-graph is connected.
98 :param graph: graph to operate on.
99 :param node_names: list of node names to be checked.
100 :return: Result of the check.
102 if len(node_names) == 0:
105 d = deque([node_names[0]])
106 visited = set([node_names[0]])
108 cur_node_name = d.popleft()
109 visited.add(cur_node_name)
110 # find adjacent nodes from the list of node_names. Ignoring edges direction
111 adj_nodes = [src_node for src_node, _ in graph.in_edges(cur_node_name) if src_node in node_names] + \
112 [dst_node for _, dst_node in graph.out_edges(cur_node_name) if dst_node in node_names]
113 for adj_node in adj_nodes:
114 if adj_node not in visited:
116 visited.add(adj_node)
117 return set(node_names).issubset(visited)
120 def sub_graph_between_nodes(graph: Graph, start_nodes: list, end_nodes: list, detect_extra_start_node: callable=None):
122 Finds nodes of the sub-graph between 'start_nodes' and 'end_nodes'. Input nodes for the sub-graph nodes are also
123 added to the sub-graph. Constant inputs of the 'start_nodes' are also added to the sub-graph.
124 :param graph: graph to operate on.
125 :param start_nodes: list of nodes names that specifies start nodes.
126 :param end_nodes: list of nodes names that specifies end nodes.
127 :return: list of nodes of the identified sub-graph or None if the sub-graph cannot be extracted.
129 sub_graph_nodes = list()
130 visited = set(start_nodes)
131 d = deque(start_nodes)
132 extra_start_nodes = []
134 nx.set_node_attributes(G=graph, name='prev', values=None)
136 cur_node_name = d.popleft()
137 sub_graph_nodes.append(cur_node_name)
138 if cur_node_name not in end_nodes: # do not add output nodes of the end_nodes
139 for _, dst_node_name in graph.out_edges(cur_node_name):
140 if dst_node_name not in visited:
141 d.append(dst_node_name)
142 visited.add(dst_node_name)
143 graph.node[dst_node_name]['prev'] = cur_node_name
145 for src_node_name, _ in graph.in_edges(cur_node_name):
146 # add input nodes for the non-start_nodes
147 if cur_node_name not in start_nodes and src_node_name not in visited:
148 if detect_extra_start_node is not None and detect_extra_start_node(Node(graph, cur_node_name)):
149 extra_start_nodes.append(cur_node_name)
151 d.append(src_node_name)
152 graph.node[src_node_name]['prev'] = cur_node_name
153 visited.add(src_node_name)
155 # use forward dfs to check that all end nodes are reachable from at least one of input nodes
156 forward_visited = set()
157 for start_node in start_nodes:
158 graph.dfs(start_node, forward_visited)
159 for end_node in end_nodes:
160 if end_node not in forward_visited:
161 raise Error('End node "{}" is not reachable from start nodes: {}. '.format(end_node, start_nodes) +
162 refer_to_faq_msg(74))
164 for node_name in sub_graph_nodes:
165 # sub-graph should not contain Placeholder nodes
166 if graph.node[node_name].get('op', '') == 'Parameter':
169 while cur_node and 'prev' in graph.node[cur_node]:
170 path.append(str(cur_node))
171 cur_node = graph.node[cur_node]['prev']
172 log.debug("The path from input node is the following: {}".format('\n'.join(path)))
173 raise Error('The matched sub-graph contains network input node "{}". '.format(node_name) +
174 refer_to_faq_msg(75))
175 if detect_extra_start_node is None:
176 return sub_graph_nodes
178 return sub_graph_nodes, extra_start_nodes
181 def invert_sub_graph_between_nodes(graph: Graph, start_nodes: list, end_nodes: list, detect_extra_start_node: callable=None):
183 Finds nodes of the sub-graph between 'start_nodes' and 'end_nodes'. But doing it from start_nodes stepping
184 backward by in edges.
186 Input nodes for the sub-graph nodes are also added to the sub-graph. Constant inputs of the 'start_nodes'
187 are also added to the sub-graph.
188 :param graph: graph to operate on.
189 :param start_nodes: list of nodes names that specifies start nodes.
190 :param end_nodes: list of nodes names that specifies end nodes.
191 :return: list of nodes of the identified sub-graph or None if the sub-graph cannot be extracted.
193 sub_graph_nodes = list()
194 visited = set(start_nodes)
195 d = deque(start_nodes)
196 extra_start_nodes = []
198 nx.set_node_attributes(G=graph, name='prev', values=None)
200 cur_node_name = d.popleft()
201 sub_graph_nodes.append(cur_node_name)
202 if cur_node_name not in start_nodes and \
203 detect_extra_start_node is not None and detect_extra_start_node(Node(graph, cur_node_name)):
204 extra_start_nodes.append(cur_node_name)
206 if cur_node_name not in end_nodes: # do not add output nodes of the end_nodes
207 for src_node_name, _ in graph.in_edges(cur_node_name):
208 if src_node_name not in visited:
209 d.append(src_node_name)
210 visited.add(src_node_name)
211 graph.node[cur_node_name]['prev'] = src_node_name
213 for node_name in sub_graph_nodes:
214 # sub-graph should not contain Input nodes
215 if graph.node[node_name].get('op', '') == 'Parameter':
218 while cur_node and 'prev' in graph.node[cur_node]:
219 path.append(str(cur_node))
220 cur_node = graph.node[cur_node]['prev']
221 log.debug("The path from input node is the following: {}".format('\n'.join(path)))
222 raise Error('The matched sub-graph contains network input node "{}". '.format(node_name) +
223 refer_to_faq_msg(75))
224 if detect_extra_start_node is None:
225 return sub_graph_nodes
227 return sub_graph_nodes, extra_start_nodes
230 def node_neighbourhood(node_name: str, depth: int, next_node_fn):
232 Find neighbourhood of the node..
233 :param node_name: name of the node to find neighbourhood for.
234 :param depth: maximum depth of search nodes.
235 :param next_node_fn: callable that accepts node name and should return list of adjacent nodes.
236 :return: list of names of nodes in the neighbourhood.
240 deq = deque([node_name])
242 cur_node_name = deq.popleft()
243 cur_dist = dist[cur_node_name]
245 for next_node_name in next_node_fn(cur_node_name):
246 next_dist = dist.setdefault(next_node_name, depth + 1)
247 if next_dist > cur_dist + 1:
248 dist[next_node_name] = cur_dist + 1
249 deq.append(next_node_name)
250 return list(dist.keys())
253 def node_incoming_neighbourhood(graph: Graph, node_name: str, depth: int):
255 Find input neighbourhood of the node.
256 :param graph: graph to operate on.
257 :param node_name: name of the node to find neighbourhood for.
258 :param depth: maximum depth of input nodes.
259 :return: list of names of nodes in the neighbourhood.
261 return node_neighbourhood(node_name, depth, lambda node_name: [u for u, v in graph.in_edges([node_name])])
264 def node_outcoming_neighbourhood(graph: Graph, node_name: str, depth: int):
266 Find output neighbourhood of the node.
267 :param graph: graph to operate on.
268 :param node_name: name of the node to find neighbourhood for.
269 :param depth: maximum depth of output nodes.
270 :return: list of names of nodes in the neighbourhood.
272 return node_neighbourhood(node_name, depth, lambda node_name: [v for u, v in graph.out_edges([node_name])])
275 def scope_output_nodes(graph: Graph, scope: str, scope_delimiter: str='/'):
277 The function returns nodes producing output of the sub-graph defined by scope (name prefix). The node is considered
278 output of the scope if it is in this scope and it's output is outside of the scope.
279 :param graph: graph to operate on.
280 :param scope: string with scope (prefix of the node name).
281 :param scope_delimiter: delimiter between scope parts.
282 :return: list of Node objects which are outputs of the scope.
284 if scope[-1] != scope_delimiter:
285 scope += scope_delimiter
288 for node_id in graph.nodes():
289 if node_id.startswith(scope):
290 for _, out_node_name in graph.out_edges(node_id):
291 if not out_node_name.startswith(scope):
294 return [Node(graph, node_id) for node_id in result]