2 Copyright (c) 2018 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
20 from extensions.middle.AddReshapeAfterStridedSlice import AddReshapeAfterStridedSlice
21 from mo.graph.graph import Node
22 from mo.middle.passes.fusing.fuse_linear_ops_test import compare_graphs
23 from mo.middle.passes.eliminate_test import build_graph
25 # The dictionary with nodes attributes used to build various graphs. A key is the name of the node and the value is the
26 # dictionary with node attributes.
27 nodes_attributes_test = {
28 'placeholder_1': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
29 'placeholder_1_data': {'shape': None, 'kind': 'data', 'data_type': None},
30 'placeholder_2': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
31 'placeholder_2_data': {'shape': None, 'kind': 'data', 'data_type': None},
32 'placeholder_begin_data': {'shape': None, 'kind': 'data', 'data_type': None},
33 'placeholder_end_data': {'shape': None, 'kind': 'data', 'data_type': None},
34 'placeholder_stride_data': {'shape': None, 'kind': 'data', 'data_type': None},
36 'sslice_1': {'type': 'StridedSlice', 'kind': 'op', 'op': 'StridedSlice', 'slices': None,
37 'shrink_axis_mask': np.array([False, False, True, False]),
38 'new_axis_mask': np.array([False, False, False, False])},
39 'sslice_1_data': {'shape': None, 'kind': 'data'},
40 'sslice_2': {'type': 'StridedSlice', 'kind': 'op', 'op': 'StridedSlice', 'slices': None,
41 'shrink_axis_mask': np.array([False, False, True, False]),
42 'new_axis_mask': np.array([False, False, False, False])},
43 'sslice_2_data': {'shape': None, 'kind': 'data'}}
46 'placeholder_1': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
47 'placeholder_1_data': {'shape': None, 'kind': 'data', 'data_type': None},
48 'placeholder_2': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
49 'placeholder_2_data': {'shape': None, 'kind': 'data', 'data_type': None},
50 'placeholder_begin_data': {'shape': None, 'kind': 'data', 'data_type': None},
51 'placeholder_end_data': {'shape': None, 'kind': 'data', 'data_type': None},
52 'placeholder_stride_data': {'shape': None, 'kind': 'data', 'data_type': None},
54 'sslice_1': {'type': 'StridedSlice', 'value': None, 'kind': 'op', 'op': 'StridedSlice', 'slices': None,
55 'shrink_axis_mask': np.array([False, False, True, False]),
56 'new_axis_mask': np.array([False, False, False, False])},
57 'sslice_1_data': {'value': None, 'shape': None, 'kind': 'data'},
58 'sslice_2': {'type': 'StridedSlice', 'value': None, 'kind': 'op', 'op': 'StridedSlice', 'slices': None,
59 'shrink_axis_mask': np.array([False, False, True, False]),
60 'new_axis_mask': np.array([False, False, False, False])},
61 'sslice_2_data': {'value': None, 'shape': None, 'kind': 'data'},
63 'sslice_1/Reshape_shrink': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
64 'sslice_1/Reshape_shrink_data': {'value': None, 'shape': None, 'kind': 'data'},
65 'sslice_2/Reshape_shrink': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
66 'sslice_2/Reshape_shrink_data': {'value': None, 'shape': None, 'kind': 'data'},
67 'sslice_2/Reshape_new': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
68 'sslice_2/Reshape_new_data': {'value': None, 'shape': None, 'kind': 'data'},
72 class AddReshapeAfterStridedSliceTests(unittest.TestCase):
73 def test_ss_1_shrink_last(self):
74 graph = build_graph(nodes_attributes_test,
75 [('placeholder_1', 'placeholder_1_data'),
76 ('placeholder_1_data', 'sslice_1'),
77 ('placeholder_begin_data', 'sslice_1'),
78 ('placeholder_end_data', 'sslice_1'),
79 ('placeholder_stride_data', 'sslice_1'),
80 ('sslice_1', 'sslice_1_data')],
81 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
82 'sslice_1': {'slices': np.array(
83 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)])},
84 'sslice_1_data': {'shape': np.array([1, 227, 54]), 'is_output': True},
86 graph.graph['layout'] = 'NHWC'
88 graph_ref = build_graph(nodes_reshape,
89 [('placeholder_1', 'placeholder_1_data'),
90 ('placeholder_1_data', 'sslice_1'),
91 ('placeholder_begin_data', 'sslice_1'),
92 ('placeholder_end_data', 'sslice_1'),
93 ('placeholder_stride_data', 'sslice_1'),
94 ('sslice_1', 'sslice_1/Reshape_shrink_data'),
95 ('sslice_1/Reshape_shrink_data', 'sslice_1/Reshape_shrink'),
96 ('sslice_1/Reshape_shrink', 'sslice_1_data')],
97 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
98 'sslice_1': {'slices': np.array(
99 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)]),
100 'shrink_axis_mask': np.array([False, False, False, False]),
101 'new_axis_mask': np.array([False, False, False, False])},
102 'sslice_1_data': {'shape': np.array([1, 227, 54]), 'is_output': True},
103 'sslice_1/Reshape_shrink': {'dim': np.array([1, 227, 54])},
104 'sslice_1/Reshape_shrink_data': {'shape': np.array([1, 227, 1, 54])}
107 pattern = AddReshapeAfterStridedSlice()
108 pattern.find_and_replace_pattern(graph)
110 (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_1_data', check_op_attrs=True)
113 self.assertTrue(flag, resp)
115 def test_ss_1_shrink(self):
116 graph = build_graph(nodes_attributes_test,
117 [('placeholder_1', 'placeholder_1_data'),
118 ('placeholder_1_data', 'sslice_2'),
119 ('placeholder_begin_data', 'sslice_2'),
120 ('placeholder_end_data', 'sslice_2'),
121 ('placeholder_stride_data', 'sslice_2'),
122 ('sslice_2', 'sslice_2_data'),
123 ('sslice_2_data', 'placeholder_2'),
124 ('placeholder_2', 'placeholder_2_data'), ],
125 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
126 'sslice_2': {'slices': np.array(
127 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)]), },
128 'sslice_2_data': {'shape': np.array([1, 227, 54]), 'is_output': True}
130 graph.graph['layout'] = 'NHWC'
132 graph_ref = build_graph(nodes_reshape,
133 [('placeholder_1', 'placeholder_1_data'),
134 ('placeholder_1_data', 'sslice_2'),
135 ('placeholder_begin_data', 'sslice_2'),
136 ('placeholder_end_data', 'sslice_2'),
137 ('placeholder_stride_data', 'sslice_2'),
138 ('sslice_2', 'sslice_2/Reshape_shrink_data'),
139 ('sslice_2/Reshape_shrink_data', 'sslice_2/Reshape_shrink'),
140 ('sslice_2/Reshape_shrink', 'sslice_2_data'),
141 ('sslice_2_data', 'placeholder_2'),
142 ('placeholder_2', 'placeholder_2_data')],
143 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
144 'sslice_2': {'slices': np.array(
145 [slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)]),
146 'shrink_axis_mask': np.array([False, False, False, False]),
147 'new_axis_mask': np.array([False, False, False, False])},
148 'sslice_2_data': {'shape': np.array([1, 227, 54])},
149 'sslice_2/Reshape_shrink': {'dim': np.array([1, 227, 54])},
150 'sslice_2/Reshape_shrink_data': {'shape': np.array([1, 227, 1, 54])},
153 pattern = AddReshapeAfterStridedSlice()
154 pattern.find_and_replace_pattern(graph)
156 (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_2_data', check_op_attrs=True)
159 self.assertTrue(flag, resp)
161 def test_ss_2_shrink(self):
162 graph = build_graph(nodes_attributes_test,
163 [('placeholder_1', 'placeholder_1_data'),
164 ('placeholder_1_data', 'sslice_2'),
165 ('placeholder_begin_data', 'sslice_2'),
166 ('placeholder_end_data', 'sslice_2'),
167 ('placeholder_stride_data', 'sslice_2'),
168 ('sslice_2', 'sslice_2_data'),
169 ('sslice_2_data', 'placeholder_2'),
170 ('placeholder_2', 'placeholder_2_data'), ],
171 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
173 'slices': np.array([slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1)]),
174 'shrink_axis_mask': np.array([False, True, False, True])},
175 'sslice_2_data': {'shape': np.array([1, 227]), 'is_output': True}
177 graph.graph['layout'] = 'NHWC'
179 graph_ref = build_graph(nodes_reshape,
180 [('placeholder_1', 'placeholder_1_data'),
181 ('placeholder_1_data', 'sslice_2'),
182 ('placeholder_begin_data', 'sslice_2'),
183 ('placeholder_end_data', 'sslice_2'),
184 ('placeholder_stride_data', 'sslice_2'),
185 ('sslice_2', 'sslice_2/Reshape_shrink_data'),
186 ('sslice_2/Reshape_shrink_data', 'sslice_2/Reshape_shrink'),
187 ('sslice_2/Reshape_shrink', 'sslice_2_data'),
188 ('sslice_2_data', 'placeholder_2'),
189 ('placeholder_2', 'placeholder_2_data')],
190 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
191 'sslice_2': {'slices': np.array(
192 [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1)]),
193 'shrink_axis_mask': np.array([False, False, False, False]),
194 'new_axis_mask': np.array([False, False, False, False])},
195 'sslice_2_data': {'shape': np.array([1, 227])},
196 'sslice_2/Reshape_shrink': {'dim': np.array([1, 227])},
197 'sslice_2/Reshape_shrink_data': {'shape': np.array([1, 1, 227, 1])},
200 pattern = AddReshapeAfterStridedSlice()
201 pattern.find_and_replace_pattern(graph)
203 (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_2_data', check_op_attrs=True)
206 self.assertTrue(flag, resp)
208 def test_ss_1_new(self):
209 graph = build_graph(nodes_attributes_test,
210 [('placeholder_1', 'placeholder_1_data'),
211 ('placeholder_1_data', 'sslice_2'),
212 ('placeholder_begin_data', 'sslice_2'),
213 ('placeholder_end_data', 'sslice_2'),
214 ('placeholder_stride_data', 'sslice_2'),
215 ('sslice_2', 'sslice_2_data'),
216 ('sslice_2_data', 'placeholder_2'),
217 ('placeholder_2', 'placeholder_2_data'), ],
218 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
219 'sslice_2': {'slices': np.array(
220 [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1), slice(0, 54, 1)]),
221 'shrink_axis_mask': np.array([False, False, False, False, False]),
222 'new_axis_mask': np.array([False, True, False, False, False])},
223 'sslice_2_data': {'shape': np.array([1, 1, 227, 227, 54])}
225 graph.graph['layout'] = 'NHWC'
227 graph_ref = build_graph(nodes_reshape,
228 [('placeholder_1', 'placeholder_1_data'),
229 ('placeholder_1_data', 'sslice_2'),
230 ('placeholder_begin_data', 'sslice_2'),
231 ('placeholder_end_data', 'sslice_2'),
232 ('placeholder_stride_data', 'sslice_2'),
233 ('sslice_2', 'sslice_2/Reshape_new_data'),
234 ('sslice_2/Reshape_new_data', 'sslice_2/Reshape_new'),
235 ('sslice_2/Reshape_new', 'sslice_2_data'),
236 ('sslice_2_data', 'placeholder_2'),
237 ('placeholder_2', 'placeholder_2_data')],
238 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
239 'sslice_2': {'slices': np.array(
240 [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 227, 1),
242 'shrink_axis_mask': np.array([False, False, False, False, False]),
243 'new_axis_mask': np.array([False, False, False, False, False])},
244 'sslice_2_data': {'shape': np.array([1, 1, 227, 227, 54])},
245 'sslice_2/Reshape_new': {'dim': np.array([1, 1, 227, 227, 54])},
246 'sslice_2/Reshape_new_data': {'shape': np.array([1, 227, 227, 54])},
249 pattern = AddReshapeAfterStridedSlice()
250 pattern.find_and_replace_pattern(graph)
252 (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_2_data', check_op_attrs=True)
255 self.assertTrue(flag, resp)
257 def test_ss_shrink_new(self):
258 graph = build_graph(nodes_attributes_test,
259 [('placeholder_1', 'placeholder_1_data'),
260 ('placeholder_1_data', 'sslice_2'),
261 ('placeholder_begin_data', 'sslice_2'),
262 ('placeholder_end_data', 'sslice_2'),
263 ('placeholder_stride_data', 'sslice_2'),
264 ('sslice_2', 'sslice_2_data'),
265 ('sslice_2_data', 'placeholder_2'),
266 ('placeholder_2', 'placeholder_2_data'), ],
267 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
268 'sslice_2': {'slices': np.array(
269 [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1), slice(0, 54, 1)]),
270 'shrink_axis_mask': np.array([False, False, False, True, False]),
271 'new_axis_mask': np.array([False, True, False, False, False])},
272 'sslice_2_data': {'shape': np.array([1, 1, 227, 54]), 'is_output': True}
274 graph.graph['layout'] = 'NHWC'
276 graph_ref = build_graph(nodes_reshape,
277 [('placeholder_1', 'placeholder_1_data'),
278 ('placeholder_1_data', 'sslice_2'),
279 ('placeholder_begin_data', 'sslice_2'),
280 ('placeholder_end_data', 'sslice_2'),
281 ('placeholder_stride_data', 'sslice_2'),
282 ('sslice_2', 'sslice_2/Reshape_new_data'),
283 ('sslice_2/Reshape_new_data', 'sslice_2/Reshape_new'),
284 ('sslice_2/Reshape_new', 'sslice_2/Reshape_shrink_data'),
285 ('sslice_2/Reshape_shrink_data', 'sslice_2/Reshape_shrink'),
286 ('sslice_2/Reshape_shrink', 'sslice_2_data'),
287 ('sslice_2_data', 'placeholder_2'),
288 ('placeholder_2', 'placeholder_2_data')],
289 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 54])},
290 'sslice_2': {'slices': np.array(
291 [slice(0, 1, 1), slice(0, 1, 1), slice(0, 227, 1), slice(0, 1, 1),
293 'shrink_axis_mask': np.array([False, False, False, False, False]),
294 'new_axis_mask': np.array([False, False, False, False, False])},
295 'sslice_2_data': {'shape': np.array([1, 1, 227, 54])},
296 'sslice_2/Reshape_new': {'dim': np.array([1, 1, 227, 1, 54])},
297 'sslice_2/Reshape_new_data': {'shape': np.array([1, 227, 1, 54])},
298 'sslice_2/Reshape_shrink': {'dim': np.array([1, 1, 227, 54])},
299 'sslice_2/Reshape_shrink_data': {'shape': np.array([1, 1, 227, 1, 54])},
302 pattern = AddReshapeAfterStridedSlice()
303 pattern.find_and_replace_pattern(graph)
305 (flag, resp) = compare_graphs(graph, graph_ref, 'sslice_2_data', check_op_attrs=True)
308 self.assertTrue(flag, resp)
311 if __name__ == '__main__':