2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from extensions.middle.ConvertGroupedStridedSlice import ConvertGroupedStridedSlice
22 from mo.graph.graph import Node
23 from mo.utils.unittest.graph import build_graph, compare_graphs
26 'placeholder_1': {'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
27 'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
28 'placeholder_2': {'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
29 'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
30 'placeholder_begin_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
31 'placeholder_end_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
32 'placeholder_stride_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
34 'sslice_1': {'type': None, 'kind': 'op', 'op': 'StridedSlice', 'slices': None,
35 'shrink_axis_mask': np.array([0, 0, 0, 0])},
36 'sslice_1_data': {'value': None, 'shape': None, 'kind': 'data'},
37 'sslice_2': {'type': None, 'kind': 'op', 'op': 'StridedSlice', 'slices': None,
38 'shrink_axis_mask': np.array([0, 0, 0, 0])},
39 'sslice_2_data': {'value': None, 'shape': None, 'kind': 'data'},
40 'sslice_3': {'type': None, 'kind': 'op', 'op': 'StridedSlice', 'slices': None,
41 'shrink_axis_mask': np.array([0, 0, 0, 0])},
42 'sslice_3_data': {'value': None, 'shape': None, 'kind': 'data'},
44 'split_1': {'type': 'Split', 'kind': 'op', 'op': 'SplitV'},
45 'split_1_data': {'value': None, 'shape': None, 'kind': 'data'},
46 'split_2_data': {'value': None, 'shape': None, 'kind': 'data'},
47 'split_3_data': {'value': None, 'shape': None, 'kind': 'data'},
48 'split_4_data': {'value': None, 'shape': None, 'kind': 'data'},
50 'concat_1': {'type': 'Concat', 'kind': 'op', 'op': 'Concat'},
51 'concat_1_data': {'value': None, 'shape': None, 'kind': 'data'},
52 'op_output': {'kind': 'op', 'op': 'OpOutput'},
53 'op_output_1': {'kind': 'op', 'op': 'OpOutput'},
54 'op_output_2': {'kind': 'op', 'op': 'OpOutput'},
56 'sslice_1/Reshape_shrink': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
57 'sslice_1/Reshape_shrink_data': {'value': None, 'shape': None, 'kind': 'data'},
58 'sslice_2/Reshape_shrink': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
59 'sslice_2/Reshape_shrink_data': {'value': None, 'shape': None, 'kind': 'data'},
60 'sslice_2/Reshape_new': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
61 'sslice_2/Reshape_new_data': {'value': None, 'shape': None, 'kind': 'data'},
65 class ConvertGroupedStridedSliceTests(unittest.TestCase):
67 graph = build_graph(nodes_attributes,
68 [('placeholder_1', 'placeholder_1_data'),
69 ('placeholder_1_data', 'sslice_1'),
70 ('sslice_1', 'sslice_1_data'),
71 ('placeholder_1_data', 'sslice_2'),
72 ('sslice_2', 'sslice_2_data'),
73 ('placeholder_1_data', 'sslice_3'),
74 ('sslice_3', 'sslice_3_data'),
75 ('sslice_1_data', 'concat_1'),
76 ('sslice_2_data', 'concat_1'),
77 ('sslice_3_data', 'concat_1'),
78 ('concat_1', 'concat_1_data'),
79 ('concat_1_data', 'op_output')
81 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
83 'sslice_1': {'slices': np.array(
84 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 18, 1)])},
85 'sslice_1_data': {'shape': np.array([1, 227, 227, 18])},
87 'sslice_2': {'slices': np.array(
88 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(18, 36, 1)])},
89 'sslice_2_data': {'shape': np.array([1, 227, 227, 18])},
91 'sslice_3': {'slices': np.array(
92 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(36, 54, 1)])},
93 'sslice_3_data': {'shape': np.array([1, 227, 227, 18])},
95 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
97 graph.graph['layout'] = 'NHWC'
99 graph_ref = build_graph(nodes_attributes,
100 [('placeholder_1', 'placeholder_1_data'),
101 ('placeholder_1_data', 'split_1'),
102 ('split_1', 'split_1_data'),
103 ('split_1', 'split_2_data'),
104 ('split_1', 'split_3_data'),
105 ('split_1_data', 'concat_1'),
106 ('split_2_data', 'concat_1'),
107 ('split_3_data', 'concat_1'),
108 ('concat_1', 'concat_1_data'),
109 ('concat_1_data', 'op_output')
112 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
113 'split_1': {'axis': 3},
114 'split_1_data': {'shape': np.array([1, 227, 227, 18])},
115 'split_2_data': {'shape': np.array([1, 227, 227, 18])},
116 'split_3_data': {'shape': np.array([1, 227, 227, 18])},
117 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
120 pattern = ConvertGroupedStridedSlice()
121 pattern.find_and_replace_pattern(graph)
123 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data', check_op_attrs=True)
124 self.assertTrue(flag, resp)
127 graph = build_graph(nodes_attributes,
128 [('placeholder_1', 'placeholder_1_data'),
129 ('placeholder_1_data', 'sslice_1'),
130 ('sslice_1', 'sslice_1_data'),
131 ('placeholder_1_data', 'sslice_2'),
132 ('sslice_2', 'sslice_2_data'),
133 ('placeholder_1_data', 'sslice_3'),
134 ('sslice_3', 'sslice_3_data'),
135 ('sslice_1_data', 'concat_1'),
136 ('sslice_2_data', 'concat_1'),
137 ('sslice_3_data', 'concat_1'),
138 ('concat_1', 'concat_1_data'),
139 ('concat_1_data', 'op_output')
141 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
143 'sslice_1': {'slices': np.array(
144 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(19, 37, 1)])},
145 'sslice_1_data': {'shape': np.array([1, 227, 227, 18])},
147 'sslice_2': {'slices': np.array(
148 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(37, 54, 1)])},
149 'sslice_2_data': {'shape': np.array([1, 227, 227, 17])},
151 'sslice_3': {'slices': np.array(
152 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 19, 1)])},
153 'sslice_3_data': {'shape': np.array([1, 227, 227, 19])},
155 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
157 graph.graph['layout'] = 'NHWC'
159 graph_ref = build_graph(nodes_attributes,
160 [('placeholder_1', 'placeholder_1_data'),
161 ('placeholder_1_data', 'split_1'),
162 ('split_1', 'split_1_data'),
163 ('split_1', 'split_2_data'),
164 ('split_1', 'split_3_data'),
165 ('split_1_data', 'concat_1'),
166 ('split_2_data', 'concat_1'),
167 ('split_3_data', 'concat_1'),
168 ('concat_1', 'concat_1_data'),
169 ('concat_1_data', 'op_output')
171 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
172 'split_1': {'axis': 3},
173 'split_1_data': {'shape': np.array([1, 227, 227, 18])},
174 'split_2_data': {'shape': np.array([1, 227, 227, 17])},
175 'split_3_data': {'shape': np.array([1, 227, 227, 19])},
176 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
179 pattern = ConvertGroupedStridedSlice()
180 pattern.find_and_replace_pattern(graph)
182 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data', check_op_attrs=True)
183 self.assertTrue(flag, resp)
185 # Intersection of split ranges in feature dimension
186 def test_3_neg(self):
187 graph = build_graph(nodes_attributes,
188 [('placeholder_1', 'placeholder_1_data'),
189 ('placeholder_1_data', 'sslice_1'),
190 ('sslice_1', 'sslice_1_data'),
191 ('placeholder_1_data', 'sslice_2'),
192 ('sslice_2', 'sslice_2_data'),
193 ('placeholder_1_data', 'sslice_3'),
194 ('sslice_3', 'sslice_3_data'),
195 ('sslice_1_data', 'concat_1'),
196 ('sslice_2_data', 'concat_1'),
197 ('sslice_3_data', 'concat_1'),
198 ('concat_1', 'concat_1_data'),
199 ('concat_1_data', 'op_output')
201 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
203 'sslice_1': {'slices': np.array(
204 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(19, 39, 1)])},
205 'sslice_1_data': {'shape': np.array([1, 227, 227, 20])},
207 'sslice_2': {'slices': np.array(
208 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(37, 54, 1)])},
209 'sslice_2_data': {'shape': np.array([1, 227, 227, 17])},
211 'sslice_3': {'slices': np.array(
212 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 19, 1)])},
213 'sslice_3_data': {'shape': np.array([1, 227, 227, 19])},
215 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
217 graph.graph['layout'] = 'NHWC'
219 graph_ref = build_graph(nodes_attributes,
220 [('placeholder_1', 'placeholder_1_data'),
221 ('placeholder_1_data', 'sslice_1'),
222 ('sslice_1', 'sslice_1_data'),
223 ('placeholder_1_data', 'sslice_2'),
224 ('sslice_2', 'sslice_2_data'),
225 ('placeholder_1_data', 'sslice_3'),
226 ('sslice_3', 'sslice_3_data'),
227 ('sslice_1_data', 'concat_1'),
228 ('sslice_2_data', 'concat_1'),
229 ('sslice_3_data', 'concat_1'),
230 ('concat_1', 'concat_1_data'),
231 ('concat_1_data', 'op_output')
233 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
235 'sslice_1': {'slices': np.array(
236 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(19, 39, 1)])},
237 'sslice_1_data': {'shape': np.array([1, 227, 227, 20])},
239 'sslice_2': {'slices': np.array(
240 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(37, 54, 1)])},
241 'sslice_2_data': {'shape': np.array([1, 227, 227, 17])},
243 'sslice_3': {'slices': np.array(
244 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 19, 1)])},
245 'sslice_3_data': {'shape': np.array([1, 227, 227, 19])},
247 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
250 pattern = ConvertGroupedStridedSlice()
251 pattern.find_and_replace_pattern(graph)
253 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data', check_op_attrs=True)
254 self.assertTrue(flag, resp)
256 # Split range overflow in feature dimension
257 def test_4_neg(self):
258 graph = build_graph(nodes_attributes,
259 [('placeholder_1', 'placeholder_1_data'),
260 ('placeholder_1_data', 'sslice_1'),
261 ('sslice_1', 'sslice_1_data'),
262 ('placeholder_1_data', 'sslice_2'),
263 ('sslice_2', 'sslice_2_data'),
264 ('placeholder_1_data', 'sslice_3'),
265 ('sslice_3', 'sslice_3_data'),
266 ('sslice_1_data', 'concat_1'),
267 ('sslice_2_data', 'concat_1'),
268 ('sslice_3_data', 'concat_1'),
269 ('concat_1', 'concat_1_data'),
270 ('concat_1_data', 'op_output')
272 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
274 'sslice_1': {'slices': np.array(
275 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(19, 37, 1)])},
276 'sslice_1_data': {'shape': np.array([1, 227, 227, 18])},
278 'sslice_2': {'slices': np.array(
279 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(37, 55, 1)])},
280 'sslice_2_data': {'shape': np.array([1, 227, 227, 18])},
282 'sslice_3': {'slices': np.array(
283 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 19, 1)])},
284 'sslice_3_data': {'shape': np.array([1, 227, 227, 19])},
286 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
288 graph.graph['layout'] = 'NHWC'
290 graph_ref = build_graph(nodes_attributes,
291 [('placeholder_1', 'placeholder_1_data'),
292 ('placeholder_1_data', 'sslice_1'),
293 ('sslice_1', 'sslice_1_data'),
294 ('placeholder_1_data', 'sslice_2'),
295 ('sslice_2', 'sslice_2_data'),
296 ('placeholder_1_data', 'sslice_3'),
297 ('sslice_3', 'sslice_3_data'),
298 ('sslice_1_data', 'concat_1'),
299 ('sslice_2_data', 'concat_1'),
300 ('sslice_3_data', 'concat_1'),
301 ('concat_1', 'concat_1_data'),
302 ('concat_1_data', 'op_output')
304 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
306 'sslice_1': {'slices': np.array(
307 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(19, 37, 1)])},
308 'sslice_1_data': {'shape': np.array([1, 227, 227, 18])},
310 'sslice_2': {'slices': np.array(
311 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(37, 55, 1)])},
312 'sslice_2_data': {'shape': np.array([1, 227, 227, 18])},
314 'sslice_3': {'slices': np.array(
315 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 19, 1)])},
316 'sslice_3_data': {'shape': np.array([1, 227, 227, 19])},
318 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
321 pattern = ConvertGroupedStridedSlice()
322 pattern.find_and_replace_pattern(graph)
324 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data', check_op_attrs=True)
325 self.assertTrue(flag, resp)
327 # Split(1,H,W,54)--->Fake_data (1,H,W,1)
328 # |`---->Sslice1_out (1,H,W,18)
329 # |`---->Sslice2_out (1,H,W,18)
330 # `----->Sslice3_out (1,H,W,17)
332 graph = build_graph(nodes_attributes,
333 [('placeholder_1', 'placeholder_1_data'),
334 ('placeholder_1_data', 'sslice_1'),
335 ('sslice_1', 'sslice_1_data'),
336 ('placeholder_1_data', 'sslice_2'),
337 ('sslice_2', 'sslice_2_data'),
338 ('placeholder_1_data', 'sslice_3'),
339 ('sslice_3', 'sslice_3_data'),
340 ('sslice_1_data', 'concat_1'),
341 ('sslice_2_data', 'concat_1'),
342 ('sslice_3_data', 'concat_1'),
343 ('concat_1', 'concat_1_data'),
344 ('concat_1_data', 'op_output'),
346 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
348 'sslice_1': {'slices': np.array(
349 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(19, 37, 1)])},
350 'sslice_1_data': {'shape': np.array([1, 227, 227, 18])},
352 'sslice_2': {'slices': np.array(
353 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(37, 54, 1)])},
354 'sslice_2_data': {'shape': np.array([1, 227, 227, 17])},
356 'sslice_3': {'slices': np.array(
357 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(1, 19, 1)])},
358 'sslice_3_data': {'shape': np.array([1, 227, 227, 18])},
360 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
362 graph.graph['layout'] = 'NHWC'
364 graph_ref = build_graph(nodes_attributes,
365 [('placeholder_1', 'placeholder_1_data'),
366 ('placeholder_1_data', 'split_1'),
367 ('split_1', 'split_1_data'),
368 ('split_1', 'split_2_data'),
369 ('split_1', 'split_3_data'),
370 ('split_1', 'split_4_data'),
371 ('split_2_data', 'concat_1'),
372 ('split_3_data', 'concat_1'),
373 ('split_4_data', 'concat_1'),
374 ('concat_1', 'concat_1_data'),
375 ('concat_1_data', 'op_output'),
376 ('split_1_data', 'op_output_1')
378 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
379 'split_1': {'axis': 3},
380 'split_1_data': {'shape': np.array([1, 227, 227, 1])},
381 'split_2_data': {'shape': np.array([1, 227, 227, 18])},
382 'split_3_data': {'shape': np.array([1, 227, 227, 17])},
383 'split_4_data': {'shape': np.array([1, 227, 227, 18])},
384 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
387 pattern = ConvertGroupedStridedSlice()
388 pattern.find_and_replace_pattern(graph)
390 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data', check_op_attrs=True)
391 self.assertTrue(flag, resp)
394 # |`---->Sslice1_out (1,H,W,(0,18))
395 # |`---->Fake_data (1,H,W,(18,27))
396 # |`---->Sslice3_out (1,H,W,(27,45))
397 # `----->Fake_data (1,H,W,(45,54))
399 graph = build_graph(nodes_attributes,
400 [('placeholder_1', 'placeholder_1_data'),
401 ('placeholder_1_data', 'sslice_1'),
402 ('sslice_1', 'sslice_1_data'),
403 ('placeholder_1_data', 'sslice_2'),
404 ('sslice_2', 'sslice_2_data'),
405 ('sslice_1_data', 'concat_1'),
406 ('sslice_2_data', 'concat_1'),
407 ('concat_1', 'concat_1_data'),
408 ('concat_1_data', 'op_output')
410 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
412 'sslice_1': {'slices': np.array(
413 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 18, 1)])},
414 'sslice_1_data': {'shape': np.array([1, 227, 227, 18])},
416 'sslice_2': {'slices': np.array(
417 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(27, 45, 1)])},
418 'sslice_2_data': {'shape': np.array([1, 227, 227, 18])},
420 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
422 graph.graph['layout'] = 'NHWC'
424 graph_ref = build_graph(nodes_attributes,
425 [('placeholder_1', 'placeholder_1_data'),
426 ('placeholder_1_data', 'split_1'),
427 ('split_1', 'split_1_data'),
428 ('split_1', 'split_2_data'),
429 ('split_1', 'split_3_data'),
430 ('split_1', 'split_4_data'),
431 ('split_1_data', 'concat_1'),
432 ('split_3_data', 'concat_1'),
433 ('concat_1', 'concat_1_data'),
434 ('concat_1_data', 'op_output'),
435 ('split_2_data', 'op_output_1'),
436 ('split_4_data', 'op_output_2'),
438 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
439 'split_1': {'axis': 3},
440 'split_1_data': {'shape': np.array([1, 227, 227, 18])},
441 'split_2_data': {'shape': np.array([1, 227, 227, 9])},
442 'split_3_data': {'shape': np.array([1, 227, 227, 18])},
443 'split_4_data': {'shape': np.array([1, 227, 227, 9])},
444 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
447 pattern = ConvertGroupedStridedSlice()
448 pattern.find_and_replace_pattern(graph)
450 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data', check_op_attrs=True)
451 self.assertTrue(flag, resp)
453 def test_7_neg(self):
454 graph = build_graph(nodes_attributes,
455 [('placeholder_1', 'placeholder_1_data'),
456 ('placeholder_1_data', 'sslice_1'),
457 ('sslice_1', 'sslice_1_data'),
458 ('placeholder_1_data', 'sslice_2'),
459 ('sslice_2', 'sslice_2_data'),
460 ('sslice_1_data', 'concat_1'),
461 ('sslice_2_data', 'concat_1'),
462 ('concat_1', 'concat_1_data'),
463 ('concat_1_data', 'op_output')
465 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
467 'sslice_1': {'slices': np.array(
468 [slice(0, 1, 1), slice(0, 10, 1), slice(0, 227, 1), slice(0, 18, 1)])},
469 'sslice_1_data': {'shape': np.array([1, 10, 227, 18])},
471 'sslice_2': {'slices': np.array(
472 [slice(0, 1, 1), slice(10, 227, 1), slice(0, 227, 1), slice(27, 45, 1)])},
473 'sslice_2_data': {'shape': np.array([1, 217, 227, 18])},
475 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
477 graph.graph['layout'] = 'NHWC'
479 graph_ref = build_graph(nodes_attributes,
480 [('placeholder_1', 'placeholder_1_data'),
481 ('placeholder_1_data', 'sslice_1'),
482 ('sslice_1', 'sslice_1_data'),
483 ('placeholder_1_data', 'sslice_2'),
484 ('sslice_2', 'sslice_2_data'),
485 ('sslice_1_data', 'concat_1'),
486 ('sslice_2_data', 'concat_1'),
487 ('concat_1', 'concat_1_data'),
488 ('concat_1_data', 'op_output')
490 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
492 'sslice_1': {'slices': np.array(
493 [slice(0, 1, 1), slice(0, 10, 1), slice(0, 227, 1), slice(0, 18, 1)])},
494 'sslice_1_data': {'shape': np.array([1, 10, 227, 18])},
496 'sslice_2': {'slices': np.array(
497 [slice(0, 1, 1), slice(10, 227, 1), slice(0, 227, 1), slice(27, 45, 1)])},
498 'sslice_2_data': {'shape': np.array([1, 217, 227, 18])},
500 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
503 pattern = ConvertGroupedStridedSlice()
504 pattern.find_and_replace_pattern(graph)
506 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data', check_op_attrs=True)
507 self.assertTrue(flag, resp)
510 # |`---->Sslice1_out (1,(0,18),W,C)
511 # |`---->Sslice2_out (1,(18,36),W,C)
512 # `----->Fake_data (1,(36,54),W,C)
514 graph = build_graph(nodes_attributes,
515 [('placeholder_1', 'placeholder_1_data'),
516 ('placeholder_1_data', 'sslice_1'),
517 ('sslice_1', 'sslice_1_data'),
518 ('placeholder_1_data', 'sslice_2'),
519 ('sslice_2', 'sslice_2_data'),
520 ('sslice_1_data', 'concat_1'),
521 ('sslice_2_data', 'concat_1'),
522 ('concat_1', 'concat_1_data'),
523 ('concat_1_data', 'op_output')
525 {'placeholder_1_data': {'shape': np.array([1, 54, 54, 3])},
527 'sslice_1': {'slices': np.array(
528 [slice(0, 1, 1), slice(0, 18, 1), slice(0, 54, 1), slice(0, 3, 1)])},
529 'sslice_1_data': {'shape': np.array([1, 18, 54, 3])},
531 'sslice_2': {'slices': np.array(
532 [slice(0, 1, 1), slice(18, 36, 1), slice(0, 54, 1), slice(0, 3, 1)])},
533 'sslice_2_data': {'shape': np.array([1, 18, 54, 3])},
535 'concat_1_data': {'shape': np.array([1, 54, 54, 3])},
537 graph.graph['layout'] = 'NHWC'
539 graph_ref = build_graph(nodes_attributes,
540 [('placeholder_1', 'placeholder_1_data'),
541 ('placeholder_1_data', 'split_1'),
542 ('split_1', 'split_1_data'),
543 ('split_1', 'split_2_data'),
544 ('split_1', 'split_3_data'),
545 ('split_1_data', 'concat_1'),
546 ('split_3_data', 'concat_1'),
547 ('concat_1', 'concat_1_data'),
548 ('concat_1_data', 'op_output'),
549 ('split_2_data', 'op_output_1')
551 {'placeholder_1_data': {'shape': np.array([1, 54, 54, 3])},
552 'split_1': {'axis': 1},
553 'split_1_data': {'shape': np.array([1, 18, 54, 3])},
554 'split_2_data': {'shape': np.array([1, 18, 54, 3])},
555 'split_3_data': {'shape': np.array([1, 18, 54, 3])},
556 'concat_1_data': {'shape': np.array([1, 54, 54, 3])},
559 pattern = ConvertGroupedStridedSlice()
560 pattern.find_and_replace_pattern(graph)
562 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data', check_op_attrs=True)
563 self.assertTrue(flag, resp)
566 class AddReshapeAfterStridedSliceTests(unittest.TestCase):
567 def test_ss_1_shrink_last(self):
568 graph = build_graph(nodes_attributes,
569 [('placeholder_1', 'placeholder_1_data'),
570 ('placeholder_1_data', 'sslice_1'),
571 ('placeholder_begin_data', 'sslice_1'),
572 ('placeholder_end_data', 'sslice_1'),
573 ('placeholder_stride_data', 'sslice_1'),
574 ('sslice_1', 'sslice_1_data'),
575 ('sslice_1_data', 'op_output')
577 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
578 'sslice_1': {'slices': np.array([slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)]),
579 'shrink_axis_mask': [0, 0, 1, 0],
580 'new_axis_mask': np.array([0, 0, 0, 0])},
581 'sslice_1_data': {'shape': np.array([1, 227, 54])},
583 graph.graph['layout'] = 'NHWC'
585 graph_ref = build_graph(nodes_attributes,
586 [('placeholder_1', 'placeholder_1_data'),
587 ('placeholder_1_data', 'sslice_1'),
588 ('placeholder_begin_data', 'sslice_1'),
589 ('placeholder_end_data', 'sslice_1'),
590 ('placeholder_stride_data', 'sslice_1'),
591 ('sslice_1', 'sslice_1/Reshape_shrink_data'),
592 ('sslice_1/Reshape_shrink_data', 'sslice_1/Reshape_shrink'),
593 ('sslice_1/Reshape_shrink', 'sslice_1_data'),
594 ('sslice_1_data', 'op_output')
596 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
597 'sslice_1': {'slices': np.array(
598 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)]),
599 'shrink_axis_mask': np.array([0, 0, 0, 0]),
600 'new_axis_mask': np.array([0, 0, 0, 0])},
601 'sslice_1_data': {'shape': np.array([1, 227, 54])},
602 'sslice_1/Reshape_shrink': {'dim': np.array([1, 227, 54])},
603 'sslice_1/Reshape_shrink_data': {'shape': np.array([1, 227, 1, 54])}
606 pattern = ConvertGroupedStridedSlice()
607 pattern.add_reshape_for_shrink(graph, Node(graph, 'sslice_1'))
609 (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_1_data', check_op_attrs=True)
612 self.assertTrue(flag, resp)
614 def test_ss_1_shrink(self):
615 graph = build_graph(nodes_attributes,
616 [('placeholder_1', 'placeholder_1_data'),
617 ('placeholder_1_data', 'sslice_2'),
618 ('placeholder_begin_data', 'sslice_2'),
619 ('placeholder_end_data', 'sslice_2'),
620 ('placeholder_stride_data', 'sslice_2'),
621 ('sslice_2', 'sslice_2_data'),
622 ('sslice_2_data', 'placeholder_2'),
623 ('placeholder_2', 'placeholder_2_data'),
624 ('sslice_2_data', 'op_output')
626 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
627 'sslice_2': {'slices': np.array([slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)]),
628 'shrink_axis_mask': [0, 0, 1, 0],
629 'new_axis_mask': np.array([0, 0, 0, 0])},
630 'sslice_2_data': {'shape': np.array([1, 227, 54])}
632 graph.graph['layout'] = 'NHWC'
634 graph_ref = build_graph(nodes_attributes,
635 [('placeholder_1', 'placeholder_1_data'),
636 ('placeholder_1_data', 'sslice_2'),
637 ('placeholder_begin_data', 'sslice_2'),
638 ('placeholder_end_data', 'sslice_2'),
639 ('placeholder_stride_data', 'sslice_2'),
640 ('sslice_2', 'sslice_2/Reshape_shrink_data'),
641 ('sslice_2/Reshape_shrink_data', 'sslice_2/Reshape_shrink'),
642 ('sslice_2/Reshape_shrink', 'sslice_2_data'),
643 ('sslice_2_data', 'placeholder_2'),
644 ('placeholder_2', 'placeholder_2_data'),
645 ('sslice_2_data', 'op_output')
647 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
648 'sslice_2': {'slices': np.array([slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)]),
649 'shrink_axis_mask': np.array([0, 0, 0, 0]),
650 'new_axis_mask': np.array([0, 0, 0, 0])},
651 'sslice_2_data': {'shape': np.array([1, 227, 54])},
652 'sslice_2/Reshape_shrink': {'dim': np.array([1, 227, 54])},
653 'sslice_2/Reshape_shrink_data': {'shape': np.array([1, 227, 1, 54])},
656 pattern = ConvertGroupedStridedSlice()
657 pattern.add_reshape_for_shrink(graph, Node(graph, 'sslice_2'))
659 (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_2_data', check_op_attrs=True)
662 self.assertTrue(flag, resp)
664 def test_ss_2_shrink(self):
665 graph = build_graph(nodes_attributes,
666 [('placeholder_1', 'placeholder_1_data'),
667 ('placeholder_1_data', 'sslice_2'),
668 ('placeholder_begin_data', 'sslice_2'),
669 ('placeholder_end_data', 'sslice_2'),
670 ('placeholder_stride_data', 'sslice_2'),
671 ('sslice_2', 'sslice_2_data'),
672 ('sslice_2_data', 'placeholder_2'),
673 ('placeholder_2', 'placeholder_2_data'),
674 ('sslice_2_data', 'op_output')
676 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
678 'slices': np.array([slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1)]),
679 'shrink_axis_mask': np.array([0, 1, 0, 1]),
680 'new_axis_mask': np.array([0, 0, 0, 0])},
681 'sslice_2_data': {'shape': np.array([1, 227])}
683 graph.graph['layout'] = 'NHWC'
685 graph_ref = build_graph(nodes_attributes,
686 [('placeholder_1', 'placeholder_1_data'),
687 ('placeholder_1_data', 'sslice_2'),
688 ('placeholder_begin_data', 'sslice_2'),
689 ('placeholder_end_data', 'sslice_2'),
690 ('placeholder_stride_data', 'sslice_2'),
691 ('sslice_2', 'sslice_2/Reshape_shrink_data'),
692 ('sslice_2/Reshape_shrink_data', 'sslice_2/Reshape_shrink'),
693 ('sslice_2/Reshape_shrink', 'sslice_2_data'),
694 ('sslice_2_data', 'placeholder_2'),
695 ('placeholder_2', 'placeholder_2_data'),
696 ('sslice_2_data', 'op_output')
698 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
699 'sslice_2': {'slices': np.array(
700 [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1)]),
701 'shrink_axis_mask': np.array([0, 0, 0, 0]),
702 'new_axis_mask': np.array([0, 0, 0, 0])},
703 'sslice_2_data': {'shape': np.array([1, 227])},
704 'sslice_2/Reshape_shrink': {'dim': np.array([1, 227])},
705 'sslice_2/Reshape_shrink_data': {'shape': np.array([1, 1, 227, 1])},
708 pattern = ConvertGroupedStridedSlice()
709 pattern.add_reshape_for_shrink(graph, Node(graph, 'sslice_2'))
711 (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_2_data', check_op_attrs=True)
714 self.assertTrue(flag, resp)
716 def test_ss_1_new(self):
717 graph = build_graph(nodes_attributes,
718 [('placeholder_1', 'placeholder_1_data'),
719 ('placeholder_1_data', 'sslice_2'),
720 ('placeholder_begin_data', 'sslice_2'),
721 ('placeholder_end_data', 'sslice_2'),
722 ('placeholder_stride_data', 'sslice_2'),
723 ('sslice_2', 'sslice_2_data'),
724 ('sslice_2_data', 'placeholder_2'),
725 ('placeholder_2', 'placeholder_2_data'), ],
726 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
727 'sslice_2': {'slices': np.array(
728 [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 54, 1)]),
729 'shrink_axis_mask': np.array([0, 0, 0, 0, 0]),
730 'new_axis_mask': np.array([0, 1, 0, 0, 0])},
731 'sslice_2_data': {'shape': np.array([1, 1, 227, 227, 54])}
733 graph.graph['layout'] = 'NHWC'
735 graph_ref = build_graph(nodes_attributes,
736 [('placeholder_1', 'placeholder_1_data'),
737 ('placeholder_1_data', 'sslice_2'),
738 ('placeholder_begin_data', 'sslice_2'),
739 ('placeholder_end_data', 'sslice_2'),
740 ('placeholder_stride_data', 'sslice_2'),
741 ('sslice_2', 'sslice_2/Reshape_new_data'),
742 ('sslice_2/Reshape_new_data', 'sslice_2/Reshape_new'),
743 ('sslice_2/Reshape_new', 'sslice_2_data'),
744 ('sslice_2_data', 'placeholder_2'),
745 ('placeholder_2', 'placeholder_2_data')],
746 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
747 'sslice_2': {'slices': np.array(
748 [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1),
750 'shrink_axis_mask': np.array([0, 0, 0, 0, 0]),
751 'new_axis_mask': np.array([0, 0, 0, 0, 0])},
752 'sslice_2_data': {'shape': np.array([1, 1, 227, 227, 54])},
753 'sslice_2/Reshape_new': {'dim': np.array([1, 1, 227, 227, 54])},
754 'sslice_2/Reshape_new_data': {'shape': np.array([1, 227, 227, 54])},
757 pattern = ConvertGroupedStridedSlice()
758 pattern.add_reshape_for_new(graph, Node(graph, 'sslice_2'))
760 (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_2_data', check_op_attrs=True)
763 self.assertTrue(flag, resp)
765 def test_ss_shrink_new(self):
766 graph = build_graph(nodes_attributes,
767 [('placeholder_1', 'placeholder_1_data'),
768 ('placeholder_1_data', 'sslice_2'),
769 ('placeholder_begin_data', 'sslice_2'),
770 ('placeholder_end_data', 'sslice_2'),
771 ('placeholder_stride_data', 'sslice_2'),
772 ('sslice_2', 'sslice_2_data'),
773 ('sslice_2_data', 'placeholder_2'),
774 ('placeholder_2', 'placeholder_2_data'),
775 ('sslice_2_data', 'op_output')
777 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
778 'sslice_2': {'slices': np.array(
779 [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)]),
780 'shrink_axis_mask': np.array([0, 0, 0, 1, 0]),
781 'new_axis_mask': np.array([0, 1, 0, 0, 0])},
782 'sslice_2_data': {'shape': np.array([1, 1, 227, 54])}
784 graph.graph['layout'] = 'NHWC'
786 graph_ref = build_graph(nodes_attributes,
787 [('placeholder_1', 'placeholder_1_data'),
788 ('placeholder_1_data', 'sslice_2'),
789 ('placeholder_begin_data', 'sslice_2'),
790 ('placeholder_end_data', 'sslice_2'),
791 ('placeholder_stride_data', 'sslice_2'),
792 ('sslice_2', 'sslice_2/Reshape_new_data'),
793 ('sslice_2/Reshape_new_data', 'sslice_2/Reshape_new'),
794 ('sslice_2/Reshape_new', 'sslice_2/Reshape_shrink_data'),
795 ('sslice_2/Reshape_shrink_data', 'sslice_2/Reshape_shrink'),
796 ('sslice_2/Reshape_shrink', 'sslice_2_data'),
797 ('sslice_2_data', 'placeholder_2'),
798 ('placeholder_2', 'placeholder_2_data'),
799 ('sslice_2_data', 'op_output')
801 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
802 'sslice_2': {'slices': np.array(
803 [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1),
805 'shrink_axis_mask': np.array([0, 0, 0, 0, 0]),
806 'new_axis_mask': np.array([0, 0, 0, 0, 0])},
807 'sslice_2_data': {'shape': np.array([1, 1, 227, 54])},
808 'sslice_2/Reshape_new': {'dim': np.array([1, 1, 227, 1, 54])},
809 'sslice_2/Reshape_new_data': {'shape': np.array([1, 227, 1, 54])},
810 'sslice_2/Reshape_shrink': {'dim': np.array([1, 1, 227, 54])},
811 'sslice_2/Reshape_shrink_data': {'shape': np.array([1, 1, 227, 1, 54])},
814 pattern = ConvertGroupedStridedSlice()
815 pattern.add_reshape_for_shrink(graph, Node(graph, 'sslice_2'))
816 pattern.add_reshape_for_new(graph, Node(graph, 'sslice_2'))
818 (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_2_data', check_op_attrs=True)
821 self.assertTrue(flag, resp)
823 # test case with 2 strided slices with the same parameters but different outputs
825 graph = build_graph(nodes_attributes,
826 [('placeholder_1', 'placeholder_1_data'),
827 ('placeholder_1_data', 'sslice_1'),
828 ('sslice_1', 'sslice_1_data'),
829 ('placeholder_1_data', 'sslice_2'),
830 ('sslice_2', 'sslice_2_data'),
831 ('placeholder_1_data', 'sslice_3'),
832 ('sslice_3', 'sslice_3_data'),
833 ('sslice_1_data', 'concat_1'),
834 ('sslice_2_data', 'concat_1'),
835 ('sslice_3_data', 'placeholder_2'),
836 ('placeholder_2', 'placeholder_2_data'),
837 ('concat_1', 'concat_1_data'),
838 ('concat_1_data', 'op_output'),
839 ('placeholder_2_data', 'op_output')
841 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
843 'sslice_1': {'slices': np.array(
844 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 27, 1)])},
845 'sslice_1_data': {'shape': np.array([1, 227, 227, 27])},
847 'sslice_2': {'slices': np.array(
848 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(27, 54, 1)])},
849 'sslice_2_data': {'shape': np.array([1, 227, 227, 27])},
851 'sslice_3': {'slices': np.array(
852 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 27, 1)])},
853 'sslice_3_data': {'shape': np.array([1, 227, 227, 27])},
855 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
857 graph.graph['layout'] = 'NHWC'
859 graph_ref = build_graph(nodes_attributes,
860 [('placeholder_1', 'placeholder_1_data'),
861 ('placeholder_1_data', 'split_1'),
862 ('split_1', 'split_1_data'),
863 ('split_1', 'split_2_data'),
864 ('split_1_data', 'concat_1'),
865 ('split_2_data', 'concat_1'),
866 ('split_1_data', 'placeholder_2'),
867 ('placeholder_2', 'placeholder_2_data'),
868 ('concat_1', 'concat_1_data'),
869 ('concat_1_data', 'op_output'),
870 ('placeholder_2_data', 'op_output')
872 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
873 'split_1': {'axis': 3},
874 'split_1_data': {'shape': np.array([1, 227, 227, 27])},
875 'split_2_data': {'shape': np.array([1, 227, 227, 27])},
876 'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
879 pattern = ConvertGroupedStridedSlice()
880 pattern.find_and_replace_pattern(graph)
882 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data', check_op_attrs=True)
883 self.assertTrue(flag, resp)
886 if __name__ == '__main__':