Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / middle / passes / fusing / decomposition_test.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import unittest
18
19 import numpy as np
20
21 from mo.middle.passes.eliminate import graph_clean_up
22 from mo.middle.passes.fusing.decomposition import convert_scale_shift_to_mul_add, convert_batch_norm, \
23     convert_bn_to_mul_add
24 from mo.utils.unittest.graph import build_graph, compare_graphs
25
26 nodes_attributes = {
27     'placeholder_1': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
28     'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
29     'placeholder_2': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
30     'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
31     # ScaleShift layer
32     'scaleshift_1': {'type': 'ScaleShift', 'kind': 'op', 'op': 'ScaleShift', 'axis': 0},
33     'const_scaleshift_1_w': {'value': None, 'shape': None, 'kind': 'op'},
34     'scaleshift_1_w': {'value': None, 'shape': None, 'kind': 'data'},
35     'const_scaleshift_1_b': {'value': None, 'shape': None, 'kind': 'op'},
36     'scaleshift_1_b': {'value': None, 'shape': None, 'kind': 'data'},
37     'scaleshift_1_data': {'value': None, 'shape': None, 'kind': 'data'},
38     # Mul and Add operations
39     'mul_1': {'type': None, 'value': None, 'kind': 'op', 'op': 'Mul'},
40     'const_mul_1_w': {'value': None, 'shape': None, 'kind': 'op'},
41     'mul_1_w': {'value': None, 'shape': None, 'kind': 'data'},
42     'mul_1_data': {'value': None, 'shape': None, 'kind': 'data'},
43     'add_1': {'type': None, 'kind': 'op', 'op': 'Add'},
44     'const_add_1_w': {'value': None, 'shape': None, 'kind': 'op'},
45     'add_1_w': {'value': None, 'shape': None, 'kind': 'data'},
46     'add_1_data': {'value': None, 'shape': None, 'kind': 'data'},
47     # Mul and Add operations
48     'mul_2': {'type': None, 'kind': 'op', 'op': 'Mul'},
49     'const_mul_2_w': {'value': None, 'shape': None, 'kind': 'op'},
50     'mul_2_w': {'value': None, 'shape': None, 'kind': 'data'},
51     'mul_2_data': {'value': None, 'shape': None, 'kind': 'data'},
52     'add_2': {'type': None, 'kind': 'op', 'op': 'Add'},
53     'const_add_2_w': {'value': None, 'shape': None, 'kind': 'op'},
54     'add_2_w': {'value': None, 'shape': None, 'kind': 'data'},
55     'add_2_data': {'value': None, 'shape': None, 'kind': 'data'},
56     # Reshape
57     'placeholder_2/Reshape_': {'type': 'Reshape', 'kind': 'op', 'op': 'Reshape'},
58     'placeholder_2/Reshape_data': {'value': None, 'shape': None, 'kind': 'data'},
59     # BatchNorm operation
60     'bn_op': {'type': None, 'kind': 'op', 'op': 'BatchNorm', 'can_be_fused': True},
61     'bn_const': {'value': None, 'shape': None, 'kind': 'data'},
62     'bn_beta': {'value': None, 'shape': None, 'kind': 'data'},
63     'bn_mean': {'value': None, 'shape': None, 'kind': 'data'},
64     'bn_var': {'value': None, 'shape': None, 'kind': 'data'},
65     'bn_data': {'value': None, 'shape': None, 'kind': 'data'},
66     # Concat1 operation
67     'concat': {'type': 'Concat', 'kind': 'op', 'op': 'Concat'},
68     'concat_data': {'value': None, 'shape': None, 'kind': 'data'},
69     'op_output': {'kind': 'op', 'op': 'OpOutput'}
70 }
71
72
73 class ScaleShiftToMulAdd(unittest.TestCase):
74     # ScaleShift -> Mul
75     def test_scaleshift_to_mul_1(self):
76         graph = build_graph(nodes_attributes,
77                             [('placeholder_1', 'placeholder_1_data'),
78                              ('placeholder_1_data', 'scaleshift_1'),
79                              ('const_scaleshift_1_w', 'scaleshift_1_w'),
80                              ('scaleshift_1_w', 'scaleshift_1'),
81                              ('scaleshift_1', 'scaleshift_1_data'),
82                              ('scaleshift_1_data', 'op_output')
83                              ],
84                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
85                              'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
86                              'scaleshift_1_data': {}
87                              })
88
89         graph_ref = build_graph(nodes_attributes,
90                                 [('placeholder_1', 'placeholder_1_data'),
91                                  ('placeholder_1_data', 'mul_1'),
92                                  ('const_mul_1_w', 'mul_1_w'),
93                                  ('mul_1_w', 'mul_1'),
94                                  ('mul_1', 'scaleshift_1_data'),
95                                  ('scaleshift_1_data', 'op_output')
96                                  ],
97                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
98                                  'const_mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
99                                  'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
100                                  'mul_1': {'can_be_fused': True},
101                                  'scaleshift_1_data': {}
102                                  })
103
104         graph.graph['layout'] = 'NHWC'
105         convert_scale_shift_to_mul_add(graph)
106         graph_clean_up(graph)
107         (flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1')
108         self.assertTrue(flag, resp)
109
110     # ScaleShift  2 inputs-> Mul
111     def test_scaleshift2_to_mul(self):
112         graph = build_graph(nodes_attributes,
113                             [('placeholder_1', 'placeholder_1_data'),
114                              ('placeholder_2', 'placeholder_2_data'),
115                              ('placeholder_1_data', 'scaleshift_1'),
116                              ('placeholder_2_data', 'scaleshift_1'),
117                              ('scaleshift_1', 'scaleshift_1_data'),
118                              ('scaleshift_1_data', 'op_output')
119                              ],
120                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
121                              'placeholder_2_data': {'shape': np.array([1, 227])},
122                              'scaleshift_1_data': {}
123                              })
124
125         graph_ref = build_graph(nodes_attributes,
126                                 [('placeholder_1', 'placeholder_1_data'),
127                                  ('placeholder_2', 'placeholder_2_data'),
128                                  ('placeholder_2_data', 'placeholder_2/Reshape_'),
129                                  ('placeholder_2/Reshape_', 'placeholder_2/Reshape_data'),
130                                  ('placeholder_1_data', 'mul_1'),
131                                  ('placeholder_2/Reshape_data', 'mul_1'),
132                                  ('mul_1', 'scaleshift_1_data'),
133                                  ('scaleshift_1_data', 'op_output')
134                                  ],
135                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
136                                  'placeholder_2_data': {'shape': np.array([1, 227])},
137                                  'placeholder_2/Reshape_': {'dim': np.array([1, 227, 1, 1])},
138                                  'placeholder_2/Reshape_data': {'shape': np.array([1, 227, 1, 1])},
139                                  'mul_1': {'can_be_fused': True},
140                                  'scaleshift_1_data': {}
141                                  })
142
143         graph.graph['layout'] = 'NHWC'
144         convert_scale_shift_to_mul_add(graph)
145         graph_clean_up(graph)
146         (flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1')
147         self.assertTrue(flag, resp)
148
149     # ScaleShift  2 inputs-> Mul (axis = 1)
150     def test_scaleshift2_axis1_to_mul(self):
151         graph = build_graph(nodes_attributes,
152                             [('placeholder_1', 'placeholder_1_data'),
153                              ('placeholder_2', 'placeholder_2_data'),
154                              ('placeholder_1_data', 'scaleshift_1'),
155                              ('placeholder_2_data', 'scaleshift_1'),
156                              ('scaleshift_1', 'scaleshift_1_data'),
157                              ('scaleshift_1_data', 'op_output')
158                              ],
159                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
160                              'placeholder_2_data': {'shape': np.array([227])},
161                              'scaleshift_1': {'axis': 1},
162                              'scaleshift_1_data': {}
163                              })
164
165         graph_ref = build_graph(nodes_attributes,
166                                 [('placeholder_1', 'placeholder_1_data'),
167                                  ('placeholder_2', 'placeholder_2_data'),
168                                  ('placeholder_2_data', 'placeholder_2/Reshape_'),
169                                  ('placeholder_2/Reshape_', 'placeholder_2/Reshape_data'),
170                                  ('placeholder_1_data', 'mul_1'),
171                                  ('placeholder_2/Reshape_data', 'mul_1'),
172                                  ('mul_1', 'scaleshift_1_data'),
173                                  ('scaleshift_1_data', 'op_output')
174                                  ],
175                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
176                                  'placeholder_2_data': {'shape': np.array([227])},
177                                  'placeholder_2/Reshape_': {'dim': np.array([1, 227, 1, 1])},
178                                  'placeholder_2/Reshape_data': {'shape': np.array([1, 227, 1, 1])},
179                                  'mul_1': {'can_be_fused': True},
180                                  'scaleshift_1_data': {}
181                                  })
182
183         graph.graph['layout'] = 'NHWC'
184         convert_scale_shift_to_mul_add(graph)
185         graph_clean_up(graph)
186         (flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1')
187         self.assertTrue(flag, resp)
188
189     # ScaleShift -> Mul (Zero biases)
190     def test_scaleshift_to_mul_2(self):
191         graph = build_graph(nodes_attributes,
192                             [('placeholder_1', 'placeholder_1_data'),
193                              ('placeholder_1_data', 'scaleshift_1'),
194                              ('const_scaleshift_1_w', 'scaleshift_1_w'),
195                              ('const_scaleshift_1_b', 'scaleshift_1_b'),
196                              ('scaleshift_1_w', 'scaleshift_1'),
197                              ('scaleshift_1_b', 'scaleshift_1'),
198                              ('scaleshift_1', 'scaleshift_1_data'),
199                              ('scaleshift_1_data', 'op_output')
200                              ],
201                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
202                              'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
203                              'scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([0, 0, 0])},
204                              'scaleshift_1_data': {}
205                              })
206
207         graph_ref = build_graph(nodes_attributes,
208                                 [('placeholder_1', 'placeholder_1_data'),
209                                  ('placeholder_1_data', 'mul_1'),
210                                  ('const_mul_1_w', 'mul_1_w'),
211                                  ('mul_1_w', 'mul_1'),
212                                  ('mul_1', 'scaleshift_1_data'),
213                                  ('scaleshift_1_data', 'op_output')
214                                  ],
215                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
216                                  'const_mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
217                                  'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
218                                  'mul_1': {'can_be_fused': True},
219                                  'scaleshift_1_data': {}
220                                  })
221
222         graph.graph['layout'] = 'NHWC'
223         convert_scale_shift_to_mul_add(graph)
224         graph_clean_up(graph)
225         (flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1')
226         self.assertTrue(flag, resp)
227
228     # ScaleShift -> Mul->Add
229     def test_scaleshift_to_mul_add(self):
230         graph = build_graph(nodes_attributes,
231                             [('placeholder_1', 'placeholder_1_data'),
232                              ('placeholder_1_data', 'scaleshift_1'),
233                              ('const_scaleshift_1_w', 'scaleshift_1_w'),
234                              ('const_scaleshift_1_b', 'scaleshift_1_b'),
235                              ('scaleshift_1_w', 'scaleshift_1'),
236                              ('scaleshift_1_b', 'scaleshift_1'),
237                              ('scaleshift_1', 'scaleshift_1_data'),
238                              ('scaleshift_1_data', 'op_output')
239                              ],
240                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
241                              'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
242                              'scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([3, 2, 1])},
243                              'scaleshift_1_data': {}
244                              })
245
246         graph_ref = build_graph(nodes_attributes,
247                                 [('placeholder_1', 'placeholder_1_data'),
248                                  ('placeholder_1_data', 'mul_1'),
249                                  ('const_mul_1_w', 'mul_1_w'),
250                                  ('mul_1_w', 'mul_1'),
251                                  ('mul_1', 'mul_1_data'),
252                                  ('mul_1_data', 'add_1'),
253                                  ('const_add_1_w', 'add_1_w'),
254                                  ('add_1_w', 'add_1'),
255                                  ('add_1', 'scaleshift_1_data'),
256                                  ('scaleshift_1_data', 'op_output')
257                                  ],
258                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
259                                  'const_mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
260                                  'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
261                                  'const_add_1_w': {'shape': np.array([3]), 'value': np.array([3, 2, 1])},
262                                  'add_1_w': {'shape': np.array([3]), 'value': np.array([3, 2, 1])},
263                                  'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
264                                  'add_1': {'can_be_fused': True},
265                                  'mul_1': {'can_be_fused': True},
266                                  'scaleshift_1_data': {}
267                                  })
268
269         graph.graph['layout'] = 'NHWC'
270         convert_scale_shift_to_mul_add(graph)
271         graph_clean_up(graph)
272         (flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1')
273         self.assertTrue(flag, resp)
274
275     # ScaleShift -> None (Zero weights and biases)
276     def test_scaleshift_to_nothing(self):
277         graph = build_graph(nodes_attributes,
278                             [('placeholder_1', 'placeholder_1_data'),
279                              ('placeholder_1_data', 'scaleshift_1'),
280                              ('const_scaleshift_1_w', 'scaleshift_1_w'),
281                              ('const_scaleshift_1_b', 'scaleshift_1_b'),
282                              ('scaleshift_1_w', 'scaleshift_1'),
283                              ('scaleshift_1_b', 'scaleshift_1'),
284                              ('scaleshift_1', 'scaleshift_1_data'),
285                              ('scaleshift_1_data', 'op_output')
286                              ],
287                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
288                              'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 1, 1])},
289                              'scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([0, 0, 0])},
290                              'scaleshift_1_data': {'shape': np.array([1, 227, 227, 3])}
291                              }, nodes_with_edges_only=True)
292
293         graph_ref = build_graph(nodes_attributes,
294                                 [('placeholder_1', 'placeholder_1_data'),
295                                  ('placeholder_1_data', 'op_output')
296                                  ],
297                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])}}
298                                 ,nodes_with_edges_only=True)
299
300         graph.graph['layout'] = 'NHWC'
301         convert_scale_shift_to_mul_add(graph)
302         graph_clean_up(graph)
303         (flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1')
304         self.assertTrue(flag, resp)
305
306     # ScaleShift -> ScaleShift (can_be_fused=False)
307     def test_scaleshift_can_be_fused(self):
308         graph = build_graph(nodes_attributes,
309                             [('placeholder_1', 'placeholder_1_data'),
310                              ('placeholder_1_data', 'scaleshift_1'),
311                              ('const_scaleshift_1_w', 'scaleshift_1_w'),
312                              ('const_scaleshift_1_b', 'scaleshift_1_b'),
313                              ('scaleshift_1_w', 'scaleshift_1'),
314                              ('scaleshift_1_b', 'scaleshift_1'),
315                              ('scaleshift_1', 'scaleshift_1_data'),
316                              ('scaleshift_1_data', 'op_output')
317                              ],
318                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
319                              'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 1, 1])},
320                              'scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([0, 0, 0])},
321                              'scaleshift_1': {'can_be_fused': False},
322                              'scaleshift_1_data': {'shape': np.array([1, 227, 227, 3])}
323                              })
324
325         graph_ref = build_graph(nodes_attributes,
326                                 [('placeholder_1', 'placeholder_1_data'),
327                                  ('placeholder_1_data', 'scaleshift_1'),
328                                  ('const_scaleshift_1_w', 'scaleshift_1_w'),
329                                  ('const_scaleshift_1_b', 'scaleshift_1_b'),
330                                  ('scaleshift_1_w', 'scaleshift_1'),
331                                  ('scaleshift_1_b', 'scaleshift_1'),
332                                  ('scaleshift_1', 'scaleshift_1_data'),
333                                  ('scaleshift_1_data', 'op_output')
334                                  ],
335                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
336                                  'const_scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 1, 1])},
337                                  'scaleshift_1_w': {'shape': np.array([3]), 'value': np.array([1, 1, 1])},
338                                  'const_scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([0, 0, 0])},
339                                  'scaleshift_1_b': {'shape': np.array([3]), 'value': np.array([0, 0, 0])},
340                                  'scaleshift_1': {'can_be_fused': False},
341                                  'scaleshift_1_data': {'shape': np.array([1, 227, 227, 3])}
342                                  })
343
344         convert_scale_shift_to_mul_add(graph)
345         graph_clean_up(graph)
346
347         (flag, resp) = compare_graphs(graph, graph_ref, 'scaleshift_1_data')
348         self.assertTrue(flag, resp)
349
350
351 class BatchNormDecomposition(unittest.TestCase):
352     def test_bn_decomposition_1(self):
353         graph = build_graph(nodes_attributes,
354                             [('placeholder_1', 'placeholder_1_data'),
355                              ('placeholder_1_data', 'bn_op'),
356                              ('bn_const', 'bn_op'),
357                              ('bn_beta', 'bn_op'),
358                              ('bn_mean', 'bn_op'),
359                              ('bn_var', 'bn_op'),
360                              ('bn_op', 'bn_data'),
361                              ('concat', 'concat_data'),
362                              ('bn_data', 'concat'),
363                              ('concat_data', 'op_output')
364                              ],
365                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
366                              'bn_op': {'eps': 1.2},
367                              'bn_const': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
368                              'bn_beta': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
369                              'bn_mean': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
370                              'bn_var': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
371                              'bn_data': {'shape': np.array([1, 227, 227, 3])},
372                              'concat_data': {}
373                              })
374
375         graph_ref = build_graph(nodes_attributes,
376                                 [('placeholder_1', 'placeholder_1_data'),
377                                  ('placeholder_1_data', 'mul_1'),
378                                  ('const_mul_1_w', 'mul_1_w'),
379                                  ('mul_1_w', 'mul_1'),
380                                  ('mul_1', 'mul_1_data'),
381                                  ('mul_1_data', 'add_1'),
382                                  ('const_add_1_w', 'add_1_w'),
383                                  ('add_1_w', 'add_1'),
384                                  ('add_1', 'add_1_data'),
385                                  ('add_1_data', 'mul_2'),
386                                  ('const_mul_2_w', 'mul_2_w'),
387                                  ('mul_2_w', 'mul_2'),
388                                  ('mul_2', 'mul_2_data'),
389                                  ('mul_2_data', 'add_2'),
390                                  ('const_add_2_w', 'add_2_w'),
391                                  ('add_2_w', 'add_2'),
392                                  ('add_2', 'add_2_data'),
393                                  ('concat', 'concat_data'),
394                                  ('add_2_data', 'concat'),
395                                  ('concat_data', 'op_output')
396                                  ],
397                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
398                                  'const_mul_1_w': {'shape': np.array([3]),
399                                              'value': np.array([0.67419986, 0.55901699, 0.48795004])},
400                                  'mul_1_w': {'shape': np.array([3]),
401                                              'value': np.array([0.67419986, 0.55901699, 0.48795004])},
402                                  'const_mul_2_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
403                                  'mul_2_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
404                                  'const_add_1_w': {'shape': np.array([3]),
405                                              'value': np.array([-0.67419986, -1.11803399, -1.46385011])},
406                                  'add_1_w': {'shape': np.array([3]),
407                                              'value': np.array([-0.67419986, -1.11803399, -1.46385011])},
408                                  'const_add_2_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
409                                  'add_2_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
410                                  'add_2_data': {'shape': np.array([1, 227, 227, 3])},
411                                  'mul_1': {'can_be_fused': True},
412                                  'mul_2': {'can_be_fused': True},
413                                  'add_1': {'can_be_fused': True},
414                                  'add_2': {'can_be_fused': True},
415                                  'concat_data': {}
416                                  })
417
418         graph.graph['layout'] = 'NHWC'
419         convert_batch_norm(graph)
420         graph_clean_up(graph)
421
422         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_data')
423         self.assertTrue(flag, resp)
424
425     # 'can_be_fused': False for BatchNorm
426     def test_bn_decomposition_2(self):
427         graph = build_graph(nodes_attributes,
428                             [('placeholder_1', 'placeholder_1_data'),
429                              ('placeholder_1_data', 'bn_op'),
430                              ('bn_const', 'bn_op'),
431                              ('bn_beta', 'bn_op'),
432                              ('bn_mean', 'bn_op'),
433                              ('bn_var', 'bn_op'),
434                              ('bn_op', 'bn_data'),
435                              ('concat', 'concat_data'),
436                              ('bn_data', 'concat'),
437                              ('concat_data', 'op_output')
438                              ],
439                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
440                              'bn_op': {'eps': 1.2, 'can_be_fused': False},
441                              'bn_const': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
442                              'bn_beta': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
443                              'bn_mean': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
444                              'bn_var': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
445                              'bn_data': {'shape': np.array([1, 227, 227, 3])},
446                              'concat_data': {}
447                              })
448
449         graph_ref = build_graph(nodes_attributes,
450                                 [('placeholder_1', 'placeholder_1_data'),
451                                  ('placeholder_1_data', 'mul_1'),
452                                  ('const_mul_1_w', 'mul_1_w'),
453                                  ('mul_1_w', 'mul_1'),
454                                  ('mul_1', 'mul_1_data'),
455                                  ('mul_1_data', 'add_1'),
456                                  ('const_add_1_w', 'add_1_w'),
457                                  ('add_1_w', 'add_1'),
458                                  ('add_1', 'add_1_data'),
459                                  ('add_1_data', 'mul_2'),
460                                  ('const_mul_2_w', 'mul_2_w'),
461                                  ('mul_2_w', 'mul_2'),
462                                  ('mul_2', 'mul_2_data'),
463                                  ('mul_2_data', 'add_2'),
464                                  ('const_add_2_w', 'add_2_w'),
465                                  ('add_2_w', 'add_2'),
466                                  ('add_2', 'add_2_data'),
467                                  ('concat', 'concat_data'),
468                                  ('add_2_data', 'concat'),
469                                  ('concat_data', 'op_output')
470                                  ],
471                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
472                                  'const_mul_1_w': {'shape': np.array([3]),
473                                              'value': np.array([0.67419986, 0.55901699, 0.48795004])},
474                                  'mul_1_w': {'shape': np.array([3]),
475                                              'value': np.array([0.67419986, 0.55901699, 0.48795004])},
476                                  'const_mul_2_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
477                                  'mul_2_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
478                                  'const_add_1_w': {'shape': np.array([3]),
479                                              'value': np.array([-0.67419986, -1.11803399, -1.46385011])},
480                                  'add_1_w': {'shape': np.array([3]),
481                                              'value': np.array([-0.67419986, -1.11803399, -1.46385011])},
482                                  'const_add_2_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
483                                  'add_2_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
484                                  'add_2_data': {'shape': np.array([1, 227, 227, 3])},
485                                  'mul_1': {'can_be_fused': False},
486                                  'mul_2': {'can_be_fused': False},
487                                  'add_1': {'can_be_fused': False},
488                                  'add_2': {'can_be_fused': False},
489                                  'concat_data': {}
490                                  })
491
492         graph.graph['layout'] = 'NHWC'
493         convert_batch_norm(graph)
494         graph_clean_up(graph)
495
496         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_data')
497         self.assertTrue(flag, resp)
498
499     def test_caffe_bn_decomposition_1(self):
500         graph = build_graph(nodes_attributes,
501                             [('placeholder_1', 'placeholder_1_data'),
502                              ('placeholder_1_data', 'bn_op'),
503                              ('bn_mean', 'bn_op'),
504                              ('bn_var', 'bn_op'),
505                              ('bn_op', 'bn_data'),
506                              ('concat', 'concat_data'),
507                              ('bn_data', 'concat'),
508                              ('concat_data', 'op_output')
509                              ],
510                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
511                              'bn_op': {'epsilon': 1.2, 'op': 'BatchNormalization'},
512                              'bn_mean': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
513                              'bn_var': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
514                              'bn_data': {'shape': np.array([1, 227, 227, 3])},
515                              'concat_data': {}
516                              })
517
518         del graph['placeholder_1']['placeholder_1_data'][0]['in']
519         del graph['bn_op']['bn_data'][0]['in']
520
521         graph_ref = build_graph(nodes_attributes,
522                                 [('placeholder_1', 'placeholder_1_data'),
523                                  ('placeholder_1_data', 'mul_1'),
524                                  ('const_mul_1_w', 'mul_1_w'),
525                                  ('mul_1_w', 'mul_1'),
526                                  ('mul_1', 'mul_1_data'),
527                                  ('mul_1_data', 'add_1'),
528                                  ('const_add_1_w', 'add_1_w'),
529                                  ('add_1_w', 'add_1'),
530                                  ('add_1', 'add_1_data'),
531                                  ('concat', 'concat_data'),
532                                  ('add_1_data', 'concat'),
533                                  ('concat_data', 'op_output')
534                                  ],
535                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
536                                  'const_mul_1_w': {'shape': np.array([3]),
537                                              'value': np.array([0.67419986, 0.55901699, 0.48795004])},
538                                  'mul_1_w': {'shape': np.array([3]),
539                                              'value': np.array([0.67419986, 0.55901699, 0.48795004])},
540                                  'const_add_1_w': {'shape': np.array([3]),
541                                              'value': np.array([-0.67419986, -1.11803399, -1.46385011])},
542                                  'add_1_w': {'shape': np.array([3]),
543                                              'value': np.array([-0.67419986, -1.11803399, -1.46385011])},
544                                  'add_1_data': {'shape': np.array([1, 227, 227, 3])},
545                                  'mul_1': {'can_be_fused': True},
546                                  'add_1': {'can_be_fused': True},
547                                  'concat_data': {}
548                                  })
549
550         graph.graph['layout'] = 'NHWC'
551         convert_bn_to_mul_add(graph)
552         graph_clean_up(graph)
553
554         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_data')
555         self.assertTrue(flag, resp)
556
557     # 'can_be_fused': False for BatchNormalization
558     def test_caffe_bn_decomposition_2(self):
559         graph = build_graph(nodes_attributes,
560                             [('placeholder_1', 'placeholder_1_data'),
561                              ('placeholder_1_data', 'bn_op'),
562                              ('bn_mean', 'bn_op'),
563                              ('bn_var', 'bn_op'),
564                              ('bn_op', 'bn_data'),
565                              ('concat', 'concat_data'),
566                              ('bn_data', 'concat'),
567                              ('concat_data', 'op_output')
568                              ],
569                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
570                              'bn_op': {'epsilon': 1.2, 'op': 'BatchNormalization', 'can_be_fused': False},
571                              'bn_mean': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
572                              'bn_var': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
573                              'bn_data': {'shape': np.array([1, 227, 227, 3])},
574                              'concat_data': {}
575                              })
576
577         del graph['placeholder_1']['placeholder_1_data'][0]['in']
578         del graph['bn_op']['bn_data'][0]['in']
579
580         graph_ref = build_graph(nodes_attributes,
581                                 [('placeholder_1', 'placeholder_1_data'),
582                                  ('placeholder_1_data', 'mul_1'),
583                                  ('const_mul_1_w', 'mul_1_w'),
584                                  ('mul_1_w', 'mul_1'),
585                                  ('mul_1', 'mul_1_data'),
586                                  ('mul_1_data', 'add_1'),
587                                  ('const_add_1_w', 'add_1_w'),
588                                  ('add_1_w', 'add_1'),
589                                  ('add_1', 'add_1_data'),
590                                  ('concat', 'concat_data'),
591                                  ('add_1_data', 'concat'),
592                                  ('concat_data', 'op_output')
593                                  ],
594                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
595                                  'const_mul_1_w': {'shape': np.array([3]),
596                                              'value': np.array([0.67419986, 0.55901699, 0.48795004])},
597                                  'mul_1_w': {'shape': np.array([3]),
598                                              'value': np.array([0.67419986, 0.55901699, 0.48795004])},
599                                  'const_add_1_w': {'shape': np.array([3]),
600                                              'value': np.array([-0.67419986, -1.11803399, -1.46385011])},
601                                  'add_1_w': {'shape': np.array([3]),
602                                              'value': np.array([-0.67419986, -1.11803399, -1.46385011])},
603                                  'add_1_data': {'shape': np.array([1, 227, 227, 3])},
604                                  'mul_1': {'can_be_fused': False},
605                                  'add_1': {'can_be_fused': False},
606                                  'concat_data': {}
607                                  })
608
609         graph.graph['layout'] = 'NHWC'
610         convert_bn_to_mul_add(graph)
611         graph_clean_up(graph)
612
613         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_data')
614         self.assertTrue(flag, resp)