Publishing R3
[platform/upstream/dldt.git] / model-optimizer / mo / middle / passes / eliminate.py
1 """
2  Copyright (c) 2018 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import logging as log
17 from collections import deque
18
19 import networkx as nx
20
21 from mo.graph.graph import Node, create_edge
22 from mo.middle.pattern_match import apply_pattern
23 from mo.utils.graph import bfs_search, pseudo_topological_sort
24
25
26 def get_nodes_with_attributes(graph: nx.MultiDiGraph, **attrs: dict):
27     node_attrs = graph.nodes(data=True)
28     return [n for n, d in node_attrs if all(a in d.items() for a in attrs.items())]
29
30
31 def reverse_dfs(graph: nx.MultiDiGraph, node_name: str, update_func: callable, visited: set = None):
32     d = deque()
33
34     if visited is None:
35         visited = set()
36     visited.add(node_name)
37     d.appendleft(node_name)
38     while len(d) != 0:
39         cur_node = d.popleft()
40         update_func(graph, cur_node)
41         for in_node_name, _ in graph.in_edges(cur_node):
42             if in_node_name not in visited:
43                 visited.add(in_node_name)
44                 d.append(in_node_name)
45
46
47 def mark_input_nodes(graph: nx.MultiDiGraph, node_name: str, key: str, value):
48     for input, _ in graph.in_edges(node_name):
49         graph.node[input][key] = value
50
51
52 def mark_output_nodes(graph: nx.MultiDiGraph, node_name: str, key: str, value):
53     for output, _ in graph.out_edges(node_name):
54         graph.node[output][key] = value
55
56
57 def mark_output_reachable_nodes(graph: nx.MultiDiGraph):
58     """
59     Mark nodes whether they are outputs reachable or not. The node is considered output reachable if it is connected to
60     one of the nodes that has attribute is_output=True.
61     """
62     nx.set_node_attributes(graph, name='is_output_reachable', values=False)
63     outputs = get_nodes_with_attributes(graph, is_output=True)
64     log.debug('The following nodes are seeded as output reachable:\n{}'.format('\n'.join(sorted(map(str, outputs)))))
65     nx.set_node_attributes(graph, name='is_output_reachable', values={n: True for n in outputs})
66     for output_name in outputs:
67         reverse_dfs(graph, output_name,
68                     lambda graph, node_name: mark_input_nodes(graph, node_name, 'is_output_reachable', True), set())
69
70
71 def mark_undead_nodes(graph: nx.MultiDiGraph, undead_types: list):
72     """
73     Mark output nodes and nodes of the specific type as undead, meaning that they should survive the dead nodes
74     elimination phase. Then mark all children nodes of the undead nodes (except children of inputs) as undead.
75     :param graph: graph to operate on.
76     :param undead_types: list of node types that should be marked as undead.
77     :return: updated graph where each has attribute 'is_undead'.
78     """
79     nx.set_node_attributes(graph, name='is_undead', values=False)
80
81     # mark output nodes as undead
82     outputs = get_nodes_with_attributes(graph, is_output=True)
83     nx.set_node_attributes(graph, name='is_undead', values={n: True for n in outputs})
84
85     # mark specifically defined with node type set of nodes
86     for type in undead_types:
87         node_of_specific_type = get_nodes_with_attributes(graph, type=type)
88         nx.set_node_attributes(graph, name='is_undead', values={n: True for n in node_of_specific_type})
89
90     undead_nodes = get_nodes_with_attributes(graph, is_undead=True)
91     # propagate 'undead' attribute to children nodes of undead nodes if the node produces constant value
92     for node_name in bfs_search(graph, undead_nodes):
93         if graph.node[node_name]['is_undead']:
94             for _, dst_node_name in graph.out_edges(node_name):
95                 node_attrs = graph.node[dst_node_name]
96                 if 'kind' in node_attrs and node_attrs['kind'] == 'data' and node_attrs['value'] is not None:
97                     graph.node[dst_node_name]['is_undead'] = True
98
99     # mark input nodes as undead
100     inputs = get_nodes_with_attributes(graph, is_input=True)
101     nx.set_node_attributes(graph, name='is_undead', values={n: True for n in inputs})
102
103
104 def mark_const_producer_nodes(graph: nx.MultiDiGraph):
105     """
106     Mark nodes that produce constant values.
107     :param graph: graph to operate on.
108     :return: .
109     """
110     nx.set_node_attributes(graph, name='is_const_producer', values=True)
111
112     for n in pseudo_topological_sort(graph):
113         node = Node(graph, n)
114         for input, output, attrs in graph.in_edges(n, data=True):
115             if 'control_flow_edge' in attrs and attrs['control_flow_edge']:
116                 graph.node[input]['is_const_producer'] = False
117                 graph.node[output]['is_const_producer'] = False
118
119         if not node.has('value') or node.value is None:
120             for input, _ in graph.in_edges(n):
121                 graph.node[input]['is_const_producer'] = False
122
123
124 def eliminate_dead_nodes(graph: nx.MultiDiGraph):
125     nodes_to_remove = set()
126     for node_name, node_attrs in graph.nodes(data=True):
127         if not node_attrs['is_output_reachable'] or (node_attrs['is_const_producer'] and not node_attrs['is_undead']):
128             nodes_to_remove.add(node_name)
129     log.debug('Removing the following dead nodes: {}'.format('\n'.join(sorted(map(str, nodes_to_remove)))))
130     graph.remove_nodes_from(nodes_to_remove)
131
132
133 def graph_clean_up(graph: nx.MultiDiGraph, undead_node_types: list = []):
134     mark_output_reachable_nodes(graph)
135     mark_undead_nodes(graph, undead_node_types)
136     mark_const_producer_nodes(graph)
137     eliminate_dead_nodes(graph)
138
139
140 def graph_clean_up_tf(graph: nx.MultiDiGraph):
141     graph_clean_up(graph, ['TFCustomSubgraphCall'])
142
143
144 def remove_identity_action(graph: nx.MultiDiGraph, matches: dict):
145     remove_op_node(graph, matches['identity'])
146
147
148 # TODO: unit tests
149 def merge_data_nodes(graph: nx.MultiDiGraph, survived: Node, removed: Node):
150     if survived.has_and_set('is_output'):
151         graph.node[removed.id].update({'is_output': True})
152
153     for u, v, d in list(graph.in_edges(removed.id, data=True)):
154         graph.add_edges_from([(u, survived.id, d)])
155         graph.remove_edge(u, v)
156
157     for u, v, d in list(graph.out_edges(removed.id, data=True)):
158         graph.add_edges_from([(survived.id, v, d)])
159         graph.remove_edge(u, v)
160
161     for attr in graph.node[removed.id]:
162         if not attr in ['name']:
163             # We need to save debug info from removed data node
164             if attr == 'fw_tensor_debug_info':
165                 if not survived.has_valid(attr):
166                     survived[attr] = []
167                 for fw_tensor_debug_info in removed[attr]:
168                     survived[attr].append(fw_tensor_debug_info)
169             else:
170                 survived[attr] = removed[attr]
171
172
173 # TODO: unit tests
174 def remove_op_node(graph: nx.MultiDiGraph, identity: Node):
175     input = identity.in_node()
176     output = [v for _, v in graph.out_edges(identity.id)]
177     assert len(output) == 1
178     output = Node(graph, output[0])
179
180     graph.remove_edge(input.id, identity.id)
181     graph.remove_edge(identity.id, output.id)
182
183     merge_data_nodes(graph, output, input)
184
185     # we just have saved all output edges from 'input' by reconnecting them to 'output', now we can delete 'input'
186     log.debug('Removing op node: {}'.format(identity.id))
187     graph.remove_node(identity.id)
188     graph.remove_node(input.id)
189
190
191 def remove_op_nodes(graph: nx.MultiDiGraph, attrs: dict):
192     op_attrs = {'kind': 'op'}
193     op_attrs.update(attrs)
194     apply_pattern(
195         graph,
196         nodes=[('identity', op_attrs)],
197         edges=[],
198         action=remove_identity_action,
199         node_attrs=['kind'] + list(attrs.keys()),
200         edge_attrs=[])
201
202
203 def remove_edges_for_nodes(graph: nx.MultiDiGraph, node_attrs: dict, edge_attrs: dict):
204     for node in graph.nodes():
205         node = Node(graph, node)
206         if all([node.has(attr) and node[attr] == node_attrs[attr] for attr in node_attrs]):
207             nodes_edges = node.in_nodes_edges()
208             for port in nodes_edges:
209                 src_node, edge = nodes_edges[port]
210                 if all([attr in edge and edge[attr] == edge_attrs[attr] for attr in edge_attrs]):
211                     graph.remove_edge(src_node.id, node.id)
212
213
214 def remove_useless_split_action(graph: nx.MultiDiGraph, matches: dict):
215     split_node = matches['split']
216     input = split_node.in_node(1)
217     output = split_node.out_node()
218     graph.remove_edge(input.id, split_node.id)
219
220     for u, v, d in list(graph.out_edges(output.id, data=True)):
221         graph.add_edges_from([(input.id, v, d)])
222         graph.remove_edge(u, v)
223
224
225 def remove_useless_split(graph: nx.MultiDiGraph):
226     apply_pattern(
227         graph,
228         nodes=[('split', {'kind': 'op', 'op': 'Split', 'num_split': 1})],
229         edges=[],
230         action=remove_useless_split_action,
231         node_attrs=['kind', 'op', 'num_split'],
232         edge_attrs=[])
233
234
235 def remove_node_from_graph(graph: nx.MultiDiGraph, previous_node: Node, removing_node: Node):
236     if len(removing_node.out_nodes()) > 0:
237         last_node_out = removing_node.out_node(0)
238         edge_data = graph.get_edge_data(removing_node.id, last_node_out.id)
239         out_port = edge_data[0]['out']
240         in_port = edge_data[0]['in']
241         graph.remove_edge(previous_node.id, removing_node.id)
242         graph.remove_edge(removing_node.id, last_node_out.id)
243         create_edge(previous_node, last_node_out, out_port=out_port, in_port=in_port)
244         graph.remove_node(removing_node.id)