Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / ConvertGroupedStridedSlice_test.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import unittest
18
19 import numpy as np
20
21 from extensions.middle.ConvertGroupedStridedSlice import ConvertGroupedStridedSlice
22 from mo.graph.graph import Node
23 from mo.utils.unittest.graph import build_graph, compare_graphs
24
25 nodes_attributes = {
26     'placeholder_1': {'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
27     'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
28     'placeholder_2': {'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
29     'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
30     'placeholder_begin_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
31     'placeholder_end_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
32     'placeholder_stride_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
33     # StridedSlice layers
34     'sslice_1': {'type': None, 'kind': 'op', 'op': 'StridedSlice', 'slices': None,
35                  'shrink_axis_mask': np.array([0, 0, 0, 0])},
36     'sslice_1_data': {'value': None, 'shape': None, 'kind': 'data'},
37     'sslice_2': {'type': None, 'kind': 'op', 'op': 'StridedSlice', 'slices': None,
38                  'shrink_axis_mask': np.array([0, 0, 0, 0])},
39     'sslice_2_data': {'value': None, 'shape': None, 'kind': 'data'},
40     'sslice_3': {'type': None, 'kind': 'op', 'op': 'StridedSlice', 'slices': None,
41                  'shrink_axis_mask': np.array([0, 0, 0, 0])},
42     'sslice_3_data': {'value': None, 'shape': None, 'kind': 'data'},
43     # Split layer
44     'split_1': {'type': 'Split', 'kind': 'op', 'op': 'SplitV'},
45     'split_1_data': {'value': None, 'shape': None, 'kind': 'data'},
46     'split_2_data': {'value': None, 'shape': None, 'kind': 'data'},
47     'split_3_data': {'value': None, 'shape': None, 'kind': 'data'},
48     'split_4_data': {'value': None, 'shape': None, 'kind': 'data'},
49     # Concat1 operation
50     'concat_1': {'type': 'Concat', 'kind': 'op', 'op': 'Concat'},
51     'concat_1_data': {'value': None, 'shape': None, 'kind': 'data'},
52     'op_output': {'kind': 'op', 'op': 'OpOutput'},
53     'op_output_1': {'kind': 'op', 'op': 'OpOutput'},
54     'op_output_2': {'kind': 'op', 'op': 'OpOutput'},
55     # Reshape layer
56     'sslice_1/Reshape_shrink': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
57     'sslice_1/Reshape_shrink_data': {'value': None, 'shape': None, 'kind': 'data'},
58     'sslice_2/Reshape_shrink': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
59     'sslice_2/Reshape_shrink_data': {'value': None, 'shape': None, 'kind': 'data'},
60     'sslice_2/Reshape_new': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
61     'sslice_2/Reshape_new_data': {'value': None, 'shape': None, 'kind': 'data'},
62 }
63
64
65 class ConvertGroupedStridedSliceTests(unittest.TestCase):
66     def test_1(self):
67         graph = build_graph(nodes_attributes,
68                             [('placeholder_1', 'placeholder_1_data'),
69                              ('placeholder_1_data', 'sslice_1'),
70                              ('sslice_1', 'sslice_1_data'),
71                              ('placeholder_1_data', 'sslice_2'),
72                              ('sslice_2', 'sslice_2_data'),
73                              ('placeholder_1_data', 'sslice_3'),
74                              ('sslice_3', 'sslice_3_data'),
75                              ('sslice_1_data', 'concat_1'),
76                              ('sslice_2_data', 'concat_1'),
77                              ('sslice_3_data', 'concat_1'),
78                              ('concat_1', 'concat_1_data'),
79                              ('concat_1_data', 'op_output')
80                              ],
81                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
82
83                              'sslice_1': {'slices': np.array(
84                                  [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 18, 1)])},
85                              'sslice_1_data': {'shape': np.array([1, 227, 227, 18])},
86
87                              'sslice_2': {'slices': np.array(
88                                  [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(18, 36, 1)])},
89                              'sslice_2_data': {'shape': np.array([1, 227, 227, 18])},
90
91                              'sslice_3': {'slices': np.array(
92                                  [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(36, 54, 1)])},
93                              'sslice_3_data': {'shape': np.array([1, 227, 227, 18])},
94
95                              'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
96                              })
97         graph.graph['layout'] = 'NHWC'
98
99         graph_ref = build_graph(nodes_attributes,
100                                 [('placeholder_1', 'placeholder_1_data'),
101                                  ('placeholder_1_data', 'split_1'),
102                                  ('split_1', 'split_1_data'),
103                                  ('split_1', 'split_2_data'),
104                                  ('split_1', 'split_3_data'),
105                                  ('split_1_data', 'concat_1'),
106                                  ('split_2_data', 'concat_1'),
107                                  ('split_3_data', 'concat_1'),
108                                  ('concat_1', 'concat_1_data'),
109                                  ('concat_1_data', 'op_output')
110
111                                  ],
112                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
113                                  'split_1': {'axis': 3},
114                                  'split_1_data': {'shape': np.array([1, 227, 227, 18])},
115                                  'split_2_data': {'shape': np.array([1, 227, 227, 18])},
116                                  'split_3_data': {'shape': np.array([1, 227, 227, 18])},
117                                  'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
118                                  })
119
120         pattern = ConvertGroupedStridedSlice()
121         pattern.find_and_replace_pattern(graph)
122
123         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data', check_op_attrs=True)
124         self.assertTrue(flag, resp)
125
126     def test_2(self):
127         graph = build_graph(nodes_attributes,
128                             [('placeholder_1', 'placeholder_1_data'),
129                              ('placeholder_1_data', 'sslice_1'),
130                              ('sslice_1', 'sslice_1_data'),
131                              ('placeholder_1_data', 'sslice_2'),
132                              ('sslice_2', 'sslice_2_data'),
133                              ('placeholder_1_data', 'sslice_3'),
134                              ('sslice_3', 'sslice_3_data'),
135                              ('sslice_1_data', 'concat_1'),
136                              ('sslice_2_data', 'concat_1'),
137                              ('sslice_3_data', 'concat_1'),
138                              ('concat_1', 'concat_1_data'),
139                              ('concat_1_data', 'op_output')
140                              ],
141                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
142
143                              'sslice_1': {'slices': np.array(
144                                  [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(19, 37, 1)])},
145                              'sslice_1_data': {'shape': np.array([1, 227, 227, 18])},
146
147                              'sslice_2': {'slices': np.array(
148                                  [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(37, 54, 1)])},
149                              'sslice_2_data': {'shape': np.array([1, 227, 227, 17])},
150
151                              'sslice_3': {'slices': np.array(
152                                  [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 19, 1)])},
153                              'sslice_3_data': {'shape': np.array([1, 227, 227, 19])},
154
155                              'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
156                              })
157         graph.graph['layout'] = 'NHWC'
158
159         graph_ref = build_graph(nodes_attributes,
160                                 [('placeholder_1', 'placeholder_1_data'),
161                                  ('placeholder_1_data', 'split_1'),
162                                  ('split_1', 'split_1_data'),
163                                  ('split_1', 'split_2_data'),
164                                  ('split_1', 'split_3_data'),
165                                  ('split_1_data', 'concat_1'),
166                                  ('split_2_data', 'concat_1'),
167                                  ('split_3_data', 'concat_1'),
168                                  ('concat_1', 'concat_1_data'),
169                                  ('concat_1_data', 'op_output')
170                                  ],
171                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
172                                  'split_1': {'axis': 3},
173                                  'split_1_data': {'shape': np.array([1, 227, 227, 18])},
174                                  'split_2_data': {'shape': np.array([1, 227, 227, 17])},
175                                  'split_3_data': {'shape': np.array([1, 227, 227, 19])},
176                                  'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
177                                  })
178
179         pattern = ConvertGroupedStridedSlice()
180         pattern.find_and_replace_pattern(graph)
181
182         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data', check_op_attrs=True)
183         self.assertTrue(flag, resp)
184
185     # Intersection of split ranges in feature dimension
186     def test_3_neg(self):
187         graph = build_graph(nodes_attributes,
188                             [('placeholder_1', 'placeholder_1_data'),
189                              ('placeholder_1_data', 'sslice_1'),
190                              ('sslice_1', 'sslice_1_data'),
191                              ('placeholder_1_data', 'sslice_2'),
192                              ('sslice_2', 'sslice_2_data'),
193                              ('placeholder_1_data', 'sslice_3'),
194                              ('sslice_3', 'sslice_3_data'),
195                              ('sslice_1_data', 'concat_1'),
196                              ('sslice_2_data', 'concat_1'),
197                              ('sslice_3_data', 'concat_1'),
198                              ('concat_1', 'concat_1_data'),
199                              ('concat_1_data', 'op_output')
200                              ],
201                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
202
203                              'sslice_1': {'slices': np.array(
204                                  [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(19, 39, 1)])},
205                              'sslice_1_data': {'shape': np.array([1, 227, 227, 20])},
206
207                              'sslice_2': {'slices': np.array(
208                                  [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(37, 54, 1)])},
209                              'sslice_2_data': {'shape': np.array([1, 227, 227, 17])},
210
211                              'sslice_3': {'slices': np.array(
212                                  [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 19, 1)])},
213                              'sslice_3_data': {'shape': np.array([1, 227, 227, 19])},
214
215                              'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
216                              })
217         graph.graph['layout'] = 'NHWC'
218
219         graph_ref = build_graph(nodes_attributes,
220                                 [('placeholder_1', 'placeholder_1_data'),
221                                  ('placeholder_1_data', 'sslice_1'),
222                                  ('sslice_1', 'sslice_1_data'),
223                                  ('placeholder_1_data', 'sslice_2'),
224                                  ('sslice_2', 'sslice_2_data'),
225                                  ('placeholder_1_data', 'sslice_3'),
226                                  ('sslice_3', 'sslice_3_data'),
227                                  ('sslice_1_data', 'concat_1'),
228                                  ('sslice_2_data', 'concat_1'),
229                                  ('sslice_3_data', 'concat_1'),
230                                  ('concat_1', 'concat_1_data'),
231                                  ('concat_1_data', 'op_output')
232                                  ],
233                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
234
235                                  'sslice_1': {'slices': np.array(
236                                      [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(19, 39, 1)])},
237                                  'sslice_1_data': {'shape': np.array([1, 227, 227, 20])},
238
239                                  'sslice_2': {'slices': np.array(
240                                      [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(37, 54, 1)])},
241                                  'sslice_2_data': {'shape': np.array([1, 227, 227, 17])},
242
243                                  'sslice_3': {'slices': np.array(
244                                      [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 19, 1)])},
245                                  'sslice_3_data': {'shape': np.array([1, 227, 227, 19])},
246
247                                  'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
248                                  })
249
250         pattern = ConvertGroupedStridedSlice()
251         pattern.find_and_replace_pattern(graph)
252
253         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data', check_op_attrs=True)
254         self.assertTrue(flag, resp)
255
256     # Split range overflow in feature dimension
257     def test_4_neg(self):
258         graph = build_graph(nodes_attributes,
259                             [('placeholder_1', 'placeholder_1_data'),
260                              ('placeholder_1_data', 'sslice_1'),
261                              ('sslice_1', 'sslice_1_data'),
262                              ('placeholder_1_data', 'sslice_2'),
263                              ('sslice_2', 'sslice_2_data'),
264                              ('placeholder_1_data', 'sslice_3'),
265                              ('sslice_3', 'sslice_3_data'),
266                              ('sslice_1_data', 'concat_1'),
267                              ('sslice_2_data', 'concat_1'),
268                              ('sslice_3_data', 'concat_1'),
269                              ('concat_1', 'concat_1_data'),
270                              ('concat_1_data', 'op_output')
271                              ],
272                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
273
274                              'sslice_1': {'slices': np.array(
275                                  [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(19, 37, 1)])},
276                              'sslice_1_data': {'shape': np.array([1, 227, 227, 18])},
277
278                              'sslice_2': {'slices': np.array(
279                                  [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(37, 55, 1)])},
280                              'sslice_2_data': {'shape': np.array([1, 227, 227, 18])},
281
282                              'sslice_3': {'slices': np.array(
283                                  [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 19, 1)])},
284                              'sslice_3_data': {'shape': np.array([1, 227, 227, 19])},
285
286                              'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
287                              })
288         graph.graph['layout'] = 'NHWC'
289
290         graph_ref = build_graph(nodes_attributes,
291                                 [('placeholder_1', 'placeholder_1_data'),
292                                  ('placeholder_1_data', 'sslice_1'),
293                                  ('sslice_1', 'sslice_1_data'),
294                                  ('placeholder_1_data', 'sslice_2'),
295                                  ('sslice_2', 'sslice_2_data'),
296                                  ('placeholder_1_data', 'sslice_3'),
297                                  ('sslice_3', 'sslice_3_data'),
298                                  ('sslice_1_data', 'concat_1'),
299                                  ('sslice_2_data', 'concat_1'),
300                                  ('sslice_3_data', 'concat_1'),
301                                  ('concat_1', 'concat_1_data'),
302                                  ('concat_1_data', 'op_output')
303                                  ],
304                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
305
306                                  'sslice_1': {'slices': np.array(
307                                      [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(19, 37, 1)])},
308                                  'sslice_1_data': {'shape': np.array([1, 227, 227, 18])},
309
310                                  'sslice_2': {'slices': np.array(
311                                      [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(37, 55, 1)])},
312                                  'sslice_2_data': {'shape': np.array([1, 227, 227, 18])},
313
314                                  'sslice_3': {'slices': np.array(
315                                      [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 19, 1)])},
316                                  'sslice_3_data': {'shape': np.array([1, 227, 227, 19])},
317
318                                  'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
319                                  })
320
321         pattern = ConvertGroupedStridedSlice()
322         pattern.find_and_replace_pattern(graph)
323
324         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data', check_op_attrs=True)
325         self.assertTrue(flag, resp)
326
327     # Split(1,H,W,54)--->Fake_data (1,H,W,1)
328     #       |`---->Sslice1_out (1,H,W,18)
329     #       |`---->Sslice2_out (1,H,W,18)
330     #       `----->Sslice3_out (1,H,W,17)
331     def test_5(self):
332         graph = build_graph(nodes_attributes,
333                             [('placeholder_1', 'placeholder_1_data'),
334                              ('placeholder_1_data', 'sslice_1'),
335                              ('sslice_1', 'sslice_1_data'),
336                              ('placeholder_1_data', 'sslice_2'),
337                              ('sslice_2', 'sslice_2_data'),
338                              ('placeholder_1_data', 'sslice_3'),
339                              ('sslice_3', 'sslice_3_data'),
340                              ('sslice_1_data', 'concat_1'),
341                              ('sslice_2_data', 'concat_1'),
342                              ('sslice_3_data', 'concat_1'),
343                              ('concat_1', 'concat_1_data'),
344                              ('concat_1_data', 'op_output'),
345                              ],
346                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
347
348                              'sslice_1': {'slices': np.array(
349                                  [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(19, 37, 1)])},
350                              'sslice_1_data': {'shape': np.array([1, 227, 227, 18])},
351
352                              'sslice_2': {'slices': np.array(
353                                  [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(37, 54, 1)])},
354                              'sslice_2_data': {'shape': np.array([1, 227, 227, 17])},
355
356                              'sslice_3': {'slices': np.array(
357                                  [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(1, 19, 1)])},
358                              'sslice_3_data': {'shape': np.array([1, 227, 227, 18])},
359
360                              'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
361                              })
362         graph.graph['layout'] = 'NHWC'
363
364         graph_ref = build_graph(nodes_attributes,
365                                 [('placeholder_1', 'placeholder_1_data'),
366                                  ('placeholder_1_data', 'split_1'),
367                                  ('split_1', 'split_1_data'),
368                                  ('split_1', 'split_2_data'),
369                                  ('split_1', 'split_3_data'),
370                                  ('split_1', 'split_4_data'),
371                                  ('split_2_data', 'concat_1'),
372                                  ('split_3_data', 'concat_1'),
373                                  ('split_4_data', 'concat_1'),
374                                  ('concat_1', 'concat_1_data'),
375                                  ('concat_1_data', 'op_output'),
376                                  ('split_1_data', 'op_output_1')
377                                  ],
378                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
379                                  'split_1': {'axis': 3},
380                                  'split_1_data': {'shape': np.array([1, 227, 227, 1])},
381                                  'split_2_data': {'shape': np.array([1, 227, 227, 18])},
382                                  'split_3_data': {'shape': np.array([1, 227, 227, 17])},
383                                  'split_4_data': {'shape': np.array([1, 227, 227, 18])},
384                                  'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
385                                  })
386
387         pattern = ConvertGroupedStridedSlice()
388         pattern.find_and_replace_pattern(graph)
389
390         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data', check_op_attrs=True)
391         self.assertTrue(flag, resp)
392
393     # Split(1,H,W,54)
394     #       |`---->Sslice1_out (1,H,W,(0,18))
395     #       |`---->Fake_data (1,H,W,(18,27))
396     #       |`---->Sslice3_out (1,H,W,(27,45))
397     #       `----->Fake_data (1,H,W,(45,54))
398     def test_6(self):
399         graph = build_graph(nodes_attributes,
400                             [('placeholder_1', 'placeholder_1_data'),
401                              ('placeholder_1_data', 'sslice_1'),
402                              ('sslice_1', 'sslice_1_data'),
403                              ('placeholder_1_data', 'sslice_2'),
404                              ('sslice_2', 'sslice_2_data'),
405                              ('sslice_1_data', 'concat_1'),
406                              ('sslice_2_data', 'concat_1'),
407                              ('concat_1', 'concat_1_data'),
408                              ('concat_1_data', 'op_output')
409                              ],
410                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
411
412                              'sslice_1': {'slices': np.array(
413                                  [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 18, 1)])},
414                              'sslice_1_data': {'shape': np.array([1, 227, 227, 18])},
415
416                              'sslice_2': {'slices': np.array(
417                                  [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(27, 45, 1)])},
418                              'sslice_2_data': {'shape': np.array([1, 227, 227, 18])},
419
420                              'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
421                              })
422         graph.graph['layout'] = 'NHWC'
423
424         graph_ref = build_graph(nodes_attributes,
425                                 [('placeholder_1', 'placeholder_1_data'),
426                                  ('placeholder_1_data', 'split_1'),
427                                  ('split_1', 'split_1_data'),
428                                  ('split_1', 'split_2_data'),
429                                  ('split_1', 'split_3_data'),
430                                  ('split_1', 'split_4_data'),
431                                  ('split_1_data', 'concat_1'),
432                                  ('split_3_data', 'concat_1'),
433                                  ('concat_1', 'concat_1_data'),
434                                  ('concat_1_data', 'op_output'),
435                                  ('split_2_data', 'op_output_1'),
436                                  ('split_4_data', 'op_output_2'),
437                                  ],
438                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
439                                  'split_1': {'axis': 3},
440                                  'split_1_data': {'shape': np.array([1, 227, 227, 18])},
441                                  'split_2_data': {'shape': np.array([1, 227, 227, 9])},
442                                  'split_3_data': {'shape': np.array([1, 227, 227, 18])},
443                                  'split_4_data': {'shape': np.array([1, 227, 227, 9])},
444                                  'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
445                                  })
446
447         pattern = ConvertGroupedStridedSlice()
448         pattern.find_and_replace_pattern(graph)
449
450         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data', check_op_attrs=True)
451         self.assertTrue(flag, resp)
452
453     def test_7_neg(self):
454         graph = build_graph(nodes_attributes,
455                             [('placeholder_1', 'placeholder_1_data'),
456                              ('placeholder_1_data', 'sslice_1'),
457                              ('sslice_1', 'sslice_1_data'),
458                              ('placeholder_1_data', 'sslice_2'),
459                              ('sslice_2', 'sslice_2_data'),
460                              ('sslice_1_data', 'concat_1'),
461                              ('sslice_2_data', 'concat_1'),
462                              ('concat_1', 'concat_1_data'),
463                              ('concat_1_data', 'op_output')
464                              ],
465                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
466
467                              'sslice_1': {'slices': np.array(
468                                  [slice(0, 1, 1), slice(0, 10, 1), slice(0, 227, 1), slice(0, 18, 1)])},
469                              'sslice_1_data': {'shape': np.array([1, 10, 227, 18])},
470
471                              'sslice_2': {'slices': np.array(
472                                  [slice(0, 1, 1), slice(10, 227, 1), slice(0, 227, 1), slice(27, 45, 1)])},
473                              'sslice_2_data': {'shape': np.array([1, 217, 227, 18])},
474
475                              'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
476                              })
477         graph.graph['layout'] = 'NHWC'
478
479         graph_ref = build_graph(nodes_attributes,
480                                 [('placeholder_1', 'placeholder_1_data'),
481                                  ('placeholder_1_data', 'sslice_1'),
482                                  ('sslice_1', 'sslice_1_data'),
483                                  ('placeholder_1_data', 'sslice_2'),
484                                  ('sslice_2', 'sslice_2_data'),
485                                  ('sslice_1_data', 'concat_1'),
486                                  ('sslice_2_data', 'concat_1'),
487                                  ('concat_1', 'concat_1_data'),
488                                  ('concat_1_data', 'op_output')
489                                  ],
490                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
491
492                                  'sslice_1': {'slices': np.array(
493                                      [slice(0, 1, 1), slice(0, 10, 1), slice(0, 227, 1), slice(0, 18, 1)])},
494                                  'sslice_1_data': {'shape': np.array([1, 10, 227, 18])},
495
496                                  'sslice_2': {'slices': np.array(
497                                      [slice(0, 1, 1), slice(10, 227, 1), slice(0, 227, 1), slice(27, 45, 1)])},
498                                  'sslice_2_data': {'shape': np.array([1, 217, 227, 18])},
499
500                                  'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
501                                  })
502
503         pattern = ConvertGroupedStridedSlice()
504         pattern.find_and_replace_pattern(graph)
505
506         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data', check_op_attrs=True)
507         self.assertTrue(flag, resp)
508
509     # Split(1,54,W,C)
510     #       |`---->Sslice1_out (1,(0,18),W,C)
511     #       |`---->Sslice2_out (1,(18,36),W,C)
512     #       `----->Fake_data (1,(36,54),W,C)
513     def test_8(self):
514         graph = build_graph(nodes_attributes,
515                             [('placeholder_1', 'placeholder_1_data'),
516                              ('placeholder_1_data', 'sslice_1'),
517                              ('sslice_1', 'sslice_1_data'),
518                              ('placeholder_1_data', 'sslice_2'),
519                              ('sslice_2', 'sslice_2_data'),
520                              ('sslice_1_data', 'concat_1'),
521                              ('sslice_2_data', 'concat_1'),
522                              ('concat_1', 'concat_1_data'),
523                              ('concat_1_data', 'op_output')
524                              ],
525                             {'placeholder_1_data': {'shape': np.array([1, 54, 54, 3])},
526
527                              'sslice_1': {'slices': np.array(
528                                  [slice(0, 1, 1), slice(0, 18, 1), slice(0, 54, 1), slice(0, 3, 1)])},
529                              'sslice_1_data': {'shape': np.array([1, 18, 54, 3])},
530
531                              'sslice_2': {'slices': np.array(
532                                  [slice(0, 1, 1), slice(18, 36, 1), slice(0, 54, 1), slice(0, 3, 1)])},
533                              'sslice_2_data': {'shape': np.array([1, 18, 54, 3])},
534
535                              'concat_1_data': {'shape': np.array([1, 54, 54, 3])},
536                              })
537         graph.graph['layout'] = 'NHWC'
538
539         graph_ref = build_graph(nodes_attributes,
540                                 [('placeholder_1', 'placeholder_1_data'),
541                                  ('placeholder_1_data', 'split_1'),
542                                  ('split_1', 'split_1_data'),
543                                  ('split_1', 'split_2_data'),
544                                  ('split_1', 'split_3_data'),
545                                  ('split_1_data', 'concat_1'),
546                                  ('split_3_data', 'concat_1'),
547                                  ('concat_1', 'concat_1_data'),
548                                  ('concat_1_data', 'op_output'),
549                                  ('split_2_data', 'op_output_1')
550                                  ],
551                                 {'placeholder_1_data': {'shape': np.array([1, 54, 54, 3])},
552                                  'split_1': {'axis': 1},
553                                  'split_1_data': {'shape': np.array([1, 18, 54, 3])},
554                                  'split_2_data': {'shape': np.array([1, 18, 54, 3])},
555                                  'split_3_data': {'shape': np.array([1, 18, 54, 3])},
556                                  'concat_1_data': {'shape': np.array([1, 54, 54, 3])},
557                                  })
558
559         pattern = ConvertGroupedStridedSlice()
560         pattern.find_and_replace_pattern(graph)
561
562         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data', check_op_attrs=True)
563         self.assertTrue(flag, resp)
564
565
566 class AddReshapeAfterStridedSliceTests(unittest.TestCase):
567     def test_ss_1_shrink_last(self):
568         graph = build_graph(nodes_attributes,
569                             [('placeholder_1', 'placeholder_1_data'),
570                              ('placeholder_1_data', 'sslice_1'),
571                              ('placeholder_begin_data', 'sslice_1'),
572                              ('placeholder_end_data', 'sslice_1'),
573                              ('placeholder_stride_data', 'sslice_1'),
574                              ('sslice_1', 'sslice_1_data'),
575                              ('sslice_1_data', 'op_output')
576                              ],
577                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
578                              'sslice_1': {'slices': np.array([slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)]),
579                                           'shrink_axis_mask': [0, 0, 1, 0],
580                                           'new_axis_mask': np.array([0, 0, 0, 0])},
581                              'sslice_1_data': {'shape': np.array([1, 227, 54])},
582                              })
583         graph.graph['layout'] = 'NHWC'
584
585         graph_ref = build_graph(nodes_attributes,
586                                 [('placeholder_1', 'placeholder_1_data'),
587                                  ('placeholder_1_data', 'sslice_1'),
588                                  ('placeholder_begin_data', 'sslice_1'),
589                                  ('placeholder_end_data', 'sslice_1'),
590                                  ('placeholder_stride_data', 'sslice_1'),
591                                  ('sslice_1', 'sslice_1/Reshape_shrink_data'),
592                                  ('sslice_1/Reshape_shrink_data', 'sslice_1/Reshape_shrink'),
593                                  ('sslice_1/Reshape_shrink', 'sslice_1_data'),
594                                  ('sslice_1_data', 'op_output')
595                                  ],
596                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
597                                  'sslice_1': {'slices': np.array(
598                                      [slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)]),
599                                      'shrink_axis_mask': np.array([0, 0, 0, 0]),
600                                      'new_axis_mask': np.array([0, 0, 0, 0])},
601                                  'sslice_1_data': {'shape': np.array([1, 227, 54])},
602                                  'sslice_1/Reshape_shrink': {'dim': np.array([1, 227, 54])},
603                                  'sslice_1/Reshape_shrink_data': {'shape': np.array([1, 227, 1, 54])}
604                                  })
605
606         pattern = ConvertGroupedStridedSlice()
607         pattern.add_reshape_for_shrink(graph, Node(graph, 'sslice_1'))
608
609         (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_1_data', check_op_attrs=True)
610         graph.clear()
611         graph_ref.clear()
612         self.assertTrue(flag, resp)
613
614     def test_ss_1_shrink(self):
615         graph = build_graph(nodes_attributes,
616                             [('placeholder_1', 'placeholder_1_data'),
617                              ('placeholder_1_data', 'sslice_2'),
618                              ('placeholder_begin_data', 'sslice_2'),
619                              ('placeholder_end_data', 'sslice_2'),
620                              ('placeholder_stride_data', 'sslice_2'),
621                              ('sslice_2', 'sslice_2_data'),
622                              ('sslice_2_data', 'placeholder_2'),
623                              ('placeholder_2', 'placeholder_2_data'),
624                              ('sslice_2_data', 'op_output')
625                              ],
626                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
627                              'sslice_2': {'slices': np.array([slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)]),
628                                           'shrink_axis_mask': [0, 0, 1, 0],
629                                           'new_axis_mask': np.array([0, 0, 0, 0])},
630                              'sslice_2_data': {'shape': np.array([1, 227, 54])}
631                              })
632         graph.graph['layout'] = 'NHWC'
633
634         graph_ref = build_graph(nodes_attributes,
635                                 [('placeholder_1', 'placeholder_1_data'),
636                                  ('placeholder_1_data', 'sslice_2'),
637                                  ('placeholder_begin_data', 'sslice_2'),
638                                  ('placeholder_end_data', 'sslice_2'),
639                                  ('placeholder_stride_data', 'sslice_2'),
640                                  ('sslice_2', 'sslice_2/Reshape_shrink_data'),
641                                  ('sslice_2/Reshape_shrink_data', 'sslice_2/Reshape_shrink'),
642                                  ('sslice_2/Reshape_shrink', 'sslice_2_data'),
643                                  ('sslice_2_data', 'placeholder_2'),
644                                  ('placeholder_2', 'placeholder_2_data'),
645                                  ('sslice_2_data', 'op_output')
646                                  ],
647                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
648                                  'sslice_2': {'slices': np.array([slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)]),
649                                               'shrink_axis_mask': np.array([0, 0, 0, 0]),
650                                               'new_axis_mask': np.array([0, 0, 0, 0])},
651                                  'sslice_2_data': {'shape': np.array([1, 227, 54])},
652                                  'sslice_2/Reshape_shrink': {'dim': np.array([1, 227, 54])},
653                                  'sslice_2/Reshape_shrink_data': {'shape': np.array([1, 227, 1, 54])},
654                                  })
655
656         pattern = ConvertGroupedStridedSlice()
657         pattern.add_reshape_for_shrink(graph, Node(graph, 'sslice_2'))
658
659         (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_2_data', check_op_attrs=True)
660         graph.clear()
661         graph_ref.clear()
662         self.assertTrue(flag, resp)
663
664     def test_ss_2_shrink(self):
665         graph = build_graph(nodes_attributes,
666                             [('placeholder_1', 'placeholder_1_data'),
667                              ('placeholder_1_data', 'sslice_2'),
668                              ('placeholder_begin_data', 'sslice_2'),
669                              ('placeholder_end_data', 'sslice_2'),
670                              ('placeholder_stride_data', 'sslice_2'),
671                              ('sslice_2', 'sslice_2_data'),
672                              ('sslice_2_data', 'placeholder_2'),
673                              ('placeholder_2', 'placeholder_2_data'),
674                              ('sslice_2_data', 'op_output')
675                              ],
676                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
677                              'sslice_2': {
678                                  'slices': np.array([slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1)]),
679                                  'shrink_axis_mask': np.array([0, 1, 0, 1]),
680                                  'new_axis_mask': np.array([0, 0, 0, 0])},
681                              'sslice_2_data': {'shape': np.array([1, 227])}
682                              })
683         graph.graph['layout'] = 'NHWC'
684
685         graph_ref = build_graph(nodes_attributes,
686                                 [('placeholder_1', 'placeholder_1_data'),
687                                  ('placeholder_1_data', 'sslice_2'),
688                                  ('placeholder_begin_data', 'sslice_2'),
689                                  ('placeholder_end_data', 'sslice_2'),
690                                  ('placeholder_stride_data', 'sslice_2'),
691                                  ('sslice_2', 'sslice_2/Reshape_shrink_data'),
692                                  ('sslice_2/Reshape_shrink_data', 'sslice_2/Reshape_shrink'),
693                                  ('sslice_2/Reshape_shrink', 'sslice_2_data'),
694                                  ('sslice_2_data', 'placeholder_2'),
695                                  ('placeholder_2', 'placeholder_2_data'),
696                                  ('sslice_2_data', 'op_output')
697                                  ],
698                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
699                                  'sslice_2': {'slices': np.array(
700                                      [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1)]),
701                                      'shrink_axis_mask': np.array([0, 0, 0, 0]),
702                                      'new_axis_mask': np.array([0, 0, 0, 0])},
703                                  'sslice_2_data': {'shape': np.array([1, 227])},
704                                  'sslice_2/Reshape_shrink': {'dim': np.array([1, 227])},
705                                  'sslice_2/Reshape_shrink_data': {'shape': np.array([1, 1, 227, 1])},
706                                  })
707
708         pattern = ConvertGroupedStridedSlice()
709         pattern.add_reshape_for_shrink(graph, Node(graph, 'sslice_2'))
710
711         (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_2_data', check_op_attrs=True)
712         graph.clear()
713         graph_ref.clear()
714         self.assertTrue(flag, resp)
715
716     def test_ss_1_new(self):
717         graph = build_graph(nodes_attributes,
718                             [('placeholder_1', 'placeholder_1_data'),
719                              ('placeholder_1_data', 'sslice_2'),
720                              ('placeholder_begin_data', 'sslice_2'),
721                              ('placeholder_end_data', 'sslice_2'),
722                              ('placeholder_stride_data', 'sslice_2'),
723                              ('sslice_2', 'sslice_2_data'),
724                              ('sslice_2_data', 'placeholder_2'),
725                              ('placeholder_2', 'placeholder_2_data'), ],
726                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
727                              'sslice_2': {'slices': np.array(
728                                  [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 54, 1)]),
729                                  'shrink_axis_mask': np.array([0, 0, 0, 0, 0]),
730                                  'new_axis_mask': np.array([0, 1, 0, 0, 0])},
731                              'sslice_2_data': {'shape': np.array([1, 1, 227, 227, 54])}
732                              })
733         graph.graph['layout'] = 'NHWC'
734
735         graph_ref = build_graph(nodes_attributes,
736                                 [('placeholder_1', 'placeholder_1_data'),
737                                  ('placeholder_1_data', 'sslice_2'),
738                                  ('placeholder_begin_data', 'sslice_2'),
739                                  ('placeholder_end_data', 'sslice_2'),
740                                  ('placeholder_stride_data', 'sslice_2'),
741                                  ('sslice_2', 'sslice_2/Reshape_new_data'),
742                                  ('sslice_2/Reshape_new_data', 'sslice_2/Reshape_new'),
743                                  ('sslice_2/Reshape_new', 'sslice_2_data'),
744                                  ('sslice_2_data', 'placeholder_2'),
745                                  ('placeholder_2', 'placeholder_2_data')],
746                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
747                                  'sslice_2': {'slices': np.array(
748                                      [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1),
749                                       slice(0, 54, 1)]),
750                                      'shrink_axis_mask': np.array([0, 0, 0, 0, 0]),
751                                      'new_axis_mask': np.array([0, 0, 0, 0, 0])},
752                                  'sslice_2_data': {'shape': np.array([1, 1, 227, 227, 54])},
753                                  'sslice_2/Reshape_new': {'dim': np.array([1, 1, 227, 227, 54])},
754                                  'sslice_2/Reshape_new_data': {'shape': np.array([1, 227, 227, 54])},
755                                  })
756
757         pattern = ConvertGroupedStridedSlice()
758         pattern.add_reshape_for_new(graph, Node(graph, 'sslice_2'))
759
760         (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_2_data', check_op_attrs=True)
761         graph.clear()
762         graph_ref.clear()
763         self.assertTrue(flag, resp)
764
765     def test_ss_shrink_new(self):
766         graph = build_graph(nodes_attributes,
767                             [('placeholder_1', 'placeholder_1_data'),
768                              ('placeholder_1_data', 'sslice_2'),
769                              ('placeholder_begin_data', 'sslice_2'),
770                              ('placeholder_end_data', 'sslice_2'),
771                              ('placeholder_stride_data', 'sslice_2'),
772                              ('sslice_2', 'sslice_2_data'),
773                              ('sslice_2_data', 'placeholder_2'),
774                              ('placeholder_2', 'placeholder_2_data'),
775                              ('sslice_2_data', 'op_output')
776                              ],
777                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
778                              'sslice_2': {'slices': np.array(
779                                  [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)]),
780                                  'shrink_axis_mask': np.array([0, 0, 0, 1, 0]),
781                                  'new_axis_mask': np.array([0, 1, 0, 0, 0])},
782                              'sslice_2_data': {'shape': np.array([1, 1, 227, 54])}
783                              })
784         graph.graph['layout'] = 'NHWC'
785
786         graph_ref = build_graph(nodes_attributes,
787                                 [('placeholder_1', 'placeholder_1_data'),
788                                  ('placeholder_1_data', 'sslice_2'),
789                                  ('placeholder_begin_data', 'sslice_2'),
790                                  ('placeholder_end_data', 'sslice_2'),
791                                  ('placeholder_stride_data', 'sslice_2'),
792                                  ('sslice_2', 'sslice_2/Reshape_new_data'),
793                                  ('sslice_2/Reshape_new_data', 'sslice_2/Reshape_new'),
794                                  ('sslice_2/Reshape_new', 'sslice_2/Reshape_shrink_data'),
795                                  ('sslice_2/Reshape_shrink_data', 'sslice_2/Reshape_shrink'),
796                                  ('sslice_2/Reshape_shrink', 'sslice_2_data'),
797                                  ('sslice_2_data', 'placeholder_2'),
798                                  ('placeholder_2', 'placeholder_2_data'),
799                                  ('sslice_2_data', 'op_output')
800                                  ],
801                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
802                                  'sslice_2': {'slices': np.array(
803                                      [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1),
804                                       slice(0, 54, 1)]),
805                                      'shrink_axis_mask': np.array([0, 0, 0, 0, 0]),
806                                      'new_axis_mask': np.array([0, 0, 0, 0, 0])},
807                                  'sslice_2_data': {'shape': np.array([1, 1, 227, 54])},
808                                  'sslice_2/Reshape_new': {'dim': np.array([1, 1, 227, 1, 54])},
809                                  'sslice_2/Reshape_new_data': {'shape': np.array([1, 227, 1, 54])},
810                                  'sslice_2/Reshape_shrink': {'dim': np.array([1, 1, 227, 54])},
811                                  'sslice_2/Reshape_shrink_data': {'shape': np.array([1, 1, 227, 1, 54])},
812                                  })
813
814         pattern = ConvertGroupedStridedSlice()
815         pattern.add_reshape_for_shrink(graph, Node(graph, 'sslice_2'))
816         pattern.add_reshape_for_new(graph, Node(graph, 'sslice_2'))
817
818         (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_2_data', check_op_attrs=True)
819         graph.clear()
820         graph_ref.clear()
821         self.assertTrue(flag, resp)
822
823     # test case with 2 strided slices with the same parameters but different outputs
824     def test_1(self):
825         graph = build_graph(nodes_attributes,
826                             [('placeholder_1', 'placeholder_1_data'),
827                              ('placeholder_1_data', 'sslice_1'),
828                              ('sslice_1', 'sslice_1_data'),
829                              ('placeholder_1_data', 'sslice_2'),
830                              ('sslice_2', 'sslice_2_data'),
831                              ('placeholder_1_data', 'sslice_3'),
832                              ('sslice_3', 'sslice_3_data'),
833                              ('sslice_1_data', 'concat_1'),
834                              ('sslice_2_data', 'concat_1'),
835                              ('sslice_3_data', 'placeholder_2'),
836                              ('placeholder_2', 'placeholder_2_data'),
837                              ('concat_1', 'concat_1_data'),
838                              ('concat_1_data', 'op_output'),
839                              ('placeholder_2_data', 'op_output')
840                              ],
841                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
842
843                              'sslice_1': {'slices': np.array(
844                                  [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 27, 1)])},
845                              'sslice_1_data': {'shape': np.array([1, 227, 227, 27])},
846
847                              'sslice_2': {'slices': np.array(
848                                  [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(27, 54, 1)])},
849                              'sslice_2_data': {'shape': np.array([1, 227, 227, 27])},
850
851                              'sslice_3': {'slices': np.array(
852                                  [slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 27, 1)])},
853                              'sslice_3_data': {'shape': np.array([1, 227, 227, 27])},
854
855                              'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
856                              })
857         graph.graph['layout'] = 'NHWC'
858
859         graph_ref = build_graph(nodes_attributes,
860                                 [('placeholder_1', 'placeholder_1_data'),
861                                  ('placeholder_1_data', 'split_1'),
862                                  ('split_1', 'split_1_data'),
863                                  ('split_1', 'split_2_data'),
864                                  ('split_1_data', 'concat_1'),
865                                  ('split_2_data', 'concat_1'),
866                                  ('split_1_data', 'placeholder_2'),
867                                  ('placeholder_2', 'placeholder_2_data'),
868                                  ('concat_1', 'concat_1_data'),
869                                  ('concat_1_data', 'op_output'),
870                                  ('placeholder_2_data', 'op_output')
871                                  ],
872                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
873                                  'split_1': {'axis': 3},
874                                  'split_1_data': {'shape': np.array([1, 227, 227, 27])},
875                                  'split_2_data': {'shape': np.array([1, 227, 227, 27])},
876                                  'concat_1_data': {'shape': np.array([1, 227, 227, 54])},
877                                  })
878
879         pattern = ConvertGroupedStridedSlice()
880         pattern.find_and_replace_pattern(graph)
881
882         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data', check_op_attrs=True)
883         self.assertTrue(flag, resp)
884
885
886 if __name__ == '__main__':
887     unittest.main()