Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / PixelLinkReshape.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import logging as log
18 import numpy as np
19
20 from copy import deepcopy
21
22 from extensions.middle.ConvertGroupedStridedSlice import ConvertGroupedStridedSlice
23 from extensions.middle.FusePermutesSequence import FusePermutesSequence
24 from extensions.middle.ShufflenetReshape import ReshapeSoftmaxReshape
25 from mo.graph.graph import Graph
26 from mo.middle.replacement import MiddleReplacementPattern
27 from mo.ops.op import Op
28 from mo.ops.permute import Permute
29
30
31 class PixelLinkReshape(MiddleReplacementPattern):
32     """
33       Transform adds Permutes around Reshapes that pack 4 dimensions in 2, than
34       do Softmax and then unpack it back to 5 dims.
35     """
36     enabled = True
37
38     def run_before(self):
39         return [FusePermutesSequence, ReshapeSoftmaxReshape, ConvertGroupedStridedSlice]
40
41     def run_after(self):
42         from extensions.middle.pass_separator import MiddleStart
43         return [MiddleStart]
44
45     def pattern(self):
46         return dict(nodes=[('reshape_split', dict(kind='op', type='Reshape')),
47                            ('reshape_split_data', dict(kind='data')),
48                            ('reshape_pack', dict(kind='op', type='Reshape')),
49                            ('reshape_data', dict(kind='data')),
50                            ('softmax', dict(kind='op', type='SoftMax')),
51                            ('softmax_data', dict(kind='data')),
52                            ('reshape_unpack', dict(kind='op', type='Reshape')),
53                            ('reshape_unpack_data', dict(kind='data')),
54                            ('strided_slice', dict(kind='op', op='StridedSlice')),
55                            ],
56                     edges=[('reshape_split', 'reshape_split_data'),
57                            ('reshape_split_data', 'reshape_pack'),
58                            ('reshape_pack', 'reshape_data'),
59                            ('reshape_data', 'softmax'),
60                            ('softmax', 'softmax_data'),
61                            ('softmax_data', 'reshape_unpack'),
62                            ('reshape_unpack', 'reshape_unpack_data'),
63                            ('reshape_unpack_data', 'strided_slice')])
64
65     def is_reshape_bad(self, node_pack, node_unpack, node_ss):
66         shape_in = node_pack.in_node(0).shape
67         shape_out = node_pack.out_node(0).shape
68
69         if len(shape_in) == 5 and len(shape_out) == 2:
70             shape_in = node_unpack.in_node(0).shape
71             shape_out = node_unpack.out_node(0).shape
72             if len(shape_out) == 5 and len(shape_in) == 2:
73                 if node_ss.slices[1].stop == shape_out[1] and node_ss.slices[1].start == 0 and node_ss.slices[
74                     1].step == 1 and \
75                         node_ss.slices[2].stop == shape_out[2] and node_ss.slices[2].start == 0 and node_ss.slices[
76                     2].step == 1 and \
77                         node_ss.slices[3].stop == shape_out[3] and node_ss.slices[3].start == 0 and node_ss.slices[
78                     3].step == 1 and \
79                         node_ss.shrink_axis_mask[4] and \
80                         not node_ss.shrink_axis_mask[1] and not node_ss.shrink_axis_mask[2] and not \
81                         node_ss.shrink_axis_mask[3]:
82                     return True
83                 else:
84                     return False
85         else:
86             return False
87
88     def replace_pattern(self, graph: Graph, match: dict):
89         if graph.graph['layout'] != 'NHWC':
90             return
91
92         if self.is_reshape_bad(match['reshape_pack'], match['reshape_unpack'], match['strided_slice']):
93             log.info("Reshape that pack/unpack several dimensions detected {}".format(match['reshape_pack'].id))
94             node_split = match['reshape_split']
95
96             # insert Permute before reshape
97             data_node = Op._create_data_node(graph, node_split.name + "/Permute_before_data")
98             permute_before = Permute(graph, dict(name=node_split.name + "/Permute_before",
99                                                  order=np.array([0, 2, 3, 1])))
100             in_node = node_split.in_node(0)
101             attrs = deepcopy(graph.get_edge_data(in_node.id, node_split.id)[0])
102             graph.remove_edge(in_node.id, node_split.id)
103             permute_before_node = permute_before.create_node_with_data([in_node], permute_before.attrs,
104                                                                        data_nodes=[data_node])
105             graph.add_edge(permute_before_node.id, node_split.id, **attrs)
106
107             node = match['reshape_pack']
108             node['nchw_layout'] = True
109             new_reshape_shape = np.concatenate((np.array([node.in_node(0).shape[0]]),
110                                                 np.array([np.prod(node.in_node(0).shape[[1, 2, 3]])]),
111                                                 np.array([node.in_node(0).shape[-1]])))
112
113             node.dim = new_reshape_shape
114
115             # insert Permute after reshape
116             data_node = Op._create_data_node(graph, node.name + "/Permute_after_data", {'shape': node.dim})
117             permute_after = Permute(graph, dict(name=node.name + "/Permute_after",
118                                                 order=np.array([0, 2, 1])))
119             out_node = node.out_node(0)
120             out_node.shape = new_reshape_shape[np.array([0, 2, 1])]
121             attrs = deepcopy(graph.get_edge_data(node.id, out_node.id)[0])
122             graph.remove_edge(node.id, out_node.id)
123
124             permute_after.create_node_with_data([data_node], permute_after.attrs,
125                                                 data_nodes=[out_node])
126             graph.add_edge(node.id, data_node.id, **attrs)
127
128             # update softmax shape
129             node_softmax = match['softmax']
130             node_softmax.out_node(0).shape = out_node.shape
131
132             if ConvertGroupedStridedSlice.enabled is True:
133                 # revert strided slice and reshape
134                 node_ss = match['strided_slice']
135                 node_unpack = match['reshape_unpack']
136
137                 unpack_out = node_unpack.out_node(0).id
138                 ss_out = node_ss.out_node(0).id
139
140                 # gather edge attributes
141                 soft_reshape_attrs = deepcopy(graph.get_edge_data(node_softmax.out_node(0).id, node_unpack.id)[0])
142                 reshape_data_attrs = deepcopy(graph.get_edge_data(node_unpack.id, unpack_out)[0])
143                 reshape_ss_attrs = deepcopy(graph.get_edge_data(unpack_out, node_ss.id)[0])
144                 ss_out_attrs = deepcopy(graph.get_edge_data(node_ss.id, ss_out)[0])
145
146                 # remove all edges in Softmax->Reshape->StridedSlice chain
147                 graph.remove_edge(node_softmax.out_node(0).id, node_unpack.id)
148                 graph.remove_edge(node_unpack.id, unpack_out)
149                 graph.remove_edge(unpack_out, node_ss.id)
150                 graph.remove_edge(node_ss.id, ss_out)
151
152                 # add new edges to get chain Softmax->StridedSlice->Reshape
153                 graph.add_edge(node_softmax.out_node(0).id, node_ss.id, **soft_reshape_attrs)
154                 graph.add_edge(node_ss.id, unpack_out, **reshape_data_attrs)
155                 graph.add_edge(unpack_out, node_unpack.id, **reshape_ss_attrs)
156                 graph.add_edge(node_unpack.id, ss_out, **ss_out_attrs)
157
158                 # update output shape and parameters for StridedSlice
159                 node_ss.out_node(0).shape = np.zeros(3)
160                 node_ss.out_node(0).shape[0] = out_node.shape[0]
161                 node_ss.out_node(0).shape[1] = 1
162                 node_ss.out_node(0).shape[2] = out_node.shape[2]
163
164                 old_slices = node_ss.slices.copy()
165                 node_ss.slices = []
166                 node_ss.slices.append(old_slices[0])
167                 node_ss.slices.append(old_slices[-1])
168                 node_ss.slices.append(slice(0, out_node.shape[2], 1))
169                 node_ss.shrink_axis_mask = np.array([0, 0, 0], dtype=np.int64)
170                 node_ss.new_axis_mask = np.array([0, 0, 0], dtype=np.int64)
171                 node_ss.ellipsis_mask = np.array([0, 0, 0], dtype=np.int64)
172                 node_ss.begin_mask = np.array([0, 1, 0], dtype=np.int64)
173                 node_ss.end_mask = np.array([0, 1, 0], dtype=np.int64)
174
175                 # update Reshape attribute
176                 node_unpack.dim = np.delete(node_unpack.dim, 4)
177                 # prevent permute for reshape because it gives wrong result
178                 node_unpack['nchw_layout'] = True
179                 node_unpack.out_node(0)['nchw_layout'] = True
180             else:
181                 # reshape unpack: permute correctly
182                 node_unpack = match['reshape_unpack']
183                 data_node = Op._create_data_node(graph, node.name + "/Permute_after_unpack_data", {'shape': node_unpack.out_node().shape})
184                 permute_after_unpack = Permute(graph, dict(name=node.name + "/Permute_after_unpack",
185                                                            order=np.array([0, 3, 1, 2, 4])))
186                 out_node = node_unpack.out_node(0)
187                 out_node.shape = out_node.shape[np.array([0, 3, 1, 2, 4], dtype=np.int)]
188                 attrs = deepcopy(graph.get_edge_data(node_unpack.id, out_node.id)[0])
189                 graph.remove_edge(node_unpack.id, out_node.id)
190                 permute_after.create_node_with_data([data_node], permute_after_unpack.attrs,
191                                                     data_nodes=[out_node])
192                 graph.add_edge(node_unpack.id, data_node.id, **attrs)