Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / ConvertGroupedStridedSlice.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 from copy import deepcopy
18
19 import logging as log
20 import numpy as np
21
22 from extensions.middle.SliceConverter import ConvertSlice
23 from extensions.ops.splitv import SplitV
24 from mo.front.common.partial_infer.utils import int64_array
25 from mo.graph.graph import Node, Graph, add_opoutput
26 from mo.middle.replacement import MiddleReplacementPattern
27 from mo.ops.op import Op
28 from mo.ops.reshape import Reshape
29
30
31 class ConvertGroupedStridedSlice(MiddleReplacementPattern):
32     """
33         This pass converts subgraphs where StridedSlices used for splitting single channel to single Split layers
34         In case if StrdedSlices consume not entire tensor will be created fake outputs for Split layer
35         For example:
36             Let's suppose we have next graph:
37             Data(1,H,W,54)
38                |`---->Sslice1_out (1,H,W,(10,18))
39                `---->Sslice2_out (1,H,W,(18,36))
40
41             In this case StridedSlices takes only [10, 36] from input tensor in 3rd dim
42             So this pass will convert this graph to the next one:
43             Split(1,H,W,54)
44                |`---->Fake_data (1,H,W,10)
45                |`---->Sslice1_out (1,H,W,8)
46                |`---->Sslice2_out (1,H,W,18)
47                `----->Fake_data (1,H,W,18)
48             Where Fake_data - data nodes that have not any consumers.
49     """
50
51     enabled = True
52
53     def run_after(self):
54         return [ConvertSlice]
55
56     def run_before(self):
57         from extensions.middle.pass_separator import MiddleFinish
58         return [MiddleFinish]
59
60     def find_and_replace_pattern(self, graph: Graph):
61         # Iterate over all data nodes and find all with >= 1 consumers
62         data_nodes = [Node(graph, node) for node in graph.node if Node(graph, node).kind == 'data']
63         for input_data in data_nodes:
64             # We don't use constant data nodes
65             if input_data.value is not None:
66                 continue
67
68             input_shape = np.array(input_data.shape)
69
70             # Get all StridedSlice consumers
71             out_nodes = [node for node in input_data.out_nodes() if node.op == 'StridedSlice' and node.in_node(0).name == input_data.name]
72             if len(out_nodes) < 1:
73                 continue
74
75             valid_for_replacement = True
76
77             for node in out_nodes:
78                 if len(node.slices) != len(out_nodes[0].slices):
79                     valid_for_replacement = False
80
81             # Detect dimension for splitting
82             split_channel_dim = None
83             for dim_id, s in enumerate(out_nodes[0].slices):
84                 l, r, stride = s.start, s.stop, s.step
85                 if l != 0 or r != input_shape[dim_id]:
86                     if split_channel_dim is None:
87                         split_channel_dim = dim_id
88                     else:
89                         valid_for_replacement = False
90
91             # split_dims contains tuples with split range and output data node
92             split_dims = []
93             for out_id, node in enumerate(out_nodes):
94                 # Check that StridedSlice op has stride eq 1 and splits only feature channel
95                 for id, s in enumerate(node.slices):
96                     l, r, stride = s.start, s.stop, s.step
97                     # We don't support StridedSlice with stride != 1
98                     if stride != 1:
99                         valid_for_replacement = False
100                     if id == split_channel_dim:
101                         split_dims.append((s.start, s.stop, node.out_node()))
102
103             if not valid_for_replacement:
104                 continue
105
106             # Check feature split intersection
107             final_data_nodes_list = []
108             sorted_split_dims = sorted(split_dims, key=lambda item: (item[0], item[1]))
109
110             # check if we have similar StridedSlice operations with different outputs
111             prev_sd = sorted_split_dims[0]
112             to_remove = []
113             for i in range(1, len(sorted_split_dims)):
114                 if sorted_split_dims[i][0] == prev_sd[0] and sorted_split_dims[i][1] == prev_sd[1] and sorted_split_dims[i][2].name != prev_sd[2].name:
115                     cur_node = sorted_split_dims[i][2]
116                     for out in cur_node.out_nodes():
117                         attrs = deepcopy(graph.get_edge_data(cur_node.id, out.id)[0])
118                         graph.remove_edge(cur_node.id, out.id)
119                         graph.add_edge(prev_sd[2].id, out.id, **attrs)
120                     to_remove.append(i)
121
122             for ind in reversed(to_remove):
123                 sorted_split_dims.pop(ind)
124
125             size_splits = []
126             prev_r = 0
127             for l, r, out in sorted_split_dims:
128                 # Split dims shouldn't intersect
129                 if l < prev_r:
130                     valid_for_replacement = False
131                 # Save missing tensor part
132                 if l > prev_r:
133                     shape = np.array(input_shape)
134                     size_splits.append(l - prev_r)
135                     shape[split_channel_dim] = l - prev_r
136                     data_node = Op._create_data_node(graph, 'fake_data', {'shape': shape})
137                     add_opoutput(graph, data_node.id, 0, False)
138                     final_data_nodes_list.append(data_node)
139
140                 prev_r = r
141                 size_splits.append(r - l)
142                 final_data_nodes_list.append(out)
143
144             if prev_r > input_shape[split_channel_dim]:
145                 valid_for_replacement = False
146             elif prev_r != input_shape[split_channel_dim]:
147                 # Add last part of tensor
148                 shape = input_shape.copy()
149                 shape[split_channel_dim] = input_shape[split_channel_dim] - prev_r
150                 size_splits.append(input_shape[split_channel_dim] - prev_r)
151                 data_node = Op._create_data_node(graph, 'fake_data', {'shape': shape})
152                 add_opoutput(graph, data_node.id, 0, False)
153                 final_data_nodes_list.append(data_node)
154
155             if not valid_for_replacement:
156                 continue
157
158             for node in out_nodes:
159                 if not np.all([x == 0 for x in node.shrink_axis_mask]):
160                     out_node = node.out_node()
161                     if np.any(node['shrink_axis_mask']):
162                         self.add_reshape_for_shrink(graph, node)
163                     if np.any(node['new_axis_mask']):
164                         self.add_reshape_for_new(graph, node)
165
166                     for i in range(len(final_data_nodes_list)):
167                         if final_data_nodes_list[i].name == out_node.name:
168                             final_data_nodes_list[i] = node.out_node()
169                             break
170
171             # Insert Split layer and remove old StridedSlice layers
172             # 1. Remove connections from input_data to StridedSlice ops
173             out_data_nodes = []
174             name_for_future_split = out_nodes[0].name
175             for node in out_nodes:
176                 out_data_nodes.append(node.out_node())
177                 graph.remove_edge(input_data.id, node.id)
178                 graph.remove_edge(node.id, node.out_node().id)
179                 graph.remove_node(node.id)
180                 log.debug("Removed: {}".format(node.id))
181
182             # 2. Create Split layer and reorder outputs
183             split = SplitV(graph, dict(name=name_for_future_split + "/Split", axis=split_channel_dim,
184                                        size_splits=size_splits, out_ports_count=len(size_splits)))
185             split.create_node_with_data(inputs=[input_data], data_nodes=final_data_nodes_list)
186
187     @staticmethod
188     def add_reshape_for_shrink(graph: Graph, ss_node):
189         # add Reshape for shrink_axis_mask
190         log.info("StridedSlice op with shrink mask '{}' has been detected".format(ss_node.id))
191         node = ss_node
192
193         if len(node.in_nodes()) != 4 or len(node.out_nodes()) != 1:
194             return
195
196         shape_out = node.out_node().shape
197         dim = shape_out.copy()
198         ss_shape = []
199         k = 0
200
201         # Don't permute reshape if channels were squeezed
202         dont_permute = False
203         if graph.graph['layout'] == 'NHWC' and node['shrink_axis_mask'][-1] == 1:
204             dont_permute = True
205
206         for i in range(0, len(node['shrink_axis_mask'])):
207             if not node['shrink_axis_mask'][i]:
208                 ss_shape.append(shape_out[k])
209                 k = k + 1
210             else:
211                 node['shrink_axis_mask'][i] = 0
212                 ss_shape.append(1)
213
214         out_node = node.out_node(0)
215
216         # insert data node for StridedSlice
217         data_node = Op._create_data_node(graph, node.name + "/Reshape_shrink_data", {'shape': int64_array(ss_shape)})
218         attrs = deepcopy(graph.get_edge_data(node.id, out_node.id)[0])
219         graph.remove_edge(node.id, out_node.id)
220         graph.add_edge(node.id, data_node.id, **attrs)
221
222         # insert Reshape
223         if dont_permute:
224             reshape = Reshape(graph, dict(name=node.name + "/Reshape_shrink",
225                                           dim=np.array(dim, dtype=np.int64), nchw_layout=True))
226             reshape_data_node = reshape.create_node_with_data([data_node], reshape.attrs,
227                                                               data_nodes=[out_node])
228             reshape_data_node['nchw_layout'] = True
229         else:
230             reshape = Reshape(graph, dict(name=node.name + "/Reshape_shrink",
231                                           dim=np.array(dim, dtype=np.int64)))
232             reshape_data_node = reshape.create_node_with_data([data_node], reshape.attrs,
233                                                               data_nodes=[out_node])
234
235     @staticmethod
236     def add_reshape_for_new(graph: Graph, ss_node):
237         log.info("StridedSlice op with new axis mask '{}' has been detected".format(ss_node.id))
238         node = ss_node
239
240         if len(node.in_nodes()) != 4 or len(node.out_nodes()) != 1:
241             return
242
243         shape_out = node.out_node().shape
244         dim = shape_out.copy()
245         ss_shape = []
246         for i in range(0, len(node['new_axis_mask'])):
247             if not node['new_axis_mask'][i]:
248                 ss_shape.append(shape_out[i])
249             else:
250                 node['new_axis_mask'][i] = 0
251
252         out_node = node.out_node(0)
253         # insert data node for StridedSlice
254         data_node = Op._create_data_node(graph, node.name + "/Reshape_new_data", {'shape': ss_shape})
255         attrs = deepcopy(graph.get_edge_data(node.id, out_node.id)[0])
256         graph.remove_edge(node.id, out_node.id)
257         graph.add_edge(node.id, data_node.id, **attrs)
258
259         # insert Reshape
260         reshape = Reshape(graph, dict(name=node.name + "/Reshape_new",
261                                       dim=np.array(dim, dtype=np.int64)))
262         reshape.create_node_with_data([data_node], reshape.attrs, data_nodes=[out_node])