2 Copyright (c) 2018 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
17 from collections import deque
21 from mo.graph.graph import Node, create_edge
22 from mo.middle.pattern_match import apply_pattern
23 from mo.utils.graph import bfs_search, pseudo_topological_sort
26 def get_nodes_with_attributes(graph: nx.MultiDiGraph, **attrs: dict):
27 node_attrs = graph.nodes(data=True)
28 return [n for n, d in node_attrs if all(a in d.items() for a in attrs.items())]
31 def reverse_dfs(graph: nx.MultiDiGraph, node_name: str, update_func: callable, visited: set = None):
36 visited.add(node_name)
37 d.appendleft(node_name)
39 cur_node = d.popleft()
40 update_func(graph, cur_node)
41 for in_node_name, _ in graph.in_edges(cur_node):
42 if in_node_name not in visited:
43 visited.add(in_node_name)
44 d.append(in_node_name)
47 def mark_input_nodes(graph: nx.MultiDiGraph, node_name: str, key: str, value):
48 for input, _ in graph.in_edges(node_name):
49 graph.node[input][key] = value
52 def mark_output_nodes(graph: nx.MultiDiGraph, node_name: str, key: str, value):
53 for output, _ in graph.out_edges(node_name):
54 graph.node[output][key] = value
57 def mark_output_reachable_nodes(graph: nx.MultiDiGraph):
59 Mark nodes whether they are outputs reachable or not. The node is considered output reachable if it is connected to
60 one of the nodes that has attribute is_output=True.
62 nx.set_node_attributes(graph, name='is_output_reachable', values=False)
63 outputs = get_nodes_with_attributes(graph, is_output=True)
64 log.debug('The following nodes are seeded as output reachable:\n{}'.format('\n'.join(sorted(map(str, outputs)))))
65 nx.set_node_attributes(graph, name='is_output_reachable', values={n: True for n in outputs})
66 for output_name in outputs:
67 reverse_dfs(graph, output_name,
68 lambda graph, node_name: mark_input_nodes(graph, node_name, 'is_output_reachable', True), set())
71 def mark_undead_nodes(graph: nx.MultiDiGraph, undead_types: list):
73 Mark output nodes and nodes of the specific type as undead, meaning that they should survive the dead nodes
74 elimination phase. Then mark all children nodes of the undead nodes (except children of inputs) as undead.
75 :param graph: graph to operate on.
76 :param undead_types: list of node types that should be marked as undead.
77 :return: updated graph where each has attribute 'is_undead'.
79 nx.set_node_attributes(graph, name='is_undead', values=False)
81 # mark output nodes as undead
82 outputs = get_nodes_with_attributes(graph, is_output=True)
83 nx.set_node_attributes(graph, name='is_undead', values={n: True for n in outputs})
85 # mark specifically defined with node type set of nodes
86 for type in undead_types:
87 node_of_specific_type = get_nodes_with_attributes(graph, type=type)
88 nx.set_node_attributes(graph, name='is_undead', values={n: True for n in node_of_specific_type})
90 undead_nodes = get_nodes_with_attributes(graph, is_undead=True)
91 # propagate 'undead' attribute to children nodes of undead nodes if the node produces constant value
92 for node_name in bfs_search(graph, undead_nodes):
93 if graph.node[node_name]['is_undead']:
94 for _, dst_node_name in graph.out_edges(node_name):
95 node_attrs = graph.node[dst_node_name]
96 if 'kind' in node_attrs and node_attrs['kind'] == 'data' and node_attrs['value'] is not None:
97 graph.node[dst_node_name]['is_undead'] = True
99 # mark input nodes as undead
100 inputs = get_nodes_with_attributes(graph, is_input=True)
101 nx.set_node_attributes(graph, name='is_undead', values={n: True for n in inputs})
104 def mark_const_producer_nodes(graph: nx.MultiDiGraph):
106 Mark nodes that produce constant values.
107 :param graph: graph to operate on.
110 nx.set_node_attributes(graph, name='is_const_producer', values=True)
112 for n in pseudo_topological_sort(graph):
113 node = Node(graph, n)
114 for input, output, attrs in graph.in_edges(n, data=True):
115 if 'control_flow_edge' in attrs and attrs['control_flow_edge']:
116 graph.node[input]['is_const_producer'] = False
117 graph.node[output]['is_const_producer'] = False
119 if not node.has('value') or node.value is None:
120 for input, _ in graph.in_edges(n):
121 graph.node[input]['is_const_producer'] = False
124 def eliminate_dead_nodes(graph: nx.MultiDiGraph):
125 nodes_to_remove = set()
126 for node_name, node_attrs in graph.nodes(data=True):
127 if not node_attrs['is_output_reachable'] or (node_attrs['is_const_producer'] and not node_attrs['is_undead']):
128 nodes_to_remove.add(node_name)
129 log.debug('Removing the following dead nodes: {}'.format('\n'.join(sorted(map(str, nodes_to_remove)))))
130 graph.remove_nodes_from(nodes_to_remove)
133 def graph_clean_up(graph: nx.MultiDiGraph, undead_node_types: list = []):
134 mark_output_reachable_nodes(graph)
135 mark_undead_nodes(graph, undead_node_types)
136 mark_const_producer_nodes(graph)
137 eliminate_dead_nodes(graph)
140 def graph_clean_up_tf(graph: nx.MultiDiGraph):
141 graph_clean_up(graph, ['TFCustomSubgraphCall'])
144 def remove_identity_action(graph: nx.MultiDiGraph, matches: dict):
145 remove_op_node(graph, matches['identity'])
149 def merge_data_nodes(graph: nx.MultiDiGraph, survived: Node, removed: Node):
150 if survived.has_and_set('is_output'):
151 graph.node[removed.id].update({'is_output': True})
153 for u, v, d in list(graph.in_edges(removed.id, data=True)):
154 graph.add_edges_from([(u, survived.id, d)])
155 graph.remove_edge(u, v)
157 for u, v, d in list(graph.out_edges(removed.id, data=True)):
158 graph.add_edges_from([(survived.id, v, d)])
159 graph.remove_edge(u, v)
161 for attr in graph.node[removed.id]:
162 if not attr in ['name']:
163 # We need to save debug info from removed data node
164 if attr == 'fw_tensor_debug_info':
165 if not survived.has_valid(attr):
167 for fw_tensor_debug_info in removed[attr]:
168 survived[attr].append(fw_tensor_debug_info)
170 survived[attr] = removed[attr]
174 def remove_op_node(graph: nx.MultiDiGraph, identity: Node):
175 input = identity.in_node()
176 output = [v for _, v in graph.out_edges(identity.id)]
177 assert len(output) == 1
178 output = Node(graph, output[0])
180 graph.remove_edge(input.id, identity.id)
181 graph.remove_edge(identity.id, output.id)
183 merge_data_nodes(graph, output, input)
185 # we just have saved all output edges from 'input' by reconnecting them to 'output', now we can delete 'input'
186 log.debug('Removing op node: {}'.format(identity.id))
187 graph.remove_node(identity.id)
188 graph.remove_node(input.id)
191 def remove_op_nodes(graph: nx.MultiDiGraph, attrs: dict):
192 op_attrs = {'kind': 'op'}
193 op_attrs.update(attrs)
196 nodes=[('identity', op_attrs)],
198 action=remove_identity_action,
199 node_attrs=['kind'] + list(attrs.keys()),
203 def remove_edges_for_nodes(graph: nx.MultiDiGraph, node_attrs: dict, edge_attrs: dict):
204 for node in graph.nodes():
205 node = Node(graph, node)
206 if all([node.has(attr) and node[attr] == node_attrs[attr] for attr in node_attrs]):
207 nodes_edges = node.in_nodes_edges()
208 for port in nodes_edges:
209 src_node, edge = nodes_edges[port]
210 if all([attr in edge and edge[attr] == edge_attrs[attr] for attr in edge_attrs]):
211 graph.remove_edge(src_node.id, node.id)
214 def remove_useless_split_action(graph: nx.MultiDiGraph, matches: dict):
215 split_node = matches['split']
216 input = split_node.in_node(1)
217 output = split_node.out_node()
218 graph.remove_edge(input.id, split_node.id)
220 for u, v, d in list(graph.out_edges(output.id, data=True)):
221 graph.add_edges_from([(input.id, v, d)])
222 graph.remove_edge(u, v)
225 def remove_useless_split(graph: nx.MultiDiGraph):
228 nodes=[('split', {'kind': 'op', 'op': 'Split', 'num_split': 1})],
230 action=remove_useless_split_action,
231 node_attrs=['kind', 'op', 'num_split'],
235 def remove_node_from_graph(graph: nx.MultiDiGraph, previous_node: Node, removing_node: Node):
236 if len(removing_node.out_nodes()) > 0:
237 last_node_out = removing_node.out_node(0)
238 edge_data = graph.get_edge_data(removing_node.id, last_node_out.id)
239 out_port = edge_data[0]['out']
240 in_port = edge_data[0]['in']
241 graph.remove_edge(previous_node.id, removing_node.id)
242 graph.remove_edge(removing_node.id, last_node_out.id)
243 create_edge(previous_node, last_node_out, out_port=out_port, in_port=in_port)
244 graph.remove_node(removing_node.id)