2 Copyright (c) 2018 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
22 from mo.middle.replacement import MiddleReplacementPattern
23 from mo.ops.reshape import Reshape
26 class FeatureShuffleReshape(MiddleReplacementPattern):
28 This pass finds patterns like in shufflenet topology (Reshape->Transpose->Reshape) and will change attributes for
29 first Reshape and Transpose operations to preserve original semantics.
37 ('reshape1', dict(kind='op', type='Reshape')),
38 ('reshape1_data', dict(kind='data')),
39 ('transpose', dict(kind='op', type='Permute')),
40 ('transpose_data', dict(kind='data')),
41 ('reshape2', dict(kind='op', type='Reshape')),
42 ('reshape2_data', dict(kind='data')),
44 edges=[('reshape1', 'reshape1_data'),
45 ('reshape1_data', 'transpose'),
46 ('transpose', 'transpose_data'),
47 ('transpose_data', 'reshape2'),
48 ('reshape2', 'reshape2_data'),
52 def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
53 reshape1 = match['reshape1']
54 reshape2 = match['reshape2']
55 transpose = match['transpose']
57 # Check that Reshape->Transpose->Reshape shuffle only feature channel
58 input_shape = np.array(reshape1.in_node(0).shape)
59 reshape1_shape = np.array(reshape1.out_node().shape)
60 output_shape = np.array(reshape2.out_node().shape)
62 # Check that input shape is 4D
63 if len(input_shape) != 4:
64 log.warning('Can\'t convert Reshape->Transpose({})->Reshape sequence due to input shape should be 4D '
65 '(instead of {}D)'.format(transpose.name, len(input_shape)))
68 # Check that output shape the same as input shape
69 if not np.prod(input_shape) == np.prod(output_shape):
70 log.warning('Can\'t convert Reshape->Transpose({})->Reshape sequence due to output shape should be equal '
71 'to input shape: {} and {}'.format(transpose.name, input_shape, output_shape))
74 # Input shapes can be either NCHW or NHWC, so in case of channel split, feature channel can be splited as
75 # follows in comments below
76 # So feature_dims_split list contains possible dims responsible for feature dim
77 if graph.graph['layout'] == 'NCHW':
78 # NC1C2HW or NC1C2(H*W)
81 feature_dims_split = np.array([feature_dim, feature_dim + 1])
83 # NHWC1C2 or N(H*W)C1C2 or (N*H*W)C1C2
86 feature_dims_split = np.array([len(reshape1_shape) - 2, len(reshape1_shape) - 1])
88 # Check that feature_dims_split suits reshape layer shape
89 for dim in feature_dims_split:
90 if dim < 0 or dim >= len(reshape1_shape):
91 log.warning('Can\'t convert Reshape({}:{})->Transpose->Reshape sequence. Can\'t detect feature shuffle.'
92 ''.format(reshape1.shape, reshape1_shape))
95 if not np.prod(np.delete(reshape1_shape, feature_dims_split)) == np.prod(np.delete(input_shape, feature_dim)):
96 log.warning('Can\'t convert Reshape->Transpose->Reshape sequence. Can\'t detect feature shuffle. {} '
97 'should be equal to {}'.format(np.prod(np.delete(reshape1_shape, feature_dims_split)),
98 np.prod(np.delete(input_shape, feature_dim))))
101 # Check transpose order
102 if not np.array_equal(feature_dims_split[::-1], transpose.order[feature_dims_split]):
103 log.warning('Can\'t convert Reshape->Transpose({})->Reshape sequence. Transpose operation should witch '
104 'feature order (given order: {})'.format(transpose.name, transpose.order))
107 # Now we are sure that Reshape->Transpose->Reshape shuffle feature dims
108 # So, then we change Reshape and Transpose attrs to suite NCHW layout
110 # The resulting shape for Reshape1 layer : [N,C1,C2,(H*W)]
111 new_reshape1_shape = np.concatenate((np.array([input_shape[0]]),
112 np.array(reshape1_shape[feature_dims_split]),
113 np.array([np.prod(input_shape[spatial_dims])])))
115 new_transpose_order = np.array([0, 2, 1, 3])
116 new_transpose_shape = np.array(new_reshape1_shape[new_transpose_order])
118 reshape1.out_node().shape = new_reshape1_shape
119 transpose.order = new_transpose_order
120 transpose.out_node().shape = new_transpose_shape
122 # Preserve layers from conversion to NCHW (in case of NHWC topology layout)
123 reshape1['nchw_layout'] = True
124 reshape1.out_node()['nchw_layout'] = True
125 transpose['nchw_layout'] = True
126 transpose.out_node()['nchw_layout'] = True
129 class ReshapeSoftmaxReshape(MiddleReplacementPattern):
131 In case of NHWC this pass finds patterns Reshape(-1,2)->Softmax and changes first Reshape dims for NCHW format.
132 This transformation is necessary because after conversion to NCHW this sequence will have wrong interpretation
140 ('reshape1', dict(kind='op', type='Reshape')),
141 ('reshape1_data', dict(kind='data')),
142 ('softmax', dict(kind='op', type='SoftMax')),
143 ('softmax_data', dict(kind='data')),
145 edges=[('reshape1', 'reshape1_data'),
146 ('reshape1_data', 'softmax'),
147 ('softmax', 'softmax_data'),
150 def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
151 if graph.graph['layout'] != 'NHWC':
154 reshape1 = match['reshape1']
155 softmax = match['softmax']
157 # Check that Reshape->Softmax->Reshape shuffle only feature channel
158 input_shape = np.array(reshape1.in_node(0).shape)
159 reshape1_shape = np.array(reshape1.out_node().shape)
161 # Check that input shape is 4D
162 if len(input_shape) != 4:
163 log.warning('Can\'t convert Reshape({})->Softmax->Reshape sequence due to input shape should be 4D '
164 '(instead of {}D {})'.format(reshape1.name, len(input_shape), input_shape))
167 if len(reshape1_shape) != 2:
168 log.warning('This pass expect 2D output tensor for first Reshape {} layer (given shape: {})'
169 ''.format(reshape1.name, reshape1_shape))
174 spatial_dims = [1, 2]
176 # Skip transform in case if spatial dims in input shape are equal to [1,1]
177 if np.array_equal(input_shape[spatial_dims], np.array([1, 1])):
178 log.info('Skip this transformation due to spatial dims are [1,1]')
181 # Check that Reshape1 has out dims [-1, feature_dims]
182 if not (reshape1_shape[-1] == input_shape[-1] and reshape1_shape[0] == np.prod(
183 np.delete(input_shape, feature_dim))):
184 log.warning('Output shape for Reshape operation should be [{},{}] instead of {}'.format(
185 np.prod(np.delete(input_shape, feature_dim)), input_shape[-1], reshape1_shape))
188 # Now we are sure that Reshape->Softmax suits for this transformation
190 # The resulting shape for Reshape1 layer : [N,C,(H*W)]
191 new_reshape1_shape = np.concatenate((np.array([input_shape[0]]),
192 np.array([reshape1_shape[-1]]),
193 np.array([np.prod(input_shape[spatial_dims])])))
195 old_shape = np.array(reshape1.out_node().shape)
196 reshape1.out_node().shape = new_reshape1_shape
197 softmax.out_node().shape = new_reshape1_shape
199 # Preserve layers from conversion to NCHW (in case of NHWC topology layout)
200 reshape1['nchw_layout'] = True
201 reshape1.out_node()['nchw_layout'] = True
202 softmax['nchw_layout'] = True
203 softmax.out_node()['nchw_layout'] = True
205 # Create final Reshape to keep original shape for softmax output
206 softmax_out_data = softmax.out_node()
207 next_operation = softmax_out_data.out_node()
208 # Save edge attributes & remove edge
209 edge_attrs = graph.get_edge_data(softmax_out_data.id, next_operation.id)[0]
210 graph.remove_edge(softmax_out_data.id, next_operation.id)
212 reshape_op = Reshape(graph, dict(name="Reshape_", dim=np.array(old_shape)))
213 reshape_out_data = reshape_op.create_node_with_data(inputs=[softmax_out_data])
214 graph.add_edges_from([(reshape_out_data.id, next_operation.id, edge_attrs)])