Add a section of how to link IE with CMake project (#99)
[platform/upstream/dldt.git] / model-optimizer / extensions / middle / AddReshapeAfterStridedSlice_test.py
1 """
2  Copyright (c) 2018 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import numpy as np
18 import unittest
19
20 from extensions.middle.AddReshapeAfterStridedSlice import AddReshapeAfterStridedSlice
21 from mo.graph.graph import Node
22 from mo.middle.passes.fusing.fuse_linear_ops_test import compare_graphs
23 from mo.middle.passes.eliminate_test import build_graph
24
25 # The dictionary with nodes attributes used to build various graphs. A key is the name of the node and the value is the
26 # dictionary with node attributes.
27 nodes_attributes_test = {
28     'placeholder_1': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
29     'placeholder_1_data': {'shape': None, 'kind': 'data', 'data_type': None},
30     'placeholder_2': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
31     'placeholder_2_data': {'shape': None, 'kind': 'data', 'data_type': None},
32     'placeholder_begin_data': {'shape': None, 'kind': 'data', 'data_type': None},
33     'placeholder_end_data': {'shape': None, 'kind': 'data', 'data_type': None},
34     'placeholder_stride_data': {'shape': None, 'kind': 'data', 'data_type': None},
35     # StridedSlice layers
36     'sslice_1': {'type': 'StridedSlice', 'kind': 'op', 'op': 'StridedSlice', 'slices': None,
37                  'shrink_axis_mask': np.array([False, False, True, False]),
38                  'new_axis_mask': np.array([False, False, False, False])},
39     'sslice_1_data': {'shape': None, 'kind': 'data'},
40     'sslice_2': {'type': 'StridedSlice', 'kind': 'op', 'op': 'StridedSlice', 'slices': None,
41                  'shrink_axis_mask': np.array([False, False, True, False]),
42                  'new_axis_mask': np.array([False, False, False, False])},
43     'sslice_2_data': {'shape': None, 'kind': 'data'}}
44
45 nodes_reshape = {
46     'placeholder_1': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
47     'placeholder_1_data': {'shape': None, 'kind': 'data', 'data_type': None},
48     'placeholder_2': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
49     'placeholder_2_data': {'shape': None, 'kind': 'data', 'data_type': None},
50     'placeholder_begin_data': {'shape': None, 'kind': 'data', 'data_type': None},
51     'placeholder_end_data': {'shape': None, 'kind': 'data', 'data_type': None},
52     'placeholder_stride_data': {'shape': None, 'kind': 'data', 'data_type': None},
53     # StridedSlice layers
54     'sslice_1': {'type': 'StridedSlice', 'value': None, 'kind': 'op', 'op': 'StridedSlice', 'slices': None,
55                  'shrink_axis_mask': np.array([False, False, True, False]),
56                  'new_axis_mask': np.array([False, False, False, False])},
57     'sslice_1_data': {'value': None, 'shape': None, 'kind': 'data'},
58     'sslice_2': {'type': 'StridedSlice', 'value': None, 'kind': 'op', 'op': 'StridedSlice', 'slices': None,
59                  'shrink_axis_mask': np.array([False, False, True, False]),
60                  'new_axis_mask': np.array([False, False, False, False])},
61     'sslice_2_data': {'value': None, 'shape': None, 'kind': 'data'},
62     # Reshape layer
63     'sslice_1/Reshape_shrink': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
64     'sslice_1/Reshape_shrink_data': {'value': None, 'shape': None, 'kind': 'data'},
65     'sslice_2/Reshape_shrink': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
66     'sslice_2/Reshape_shrink_data': {'value': None, 'shape': None, 'kind': 'data'},
67     'sslice_2/Reshape_new': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
68     'sslice_2/Reshape_new_data': {'value': None, 'shape': None, 'kind': 'data'},
69 }
70
71
72 class AddReshapeAfterStridedSliceTests(unittest.TestCase):
73     def test_ss_1_shrink_last(self):
74         graph = build_graph(nodes_attributes_test,
75                             [('placeholder_1', 'placeholder_1_data'),
76                              ('placeholder_1_data', 'sslice_1'),
77                              ('placeholder_begin_data', 'sslice_1'),
78                              ('placeholder_end_data', 'sslice_1'),
79                              ('placeholder_stride_data', 'sslice_1'),
80                              ('sslice_1', 'sslice_1_data')],
81                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
82                              'sslice_1': {'slices': np.array(
83                                  [slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)])},
84                              'sslice_1_data': {'shape': np.array([1, 227, 54]), 'is_output': True},
85                              })
86         graph.graph['layout'] = 'NHWC'
87
88         graph_ref = build_graph(nodes_reshape,
89                                 [('placeholder_1', 'placeholder_1_data'),
90                                  ('placeholder_1_data', 'sslice_1'),
91                                  ('placeholder_begin_data', 'sslice_1'),
92                                  ('placeholder_end_data', 'sslice_1'),
93                                  ('placeholder_stride_data', 'sslice_1'),
94                                  ('sslice_1', 'sslice_1/Reshape_shrink_data'),
95                                  ('sslice_1/Reshape_shrink_data', 'sslice_1/Reshape_shrink'),
96                                  ('sslice_1/Reshape_shrink', 'sslice_1_data')],
97                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
98                                  'sslice_1': {'slices': np.array(
99                                      [slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)]),
100                                      'shrink_axis_mask': np.array([False, False, False, False]),
101                                      'new_axis_mask': np.array([False, False, False, False])},
102                                  'sslice_1_data': {'shape': np.array([1, 227, 54]), 'is_output': True},
103                                  'sslice_1/Reshape_shrink': {'dim': np.array([1, 227, 54])},
104                                  'sslice_1/Reshape_shrink_data': {'shape': np.array([1, 227, 1, 54])}
105                                  })
106
107         pattern = AddReshapeAfterStridedSlice()
108         pattern.find_and_replace_pattern(graph)
109
110         (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_1_data', check_op_attrs=True)
111         graph.clear()
112         graph_ref.clear()
113         self.assertTrue(flag, resp)
114
115     def test_ss_1_shrink(self):
116         graph = build_graph(nodes_attributes_test,
117                             [('placeholder_1', 'placeholder_1_data'),
118                              ('placeholder_1_data', 'sslice_2'),
119                              ('placeholder_begin_data', 'sslice_2'),
120                              ('placeholder_end_data', 'sslice_2'),
121                              ('placeholder_stride_data', 'sslice_2'),
122                              ('sslice_2', 'sslice_2_data'),
123                              ('sslice_2_data', 'placeholder_2'),
124                              ('placeholder_2', 'placeholder_2_data'), ],
125                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
126                              'sslice_2': {'slices': np.array(
127                                  [slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)]), },
128                              'sslice_2_data': {'shape': np.array([1, 227, 54]), 'is_output': True}
129                              })
130         graph.graph['layout'] = 'NHWC'
131
132         graph_ref = build_graph(nodes_reshape,
133                                 [('placeholder_1', 'placeholder_1_data'),
134                                  ('placeholder_1_data', 'sslice_2'),
135                                  ('placeholder_begin_data', 'sslice_2'),
136                                  ('placeholder_end_data', 'sslice_2'),
137                                  ('placeholder_stride_data', 'sslice_2'),
138                                  ('sslice_2', 'sslice_2/Reshape_shrink_data'),
139                                  ('sslice_2/Reshape_shrink_data', 'sslice_2/Reshape_shrink'),
140                                  ('sslice_2/Reshape_shrink', 'sslice_2_data'),
141                                  ('sslice_2_data', 'placeholder_2'),
142                                  ('placeholder_2', 'placeholder_2_data')],
143                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
144                                  'sslice_2': {'slices': np.array(
145                                      [slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)]),
146                                      'shrink_axis_mask': np.array([False, False, False, False]),
147                                      'new_axis_mask': np.array([False, False, False, False])},
148                                  'sslice_2_data': {'shape': np.array([1, 227, 54])},
149                                  'sslice_2/Reshape_shrink': {'dim': np.array([1, 227, 54])},
150                                  'sslice_2/Reshape_shrink_data': {'shape': np.array([1, 227, 1, 54])},
151                                  })
152
153         pattern = AddReshapeAfterStridedSlice()
154         pattern.find_and_replace_pattern(graph)
155
156         (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_2_data', check_op_attrs=True)
157         graph.clear()
158         graph_ref.clear()
159         self.assertTrue(flag, resp)
160
161     def test_ss_2_shrink(self):
162         graph = build_graph(nodes_attributes_test,
163                             [('placeholder_1', 'placeholder_1_data'),
164                              ('placeholder_1_data', 'sslice_2'),
165                              ('placeholder_begin_data', 'sslice_2'),
166                              ('placeholder_end_data', 'sslice_2'),
167                              ('placeholder_stride_data', 'sslice_2'),
168                              ('sslice_2', 'sslice_2_data'),
169                              ('sslice_2_data', 'placeholder_2'),
170                              ('placeholder_2', 'placeholder_2_data'), ],
171                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
172                              'sslice_2': {
173                                  'slices': np.array([slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1)]),
174                                  'shrink_axis_mask': np.array([False, True, False, True])},
175                              'sslice_2_data': {'shape': np.array([1, 227]), 'is_output': True}
176                              })
177         graph.graph['layout'] = 'NHWC'
178
179         graph_ref = build_graph(nodes_reshape,
180                                 [('placeholder_1', 'placeholder_1_data'),
181                                  ('placeholder_1_data', 'sslice_2'),
182                                  ('placeholder_begin_data', 'sslice_2'),
183                                  ('placeholder_end_data', 'sslice_2'),
184                                  ('placeholder_stride_data', 'sslice_2'),
185                                  ('sslice_2', 'sslice_2/Reshape_shrink_data'),
186                                  ('sslice_2/Reshape_shrink_data', 'sslice_2/Reshape_shrink'),
187                                  ('sslice_2/Reshape_shrink', 'sslice_2_data'),
188                                  ('sslice_2_data', 'placeholder_2'),
189                                  ('placeholder_2', 'placeholder_2_data')],
190                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
191                                  'sslice_2': {'slices': np.array(
192                                      [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1)]),
193                                      'shrink_axis_mask': np.array([False, False, False, False]),
194                                      'new_axis_mask': np.array([False, False, False, False])},
195                                  'sslice_2_data': {'shape': np.array([1, 227])},
196                                  'sslice_2/Reshape_shrink': {'dim': np.array([1, 227])},
197                                  'sslice_2/Reshape_shrink_data': {'shape': np.array([1, 1, 227, 1])},
198                                  })
199
200         pattern = AddReshapeAfterStridedSlice()
201         pattern.find_and_replace_pattern(graph)
202
203         (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_2_data', check_op_attrs=True)
204         graph.clear()
205         graph_ref.clear()
206         self.assertTrue(flag, resp)
207
208     def test_ss_1_new(self):
209         graph = build_graph(nodes_attributes_test,
210                             [('placeholder_1', 'placeholder_1_data'),
211                              ('placeholder_1_data', 'sslice_2'),
212                              ('placeholder_begin_data', 'sslice_2'),
213                              ('placeholder_end_data', 'sslice_2'),
214                              ('placeholder_stride_data', 'sslice_2'),
215                              ('sslice_2', 'sslice_2_data'),
216                              ('sslice_2_data', 'placeholder_2'),
217                              ('placeholder_2', 'placeholder_2_data'), ],
218                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
219                              'sslice_2': {'slices': np.array(
220                                  [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 54, 1)]),
221                                  'shrink_axis_mask': np.array([False, False, False, False, False]),
222                                  'new_axis_mask': np.array([False, True, False, False, False])},
223                              'sslice_2_data': {'shape': np.array([1, 1, 227, 227, 54])}
224                              })
225         graph.graph['layout'] = 'NHWC'
226
227         graph_ref = build_graph(nodes_reshape,
228                                 [('placeholder_1', 'placeholder_1_data'),
229                                  ('placeholder_1_data', 'sslice_2'),
230                                  ('placeholder_begin_data', 'sslice_2'),
231                                  ('placeholder_end_data', 'sslice_2'),
232                                  ('placeholder_stride_data', 'sslice_2'),
233                                  ('sslice_2', 'sslice_2/Reshape_new_data'),
234                                  ('sslice_2/Reshape_new_data', 'sslice_2/Reshape_new'),
235                                  ('sslice_2/Reshape_new', 'sslice_2_data'),
236                                  ('sslice_2_data', 'placeholder_2'),
237                                  ('placeholder_2', 'placeholder_2_data')],
238                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
239                                  'sslice_2': {'slices': np.array(
240                                      [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1),
241                                       slice(0, 54, 1)]),
242                                      'shrink_axis_mask': np.array([False, False, False, False, False]),
243                                      'new_axis_mask': np.array([False, False, False, False, False])},
244                                  'sslice_2_data': {'shape': np.array([1, 1, 227, 227, 54])},
245                                  'sslice_2/Reshape_new': {'dim': np.array([1, 1, 227, 227, 54])},
246                                  'sslice_2/Reshape_new_data': {'shape': np.array([1, 227, 227, 54])},
247                                  })
248
249         pattern = AddReshapeAfterStridedSlice()
250         pattern.find_and_replace_pattern(graph)
251
252         (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_2_data', check_op_attrs=True)
253         graph.clear()
254         graph_ref.clear()
255         self.assertTrue(flag, resp)
256
257     def test_ss_shrink_new(self):
258         graph = build_graph(nodes_attributes_test,
259                             [('placeholder_1', 'placeholder_1_data'),
260                              ('placeholder_1_data', 'sslice_2'),
261                              ('placeholder_begin_data', 'sslice_2'),
262                              ('placeholder_end_data', 'sslice_2'),
263                              ('placeholder_stride_data', 'sslice_2'),
264                              ('sslice_2', 'sslice_2_data'),
265                              ('sslice_2_data', 'placeholder_2'),
266                              ('placeholder_2', 'placeholder_2_data'), ],
267                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
268                              'sslice_2': {'slices': np.array(
269                                  [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)]),
270                                  'shrink_axis_mask': np.array([False, False, False, True, False]),
271                                  'new_axis_mask': np.array([False, True, False, False, False])},
272                              'sslice_2_data': {'shape': np.array([1, 1, 227, 54]), 'is_output': True}
273                              })
274         graph.graph['layout'] = 'NHWC'
275
276         graph_ref = build_graph(nodes_reshape,
277                                 [('placeholder_1', 'placeholder_1_data'),
278                                  ('placeholder_1_data', 'sslice_2'),
279                                  ('placeholder_begin_data', 'sslice_2'),
280                                  ('placeholder_end_data', 'sslice_2'),
281                                  ('placeholder_stride_data', 'sslice_2'),
282                                  ('sslice_2', 'sslice_2/Reshape_new_data'),
283                                  ('sslice_2/Reshape_new_data', 'sslice_2/Reshape_new'),
284                                  ('sslice_2/Reshape_new', 'sslice_2/Reshape_shrink_data'),
285                                  ('sslice_2/Reshape_shrink_data', 'sslice_2/Reshape_shrink'),
286                                  ('sslice_2/Reshape_shrink', 'sslice_2_data'),
287                                  ('sslice_2_data', 'placeholder_2'),
288                                  ('placeholder_2', 'placeholder_2_data')],
289                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
290                                  'sslice_2': {'slices': np.array(
291                                      [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1),
292                                       slice(0, 54, 1)]),
293                                      'shrink_axis_mask': np.array([False, False, False, False, False]),
294                                      'new_axis_mask': np.array([False, False, False, False, False])},
295                                  'sslice_2_data': {'shape': np.array([1, 1, 227, 54])},
296                                  'sslice_2/Reshape_new': {'dim': np.array([1, 1, 227, 1, 54])},
297                                  'sslice_2/Reshape_new_data': {'shape': np.array([1, 227, 1, 54])},
298                                  'sslice_2/Reshape_shrink': {'dim': np.array([1, 1, 227, 54])},
299                                  'sslice_2/Reshape_shrink_data': {'shape': np.array([1, 1, 227, 1, 54])},
300                                  })
301
302         pattern = AddReshapeAfterStridedSlice()
303         pattern.find_and_replace_pattern(graph)
304
305         (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_2_data', check_op_attrs=True)
306         graph.clear()
307         graph_ref.clear()
308         self.assertTrue(flag, resp)
309
310
311 if __name__ == '__main__':
312     unittest.main()