2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
17 from copy import deepcopy
22 from extensions.middle.SliceConverter import ConvertSlice
23 from extensions.ops.splitv import SplitV
24 from mo.front.common.partial_infer.utils import int64_array
25 from mo.graph.graph import Node, Graph, add_opoutput
26 from mo.middle.replacement import MiddleReplacementPattern
27 from mo.ops.op import Op
28 from mo.ops.reshape import Reshape
31 class ConvertGroupedStridedSlice(MiddleReplacementPattern):
33 This pass converts subgraphs where StridedSlices used for splitting single channel to single Split layers
34 In case if StrdedSlices consume not entire tensor will be created fake outputs for Split layer
36 Let's suppose we have next graph:
38 |`---->Sslice1_out (1,H,W,(10,18))
39 `---->Sslice2_out (1,H,W,(18,36))
41 In this case StridedSlices takes only [10, 36] from input tensor in 3rd dim
42 So this pass will convert this graph to the next one:
44 |`---->Fake_data (1,H,W,10)
45 |`---->Sslice1_out (1,H,W,8)
46 |`---->Sslice2_out (1,H,W,18)
47 `----->Fake_data (1,H,W,18)
48 Where Fake_data - data nodes that have not any consumers.
57 from extensions.middle.pass_separator import MiddleFinish
60 def find_and_replace_pattern(self, graph: Graph):
61 # Iterate over all data nodes and find all with >= 1 consumers
62 data_nodes = [Node(graph, node) for node in graph.node if Node(graph, node).kind == 'data']
63 for input_data in data_nodes:
64 # We don't use constant data nodes
65 if input_data.value is not None:
68 input_shape = np.array(input_data.shape)
70 # Get all StridedSlice consumers
71 out_nodes = [node for node in input_data.out_nodes() if node.op == 'StridedSlice' and node.in_node(0).name == input_data.name]
72 if len(out_nodes) < 1:
75 valid_for_replacement = True
77 for node in out_nodes:
78 if len(node.slices) != len(out_nodes[0].slices):
79 valid_for_replacement = False
81 # Detect dimension for splitting
82 split_channel_dim = None
83 for dim_id, s in enumerate(out_nodes[0].slices):
84 l, r, stride = s.start, s.stop, s.step
85 if l != 0 or r != input_shape[dim_id]:
86 if split_channel_dim is None:
87 split_channel_dim = dim_id
89 valid_for_replacement = False
91 # split_dims contains tuples with split range and output data node
93 for out_id, node in enumerate(out_nodes):
94 # Check that StridedSlice op has stride eq 1 and splits only feature channel
95 for id, s in enumerate(node.slices):
96 l, r, stride = s.start, s.stop, s.step
97 # We don't support StridedSlice with stride != 1
99 valid_for_replacement = False
100 if id == split_channel_dim:
101 split_dims.append((s.start, s.stop, node.out_node()))
103 if not valid_for_replacement:
106 # Check feature split intersection
107 final_data_nodes_list = []
108 sorted_split_dims = sorted(split_dims, key=lambda item: (item[0], item[1]))
110 # check if we have similar StridedSlice operations with different outputs
111 prev_sd = sorted_split_dims[0]
113 for i in range(1, len(sorted_split_dims)):
114 if sorted_split_dims[i][0] == prev_sd[0] and sorted_split_dims[i][1] == prev_sd[1] and sorted_split_dims[i][2].name != prev_sd[2].name:
115 cur_node = sorted_split_dims[i][2]
116 for out in cur_node.out_nodes():
117 attrs = deepcopy(graph.get_edge_data(cur_node.id, out.id)[0])
118 graph.remove_edge(cur_node.id, out.id)
119 graph.add_edge(prev_sd[2].id, out.id, **attrs)
122 for ind in reversed(to_remove):
123 sorted_split_dims.pop(ind)
127 for l, r, out in sorted_split_dims:
128 # Split dims shouldn't intersect
130 valid_for_replacement = False
131 # Save missing tensor part
133 shape = np.array(input_shape)
134 size_splits.append(l - prev_r)
135 shape[split_channel_dim] = l - prev_r
136 data_node = Op._create_data_node(graph, 'fake_data', {'shape': shape})
137 add_opoutput(graph, data_node.id, 0, False)
138 final_data_nodes_list.append(data_node)
141 size_splits.append(r - l)
142 final_data_nodes_list.append(out)
144 if prev_r > input_shape[split_channel_dim]:
145 valid_for_replacement = False
146 elif prev_r != input_shape[split_channel_dim]:
147 # Add last part of tensor
148 shape = input_shape.copy()
149 shape[split_channel_dim] = input_shape[split_channel_dim] - prev_r
150 size_splits.append(input_shape[split_channel_dim] - prev_r)
151 data_node = Op._create_data_node(graph, 'fake_data', {'shape': shape})
152 add_opoutput(graph, data_node.id, 0, False)
153 final_data_nodes_list.append(data_node)
155 if not valid_for_replacement:
158 for node in out_nodes:
159 if not np.all([x == 0 for x in node.shrink_axis_mask]):
160 out_node = node.out_node()
161 if np.any(node['shrink_axis_mask']):
162 self.add_reshape_for_shrink(graph, node)
163 if np.any(node['new_axis_mask']):
164 self.add_reshape_for_new(graph, node)
166 for i in range(len(final_data_nodes_list)):
167 if final_data_nodes_list[i].name == out_node.name:
168 final_data_nodes_list[i] = node.out_node()
171 # Insert Split layer and remove old StridedSlice layers
172 # 1. Remove connections from input_data to StridedSlice ops
174 name_for_future_split = out_nodes[0].name
175 for node in out_nodes:
176 out_data_nodes.append(node.out_node())
177 graph.remove_edge(input_data.id, node.id)
178 graph.remove_edge(node.id, node.out_node().id)
179 graph.remove_node(node.id)
180 log.debug("Removed: {}".format(node.id))
182 # 2. Create Split layer and reorder outputs
183 split = SplitV(graph, dict(name=name_for_future_split + "/Split", axis=split_channel_dim,
184 size_splits=size_splits, out_ports_count=len(size_splits)))
185 split.create_node_with_data(inputs=[input_data], data_nodes=final_data_nodes_list)
188 def add_reshape_for_shrink(graph: Graph, ss_node):
189 # add Reshape for shrink_axis_mask
190 log.info("StridedSlice op with shrink mask '{}' has been detected".format(ss_node.id))
193 if len(node.in_nodes()) != 4 or len(node.out_nodes()) != 1:
196 shape_out = node.out_node().shape
197 dim = shape_out.copy()
201 # Don't permute reshape if channels were squeezed
203 if graph.graph['layout'] == 'NHWC' and node['shrink_axis_mask'][-1] == 1:
206 for i in range(0, len(node['shrink_axis_mask'])):
207 if not node['shrink_axis_mask'][i]:
208 ss_shape.append(shape_out[k])
211 node['shrink_axis_mask'][i] = 0
214 out_node = node.out_node(0)
216 # insert data node for StridedSlice
217 data_node = Op._create_data_node(graph, node.name + "/Reshape_shrink_data", {'shape': int64_array(ss_shape)})
218 attrs = deepcopy(graph.get_edge_data(node.id, out_node.id)[0])
219 graph.remove_edge(node.id, out_node.id)
220 graph.add_edge(node.id, data_node.id, **attrs)
224 reshape = Reshape(graph, dict(name=node.name + "/Reshape_shrink",
225 dim=np.array(dim, dtype=np.int64), nchw_layout=True))
226 reshape_data_node = reshape.create_node_with_data([data_node], reshape.attrs,
227 data_nodes=[out_node])
228 reshape_data_node['nchw_layout'] = True
230 reshape = Reshape(graph, dict(name=node.name + "/Reshape_shrink",
231 dim=np.array(dim, dtype=np.int64)))
232 reshape_data_node = reshape.create_node_with_data([data_node], reshape.attrs,
233 data_nodes=[out_node])
236 def add_reshape_for_new(graph: Graph, ss_node):
237 log.info("StridedSlice op with new axis mask '{}' has been detected".format(ss_node.id))
240 if len(node.in_nodes()) != 4 or len(node.out_nodes()) != 1:
243 shape_out = node.out_node().shape
244 dim = shape_out.copy()
246 for i in range(0, len(node['new_axis_mask'])):
247 if not node['new_axis_mask'][i]:
248 ss_shape.append(shape_out[i])
250 node['new_axis_mask'][i] = 0
252 out_node = node.out_node(0)
253 # insert data node for StridedSlice
254 data_node = Op._create_data_node(graph, node.name + "/Reshape_new_data", {'shape': ss_shape})
255 attrs = deepcopy(graph.get_edge_data(node.id, out_node.id)[0])
256 graph.remove_edge(node.id, out_node.id)
257 graph.add_edge(node.id, data_node.id, **attrs)
260 reshape = Reshape(graph, dict(name=node.name + "/Reshape_new",
261 dim=np.array(dim, dtype=np.int64)))
262 reshape.create_node_with_data([data_node], reshape.attrs, data_nodes=[out_node])