2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from mo.middle.passes.eliminate import graph_clean_up
22 from mo.middle.passes.fusing.decomposition import convert_scale_shift_to_mul_add, convert_batch_norm, \
24 from mo.utils.unittest.graph import build_graph, compare_graphs
27 'placeholder_1': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
28 'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
29 'placeholder_2': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
30 'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
32 'scaleshift_1': {'type': 'ScaleShift', 'kind': 'op', 'op': 'ScaleShift', 'axis': 0},
33 'const_scaleshift_1_w': {'value': None, 'shape': None, 'kind': 'op'},
34 'scaleshift_1_w': {'value': None, 'shape': None, 'kind': 'data'},
35 'const_scaleshift_1_b': {'value': None, 'shape': None, 'kind': 'op'},
36 'scaleshift_1_b': {'value': None, 'shape': None, 'kind': 'data'},
37 'scaleshift_1_data': {'value': None, 'shape': None, 'kind': 'data'},
38 # Mul and Add operations
39 'mul_1': {'type': None, 'value': None, 'kind': 'op', 'op': 'Mul'},
40 'const_mul_1_w': {'value': None, 'shape': None, 'kind': 'op'},
41 'mul_1_w': {'value': None, 'shape': None, 'kind': 'data'},
42 'mul_1_data': {'value': None, 'shape': None, 'kind': 'data'},
43 'add_1': {'type': None, 'kind': 'op', 'op': 'Add'},
44 'const_add_1_w': {'value': None, 'shape': None, 'kind': 'op'},
45 'add_1_w': {'value': None, 'shape': None, 'kind': 'data'},
46 'add_1_data': {'value': None, 'shape': None, 'kind': 'data'},
47 # Mul and Add operations
48 'mul_2': {'type': None, 'kind': 'op', 'op': 'Mul'},
49 'const_mul_2_w': {'value': None, 'shape': None, 'kind': 'op'},
50 'mul_2_w': {'value': None, 'shape': None, 'kind': 'data'},
51 'mul_2_data': {'value': None, 'shape': None, 'kind': 'data'},
52 'add_2': {'type': None, 'kind': 'op', 'op': 'Add'},
53 'const_add_2_w': {'value': None, 'shape': None, 'kind': 'op'},
54 'add_2_w': {'value': None, 'shape': None, 'kind': 'data'},
55 'add_2_data': {'value': None, 'shape': None, 'kind': 'data'},
57 'placeholder_2/Reshape_': {'type': 'Reshape', 'kind': 'op', 'op': 'Reshape'},
58 'placeholder_2/Reshape_data': {'value': None, 'shape': None, 'kind': 'data'},
60 'bn_op': {'type': None, 'kind': 'op', 'op': 'BatchNorm', 'can_be_fused': True},
61 'bn_const': {'value': None, 'shape': None, 'kind': 'data'},
62 'bn_beta': {'value': None, 'shape': None, 'kind': 'data'},
63 'bn_mean': {'value': None, 'shape': None, 'kind': 'data'},
64 'bn_var': {'value': None, 'shape': None, 'kind': 'data'},
65 'bn_data': {'value': None, 'shape': None, 'kind': 'data'},
67 'concat': {'type': 'Concat', 'kind': 'op', 'op': 'Concat'},
68 'concat_data': {'value': None, 'shape': None, 'kind': 'data'},
69 'op_output': {'kind': 'op', 'op': 'OpOutput'}
73 class ScaleShiftToMulAdd(unittest.TestCase):
75 def test_scaleshift_to_mul_1(self):
76 graph = build_graph(nodes_attributes,
77 [('placeholder_1', 'placeholder_1_data'),
78 ('placeholder_1_data', 'scaleshift_1'),
79 ('const_scaleshift_1_w', 'scaleshift_1_w'),
80 ('scaleshift_1_w', 'scaleshift_1'),
81 ('scaleshift_1', 'scaleshift_1_data'),
82 ('scaleshift_1_data', 'op_output')
84 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
85 'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
86 'scaleshift_1_data': {}
89 graph_ref = build_graph(nodes_attributes,
90 [('placeholder_1', 'placeholder_1_data'),
91 ('placeholder_1_data', 'mul_1'),
92 ('const_mul_1_w', 'mul_1_w'),
94 ('mul_1', 'scaleshift_1_data'),
95 ('scaleshift_1_data', 'op_output')
97 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
98 'const_mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
99 'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
100 'mul_1': {'can_be_fused': True},
101 'scaleshift_1_data': {}
104 graph.graph['layout'] = 'NHWC'
105 convert_scale_shift_to_mul_add(graph)
106 graph_clean_up(graph)
107 (flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1')
108 self.assertTrue(flag, resp)
110 # ScaleShift 2 inputs-> Mul
111 def test_scaleshift2_to_mul(self):
112 graph = build_graph(nodes_attributes,
113 [('placeholder_1', 'placeholder_1_data'),
114 ('placeholder_2', 'placeholder_2_data'),
115 ('placeholder_1_data', 'scaleshift_1'),
116 ('placeholder_2_data', 'scaleshift_1'),
117 ('scaleshift_1', 'scaleshift_1_data'),
118 ('scaleshift_1_data', 'op_output')
120 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
121 'placeholder_2_data': {'shape': np.array([1, 227])},
122 'scaleshift_1_data': {}
125 graph_ref = build_graph(nodes_attributes,
126 [('placeholder_1', 'placeholder_1_data'),
127 ('placeholder_2', 'placeholder_2_data'),
128 ('placeholder_2_data', 'placeholder_2/Reshape_'),
129 ('placeholder_2/Reshape_', 'placeholder_2/Reshape_data'),
130 ('placeholder_1_data', 'mul_1'),
131 ('placeholder_2/Reshape_data', 'mul_1'),
132 ('mul_1', 'scaleshift_1_data'),
133 ('scaleshift_1_data', 'op_output')
135 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
136 'placeholder_2_data': {'shape': np.array([1, 227])},
137 'placeholder_2/Reshape_': {'dim': np.array([1, 227, 1, 1])},
138 'placeholder_2/Reshape_data': {'shape': np.array([1, 227, 1, 1])},
139 'mul_1': {'can_be_fused': True},
140 'scaleshift_1_data': {}
143 graph.graph['layout'] = 'NHWC'
144 convert_scale_shift_to_mul_add(graph)
145 graph_clean_up(graph)
146 (flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1')
147 self.assertTrue(flag, resp)
149 # ScaleShift 2 inputs-> Mul (axis = 1)
150 def test_scaleshift2_axis1_to_mul(self):
151 graph = build_graph(nodes_attributes,
152 [('placeholder_1', 'placeholder_1_data'),
153 ('placeholder_2', 'placeholder_2_data'),
154 ('placeholder_1_data', 'scaleshift_1'),
155 ('placeholder_2_data', 'scaleshift_1'),
156 ('scaleshift_1', 'scaleshift_1_data'),
157 ('scaleshift_1_data', 'op_output')
159 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
160 'placeholder_2_data': {'shape': np.array([227])},
161 'scaleshift_1': {'axis': 1},
162 'scaleshift_1_data': {}
165 graph_ref = build_graph(nodes_attributes,
166 [('placeholder_1', 'placeholder_1_data'),
167 ('placeholder_2', 'placeholder_2_data'),
168 ('placeholder_2_data', 'placeholder_2/Reshape_'),
169 ('placeholder_2/Reshape_', 'placeholder_2/Reshape_data'),
170 ('placeholder_1_data', 'mul_1'),
171 ('placeholder_2/Reshape_data', 'mul_1'),
172 ('mul_1', 'scaleshift_1_data'),
173 ('scaleshift_1_data', 'op_output')
175 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
176 'placeholder_2_data': {'shape': np.array([227])},
177 'placeholder_2/Reshape_': {'dim': np.array([1, 227, 1, 1])},
178 'placeholder_2/Reshape_data': {'shape': np.array([1, 227, 1, 1])},
179 'mul_1': {'can_be_fused': True},
180 'scaleshift_1_data': {}
183 graph.graph['layout'] = 'NHWC'
184 convert_scale_shift_to_mul_add(graph)
185 graph_clean_up(graph)
186 (flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1')
187 self.assertTrue(flag, resp)
189 # ScaleShift -> Mul (Zero biases)
190 def test_scaleshift_to_mul_2(self):
191 graph = build_graph(nodes_attributes,
192 [('placeholder_1', 'placeholder_1_data'),
193 ('placeholder_1_data', 'scaleshift_1'),
194 ('const_scaleshift_1_w', 'scaleshift_1_w'),
195 ('const_scaleshift_1_b', 'scaleshift_1_b'),
196 ('scaleshift_1_w', 'scaleshift_1'),
197 ('scaleshift_1_b', 'scaleshift_1'),
198 ('scaleshift_1', 'scaleshift_1_data'),
199 ('scaleshift_1_data', 'op_output')
201 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
202 'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
203 'scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([0, 0, 0])},
204 'scaleshift_1_data': {}
207 graph_ref = build_graph(nodes_attributes,
208 [('placeholder_1', 'placeholder_1_data'),
209 ('placeholder_1_data', 'mul_1'),
210 ('const_mul_1_w', 'mul_1_w'),
211 ('mul_1_w', 'mul_1'),
212 ('mul_1', 'scaleshift_1_data'),
213 ('scaleshift_1_data', 'op_output')
215 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
216 'const_mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
217 'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
218 'mul_1': {'can_be_fused': True},
219 'scaleshift_1_data': {}
222 graph.graph['layout'] = 'NHWC'
223 convert_scale_shift_to_mul_add(graph)
224 graph_clean_up(graph)
225 (flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1')
226 self.assertTrue(flag, resp)
228 # ScaleShift -> Mul->Add
229 def test_scaleshift_to_mul_add(self):
230 graph = build_graph(nodes_attributes,
231 [('placeholder_1', 'placeholder_1_data'),
232 ('placeholder_1_data', 'scaleshift_1'),
233 ('const_scaleshift_1_w', 'scaleshift_1_w'),
234 ('const_scaleshift_1_b', 'scaleshift_1_b'),
235 ('scaleshift_1_w', 'scaleshift_1'),
236 ('scaleshift_1_b', 'scaleshift_1'),
237 ('scaleshift_1', 'scaleshift_1_data'),
238 ('scaleshift_1_data', 'op_output')
240 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
241 'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
242 'scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([3, 2, 1])},
243 'scaleshift_1_data': {}
246 graph_ref = build_graph(nodes_attributes,
247 [('placeholder_1', 'placeholder_1_data'),
248 ('placeholder_1_data', 'mul_1'),
249 ('const_mul_1_w', 'mul_1_w'),
250 ('mul_1_w', 'mul_1'),
251 ('mul_1', 'mul_1_data'),
252 ('mul_1_data', 'add_1'),
253 ('const_add_1_w', 'add_1_w'),
254 ('add_1_w', 'add_1'),
255 ('add_1', 'scaleshift_1_data'),
256 ('scaleshift_1_data', 'op_output')
258 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
259 'const_mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
260 'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
261 'const_add_1_w': {'shape': np.array([3]), 'value': np.array([3, 2, 1])},
262 'add_1_w': {'shape': np.array([3]), 'value': np.array([3, 2, 1])},
263 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
264 'add_1': {'can_be_fused': True},
265 'mul_1': {'can_be_fused': True},
266 'scaleshift_1_data': {}
269 graph.graph['layout'] = 'NHWC'
270 convert_scale_shift_to_mul_add(graph)
271 graph_clean_up(graph)
272 (flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1')
273 self.assertTrue(flag, resp)
275 # ScaleShift -> None (Zero weights and biases)
276 def test_scaleshift_to_nothing(self):
277 graph = build_graph(nodes_attributes,
278 [('placeholder_1', 'placeholder_1_data'),
279 ('placeholder_1_data', 'scaleshift_1'),
280 ('const_scaleshift_1_w', 'scaleshift_1_w'),
281 ('const_scaleshift_1_b', 'scaleshift_1_b'),
282 ('scaleshift_1_w', 'scaleshift_1'),
283 ('scaleshift_1_b', 'scaleshift_1'),
284 ('scaleshift_1', 'scaleshift_1_data'),
285 ('scaleshift_1_data', 'op_output')
287 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
288 'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 1, 1])},
289 'scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([0, 0, 0])},
290 'scaleshift_1_data': {'shape': np.array([1, 227, 227, 3])}
291 }, nodes_with_edges_only=True)
293 graph_ref = build_graph(nodes_attributes,
294 [('placeholder_1', 'placeholder_1_data'),
295 ('placeholder_1_data', 'op_output')
297 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])}}
298 ,nodes_with_edges_only=True)
300 graph.graph['layout'] = 'NHWC'
301 convert_scale_shift_to_mul_add(graph)
302 graph_clean_up(graph)
303 (flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1')
304 self.assertTrue(flag, resp)
306 # ScaleShift -> ScaleShift (can_be_fused=False)
307 def test_scaleshift_can_be_fused(self):
308 graph = build_graph(nodes_attributes,
309 [('placeholder_1', 'placeholder_1_data'),
310 ('placeholder_1_data', 'scaleshift_1'),
311 ('const_scaleshift_1_w', 'scaleshift_1_w'),
312 ('const_scaleshift_1_b', 'scaleshift_1_b'),
313 ('scaleshift_1_w', 'scaleshift_1'),
314 ('scaleshift_1_b', 'scaleshift_1'),
315 ('scaleshift_1', 'scaleshift_1_data'),
316 ('scaleshift_1_data', 'op_output')
318 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
319 'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 1, 1])},
320 'scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([0, 0, 0])},
321 'scaleshift_1': {'can_be_fused': False},
322 'scaleshift_1_data': {'shape': np.array([1, 227, 227, 3])}
325 graph_ref = build_graph(nodes_attributes,
326 [('placeholder_1', 'placeholder_1_data'),
327 ('placeholder_1_data', 'scaleshift_1'),
328 ('const_scaleshift_1_w', 'scaleshift_1_w'),
329 ('const_scaleshift_1_b', 'scaleshift_1_b'),
330 ('scaleshift_1_w', 'scaleshift_1'),
331 ('scaleshift_1_b', 'scaleshift_1'),
332 ('scaleshift_1', 'scaleshift_1_data'),
333 ('scaleshift_1_data', 'op_output')
335 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
336 'const_scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 1, 1])},
337 'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 1, 1])},
338 'const_scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([0, 0, 0])},
339 'scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([0, 0, 0])},
340 'scaleshift_1': {'can_be_fused': False},
341 'scaleshift_1_data': {'shape': np.array([1, 227, 227, 3])}
344 convert_scale_shift_to_mul_add(graph)
345 graph_clean_up(graph)
347 (flag, resp) = compare_graphs(graph, graph_ref, 'scaleshift_1_data')
348 self.assertTrue(flag, resp)
351 class BatchNormDecomposition(unittest.TestCase):
352 def test_bn_decomposition_1(self):
353 graph = build_graph(nodes_attributes,
354 [('placeholder_1', 'placeholder_1_data'),
355 ('placeholder_1_data', 'bn_op'),
356 ('bn_const', 'bn_op'),
357 ('bn_beta', 'bn_op'),
358 ('bn_mean', 'bn_op'),
360 ('bn_op', 'bn_data'),
361 ('concat', 'concat_data'),
362 ('bn_data', 'concat'),
363 ('concat_data', 'op_output')
365 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
366 'bn_op': {'eps': 1.2},
367 'bn_const': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
368 'bn_beta': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
369 'bn_mean': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
370 'bn_var': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
371 'bn_data': {'shape': np.array([1, 227, 227, 3])},
375 graph_ref = build_graph(nodes_attributes,
376 [('placeholder_1', 'placeholder_1_data'),
377 ('placeholder_1_data', 'mul_1'),
378 ('const_mul_1_w', 'mul_1_w'),
379 ('mul_1_w', 'mul_1'),
380 ('mul_1', 'mul_1_data'),
381 ('mul_1_data', 'add_1'),
382 ('const_add_1_w', 'add_1_w'),
383 ('add_1_w', 'add_1'),
384 ('add_1', 'add_1_data'),
385 ('add_1_data', 'mul_2'),
386 ('const_mul_2_w', 'mul_2_w'),
387 ('mul_2_w', 'mul_2'),
388 ('mul_2', 'mul_2_data'),
389 ('mul_2_data', 'add_2'),
390 ('const_add_2_w', 'add_2_w'),
391 ('add_2_w', 'add_2'),
392 ('add_2', 'add_2_data'),
393 ('concat', 'concat_data'),
394 ('add_2_data', 'concat'),
395 ('concat_data', 'op_output')
397 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
398 'const_mul_1_w': {'shape': np.array([3]),
399 'value': np.array([0.67419986, 0.55901699, 0.48795004])},
400 'mul_1_w': {'shape': np.array([3]),
401 'value': np.array([0.67419986, 0.55901699, 0.48795004])},
402 'const_mul_2_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
403 'mul_2_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
404 'const_add_1_w': {'shape': np.array([3]),
405 'value': np.array([-0.67419986, -1.11803399, -1.46385011])},
406 'add_1_w': {'shape': np.array([3]),
407 'value': np.array([-0.67419986, -1.11803399, -1.46385011])},
408 'const_add_2_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
409 'add_2_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
410 'add_2_data': {'shape': np.array([1, 227, 227, 3])},
411 'mul_1': {'can_be_fused': True},
412 'mul_2': {'can_be_fused': True},
413 'add_1': {'can_be_fused': True},
414 'add_2': {'can_be_fused': True},
418 graph.graph['layout'] = 'NHWC'
419 convert_batch_norm(graph)
420 graph_clean_up(graph)
422 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_data')
423 self.assertTrue(flag, resp)
425 # 'can_be_fused': False for BatchNorm
426 def test_bn_decomposition_2(self):
427 graph = build_graph(nodes_attributes,
428 [('placeholder_1', 'placeholder_1_data'),
429 ('placeholder_1_data', 'bn_op'),
430 ('bn_const', 'bn_op'),
431 ('bn_beta', 'bn_op'),
432 ('bn_mean', 'bn_op'),
434 ('bn_op', 'bn_data'),
435 ('concat', 'concat_data'),
436 ('bn_data', 'concat'),
437 ('concat_data', 'op_output')
439 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
440 'bn_op': {'eps': 1.2, 'can_be_fused': False},
441 'bn_const': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
442 'bn_beta': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
443 'bn_mean': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
444 'bn_var': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
445 'bn_data': {'shape': np.array([1, 227, 227, 3])},
449 graph_ref = build_graph(nodes_attributes,
450 [('placeholder_1', 'placeholder_1_data'),
451 ('placeholder_1_data', 'mul_1'),
452 ('const_mul_1_w', 'mul_1_w'),
453 ('mul_1_w', 'mul_1'),
454 ('mul_1', 'mul_1_data'),
455 ('mul_1_data', 'add_1'),
456 ('const_add_1_w', 'add_1_w'),
457 ('add_1_w', 'add_1'),
458 ('add_1', 'add_1_data'),
459 ('add_1_data', 'mul_2'),
460 ('const_mul_2_w', 'mul_2_w'),
461 ('mul_2_w', 'mul_2'),
462 ('mul_2', 'mul_2_data'),
463 ('mul_2_data', 'add_2'),
464 ('const_add_2_w', 'add_2_w'),
465 ('add_2_w', 'add_2'),
466 ('add_2', 'add_2_data'),
467 ('concat', 'concat_data'),
468 ('add_2_data', 'concat'),
469 ('concat_data', 'op_output')
471 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
472 'const_mul_1_w': {'shape': np.array([3]),
473 'value': np.array([0.67419986, 0.55901699, 0.48795004])},
474 'mul_1_w': {'shape': np.array([3]),
475 'value': np.array([0.67419986, 0.55901699, 0.48795004])},
476 'const_mul_2_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
477 'mul_2_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
478 'const_add_1_w': {'shape': np.array([3]),
479 'value': np.array([-0.67419986, -1.11803399, -1.46385011])},
480 'add_1_w': {'shape': np.array([3]),
481 'value': np.array([-0.67419986, -1.11803399, -1.46385011])},
482 'const_add_2_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
483 'add_2_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
484 'add_2_data': {'shape': np.array([1, 227, 227, 3])},
485 'mul_1': {'can_be_fused': False},
486 'mul_2': {'can_be_fused': False},
487 'add_1': {'can_be_fused': False},
488 'add_2': {'can_be_fused': False},
492 graph.graph['layout'] = 'NHWC'
493 convert_batch_norm(graph)
494 graph_clean_up(graph)
496 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_data')
497 self.assertTrue(flag, resp)
499 def test_caffe_bn_decomposition_1(self):
500 graph = build_graph(nodes_attributes,
501 [('placeholder_1', 'placeholder_1_data'),
502 ('placeholder_1_data', 'bn_op'),
503 ('bn_mean', 'bn_op'),
505 ('bn_op', 'bn_data'),
506 ('concat', 'concat_data'),
507 ('bn_data', 'concat'),
508 ('concat_data', 'op_output')
510 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
511 'bn_op': {'epsilon': 1.2, 'op': 'BatchNormalization'},
512 'bn_mean': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
513 'bn_var': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
514 'bn_data': {'shape': np.array([1, 227, 227, 3])},
518 del graph['placeholder_1']['placeholder_1_data'][0]['in']
519 del graph['bn_op']['bn_data'][0]['in']
521 graph_ref = build_graph(nodes_attributes,
522 [('placeholder_1', 'placeholder_1_data'),
523 ('placeholder_1_data', 'mul_1'),
524 ('const_mul_1_w', 'mul_1_w'),
525 ('mul_1_w', 'mul_1'),
526 ('mul_1', 'mul_1_data'),
527 ('mul_1_data', 'add_1'),
528 ('const_add_1_w', 'add_1_w'),
529 ('add_1_w', 'add_1'),
530 ('add_1', 'add_1_data'),
531 ('concat', 'concat_data'),
532 ('add_1_data', 'concat'),
533 ('concat_data', 'op_output')
535 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
536 'const_mul_1_w': {'shape': np.array([3]),
537 'value': np.array([0.67419986, 0.55901699, 0.48795004])},
538 'mul_1_w': {'shape': np.array([3]),
539 'value': np.array([0.67419986, 0.55901699, 0.48795004])},
540 'const_add_1_w': {'shape': np.array([3]),
541 'value': np.array([-0.67419986, -1.11803399, -1.46385011])},
542 'add_1_w': {'shape': np.array([3]),
543 'value': np.array([-0.67419986, -1.11803399, -1.46385011])},
544 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
545 'mul_1': {'can_be_fused': True},
546 'add_1': {'can_be_fused': True},
550 graph.graph['layout'] = 'NHWC'
551 convert_bn_to_mul_add(graph)
552 graph_clean_up(graph)
554 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_data')
555 self.assertTrue(flag, resp)
557 # 'can_be_fused': False for BatchNormalization
558 def test_caffe_bn_decomposition_2(self):
559 graph = build_graph(nodes_attributes,
560 [('placeholder_1', 'placeholder_1_data'),
561 ('placeholder_1_data', 'bn_op'),
562 ('bn_mean', 'bn_op'),
564 ('bn_op', 'bn_data'),
565 ('concat', 'concat_data'),
566 ('bn_data', 'concat'),
567 ('concat_data', 'op_output')
569 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
570 'bn_op': {'epsilon': 1.2, 'op': 'BatchNormalization', 'can_be_fused': False},
571 'bn_mean': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
572 'bn_var': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
573 'bn_data': {'shape': np.array([1, 227, 227, 3])},
577 del graph['placeholder_1']['placeholder_1_data'][0]['in']
578 del graph['bn_op']['bn_data'][0]['in']
580 graph_ref = build_graph(nodes_attributes,
581 [('placeholder_1', 'placeholder_1_data'),
582 ('placeholder_1_data', 'mul_1'),
583 ('const_mul_1_w', 'mul_1_w'),
584 ('mul_1_w', 'mul_1'),
585 ('mul_1', 'mul_1_data'),
586 ('mul_1_data', 'add_1'),
587 ('const_add_1_w', 'add_1_w'),
588 ('add_1_w', 'add_1'),
589 ('add_1', 'add_1_data'),
590 ('concat', 'concat_data'),
591 ('add_1_data', 'concat'),
592 ('concat_data', 'op_output')
594 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
595 'const_mul_1_w': {'shape': np.array([3]),
596 'value': np.array([0.67419986, 0.55901699, 0.48795004])},
597 'mul_1_w': {'shape': np.array([3]),
598 'value': np.array([0.67419986, 0.55901699, 0.48795004])},
599 'const_add_1_w': {'shape': np.array([3]),
600 'value': np.array([-0.67419986, -1.11803399, -1.46385011])},
601 'add_1_w': {'shape': np.array([3]),
602 'value': np.array([-0.67419986, -1.11803399, -1.46385011])},
603 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
604 'mul_1': {'can_be_fused': False},
605 'add_1': {'can_be_fused': False},
609 graph.graph['layout'] = 'NHWC'
610 convert_bn_to_mul_add(graph)
611 graph_clean_up(graph)
613 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_data')
614 self.assertTrue(flag, resp)