2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from generator import generator
21 from mo.graph.graph import Node
22 from mo.ops.op import PermuteAttrs
23 from mo.ops.strided_slice import permute_masks, permute_array_with_ellipsis
24 from mo.utils.unittest.graph import build_graph
51 'new_axis_mask': None,
52 'shrink_axis_mask': None,
53 'ellipsis_mask': None,
65 class TestPermutationStridedSlice(unittest.TestCase):
66 def test_permute_begin_end(self):
67 # Testing constant path case
68 graph = build_graph(nodes_attributes,
69 [('data_1', 'strided_slice'),
70 ('begin', 'strided_slice'),
71 ('end', 'strided_slice'),
72 ('stride', 'strided_slice'),
73 ('strided_slice', 'data_2')],
74 {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None},
75 'strided_slice': {'begin_mask': np.array([1, 1, 0, 0]), 'end_mask': np.array([0, 1, 0, 0]),
76 'new_axis_mask': np.array([0, 0, 0]), 'shrink_axis_mask': [0, 0, 0],
77 'ellipsis_mask': np.array([0, 0, 0])},
78 'data_2': {'shape': np.array([1, 2, 3, 4]), 'value': None},
81 slice_node = Node(graph, 'strided_slice')
82 permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'begin_mask')
83 self.assertTrue(np.array_equal(slice_node.begin_mask, np.array([1, 0, 1, 0])))
85 permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'end_mask')
86 self.assertTrue(np.array_equal(slice_node.end_mask, np.array([0, 0, 1, 0])))
88 def test_permute_begin_end_short(self):
89 # Testing constant path case
90 graph = build_graph(nodes_attributes,
91 [('data_1', 'strided_slice'),
92 ('begin', 'strided_slice'),
93 ('end', 'strided_slice'),
94 ('stride', 'strided_slice'),
95 ('strided_slice', 'data_2')],
96 {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None},
97 'strided_slice': {'begin_mask': np.array([1, 0, 0]), 'end_mask': np.array([0, 1, 0]),
98 'new_axis_mask': np.array([0, 0, 0]), 'shrink_axis_mask': [0, 0, 0],
99 'ellipsis_mask': np.array([0, 0, 0])},
100 'data_2': {'shape': np.array([1, 2, 3, 4]), 'value': None},
103 slice_node = Node(graph, 'strided_slice')
104 permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'begin_mask')
105 self.assertTrue(np.array_equal(slice_node.begin_mask, np.array([1, 1, 0, 0])))
107 permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'end_mask')
108 self.assertTrue(np.array_equal(slice_node.end_mask, np.array([0, 1, 1, 0])))
110 def test_permute_begin_end_long(self):
111 # Testing constant path case
112 graph = build_graph(nodes_attributes,
113 [('data_1', 'strided_slice'),
114 ('begin', 'strided_slice'),
115 ('end', 'strided_slice'),
116 ('stride', 'strided_slice'),
117 ('strided_slice', 'data_2')],
118 {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None},
119 'strided_slice': {'begin_mask': np.array([1, 0, 0, 1, 0]), 'end_mask': np.array([0, 1, 0, 1, 1]),
120 'new_axis_mask': np.array([0, 0, 0]), 'shrink_axis_mask': [0, 0, 0],
121 'ellipsis_mask': np.array([0, 0, 0])},
122 'data_2': {'shape': np.array([1, 2, 3, 4]), 'value': None},
125 slice_node = Node(graph, 'strided_slice')
126 permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'begin_mask')
127 self.assertTrue(np.array_equal(slice_node.begin_mask, np.array([1, 1, 0, 0, 0])))
129 permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'end_mask')
130 self.assertTrue(np.array_equal(slice_node.end_mask, np.array([0, 1, 1, 0, 1])))
132 def test_permute_begin_end_new(self):
133 # Testing constant path case
134 graph = build_graph(nodes_attributes,
135 [('data_1', 'strided_slice'),
136 ('begin', 'strided_slice'),
137 ('end', 'strided_slice'),
138 ('stride', 'strided_slice'),
139 ('strided_slice', 'data_2')],
140 {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None},
141 'strided_slice': {'begin_mask': np.array([1, 0, 0, 1, 0]), 'end_mask': np.array([0, 1, 0, 1, 1]),
142 'new_axis_mask': np.array([1, 0, 0]), 'shrink_axis_mask': [0, 0, 0],
143 'ellipsis_mask': np.array([0, 0, 0])},
144 'data_2': {'shape': np.array([1, 1, 2, 3, 4]), 'value': None},
147 slice_node = Node(graph, 'strided_slice')
148 permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 4, 1, 2, 3], inv=[0, 2, 3, 4, 1]), 'begin_mask')
149 self.assertTrue(np.array_equal(slice_node.begin_mask, np.array([1, 0, 0, 0, 1])))
151 permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 4, 1, 2, 3], inv=[0, 2, 3, 4, 1]), 'end_mask')
152 self.assertTrue(np.array_equal(slice_node.end_mask, np.array([0, 1, 1, 0, 1])))
154 def test_permute_begin_end_new_short(self):
155 # Testing constant path case
156 graph = build_graph(nodes_attributes,
157 [('data_1', 'strided_slice'),
158 ('begin', 'strided_slice'),
159 ('end', 'strided_slice'),
160 ('stride', 'strided_slice'),
161 ('strided_slice', 'data_2')],
162 {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None},
163 'strided_slice': {'begin_mask': np.array([1, 0, 0]), 'end_mask': np.array([0, 1, 0]),
164 'new_axis_mask': np.array([1, 0, 0]), 'shrink_axis_mask': [0, 0, 0],
165 'ellipsis_mask': np.array([0, 0, 0])},
166 'data_2': {'shape': np.array([1, 1, 2, 3, 4]), 'value': None},
169 slice_node = Node(graph, 'strided_slice')
170 permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 4, 1, 2, 3], inv=[0, 2, 3, 4, 1]), 'begin_mask')
171 self.assertTrue(np.array_equal(slice_node.begin_mask, np.array([1, 1, 0, 0, 1])))
173 permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 4, 1, 2, 3], inv=[0, 2, 3, 4, 1]), 'end_mask')
174 self.assertTrue(np.array_equal(slice_node.end_mask, np.array([0, 1, 1, 0, 1])))
176 def test_permute_begin_end_shrink(self):
177 # Testing constant path case
178 graph = build_graph(nodes_attributes,
179 [('data_1', 'strided_slice'),
180 ('begin', 'strided_slice'),
181 ('end', 'strided_slice'),
182 ('stride', 'strided_slice'),
183 ('strided_slice', 'data_2')],
184 {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None},
185 'strided_slice': {'begin_mask': np.array([1, 0, 0, 1]), 'end_mask': np.array([0, 1, 0, 1]),
186 'new_axis_mask': np.array([0, 0, 0]), 'shrink_axis_mask': [1, 0, 0],
187 'ellipsis_mask': np.array([0, 0, 0])},
188 'data_2': {'shape': np.array([2, 3, 4]), 'value': None},
191 slice_node = Node(graph, 'strided_slice')
192 permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'begin_mask')
194 self.assertTrue(np.array_equal(slice_node.begin_mask, np.array([1, 1, 0, 0])))
196 permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'end_mask')
197 self.assertTrue(np.array_equal(slice_node.end_mask, np.array([0, 1, 1, 0])))
199 def test_permute_begin_end_shrink_short(self):
200 # Testing constant path case
201 graph = build_graph(nodes_attributes,
202 [('data_1', 'strided_slice'),
203 ('begin', 'strided_slice'),
204 ('end', 'strided_slice'),
205 ('stride', 'strided_slice'),
206 ('strided_slice', 'data_2')],
207 {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None},
208 'strided_slice': {'begin_mask': np.array([1, 0, 0]), 'end_mask': np.array([0, 1, 0]),
209 'new_axis_mask': np.array([0, 0, 0]), 'shrink_axis_mask': [1, 0, 0],
210 'ellipsis_mask': np.array([0, 0, 0])},
211 'data_2': {'shape': np.array([2, 3, 4]), 'value': None},
214 slice_node = Node(graph, 'strided_slice')
215 permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'begin_mask')
216 self.assertTrue(np.array_equal(slice_node.begin_mask, np.array([1, 1, 0, 0])))
218 permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'end_mask')
219 self.assertTrue(np.array_equal(slice_node.end_mask, np.array([0, 1, 1, 0])))
221 def test_permute_begin_end_ellipsis(self):
222 # Testing constant path case
223 graph = build_graph(nodes_attributes,
224 [('data_1', 'strided_slice'),
225 ('begin', 'strided_slice'),
226 ('end', 'strided_slice'),
227 ('stride', 'strided_slice'),
228 ('strided_slice', 'data_2')],
229 {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None},
230 'strided_slice': {'begin_mask': np.array([0, 0]), 'end_mask': np.array([1, 0]),
231 'new_axis_mask': np.array([0]), 'shrink_axis_mask': [0],
232 'ellipsis_mask': np.array([1, 0])},
233 'data_2': {'shape': np.array([1, 2, 3, 4]), 'value': None},
236 slice_node = Node(graph, 'strided_slice')
237 permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'begin_mask')
238 self.assertTrue(np.array_equal(slice_node.begin_mask, np.array([0, 0, 1, 1])))
240 permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]), 'end_mask')
241 self.assertTrue(np.array_equal(slice_node.end_mask, np.array([1, 0, 1, 1])))
243 def test_permute_begin_end_ellipsis_new(self):
244 # Testing constant path case
245 graph = build_graph(nodes_attributes,
246 [('data_1', 'strided_slice'),
247 ('begin', 'strided_slice'),
248 ('end', 'strided_slice'),
249 ('stride', 'strided_slice'),
250 ('strided_slice', 'data_2')],
251 {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None},
252 'strided_slice': {'begin_mask': np.array([0, 0, 0]), 'end_mask': np.array([1, 0, 0]),
253 'new_axis_mask': np.array([1, 0, 0]), 'shrink_axis_mask': [0],
254 'ellipsis_mask': np.array([0, 1, 0])},
255 'data_2': {'shape': np.array([1, 1, 2, 3, 4]), 'value': None},
258 slice_node = Node(graph, 'strided_slice')
259 permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 4, 1, 2, 3], inv=[0, 2, 3, 4, 1]), 'begin_mask')
260 self.assertTrue(np.array_equal(slice_node.begin_mask, np.array([0, 0, 0, 1, 1])))
262 permute_masks(slice_node, PermuteAttrs.Permutation(perm=[0, 4, 1, 2, 3], inv=[0, 2, 3, 4, 1]), 'end_mask')
263 self.assertTrue(np.array_equal(slice_node.end_mask, np.array([1, 0, 0, 1, 1])))
265 def test_permute_begin_end_ellipsis_new_inputs(self):
266 # Testing constant path case
267 graph = build_graph(nodes_attributes,
268 [('data_1', 'strided_slice'),
269 ('begin', 'strided_slice'),
270 ('end', 'strided_slice'),
271 ('stride', 'strided_slice'),
272 ('strided_slice', 'data_2')],
273 {'data_1': {'shape': np.array([1, 2, 3, 4]), 'value': None},
274 'strided_slice': {'begin_mask': np.array([0, 0, 0]), 'end_mask': np.array([1, 0, 0]),
275 'new_axis_mask': np.array([1, 0, 0]), 'shrink_axis_mask': [0],
276 'ellipsis_mask': np.array([0, 1, 0])},
277 'begin': {'value': np.array([0, 1, 2])},
278 'end': {'value': np.array([1, 2, 3])},
279 'stride': {'value': np.array([1, 1, 1])},
280 'data_2': {'shape': np.array([1, 1, 2, 3, 4]), 'value': None},
283 slice_node = Node(graph, 'strided_slice')
284 slice_node.in_node(1).value = permute_array_with_ellipsis(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]),
285 slice_node.in_node(1).value, 0)
286 self.assertTrue(np.array_equal(slice_node.in_node(1).value, np.array([0, 2, 1, 0, 0])))
288 slice_node.in_node(2).value = permute_array_with_ellipsis(slice_node, PermuteAttrs.Permutation(perm=[0, 3, 1, 2], inv=[0, 2, 3, 1]),
289 slice_node.in_node(2).value, 0)
290 self.assertTrue(np.array_equal(slice_node.in_node(2).value, np.array([1, 3, 2, 0, 0])))