Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / middle / passes / eliminate.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import logging as log
17 from collections import deque
18
19 import networkx as nx
20 import numpy as np
21
22 from mo.graph.graph import Node, Graph
23 from mo.middle.pattern_match import apply_pattern
24 from mo.utils.error import Error
25 from mo.utils.graph import bfs_search, pseudo_topological_sort
26
27
28 # TODO: dep warning
29 def get_nodes_with_attributes(graph: Graph, **attrs: dict):
30     node_attrs = graph.nodes(data=True)
31     return [n for n, d in node_attrs if all(a in d.items() for a in attrs.items())]
32
33
34 def reverse_dfs(graph: Graph, node_name: str, update_func: callable, visited: set = None):
35     d = deque()
36
37     if visited is None:
38         visited = set()
39     visited.add(node_name)
40     d.appendleft(node_name)
41     while len(d) != 0:
42         cur_node = d.popleft()
43         update_func(graph, cur_node)
44         for in_node_name, _ in graph.in_edges(cur_node):
45             if in_node_name not in visited:
46                 visited.add(in_node_name)
47                 d.append(in_node_name)
48
49
50 def mark_input_nodes(graph: Graph, node_name: str, key: str, value):
51     for input, _ in graph.in_edges(node_name):
52         graph.node[input][key] = value
53
54
55 def mark_output_nodes(graph: Graph, node_name: str, key: str, value):
56     for output, _ in graph.out_edges(node_name):
57         graph.node[output][key] = value
58
59
60 def mark_output_reachable_nodes(graph: Graph):
61     """
62     Mark nodes whether they are outputs reachable or not. The node is considered output reachable if it is connected to
63     one of the nodes that has attribute op=OpOutput.
64     """
65     nx.set_node_attributes(G=graph, name='is_output_reachable', values=False)
66     outputs = graph.get_nodes_with_attributes(op='OpOutput')
67     log.debug('The following nodes are seeded as output reachable:\n{}'.format('\n'.join(sorted(map(str, outputs)))))
68     nx.set_node_attributes(G=graph, name='is_output_reachable', values={n: True for n in outputs})
69     visited = set()
70     for output_name in outputs:
71         reverse_dfs(graph, output_name,
72                     lambda graph, node_name: mark_input_nodes(graph, node_name, 'is_output_reachable', True), visited)
73
74
75 def mark_undead_nodes(graph: Graph, undead_types: list):
76     """
77     Mark output nodes and nodes of the specific type as undead, meaning that they should survive the dead nodes
78     elimination phase. Then mark all children nodes of the undead nodes (except children of inputs) as undead.
79     :param graph: graph to operate on.
80     :param undead_types: list of node types that should be marked as undead.
81     :return: updated graph where each has attribute 'is_undead'.
82     """
83     nx.set_node_attributes(G=graph, name='is_undead', values=False)
84
85     # mark output nodes as undead
86     outputs = graph.get_nodes_with_attributes(op='OpOutput')
87     nx.set_node_attributes(G=graph, name='is_undead', values={n: True for n in outputs})
88
89     # mark specifically defined with node type set of nodes
90     for type in undead_types:
91         node_of_specific_type = graph.get_nodes_with_attributes(type=type)
92         nx.set_node_attributes(G=graph, name='is_undead', values={n: True for n in node_of_specific_type})
93
94     undead_nodes = graph.get_nodes_with_attributes(is_undead=True)
95     # propagate 'undead' attribute to children nodes of undead nodes if the node produces constant value
96     for node_name in bfs_search(graph, undead_nodes):
97         if graph.node[node_name]['is_undead']:
98             for _, dst_node_name in graph.out_edges(node_name):
99                 node_attrs = graph.node[dst_node_name]
100                 if 'kind' in node_attrs and (
101                         node_attrs['kind'] == 'data' and node_attrs['value'] is not None or node_attrs['kind'] == 'op'):
102                     graph.node[dst_node_name]['is_undead'] = True
103
104     # mark input nodes as undead
105     inputs = graph.get_nodes_with_attributes(is_input=True)
106     nx.set_node_attributes(G=graph, name='is_undead', values={n: True for n in inputs})
107
108
109 def mark_const_producer_nodes(graph: Graph):
110     """
111     Mark nodes that produce constant values.
112     :param graph: graph to operate on.
113     :return: .
114     """
115     nx.set_node_attributes(G=graph, name='is_const_producer', values=True)
116
117     for n in pseudo_topological_sort(graph):
118         node = Node(graph, n)
119         for input, output, attrs in graph.in_edges(n, data=True):
120             if 'control_flow_edge' in attrs and attrs['control_flow_edge']:
121                 graph.node[input]['is_const_producer'] = False
122                 graph.node[output]['is_const_producer'] = False
123
124         if not node.has('value') or node.value is None:
125             for input, _ in graph.in_edges(n):
126                 graph.node[input]['is_const_producer'] = False
127
128
129 def eliminate_dead_nodes(graph: Graph):
130     nodes_to_remove = set()
131     for node_name, node_attrs in graph.nodes(data=True):
132         if not node_attrs['is_output_reachable'] or (node_attrs['is_const_producer'] and not node_attrs['is_undead']):
133             nodes_to_remove.add(node_name)
134     log.debug('Removing the following dead nodes: {}'.format('\n'.join(sorted(map(str, nodes_to_remove)))))
135     graph.remove_nodes_from(nodes_to_remove)
136
137
138 def add_constant_operations(graph: Graph):
139     data_nodes = graph.get_data_nodes(has_value=True)
140     for node in data_nodes:
141         # If data node has no producers we create Const operation
142         if len(node.in_nodes()) == 0 and len(node.out_nodes()) != 0:
143             # It's necessary to import here due to cycle dependencies
144             from mo.ops.const import Const
145             Const(graph, dict(value=node.value, shape=np.array(node.value.shape))).create_node_with_data(data_nodes=node)
146
147
148 def remove_const_ops(graph: Graph):
149     ops = [node for node in graph.get_op_nodes() if node.soft_get('type') == 'Const']
150     for node in ops:
151         graph.remove_edge(node.id, node.out_node().id)
152         graph.remove_node(node.id)
153
154
155 def shape_inference(graph: Graph):
156     nodes = pseudo_topological_sort(graph)
157     for node in nodes:
158         node = Node(graph, node)
159         if node.has_and_set('need_shape_inference'):
160             old_out_shapes = [port.data.get_shape() for port in node.out_ports().values()]
161             node.infer(node)
162             new_out_shapes = [port.data.get_shape() for port in node.out_ports().values()]
163             for shape1, shape2 in zip(old_out_shapes, new_out_shapes):
164                 if shape1 is not None and not np.array_equal(shape1, shape2):
165                     raise Error("After partial shape inference were found shape collision for node {} (old shape: {}, new shape: {})".format(node.name, shape1, shape2))
166             node.need_shape_inference = False
167
168
169 def graph_clean_up(graph: Graph, undead_node_types: list = None):
170     if undead_node_types is None:
171         undead_node_types = []
172
173     if 'Shape' in undead_node_types and not graph.graph['cmd_params'].keep_shape_ops:
174         undead_node_types.remove('Shape')
175
176     mark_output_reachable_nodes(graph)
177     mark_undead_nodes(graph, undead_node_types)
178     mark_const_producer_nodes(graph)
179     eliminate_dead_nodes(graph)
180     # Add Const op for constant data nodes
181     add_constant_operations(graph)
182     shape_inference(graph)
183
184
185 def graph_clean_up_tf(graph: Graph):
186     graph_clean_up(graph, ['TFCustomSubgraphCall', 'Shape'])
187
188
189 def graph_clean_up_onnx(graph: Graph):
190     graph_clean_up(graph, ['Shape'])
191
192
193 def remove_identity_action(graph: Graph, matches: dict):
194     remove_op_node_with_data_node(graph, matches['identity'])
195
196
197 # TODO: unit tests
198 def merge_data_nodes(graph: Graph, survived: Node, removed: Node):
199     if survived.has_and_set('op') and survived.op == 'OpOutput':
200         graph.node[removed.id].update({'op': 'OpOutput'})
201
202     for u, v, d in list(graph.in_edges(removed.id, data=True)):
203         graph.add_edges_from([(u, survived.id, d)])
204         graph.remove_edge(u, v)
205
206     for u, v, d in list(graph.out_edges(removed.id, data=True)):
207         graph.add_edges_from([(survived.id, v, d)])
208         graph.remove_edge(u, v)
209
210     for attr in graph.node[removed.id]:
211         if not attr in ['name']:
212             # We need to save debug info from removed data node
213             if attr == 'fw_tensor_debug_info':
214                 if not survived.has_valid(attr):
215                     survived[attr] = []
216                 for fw_tensor_debug_info in removed[attr]:
217                     survived[attr].append(fw_tensor_debug_info)
218             else:
219                 survived[attr] = removed[attr]
220
221
222 # TODO: unit tests
223 def remove_op_node_with_data_node(graph: Graph, node_to_remove: Node):
224     assert node_to_remove.kind == 'op'
225     input_data_node = node_to_remove.in_node()
226     output_node = [v for _, v in graph.out_edges(node_to_remove.id)]
227     assert len(output_node) == 1, "Cannot remove node producing two or more output tensors"
228     output_node = Node(graph, output_node[0])
229     assert output_node.kind == 'data', "The function must be used after partial infer"
230
231     graph.remove_edge(input_data_node.id, node_to_remove.id)
232     graph.remove_edge(node_to_remove.id, output_node.id)
233
234     merge_data_nodes(graph, output_node, input_data_node)
235
236     # we just have saved all output edges from 'input' by reconnecting them to 'output', now we can delete 'input'
237     log.debug('Removing op node: {}'.format(node_to_remove.id))
238     graph.remove_nodes_from([node_to_remove.id, input_data_node.id])
239
240
241 def remove_op_nodes(graph: Graph, attrs: dict):
242     op_attrs = {'kind': 'op'}
243     op_attrs.update(attrs)
244     apply_pattern(
245         graph,
246         nodes=[('identity', op_attrs)],
247         edges=[],
248         action=remove_identity_action
249     )
250
251
252 def remove_edges_for_nodes(graph: Graph, node_attrs: dict, edge_attrs: dict):
253     for node in graph.nodes():
254         node = Node(graph, node)
255         if all([node.has(attr) and node[attr] == node_attrs[attr] for attr in node_attrs]):
256             nodes_edges = node.in_nodes_edges()
257             for port in nodes_edges:
258                 src_node, edge = nodes_edges[port]
259                 if all([attr in edge and edge[attr] == edge_attrs[attr] for attr in edge_attrs]):
260                     graph.remove_edge(src_node.id, node.id)
261
262