Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / middle / passes / infer.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18
19 import networkx as nx
20 import numpy as np
21
22 # TODO remove it
23 from mo.front.extractor import update_ie_fields
24 from mo.graph.graph import Node, Graph
25 from mo.graph.graph import dict_includes
26 from mo.middle.pattern_match import for_each_sub_graph
27 from mo.utils.error import Error
28 from mo.utils.utils import refer_to_faq_msg, shrink_str_value
29
30
31 def log_debug_dict(nodes_per_port: dict, direction_name: str):
32     for port, node in nodes_per_port.items():
33         value = shrink_str_value(node.soft_get('value'))
34         log.debug('{}[{}]: shape = {}, value = {}'.format(direction_name, port, node.soft_get('shape'), value))
35
36
37 def is_fully_defined_shape(shape: np.ndarray):
38     if -1 in shape:
39         return False
40     return True
41
42
43 def control_flow_infer(graph: Graph, node_name: str):
44     """
45        Executes constant control flow. Propagates nodes executability
46     """
47     if graph.node[node_name]['kind'] == 'data':
48         return
49
50     def mark_executability(node_id: str, is_executable: bool):
51         if is_executable and not graph.node[node_id]['executable']:
52             return
53         graph.node[node_id]['executable'] = is_executable
54
55     in_edges_with_data = graph.in_edges(node_name, data=True)
56     in_df_edges_with_data = [(u, v, attrs) for u, v, attrs in in_edges_with_data
57                              if 'control_flow_edge' not in attrs or not attrs['control_flow_edge']]
58     in_cf_edges_with_data = [(u, v, attrs) for u, v, attrs in in_edges_with_data
59                              if 'control_flow_edge' in attrs and attrs['control_flow_edge']]
60     is_executable_df = not all([not graph.node[u]['executable'] for u, _, attrs in in_df_edges_with_data]
61                                if len(in_df_edges_with_data) else [False])
62     is_executable_cf = not any([not graph.node[u]['executable'] for u, _, attrs in in_cf_edges_with_data]
63                                if len(in_cf_edges_with_data) else [False])
64     is_executable = is_executable_df and is_executable_cf
65
66     node = Node(graph, node_name)
67     if 'cf_infer' in graph.node[node_name] and callable(node.cf_infer):
68         node.cf_infer(node, is_executable, mark_executability)
69     else:
70         for _, out_data in graph.out_edges(node_name):
71             mark_executability(out_data, is_executable)
72
73
74 def exit_bound_edges(graph: Graph, sources: list, end_node_attrs: dict):
75     """
76     Finds all descendant nodes for each node from 'sources' that have given attributes from end_node_attrs.
77     For each found node, create a tuple with a given element from 'source' and the node.
78     """
79     result = []
80     for node in sources:
81         for end_node in nx.descendants(graph, node):
82             if dict_includes(big=graph.node[end_node], sub_dict=end_node_attrs):
83                 result.append((node, end_node, 0, {}))
84     return result
85
86
87 def partial_infer(graph: Graph, start_node: str = None):
88     """
89     Tries to execute constant parts of the graph and deduce as much as possible
90     information following the data flow, e.g. calculate and propagate shapes and
91     constant values. Partially or completely defined values are stored in data
92     nodes (kind='data').
93     """
94     cycle_nodes = graph.get_nodes_with_attributes(is_cyclic=True)
95     cycle_nodes = [Node(graph, node).out_node().id for node in cycle_nodes]
96     ebunch_cyclic = list(graph.out_edges(nbunch=cycle_nodes, data=True, keys=True))
97     ebunch_reconnected = exit_bound_edges(graph, sources=cycle_nodes, end_node_attrs={'op': 'Exit'})
98     graph.remove_edges_from(ebunch_cyclic)
99     graph.add_edges_from(ebunch_reconnected)
100
101     try:
102         nodes = list(nx.topological_sort(graph))
103     except:
104         raise Error('Graph contains a cycle. Can not proceed. ' + refer_to_faq_msg(97))
105
106     graph.remove_edges_from(ebunch_reconnected)
107     graph.add_edges_from(ebunch_cyclic)
108
109     # Mark all nodes as not inferred yet
110     if not start_node is None:
111         start_index = nodes.index(start_node)
112         nx.set_node_attributes(G=graph.subgraph(nodes[start_index:]), name='is_partial_inferred', values=False)
113     else:
114         nx.set_node_attributes(G=graph, name='is_partial_inferred', values=False)
115     debug_logger = log.getLogger().isEnabledFor(log.DEBUG)
116
117     nx.set_node_attributes(G=graph, name='executable',
118                            values={n: True for n in graph.get_nodes_with_attributes(kind='data')})
119
120     for n in nodes:
121         # Data Flow Infer
122         try:
123             node = Node(graph, n)
124             node_name = node.soft_get('name')
125             if node.has('is_partial_inferred') and not node.is_partial_inferred:
126                 if node.has('infer') and not node.infer is None:
127                     log.debug('-' * 20)
128                     log.debug('Partial infer for {}'.format(node.soft_get('name')))
129                     log.debug('Op: {}'.format(node.soft_get('op')))
130                     node.infer(node)
131                     out_nodes = node.out_nodes()
132
133                     # propagate nchw_layout attributes to data nodes
134                     if node.has('nchw_layout'):
135                         for out_node in out_nodes.values():
136                             out_node['nchw_layout'] = node.nchw_layout
137
138                     # In debug print current node attributes, input shapes/values and output shape/values
139                     if debug_logger:
140                         log.debug('Inputs:')
141                         log_debug_dict(node.in_nodes(), 'input')
142                         log.debug('Outputs:')
143                         log_debug_dict(node.out_nodes(), 'output')
144
145                     not_all_output_shapes = False
146
147                     for out_port, out_node in out_nodes.items():
148                         not_all_output_shapes = False
149                         if not out_node.has_valid('shape'):
150                             log.error('Shape is not defined for output {} of "{}".'.format(out_port, node_name))
151                             not_all_output_shapes = True
152                         elif not is_fully_defined_shape(out_node.shape):
153                             log.error(
154                                 ('Shape {} is not fully defined for output {} of "{}". ' +
155                                  'Use --input_shape with positive integers to override model input shapes.').format(
156                                     out_node.shape,
157                                     out_port,
158                                     node_name
159                                 )
160                             )
161                             not_all_output_shapes = True
162
163                     if not_all_output_shapes:
164                         raise Error('Not all output shapes were inferred or fully defined for node "{}". ' +
165                                     refer_to_faq_msg(40),
166                                     node_name)
167                 elif node.kind != 'data':
168                     raise Error(
169                         'There is no registered "infer" function for node "{}" with op = "{}". ' +
170                         'Please implement this function in the extensions. ' +
171                         refer_to_faq_msg(37),
172                         node_name,
173                         node.soft_get('op')
174                     )
175                 node.is_partial_inferred = True
176
177         except Exception as err:
178             log.error('Cannot infer shapes or values for node "{}".'.format(node.soft_get('name')))
179             log.error(str(err))
180             log.error('')
181             log.error('It can happen due to bug in custom shape infer function {}.'.format(node.soft_get('infer')))
182             log.error('Or because the node inputs have incorrect values/shapes.')
183             log.error('Or because input shapes are incorrect (embedded to the model or passed via --input_shape).')
184             debug_messages = '\n'.join(
185                 ['Layer "' + node_name + '": ' + node_attrs['debug_message'] for node_name, node_attrs in
186                  graph.nodes(data=True) if 'debug_message' in node_attrs])
187             if debug_messages != "":
188                 log.error('')
189                 log.error('Other possible failure reasons are listed below:')
190                 log.error(debug_messages)
191             if not debug_logger:
192                 log.error('Run Model Optimizer with --log_level=DEBUG for more information.')
193             else:
194                 log.debug('Node "{}" attributes: {}'.format(node.soft_get('name'), node.graph.node[node.id]))
195             raise Error('Stopped shape/value propagation at "{}" node. '.format(node.soft_get('name')) +
196                         refer_to_faq_msg(38)) from err
197         control_flow_infer(graph, n)
198
199     not_fully_inferred = graph.get_nodes_with_attributes(is_not_fully_inferred=True)
200     for n in not_fully_inferred:
201         node = Node(graph, n)
202         if node.has('infer') and not node.infer is None:
203             node.infer(node)
204
205     return graph
206
207
208 def override_batch(graph: Graph, batch: int):
209     """
210     Overrides batch for nodes with 'op' param set to 'Placeholder'
211     Parameters
212     ----------
213     graph: graph to operate on
214     batch: user defined integer value to override batch
215     """
216     if batch is not None:
217         for node_id, data in graph.nodes(data=True):
218             if 'op' in data and data['op'] == 'Placeholder' and not data.get('fixed_batch', False):
219                 if len(data['shape']) == 0 or data['shape'][0] not in (-1, 0, 1):
220                     raise Error(('The input layer {} has a shape {} defined in the model. \n\n' +
221                                  'When you use -b (--batch) option, Model Optimizer applies its value to the first ' +
222                                  'element of the shape if it is equal to -1, 0 or 1. Otherwise, this is the ambiguous ' +
223                                  'situation - Model Optimizer can not know in advance whether the layer has the batch ' +
224                                  'dimension or not.\n\n For example, you want to set batch dimension equals 100 ' +
225                                  'for the input layer "data" with shape (10,34). Although you can not use --batch, ' +
226                                  'you should pass --input_shape (100,34) instead of --batch 100. \n\n' +
227                                  refer_to_faq_msg(39))
228                                 .format(data['name'], data['shape']))
229                 data['shape'][0] = batch
230
231
232 def override_placeholder_shapes(graph: Graph, user_shapes: dict, batch=None):
233     """
234     This function overrides shapes for nodes with 'op' param set to 'Placeholder' with shapes defined by users (only
235     for inputs without in/out port specified).
236     And override batch if batch was specified and shape for input is not None.
237     :param graph: graph to operate on
238     :param user_shapes: dictionary, that represents user defined nodes and shapes
239     :param batch: user defined integer value to override batch
240     """
241     if user_shapes is None:
242         # DON'T MOVE UPPER!!! WE NEED TO SET BATCH FIRST
243         # user did not specify neither shapes nor inputs, keep models values
244         return
245     placeholders = graph.get_nodes_with_attributes(kind='op', op='Placeholder')
246     for node_id in placeholders:
247         node_attrs = graph.node[node_id]
248         shape = None
249         if node_id in user_shapes:
250             values = user_shapes[node_id]
251             for value in values:
252                 if 'in' not in value and 'out' not in value:
253                     shape = value['shape'] if value['shape'] is not None else None
254                     break  # we assume only one specified shape for one input
255         if shape is not None:
256             node_attrs['shape'] = shape
257         if batch is not None and node_attrs['shape'] is not None and len(node_attrs['shape']) > 0:
258             node_attrs['shape'][0] = batch
259
260
261 def update_fully_connected_shapes(graph: Graph):
262     nodes = nx.topological_sort(graph)
263     while True:
264         should_infer = False
265         for n in nodes:
266             node = Node(graph, n)
267             if node.has('type') and node.type == 'FullyConnected' and node.in_node(0).shape.size == 3:
268                 log.debug("node.in_node(0).shape = {}".format(node.in_node(0).shape))
269                 log.debug("channel_dims = {}".format(node.channel_dims))
270                 assert (node.in_node(0).shape.size == 3 and node.channel_dims > 0)
271                 node.in_node(0).shape = np.delete(node.in_node(0).shape, 1)
272                 if node.out_node().shape.size == 3:
273                     node.channel_dims = node.channel_dims - 1
274                     log.debug("Initiated partial infer from update_fully_connected_shapes")
275                     graph = partial_infer(graph, node.in_node(0).id)
276                     # Not working
277                     # graph = mark_dead_nodes(graph)
278                     # graph = eliminate_dead_nodes(graph)
279                     should_infer = True
280                     break
281         if not should_infer:
282             break
283
284
285 # Convert MUL operation to Power layer in case when
286 # mul op takes two inputs (scalar constant and tensor)
287 def convert_mul_add_to_power(graph: Graph):
288     for_each_sub_graph(graph, convert_mul_add_to_power)
289     nodes = list(graph.nodes())
290     for n in nodes:
291         # As we remove nodes from graph, we should check that node exists in graph
292         if n in graph:
293             node = Node(graph, n)
294             if node.has('op') and (node.op == 'Mul' or node.op == 'Add') and len(node.in_nodes()) == 2 and \
295                     node.soft_get('can_be_scaleshift') is not False:
296                 scalar_idx, tensor_idx = (0, 1) if not node.in_node(0).value is None else (1, 0)
297                 if not node.in_node(scalar_idx).value is None and node.in_node(tensor_idx).value is None:
298                     if np.squeeze(node.in_node(scalar_idx).value).ndim == 0:
299                         node['type'] = 'Power'
300                         node['scale'] = node.in_node(scalar_idx).value.item() if node.op == 'Mul' else 1
301                         node['power'] = 1
302                         node['shift'] = node.in_node(scalar_idx).value.item() if node.op == 'Add' else 0
303                         node['op'] = 'Power'
304                         if node.has('operation'):
305                             del node.graph.node[node.id]['operation']
306                         update_ie_fields(graph.node[node.id])
307                         scalar_node = node.in_node(scalar_idx)
308                         graph.remove_edge(scalar_node.id, node.id)
309                         graph.remove_node(scalar_node.id)