2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from mo.front.common.partial_infer.slice import caffe_slice_infer, tf_strided_slice_infer, \
22 convert_negative_indices, mxnet_slice_axis_infer
23 from mo.graph.graph import Node
24 from mo.utils.unittest.graph import build_graph
26 nodes_attributes = {'node_1': {'value': None, 'kind': 'data'},
27 'Slice_node': {'type': 'Slice', 'kind': 'op'},
28 'node_2': {'value': None, 'kind': 'data'},
29 'node_3': {'value': None, 'kind': 'data'},
30 'node_4': {'value': None, 'kind': 'data'},
31 # StridedSlice node with attrs
32 'sslice_input': {'value': None, 'shape': None, 'kind': 'data'},
33 'sslice_1': {'type': 'StridedSlice', 'value': None, 'kind': 'op', 'op': 'StridedSlice'},
34 'sslice_begin_1': {'value': None, 'shape': None, 'kind': 'data'},
35 'sslice_end_1': {'value': None, 'shape': None, 'kind': 'data'},
36 'sslice_stride_1': {'value': None, 'shape': None, 'kind': 'data'},
37 'sslice_data_1': {'value': None, 'shape': None, 'kind': 'data'},
39 'tf_slice_input': {'value': None, 'shape': None, 'kind': 'data'},
40 'tf_slice_begin': {'value': None, 'shape': None, 'kind': 'data'},
41 'tf_slice_size': {'value': None, 'shape': None, 'kind': 'data'},
42 'tf_slice': {'kind': 'op'},
43 'tf_slice_output': {'value': None, 'shape': None, 'kind': 'data'},
44 'op_output': {'kind': 'op', 'op': 'OpOutput'},
45 'op_output_1': {'kind': 'op', 'op': 'OpOutput'},
46 'op_output_2': {'kind': 'op', 'op': 'OpOutput'}
49 tf_slice_edges = [('tf_slice_input', 'tf_slice'), ('tf_slice_begin', 'tf_slice'), ('tf_slice_size', 'tf_slice'),
50 ('tf_slice', 'tf_slice_output')]
53 class TestSSliceInfer(unittest.TestCase):
54 def test_slice_infer_ideal(self):
55 graph = build_graph(nodes_attributes,
56 [('node_1', 'Slice_node'),
57 ('Slice_node', 'node_2'),
58 ('Slice_node', 'node_3'),
59 ('node_2', 'op_output'),
60 ('node_3', 'op_output_1')
62 {'node_1': {'shape': np.array([1, 288, 56, 56])},
63 'node_2': {'shape': None},
64 'node_3': {'shape': None},
65 'Slice_node': {'axis': 1, 'slice_point': np.array([256])}
68 slice_node = Node(graph, 'Slice_node')
70 caffe_slice_infer(slice_node)
71 exp_shape1 = np.array([1, 256, 56, 56])
72 exp_shape2 = np.array([1, 32, 56, 56])
73 res_shape1 = graph.node['node_2']['shape']
74 res_shape2 = graph.node['node_3']['shape']
76 for i in range(0, len(exp_shape1)):
77 self.assertEqual(exp_shape1[i], res_shape1[i])
79 for i in range(0, len(exp_shape2)):
80 self.assertEqual(exp_shape2[i], res_shape2[i])
82 def test_slice_infer_no_slice_point(self):
83 graph = build_graph(nodes_attributes,
84 [('node_1', 'Slice_node'),
85 ('Slice_node', 'node_2'),
86 ('Slice_node', 'node_3'),
87 ('node_2', 'op_output'),
88 ('node_3', 'op_output_1')
90 {'node_1': {'shape': np.array([1, 288, 56, 56])},
91 'node_2': {'shape': None},
92 'node_3': {'shape': None},
93 'Slice_node': {'axis': 1, 'slice_point': []}
96 slice_node = Node(graph, 'Slice_node')
98 caffe_slice_infer(slice_node)
99 exp_shape = np.array([1, 144, 56, 56])
100 res_shape1 = graph.node['node_2']['shape']
101 res_shape2 = graph.node['node_3']['shape']
103 for i in range(0, len(exp_shape)):
104 self.assertEqual(exp_shape[i], res_shape1[i])
106 for i in range(0, len(exp_shape)):
107 self.assertEqual(exp_shape[i], res_shape2[i])
109 def test_slice_infer_3_outs_no_slice_point(self):
110 graph = build_graph(nodes_attributes,
111 [('node_1', 'Slice_node'),
112 ('Slice_node', 'node_2'),
113 ('Slice_node', 'node_3'),
114 ('Slice_node', 'node_4'),
115 ('node_2', 'op_output'),
116 ('node_3', 'op_output_1'),
117 ('node_2', 'op_output_2')
119 {'node_1': {'shape': np.array([1, 288, 56, 56])},
120 'node_2': {'shape': None},
121 'node_3': {'shape': None},
122 'node_4': {'shape': None},
123 'Slice_node': {'axis': 1, 'slice_point': []}
126 slice_node = Node(graph, 'Slice_node')
128 caffe_slice_infer(slice_node)
129 exp_shape = np.array([1, 96, 56, 56])
130 res_shape1 = graph.node['node_2']['shape']
131 res_shape2 = graph.node['node_3']['shape']
132 res_shape3 = graph.node['node_4']['shape']
134 for i in range(0, len(exp_shape)):
135 self.assertEqual(exp_shape[i], res_shape1[i])
137 for i in range(0, len(exp_shape)):
138 self.assertEqual(exp_shape[i], res_shape2[i])
140 for i in range(0, len(exp_shape)):
141 self.assertEqual(exp_shape[i], res_shape3[i])
143 def test_slice_infer_3_outs(self):
144 graph = build_graph(nodes_attributes,
145 [('node_1', 'Slice_node'),
146 ('Slice_node', 'node_2'),
147 ('Slice_node', 'node_3'),
148 ('Slice_node', 'node_4'),
149 ('node_2', 'op_output'),
150 ('node_3', 'op_output_1'),
151 ('node_2', 'op_output_2')
153 {'node_1': {'shape': np.array([1, 288, 56, 56])},
154 'node_2': {'shape': None},
155 'node_3': {'shape': None},
156 'node_4': {'shape': None},
157 'Slice_node': {'axis': 1, 'slice_point': [100, 150]}
160 slice_node = Node(graph, 'Slice_node')
162 caffe_slice_infer(slice_node)
163 exp_shape1 = np.array([1, 100, 56, 56])
164 exp_shape2 = np.array([1, 50, 56, 56])
165 exp_shape3 = np.array([1, 138, 56, 56])
166 res_shape1 = graph.node['node_2']['shape']
167 res_shape2 = graph.node['node_3']['shape']
168 res_shape3 = graph.node['node_4']['shape']
170 for i in range(0, len(exp_shape1)):
171 self.assertEqual(exp_shape1[i], res_shape1[i])
173 for i in range(0, len(exp_shape2)):
174 self.assertEqual(exp_shape2[i], res_shape2[i])
176 for i in range(0, len(exp_shape3)):
177 self.assertEqual(exp_shape3[i], res_shape3[i])
180 class TestTFStridedSliceInfer(unittest.TestCase):
181 def build_test_graph2(self):
182 return build_graph(nodes_attributes,
183 [('sslice_input', 'sslice_1'),
184 ('sslice_begin_1', 'sslice_1'),
185 ('sslice_end_1', 'sslice_1'),
186 ('sslice_stride_1', 'sslice_1'),
187 ('sslice_1', 'sslice_data_1'),
188 ('sslice_data_1', 'op_output')
191 'sslice_input': {'value': np.array([1, 34, 34, 62]),
192 'shape': np.array([3])},
193 'sslice_begin_1': {'value': np.array([0]), 'shape': np.array([1])},
194 'sslice_end_1': {'value': np.array([4]), 'shape': np.array([1])},
195 'sslice_stride_1': {'value': np.array([1]), 'shape': np.array([1])},
196 'sslice_1': {'shrink_axis_mask': [0], 'ellipsis_mask': [0], 'new_axis_mask': [0],
197 'begin_mask': [1], 'end_mask': [1]},
200 def build_test_graph(self):
201 return build_graph(nodes_attributes,
202 [('sslice_input', 'sslice_1'),
203 ('sslice_begin_1', 'sslice_1'),
204 ('sslice_end_1', 'sslice_1'),
205 ('sslice_stride_1', 'sslice_1'),
206 ('sslice_1', 'sslice_data_1'),
207 ('sslice_data_1', 'op_output')
210 'sslice_input': {'value': None, 'shape': np.array([1, 35, 35, 3])},
211 'sslice_begin_1': {'value': np.array([0, 0, 0, 0]), 'shape': np.array([4])},
212 'sslice_end_1': {'value': np.array([1, 34, 30, 2]), 'shape': np.array([4])},
213 'sslice_stride_1': {'value': np.array([1, 1, 1, 1]),
214 'shape': np.array([4])},
215 'sslice_1': {'shrink_axis_mask': [0], 'ellipsis_mask': [0], 'new_axis_mask': [0],
216 'begin_mask': [1], 'end_mask': [1]},
219 def build_test_graph_dim_beg(self):
220 return build_graph(nodes_attributes,
221 [('sslice_input', 'sslice_1'),
222 ('sslice_begin_1', 'sslice_1'),
223 ('sslice_end_1', 'sslice_1'),
224 ('sslice_stride_1', 'sslice_1'),
225 ('sslice_1', 'sslice_data_1'),
226 ('sslice_data_1', 'op_output')
229 'sslice_input': {'value': np.array([[1, 34, 34, 62]]),
230 'shape': np.array([1, 4])},
231 'sslice_begin_1': {'value': np.array([0]), 'shape': np.array([1])},
232 'sslice_end_1': {'value': np.array([4]), 'shape': np.array([1])},
233 'sslice_stride_1': {'value': np.array([1]), 'shape': np.array([1])},
234 'sslice_1': {'shrink_axis_mask': [0], 'ellipsis_mask': [0], 'new_axis_mask': [0],
235 'begin_mask': [1], 'end_mask': [1]},
238 def test_slice_infer_1(self):
239 graph = self.build_test_graph()
240 node = Node(graph, 'sslice_1')
241 tf_strided_slice_infer(node)
242 self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 34, 30, 2])), 'Wrong output shape detected')
244 def test_slice_infer_2(self):
245 graph = self.build_test_graph()
246 node = Node(graph, 'sslice_1')
247 node.end_mask = [1, 0, 0, 1] # 6
248 tf_strided_slice_infer(node)
249 self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 35, 35, 2])), 'Wrong output shape detected')
251 def test_slice_infer_3(self):
252 graph = self.build_test_graph()
253 node = Node(graph, 'sslice_1')
254 node.in_node(1).value = np.array([0, 10, 10, 0])
255 node.end_mask = [1, 0, 0, 1] # 6
256 tf_strided_slice_infer(node)
257 self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 25, 25, 2])), 'Wrong output shape detected')
259 def test_slice_infer_4(self):
260 graph = self.build_test_graph()
261 node = Node(graph, 'sslice_1')
262 node.in_node(1).value = np.array([0, 10, 10, 0])
263 node.begin_mask = [1, 0, 0, 1] # 6
264 tf_strided_slice_infer(node)
265 self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 34, 30, 2])), 'Wrong output shape detected')
267 def test_slice_infer_5(self):
268 graph = self.build_test_graph()
269 node = Node(graph, 'sslice_1')
270 node.in_node(1).value = np.array([0, 10, 10, 0])
271 node.begin_mask = [0, 0, 0, 0] # 15
272 node.end_mask = [0, 0, 0, 0] # 15
273 tf_strided_slice_infer(node)
274 self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 35, 35, 3])), 'Wrong output shape detected')
276 def test_slice_infer_6(self):
277 graph = self.build_test_graph2()
278 node = Node(graph, 'sslice_1')
279 tf_strided_slice_infer(node)
280 self.assertTrue(np.array_equal(node.out_node().shape, np.array([4])), 'Wrong output shape detected')
281 self.assertTrue(np.array_equal(node.out_node().value, np.array([1, 34, 34, 62])), 'Wrong output value detected')
283 def test_slice_infer_7(self):
284 graph = self.build_test_graph2()
285 node = Node(graph, 'sslice_1')
286 node.in_node(1).value = np.array([1])
287 node.in_node(2).value = np.array([3])
288 tf_strided_slice_infer(node)
289 self.assertTrue(np.array_equal(node.out_node().shape, np.array([2])), 'Wrong output shape detected')
290 self.assertTrue(np.array_equal(node.out_node().value, np.array([34, 34])), 'Wrong output value detected')
292 def test_slice_infer_8(self):
293 graph = self.build_test_graph2()
294 node = Node(graph, 'sslice_1')
295 node.new_axis_mask = [1]
296 tf_strided_slice_infer(node)
297 self.assertTrue(np.array_equal(node.out_node().shape, np.array([1, 4])), 'Wrong output shape detected')
298 self.assertTrue(np.array_equal(node.out_node().value, np.array([[1, 34, 34, 62]])),
299 'Wrong output value detected')
301 def test_slice_infer_9(self):
302 graph = self.build_test_graph()
303 node = Node(graph, 'sslice_1')
304 node.begin_mask = [0, 0, 0, 0] # 15
305 node.end_mask = [0, 0, 0, 0] # 15
306 node.shrink_axis_mask = [1]
307 tf_strided_slice_infer(node)
308 self.assertTrue(np.array_equal(node.out_node().shape, np.array([35, 35, 3])), 'Wrong output shape detected')
310 def test_slice_infer_10(self):
311 graph = self.build_test_graph()
312 node = Node(graph, 'sslice_1')
313 node.begin_mask = [0, 0, 0, 0] # 15
314 node.end_mask = [0, 0, 0, 0] # 15
315 node.shrink_axis_mask = [1, 0, 0, 0]
316 node.new_axis_mask = [0, 0, 0, 1] # 8
317 tf_strided_slice_infer(node)
318 self.assertTrue(np.array_equal(node.out_node().shape, np.array([35, 35, 1, 3])), 'Wrong output shape detected')
320 def test_slice_infer_11(self):
321 graph = self.build_test_graph()
322 node = Node(graph, 'sslice_1')
323 node.begin_mask = [0, 0, 0, 0] # 15
324 node.end_mask = [0, 0, 0, 0] # 15
325 node.shrink_axis_mask = [1, 0, 1, 0] # 5
326 tf_strided_slice_infer(node)
327 self.assertTrue(np.array_equal(node.out_node().shape, np.array([35, 3])), 'Wrong output shape detected')
329 def test_slice_infer_12(self):
330 graph = self.build_test_graph()
331 node = Node(graph, 'sslice_1')
332 node.begin_mask = [0, 0, 0, 0] # 15
333 node.end_mask = [0, 0, 0, 0] # 15
334 node.shrink_axis_mask = [1, 1, 1, 0] # 7
335 tf_strided_slice_infer(node)
336 self.assertTrue(np.array_equal(node.out_node().shape, np.array([3])), 'Wrong output shape detected')
338 def test_slice_infer_13(self):
339 graph = self.build_test_graph2()
340 node = Node(graph, 'sslice_1')
341 node.in_node(1).value = np.array([1])
342 node.shrink_axis_mask = [1]
343 tf_strided_slice_infer(node)
344 self.assertTrue(np.array_equal(node.out_node().shape, np.array([])), 'Wrong output shape detected')
345 self.assertTrue(np.array_equal(node.out_node().value, np.array(34)), 'Wrong output shape detected')
347 def test_slice_infer_14(self):
348 graph = self.build_test_graph2()
349 node = Node(graph, 'sslice_1')
350 node.in_node(3).value = np.array([-1])
352 node.begin_mask = [0]
353 node.in_node(0).shape = [4]
354 tf_strided_slice_infer(node)
355 self.assertTrue(np.array_equal(node.out_node().shape, np.array([4])), 'Wrong output shape detected')
356 print(node.out_node().value)
357 self.assertTrue(np.array_equal(node.out_node().value, np.array([62, 34, 34, 1])), 'Wrong output shape detected')
359 def test_slice_infer_dim_beg(self):
360 graph = self.build_test_graph_dim_beg()
361 node = Node(graph, 'sslice_1')
362 node.shrink_axis_mask = [1]
363 tf_strided_slice_infer(node)
364 self.assertTrue(np.array_equal(node.out_node().shape, np.array([4])), 'Wrong output shape detected')
365 self.assertTrue(np.array_equal(node.out_node().value, np.array([1, 34, 34, 62])), 'Wrong output shape detected')
368 class TestConvertNegativeIndices(unittest.TestCase):
369 def test_convert_negative_indices(self):
370 dimensions = np.array([3, 4, 8, 10])
371 indices = np.array([2, 0, -3, -4])
372 convert_negative_indices(indices, dimensions)
373 self.assertTrue(np.array_equal(indices, np.array([2, 0, 5, 6])), 'Wrong dimension indices')
376 class TestMXNetSliceAxisInfer(unittest.TestCase):
377 def test_slice_axis_infer_layer(self):
379 {'node_1': {'name': 'data', 'type': 'Identity', 'value': None, 'kind': 'op', 'op': 'Placeholder'},
380 'slice_axis_node': {'name': 'slice_axis_node', 'type': 'sigmoid', 'value': None,
381 'kind': 'op', 'op': 'slice_axis', },
382 'node_3': {'name': 'node_3', 'type': 'Identity', 'value': None, 'kind': 'op'},
385 ('node_1', 'slice_axis_node'),
386 ('slice_axis_node', 'node_3'),
389 'node_1': {'shape': np.array([1, 1024, 19, 19])},
390 'slice_axis_node': {'axis': 1, 'offset': 10, 'dim': 25},
393 slice_axis_node = Node(graph, 'slice_axis_node')
394 mxnet_slice_axis_infer(slice_axis_node)
395 res_shape = [1, 15, 19, 19]
396 for i in range(0, len(graph.node['node_3']['shape'])):
397 self.assertEqual(graph.node['node_3']['shape'][i], res_shape[i])