2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from mo.front.common.partial_infer.utils import int64_array
22 from mo.front.extractor import update_attrs
23 from mo.graph.graph import Node, Graph
24 from mo.middle.passes.eliminate import remove_op_node_with_data_node, merge_data_nodes, graph_clean_up_tf
25 from mo.middle.passes.fusing.helpers import get_next_operation
26 from mo.middle.pattern_match import apply_pattern
27 from mo.ops.op import PermuteAttrs, Op
28 from mo.ops.permute import Permute
29 from mo.utils.error import Error
30 from mo.utils.utils import refer_to_faq_msg
33 def reshape_squeeze_transform(graph: Graph, match: dict):
34 reshape = match['reshape']
35 output = match['output']
36 if output.shape is None:
37 return # cannot really do anything if shape is dynamic
38 reshape['shape'] = output.shape
39 reshape.op = 'Reshape'
40 reshape['type'] = 'Reshape'
41 if not reshape.has_valid('dim'):
42 # do not override value 'dim' if it is set. It may contain specific values like -1 and 0
43 reshape['dim'] = reshape.shape.copy()
44 update_attrs(reshape, 'shape_attrs', 'dim')
47 def convert_squeeze(graph: Graph):
51 ('reshape', dict(kind='op', op='Squeeze')),
52 ('output', dict(kind='data'))],
53 edges=[('reshape', 'output')],
54 action=reshape_squeeze_transform
58 def convert_reshape(graph: Graph):
62 ('shape', dict(kind='data')),
63 ('reshape', dict(kind='op', op='Reshape')),
64 ('output', dict(kind='data'))],
65 edges=[('shape', 'reshape', {'in': 1}), ('reshape', 'output')],
66 action=reshape_squeeze_transform
70 def can_repack_fully_connected_weights_nhwc_to_nchw(fc_node: Node):
72 Checks that it is possible to repack weights of the FullyConnected layer if the Reshape layer is the input of the
73 FullyConnected and satisfies several conditions.
74 :param fc_node: the FullyConnected node to check
75 :return: the result of the check
77 if len(fc_node.in_node(0).in_nodes()) != 1:
80 reshape_node = fc_node.in_node(0).in_node(0)
81 if not reshape_node.has_valid('type') or reshape_node.type != 'Reshape':
84 if not reshape_node.in_node(0).has_valid('shape') or not reshape_node.out_node().has_valid('shape'):
87 orig_shape = reshape_node.in_node(0).shape
88 new_shape = reshape_node.out_node().shape
90 # TODO a bit conservative condition; relax it checking specific dimensions that are involved in
91 # NHWC to NCWH translation
92 if len(orig_shape) == len(new_shape) and all(orig_shape == new_shape):
95 # TODO here is a couple of limitations that makes this pass simpler; consider to relax them
96 if len(orig_shape) == 4 and len(new_shape) == 2 and orig_shape[0] == new_shape[0]:
97 # that means orig_shape is in NCHW and new_shape is in NC
98 # and we need to map CHW part to C after HWC to CHW transform
99 # Assuming that FullyConnected weights haven't been converted from IO to OI yet.
103 log.warning("Cannot do the complete NHWC to NCHW translation for FullyConnected weights. "
104 "The final model can be broken.")
108 def repack_fully_connected_weights_nhwc_to_nchw(graph: Graph):
110 Repack weights of FullyConnected layer as a part of nhwc_to_nchw translation if Reshape of
111 that involves dimensions that we are repacking appears right before FullyConnected layer.
113 for node_id in graph.get_nodes_with_attributes(type='FullyConnected'):
114 fc_node = Node(graph, node_id)
116 if not can_repack_fully_connected_weights_nhwc_to_nchw(fc_node):
119 reshape_node = fc_node.in_node(0).in_node(0)
121 orig_shape = reshape_node.in_node(0).shape
122 new_shape = reshape_node.out_node().shape
124 # OK, here we are; need to repack fc_node.in_node(1) to maintain it compatible with original input order
126 assert all(orig_shape != -1), 'Input shape for {} can not be negative.'.format(fc_node.id)
127 assert all(new_shape != -1), 'Output shape for {} can not be negative.'.format(fc_node.id)
128 assert orig_shape[1] * orig_shape[2] * orig_shape[3] == new_shape[1], \
129 'Input shape does not correspond to output shape for layer {}.'.format(fc_node.id)
130 assert fc_node.in_node(1).has_valid('value'), 'Node {} does not have value.'.format(fc_node.id)
132 weights = fc_node.in_node(1)
134 log.debug("orig_shape = {}".format(orig_shape))
135 log.debug("new_shape = {}".format(new_shape))
136 log.debug("weights.shape = {}".format(weights.shape))
137 log.debug("weights.shape[1] = {}, new_shape[1] = {}".format(weights.shape[1], new_shape[1]))
139 assert weights.shape[0] == new_shape[1], \
140 'First dim of weights does not correspond to output shape of {}'.format(fc_node.id)
141 # interpret I dimension of the weights as packed HWC
142 # orig shape is already converted to NCHW, so provide transposed order for I repacking
143 tmp_shape = (orig_shape[2], orig_shape[3], orig_shape[1], weights.shape[1])
144 weights.value = np.transpose(weights.value.reshape(tmp_shape), (2, 0, 1, 3)).reshape(weights.shape)
147 def apply_nhwc_to_nchw_permutation(graph: Graph):
148 # Add NHWC to NCHW permutation for all data nodes (only for nodes without permutation)
149 if graph.graph['layout'] == 'NCHW':
151 for node in graph.nodes():
152 node = Node(graph, node)
153 if node.kind == 'data':
154 if node.has_and_set('nchw_layout'):
157 # Get NHWC to NCHW permutation for N dims, where N = len(node.shape)
158 permutation = PermuteAttrs().get_nhwc_to_nchw_permutation(len(node.shape))
160 # Check that data node already has permutation
161 skip_permutation = False
162 for in_node in node.in_nodes():
163 edge_attrs = node.graph.get_edge_data(in_node.id, node.id)[0]
164 if 'permutation' in edge_attrs:
165 skip_permutation = True
166 for out_node in node.out_nodes():
167 edge_attrs = node.graph.get_edge_data(node.id, out_node.id)[0]
168 if 'permutation' in edge_attrs:
169 skip_permutation = True
174 # Set permutation to all in/out edges
175 for in_node in node.in_nodes():
176 PermuteAttrs.set_permutation(in_node, node, permutation)
178 for out_node in node.out_nodes():
179 PermuteAttrs.set_permutation(node, out_node, permutation)
182 def merge_nodes_permutations(graph: Graph):
183 # Iterate over all data nodes and check all permutations for similarity
184 # In case of equal permutations, this permutation will be set as attribute for data node
185 # otherwise exception will be raised
186 for node in graph.nodes():
187 node = Node(graph, node)
188 if node.kind != 'data':
193 # Get all permutations from in edges
194 for in_node in node.in_nodes():
195 edge_attrs = node.graph.get_edge_data(in_node.id, node.id)[0]
196 if 'permutation' in edge_attrs:
197 permutations.append(edge_attrs['permutation'])
199 # Get all permutations from out edges
200 for out_node in node.out_nodes():
201 edge_attrs = node.graph.get_edge_data(node.id, out_node.id)[0]
202 if 'permutation' in edge_attrs:
203 permutations.append(edge_attrs['permutation'])
205 # Check that all permutations are equal
206 final_permutations = []
207 for p in permutations:
209 final_permutations.append(p.perm)
211 final_permutations.append(np.arange(node.shape.size))
213 if len(final_permutations) == 0:
216 if not all([np.array_equal(final_permutations[0], perm) for perm in final_permutations]):
218 'Permutations requested for {} data node are not equal! List of permutations: {}'.format(node.name,
223 assert not node.has_valid('permutation') or np.array_equal(node.permutation, permutations[0])
224 node['permutation'] = permutations[0]
225 if node.permutation is not None and node.permutation.perm.size == 0:
226 node.permutation = None
229 def permute_data_nodes_attrs(graph: Graph):
230 # Iterate over all data nodes and apply permutation if exists
231 for node in graph.nodes():
232 node = Node(graph, node)
233 if node.kind != 'data' or not node.has_valid('permutation'):
236 # Apply permutation for shape and value if exists
237 node.shape = np.array(node.shape)[node.permutation.perm]
238 if node.has_valid('value'):
239 if len(node.value.shape) != len(node.permutation.perm):
240 log.warning('Node {} has shape {} and permutation {} that is not satisfied'.format(node.name, node.value.shape, node.permutation.perm))
242 #print(node.name, node.value.shape, node.shape, node.permutation)
243 node.value = np.array(node.value.transpose(node.permutation.perm))
246 def permute_op_nodes_attrs(graph: Graph):
247 for node in graph.nodes():
248 node = Node(graph, node)
249 if node.kind == 'op' and node.has_valid('permute_attrs'):
251 node.permute_attrs.permute_attrs(node)
252 except Exception as e:
253 raise Error('Can\'t permute attrs for node {}. Error message: {}'.format(node.id, e))
256 def reverse_input_channels(graph: Graph):
258 Searchers for all type=Input nodes with 4D output tensors,
259 tracks tensors down through non-shape-changing ops to the first type=Convolution or other channel-dependent nodes
260 and reverse input channels in convolution weights.
263 for node in graph.nodes():
264 node = Node(graph, node)
265 if node.has_valid('type') and node.type == 'Input' and len(node.out_nodes()) == 1 and node.out_node(
268 log.debug('reverse_input_channels found candidates: {}'.format([c.node for c in candidates]))
269 # Track down to the first convolutions
271 flip_passthrough = set()
272 while len(candidates) > 0:
273 op_node = candidates.pop()
274 assert (len(op_node.out_nodes()) == 1)
275 tensor_node = op_node.out_node(0)
276 for consumer in tensor_node.out_nodes():
277 if (consumer.has_valid('type') and
278 consumer.type == 'Convolution' and
279 consumer.in_node(1).has_valid('input_channel_dim') and
280 consumer.in_node(1).has_valid('shape') and
281 consumer.in_node(1).shape[consumer.in_node(1).input_channel_dim] == 3 and
282 consumer.in_node(1).has_valid('value')):
283 convolutions.add(consumer)
285 # TODO Use more reliable way
286 if len(consumer.out_nodes()) == 1 and np.all(consumer.out_node().shape == tensor_node.shape):
287 candidates.add(consumer)
288 if consumer.has_valid('type') and (
289 consumer.type == 'ScaleShift' or consumer.type == 'BatchNormalization'):
290 flip_passthrough.add(consumer)
292 log.debug('Stop searching of conv candidate for channel reversing at node {}'.format(consumer.id))
294 if len(convolutions) == 0:
295 log.error('Reverse input channels are not applied -- appropriate convolutions were not found')
297 for node in flip_passthrough:
298 log.debug("Applying flip for ScaleShift: {}".format(node.name))
299 assert node.has_valid('type') and (node.type == 'ScaleShift' or node.type == 'BatchNormalization')
300 blobs = [node.in_node(i) for i in range(1, len(node.in_nodes()))]
302 assert blob.has_valid('value')
303 non_one_dimensions = np.where(blob.shape != 1)[0]
304 assert len(non_one_dimensions) == 1
305 assert blob.shape[non_one_dimensions[0]] == 3
306 blob.value = np.flip(blob.value, non_one_dimensions[0])
308 for conv in convolutions:
309 if conv.op == 'DepthwiseConv2dNative':
310 log.debug('out nodes: {}'.format(conv.out_node()))
311 bottoms = conv.out_node().out_nodes()
312 if len(bottoms) == 1 and bottoms[0].op == 'FakeQuantWithMinMaxVars':
313 bottoms = bottoms[0].out_node().out_nodes()
314 log.debug('bottoms: {}'.format(bottoms))
315 log.debug('assumed conv: name = {}, op = {}'.format(bottoms[0].name, bottoms[0].op))
316 if len(bottoms) > 0 and bottoms[0].op == 'Conv2D':
317 bottom_conv = bottoms[0]
318 # Flipping input channel for DepthwiseConv2dNative along doesn't do complete thing
319 # We also need to flip input channels for the next convolution in groups
321 log.debug('ngroups = {}'.format(ngroups))
322 bottom_channel_dim = bottom_conv.channel_dims[0]
323 log.debug('bottom_challen_dim = {}'.format(bottom_channel_dim))
324 bottom_channels = bottom_conv.in_node(0).shape[bottom_channel_dim]
325 log.debug('bottom_channels = {}'.format(bottom_channels))
326 assert (bottom_channels % ngroups == 0)
327 multiplier = int(bottom_channels / ngroups)
328 log.debug('multiplier = {}'.format(multiplier))
329 bottom_weights = bottom_conv.in_node(1)
330 tmp_shape_for_reorder = list(bottom_weights.value.shape)
331 src_shape = list(tmp_shape_for_reorder)
332 log.debug('weights shape = {}'.format(tmp_shape_for_reorder))
333 assert (tmp_shape_for_reorder[bottom_weights.input_channel_dim] == bottom_channels)
334 tmp_shape_for_reorder[bottom_weights.input_channel_dim] = ngroups
335 tmp_shape_for_reorder = tmp_shape_for_reorder + [multiplier]
336 log.debug('tmp_shape_for_reorder = {}'.format(tmp_shape_for_reorder))
337 # temporary change shape of weights to do reordering
338 # bottom_weights.value.shape = tuple(tmp_shape_for_reorder)
339 bottom_weights.value = np.flip(bottom_weights.value.reshape(tuple(tmp_shape_for_reorder)),
340 bottom_weights.input_channel_dim)
341 # change shape of weights back
342 log.debug('back to shape = {}'.format(tuple(src_shape)))
343 bottom_weights.value = bottom_weights.value.reshape(tuple(src_shape))
344 log.debug('final shape of weights = {}'.format(bottom_weights.value.shape))
345 log.debug('shape as attr = {}'.format(bottom_weights.shape))
348 'Reverse input channels are not applied: there is no Conv2D after DepthwiseConv2dNative to ' +
351 conv.in_node(1).value = np.flip(conv.in_node(1).value, conv.in_node(1).input_channel_dim)
352 conv.in_node(1).shape = int64_array(conv.in_node(1).value.shape)
353 log.debug('Applied reversing input channels for weights of convolution {}'.format(conv.id))
354 log.debug('Shape was (shape){}, (value.shape){}'.format(conv.in_node(1).shape, conv.in_node(1).value.shape))
355 log.debug('Flipped dim: {}'.format(conv.in_node(1).input_channel_dim))
358 def conv_flatten_concat_action(graph: Graph, match: dict):
359 assert graph.graph['layout'] == 'NHWC'
360 reshape_node = match['reshape']
361 reshape_data_node = match['reshape_data']
362 conv_name = match['conv'].name
363 conv_data_node = match['conv_data']
364 # the pattern should be applied only in case when the reshape operation changes number of dimensions
365 if len(reshape_data_node.shape) == len(conv_data_node.shape) or reshape_node.has_and_set('nchw_layout'):
368 if len(reshape_data_node.out_nodes()) == 1 and reshape_data_node.out_node().has_valid('type') and \
369 reshape_data_node.out_node().type == 'FullyConnected' and \
370 can_repack_fully_connected_weights_nhwc_to_nchw(reshape_data_node.out_node()):
371 log.info('There is a FullyConnected layer after the node "{}" which weights will be repacked. So there is no '
372 'need to insert Permute'.format(reshape_node.soft_get('name')))
374 graph.remove_edge(conv_data_node.id, reshape_node.id)
376 permutation_order = PermuteAttrs.get_nchw_to_nhwc_permutation(len(conv_data_node.shape)).perm
377 new_permute_op = Permute(graph, {'order': permutation_order})
378 permute_data_node = new_permute_op.create_node_with_data([conv_data_node], dict(name=conv_name + '/Permute_'))
379 graph.create_edge(permute_data_node, reshape_node)
380 # Disable permutation for Reshape and Concat layers attributes
381 PermuteAttrs.set_permutation(reshape_node, reshape_data_node, None)
382 reshape_node['nchw_layout'] = True
385 def conv_flatten_concat(graph: Graph):
389 ('conv', dict(kind='op', type='Convolution')),
390 ('conv_data', dict(kind='data')),
391 ('reshape', dict(kind='op', type='Reshape')),
392 ('reshape_data', dict(kind='data')),
395 ('conv', 'conv_data'),
396 ('conv_data', 'reshape'),
397 ('reshape', 'reshape_data'),
399 action=conv_flatten_concat_action
405 ('real_conv', dict(kind='op', type='Convolution')),
406 ('real_conv_data', dict(kind='data')),
407 ('conv', dict(kind='op', type='ReLU')),
408 ('conv_data', dict(kind='data')),
409 ('reshape', dict(kind='op', type='Reshape')),
410 ('reshape_data', dict(kind='data')),
413 ('real_conv', 'real_conv_data'),
414 ('real_conv_data', 'conv'),
415 ('conv', 'conv_data'),
416 ('conv_data', 'reshape'),
417 ('reshape', 'reshape_data'),
419 action=conv_flatten_concat_action
423 def fuse_sequence_of_reshapes(graph: Graph):
424 for node in list(graph.nodes()):
425 if not graph.has_node(node):
426 # data node can be already removed
428 node = Node(graph, node)
430 node.has_valid('type') and node.type == 'Reshape' and
431 len(node.out_nodes()) == 1 and node.out_node().has_valid('kind') and node.out_node().kind == 'data' and
432 len(node.out_node().out_nodes()) == 1):
434 log.debug('First phase for Reshape: {}'.format(node.name))
436 next_op = node.out_node().out_node()
437 log.debug('second node: {}'.format(next_op.graph.node[next_op.id]))
438 if next_op.has_valid('type') and next_op.type == 'Reshape':
439 # Detected Reshape1 --> data --> Reshape2 pattern without side edges
441 log.debug('Second phase for Reshape: {}'.format(node.name))
442 remove_op_node_with_data_node(graph, node)
444 reshape_nodes = graph.get_op_nodes(op='Reshape')
445 for reshape_node in reshape_nodes:
446 in_ports = [port for port in reshape_node.in_ports().values() if not port.disconnected()]
447 assert len(in_ports) in [1, 2], "`Reshape` node must have 2 inputs or 1 input with `dim`"
448 if len(in_ports) == 2:
449 previous_dim_op = reshape_node.in_port(1).get_source().node.op
450 if previous_dim_op != 'Const':
452 dim = reshape_node.in_port(1).get_connection().data.get_value()
454 assert reshape_node.has_valid('dim'), "`Reshape` node with 1 input must have `dim` attribute"
455 dim = reshape_node.dim
457 in_shape = reshape_node.in_port(0).get_connection().data.get_shape()
459 if np.array_equal(dim, in_shape) and len(reshape_node.out_nodes()):
460 log.debug("Useless reshape with dim {} was deleted: {}".format(str(dim), reshape_node.name))
461 reshape_node.out_port(0).get_connection().set_source(reshape_node.in_port(0).get_source())