2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
20 from copy import deepcopy
22 from extensions.middle.ConvertGroupedStridedSlice import ConvertGroupedStridedSlice
23 from extensions.middle.FusePermutesSequence import FusePermutesSequence
24 from extensions.middle.ShufflenetReshape import ReshapeSoftmaxReshape
25 from mo.graph.graph import Graph
26 from mo.middle.replacement import MiddleReplacementPattern
27 from mo.ops.op import Op
28 from mo.ops.permute import Permute
31 class PixelLinkReshape(MiddleReplacementPattern):
33 Transform adds Permutes around Reshapes that pack 4 dimensions in 2, than
34 do Softmax and then unpack it back to 5 dims.
39 return [FusePermutesSequence, ReshapeSoftmaxReshape, ConvertGroupedStridedSlice]
42 from extensions.middle.pass_separator import MiddleStart
46 return dict(nodes=[('reshape_split', dict(kind='op', type='Reshape')),
47 ('reshape_split_data', dict(kind='data')),
48 ('reshape_pack', dict(kind='op', type='Reshape')),
49 ('reshape_data', dict(kind='data')),
50 ('softmax', dict(kind='op', type='SoftMax')),
51 ('softmax_data', dict(kind='data')),
52 ('reshape_unpack', dict(kind='op', type='Reshape')),
53 ('reshape_unpack_data', dict(kind='data')),
54 ('strided_slice', dict(kind='op', op='StridedSlice')),
56 edges=[('reshape_split', 'reshape_split_data'),
57 ('reshape_split_data', 'reshape_pack'),
58 ('reshape_pack', 'reshape_data'),
59 ('reshape_data', 'softmax'),
60 ('softmax', 'softmax_data'),
61 ('softmax_data', 'reshape_unpack'),
62 ('reshape_unpack', 'reshape_unpack_data'),
63 ('reshape_unpack_data', 'strided_slice')])
65 def is_reshape_bad(self, node_pack, node_unpack, node_ss):
66 shape_in = node_pack.in_node(0).shape
67 shape_out = node_pack.out_node(0).shape
69 if len(shape_in) == 5 and len(shape_out) == 2:
70 shape_in = node_unpack.in_node(0).shape
71 shape_out = node_unpack.out_node(0).shape
72 if len(shape_out) == 5 and len(shape_in) == 2:
73 if node_ss.slices[1].stop == shape_out[1] and node_ss.slices[1].start == 0 and node_ss.slices[
75 node_ss.slices[2].stop == shape_out[2] and node_ss.slices[2].start == 0 and node_ss.slices[
77 node_ss.slices[3].stop == shape_out[3] and node_ss.slices[3].start == 0 and node_ss.slices[
79 node_ss.shrink_axis_mask[4] and \
80 not node_ss.shrink_axis_mask[1] and not node_ss.shrink_axis_mask[2] and not \
81 node_ss.shrink_axis_mask[3]:
88 def replace_pattern(self, graph: Graph, match: dict):
89 if graph.graph['layout'] != 'NHWC':
92 if self.is_reshape_bad(match['reshape_pack'], match['reshape_unpack'], match['strided_slice']):
93 log.info("Reshape that pack/unpack several dimensions detected {}".format(match['reshape_pack'].id))
94 node_split = match['reshape_split']
96 # insert Permute before reshape
97 data_node = Op._create_data_node(graph, node_split.name + "/Permute_before_data")
98 permute_before = Permute(graph, dict(name=node_split.name + "/Permute_before",
99 order=np.array([0, 2, 3, 1])))
100 in_node = node_split.in_node(0)
101 attrs = deepcopy(graph.get_edge_data(in_node.id, node_split.id)[0])
102 graph.remove_edge(in_node.id, node_split.id)
103 permute_before_node = permute_before.create_node_with_data([in_node], permute_before.attrs,
104 data_nodes=[data_node])
105 graph.add_edge(permute_before_node.id, node_split.id, **attrs)
107 node = match['reshape_pack']
108 node['nchw_layout'] = True
109 new_reshape_shape = np.concatenate((np.array([node.in_node(0).shape[0]]),
110 np.array([np.prod(node.in_node(0).shape[[1, 2, 3]])]),
111 np.array([node.in_node(0).shape[-1]])))
113 node.dim = new_reshape_shape
115 # insert Permute after reshape
116 data_node = Op._create_data_node(graph, node.name + "/Permute_after_data", {'shape': node.dim})
117 permute_after = Permute(graph, dict(name=node.name + "/Permute_after",
118 order=np.array([0, 2, 1])))
119 out_node = node.out_node(0)
120 out_node.shape = new_reshape_shape[np.array([0, 2, 1])]
121 attrs = deepcopy(graph.get_edge_data(node.id, out_node.id)[0])
122 graph.remove_edge(node.id, out_node.id)
124 permute_after.create_node_with_data([data_node], permute_after.attrs,
125 data_nodes=[out_node])
126 graph.add_edge(node.id, data_node.id, **attrs)
128 # update softmax shape
129 node_softmax = match['softmax']
130 node_softmax.out_node(0).shape = out_node.shape
132 if ConvertGroupedStridedSlice.enabled is True:
133 # revert strided slice and reshape
134 node_ss = match['strided_slice']
135 node_unpack = match['reshape_unpack']
137 unpack_out = node_unpack.out_node(0).id
138 ss_out = node_ss.out_node(0).id
140 # gather edge attributes
141 soft_reshape_attrs = deepcopy(graph.get_edge_data(node_softmax.out_node(0).id, node_unpack.id)[0])
142 reshape_data_attrs = deepcopy(graph.get_edge_data(node_unpack.id, unpack_out)[0])
143 reshape_ss_attrs = deepcopy(graph.get_edge_data(unpack_out, node_ss.id)[0])
144 ss_out_attrs = deepcopy(graph.get_edge_data(node_ss.id, ss_out)[0])
146 # remove all edges in Softmax->Reshape->StridedSlice chain
147 graph.remove_edge(node_softmax.out_node(0).id, node_unpack.id)
148 graph.remove_edge(node_unpack.id, unpack_out)
149 graph.remove_edge(unpack_out, node_ss.id)
150 graph.remove_edge(node_ss.id, ss_out)
152 # add new edges to get chain Softmax->StridedSlice->Reshape
153 graph.add_edge(node_softmax.out_node(0).id, node_ss.id, **soft_reshape_attrs)
154 graph.add_edge(node_ss.id, unpack_out, **reshape_data_attrs)
155 graph.add_edge(unpack_out, node_unpack.id, **reshape_ss_attrs)
156 graph.add_edge(node_unpack.id, ss_out, **ss_out_attrs)
158 # update output shape and parameters for StridedSlice
159 node_ss.out_node(0).shape = np.zeros(3)
160 node_ss.out_node(0).shape[0] = out_node.shape[0]
161 node_ss.out_node(0).shape[1] = 1
162 node_ss.out_node(0).shape[2] = out_node.shape[2]
164 old_slices = node_ss.slices.copy()
166 node_ss.slices.append(old_slices[0])
167 node_ss.slices.append(old_slices[-1])
168 node_ss.slices.append(slice(0, out_node.shape[2], 1))
169 node_ss.shrink_axis_mask = np.array([0, 0, 0], dtype=np.int64)
170 node_ss.new_axis_mask = np.array([0, 0, 0], dtype=np.int64)
171 node_ss.ellipsis_mask = np.array([0, 0, 0], dtype=np.int64)
172 node_ss.begin_mask = np.array([0, 1, 0], dtype=np.int64)
173 node_ss.end_mask = np.array([0, 1, 0], dtype=np.int64)
175 # update Reshape attribute
176 node_unpack.dim = np.delete(node_unpack.dim, 4)
177 # prevent permute for reshape because it gives wrong result
178 node_unpack['nchw_layout'] = True
179 node_unpack.out_node(0)['nchw_layout'] = True
181 # reshape unpack: permute correctly
182 node_unpack = match['reshape_unpack']
183 data_node = Op._create_data_node(graph, node.name + "/Permute_after_unpack_data", {'shape': node_unpack.out_node().shape})
184 permute_after_unpack = Permute(graph, dict(name=node.name + "/Permute_after_unpack",
185 order=np.array([0, 3, 1, 2, 4])))
186 out_node = node_unpack.out_node(0)
187 out_node.shape = out_node.shape[np.array([0, 3, 1, 2, 4], dtype=np.int)]
188 attrs = deepcopy(graph.get_edge_data(node_unpack.id, out_node.id)[0])
189 graph.remove_edge(node_unpack.id, out_node.id)
190 permute_after.create_node_with_data([data_node], permute_after_unpack.attrs,
191 data_nodes=[out_node])
192 graph.add_edge(node_unpack.id, data_node.id, **attrs)