Publishing R4 (#41)
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / ShufflenetReshape.py
1 """
2  Copyright (c) 2018 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18
19 import networkx as nx
20 import numpy as np
21
22 from mo.middle.replacement import MiddleReplacementPattern
23 from mo.ops.reshape import Reshape
24
25
26 class FeatureShuffleReshape(MiddleReplacementPattern):
27     """
28     This pass finds patterns like in shufflenet topology (Reshape->Transpose->Reshape) and will change attributes for
29     first Reshape and Transpose operations to preserve original semantics.
30     """
31
32     enabled = True
33
34     def pattern(self):
35         return dict(
36             nodes=[
37                 ('reshape1', dict(kind='op', type='Reshape')),
38                 ('reshape1_data', dict(kind='data')),
39                 ('transpose', dict(kind='op', type='Permute')),
40                 ('transpose_data', dict(kind='data')),
41                 ('reshape2', dict(kind='op', type='Reshape')),
42                 ('reshape2_data', dict(kind='data')),
43             ],
44             edges=[('reshape1', 'reshape1_data'),
45                    ('reshape1_data', 'transpose'),
46                    ('transpose', 'transpose_data'),
47                    ('transpose_data', 'reshape2'),
48                    ('reshape2', 'reshape2_data'),
49                    ]
50         )
51
52     def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
53         reshape1 = match['reshape1']
54         reshape2 = match['reshape2']
55         transpose = match['transpose']
56
57         # Check that Reshape->Transpose->Reshape shuffle only feature channel
58         input_shape = np.array(reshape1.in_node(0).shape)
59         reshape1_shape = np.array(reshape1.out_node().shape)
60         output_shape = np.array(reshape2.out_node().shape)
61
62         # Check that input shape is 4D
63         if len(input_shape) != 4:
64             log.warning('Can\'t convert Reshape->Transpose({})->Reshape sequence due to input shape should be 4D '
65                         '(instead of {}D)'.format(transpose.name, len(input_shape)))
66             return
67
68         # Check that output shape the same as input shape
69         if not np.prod(input_shape) == np.prod(output_shape):
70             log.warning('Can\'t convert Reshape->Transpose({})->Reshape sequence due to output shape should be equal '
71                         'to input shape: {} and {}'.format(transpose.name, input_shape, output_shape))
72             return
73
74         # Input shapes can be either NCHW or NHWC, so in case of channel split, feature channel can be splited as
75         # follows in comments below
76         # So feature_dims_split list contains possible dims responsible for feature dim
77         if graph.graph['layout'] == 'NCHW':
78             # NC1C2HW or NC1C2(H*W)
79             feature_dim = 1
80             spatial_dims = [2, 3]
81             feature_dims_split = np.array([feature_dim, feature_dim + 1])
82         else:
83             # NHWC1C2 or N(H*W)C1C2 or (N*H*W)C1C2
84             feature_dim = 3
85             spatial_dims = [1, 2]
86             feature_dims_split = np.array([len(reshape1_shape) - 2, len(reshape1_shape) - 1])
87
88         # Check that feature_dims_split suits reshape layer shape
89         for dim in feature_dims_split:
90             if dim < 0 or dim >= len(reshape1_shape):
91                 log.warning('Can\'t convert Reshape({}:{})->Transpose->Reshape sequence. Can\'t detect feature shuffle.'
92                             ''.format(reshape1.shape, reshape1_shape))
93                 return
94
95         if not np.prod(np.delete(reshape1_shape, feature_dims_split)) == np.prod(np.delete(input_shape, feature_dim)):
96             log.warning('Can\'t convert Reshape->Transpose->Reshape sequence. Can\'t detect feature shuffle. {} '
97                         'should be equal to {}'.format(np.prod(np.delete(reshape1_shape, feature_dims_split)),
98                                                        np.prod(np.delete(input_shape, feature_dim))))
99             return
100
101         # Check transpose order
102         if not np.array_equal(feature_dims_split[::-1], transpose.order[feature_dims_split]):
103             log.warning('Can\'t convert Reshape->Transpose({})->Reshape sequence. Transpose operation should witch '
104                         'feature order (given order: {})'.format(transpose.name, transpose.order))
105             return
106
107         # Now we are sure that Reshape->Transpose->Reshape shuffle feature dims
108         # So, then we change Reshape and Transpose attrs to suite NCHW layout
109
110         # The resulting shape for Reshape1 layer : [N,C1,C2,(H*W)]
111         new_reshape1_shape = np.concatenate((np.array([input_shape[0]]),
112                                              np.array(reshape1_shape[feature_dims_split]),
113                                              np.array([np.prod(input_shape[spatial_dims])])))
114
115         new_transpose_order = np.array([0, 2, 1, 3])
116         new_transpose_shape = np.array(new_reshape1_shape[new_transpose_order])
117
118         reshape1.out_node().shape = new_reshape1_shape
119         transpose.order = new_transpose_order
120         transpose.out_node().shape = new_transpose_shape
121
122         # Preserve layers from conversion to NCHW (in case of NHWC topology layout)
123         reshape1['nchw_layout'] = True
124         reshape1.out_node()['nchw_layout'] = True
125         transpose['nchw_layout'] = True
126         transpose.out_node()['nchw_layout'] = True
127
128
129 class ReshapeSoftmaxReshape(MiddleReplacementPattern):
130     """
131     In case of NHWC this pass finds patterns Reshape(-1,2)->Softmax and changes first Reshape dims for NCHW format.
132     This transformation is necessary because after conversion to NCHW this sequence will have wrong interpretation
133     """
134
135     enabled = True
136
137     def pattern(self):
138         return dict(
139             nodes=[
140                 ('reshape1', dict(kind='op', type='Reshape')),
141                 ('reshape1_data', dict(kind='data')),
142                 ('softmax', dict(kind='op', type='SoftMax')),
143                 ('softmax_data', dict(kind='data')),
144             ],
145             edges=[('reshape1', 'reshape1_data'),
146                    ('reshape1_data', 'softmax'),
147                    ('softmax', 'softmax_data'),
148                    ])
149
150     def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
151         if graph.graph['layout'] != 'NHWC':
152             return
153
154         reshape1 = match['reshape1']
155         softmax = match['softmax']
156
157         # Check that Reshape->Softmax->Reshape shuffle only feature channel
158         input_shape = np.array(reshape1.in_node(0).shape)
159         reshape1_shape = np.array(reshape1.out_node().shape)
160
161         # Check that input shape is 4D
162         if len(input_shape) != 4:
163             log.warning('Can\'t convert Reshape({})->Softmax->Reshape sequence due to input shape should be 4D '
164                         '(instead of {}D {})'.format(reshape1.name, len(input_shape), input_shape))
165             return
166
167         if len(reshape1_shape) != 2:
168             log.warning('This pass expect 2D output tensor for first Reshape {} layer (given shape: {})'
169                         ''.format(reshape1.name, reshape1_shape))
170             return
171
172         # Define feature dim
173         feature_dim = 3
174         spatial_dims = [1, 2]
175
176         # Skip transform in case if spatial dims in input shape are equal to [1,1]
177         if np.array_equal(input_shape[spatial_dims], np.array([1, 1])):
178             log.info('Skip this transformation due to spatial dims are [1,1]')
179             return
180
181         # Check that Reshape1 has out dims [-1, feature_dims]
182         if not (reshape1_shape[-1] == input_shape[-1] and reshape1_shape[0] == np.prod(
183                 np.delete(input_shape, feature_dim))):
184             log.warning('Output shape for Reshape operation should be [{},{}] instead of {}'.format(
185                 np.prod(np.delete(input_shape, feature_dim)), input_shape[-1], reshape1_shape))
186             return
187
188         # Now we are sure that Reshape->Softmax suits for this transformation
189
190         # The resulting shape for Reshape1 layer : [N,C,(H*W)]
191         new_reshape1_shape = np.concatenate((np.array([input_shape[0]]),
192                                              np.array([reshape1_shape[-1]]),
193                                              np.array([np.prod(input_shape[spatial_dims])])))
194
195         old_shape = np.array(reshape1.out_node().shape)
196         reshape1.out_node().shape = new_reshape1_shape
197         softmax.out_node().shape = new_reshape1_shape
198
199         # Preserve layers from conversion to NCHW (in case of NHWC topology layout)
200         reshape1['nchw_layout'] = True
201         reshape1.out_node()['nchw_layout'] = True
202         softmax['nchw_layout'] = True
203         softmax.out_node()['nchw_layout'] = True
204
205         # Create final Reshape to keep original shape for softmax output
206         softmax_out_data = softmax.out_node()
207         next_operation = softmax_out_data.out_node()
208         # Save edge attributes & remove edge
209         edge_attrs = graph.get_edge_data(softmax_out_data.id, next_operation.id)[0]
210         graph.remove_edge(softmax_out_data.id, next_operation.id)
211
212         reshape_op = Reshape(graph, dict(name="Reshape_", dim=np.array(old_shape)))
213         reshape_out_data = reshape_op.create_node_with_data(inputs=[softmax_out_data])
214         graph.add_edges_from([(reshape_out_data.id, next_operation.id, edge_attrs)])