2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
17 from collections import deque
22 from mo.graph.graph import Node, Graph
23 from mo.middle.pattern_match import apply_pattern
24 from mo.utils.error import Error
25 from mo.utils.graph import bfs_search, pseudo_topological_sort
29 def get_nodes_with_attributes(graph: Graph, **attrs: dict):
30 node_attrs = graph.nodes(data=True)
31 return [n for n, d in node_attrs if all(a in d.items() for a in attrs.items())]
34 def reverse_dfs(graph: Graph, node_name: str, update_func: callable, visited: set = None):
39 visited.add(node_name)
40 d.appendleft(node_name)
42 cur_node = d.popleft()
43 update_func(graph, cur_node)
44 for in_node_name, _ in graph.in_edges(cur_node):
45 if in_node_name not in visited:
46 visited.add(in_node_name)
47 d.append(in_node_name)
50 def mark_input_nodes(graph: Graph, node_name: str, key: str, value):
51 for input, _ in graph.in_edges(node_name):
52 graph.node[input][key] = value
55 def mark_output_nodes(graph: Graph, node_name: str, key: str, value):
56 for output, _ in graph.out_edges(node_name):
57 graph.node[output][key] = value
60 def mark_output_reachable_nodes(graph: Graph):
62 Mark nodes whether they are outputs reachable or not. The node is considered output reachable if it is connected to
63 one of the nodes that has attribute op=OpOutput.
65 nx.set_node_attributes(G=graph, name='is_output_reachable', values=False)
66 outputs = graph.get_nodes_with_attributes(op='OpOutput')
67 log.debug('The following nodes are seeded as output reachable:\n{}'.format('\n'.join(sorted(map(str, outputs)))))
68 nx.set_node_attributes(G=graph, name='is_output_reachable', values={n: True for n in outputs})
70 for output_name in outputs:
71 reverse_dfs(graph, output_name,
72 lambda graph, node_name: mark_input_nodes(graph, node_name, 'is_output_reachable', True), visited)
75 def mark_undead_nodes(graph: Graph, undead_types: list):
77 Mark output nodes and nodes of the specific type as undead, meaning that they should survive the dead nodes
78 elimination phase. Then mark all children nodes of the undead nodes (except children of inputs) as undead.
79 :param graph: graph to operate on.
80 :param undead_types: list of node types that should be marked as undead.
81 :return: updated graph where each has attribute 'is_undead'.
83 nx.set_node_attributes(G=graph, name='is_undead', values=False)
85 # mark output nodes as undead
86 outputs = graph.get_nodes_with_attributes(op='OpOutput')
87 nx.set_node_attributes(G=graph, name='is_undead', values={n: True for n in outputs})
89 # mark specifically defined with node type set of nodes
90 for type in undead_types:
91 node_of_specific_type = graph.get_nodes_with_attributes(type=type)
92 nx.set_node_attributes(G=graph, name='is_undead', values={n: True for n in node_of_specific_type})
94 undead_nodes = graph.get_nodes_with_attributes(is_undead=True)
95 # propagate 'undead' attribute to children nodes of undead nodes if the node produces constant value
96 for node_name in bfs_search(graph, undead_nodes):
97 if graph.node[node_name]['is_undead']:
98 for _, dst_node_name in graph.out_edges(node_name):
99 node_attrs = graph.node[dst_node_name]
100 if 'kind' in node_attrs and (
101 node_attrs['kind'] == 'data' and node_attrs['value'] is not None or node_attrs['kind'] == 'op'):
102 graph.node[dst_node_name]['is_undead'] = True
104 # mark input nodes as undead
105 inputs = graph.get_nodes_with_attributes(is_input=True)
106 nx.set_node_attributes(G=graph, name='is_undead', values={n: True for n in inputs})
109 def mark_const_producer_nodes(graph: Graph):
111 Mark nodes that produce constant values.
112 :param graph: graph to operate on.
115 nx.set_node_attributes(G=graph, name='is_const_producer', values=True)
117 for n in pseudo_topological_sort(graph):
118 node = Node(graph, n)
119 for input, output, attrs in graph.in_edges(n, data=True):
120 if 'control_flow_edge' in attrs and attrs['control_flow_edge']:
121 graph.node[input]['is_const_producer'] = False
122 graph.node[output]['is_const_producer'] = False
124 if not node.has('value') or node.value is None:
125 for input, _ in graph.in_edges(n):
126 graph.node[input]['is_const_producer'] = False
129 def eliminate_dead_nodes(graph: Graph):
130 nodes_to_remove = set()
131 for node_name, node_attrs in graph.nodes(data=True):
132 if not node_attrs['is_output_reachable'] or (node_attrs['is_const_producer'] and not node_attrs['is_undead']):
133 nodes_to_remove.add(node_name)
134 log.debug('Removing the following dead nodes: {}'.format('\n'.join(sorted(map(str, nodes_to_remove)))))
135 graph.remove_nodes_from(nodes_to_remove)
138 def add_constant_operations(graph: Graph):
139 data_nodes = graph.get_data_nodes(has_value=True)
140 for node in data_nodes:
141 # If data node has no producers we create Const operation
142 if len(node.in_nodes()) == 0 and len(node.out_nodes()) != 0:
143 # It's necessary to import here due to cycle dependencies
144 from mo.ops.const import Const
145 Const(graph, dict(value=node.value, shape=np.array(node.value.shape))).create_node_with_data(data_nodes=node)
148 def remove_const_ops(graph: Graph):
149 ops = [node for node in graph.get_op_nodes() if node.soft_get('type') == 'Const']
151 graph.remove_edge(node.id, node.out_node().id)
152 graph.remove_node(node.id)
155 def shape_inference(graph: Graph):
156 nodes = pseudo_topological_sort(graph)
158 node = Node(graph, node)
159 if node.has_and_set('need_shape_inference'):
160 old_out_shapes = [port.data.get_shape() for port in node.out_ports().values()]
162 new_out_shapes = [port.data.get_shape() for port in node.out_ports().values()]
163 for shape1, shape2 in zip(old_out_shapes, new_out_shapes):
164 if shape1 is not None and not np.array_equal(shape1, shape2):
165 raise Error("After partial shape inference were found shape collision for node {} (old shape: {}, new shape: {})".format(node.name, shape1, shape2))
166 node.need_shape_inference = False
169 def graph_clean_up(graph: Graph, undead_node_types: list = None):
170 if undead_node_types is None:
171 undead_node_types = []
173 if 'Shape' in undead_node_types and not graph.graph['cmd_params'].keep_shape_ops:
174 undead_node_types.remove('Shape')
176 mark_output_reachable_nodes(graph)
177 mark_undead_nodes(graph, undead_node_types)
178 mark_const_producer_nodes(graph)
179 eliminate_dead_nodes(graph)
180 # Add Const op for constant data nodes
181 add_constant_operations(graph)
182 shape_inference(graph)
185 def graph_clean_up_tf(graph: Graph):
186 graph_clean_up(graph, ['TFCustomSubgraphCall', 'Shape'])
189 def graph_clean_up_onnx(graph: Graph):
190 graph_clean_up(graph, ['Shape'])
193 def remove_identity_action(graph: Graph, matches: dict):
194 remove_op_node_with_data_node(graph, matches['identity'])
198 def merge_data_nodes(graph: Graph, survived: Node, removed: Node):
199 if survived.has_and_set('op') and survived.op == 'OpOutput':
200 graph.node[removed.id].update({'op': 'OpOutput'})
202 for u, v, d in list(graph.in_edges(removed.id, data=True)):
203 graph.add_edges_from([(u, survived.id, d)])
204 graph.remove_edge(u, v)
206 for u, v, d in list(graph.out_edges(removed.id, data=True)):
207 graph.add_edges_from([(survived.id, v, d)])
208 graph.remove_edge(u, v)
210 for attr in graph.node[removed.id]:
211 if not attr in ['name']:
212 # We need to save debug info from removed data node
213 if attr == 'fw_tensor_debug_info':
214 if not survived.has_valid(attr):
216 for fw_tensor_debug_info in removed[attr]:
217 survived[attr].append(fw_tensor_debug_info)
219 survived[attr] = removed[attr]
223 def remove_op_node_with_data_node(graph: Graph, node_to_remove: Node):
224 assert node_to_remove.kind == 'op'
225 input_data_node = node_to_remove.in_node()
226 output_node = [v for _, v in graph.out_edges(node_to_remove.id)]
227 assert len(output_node) == 1, "Cannot remove node producing two or more output tensors"
228 output_node = Node(graph, output_node[0])
229 assert output_node.kind == 'data', "The function must be used after partial infer"
231 graph.remove_edge(input_data_node.id, node_to_remove.id)
232 graph.remove_edge(node_to_remove.id, output_node.id)
234 merge_data_nodes(graph, output_node, input_data_node)
236 # we just have saved all output edges from 'input' by reconnecting them to 'output', now we can delete 'input'
237 log.debug('Removing op node: {}'.format(node_to_remove.id))
238 graph.remove_nodes_from([node_to_remove.id, input_data_node.id])
241 def remove_op_nodes(graph: Graph, attrs: dict):
242 op_attrs = {'kind': 'op'}
243 op_attrs.update(attrs)
246 nodes=[('identity', op_attrs)],
248 action=remove_identity_action
252 def remove_edges_for_nodes(graph: Graph, node_attrs: dict, edge_attrs: dict):
253 for node in graph.nodes():
254 node = Node(graph, node)
255 if all([node.has(attr) and node[attr] == node_attrs[attr] for attr in node_attrs]):
256 nodes_edges = node.in_nodes_edges()
257 for port in nodes_edges:
258 src_node, edge = nodes_edges[port]
259 if all([attr in edge and edge[attr] == edge_attrs[attr] for attr in edge_attrs]):
260 graph.remove_edge(src_node.id, node.id)