2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from extensions.middle.Reduce import ReduceReplacer
22 from mo.middle.passes.eliminate_test import build_graph
23 from mo.middle.passes.fusing.fuse_linear_ops_test import compare_graphs
25 # The dictionary with nodes attributes used to build various graphs. A key is the name of the node and the value is the
26 # dictionary with node attributes.
29 'placeholder_1': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
30 'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
31 'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
32 'placeholder_3_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
33 'placeholder_4_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
36 'reduce_1': {'type': 'Reduce', 'kind': 'op', 'op': 'Reduce'},
37 'reduce_1_data': {'value': None, 'shape': None, 'kind': 'data'},
40 'reshape_1': {'type': 'Reshape', 'kind': 'op', 'op': 'Reshape'},
41 'reshape_1_data': {'value': None, 'shape': None, 'kind': 'data'},
43 'reshape_2': {'type': 'Reshape', 'kind': 'op', 'op': 'Reshape'},
44 'reshape_2_data': {'value': None, 'shape': None, 'kind': 'data'},
47 'pooling': {'type': 'Pooling', 'kind': 'op', 'op': 'Pooling'},
48 'pooling_data': {'value': None, 'shape': None, 'kind': 'data'},
51 'power': {'type': 'Power', 'kind': 'op', 'op': 'Power'},
52 'power_data': {'value': None, 'shape': None, 'kind': 'data'},
55 'concat': {'type': 'Concat', 'kind': 'op', 'op': 'Concat'},
59 class ReduceReplacerTest(unittest.TestCase):
62 # data(1,64,1)-->Reduce(axis=1,keep_dims=True)-->data(1,1,1)
65 # data(1,61,1)->Reshape(1,1,64,1)->Pool(1,1,1,1)->Reshape(1,1,1)
67 graph = build_graph(nodes_attributes,
68 [('placeholder_1_data', 'reduce_1'),
69 ('reduce_1', 'reduce_1_data'),
70 ('reduce_1_data', 'concat'),
72 {'placeholder_1_data': {'shape': np.array([1, 64, 1])},
73 'reduce_1': {'axis': np.array([1]), 'keep_dims': True, 'reduce_type': 'Mean'},
74 'reduce_1_data': {'shape': np.array([1, 1, 1])},
75 }, nodes_with_edges_only=True)
77 graph.graph['layout'] = 'NCHW'
79 graph_ref = build_graph(nodes_attributes,
80 [('placeholder_1_data', 'reshape_1'),
81 ('reshape_1', 'reshape_1_data'),
82 ('reshape_1_data', 'pooling'),
83 ('pooling', 'pooling_data'),
84 ('pooling_data', 'reshape_2'),
85 ('reshape_2', 'reshape_2_data'),
86 ('reshape_2_data', 'concat'),
88 {'placeholder_1_data': {'shape': np.array([1, 64, 1])},
89 'reshape_1': {'dim': np.array([1, 1, 64, 1])},
90 'reshape_1_data': {'shape': np.array([1, 1, 64, 1])},
91 'pooling': {'window': np.array([1, 1, 64, 1])},
92 'pooling_data': {'shape': np.array([1, 1, 1, 1])},
93 'reshape_2': {'dim': np.array([1, 1, 1])},
94 'reshape_2_data': {'shape': np.array([1, 1, 1])},
95 }, nodes_with_edges_only=True)
97 pattern = ReduceReplacer()
98 pattern.find_and_replace_pattern(graph)
100 (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
101 self.assertTrue(flag, resp)
105 # data(1,3,64,64)-->Reduce(axis=2,keep_dims=True)-->data(1,3,1,64)
108 # data(1,3,64,64)->Reshape->Pool(1,3,1,64)->Reshape(1,3,1,64)
110 graph = build_graph(nodes_attributes,
111 [('placeholder_1_data', 'reduce_1'),
112 ('reduce_1', 'reduce_1_data'),
113 ('reduce_1_data', 'concat'),
115 {'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])},
116 'reduce_1': {'axis': np.array([2]), 'keep_dims': True, 'reduce_type': 'Mean'},
117 'reduce_1_data': {'shape': np.array([1, 3, 1, 64])},
118 }, nodes_with_edges_only=True)
120 graph.graph['layout'] = 'NCHW'
122 graph_ref = build_graph(nodes_attributes,
123 [('placeholder_1_data', 'reshape_1'),
124 ('reshape_1', 'reshape_1_data'),
125 ('reshape_1_data', 'pooling'),
126 ('pooling', 'pooling_data'),
127 ('pooling_data', 'reshape_2'),
128 ('reshape_2', 'reshape_2_data'),
129 ('reshape_2_data', 'concat'),
131 {'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])},
132 'reshape_1': {'dim': np.array([1, 3, 64, 64])},
133 'reshape_1_data': {'shape': np.array([1, 3, 64, 64])},
134 'pooling': {'window': np.array([1, 1, 64, 1])},
135 'pooling_data': {'shape': np.array([1, 3, 1, 64])},
136 'reshape_2': {'dim': np.array([1, 3, 1, 64])},
137 'reshape_2_data': {'shape': np.array([1, 3, 1, 64])},
138 }, nodes_with_edges_only=True)
140 pattern = ReduceReplacer()
141 pattern.find_and_replace_pattern(graph)
143 (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
144 self.assertTrue(flag, resp)
148 # data(1,3,64,64)-->Reduce(axis=[2,3],keep_dims=True)-->data(1,3,1,1)
151 # data(1,3,64,64)->Reshape->Pool(1,3,1,1)->Reshape(1,3,1,1)
153 graph = build_graph(nodes_attributes,
154 [('placeholder_1_data', 'reduce_1'),
155 ('reduce_1', 'reduce_1_data'),
156 ('reduce_1_data', 'concat'),
158 {'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])},
159 'reduce_1': {'axis': np.array([2, 3]), 'keep_dims': True, 'reduce_type': 'Mean'},
160 'reduce_1_data': {'shape': np.array([1, 3, 1, 1])},
161 }, nodes_with_edges_only=True)
163 graph.graph['layout'] = 'NCHW'
165 graph_ref = build_graph(nodes_attributes,
166 [('placeholder_1_data', 'reshape_1'),
167 ('reshape_1', 'reshape_1_data'),
168 ('reshape_1_data', 'pooling'),
169 ('pooling', 'pooling_data'),
170 ('pooling_data', 'reshape_2'),
171 ('reshape_2', 'reshape_2_data'),
172 ('reshape_2_data', 'concat'),
174 {'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])},
175 'reshape_1': {'dim': np.array([1, 3, 64 * 64, 1])},
176 'reshape_1_data': {'shape': np.array([1, 3, 64 * 64, 1])},
177 'pooling': {'window': np.array([1, 1, 64 * 64, 1])},
178 'pooling_data': {'shape': np.array([1, 3, 1, 1])},
179 'reshape_2': {'dim': np.array([1, 3, 1, 1])},
180 'reshape_2_data': {'shape': np.array([1, 3, 1, 1])},
181 }, nodes_with_edges_only=True)
183 pattern = ReduceReplacer()
184 pattern.find_and_replace_pattern(graph)
186 (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
187 self.assertTrue(flag, resp)
191 # data(2,3,64,64)-->Reduce(axis=[1,2,3],keep_dims=False)-->data(2)
194 # data(2,3,64,64)->Reshape(2,1,3*64*64,1)->Pool(2,1,1,1)->Reshape(2)
196 graph = build_graph(nodes_attributes,
197 [('placeholder_1_data', 'reduce_1'),
198 ('reduce_1', 'reduce_1_data'),
199 ('reduce_1_data', 'concat'),
201 {'placeholder_1_data': {'shape': np.array([2, 3, 64, 64])},
202 'reduce_1': {'axis': np.array([1, 2, 3]), 'keep_dims': False, 'reduce_type': 'Mean'},
203 'reduce_1_data': {'shape': np.array([2])},
204 }, nodes_with_edges_only=True)
206 graph.graph['layout'] = 'NCHW'
208 graph_ref = build_graph(nodes_attributes,
209 [('placeholder_1_data', 'reshape_1'),
210 ('reshape_1', 'reshape_1_data'),
211 ('reshape_1_data', 'pooling'),
212 ('pooling', 'pooling_data'),
213 ('pooling_data', 'reshape_2'),
214 ('reshape_2', 'reshape_2_data'),
215 ('reshape_2_data', 'concat'),
217 {'placeholder_1_data': {'shape': np.array([2, 3, 64, 64])},
218 'reshape_1': {'dim': np.array([2, 1, 3 * 64 * 64, 1])},
219 'reshape_1_data': {'shape': np.array([2, 1, 3 * 64 * 64, 1])},
220 'pooling': {'window': np.array([1, 1, 3 * 64 * 64, 1])},
221 'pooling_data': {'shape': np.array([2, 1, 1, 1])},
222 'reshape_2': {'dim': np.array([2])},
223 'reshape_2_data': {'shape': np.array([2])},
224 }, nodes_with_edges_only=True)
226 pattern = ReduceReplacer()
227 pattern.find_and_replace_pattern(graph)
229 (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
230 self.assertTrue(flag, resp)
234 # data(1, 16, 64, 64, 64, 4)-->Reduce(axis=[5],keep_dims=False)-->data(1, 16, 64, 64, 64)
237 # data(1, 16, 64, 64, 64, 4)->Reshape(1*16*64*64, 64, 4, 1)->Pool(1, 1, 4, 1)->Reshape(1, 16, 64, 64, 64)
239 graph = build_graph(nodes_attributes,
240 [('placeholder_1_data', 'reduce_1'),
241 ('reduce_1', 'reduce_1_data'),
242 ('reduce_1_data', 'concat'),
244 {'placeholder_1_data': {'shape': np.array([1, 16, 64, 64, 64, 4])},
245 'reduce_1': {'axis': np.array([5]), 'keep_dims': False, 'reduce_type': 'max'},
246 'reduce_1_data': {'shape': np.array([1, 16, 64, 64, 64])},
247 }, nodes_with_edges_only=True)
249 graph.graph['layout'] = 'NCHW'
251 graph_ref = build_graph(nodes_attributes,
252 [('placeholder_1_data', 'reshape_1'),
253 ('reshape_1', 'reshape_1_data'),
254 ('reshape_1_data', 'pooling'),
255 ('pooling', 'pooling_data'),
256 ('pooling_data', 'reshape_2'),
257 ('reshape_2', 'reshape_2_data'),
258 ('reshape_2_data', 'concat'),
260 {'placeholder_1_data': {'shape': np.array([1, 16, 64, 64, 64, 4])},
261 'reshape_1': {'dim': np.array([65536, 64, 4, 1])},
262 'reshape_1_data': {'shape': np.array([65536, 64, 4, 1])},
263 'pooling': {'window': np.array([1, 1, 4, 1])},
264 'pooling_data': {'shape': np.array([65536, 64, 1, 1])},
265 'reshape_2': {'dim': np.array([1, 16, 64, 64, 64])},
266 'reshape_2_data': {'shape': np.array([1, 16, 64, 64, 64])},
267 }, nodes_with_edges_only=True)
269 pattern = ReduceReplacer()
270 pattern.find_and_replace_pattern(graph)
272 (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
273 self.assertTrue(flag, resp)
277 # data(1,64,1)-->Reduce(axis=-2,keep_dims=True, reduce_type=Sum)-->data(1,1,1)
280 # data(1,61,1)->Reshape(1,1,64,1)->Pool(1,1,1,1)->Reshape(1,1,1)->Power(scale=64)
282 graph = build_graph(nodes_attributes,
283 [('placeholder_1_data', 'reduce_1'),
284 ('reduce_1', 'reduce_1_data'),
285 ('reduce_1_data', 'concat'),
287 {'placeholder_1_data': {'shape': np.array([1, 64, 1])},
288 'reduce_1': {'axis': np.array([-2]), 'keep_dims': True, 'reduce_type': 'Sum'},
289 'reduce_1_data': {'shape': np.array([1, 1, 1])},
290 }, nodes_with_edges_only=True)
292 graph.graph['layout'] = 'NCHW'
294 graph_ref = build_graph(nodes_attributes,
295 [('placeholder_1_data', 'reshape_1'),
296 ('reshape_1', 'reshape_1_data'),
297 ('reshape_1_data', 'pooling'),
298 ('pooling', 'pooling_data'),
299 ('pooling_data', 'reshape_2'),
300 ('reshape_2', 'reshape_2_data'),
301 ('reshape_2_data', 'power'),
302 ('power', 'power_data'),
303 ('power_data', 'concat'),
305 {'placeholder_1_data': {'shape': np.array([1, 64, 1])},
306 'reshape_1': {'dim': np.array([1, 1, 64, 1])},
307 'reshape_1_data': {'shape': np.array([1, 1, 64, 1])},
308 'pooling': {'window': np.array([1, 1, 64, 1])},
309 'pooling_data': {'shape': np.array([1, 1, 1, 1])},
310 'reshape_2': {'dim': np.array([1, 1, 1])},
311 'reshape_2_data': {'shape': np.array([1, 1, 1])},
312 'power': {'scale': 64.0},
313 'power_data': {'shape': np.array([1, 1, 1])},
314 }, nodes_with_edges_only=True)
316 pattern = ReduceReplacer()
317 pattern.find_and_replace_pattern(graph)
319 (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
320 self.assertTrue(flag, resp)