2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
23 from mo.front.extractor import update_ie_fields
24 from mo.graph.graph import Node, Graph
25 from mo.graph.graph import dict_includes
26 from mo.middle.pattern_match import for_each_sub_graph
27 from mo.utils.error import Error
28 from mo.utils.utils import refer_to_faq_msg, shrink_str_value
31 def log_debug_dict(nodes_per_port: dict, direction_name: str):
32 for port, node in nodes_per_port.items():
33 value = shrink_str_value(node.soft_get('value'))
34 log.debug('{}[{}]: shape = {}, value = {}'.format(direction_name, port, node.soft_get('shape'), value))
37 def is_fully_defined_shape(shape: np.ndarray):
43 def control_flow_infer(graph: Graph, node_name: str):
45 Executes constant control flow. Propagates nodes executability
47 if graph.node[node_name]['kind'] == 'data':
50 def mark_executability(node_id: str, is_executable: bool):
51 if is_executable and not graph.node[node_id]['executable']:
53 graph.node[node_id]['executable'] = is_executable
55 in_edges_with_data = graph.in_edges(node_name, data=True)
56 in_df_edges_with_data = [(u, v, attrs) for u, v, attrs in in_edges_with_data
57 if 'control_flow_edge' not in attrs or not attrs['control_flow_edge']]
58 in_cf_edges_with_data = [(u, v, attrs) for u, v, attrs in in_edges_with_data
59 if 'control_flow_edge' in attrs and attrs['control_flow_edge']]
60 is_executable_df = not all([not graph.node[u]['executable'] for u, _, attrs in in_df_edges_with_data]
61 if len(in_df_edges_with_data) else [False])
62 is_executable_cf = not any([not graph.node[u]['executable'] for u, _, attrs in in_cf_edges_with_data]
63 if len(in_cf_edges_with_data) else [False])
64 is_executable = is_executable_df and is_executable_cf
66 node = Node(graph, node_name)
67 if 'cf_infer' in graph.node[node_name] and callable(node.cf_infer):
68 node.cf_infer(node, is_executable, mark_executability)
70 for _, out_data in graph.out_edges(node_name):
71 mark_executability(out_data, is_executable)
74 def exit_bound_edges(graph: Graph, sources: list, end_node_attrs: dict):
76 Finds all descendant nodes for each node from 'sources' that have given attributes from end_node_attrs.
77 For each found node, create a tuple with a given element from 'source' and the node.
81 for end_node in nx.descendants(graph, node):
82 if dict_includes(big=graph.node[end_node], sub_dict=end_node_attrs):
83 result.append((node, end_node, 0, {}))
87 def partial_infer(graph: Graph, start_node: str = None):
89 Tries to execute constant parts of the graph and deduce as much as possible
90 information following the data flow, e.g. calculate and propagate shapes and
91 constant values. Partially or completely defined values are stored in data
94 cycle_nodes = graph.get_nodes_with_attributes(is_cyclic=True)
95 cycle_nodes = [Node(graph, node).out_node().id for node in cycle_nodes]
96 ebunch_cyclic = list(graph.out_edges(nbunch=cycle_nodes, data=True, keys=True))
97 ebunch_reconnected = exit_bound_edges(graph, sources=cycle_nodes, end_node_attrs={'op': 'Exit'})
98 graph.remove_edges_from(ebunch_cyclic)
99 graph.add_edges_from(ebunch_reconnected)
102 nodes = list(nx.topological_sort(graph))
104 raise Error('Graph contains a cycle. Can not proceed. ' + refer_to_faq_msg(97))
106 graph.remove_edges_from(ebunch_reconnected)
107 graph.add_edges_from(ebunch_cyclic)
109 # Mark all nodes as not inferred yet
110 if not start_node is None:
111 start_index = nodes.index(start_node)
112 nx.set_node_attributes(G=graph.subgraph(nodes[start_index:]), name='is_partial_inferred', values=False)
114 nx.set_node_attributes(G=graph, name='is_partial_inferred', values=False)
115 debug_logger = log.getLogger().isEnabledFor(log.DEBUG)
117 nx.set_node_attributes(G=graph, name='executable',
118 values={n: True for n in graph.get_nodes_with_attributes(kind='data')})
123 node = Node(graph, n)
124 node_name = node.soft_get('name')
125 if node.has('is_partial_inferred') and not node.is_partial_inferred:
126 if node.has('infer') and not node.infer is None:
128 log.debug('Partial infer for {}'.format(node.soft_get('name')))
129 log.debug('Op: {}'.format(node.soft_get('op')))
131 out_nodes = node.out_nodes()
133 # propagate nchw_layout attributes to data nodes
134 if node.has('nchw_layout'):
135 for out_node in out_nodes.values():
136 out_node['nchw_layout'] = node.nchw_layout
138 # In debug print current node attributes, input shapes/values and output shape/values
141 log_debug_dict(node.in_nodes(), 'input')
142 log.debug('Outputs:')
143 log_debug_dict(node.out_nodes(), 'output')
145 not_all_output_shapes = False
147 for out_port, out_node in out_nodes.items():
148 not_all_output_shapes = False
149 if not out_node.has_valid('shape'):
150 log.error('Shape is not defined for output {} of "{}".'.format(out_port, node_name))
151 not_all_output_shapes = True
152 elif not is_fully_defined_shape(out_node.shape):
154 ('Shape {} is not fully defined for output {} of "{}". ' +
155 'Use --input_shape with positive integers to override model input shapes.').format(
161 not_all_output_shapes = True
163 if not_all_output_shapes:
164 raise Error('Not all output shapes were inferred or fully defined for node "{}". ' +
165 refer_to_faq_msg(40),
167 elif node.kind != 'data':
169 'There is no registered "infer" function for node "{}" with op = "{}". ' +
170 'Please implement this function in the extensions. ' +
171 refer_to_faq_msg(37),
175 node.is_partial_inferred = True
177 except Exception as err:
178 log.error('Cannot infer shapes or values for node "{}".'.format(node.soft_get('name')))
181 log.error('It can happen due to bug in custom shape infer function {}.'.format(node.soft_get('infer')))
182 log.error('Or because the node inputs have incorrect values/shapes.')
183 log.error('Or because input shapes are incorrect (embedded to the model or passed via --input_shape).')
184 debug_messages = '\n'.join(
185 ['Layer "' + node_name + '": ' + node_attrs['debug_message'] for node_name, node_attrs in
186 graph.nodes(data=True) if 'debug_message' in node_attrs])
187 if debug_messages != "":
189 log.error('Other possible failure reasons are listed below:')
190 log.error(debug_messages)
192 log.error('Run Model Optimizer with --log_level=DEBUG for more information.')
194 log.debug('Node "{}" attributes: {}'.format(node.soft_get('name'), node.graph.node[node.id]))
195 raise Error('Stopped shape/value propagation at "{}" node. '.format(node.soft_get('name')) +
196 refer_to_faq_msg(38)) from err
197 control_flow_infer(graph, n)
199 not_fully_inferred = graph.get_nodes_with_attributes(is_not_fully_inferred=True)
200 for n in not_fully_inferred:
201 node = Node(graph, n)
202 if node.has('infer') and not node.infer is None:
208 def override_batch(graph: Graph, batch: int):
210 Overrides batch for nodes with 'op' param set to 'Placeholder'
213 graph: graph to operate on
214 batch: user defined integer value to override batch
216 if batch is not None:
217 for node_id, data in graph.nodes(data=True):
218 if 'op' in data and data['op'] == 'Placeholder' and not data.get('fixed_batch', False):
219 if len(data['shape']) == 0 or data['shape'][0] not in (-1, 0, 1):
220 raise Error(('The input layer {} has a shape {} defined in the model. \n\n' +
221 'When you use -b (--batch) option, Model Optimizer applies its value to the first ' +
222 'element of the shape if it is equal to -1, 0 or 1. Otherwise, this is the ambiguous ' +
223 'situation - Model Optimizer can not know in advance whether the layer has the batch ' +
224 'dimension or not.\n\n For example, you want to set batch dimension equals 100 ' +
225 'for the input layer "data" with shape (10,34). Although you can not use --batch, ' +
226 'you should pass --input_shape (100,34) instead of --batch 100. \n\n' +
227 refer_to_faq_msg(39))
228 .format(data['name'], data['shape']))
229 data['shape'][0] = batch
232 def override_placeholder_shapes(graph: Graph, user_shapes: dict, batch=None):
234 This function overrides shapes for nodes with 'op' param set to 'Placeholder' with shapes defined by users (only
235 for inputs without in/out port specified).
236 And override batch if batch was specified and shape for input is not None.
237 :param graph: graph to operate on
238 :param user_shapes: dictionary, that represents user defined nodes and shapes
239 :param batch: user defined integer value to override batch
241 if user_shapes is None:
242 # DON'T MOVE UPPER!!! WE NEED TO SET BATCH FIRST
243 # user did not specify neither shapes nor inputs, keep models values
245 placeholders = graph.get_nodes_with_attributes(kind='op', op='Placeholder')
246 for node_id in placeholders:
247 node_attrs = graph.node[node_id]
249 if node_id in user_shapes:
250 values = user_shapes[node_id]
252 if 'in' not in value and 'out' not in value:
253 shape = value['shape'] if value['shape'] is not None else None
254 break # we assume only one specified shape for one input
255 if shape is not None:
256 node_attrs['shape'] = shape
257 if batch is not None and node_attrs['shape'] is not None and len(node_attrs['shape']) > 0:
258 node_attrs['shape'][0] = batch
261 def update_fully_connected_shapes(graph: Graph):
262 nodes = nx.topological_sort(graph)
266 node = Node(graph, n)
267 if node.has('type') and node.type == 'FullyConnected' and node.in_node(0).shape.size == 3:
268 log.debug("node.in_node(0).shape = {}".format(node.in_node(0).shape))
269 log.debug("channel_dims = {}".format(node.channel_dims))
270 assert (node.in_node(0).shape.size == 3 and node.channel_dims > 0)
271 node.in_node(0).shape = np.delete(node.in_node(0).shape, 1)
272 if node.out_node().shape.size == 3:
273 node.channel_dims = node.channel_dims - 1
274 log.debug("Initiated partial infer from update_fully_connected_shapes")
275 graph = partial_infer(graph, node.in_node(0).id)
277 # graph = mark_dead_nodes(graph)
278 # graph = eliminate_dead_nodes(graph)
285 # Convert MUL operation to Power layer in case when
286 # mul op takes two inputs (scalar constant and tensor)
287 def convert_mul_add_to_power(graph: Graph):
288 for_each_sub_graph(graph, convert_mul_add_to_power)
289 nodes = list(graph.nodes())
291 # As we remove nodes from graph, we should check that node exists in graph
293 node = Node(graph, n)
294 if node.has('op') and (node.op == 'Mul' or node.op == 'Add') and len(node.in_nodes()) == 2 and \
295 node.soft_get('can_be_scaleshift') is not False:
296 scalar_idx, tensor_idx = (0, 1) if not node.in_node(0).value is None else (1, 0)
297 if not node.in_node(scalar_idx).value is None and node.in_node(tensor_idx).value is None:
298 if np.squeeze(node.in_node(scalar_idx).value).ndim == 0:
299 node['type'] = 'Power'
300 node['scale'] = node.in_node(scalar_idx).value.item() if node.op == 'Mul' else 1
302 node['shift'] = node.in_node(scalar_idx).value.item() if node.op == 'Add' else 0
304 if node.has('operation'):
305 del node.graph.node[node.id]['operation']
306 update_ie_fields(graph.node[node.id])
307 scalar_node = node.in_node(scalar_idx)
308 graph.remove_edge(scalar_node.id, node.id)
309 graph.remove_node(scalar_node.id)