Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / common / partial_infer / slice_test.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import unittest
18
19 import numpy as np
20
21 from mo.front.common.partial_infer.slice import caffe_slice_infer, tf_strided_slice_infer, \
22     convert_negative_indices, mxnet_slice_axis_infer
23 from mo.graph.graph import Node
24 from mo.utils.unittest.graph import build_graph
25
26 nodes_attributes = {'node_1': {'value': None, 'kind': 'data'},
27                     'Slice_node': {'type': 'Slice', 'kind': 'op'},
28                     'node_2': {'value': None, 'kind': 'data'},
29                     'node_3': {'value': None, 'kind': 'data'},
30                     'node_4': {'value': None, 'kind': 'data'},
31                     # StridedSlice node with attrs
32                     'sslice_input': {'value': None, 'shape': None, 'kind': 'data'},
33                     'sslice_1': {'type': 'StridedSlice', 'value': None, 'kind': 'op', 'op': 'StridedSlice'},
34                     'sslice_begin_1': {'value': None, 'shape': None, 'kind': 'data'},
35                     'sslice_end_1': {'value': None, 'shape': None, 'kind': 'data'},
36                     'sslice_stride_1': {'value': None, 'shape': None, 'kind': 'data'},
37                     'sslice_data_1': {'value': None, 'shape': None, 'kind': 'data'},
38                     # TF slice
39                     'tf_slice_input': {'value': None, 'shape': None, 'kind': 'data'},
40                     'tf_slice_begin': {'value': None, 'shape': None, 'kind': 'data'},
41                     'tf_slice_size': {'value': None, 'shape': None, 'kind': 'data'},
42                     'tf_slice': {'kind': 'op'},
43                     'tf_slice_output': {'value': None, 'shape': None, 'kind': 'data'},
44                     'op_output': {'kind': 'op', 'op': 'OpOutput'},
45                     'op_output_1': {'kind': 'op', 'op': 'OpOutput'},
46                     'op_output_2': {'kind': 'op', 'op': 'OpOutput'}
47                     }
48
49 tf_slice_edges = [('tf_slice_input', 'tf_slice'), ('tf_slice_begin', 'tf_slice'), ('tf_slice_size', 'tf_slice'),
50                   ('tf_slice', 'tf_slice_output')]
51
52
53 class TestSSliceInfer(unittest.TestCase):
54     def test_slice_infer_ideal(self):
55         graph = build_graph(nodes_attributes,
56                             [('node_1', 'Slice_node'),
57                              ('Slice_node', 'node_2'),
58                              ('Slice_node', 'node_3'),
59                              ('node_2', 'op_output'),
60                              ('node_3', 'op_output_1')
61                              ],
62                             {'node_1': {'shape': np.array([1, 288, 56, 56])},
63                              'node_2': {'shape': None},
64                              'node_3': {'shape': None},
65                              'Slice_node': {'axis': 1, 'slice_point': np.array([256])}
66                              })
67
68         slice_node = Node(graph, 'Slice_node')
69
70         caffe_slice_infer(slice_node)
71         exp_shape1 = np.array([1, 256, 56, 56])
72         exp_shape2 = np.array([1, 32, 56, 56])
73         res_shape1 = graph.node['node_2']['shape']
74         res_shape2 = graph.node['node_3']['shape']
75
76         for i in range(0, len(exp_shape1)):
77             self.assertEqual(exp_shape1[i], res_shape1[i])
78
79         for i in range(0, len(exp_shape2)):
80             self.assertEqual(exp_shape2[i], res_shape2[i])
81
82     def test_slice_infer_no_slice_point(self):
83         graph = build_graph(nodes_attributes,
84                             [('node_1', 'Slice_node'),
85                              ('Slice_node', 'node_2'),
86                              ('Slice_node', 'node_3'),
87                              ('node_2', 'op_output'),
88                              ('node_3', 'op_output_1')
89                              ],
90                             {'node_1': {'shape': np.array([1, 288, 56, 56])},
91                              'node_2': {'shape': None},
92                              'node_3': {'shape': None},
93                              'Slice_node': {'axis': 1, 'slice_point': []}
94                              })
95
96         slice_node = Node(graph, 'Slice_node')
97
98         caffe_slice_infer(slice_node)
99         exp_shape = np.array([1, 144, 56, 56])
100         res_shape1 = graph.node['node_2']['shape']
101         res_shape2 = graph.node['node_3']['shape']
102
103         for i in range(0, len(exp_shape)):
104             self.assertEqual(exp_shape[i], res_shape1[i])
105
106         for i in range(0, len(exp_shape)):
107             self.assertEqual(exp_shape[i], res_shape2[i])
108
109     def test_slice_infer_3_outs_no_slice_point(self):
110         graph = build_graph(nodes_attributes,
111                             [('node_1', 'Slice_node'),
112                              ('Slice_node', 'node_2'),
113                              ('Slice_node', 'node_3'),
114                              ('Slice_node', 'node_4'),
115                              ('node_2', 'op_output'),
116                              ('node_3', 'op_output_1'),
117                              ('node_2', 'op_output_2')
118                              ],
119                             {'node_1': {'shape': np.array([1, 288, 56, 56])},
120                              'node_2': {'shape': None},
121                              'node_3': {'shape': None},
122                              'node_4': {'shape': None},
123                              'Slice_node': {'axis': 1, 'slice_point': []}
124                              })
125
126         slice_node = Node(graph, 'Slice_node')
127
128         caffe_slice_infer(slice_node)
129         exp_shape = np.array([1, 96, 56, 56])
130         res_shape1 = graph.node['node_2']['shape']
131         res_shape2 = graph.node['node_3']['shape']
132         res_shape3 = graph.node['node_4']['shape']
133
134         for i in range(0, len(exp_shape)):
135             self.assertEqual(exp_shape[i], res_shape1[i])
136
137         for i in range(0, len(exp_shape)):
138             self.assertEqual(exp_shape[i], res_shape2[i])
139
140         for i in range(0, len(exp_shape)):
141             self.assertEqual(exp_shape[i], res_shape3[i])
142
143     def test_slice_infer_3_outs(self):
144         graph = build_graph(nodes_attributes,
145                             [('node_1', 'Slice_node'),
146                              ('Slice_node', 'node_2'),
147                              ('Slice_node', 'node_3'),
148                              ('Slice_node', 'node_4'),
149                              ('node_2', 'op_output'),
150                              ('node_3', 'op_output_1'),
151                              ('node_2', 'op_output_2')
152                              ],
153                             {'node_1': {'shape': np.array([1, 288, 56, 56])},
154                              'node_2': {'shape': None},
155                              'node_3': {'shape': None},
156                              'node_4': {'shape': None},
157                              'Slice_node': {'axis': 1, 'slice_point': [100, 150]}
158                              })
159
160         slice_node = Node(graph, 'Slice_node')
161
162         caffe_slice_infer(slice_node)
163         exp_shape1 = np.array([1, 100, 56, 56])
164         exp_shape2 = np.array([1, 50, 56, 56])
165         exp_shape3 = np.array([1, 138, 56, 56])
166         res_shape1 = graph.node['node_2']['shape']
167         res_shape2 = graph.node['node_3']['shape']
168         res_shape3 = graph.node['node_4']['shape']
169
170         for i in range(0, len(exp_shape1)):
171             self.assertEqual(exp_shape1[i], res_shape1[i])
172
173         for i in range(0, len(exp_shape2)):
174             self.assertEqual(exp_shape2[i], res_shape2[i])
175
176         for i in range(0, len(exp_shape3)):
177             self.assertEqual(exp_shape3[i], res_shape3[i])
178
179
180 class TestTFStridedSliceInfer(unittest.TestCase):
181     def build_test_graph2(self):
182         return build_graph(nodes_attributes,
183                            [('sslice_input', 'sslice_1'),
184                             ('sslice_begin_1', 'sslice_1'),
185                             ('sslice_end_1', 'sslice_1'),
186                             ('sslice_stride_1', 'sslice_1'),
187                             ('sslice_1', 'sslice_data_1'),
188                             ('sslice_data_1', 'op_output')
189                             ],
190                            {
191                             'sslice_input': {'value': np.array([1, 34, 34, 62]),
192                                              'shape': np.array([3])},
193                             'sslice_begin_1': {'value': np.array([0]), 'shape': np.array([1])},
194                             'sslice_end_1': {'value': np.array([4]), 'shape': np.array([1])},
195                             'sslice_stride_1': {'value': np.array([1]), 'shape': np.array([1])},
196                             'sslice_1': {'shrink_axis_mask': [0], 'ellipsis_mask': [0], 'new_axis_mask': [0],
197                                          'begin_mask': [1], 'end_mask': [1]},
198                             })
199
200     def build_test_graph(self):
201         return build_graph(nodes_attributes,
202                            [('sslice_input', 'sslice_1'),
203                             ('sslice_begin_1', 'sslice_1'),
204                             ('sslice_end_1', 'sslice_1'),
205                             ('sslice_stride_1', 'sslice_1'),
206                             ('sslice_1', 'sslice_data_1'),
207                             ('sslice_data_1', 'op_output')
208                             ],
209                            {
210                             'sslice_input': {'value': None, 'shape': np.array([1, 35, 35, 3])},
211                             'sslice_begin_1': {'value': np.array([0, 0, 0, 0]), 'shape': np.array([4])},
212                             'sslice_end_1': {'value': np.array([1, 34, 30, 2]), 'shape': np.array([4])},
213                             'sslice_stride_1': {'value': np.array([1, 1, 1, 1]),
214                                                 'shape': np.array([4])},
215                             'sslice_1': {'shrink_axis_mask': [0], 'ellipsis_mask': [0], 'new_axis_mask': [0],
216                                          'begin_mask': [1], 'end_mask': [1]},
217                             })
218
219     def build_test_graph_dim_beg(self):
220         return build_graph(nodes_attributes,
221                            [('sslice_input', 'sslice_1'),
222                             ('sslice_begin_1', 'sslice_1'),
223                             ('sslice_end_1', 'sslice_1'),
224                             ('sslice_stride_1', 'sslice_1'),
225                             ('sslice_1', 'sslice_data_1'),
226                             ('sslice_data_1', 'op_output')
227                             ],
228                            {
229                             'sslice_input': {'value': np.array([[1, 34, 34, 62]]),
230                                              'shape': np.array([1, 4])},
231                             'sslice_begin_1': {'value': np.array([0]), 'shape': np.array([1])},
232                             'sslice_end_1': {'value': np.array([4]), 'shape': np.array([1])},
233                             'sslice_stride_1': {'value': np.array([1]), 'shape': np.array([1])},
234                             'sslice_1': {'shrink_axis_mask': [0], 'ellipsis_mask': [0], 'new_axis_mask': [0],
235                                          'begin_mask': [1], 'end_mask': [1]},
236                             })
237
238     def test_slice_infer_1(self):
239         graph = self.build_test_graph()
240         node = Node(graph, 'sslice_1')
241         tf_strided_slice_infer(node)
242         self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 34, 30, 2])), 'Wrong output shape detected')
243
244     def test_slice_infer_2(self):
245         graph = self.build_test_graph()
246         node = Node(graph, 'sslice_1')
247         node.end_mask = [1, 0, 0, 1]  # 6
248         tf_strided_slice_infer(node)
249         self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 35, 35, 2])), 'Wrong output shape detected')
250
251     def test_slice_infer_3(self):
252         graph = self.build_test_graph()
253         node = Node(graph, 'sslice_1')
254         node.in_node(1).value = np.array([0, 10, 10, 0])
255         node.end_mask = [1, 0, 0, 1]  # 6
256         tf_strided_slice_infer(node)
257         self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 25, 25, 2])), 'Wrong output shape detected')
258
259     def test_slice_infer_4(self):
260         graph = self.build_test_graph()
261         node = Node(graph, 'sslice_1')
262         node.in_node(1).value = np.array([0, 10, 10, 0])
263         node.begin_mask = [1, 0, 0, 1]  # 6
264         tf_strided_slice_infer(node)
265         self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 34, 30, 2])), 'Wrong output shape detected')
266
267     def test_slice_infer_5(self):
268         graph = self.build_test_graph()
269         node = Node(graph, 'sslice_1')
270         node.in_node(1).value = np.array([0, 10, 10, 0])
271         node.begin_mask = [0, 0, 0, 0]  # 15
272         node.end_mask = [0, 0, 0, 0]  # 15
273         tf_strided_slice_infer(node)
274         self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 35, 35, 3])), 'Wrong output shape detected')
275
276     def test_slice_infer_6(self):
277         graph = self.build_test_graph2()
278         node = Node(graph, 'sslice_1')
279         tf_strided_slice_infer(node)
280         self.assertTrue(np.array_equal(node.out_node().shape, np.array([4])), 'Wrong output shape detected')
281         self.assertTrue(np.array_equal(node.out_node().value, np.array([1, 34, 34, 62])), 'Wrong output value detected')
282
283     def test_slice_infer_7(self):
284         graph = self.build_test_graph2()
285         node = Node(graph, 'sslice_1')
286         node.in_node(1).value = np.array([1])
287         node.in_node(2).value = np.array([3])
288         tf_strided_slice_infer(node)
289         self.assertTrue(np.array_equal(node.out_node().shape, np.array([2])), 'Wrong output shape detected')
290         self.assertTrue(np.array_equal(node.out_node().value, np.array([34, 34])), 'Wrong output value detected')
291
292     def test_slice_infer_8(self):
293         graph = self.build_test_graph2()
294         node = Node(graph, 'sslice_1')
295         node.new_axis_mask = [1]
296         tf_strided_slice_infer(node)
297         self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 4])), 'Wrong output shape detected')
298         self.assertTrue(np.array_equal(node.out_node().value, np.array([[1, 34, 34, 62]])),
299                         'Wrong output value detected')
300
301     def test_slice_infer_9(self):
302         graph = self.build_test_graph()
303         node = Node(graph, 'sslice_1')
304         node.begin_mask = [0, 0, 0, 0]  # 15
305         node.end_mask = [0, 0, 0, 0]  # 15
306         node.shrink_axis_mask = [1]
307         tf_strided_slice_infer(node)
308         self.assertTrue(np.array_equal(node.out_node().shape, np.array([35, 35, 3])), 'Wrong output shape detected')
309
310     def test_slice_infer_10(self):
311         graph = self.build_test_graph()
312         node = Node(graph, 'sslice_1')
313         node.begin_mask = [0, 0, 0, 0]  # 15
314         node.end_mask = [0, 0, 0, 0]  # 15
315         node.shrink_axis_mask = [1, 0, 0, 0]
316         node.new_axis_mask = [0, 0, 0, 1]  # 8
317         tf_strided_slice_infer(node)
318         self.assertTrue(np.array_equal(node.out_node().shape, np.array([35, 35, 1, 3])), 'Wrong output shape detected')
319
320     def test_slice_infer_11(self):
321         graph = self.build_test_graph()
322         node = Node(graph, 'sslice_1')
323         node.begin_mask = [0, 0, 0, 0]  # 15
324         node.end_mask = [0, 0, 0, 0]  # 15
325         node.shrink_axis_mask = [1, 0, 1, 0]  # 5
326         tf_strided_slice_infer(node)
327         self.assertTrue(np.array_equal(node.out_node().shape, np.array([35, 3])), 'Wrong output shape detected')
328
329     def test_slice_infer_12(self):
330         graph = self.build_test_graph()
331         node = Node(graph, 'sslice_1')
332         node.begin_mask = [0, 0, 0, 0]  # 15
333         node.end_mask = [0, 0, 0, 0]  # 15
334         node.shrink_axis_mask = [1, 1, 1, 0]  # 7
335         tf_strided_slice_infer(node)
336         self.assertTrue(np.array_equal(node.out_node().shape, np.array([3])), 'Wrong output shape detected')
337
338     def test_slice_infer_13(self):
339         graph = self.build_test_graph2()
340         node = Node(graph, 'sslice_1')
341         node.in_node(1).value = np.array([1])
342         node.shrink_axis_mask = [1]
343         tf_strided_slice_infer(node)
344         self.assertTrue(np.array_equal(node.out_node().shape, np.array([])), 'Wrong output shape detected')
345         self.assertTrue(np.array_equal(node.out_node().value, np.array(34)), 'Wrong output shape detected')
346
347     def test_slice_infer_14(self):
348         graph = self.build_test_graph2()
349         node = Node(graph, 'sslice_1')
350         node.in_node(3).value = np.array([-1])
351         node.end_mask = [0]
352         node.begin_mask = [0]
353         node.in_node(0).shape = [4]
354         tf_strided_slice_infer(node)
355         self.assertTrue(np.array_equal(node.out_node().shape, np.array([4])), 'Wrong output shape detected')
356         print(node.out_node().value)
357         self.assertTrue(np.array_equal(node.out_node().value, np.array([62, 34, 34, 1])), 'Wrong output shape detected')
358
359     def test_slice_infer_dim_beg(self):
360         graph = self.build_test_graph_dim_beg()
361         node = Node(graph, 'sslice_1')
362         node.shrink_axis_mask = [1]
363         tf_strided_slice_infer(node)
364         self.assertTrue(np.array_equal(node.out_node().shape, np.array([4])), 'Wrong output shape detected')
365         self.assertTrue(np.array_equal(node.out_node().value, np.array([1, 34, 34, 62])), 'Wrong output shape detected')
366
367
368 class TestConvertNegativeIndices(unittest.TestCase):
369     def test_convert_negative_indices(self):
370         dimensions = np.array([3, 4, 8, 10])
371         indices = np.array([2, 0, -3, -4])
372         convert_negative_indices(indices, dimensions)
373         self.assertTrue(np.array_equal(indices, np.array([2, 0, 5, 6])), 'Wrong dimension indices')
374
375
376 class TestMXNetSliceAxisInfer(unittest.TestCase):
377     def test_slice_axis_infer_layer(self):
378         graph = build_graph(
379             {'node_1': {'name': 'data', 'type': 'Identity', 'value': None, 'kind': 'op', 'op': 'Placeholder'},
380              'slice_axis_node': {'name': 'slice_axis_node', 'type': 'sigmoid', 'value': None,
381                                  'kind': 'op', 'op': 'slice_axis', },
382              'node_3': {'name': 'node_3', 'type': 'Identity', 'value': None, 'kind': 'op'},
383              },
384             [
385                 ('node_1', 'slice_axis_node'),
386                 ('slice_axis_node', 'node_3'),
387             ],
388             {
389                 'node_1': {'shape': np.array([1, 1024, 19, 19])},
390                 'slice_axis_node': {'axis': 1, 'offset': 10, 'dim': 25},
391             })
392
393         slice_axis_node = Node(graph, 'slice_axis_node')
394         mxnet_slice_axis_infer(slice_axis_node)
395         res_shape = [1, 15, 19, 19]
396         for i in range(0, len(graph.node['node_3']['shape'])):
397             self.assertEqual(graph.node['node_3']['shape'][i], res_shape[i])