Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / middle / passes / shape.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18
19 import numpy as np
20
21 from mo.front.common.partial_infer.utils import int64_array
22 from mo.front.extractor import update_attrs
23 from mo.graph.graph import Node, Graph
24 from mo.middle.passes.eliminate import remove_op_node_with_data_node, merge_data_nodes, graph_clean_up_tf
25 from mo.middle.passes.fusing.helpers import get_next_operation
26 from mo.middle.pattern_match import apply_pattern
27 from mo.ops.op import PermuteAttrs, Op
28 from mo.ops.permute import Permute
29 from mo.utils.error import Error
30 from mo.utils.utils import refer_to_faq_msg
31
32
33 def reshape_squeeze_transform(graph: Graph, match: dict):
34     reshape = match['reshape']
35     output = match['output']
36     if output.shape is None:
37         return  # cannot really do anything if shape is dynamic
38     reshape['shape'] = output.shape
39     reshape.op = 'Reshape'
40     reshape['type'] = 'Reshape'
41     if not reshape.has_valid('dim'):
42         # do not override value 'dim' if it is set. It may contain specific values like -1 and 0
43         reshape['dim'] = reshape.shape.copy()
44     update_attrs(reshape, 'shape_attrs', 'dim')
45
46
47 def convert_squeeze(graph: Graph):
48     apply_pattern(
49         graph,
50         nodes=[
51             ('reshape', dict(kind='op', op='Squeeze')),
52             ('output', dict(kind='data'))],
53         edges=[('reshape', 'output')],
54         action=reshape_squeeze_transform
55     )
56
57
58 def convert_reshape(graph: Graph):
59     apply_pattern(
60         graph,
61         nodes=[
62             ('shape', dict(kind='data')),
63             ('reshape', dict(kind='op', op='Reshape')),
64             ('output', dict(kind='data'))],
65         edges=[('shape', 'reshape', {'in': 1}), ('reshape', 'output')],
66         action=reshape_squeeze_transform
67     )
68
69
70 def can_repack_fully_connected_weights_nhwc_to_nchw(fc_node: Node):
71     """
72     Checks that it is possible to repack weights of the FullyConnected layer if the Reshape layer is the input of the
73     FullyConnected and satisfies several conditions.
74     :param fc_node: the FullyConnected node to check
75     :return: the result of the check
76     """
77     if len(fc_node.in_node(0).in_nodes()) != 1:
78         return False
79
80     reshape_node = fc_node.in_node(0).in_node(0)
81     if not reshape_node.has_valid('type') or reshape_node.type != 'Reshape':
82         return False
83
84     if not reshape_node.in_node(0).has_valid('shape') or not reshape_node.out_node().has_valid('shape'):
85         return False
86
87     orig_shape = reshape_node.in_node(0).shape
88     new_shape = reshape_node.out_node().shape
89
90     # TODO a bit conservative condition; relax it checking specific dimensions that are involved in
91     # NHWC to NCWH translation
92     if len(orig_shape) == len(new_shape) and all(orig_shape == new_shape):
93         return False
94
95     # TODO here is a couple of limitations that makes this pass simpler; consider to relax them
96     if len(orig_shape) == 4 and len(new_shape) == 2 and orig_shape[0] == new_shape[0]:
97         # that means orig_shape is in NCHW and new_shape is in NC
98         # and we need to map CHW part to C after HWC to CHW transform
99         # Assuming that FullyConnected weights haven't been converted from IO to OI yet.
100         # So format is IO.
101         return True
102     else:
103         log.warning("Cannot do the complete NHWC to NCHW translation for FullyConnected weights. "
104                     "The final model can be broken.")
105         return False
106
107
108 def repack_fully_connected_weights_nhwc_to_nchw(graph: Graph):
109     """
110     Repack weights of FullyConnected layer as a part of nhwc_to_nchw translation if Reshape of
111     that involves dimensions that we are repacking appears right before FullyConnected layer.
112     """
113     for node_id in graph.get_nodes_with_attributes(type='FullyConnected'):
114         fc_node = Node(graph, node_id)
115
116         if not can_repack_fully_connected_weights_nhwc_to_nchw(fc_node):
117             continue
118
119         reshape_node = fc_node.in_node(0).in_node(0)
120
121         orig_shape = reshape_node.in_node(0).shape
122         new_shape = reshape_node.out_node().shape
123
124         # OK, here we are; need to repack fc_node.in_node(1) to maintain it compatible with original input order
125
126         assert all(orig_shape != -1), 'Input shape for {} can not be negative.'.format(fc_node.id)
127         assert all(new_shape != -1), 'Output shape for {} can not be negative.'.format(fc_node.id)
128         assert orig_shape[1] * orig_shape[2] * orig_shape[3] == new_shape[1], \
129             'Input shape does not correspond to output shape for layer {}.'.format(fc_node.id)
130         assert fc_node.in_node(1).has_valid('value'), 'Node {} does not have value.'.format(fc_node.id)
131
132         weights = fc_node.in_node(1)
133
134         log.debug("orig_shape = {}".format(orig_shape))
135         log.debug("new_shape = {}".format(new_shape))
136         log.debug("weights.shape = {}".format(weights.shape))
137         log.debug("weights.shape[1] = {}, new_shape[1] = {}".format(weights.shape[1], new_shape[1]))
138
139         assert weights.shape[0] == new_shape[1], \
140             'First dim of weights does not correspond to output shape of {}'.format(fc_node.id)
141         # interpret I dimension of the weights as packed HWC
142         # orig shape is already converted to NCHW, so provide transposed order for I repacking
143         tmp_shape = (orig_shape[2], orig_shape[3], orig_shape[1], weights.shape[1])
144         weights.value = np.transpose(weights.value.reshape(tmp_shape), (2, 0, 1, 3)).reshape(weights.shape)
145
146
147 def apply_nhwc_to_nchw_permutation(graph: Graph):
148     # Add NHWC to NCHW permutation for all data nodes (only for nodes without permutation)
149     if graph.graph['layout'] == 'NCHW':
150         return
151     for node in graph.nodes():
152         node = Node(graph, node)
153         if node.kind == 'data':
154             if node.has_and_set('nchw_layout'):
155                 continue
156
157             # Get NHWC to NCHW permutation for N dims, where N = len(node.shape)
158             permutation = PermuteAttrs().get_nhwc_to_nchw_permutation(len(node.shape))
159
160             # Check that data node already has permutation
161             skip_permutation = False
162             for in_node in node.in_nodes():
163                 edge_attrs = node.graph.get_edge_data(in_node.id, node.id)[0]
164                 if 'permutation' in edge_attrs:
165                     skip_permutation = True
166             for out_node in node.out_nodes():
167                 edge_attrs = node.graph.get_edge_data(node.id, out_node.id)[0]
168                 if 'permutation' in edge_attrs:
169                     skip_permutation = True
170
171             if skip_permutation:
172                 continue
173
174             # Set permutation to all in/out edges
175             for in_node in node.in_nodes():
176                 PermuteAttrs.set_permutation(in_node, node, permutation)
177
178             for out_node in node.out_nodes():
179                 PermuteAttrs.set_permutation(node, out_node, permutation)
180
181
182 def merge_nodes_permutations(graph: Graph):
183     # Iterate over all data nodes and check all permutations for similarity
184     # In case of equal permutations, this permutation will be set as attribute for data node
185     # otherwise exception will be raised
186     for node in graph.nodes():
187         node = Node(graph, node)
188         if node.kind != 'data':
189             continue
190
191         permutations = []
192
193         # Get all permutations from in edges
194         for in_node in node.in_nodes():
195             edge_attrs = node.graph.get_edge_data(in_node.id, node.id)[0]
196             if 'permutation' in edge_attrs:
197                 permutations.append(edge_attrs['permutation'])
198
199         # Get all permutations from out edges
200         for out_node in node.out_nodes():
201             edge_attrs = node.graph.get_edge_data(node.id, out_node.id)[0]
202             if 'permutation' in edge_attrs:
203                 permutations.append(edge_attrs['permutation'])
204
205         # Check that all permutations are equal
206         final_permutations = []
207         for p in permutations:
208             if p is not None:
209                 final_permutations.append(p.perm)
210             else:
211                 final_permutations.append(np.arange(node.shape.size))
212
213         if len(final_permutations) == 0:
214             continue
215
216         if not all([np.array_equal(final_permutations[0], perm) for perm in final_permutations]):
217             raise Error(
218                 'Permutations requested for {} data node are not equal! List of permutations: {}'.format(node.name,
219                                                                                                          [p.perm for
220                                                                                                           p in
221                                                                                                           permutations]))
222
223         assert not node.has_valid('permutation') or np.array_equal(node.permutation, permutations[0])
224         node['permutation'] = permutations[0]
225         if node.permutation is not None and node.permutation.perm.size == 0:
226             node.permutation = None
227
228
229 def permute_data_nodes_attrs(graph: Graph):
230     # Iterate over all data nodes and apply permutation if exists
231     for node in graph.nodes():
232         node = Node(graph, node)
233         if node.kind != 'data' or not node.has_valid('permutation'):
234             continue
235
236         # Apply permutation for shape and value if exists
237         node.shape = np.array(node.shape)[node.permutation.perm]
238         if node.has_valid('value'):
239             if len(node.value.shape) != len(node.permutation.perm):
240                 log.warning('Node {} has shape {} and permutation {} that is not satisfied'.format(node.name, node.value.shape, node.permutation.perm))
241                 continue
242             #print(node.name, node.value.shape, node.shape, node.permutation)
243             node.value = np.array(node.value.transpose(node.permutation.perm))
244
245
246 def permute_op_nodes_attrs(graph: Graph):
247     for node in graph.nodes():
248         node = Node(graph, node)
249         if node.kind == 'op' and node.has_valid('permute_attrs'):
250             try:
251                 node.permute_attrs.permute_attrs(node)
252             except Exception as e:
253                 raise Error('Can\'t permute attrs for node {}. Error message: {}'.format(node.id, e))
254
255
256 def reverse_input_channels(graph: Graph):
257     """
258     Searchers for all type=Input nodes with 4D output tensors,
259     tracks tensors down through non-shape-changing ops to the first type=Convolution or other channel-dependent nodes
260     and reverse input channels in convolution weights.
261     """
262     candidates = set()
263     for node in graph.nodes():
264         node = Node(graph, node)
265         if node.has_valid('type') and node.type == 'Input' and len(node.out_nodes()) == 1 and node.out_node(
266                 0).shape.size == 4:
267             candidates.add(node)
268     log.debug('reverse_input_channels found candidates: {}'.format([c.node for c in candidates]))
269     # Track down to the first convolutions
270     convolutions = set()
271     flip_passthrough = set()
272     while len(candidates) > 0:
273         op_node = candidates.pop()
274         assert (len(op_node.out_nodes()) == 1)
275         tensor_node = op_node.out_node(0)
276         for consumer in tensor_node.out_nodes():
277             if (consumer.has_valid('type') and
278                     consumer.type == 'Convolution' and
279                     consumer.in_node(1).has_valid('input_channel_dim') and
280                     consumer.in_node(1).has_valid('shape') and
281                     consumer.in_node(1).shape[consumer.in_node(1).input_channel_dim] == 3 and
282                     consumer.in_node(1).has_valid('value')):
283                 convolutions.add(consumer)
284             else:
285                 # TODO Use more reliable way
286                 if len(consumer.out_nodes()) == 1 and np.all(consumer.out_node().shape == tensor_node.shape):
287                     candidates.add(consumer)
288                     if consumer.has_valid('type') and (
289                             consumer.type == 'ScaleShift' or consumer.type == 'BatchNormalization'):
290                         flip_passthrough.add(consumer)
291                 else:
292                     log.debug('Stop searching of conv candidate for channel reversing at node {}'.format(consumer.id))
293
294     if len(convolutions) == 0:
295         log.error('Reverse input channels are not applied -- appropriate convolutions were not found')
296
297     for node in flip_passthrough:
298         log.debug("Applying flip for ScaleShift: {}".format(node.name))
299         assert node.has_valid('type') and (node.type == 'ScaleShift' or node.type == 'BatchNormalization')
300         blobs = [node.in_node(i) for i in range(1, len(node.in_nodes()))]
301         for blob in blobs:
302             assert blob.has_valid('value')
303             non_one_dimensions = np.where(blob.shape != 1)[0]
304             assert len(non_one_dimensions) == 1
305             assert blob.shape[non_one_dimensions[0]] == 3
306             blob.value = np.flip(blob.value, non_one_dimensions[0])
307
308     for conv in convolutions:
309         if conv.op == 'DepthwiseConv2dNative':
310             log.debug('out nodes: {}'.format(conv.out_node()))
311             bottoms = conv.out_node().out_nodes()
312             if len(bottoms) == 1 and bottoms[0].op == 'FakeQuantWithMinMaxVars':
313                 bottoms = bottoms[0].out_node().out_nodes()
314             log.debug('bottoms: {}'.format(bottoms))
315             log.debug('assumed conv: name = {}, op = {}'.format(bottoms[0].name, bottoms[0].op))
316             if len(bottoms) > 0 and bottoms[0].op == 'Conv2D':
317                 bottom_conv = bottoms[0]
318                 # Flipping input channel for DepthwiseConv2dNative along doesn't do complete thing
319                 # We also need to flip input channels for the next convolution in groups
320                 ngroups = conv.group
321                 log.debug('ngroups = {}'.format(ngroups))
322                 bottom_channel_dim = bottom_conv.channel_dims[0]
323                 log.debug('bottom_challen_dim = {}'.format(bottom_channel_dim))
324                 bottom_channels = bottom_conv.in_node(0).shape[bottom_channel_dim]
325                 log.debug('bottom_channels = {}'.format(bottom_channels))
326                 assert (bottom_channels % ngroups == 0)
327                 multiplier = int(bottom_channels / ngroups)
328                 log.debug('multiplier = {}'.format(multiplier))
329                 bottom_weights = bottom_conv.in_node(1)
330                 tmp_shape_for_reorder = list(bottom_weights.value.shape)
331                 src_shape = list(tmp_shape_for_reorder)
332                 log.debug('weights shape = {}'.format(tmp_shape_for_reorder))
333                 assert (tmp_shape_for_reorder[bottom_weights.input_channel_dim] == bottom_channels)
334                 tmp_shape_for_reorder[bottom_weights.input_channel_dim] = ngroups
335                 tmp_shape_for_reorder = tmp_shape_for_reorder + [multiplier]
336                 log.debug('tmp_shape_for_reorder = {}'.format(tmp_shape_for_reorder))
337                 # temporary change shape of weights to do reordering
338                 # bottom_weights.value.shape = tuple(tmp_shape_for_reorder)
339                 bottom_weights.value = np.flip(bottom_weights.value.reshape(tuple(tmp_shape_for_reorder)),
340                                                bottom_weights.input_channel_dim)
341                 # change shape of weights back
342                 log.debug('back to shape = {}'.format(tuple(src_shape)))
343                 bottom_weights.value = bottom_weights.value.reshape(tuple(src_shape))
344                 log.debug('final shape of weights = {}'.format(bottom_weights.value.shape))
345                 log.debug('shape as attr = {}'.format(bottom_weights.shape))
346             else:
347                 log.error(
348                     'Reverse input channels are not applied: there is no Conv2D after DepthwiseConv2dNative to ' +
349                     'complete the flip')
350
351         conv.in_node(1).value = np.flip(conv.in_node(1).value, conv.in_node(1).input_channel_dim)
352         conv.in_node(1).shape = int64_array(conv.in_node(1).value.shape)
353         log.debug('Applied reversing input channels for weights of convolution {}'.format(conv.id))
354         log.debug('Shape was (shape){}, (value.shape){}'.format(conv.in_node(1).shape, conv.in_node(1).value.shape))
355         log.debug('Flipped dim: {}'.format(conv.in_node(1).input_channel_dim))
356
357
358 def conv_flatten_concat_action(graph: Graph, match: dict):
359     assert graph.graph['layout'] == 'NHWC'
360     reshape_node = match['reshape']
361     reshape_data_node = match['reshape_data']
362     conv_name = match['conv'].name
363     conv_data_node = match['conv_data']
364     # the pattern should be applied only in case when the reshape operation changes number of dimensions
365     if len(reshape_data_node.shape) == len(conv_data_node.shape) or reshape_node.has_and_set('nchw_layout'):
366         return
367
368     if len(reshape_data_node.out_nodes()) == 1 and reshape_data_node.out_node().has_valid('type') and \
369         reshape_data_node.out_node().type == 'FullyConnected' and \
370             can_repack_fully_connected_weights_nhwc_to_nchw(reshape_data_node.out_node()):
371         log.info('There is a FullyConnected layer after the node "{}" which weights will be repacked. So there is no '
372                  'need to insert Permute'.format(reshape_node.soft_get('name')))
373         return
374     graph.remove_edge(conv_data_node.id, reshape_node.id)
375
376     permutation_order = PermuteAttrs.get_nchw_to_nhwc_permutation(len(conv_data_node.shape)).perm
377     new_permute_op = Permute(graph, {'order': permutation_order})
378     permute_data_node = new_permute_op.create_node_with_data([conv_data_node], dict(name=conv_name + '/Permute_'))
379     graph.create_edge(permute_data_node, reshape_node)
380     # Disable permutation for Reshape and Concat layers attributes
381     PermuteAttrs.set_permutation(reshape_node, reshape_data_node, None)
382     reshape_node['nchw_layout'] = True
383
384
385 def conv_flatten_concat(graph: Graph):
386     apply_pattern(
387         graph,
388         nodes=[
389             ('conv', dict(kind='op', type='Convolution')),
390             ('conv_data', dict(kind='data')),
391             ('reshape', dict(kind='op', type='Reshape')),
392             ('reshape_data', dict(kind='data')),
393         ],
394         edges=[
395             ('conv', 'conv_data'),
396             ('conv_data', 'reshape'),
397             ('reshape', 'reshape_data'),
398         ],
399         action=conv_flatten_concat_action
400     )
401
402     apply_pattern(
403         graph,
404         nodes=[
405             ('real_conv', dict(kind='op', type='Convolution')),
406             ('real_conv_data', dict(kind='data')),
407             ('conv', dict(kind='op', type='ReLU')),
408             ('conv_data', dict(kind='data')),
409             ('reshape', dict(kind='op', type='Reshape')),
410             ('reshape_data', dict(kind='data')),
411         ],
412         edges=[
413             ('real_conv', 'real_conv_data'),
414             ('real_conv_data', 'conv'),
415             ('conv', 'conv_data'),
416             ('conv_data', 'reshape'),
417             ('reshape', 'reshape_data'),
418         ],
419         action=conv_flatten_concat_action
420     )
421
422
423 def fuse_sequence_of_reshapes(graph: Graph):
424     for node in list(graph.nodes()):
425         if not graph.has_node(node):
426             # data node can be already removed
427             continue
428         node = Node(graph, node)
429         if (
430                 node.has_valid('type') and node.type == 'Reshape' and
431                 len(node.out_nodes()) == 1 and node.out_node().has_valid('kind') and node.out_node().kind == 'data' and
432                 len(node.out_node().out_nodes()) == 1):
433
434             log.debug('First phase for Reshape: {}'.format(node.name))
435
436             next_op = node.out_node().out_node()
437             log.debug('second node: {}'.format(next_op.graph.node[next_op.id]))
438             if next_op.has_valid('type') and next_op.type == 'Reshape':
439                 # Detected Reshape1 --> data --> Reshape2 pattern without side edges
440                 # Remove Reshape1
441                 log.debug('Second phase for Reshape: {}'.format(node.name))
442                 remove_op_node_with_data_node(graph, node)
443
444     reshape_nodes = graph.get_op_nodes(op='Reshape')
445     for reshape_node in reshape_nodes:
446         in_ports = [port for port in reshape_node.in_ports().values() if not port.disconnected()]
447         assert len(in_ports) in [1, 2], "`Reshape` node must have 2 inputs or 1 input with `dim`"
448         if len(in_ports) == 2:
449             previous_dim_op = reshape_node.in_port(1).get_source().node.op
450             if previous_dim_op != 'Const':
451                 continue
452             dim = reshape_node.in_port(1).get_connection().data.get_value()
453         else:
454             assert reshape_node.has_valid('dim'), "`Reshape` node with 1 input must have `dim` attribute"
455             dim = reshape_node.dim
456
457         in_shape = reshape_node.in_port(0).get_connection().data.get_shape()
458
459         if np.array_equal(dim, in_shape) and len(reshape_node.out_nodes()):
460             log.debug("Useless reshape with dim {} was deleted: {}".format(str(dim), reshape_node.name))
461             reshape_node.out_port(0).get_connection().set_source(reshape_node.in_port(0).get_source())