Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / ops / slice_test.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import unittest
17
18 import numpy as np
19 from generator import generator
20
21 from mo.graph.graph import Node
22 from mo.ops.slice import Slice
23 from mo.utils.unittest.graph import build_graph
24
25 nodes_attributes = {
26     'data_1': {
27         'kind': 'data',
28         'shape': None,
29         'value': None,
30     },
31     'begin': {
32         'kind': 'data',
33         'shape': None,
34         'value': None,
35     },
36     'size': {
37         'kind': 'data',
38         'shape': None,
39         'value': None,
40     },
41     'slice': {
42         'op': 'Slice',
43         'axis': None,
44         'start': None,
45         'end': None,
46         'kind': 'op',
47     },
48     'data_2': {
49         'kind': 'data',
50         'shape': None,
51         'value': None,
52     }
53 }
54
55
56 @generator
57 class TestSliceOp(unittest.TestCase):
58     def test_slice_infer_constant(self):
59         # Testing constant path case
60         graph = build_graph(nodes_attributes,
61                             [('data_1', 'slice'),
62                              ('begin', 'slice'),
63                              ('size', 'slice'),
64                              ('slice', 'data_2')],
65                             {'data_1': {'shape': np.array([4]), 'value': np.array([1, 3, 224, 224])},
66                              'slice': {'start': np.array([1]), 'end': np.array([2])},
67                              'size': {'value': np.array([1])},
68                              'begin': {'value': np.array([1])}})
69
70         slice_node = Node(graph, 'slice')
71         Slice.infer(slice_node)
72
73         self.assertTrue(np.array_equal(slice_node.out_node().value, np.array([3])))
74         self.assertTrue(np.array_equal(slice_node.out_node().shape, np.array([1])))
75         self.assertTrue(np.array_equal(slice_node['slices'], np.array([slice(1, 2, 1)])))
76
77     def test_slice_infer_non_constant(self):
78         # Testing non-constant path case (when value in input is None)
79         # with multiply params
80         graph = build_graph(nodes_attributes,
81                             [('data_1', 'slice'),
82                              ('begin', 'slice'),
83                              ('size', 'slice'),
84                              ('slice', 'data_2')],
85                             {'data_1': {'shape': np.array([4, 5, 6])},
86                              'slice': {'start': np.array([1, 2]),
87                                        'end': np.array([4, 3])},
88                              'size': {'value': np.array([3, 1])},
89                              'begin': {'value': np.array([1, 2])}})
90
91         slice_node = Node(graph, 'slice')
92
93         Slice.infer(slice_node)
94         self.assertTrue(np.array_equal(slice_node.out_node().value, None))
95         self.assertTrue(np.array_equal(slice_node.out_node().shape, np.array([3, 1, 6])))
96         self.assertTrue(np.array_equal(slice_node['slices'], np.array([slice(1, 4, 1), slice(2, 3, 1), slice(0, 6, 1)])))
97
98     def test_slice_infer_multiply_params(self):
99         # Test case when size[i] == -1 (that means all
100         # remaining elements in dimension i are included in the slice)
101         graph = build_graph(nodes_attributes,
102                             [('data_1', 'slice'),
103                              ('begin', 'slice'),
104                              ('size', 'slice'),
105                              ('slice', 'data_2')],
106                             {'data_1': {'shape': np.array([4, 5, 6])},
107                              'slice': {'start': np.array([1, 2]),
108                                        'end': np.array([4, 1])},
109                              'size': {'value': np.array([3, -1])},
110                              'begin': {'value': np.array([1, 2])}})
111
112         slice_node = Node(graph, 'slice')
113
114         Slice.infer(slice_node)
115         self.assertTrue(np.array_equal(slice_node.out_node().value, None))
116         self.assertTrue(np.array_equal(slice_node.out_node().shape, np.array([3, 3, 6])))
117         self.assertTrue(np.array_equal(slice_node['slices'], np.array([slice(1, 4, 1), slice(2, 5, 1), slice(0, 6, 1)])))