Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / ops / strided_slice_test.py
1 """
2  Copyright (c) 2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16 import unittest
17
18 import numpy as np
19 from generator import generator
20
21 from mo.graph.graph import Node
22 from mo.ops.op import PermuteAttrs
23 from mo.ops.strided_slice import permute_masks, permute_array_with_ellipsis
24 from mo.utils.unittest.graph import build_graph
25
26 nodes_attributes = {
27     'data_1': {
28         'kind': 'data',
29         'shape': None,
30         'value': None,
31     },
32     'begin': {
33         'kind': 'data',
34         'shape': None,
35         'value': None,
36     },
37     'end': {
38         'kind': 'data',
39         'shape': None,
40         'value': None,
41     },
42     'stride': {
43         'kind': 'data',
44         'shape': None,
45         'value': None,
46     },
47     'strided_slice': {
48         'op': 'StridedSlice',
49         'begin_mask': None,
50         'end_mask': None,
51         'new_axis_mask': None,
52         'shrink_axis_mask': None,
53         'ellipsis_mask': None,
54         'kind': 'op',
55     },
56     'data_2': {
57         'kind': 'data',
58         'shape': None,
59         'value': None,
60     }
61 }
62
63
64 @generator
65 class TestPermutationStridedSlice(unittest.TestCase):
66     def test_permute_begin_end(self):
67         # Testing constant path case
68         graph = build_graph(nodes_attributes,
69                             [('data_1', 'strided_slice'),
70                              ('begin', 'strided_slice'),
71                              ('end', 'strided_slice'),
72                              ('stride', 'strided_slice'),
73                              ('strided_slice', 'data_2')],
74                             {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None},
75                              'strided_slice': {'begin_mask': np.array([1, 1, 0, 0]), 'end_mask': np.array([0, 1, 0, 0]),
76                                                'new_axis_mask': np.array([0, 0, 0]), 'shrink_axis_mask': [0, 0, 0],
77                                                'ellipsis_mask': np.array([0, 0, 0])},
78                              'data_2': {'shape': np.array([1, 2, 3, 4]), 'value': None},
79                              })
80
81         slice_node = Node(graph, 'strided_slice')
82         permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'begin_mask')
83         self.assertTrue(np.array_equal(slice_node.begin_mask, np.array([1, 0, 1, 0])))
84
85         permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'end_mask')
86         self.assertTrue(np.array_equal(slice_node.end_mask, np.array([0, 0, 1, 0])))
87
88     def test_permute_begin_end_short(self):
89         # Testing constant path case
90         graph = build_graph(nodes_attributes,
91                             [('data_1', 'strided_slice'),
92                              ('begin', 'strided_slice'),
93                              ('end', 'strided_slice'),
94                              ('stride', 'strided_slice'),
95                              ('strided_slice', 'data_2')],
96                             {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None},
97                              'strided_slice': {'begin_mask': np.array([1, 0, 0]), 'end_mask': np.array([0, 1, 0]),
98                                                'new_axis_mask': np.array([0, 0, 0]), 'shrink_axis_mask': [0, 0, 0],
99                                                'ellipsis_mask': np.array([0, 0, 0])},
100                              'data_2': {'shape': np.array([1, 2, 3, 4]), 'value': None},
101                              })
102
103         slice_node = Node(graph, 'strided_slice')
104         permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'begin_mask')
105         self.assertTrue(np.array_equal(slice_node.begin_mask, np.array([1, 1, 0, 0])))
106
107         permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'end_mask')
108         self.assertTrue(np.array_equal(slice_node.end_mask, np.array([0, 1, 1, 0])))
109
110     def test_permute_begin_end_long(self):
111         # Testing constant path case
112         graph = build_graph(nodes_attributes,
113                             [('data_1', 'strided_slice'),
114                              ('begin', 'strided_slice'),
115                              ('end', 'strided_slice'),
116                              ('stride', 'strided_slice'),
117                              ('strided_slice', 'data_2')],
118                             {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None},
119                              'strided_slice': {'begin_mask': np.array([1, 0, 0, 1, 0]), 'end_mask': np.array([0, 1, 0, 1, 1]),
120                                                'new_axis_mask': np.array([0, 0, 0]), 'shrink_axis_mask': [0, 0, 0],
121                                                'ellipsis_mask': np.array([0, 0, 0])},
122                              'data_2': {'shape': np.array([1, 2, 3, 4]), 'value': None},
123                              })
124
125         slice_node = Node(graph, 'strided_slice')
126         permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'begin_mask')
127         self.assertTrue(np.array_equal(slice_node.begin_mask, np.array([1, 1, 0, 0, 0])))
128
129         permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'end_mask')
130         self.assertTrue(np.array_equal(slice_node.end_mask, np.array([0, 1, 1, 0, 1])))
131
132     def test_permute_begin_end_new(self):
133         # Testing constant path case
134         graph = build_graph(nodes_attributes,
135                             [('data_1', 'strided_slice'),
136                              ('begin', 'strided_slice'),
137                              ('end', 'strided_slice'),
138                              ('stride', 'strided_slice'),
139                              ('strided_slice', 'data_2')],
140                             {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None},
141                              'strided_slice': {'begin_mask': np.array([1, 0, 0, 1, 0]), 'end_mask': np.array([0, 1, 0, 1, 1]),
142                                                'new_axis_mask': np.array([1, 0, 0]), 'shrink_axis_mask': [0, 0, 0],
143                                                'ellipsis_mask': np.array([0, 0, 0])},
144                              'data_2': {'shape': np.array([1, 1, 2, 3, 4]), 'value': None},
145                              })
146
147         slice_node = Node(graph, 'strided_slice')
148         permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 4, 1, 2, 3], inv=[0, 2, 3, 4, 1]), 'begin_mask')
149         self.assertTrue(np.array_equal(slice_node.begin_mask, np.array([1, 0, 0, 0, 1])))
150
151         permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 4, 1, 2, 3], inv=[0, 2, 3, 4, 1]), 'end_mask')
152         self.assertTrue(np.array_equal(slice_node.end_mask, np.array([0, 1, 1, 0, 1])))
153
154     def test_permute_begin_end_new_short(self):
155         # Testing constant path case
156         graph = build_graph(nodes_attributes,
157                             [('data_1', 'strided_slice'),
158                              ('begin', 'strided_slice'),
159                              ('end', 'strided_slice'),
160                              ('stride', 'strided_slice'),
161                              ('strided_slice', 'data_2')],
162                             {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None},
163                              'strided_slice': {'begin_mask': np.array([1, 0, 0]), 'end_mask': np.array([0, 1, 0]),
164                                                'new_axis_mask': np.array([1, 0, 0]), 'shrink_axis_mask': [0, 0, 0],
165                                                'ellipsis_mask': np.array([0, 0, 0])},
166                              'data_2': {'shape': np.array([1, 1, 2, 3, 4]), 'value': None},
167                              })
168
169         slice_node = Node(graph, 'strided_slice')
170         permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 4, 1, 2, 3], inv=[0, 2, 3, 4, 1]), 'begin_mask')
171         self.assertTrue(np.array_equal(slice_node.begin_mask, np.array([1, 1, 0, 0, 1])))
172
173         permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 4, 1, 2, 3], inv=[0, 2, 3, 4, 1]), 'end_mask')
174         self.assertTrue(np.array_equal(slice_node.end_mask, np.array([0, 1, 1, 0, 1])))
175
176     def test_permute_begin_end_shrink(self):
177         # Testing constant path case
178         graph = build_graph(nodes_attributes,
179                             [('data_1', 'strided_slice'),
180                              ('begin', 'strided_slice'),
181                              ('end', 'strided_slice'),
182                              ('stride', 'strided_slice'),
183                              ('strided_slice', 'data_2')],
184                             {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None},
185                              'strided_slice': {'begin_mask': np.array([1, 0, 0, 1]), 'end_mask': np.array([0, 1, 0, 1]),
186                                                'new_axis_mask': np.array([0, 0, 0]), 'shrink_axis_mask': [1, 0, 0],
187                                                'ellipsis_mask': np.array([0, 0, 0])},
188                              'data_2': {'shape': np.array([2, 3, 4]), 'value': None},
189                              })
190
191         slice_node = Node(graph, 'strided_slice')
192         permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'begin_mask')
193
194         self.assertTrue(np.array_equal(slice_node.begin_mask, np.array([1, 1, 0, 0])))
195
196         permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'end_mask')
197         self.assertTrue(np.array_equal(slice_node.end_mask, np.array([0, 1, 1, 0])))
198
199     def test_permute_begin_end_shrink_short(self):
200         # Testing constant path case
201         graph = build_graph(nodes_attributes,
202                             [('data_1', 'strided_slice'),
203                              ('begin', 'strided_slice'),
204                              ('end', 'strided_slice'),
205                              ('stride', 'strided_slice'),
206                              ('strided_slice', 'data_2')],
207                             {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None},
208                              'strided_slice': {'begin_mask': np.array([1, 0, 0]), 'end_mask': np.array([0, 1, 0]),
209                                                'new_axis_mask': np.array([0, 0, 0]), 'shrink_axis_mask': [1, 0, 0],
210                                                'ellipsis_mask': np.array([0, 0, 0])},
211                              'data_2': {'shape': np.array([2, 3, 4]), 'value': None},
212                              })
213
214         slice_node = Node(graph, 'strided_slice')
215         permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'begin_mask')
216         self.assertTrue(np.array_equal(slice_node.begin_mask, np.array([1, 1, 0, 0])))
217
218         permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'end_mask')
219         self.assertTrue(np.array_equal(slice_node.end_mask, np.array([0, 1, 1, 0])))
220
221     def test_permute_begin_end_ellipsis(self):
222         # Testing constant path case
223         graph = build_graph(nodes_attributes,
224                             [('data_1', 'strided_slice'),
225                              ('begin', 'strided_slice'),
226                              ('end', 'strided_slice'),
227                              ('stride', 'strided_slice'),
228                              ('strided_slice', 'data_2')],
229                             {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None},
230                              'strided_slice': {'begin_mask': np.array([0, 0]), 'end_mask': np.array([1, 0]),
231                                                'new_axis_mask': np.array([0]), 'shrink_axis_mask': [0],
232                                                'ellipsis_mask': np.array([1, 0])},
233                              'data_2': {'shape': np.array([1, 2, 3, 4]), 'value': None},
234                              })
235
236         slice_node = Node(graph, 'strided_slice')
237         permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'begin_mask')
238         self.assertTrue(np.array_equal(slice_node.begin_mask, np.array([0, 0, 1, 1])))
239
240         permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'end_mask')
241         self.assertTrue(np.array_equal(slice_node.end_mask, np.array([1, 0, 1, 1])))
242
243     def test_permute_begin_end_ellipsis_new(self):
244         # Testing constant path case
245         graph = build_graph(nodes_attributes,
246                             [('data_1', 'strided_slice'),
247                              ('begin', 'strided_slice'),
248                              ('end', 'strided_slice'),
249                              ('stride', 'strided_slice'),
250                              ('strided_slice', 'data_2')],
251                             {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None},
252                              'strided_slice': {'begin_mask': np.array([0, 0, 0]), 'end_mask': np.array([1, 0, 0]),
253                                                'new_axis_mask': np.array([1, 0, 0]), 'shrink_axis_mask': [0],
254                                                'ellipsis_mask': np.array([0, 1, 0])},
255                              'data_2': {'shape': np.array([1, 1, 2, 3, 4]), 'value': None},
256                              })
257
258         slice_node = Node(graph, 'strided_slice')
259         permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 4, 1, 2, 3], inv=[0, 2, 3, 4, 1]), 'begin_mask')
260         self.assertTrue(np.array_equal(slice_node.begin_mask, np.array([0, 0, 0, 1, 1])))
261
262         permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 4, 1, 2, 3], inv=[0, 2, 3, 4, 1]), 'end_mask')
263         self.assertTrue(np.array_equal(slice_node.end_mask, np.array([1, 0, 0, 1, 1])))
264
265     def test_permute_begin_end_ellipsis_new_inputs(self):
266         # Testing constant path case
267         graph = build_graph(nodes_attributes,
268                             [('data_1', 'strided_slice'),
269                              ('begin', 'strided_slice'),
270                              ('end', 'strided_slice'),
271                              ('stride', 'strided_slice'),
272                              ('strided_slice', 'data_2')],
273                             {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None},
274                              'strided_slice': {'begin_mask': np.array([0, 0, 0]), 'end_mask': np.array([1, 0, 0]),
275                                                'new_axis_mask': np.array([1, 0, 0]), 'shrink_axis_mask': [0],
276                                                'ellipsis_mask': np.array([0, 1, 0])},
277                              'begin': {'value': np.array([0, 1, 2])},
278                              'end': {'value': np.array([1, 2, 3])},
279                              'stride': {'value': np.array([1, 1, 1])},
280                              'data_2': {'shape': np.array([1, 1, 2, 3, 4]), 'value': None},
281                              })
282
283         slice_node = Node(graph, 'strided_slice')
284         slice_node.in_node(1).value = permute_array_with_ellipsis(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]),
285                                                                   slice_node.in_node(1).value, 0)
286         self.assertTrue(np.array_equal(slice_node.in_node(1).value, np.array([0, 2, 1, 0, 0])))
287
288         slice_node.in_node(2).value = permute_array_with_ellipsis(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]),
289                                                                   slice_node.in_node(2).value, 0)
290         self.assertTrue(np.array_equal(slice_node.in_node(2).value, np.array([1, 3, 2, 0, 0])))