Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / middle / passes / fusing / fuse_linear_ops_test.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import unittest
18
19 import numpy as np
20
21 from mo.graph.graph import Node
22 from mo.middle.passes.eliminate import graph_clean_up
23 from mo.middle.passes.fusing.fuse_linear_ops import _fuse_mul, _fuse_add, fuse_linear_ops
24 from mo.utils.unittest.graph import build_graph, compare_graphs
25
26 nodes_attributes = {
27     'placeholder_1': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
28     'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
29     # ScaleShift layer
30     'scaleshift_1': {'type': 'ScaleShift', 'kind': 'op', 'op': 'ScaleShift'},
31     'scaleshift_1_w': {'value': None, 'shape': None, 'kind': 'data'},
32     'scaleshift_1_b': {'value': None, 'shape': None, 'kind': 'data'},
33     'const_scaleshift_1_w': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None},
34     'const_scaleshift_1_b': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None},
35     'scaleshift_1_data': {'value': None, 'shape': None, 'kind': 'data'},
36     # Mul and Add operations
37     'mul_1': {'type': 'Mul', 'kind': 'op', 'op': 'Mul', 'can_be_fused': True},
38     'mul_1_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
39     'const_mul_1_w': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None},
40     'mul_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
41     'add_1': {'type': 'Add', 'kind': 'op', 'op': 'Add', 'can_be_fused': True},
42     'add_1_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
43     'const_add_1_w': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None},
44     'add_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
45     # Mul2 and Add2 operations
46     'mul_2': {'type': 'Mul', 'kind': 'op', 'op': 'Mul', 'can_be_fused': True},
47     'mul_2_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
48     'const_mul_2_w': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None},
49     'mul_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
50     'add_2': {'type': 'Add', 'kind': 'op', 'op': 'Add', 'can_be_fused': True},
51     'add_2_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
52     'const_add_2_w': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None},
53     'add_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
54     # Concat1 operation
55     'concat_1': {'type': 'Concat', 'kind': 'op', 'op': 'Concat'},
56     'concat_1_data': {'value': None, 'shape': None, 'kind': 'data'},
57     # Convolutions
58     'conv_1': {'type': 'Convolution', 'kind': 'op', 'op': 'Conv2D', 'layout': 'NHWC'},
59     'conv_1_w': {'value': None, 'shape': None, 'kind': 'data'},
60     'conv_1_b': {'value': None, 'shape': None, 'kind': 'data'},
61     'const_conv_1_w': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None},
62     'const_conv_1_b': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None},
63     'conv_1_data': {'value': None, 'shape': None, 'kind': 'data'},
64     'conv_2': {'type': 'Convolution', 'kind': 'op', 'op': 'Conv2D', 'layout': 'NHWC'},
65     'conv_2_w': {'value': None, 'shape': None, 'kind': 'data'},
66     'conv_2_b': {'value': None, 'shape': None, 'kind': 'data'},
67     'const_conv_2_w': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None},
68     'const_conv_2_b': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None},
69     'conv_2_data': {'value': None, 'shape': None, 'kind': 'data'},
70     # FullyConnected
71     'fc_1': {'type': 'FullyConnected', 'kind': 'op', 'op': 'InnerProduct', 'layout': 'NHWC'},
72     'fc_1_w': {'value': None, 'shape': None, 'kind': 'data'},
73     'fc_1_b': {'value': None, 'shape': None, 'kind': 'data'},
74     'const_fc_1_w': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None},
75     'const_fc_1_b': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None},
76     'fc_1_data': {'value': None, 'shape': None, 'kind': 'data'},
77     # Placeholders
78     'placeholder_2': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
79     'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
80     'placeholder_3': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
81     'placeholder_3_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
82     'op_output': {'kind': 'op', 'op': 'OpOutput'},
83     'op_output_1': {'kind': 'op', 'op': 'OpOutput'},
84     'op_output_2': {'kind': 'op', 'op': 'OpOutput'}
85 }
86
87
88 # Unit tests for fuse_mul
89 class FuseMulTests(unittest.TestCase):
90     # Mul(array)->Conv(w+b)
91     def test_fuse_mul_to_conv_1(self):
92         # Placeholder->Mul->Conv
93         graph = build_graph(nodes_attributes,
94                             [('placeholder_1', 'placeholder_1_data'),
95                              ('placeholder_1_data', 'mul_1'),
96                              ('const_mul_1_w', 'mul_1_w'),
97                              ('mul_1_w', 'mul_1'),
98                              ('mul_1', 'mul_1_data'),
99                              ('mul_1_data', 'conv_1'),
100                              ('const_conv_1_w', 'conv_1_w'),
101                              ('const_conv_1_b', 'conv_1_b'),
102                              ('conv_1_w', 'conv_1'),
103                              ('conv_1_b', 'conv_1'),
104                              ('conv_1', 'conv_1_data'),
105                              ('conv_1_data', 'op_output')
106                              ],
107                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
108                              'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
109                              'const_mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
110                              'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
111                              'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
112                              'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
113                                           'output_channel_dim': 3, 'input_channel_dim': 2,
114                                           'dims_number': 4},
115                              'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
116                              'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
117                              'conv_1_data': {}
118                              })
119         ref_weights = np.ones((11, 11, 3, 96)) * np.reshape(np.array([1, 2, 3]), (3, 1))
120
121         graph_ref = build_graph(nodes_attributes,
122                                 [('placeholder_1', 'placeholder_1_data'),
123                                  ('placeholder_1_data', 'conv_1'),
124                                  ('const_conv_1_w', 'conv_1_w'),
125                                  ('const_conv_1_b', 'conv_1_b'),
126                                  ('conv_1_w', 'conv_1'),
127                                  ('conv_1_b', 'conv_1'),
128                                  ('conv_1', 'conv_1_data'),
129                                  ('conv_1_data', 'op_output')
130                                  ],
131                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
132                                  'const_conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
133                                  'conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
134                                               'output_channel_dim': 3, 'input_channel_dim': 2,
135                                               'dims_number': 4},
136                                  'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
137                                  'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
138                                  'conv_1_data': {}
139                                  })
140
141         _fuse_mul(graph, Node(graph, 'mul_1'), [Node(graph, 'conv_1')], backward=False)
142         graph_clean_up(graph)
143
144         (flag, resp) = compare_graphs(graph, graph_ref, 'conv_1_data')
145         self.assertTrue(flag, resp)
146
147     # Mul(scalar)->Conv(w+b)
148     def test_fuse_mul_to_conv_2(self):
149         # Placeholder->Mul->Conv
150         graph = build_graph(nodes_attributes,
151                             [('placeholder_1', 'placeholder_1_data'),
152                              ('placeholder_1_data', 'mul_1'),
153                              ('const_mul_1_w', 'mul_1_w'),
154                              ('mul_1_w', 'mul_1'),
155                              ('mul_1', 'mul_1_data'),
156                              ('mul_1_data', 'conv_1'),
157                              ('const_conv_1_w', 'conv_1_w'),
158                              ('const_conv_1_b', 'conv_1_b'),
159                              ('conv_1_w', 'conv_1'),
160                              ('conv_1_b', 'conv_1'),
161                              ('conv_1', 'conv_1_data'),
162                              ('conv_1_data', 'op_output')
163                              ],
164                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
165                              'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
166                              'const_mul_1_w': {'shape': np.array([1]), 'value': 6},
167                              'mul_1_w': {'shape': np.array([1]), 'value': 6},
168                              'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
169                              'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
170                                           'output_channel_dim': 3, 'input_channel_dim': 2,
171                                           'dims_number': 4},
172                              'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
173                              'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
174                              'conv_1_data': {}
175                              })
176         ref_weights = np.ones((11, 11, 3, 96)) * np.reshape(np.array([6, 6, 6]), (3, 1))
177
178         graph_ref = build_graph(nodes_attributes,
179                                 [('placeholder_1', 'placeholder_1_data'),
180                                  ('placeholder_1_data', 'conv_1'),
181                                  ('const_conv_1_w', 'conv_1_w'),
182                                  ('const_conv_1_b', 'conv_1_b'),
183                                  ('conv_1_w', 'conv_1'),
184                                  ('conv_1_b', 'conv_1'),
185                                  ('conv_1', 'conv_1_data'),
186                                  ('conv_1_data', 'op_output')
187                                  ],
188                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
189                                  'const_conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
190                                  'conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
191                                               'output_channel_dim': 3, 'input_channel_dim': 2,
192                                               'dims_number': 4},
193                                  'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
194                                  'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
195                                  'conv_1_data': {}
196                                  })
197
198         _fuse_mul(graph, Node(graph, 'mul_1'), [Node(graph, 'conv_1')], backward=False)
199         graph_clean_up(graph)
200
201         (flag, resp) = compare_graphs(graph, graph_ref, 'conv_1_data')
202         self.assertTrue(flag, resp)
203
204     # Conv(w+b)->Mul(array)
205     def test_fuse_mul_to_conv_3(self):
206         # Placeholder->Conv->Mul
207         graph = build_graph(nodes_attributes,
208                             [('placeholder_1', 'placeholder_1_data'),
209                              ('placeholder_1_data', 'conv_1'),
210                              ('const_conv_1_w', 'conv_1_w'),
211                              ('const_conv_1_b', 'conv_1_b'),
212                              ('conv_1_w', 'conv_1'),
213                              ('conv_1_b', 'conv_1'),
214                              ('conv_1', 'conv_1_data'),
215                              ('conv_1_data', 'mul_1'),
216                              ('const_mul_1_w', 'mul_1_w'),
217                              ('mul_1_w', 'mul_1'),
218                              ('mul_1', 'mul_1_data'),
219                              ('mul_1_data', 'op_output')
220                              ],
221                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
222                              'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
223                              'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
224                                           'output_channel_dim': 3, 'input_channel_dim': 2,
225                                           'dims_number': 4},
226                              'const_conv_1_b': {'shape': np.array([96]), 'value': np.ones(96)},
227                              'conv_1_b': {'shape': np.array([96]), 'value': np.ones(96)},
228                              'conv_1_data': {'shape': np.array([1, 55, 55, 96])},
229                              'mul_1_data': {'shape': np.array([1, 55, 55, 96])},
230                              'const_mul_1_w': {'shape': np.array([96]), 'value': np.array([x for x in range(96)])},
231                              'mul_1_w': {'shape': np.array([96]), 'value': np.array([x for x in range(96)])},
232                              })
233         ref_weights = np.ones((11, 11, 3, 96)) * np.reshape(np.array([x for x in range(96)]), 96)
234         ref_biases = np.ones(96) * np.array([x for x in range(96)])
235
236         graph_ref = build_graph(nodes_attributes,
237                                 [('placeholder_1', 'placeholder_1_data'),
238                                  ('placeholder_1_data', 'conv_1'),
239                                  ('const_conv_1_w', 'conv_1_w'),
240                                  ('const_conv_1_b', 'conv_1_b'),
241                                  ('conv_1_w', 'conv_1'),
242                                  ('conv_1_b', 'conv_1'),
243                                  ('conv_1', 'conv_1_data'),
244                                  ('conv_1_data', 'op_output')
245                                  ],
246                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
247                                  'const_conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
248                                  'conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
249                                               'output_channel_dim': 3, 'input_channel_dim': 2,
250                                               'dims_number': 4},
251                                  'const_conv_1_b': {'shape': ref_biases.shape, 'value': ref_biases},
252                                  'conv_1_b': {'shape': ref_biases.shape, 'value': ref_biases},
253                                  'conv_1_data': {'shape': np.array([1, 55, 55, 96])}
254                                  })
255
256         _fuse_mul(graph, Node(graph, 'mul_1'), [Node(graph, 'conv_1')], backward=True)
257         graph_clean_up(graph)
258
259         (flag, resp) = compare_graphs(graph, graph_ref, 'mul_1_data', 'conv_1_data')
260         self.assertTrue(flag, resp)
261
262     # Conv(w)->Mul(scalar)
263     def test_fuse_mul_to_conv_4(self):
264         # Placeholder->Conv->Mul
265         graph = build_graph(nodes_attributes,
266                             [('placeholder_1', 'placeholder_1_data'),
267                              ('placeholder_1_data', 'conv_1'),
268                              ('const_conv_1_w', 'conv_1_w'),
269                              ('const_conv_1_b', 'conv_1_b'),
270                              ('conv_1_w', 'conv_1'),
271                              ('conv_1_b', 'conv_1'),
272                              ('conv_1', 'conv_1_data'),
273                              ('conv_1_data', 'mul_1'),
274                              ('const_mul_1_w', 'mul_1_w'),
275                              ('mul_1_w', 'mul_1'),
276                              ('mul_1', 'mul_1_data'),
277                              ('mul_1_data', 'op_output')
278                              ],
279                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
280                              'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
281                              'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
282                                           'output_channel_dim': 3, 'input_channel_dim': 2,
283                                           'dims_number': 4},
284                              'const_conv_1_b': {'shape': np.array([96]), 'value': np.ones(96)},
285                              'conv_1_b': {'shape': np.array([96]), 'value': np.ones(96)},
286                              'conv_1_data': {'shape': np.array([1, 55, 55, 96])},
287                              'mul_1_data': {'shape': np.array([1, 55, 55, 96])},
288                              'const_mul_1_w': {'shape': np.array([1]), 'value': 6},
289                              'mul_1_w': {'shape': np.array([1]), 'value': 6},
290                              })
291         ref_weights = np.ones((11, 11, 3, 96)) * np.array([6])
292         ref_biases = np.ones(96) * np.array([6])
293
294         graph_ref = build_graph(nodes_attributes,
295                                 [('placeholder_1', 'placeholder_1_data'),
296                                  ('placeholder_1_data', 'conv_1'),
297                                  ('const_conv_1_w', 'conv_1_w'),
298                                  ('const_conv_1_b', 'conv_1_b'),
299                                  ('conv_1_w', 'conv_1'),
300                                  ('conv_1_b', 'conv_1'),
301                                  ('conv_1', 'conv_1_data'),
302                                  ('conv_1_data', 'op_output')
303                                  ],
304                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
305                                  'const_conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
306                                  'conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
307                                               'output_channel_dim': 3, 'input_channel_dim': 2,
308                                               'dims_number': 4},
309                                  'const_conv_1_b': {'shape': ref_biases.shape, 'value': ref_biases},
310                                  'conv_1_b': {'shape': ref_biases.shape, 'value': ref_biases},
311                                  'conv_1_data': {'shape': np.array([1, 55, 55, 96])}
312                                  })
313
314         _fuse_mul(graph, Node(graph, 'mul_1'), [Node(graph, 'conv_1')], backward=True)
315         graph_clean_up(graph)
316
317         (flag, resp) = compare_graphs(graph, graph_ref, 'mul_1_data', 'conv_1_data')
318         self.assertTrue(flag, resp)
319
320     # Op0-+->Op1--+----+-->Concat     Op0-+->Op1--+--+-->Concat
321     #  |  |       |    |               |  |       |  |
322     #  |  +->Op2--+    |          =>   |  +->Op2--+  |
323     #  +---->Mul->Conv-+               +---->Conv----+
324     def test_fuse_mul_to_conv_5(self):
325         graph = build_graph(nodes_attributes,
326                             [('placeholder_1', 'placeholder_1_data'),
327                              ('placeholder_1_data', 'mul_1'),
328                              ('const_mul_1_w', 'mul_1_w'),
329                              ('mul_1_w', 'mul_1'),
330                              ('mul_1', 'mul_1_data'),
331                              ('mul_1_data', 'conv_1'),
332                              ('const_conv_1_w', 'conv_1_w'),
333                              ('const_conv_1_b', 'conv_1_b'),
334                              ('conv_1_w', 'conv_1'),
335                              ('conv_1_b', 'conv_1'),
336                              ('conv_1', 'conv_1_data'),
337                              ('placeholder_1_data', 'placeholder_2'),
338                              ('placeholder_2', 'placeholder_2_data'),
339                              ('placeholder_1_data', 'placeholder_3'),
340                              ('placeholder_3', 'placeholder_3_data'),
341                              ('placeholder_2_data', 'concat_1'),
342                              ('placeholder_3_data', 'concat_1'),
343                              ('conv_1_data', 'concat_1'),
344                              ('concat_1', 'concat_1_data'),
345                              ('concat_1_data', 'op_output')
346
347                              ],
348                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
349                              'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
350                              'const_mul_1_w': {'shape': np.array([1]), 'value': 6},
351                              'mul_1_w': {'shape': np.array([1]), 'value': 6},
352                              'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
353                              'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
354                                           'output_channel_dim': 3, 'input_channel_dim': 2,
355                                           'dims_number': 4},
356                              'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
357                              'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
358                              'concat_1_data': {}
359                              })
360         ref_weights = np.ones((11, 11, 3, 96)) * np.reshape(np.array([6, 6, 6]), (3, 1))
361
362         graph_ref = build_graph(nodes_attributes,
363                                 [('placeholder_1', 'placeholder_1_data'),
364                                  ('placeholder_1_data', 'conv_1'),
365                                  ('const_conv_1_w', 'conv_1_w'),
366                                  ('const_conv_1_b', 'conv_1_b'),
367                                  ('conv_1_w', 'conv_1'),
368                                  ('conv_1_b', 'conv_1'),
369                                  ('conv_1', 'conv_1_data'),
370                                  ('placeholder_1_data', 'placeholder_2'),
371                                  ('placeholder_2', 'placeholder_2_data'),
372                                  ('placeholder_1_data', 'placeholder_3'),
373                                  ('placeholder_3', 'placeholder_3_data'),
374                                  ('placeholder_2_data', 'concat_1'),
375                                  ('placeholder_3_data', 'concat_1'),
376                                  ('conv_1_data', 'concat_1'),
377                                  ('concat_1', 'concat_1_data'),
378                                  ('concat_1_data', 'op_output'),
379                                  ],
380                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
381                                  'const_conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
382                                  'conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
383                                               'output_channel_dim': 3,
384                                               'input_channel_dim': 2, 'dims_number': 4},
385                                  'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
386                                  'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
387                                  'conv_1_data': {},
388                                  'placeholder_2_data': {},
389                                  'placeholder_3_data': {},
390                                  })
391
392         _fuse_mul(graph, Node(graph, 'mul_1'), [Node(graph, 'conv_1')], backward=False)
393         graph_clean_up(graph)
394
395         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
396         self.assertTrue(flag, resp)
397
398     def test_fuse_mul_to_conv_5_nparray(self):
399         graph = build_graph(nodes_attributes,
400                             [('placeholder_1', 'placeholder_1_data'),
401                              ('placeholder_1_data', 'mul_1'),
402                              ('const_mul_1_w', 'mul_1_w'),
403                              ('mul_1_w', 'mul_1'),
404                              ('mul_1', 'mul_1_data'),
405                              ('mul_1_data', 'conv_1'),
406                              ('const_conv_1_w', 'conv_1_w'),
407                              ('const_conv_1_b', 'conv_1_b'),
408                              ('conv_1_w', 'conv_1'),
409                              ('conv_1_b', 'conv_1'),
410                              ('conv_1', 'conv_1_data'),
411                              ('placeholder_1_data', 'placeholder_2'),
412                              ('placeholder_2', 'placeholder_2_data'),
413                              ('placeholder_1_data', 'placeholder_3'),
414                              ('placeholder_3', 'placeholder_3_data'),
415                              ('placeholder_2_data', 'concat_1'),
416                              ('placeholder_3_data', 'concat_1'),
417                              ('conv_1_data', 'concat_1'),
418                              ('concat_1', 'concat_1_data'),
419                              ('concat_1_data', 'op_output'),
420
421                              ],
422                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
423                              'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
424                              'const_mul_1_w': {'shape': np.array([1]), 'value': np.array([6])},
425                              'mul_1_w': {'shape': np.array([1]), 'value': np.array([6])},
426                              'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
427                              'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
428                                           'output_channel_dim': 3, 'input_channel_dim': 2,
429                                           'dims_number': 4},
430                              'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
431                              'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
432                              'concat_1_data': {}
433                              })
434         ref_weights = np.ones((11, 11, 3, 96)) * np.reshape(np.array([6, 6, 6]), (3, 1))
435
436         graph_ref = build_graph(nodes_attributes,
437                                 [('placeholder_1', 'placeholder_1_data'),
438                                  ('placeholder_1_data', 'conv_1'),
439                                  ('const_conv_1_w', 'conv_1_w'),
440                                  ('const_conv_1_b', 'conv_1_b'),
441                                  ('conv_1_w', 'conv_1'),
442                                  ('conv_1_b', 'conv_1'),
443                                  ('conv_1', 'conv_1_data'),
444                                  ('placeholder_1_data', 'placeholder_2'),
445                                  ('placeholder_2', 'placeholder_2_data'),
446                                  ('placeholder_1_data', 'placeholder_3'),
447                                  ('placeholder_3', 'placeholder_3_data'),
448                                  ('placeholder_2_data', 'concat_1'),
449                                  ('placeholder_3_data', 'concat_1'),
450                                  ('conv_1_data', 'concat_1'),
451                                  ('concat_1', 'concat_1_data'),
452                                  ('concat_1_data', 'op_output'),
453                                  ],
454                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
455                                  'const_conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
456                                  'conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
457                                               'output_channel_dim': 3,
458                                               'input_channel_dim': 2, 'dims_number': 4},
459                                  'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
460                                  'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
461                                  'conv_1_data': {},
462                                  'placeholder_2_data': {},
463                                  'placeholder_3_data': {},
464                                  })
465
466         _fuse_mul(graph, Node(graph, 'mul_1'), [Node(graph, 'conv_1')], backward=False)
467         graph_clean_up(graph)
468
469         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
470         self.assertTrue(flag, resp)
471
472     # Op->Mul(array)-+->Conv(w+b)--+->Concat     Op-+->Conv1-+-->Concat
473     #                |             |         =>     |        |
474     #                +-->Conv(w+b)-+                +->Conv2-+
475     def test_fuse_mul_to_convolutions_1(self):
476         graph = build_graph(nodes_attributes,
477                             [('placeholder_1', 'placeholder_1_data'),
478                              ('placeholder_1_data', 'mul_1'),
479                              ('const_mul_1_w', 'mul_1_w'),
480                              ('mul_1_w', 'mul_1'),
481                              ('mul_1', 'mul_1_data'),
482                              ('mul_1_data', 'conv_1'),
483                              ('const_conv_1_w', 'conv_1_w'),
484                              ('const_conv_1_b', 'conv_1_b'),
485                              ('conv_1_w', 'conv_1'),
486                              ('conv_1_b', 'conv_1'),
487                              ('conv_1', 'conv_1_data'),
488                              ('mul_1_data', 'conv_2'),
489                              ('const_conv_2_w', 'conv_2_w'),
490                              ('const_conv_2_b', 'conv_2_b'),
491                              ('conv_2_w', 'conv_2'),
492                              ('conv_2_b', 'conv_2'),
493                              ('conv_2', 'conv_2_data'),
494                              ('conv_1_data', 'concat_1'),
495                              ('conv_2_data', 'concat_1'),
496                              ('concat_1', 'concat_1_data'),
497                              ('concat_1_data', 'op_output')
498                              ],
499                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
500                              'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
501                              'const_mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
502                              'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
503                              'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
504                              'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
505                                           'output_channel_dim': 3, 'input_channel_dim': 2,
506                                           'dims_number': 4},
507                              'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
508                              'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
509                              'conv_1_data': {'shape': np.array([1, 55, 55, 96])},
510                              'const_conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
511                              'conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
512                                           'output_channel_dim': 3, 'input_channel_dim': 2,
513                                           'dims_number': 4},
514                              'const_conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
515                              'conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
516                              'conv_2_data': {'shape': np.array([1, 55, 55, 96])},
517                              'concat_1_data': {}
518                              })
519         ref_weights = np.ones((11, 11, 3, 96)) * np.reshape(np.array([1, 2, 3]), (3, 1))
520
521         graph_ref = build_graph(nodes_attributes,
522                                 [('placeholder_1', 'placeholder_1_data'),
523                                  ('placeholder_1_data', 'conv_1'),
524                                  ('const_conv_1_w', 'conv_1_w'),
525                                  ('const_conv_1_b', 'conv_1_b'),
526                                  ('conv_1_w', 'conv_1'),
527                                  ('conv_1_b', 'conv_1'),
528                                  ('conv_1', 'conv_1_data'),
529                                  ('placeholder_1_data', 'conv_2'),
530                                  ('const_conv_2_w', 'conv_2_w'),
531                                  ('const_conv_2_b', 'conv_2_b'),
532                                  ('conv_2_w', 'conv_2'),
533                                  ('conv_2_b', 'conv_2'),
534                                  ('conv_2', 'conv_2_data'),
535                                  ('conv_1_data', 'concat_1'),
536                                  ('conv_2_data', 'concat_1'),
537                                  ('concat_1', 'concat_1_data'),
538                                  ('concat_1_data', 'op_output')
539                                  ],
540                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
541                                  'const_conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
542                                  'conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
543                                               'output_channel_dim': 3, 'input_channel_dim': 2,
544                                               'dims_number': 4},
545                                  'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
546                                  'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
547                                  'conv_1_data': {'shape': np.array([1, 55, 55, 96])},
548                                  'const_conv_2_w': {'shape': ref_weights.shape, 'value': ref_weights},
549                                  'conv_2_w': {'shape': ref_weights.shape, 'value': ref_weights,
550                                               'output_channel_dim': 3, 'input_channel_dim': 2,
551                                               'dims_number': 4},
552                                  'const_conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
553                                  'conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
554                                  'conv_2_data': {'shape': np.array([1, 55, 55, 96])},
555                                  })
556
557         _fuse_mul(graph, Node(graph, 'mul_1'), [Node(graph, 'conv_1'), Node(graph, 'conv_2')], backward=False)
558         graph_clean_up(graph)
559
560         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
561         self.assertTrue(flag, resp)
562
563     # Mul(array)->FC(w+b)
564     def test_fuse_mul_to_fc_1(self):
565         # Placeholder->Mul->FC
566         graph = build_graph(nodes_attributes,
567                             [('placeholder_1', 'placeholder_1_data'),
568                              ('placeholder_1_data', 'mul_1'),
569                              ('const_mul_1_w', 'mul_1_w'),
570                              ('mul_1_w', 'mul_1'),
571                              ('mul_1', 'mul_1_data'),
572                              ('mul_1_data', 'fc_1'),
573                              ('const_fc_1_w', 'fc_1_w'),
574                              ('const_fc_1_b', 'fc_1_b'),
575                              ('fc_1_w', 'fc_1'),
576                              ('fc_1_b', 'fc_1'),
577                              ('fc_1', 'fc_1_data'),
578                              ('fc_1_data', 'op_output')
579                              ],
580                             {'placeholder_1_data': {'shape': np.array([1, 2048])},
581                              'mul_1_data': {'shape': np.array([1, 2048])},
582                              'const_mul_1_w': {'shape': np.array([2048]), 'value': np.array([x for x in range(2048)])},
583                              'mul_1_w': {'shape': np.array([2048]), 'value': np.array([x for x in range(2048)])},
584                              'const_fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048))},
585                              'fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048)),
586                                         'output_channel_dim': 0, 'input_channel_dim': 1,
587                                         'dims_number': 2},
588                              'const_fc_1_b': {'shape': np.array([10260]), 'value': np.ones(10260)},
589                              'fc_1_b': {'shape': np.array([10260]), 'value': np.ones(10260)},
590                              'fc_1_data': {'shape': np.array([1, 10260])},
591                              })
592         ref_weights = np.ones((10260, 2048)) * np.array([x for x in range(2048)])
593
594         graph_ref = build_graph(nodes_attributes,
595                                 [('placeholder_1', 'placeholder_1_data'),
596                                  ('placeholder_1_data', 'fc_1'),
597                                  ('const_fc_1_w', 'fc_1_w'),
598                                  ('const_fc_1_b', 'fc_1_b'),
599                                  ('fc_1_w', 'fc_1'),
600                                  ('fc_1_b', 'fc_1'),
601                                  ('fc_1', 'fc_1_data'),
602                                  ('fc_1_data', 'op_output')
603
604                                  ],
605                                 {'placeholder_1_data': {'shape': np.array([1, 2048])},
606                                  'const_fc_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
607                                  'fc_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
608                                             'output_channel_dim': 0, 'input_channel_dim': 1,
609                                             'dims_number': 2},
610                                  'const_fc_1_b': {'shape': np.array([10260]), 'value': np.ones(10260)},
611                                  'fc_1_b': {'shape': np.array([10260]), 'value': np.ones(10260)},
612                                  'fc_1_data': {'shape': np.array([1, 10260])},
613                                  })
614
615         _fuse_mul(graph, Node(graph, 'mul_1'), [Node(graph, 'fc_1')], backward=False)
616         graph_clean_up(graph)
617
618         (flag, resp) = compare_graphs(graph, graph_ref, 'fc_1_data')
619         self.assertTrue(flag, resp)
620
621     # Mul(scalar)->Conv(w+b) can_be_fused = False
622     def test_fuse_mul_to_conv_6(self):
623         # Placeholder->Mul->Conv
624         graph = build_graph(nodes_attributes,
625                             [('placeholder_1', 'placeholder_1_data'),
626                              ('placeholder_1_data', 'mul_1'),
627                              ('const_mul_1_w', 'mul_1_w'),
628                              ('mul_1_w', 'mul_1'),
629                              ('mul_1', 'mul_1_data'),
630                              ('mul_1_data', 'conv_1'),
631                              ('const_conv_1_w', 'conv_1_w'),
632                              ('const_conv_1_b', 'conv_1_b'),
633                              ('conv_1_w', 'conv_1'),
634                              ('conv_1_b', 'conv_1'),
635                              ('conv_1', 'conv_1_data'),
636                              ('conv_1_data', 'op_output')
637                              ],
638                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
639                              'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
640                              'const_mul_1_w': {'shape': np.array([]), 'value': np.array(6)},
641                              'mul_1_w': {'shape': np.array([]), 'value': np.array(6)},
642                              'conv_1': {'can_be_fused': False},
643                              'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
644                              'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
645                                           'output_channel_dim': 3, 'input_channel_dim': 2,
646                                           'dims_number': 4},
647                              'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
648                              'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
649                              'conv_1_data': {}
650                              })
651
652         graph_ref = build_graph(nodes_attributes,
653                                 [('placeholder_1', 'placeholder_1_data'),
654                                  ('placeholder_1_data', 'mul_1'),
655                                  ('const_mul_1_w', 'mul_1_w'),
656                                  ('mul_1_w', 'mul_1'),
657                                  ('mul_1', 'mul_1_data'),
658                                  ('mul_1_data', 'conv_1'),
659                                  ('const_conv_1_w', 'conv_1_w'),
660                                  ('const_conv_1_b', 'conv_1_b'),
661                                  ('conv_1_w', 'conv_1'),
662                                  ('conv_1_b', 'conv_1'),
663                                  ('conv_1', 'conv_1_data'),
664                                  ('conv_1_data', 'op_output')
665                                  ],
666                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
667                                  'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
668                                  'const_mul_1_w': {'shape': np.array([]), 'value': np.array(6)},
669                                  'mul_1_w': {'shape': np.array([]), 'value': np.array(6)},
670                                  'conv_1': {'can_be_fused': False},
671                                  'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
672                                  'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
673                                               'output_channel_dim': 3, 'input_channel_dim': 2,
674                                               'dims_number': 4},
675                                  'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
676                                  'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
677                                  'conv_1_data': {}
678                                  })
679
680         _fuse_mul(graph, Node(graph, 'mul_1'), [Node(graph, 'conv_1')], backward=False)
681         graph_clean_up(graph)
682
683         (flag, resp) = compare_graphs(graph, graph_ref, 'conv_1_data')
684         self.assertTrue(flag, resp)
685
686     # Mul(array)->DWConv(w+b)
687     def test_fuse_mul_to_dwconv_1(self):
688         # Placeholder->Mul->Conv
689         graph = build_graph(nodes_attributes,
690                             [('placeholder_1', 'placeholder_1_data'),
691                              ('placeholder_1_data', 'mul_1'),
692                              ('const_mul_1_w', 'mul_1_w'),
693                              ('mul_1_w', 'mul_1'),
694                              ('mul_1', 'mul_1_data'),
695                              ('mul_1_data', 'conv_1'),
696                              ('const_conv_1_w', 'conv_1_w'),
697                              ('conv_1_w', 'conv_1'),
698                              ('conv_1', 'conv_1_data'),
699                              ('conv_1_data', 'op_output')
700                              ],
701                             {'placeholder_1_data': {'shape': np.array([1, 112, 112, 6])},
702                              'mul_1_data': {'shape': np.array([1, 112, 112, 6])},
703                              'const_mul_1_w': {'shape': np.array([6]), 'value': np.array([1, 2, 3, 4, 5, 6])},
704                              'mul_1_w': {'shape': np.array([6]), 'value': np.array([1, 2, 3, 4, 5, 6])},
705                              'const_conv_1_w': {'shape': np.array([3, 3, 6, 1]), 'value': np.ones((3, 3, 6, 1))},
706                              'conv_1_w': {'shape': np.array([3, 3, 6, 1]), 'value': np.ones((3, 3, 6, 1)),
707                                           'output_channel_dim': 2, 'input_channel_dim': 2,
708                                           'dims_number': 4},
709                              'conv_1_data': {}
710                              })
711         ref_weights = np.ones((3, 3, 6, 1)) * np.reshape(np.array([1, 2, 3, 4, 5, 6]), (6, 1))
712
713         graph_ref = build_graph(nodes_attributes,
714                                 [('placeholder_1', 'placeholder_1_data'),
715                                  ('placeholder_1_data', 'conv_1'),
716                                  ('const_conv_1_w', 'conv_1_w'),
717                                  ('conv_1_w', 'conv_1'),
718                                  ('conv_1', 'conv_1_data'),
719                                  ('conv_1_data', 'op_output')
720                                  ],
721                                 {'placeholder_1_data': {'shape': np.array([1, 112, 112, 6])},
722                                  'const_conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
723                                  'conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
724                                               'output_channel_dim': 2, 'input_channel_dim': 2,
725                                               'dims_number': 4},
726                                  'conv_1_data': {}
727                                  })
728
729         _fuse_mul(graph, Node(graph, 'mul_1'), [Node(graph, 'conv_1')], backward=False)
730         graph_clean_up(graph)
731
732         (flag, resp) = compare_graphs(graph, graph_ref, 'conv_1_data')
733         self.assertTrue(flag, resp)
734
735     # DWConv(w)->Mul(scalar)
736     def test_fuse_mul_to_dwconv_2(self):
737         # Placeholder->Conv->Mul
738         graph = build_graph(nodes_attributes,
739                             [('placeholder_1', 'placeholder_1_data'),
740                              ('placeholder_1_data', 'conv_1'),
741                              ('const_conv_1_w', 'conv_1_w'),
742                              ('conv_1_w', 'conv_1'),
743                              ('conv_1', 'conv_1_data'),
744                              ('conv_1_data', 'mul_1'),
745                              ('const_mul_1_w', 'mul_1_w'),
746                              ('mul_1_w', 'mul_1'),
747                              ('mul_1', 'mul_1_data'),
748                              ('mul_1_data', 'op_output')
749                              ],
750                             {'placeholder_1_data': {'shape': np.array([1, 112, 112, 6])},
751                              'mul_1_data': {'shape': np.array([1, 112, 112, 6])},
752                              'const_mul_1_w': {'shape': np.array([6]), 'value': np.array([1, 2, 3, 4, 5, 6])},
753                              'mul_1_w': {'shape': np.array([6]), 'value': np.array([1, 2, 3, 4, 5, 6])},
754                              'const_conv_1_w': {'shape': np.array([3, 3, 6, 1]), 'value': np.ones((3, 3, 6, 1))},
755                              'conv_1_w': {'shape': np.array([3, 3, 6, 1]), 'value': np.ones((3, 3, 6, 1)),
756                                           'output_channel_dim': 2, 'input_channel_dim': 2,
757                                           'dims_number': 4},
758                              'conv_1_data': {}
759                              })
760
761         ref_weights = np.ones((3, 3, 6, 1)) * np.reshape(np.array([1, 2, 3, 4, 5, 6]), (6, 1))
762
763         graph_ref = build_graph(nodes_attributes,
764                                 [('placeholder_1', 'placeholder_1_data'),
765                                  ('placeholder_1_data', 'conv_1'),
766                                  ('const_conv_1_w', 'conv_1_w'),
767                                  ('conv_1_w', 'conv_1'),
768                                  ('conv_1', 'conv_1_data'),
769                                  ('conv_1_data', 'op_output')
770                                  ],
771                                 {'placeholder_1_data': {'shape': np.array([1, 112, 112, 6])},
772                                  'const_conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
773                                  'conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
774                                               'output_channel_dim': 2, 'input_channel_dim': 2,
775                                               'dims_number': 4},
776                                  })
777
778         _fuse_mul(graph, Node(graph, 'mul_1'), [Node(graph, 'conv_1')], backward=True)
779         graph_clean_up(graph)
780
781         (flag, resp) = compare_graphs(graph, graph_ref, 'mul_1_data', 'conv_1_data')
782         self.assertTrue(flag, resp)
783
784
785 # Unit tests for fuse_add
786 class FuseAddTests(unittest.TestCase):
787     # Add(array)->FC(w+b)
788     def test_fuse_add_to_fc_1(self):
789         # Placeholder->Add->FC
790         graph = build_graph(nodes_attributes,
791                             [('placeholder_1', 'placeholder_1_data'),
792                              ('placeholder_1_data', 'add_1'),
793                              ('const_add_1_w', 'add_1_w'),
794                              ('add_1_w', 'add_1'),
795                              ('add_1', 'add_1_data'),
796                              ('add_1_data', 'fc_1'),
797                              ('const_fc_1_w', 'fc_1_w'),
798                              ('const_fc_1_b', 'fc_1_b'),
799                              ('fc_1_w', 'fc_1'),
800                              ('fc_1_b', 'fc_1'),
801                              ('fc_1', 'fc_1_data'),
802                              ('fc_1_data', 'op_output')
803
804                              ],
805                             {'placeholder_1_data': {'shape': np.array([1, 2048])},
806                              'add_1_data': {'shape': np.array([1, 2048])},
807                              'const_add_1_w': {'shape': np.array([2048]), 'value': np.array([x for x in range(2048)])},
808                              'add_1_w': {'shape': np.array([2048]), 'value': np.array([x for x in range(2048)])},
809                              'const_fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048))},
810                              'fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048)),
811                                         'output_channel_dim': 0, 'input_channel_dim': 1,
812                                         'dims_number': 2},
813                              'const_fc_1_b': {'shape': np.array([10260]), 'value': np.ones(10260)},
814                              'fc_1_b': {'shape': np.array([10260]), 'value': np.ones(10260)},
815                              'fc_1_data': {'shape': np.array([1, 10260])},
816                              })
817         ref_weights = np.ones((10260, 2048))
818         ref_biases = np.ones(10260) + np.dot(np.ones((10260, 2048)), np.array([x for x in range(2048)]))
819
820         graph_ref = build_graph(nodes_attributes,
821                                 [('placeholder_1', 'placeholder_1_data'),
822                                  ('placeholder_1_data', 'fc_1'),
823                                  ('const_fc_1_w', 'fc_1_w'),
824                                  ('const_fc_1_b', 'fc_1_b'),
825                                  ('fc_1_w', 'fc_1'),
826                                  ('fc_1_b', 'fc_1'),
827                                  ('fc_1', 'fc_1_data'),
828                                  ('fc_1_data', 'op_output')
829                                  ],
830                                 {'placeholder_1_data': {'shape': np.array([1, 2048])},
831                                  'const_fc_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
832                                  'fc_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
833                                             'output_channel_dim': 0, 'input_channel_dim': 1,
834                                             'dims_number': 2},
835                                  'const_fc_1_b': {'shape': ref_biases.shape, 'value': ref_biases},
836                                  'fc_1_b': {'shape': ref_biases.shape, 'value': ref_biases},
837                                  'fc_1_data': {'shape': np.array([1, 10260])},
838                                  })
839
840         _fuse_add(graph, Node(graph, 'add_1'), [Node(graph, 'fc_1')], backward=False)
841         graph_clean_up(graph)
842
843         (flag, resp) = compare_graphs(graph, graph_ref, 'fc_1_data')
844         self.assertTrue(flag, resp)
845
846     # FC(w)->Add(array)
847     def test_fuse_add_to_fc_2(self):
848         # Placeholder->FC->Add
849         graph = build_graph(nodes_attributes,
850                             [('placeholder_1', 'placeholder_1_data'),
851                              ('placeholder_1_data', 'fc_1'),
852                              ('const_fc_1_w', 'fc_1_w'),
853                              ('fc_1_w', 'fc_1'),
854                              ('fc_1', 'fc_1_data'),
855                              ('fc_1_data', 'add_1'),
856                              ('const_add_1_w', 'add_1_w'),
857                              ('add_1_w', 'add_1'),
858                              ('add_1', 'add_1_data'),
859                              ('add_1_data', 'op_output_1')
860                              ],
861                             {'placeholder_1_data': {'shape': np.array([1, 2048])},
862                              'add_1_data': {'shape': np.array([1, 10260])},
863                              'const_add_1_w': {'shape': np.array([10260]), 'value': np.array([x for x in range(10260)])},
864                              'add_1_w': {'shape': np.array([10260]), 'value': np.array([x for x in range(10260)]),
865                                          'data_type': None},
866                              'const_fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048))},
867                              'fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048)),
868                                         'output_channel_dim': 0, 'input_channel_dim': 1,
869                                         'dims_number': 2, 'data_type': None},
870                              'fc_1_data': {'shape': np.array([1, 10260])},
871                              })
872
873         ref_weights = np.ones((10260, 2048))
874         ref_biases = np.array([x for x in range(10260)])
875
876         graph_ref = build_graph(nodes_attributes,
877                                 [('placeholder_1', 'placeholder_1_data'),
878                                  ('placeholder_1_data', 'fc_1'),
879                                  ('const_fc_1_w', 'fc_1_w'),
880                                  ('const_fc_1_b', 'fc_1_b'),
881                                  ('fc_1_w', 'fc_1'),
882                                  ('fc_1_b', 'fc_1'),
883                                  ('fc_1', 'fc_1_data'),
884                                  ('fc_1_data', 'op_output')
885                                  ],
886                                 {'placeholder_1_data': {'shape': np.array([1, 2048])},
887                                  'const_fc_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
888                                  'fc_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
889                                             'output_channel_dim': 0, 'input_channel_dim': 1,
890                                             'dims_number': 2},
891                                  'const_fc_1_b': {'shape': ref_biases.shape, 'value': ref_biases},
892                                  'fc_1_b': {'shape': ref_biases.shape, 'value': ref_biases},
893                                  'fc_1_data': {'shape': np.array([1, 10260])},
894                                  })
895
896         _fuse_add(graph, Node(graph, 'add_1'), [Node(graph, 'fc_1')], backward=True)
897         graph_clean_up(graph)
898
899         (flag, resp) = compare_graphs(graph, graph_ref, 'add_1_data', 'fc_1_data')
900         self.assertTrue(flag, resp)
901
902     # FC(w)->Add(scalar)
903     def test_fuse_add_to_fc_3(self):
904         # Placeholder->FC->Add
905         graph = build_graph(nodes_attributes,
906                             [('placeholder_1', 'placeholder_1_data'),
907                              ('placeholder_1_data', 'fc_1'),
908                              ('const_fc_1_w', 'fc_1_w'),
909                              ('fc_1_w', 'fc_1'),
910                              ('fc_1', 'fc_1_data'),
911                              ('fc_1_data', 'add_1'),
912                              ('const_add_1_w', 'add_1_w'),
913                              ('add_1_w', 'add_1'),
914                              ('add_1', 'add_1_data'),
915                              ('add_1_data', 'op_output')
916                              ],
917                             {'placeholder_1_data': {'shape': np.array([1, 2048])},
918                              'add_1_data': {'shape': np.array([1, 10260])},
919                              'const_add_1_w': {'shape': np.array([1]), 'value': 6, 'data_type': None},
920                              'add_1_w': {'shape': np.array([1]), 'value': 6, 'data_type': None},
921                              'const_fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048))},
922                              'fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048)),
923                                         'output_channel_dim': 0, 'input_channel_dim': 1,
924                                         'dims_number': 2, 'data_type': None},
925                              'fc_1_data': {'shape': np.array([1, 10260])},
926                              })
927
928         ref_weights = np.ones((10260, 2048))
929         ref_biases = np.full([10260], 6)
930
931         graph_ref = build_graph(nodes_attributes,
932                                 [('placeholder_1', 'placeholder_1_data'),
933                                  ('placeholder_1_data', 'fc_1'),
934                                  ('const_fc_1_w', 'fc_1_w'),
935                                  ('const_fc_1_b', 'fc_1_b'),
936                                  ('fc_1_w', 'fc_1'),
937                                  ('fc_1_b', 'fc_1'),
938                                  ('fc_1', 'fc_1_data'),
939                                  ('fc_1_data', 'op_output')
940
941                                  ],
942                                 {'placeholder_1_data': {'shape': np.array([1, 2048])},
943                                  'const_fc_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
944                                  'fc_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
945                                             'output_channel_dim': 0, 'input_channel_dim': 1,
946                                             'dims_number': 2},
947                                  'const_fc_1_b': {'shape': ref_biases.shape, 'value': ref_biases},
948                                  'fc_1_b': {'shape': ref_biases.shape, 'value': ref_biases},
949                                  'fc_1_data': {'shape': np.array([1, 10260])},
950                                  })
951
952         _fuse_add(graph, Node(graph, 'add_1'), [Node(graph, 'fc_1')], backward=True)
953         graph_clean_up(graph)
954
955         (flag, resp) = compare_graphs(graph, graph_ref, 'add_1_data', 'fc_1_data')
956         self.assertTrue(flag, resp)
957
958     # Add(array)->FC(w+b) can_be_fused = False
959     def test_fuse_add_to_fc_4(self):
960         # Placeholder->Add->FC
961         graph = build_graph(nodes_attributes,
962                             [('placeholder_1', 'placeholder_1_data'),
963                              ('placeholder_1_data', 'add_1'),
964                              ('const_add_1_w', 'add_1_w'),
965                              ('add_1_w', 'add_1'),
966                              ('add_1', 'add_1_data'),
967                              ('add_1_data', 'fc_1'),
968                              ('const_fc_1_w', 'fc_1_w'),
969                              ('const_fc_1_b', 'fc_1_b'),
970                              ('fc_1_w', 'fc_1'),
971                              ('fc_1_b', 'fc_1'),
972                              ('fc_1', 'fc_1_data'),
973                              ('fc_1_data', 'op_output')
974
975                              ],
976                             {'placeholder_1_data': {'shape': np.array([1, 2048])},
977                              'add_1_data': {'shape': np.array([1, 2048])},
978                              'const_add_1_w': {'shape': np.array([2048]), 'value': np.array([x for x in range(2048)])},
979                              'add_1_w': {'shape': np.array([2048]), 'value': np.array([x for x in range(2048)])},
980                              'fc_1': {'can_be_fused': False},
981                              'const_fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048))},
982                              'fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048)),
983                                         'output_channel_dim': 0, 'input_channel_dim': 1,
984                                         'dims_number': 2},
985                              'const_fc_1_b': {'shape': np.array([10260]), 'value': np.ones(10260)},
986                              'fc_1_b': {'shape': np.array([10260]), 'value': np.ones(10260)},
987                              'fc_1_data': {'shape': np.array([1, 10260])},
988                              })
989
990         graph_ref = build_graph(nodes_attributes,
991                                 [('placeholder_1', 'placeholder_1_data'),
992                                  ('placeholder_1_data', 'add_1'),
993                                  ('const_add_1_w', 'add_1_w'),
994                                  ('add_1_w', 'add_1'),
995                                  ('add_1', 'add_1_data'),
996                                  ('add_1_data', 'fc_1'),
997                                  ('const_fc_1_w', 'fc_1_w'),
998                                  ('const_fc_1_b', 'fc_1_b'),
999                                  ('fc_1_w', 'fc_1'),
1000                                  ('fc_1_b', 'fc_1'),
1001                                  ('fc_1', 'fc_1_data'),
1002                                  ('fc_1_data', 'op_output')
1003                                  ],
1004                                 {'placeholder_1_data': {'shape': np.array([1, 2048])},
1005                                  'add_1_data': {'shape': np.array([1, 2048])},
1006                                  'const_add_1_w': {'shape': np.array([2048]), 'value': np.array([x for x in range(2048)])},
1007                                  'add_1_w': {'shape': np.array([2048]), 'value': np.array([x for x in range(2048)])},
1008                                  'fc_1': {'can_be_fused': False},
1009                                  'const_fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048))},
1010                                  'fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048)),
1011                                             'output_channel_dim': 0, 'input_channel_dim': 1,
1012                                             'dims_number': 2},
1013                                  'const_fc_1_b': {'shape': np.array([10260]), 'value': np.ones(10260)},
1014                                  'fc_1_b': {'shape': np.array([10260]), 'value': np.ones(10260)},
1015                                  'fc_1_data': {'shape': np.array([1, 10260])},
1016                                  })
1017
1018         _fuse_add(graph, Node(graph, 'add_1'), [Node(graph, 'fc_1')], backward=False)
1019         graph_clean_up(graph)
1020
1021         (flag, resp) = compare_graphs(graph, graph_ref, 'fc_1_data')
1022         self.assertTrue(flag, resp)
1023
1024
1025 # Unit tests for fuse_linear_ops
1026 class FuseLinOpsTests(unittest.TestCase):
1027     # Op->Mul(array)-+->Conv(w+b)->Add-+->Concat     Op-+->Conv1-+-->Concat
1028     #                |                 |         =>     |        |
1029     #                +-->Conv(w+b)-----+                +->Conv2-+
1030     def test_fuse_lin_ops_1(self):
1031         graph = build_graph(nodes_attributes,
1032                             [('placeholder_1', 'placeholder_1_data'),
1033                              ('placeholder_1_data', 'mul_1'),
1034                              ('const_mul_1_w', 'mul_1_w'),
1035                              ('mul_1_w', 'mul_1'),
1036                              ('mul_1', 'mul_1_data'),
1037                              ('mul_1_data', 'conv_1'),
1038                              ('const_conv_1_w', 'conv_1_w'),
1039                              ('const_conv_1_b', 'conv_1_b'),
1040                              ('conv_1_w', 'conv_1'),
1041                              ('conv_1_b', 'conv_1'),
1042                              ('conv_1', 'conv_1_data'),
1043                              ('mul_1_data', 'conv_2'),
1044                              ('const_conv_2_w', 'conv_2_w'),
1045                              ('const_conv_2_b', 'conv_2_b'),
1046                              ('conv_2_w', 'conv_2'),
1047                              ('conv_2_b', 'conv_2'),
1048                              ('conv_2', 'conv_2_data'),
1049                              ('conv_1_data', 'concat_1'),
1050                              ('conv_2_data', 'concat_1'),
1051                              ('concat_1', 'concat_1_data'),
1052                              ('concat_1_data', 'op_output')
1053                              ],
1054                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
1055                              'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
1056                              'const_mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
1057                              'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
1058                              'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
1059                              'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
1060                                           'output_channel_dim': 3, 'input_channel_dim': 2,
1061                                           'dims_number': 4},
1062                              'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1063                              'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1064                              'conv_1_data': {'shape': np.array([1, 55, 55, 96])},
1065                              'const_conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
1066                              'conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
1067                                           'output_channel_dim': 3, 'input_channel_dim': 2,
1068                                           'dims_number': 4},
1069                              'const_conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1070                              'conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1071                              'conv_2_data': {'shape': np.array([1, 55, 55, 96])},
1072                              'concat_1_data': {}
1073                              })
1074         ref_weights = np.ones((11, 11, 3, 96)) * np.reshape(np.array([1, 2, 3]), (3, 1))
1075
1076         graph_ref = build_graph(nodes_attributes,
1077                                 [('placeholder_1', 'placeholder_1_data'),
1078                                  ('placeholder_1_data', 'conv_1'),
1079                                  ('const_conv_1_w', 'conv_1_w'),
1080                                  ('const_conv_1_b', 'conv_1_b'),
1081                                  ('conv_1_w', 'conv_1'),
1082                                  ('conv_1_b', 'conv_1'),
1083                                  ('conv_1', 'conv_1_data'),
1084                                  ('placeholder_1_data', 'conv_2'),
1085                                  ('const_conv_2_w', 'conv_2_w'),
1086                                  ('const_conv_2_b', 'conv_2_b'),
1087                                  ('conv_2_w', 'conv_2'),
1088                                  ('conv_2_b', 'conv_2'),
1089                                  ('conv_2', 'conv_2_data'),
1090                                  ('conv_1_data', 'concat_1'),
1091                                  ('conv_2_data', 'concat_1'),
1092                                  ('concat_1', 'concat_1_data'),
1093                                  ('concat_1_data', 'op_output')
1094                                  ],
1095                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
1096                                  'const_conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
1097                                  'conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
1098                                               'output_channel_dim': 3, 'input_channel_dim': 2,
1099                                               'dims_number': 4},
1100                                  'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1101                                  'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1102                                  'conv_1_data': {'shape': np.array([1, 55, 55, 96])},
1103                                  'const_conv_2_w': {'shape': ref_weights.shape, 'value': ref_weights},
1104                                  'conv_2_w': {'shape': ref_weights.shape, 'value': ref_weights,
1105                                               'output_channel_dim': 3, 'input_channel_dim': 2,
1106                                               'dims_number': 4},
1107                                  'const_conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1108                                  'conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1109                                  'conv_2_data': {'shape': np.array([1, 55, 55, 96])},
1110                                  })
1111
1112         fuse_linear_ops(graph)
1113         graph_clean_up(graph)
1114
1115         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
1116         self.assertTrue(flag, resp)
1117
1118     # Mul(array)->FC(w+b)
1119     def test_fuse_mul_to_fc_1(self):
1120         # Placeholder->Mul->FC
1121         graph = build_graph(nodes_attributes,
1122                             [('placeholder_1', 'placeholder_1_data'),
1123                              ('placeholder_1_data', 'mul_1'),
1124                              ('const_mul_1_w', 'mul_1_w'),
1125                              ('mul_1_w', 'mul_1'),
1126                              ('mul_1', 'mul_1_data'),
1127                              ('mul_1_data', 'fc_1'),
1128                              ('const_fc_1_w', 'fc_1_w'),
1129                              ('const_fc_1_b', 'fc_1_b'),
1130                              ('fc_1_w', 'fc_1'),
1131                              ('fc_1_b', 'fc_1'),
1132                              ('fc_1', 'fc_1_data'),
1133                              ('fc_1_data', 'op_output')
1134                              ],
1135                             {'placeholder_1_data': {'shape': np.array([1, 2048])},
1136                              'mul_1_data': {'shape': np.array([1, 2048])},
1137                              'const_mul_1_w': {'shape': np.array([2048]), 'value': np.array([x for x in range(2048)])},
1138                              'mul_1_w': {'shape': np.array([2048]), 'value': np.array([x for x in range(2048)])},
1139                              'const_fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048))},
1140                              'fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048)),
1141                                         'output_channel_dim': 0, 'input_channel_dim': 1,
1142                                         'dims_number': 2},
1143                              'const_fc_1_b': {'shape': np.array([10260]), 'value': np.ones(10260)},
1144                              'fc_1_b': {'shape': np.array([10260]), 'value': np.ones(10260)},
1145                              'fc_1_data': {'shape': np.array([1, 10260])},
1146                              })
1147         ref_weights = np.ones((10260, 2048)) * np.array([x for x in range(2048)])
1148
1149         graph_ref = build_graph(nodes_attributes,
1150                                 [('placeholder_1', 'placeholder_1_data'),
1151                                  ('placeholder_1_data', 'fc_1'),
1152                                  ('const_fc_1_w', 'fc_1_w'),
1153                                  ('const_fc_1_b', 'fc_1_b'),
1154                                  ('fc_1_w', 'fc_1'),
1155                                  ('fc_1_b', 'fc_1'),
1156                                  ('fc_1', 'fc_1_data'),
1157                                  ('fc_1_data', 'op_output')
1158                                  ],
1159                                 {'placeholder_1_data': {'shape': np.array([1, 2048])},
1160                                  'const_fc_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
1161                                  'fc_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
1162                                             'output_channel_dim': 0, 'input_channel_dim': 1,
1163                                             'dims_number': 2},
1164                                  'const_fc_1_b': {'shape': np.array([10260]), 'value': np.ones(10260)},
1165                                  'fc_1_b': {'shape': np.array([10260]), 'value': np.ones(10260)},
1166                                  'fc_1_data': {'shape': np.array([1, 10260])},
1167                                  })
1168
1169         fuse_linear_ops(graph)
1170         graph_clean_up(graph)
1171
1172         (flag, resp) = compare_graphs(graph, graph_ref, 'fc_1_data')
1173         self.assertTrue(flag, resp)
1174
1175     # FC(w)->Add(scalar)
1176     def test_fuse_add_to_fc_3(self):
1177         # Placeholder->FC->Add
1178         graph = build_graph(nodes_attributes,
1179                             [('placeholder_1', 'placeholder_1_data'),
1180                              ('placeholder_1_data', 'fc_1'),
1181                              ('const_fc_1_w', 'fc_1_w'),
1182                              ('fc_1_w', 'fc_1'),
1183                              ('fc_1', 'fc_1_data'),
1184                              ('fc_1_data', 'add_1'),
1185                              ('const_add_1_w', 'add_1_w'),
1186                              ('add_1_w', 'add_1'),
1187                              ('add_1', 'add_1_data'),
1188                              ('add_1_data', 'op_output')
1189                              ],
1190                             {'placeholder_1_data': {'shape': np.array([1, 2048])},
1191                              'add_1_data': {'shape': np.array([1, 10260])},
1192                              'const_add_1_w': {'shape': np.array([1]), 'value': np.array([6]), 'data_type': None},
1193                              'add_1_w': {'shape': np.array([1]), 'value': np.array([6]), 'data_type': None},
1194                              'const_fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048))},
1195                              'fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048)),
1196                                         'output_channel_dim': 0, 'input_channel_dim': 1,
1197                                         'dims_number': 2, 'data_type': None},
1198                              'fc_1_data': {'shape': np.array([1, 10260])},
1199                              })
1200
1201         ref_weights = np.ones((10260, 2048))
1202         ref_biases = np.array([6 for x in range(10260)])
1203
1204         graph_ref = build_graph(nodes_attributes,
1205                                 [('placeholder_1', 'placeholder_1_data'),
1206                                  ('placeholder_1_data', 'fc_1'),
1207                                  ('const_fc_1_w', 'fc_1_w'),
1208                                  ('const_fc_1_b', 'fc_1_b'),
1209                                  ('fc_1_w', 'fc_1'),
1210                                  ('fc_1_b', 'fc_1'),
1211                                  ('fc_1', 'fc_1_data'),
1212                                  ('fc_1_data', 'op_output')
1213                                  ],
1214                                 {'placeholder_1_data': {'shape': np.array([1, 2048])},
1215                                  'const_fc_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
1216                                  'fc_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
1217                                             'output_channel_dim': 0, 'input_channel_dim': 1,
1218                                             'dims_number': 2},
1219                                  'const_fc_1_b': {'shape': ref_biases.shape, 'value': ref_biases},
1220                                  'fc_1_b': {'shape': ref_biases.shape, 'value': ref_biases},
1221                                  'fc_1_data': {'shape': np.array([1, 10260])},
1222                                  })
1223
1224         fuse_linear_ops(graph)
1225         graph_clean_up(graph)
1226
1227         (flag, resp) = compare_graphs(graph, graph_ref, 'add_1_data', 'fc_1_data')
1228         self.assertTrue(flag, resp)
1229
1230     #                 +-----------+
1231     #                 |           |           =>  Same
1232     # Placeholder--->Add->Mul-----+->Concat
1233     def test_fuse_lin_op_1(self):
1234         graph = build_graph(nodes_attributes,
1235                             [('placeholder_1_data', 'conv_1'),
1236                              ('conv_1', 'conv_1_data'),
1237                              ('const_conv_1_w', 'conv_1_w'),
1238                              ('const_conv_1_b', 'conv_1_b'),
1239                              ('conv_1_w', 'conv_1'),
1240                              ('conv_1_b', 'conv_1'),
1241                              ('conv_1_data', 'add_1'),
1242                              ('const_add_1_w', 'add_1_w'),
1243                              ('add_1_w', 'add_1'),
1244                              ('add_1', 'add_1_data'),
1245                              ('concat_1', 'concat_1_data'),
1246                              ('const_mul_1_w', 'mul_1_w'),
1247                              ('mul_1_w', 'mul_1'),
1248                              ('mul_1', 'mul_1_data'),
1249                              ('add_1_data', 'concat_1'),
1250                              ('mul_1_data', 'concat_1'),
1251                              ('add_1_data', 'mul_1'),
1252                              ('concat_1_data', 'op_output')
1253                              ],
1254                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
1255                              'const_conv_1_w': {'shape': np.array([1, 1, 3, 3]), 'value': np.zeros((1, 1, 3, 3))},
1256                              'conv_1_w': {'shape': np.array([1, 1, 3, 3]), 'value': np.zeros((1, 1, 3, 3)),
1257                                           'output_channel_dim': 3, 'input_channel_dim': 2,
1258                                           'dims_number': 4},
1259                              'const_conv_1_b': {'shape': np.array([3]), 'value': np.zeros(3)},
1260                              'conv_1_b': {'shape': np.array([3]), 'value': np.zeros(3)},
1261                              'conv_1_data': {'shape': np.array([1, 227, 227, 3])},
1262                              'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
1263                              'add_1_data': {'shape': np.array([1, 227, 227, 3])},
1264                              'const_mul_1_w': {'shape': np.array([1]), 'value': np.array([6])},
1265                              'mul_1_w': {'shape': np.array([1]), 'value': np.array([6])},
1266                              'const_add_1_w': {'shape': np.array([1]), 'value': np.array([1])},
1267                              'add_1_w': {'shape': np.array([1]), 'value': np.array([1])},
1268                              'concat_1_data': {}
1269                              })
1270
1271         graph_ref = build_graph(nodes_attributes,
1272                                 [('placeholder_1_data', 'conv_1'),
1273                                  ('conv_1', 'conv_1_data'),
1274                                  ('const_conv_1_w', 'conv_1_w'),
1275                                  ('const_conv_1_b', 'conv_1_b'),
1276                                  ('conv_1_w', 'conv_1'),
1277                                  ('conv_1_b', 'conv_1'),
1278                                  ('conv_1_data', 'concat_1'),
1279                                  ('const_mul_1_w', 'mul_1_w'),
1280                                  ('mul_1_w', 'mul_1'),
1281                                  ('conv_1_data', 'mul_1'),
1282                                  ('concat_1', 'concat_1_data'),
1283                                  ('mul_1', 'mul_1_data'),
1284                                  ('mul_1_data', 'concat_1'),
1285                                  ('concat_1_data', 'op_output')
1286                                  ],
1287                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
1288                                  'const_conv_1_w': {'shape': np.array([1, 1, 3, 3]), 'value': np.zeros((1, 1, 3, 3))},
1289                                  'conv_1_w': {'shape': np.array([1, 1, 3, 3]), 'value': np.zeros((1, 1, 3, 3)),
1290                                               'output_channel_dim': 3, 'input_channel_dim': 2,
1291                                               'dims_number': 4},
1292                                  'const_conv_1_b': {'shape': np.array([3]), 'value': np.ones(3)},
1293                                  'conv_1_b': {'shape': np.array([3]), 'value': np.ones(3)},
1294                                  'conv_1_data': {'shape': np.array([1, 227, 227, 3])},
1295                                  'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
1296                                  'const_mul_1_w': {'shape': np.array([1]), 'value': np.array([6])},
1297                                  'mul_1_w': {'shape': np.array([1]), 'value': np.array([6])},
1298                                  'concat_1_data': {}
1299                                  })
1300
1301         fuse_linear_ops(graph)
1302         graph_clean_up(graph)
1303
1304         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
1305         self.assertTrue(flag, resp)
1306
1307     # Op->Mul(array)-+->Conv(w+b)------+->Concat
1308     #                |                 |         =>  Same('can_be_fused': False)
1309     #                +-->Conv(w+b)-----+
1310     def test_fuse_lin_ops_2(self):
1311         graph = build_graph(nodes_attributes,
1312                             [('placeholder_1', 'placeholder_1_data'),
1313                              ('placeholder_1_data', 'mul_1'),
1314                              ('const_mul_1_w', 'mul_1_w'),
1315                              ('mul_1_w', 'mul_1'),
1316                              ('mul_1', 'mul_1_data'),
1317                              ('mul_1_data', 'conv_1'),
1318                              ('const_conv_1_w', 'conv_1_w'),
1319                              ('const_conv_1_b', 'conv_1_b'),
1320                              ('conv_1_w', 'conv_1'),
1321                              ('conv_1_b', 'conv_1'),
1322                              ('conv_1', 'conv_1_data'),
1323                              ('mul_1_data', 'conv_2'),
1324                              ('const_conv_2_w', 'conv_2_w'),
1325                              ('const_conv_2_b', 'conv_2_b'),
1326                              ('conv_2_w', 'conv_2'),
1327                              ('conv_2_b', 'conv_2'),
1328                              ('conv_2', 'conv_2_data'),
1329                              ('conv_1_data', 'concat_1'),
1330                              ('conv_2_data', 'concat_1'),
1331                              ('concat_1', 'concat_1_data'),
1332                              ('concat_1_data', 'op_output')
1333
1334                              ],
1335                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
1336                              'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
1337                              'const_mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
1338                              'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
1339                              'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
1340                              'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
1341                                           'output_channel_dim': 3, 'input_channel_dim': 2,
1342                                           'dims_number': 4},
1343                              'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1344                              'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1345                              'conv_1_data': {'shape': np.array([1, 55, 55, 96])},
1346                              'conv_2': {'can_be_fused': False},
1347                              'const_conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
1348                              'conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
1349                                           'output_channel_dim': 3, 'input_channel_dim': 2,
1350                                           'dims_number': 4},
1351                              'const_conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1352                              'conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1353                              'conv_2_data': {'shape': np.array([1, 55, 55, 96])},
1354                              'concat_1_data': {}
1355                              })
1356
1357         graph_ref = build_graph(nodes_attributes,
1358                                 [('placeholder_1', 'placeholder_1_data'),
1359                                  ('placeholder_1_data', 'mul_1'),
1360                                  ('const_mul_1_w', 'mul_1_w'),
1361                                  ('mul_1_w', 'mul_1'),
1362                                  ('mul_1', 'mul_1_data'),
1363                                  ('mul_1_data', 'conv_1'),
1364                                  ('const_conv_1_w', 'conv_1_w'),
1365                                  ('const_conv_1_b', 'conv_1_b'),
1366                                  ('conv_1_w', 'conv_1'),
1367                                  ('conv_1_b', 'conv_1'),
1368                                  ('conv_1', 'conv_1_data'),
1369                                  ('mul_1_data', 'conv_2'),
1370                                  ('const_conv_2_w', 'conv_2_w'),
1371                                  ('const_conv_2_b', 'conv_2_b'),
1372                                  ('conv_2_w', 'conv_2'),
1373                                  ('conv_2_b', 'conv_2'),
1374                                  ('conv_2', 'conv_2_data'),
1375                                  ('conv_1_data', 'concat_1'),
1376                                  ('conv_2_data', 'concat_1'),
1377                                  ('concat_1', 'concat_1_data'),
1378                                  ('concat_1_data', 'op_output')
1379                                  ],
1380                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
1381                                  'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
1382                                  'const_mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
1383                                  'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
1384                                  'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
1385                                  'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
1386                                               'output_channel_dim': 3, 'input_channel_dim': 2,
1387                                               'dims_number': 4},
1388                                  'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1389                                  'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1390                                  'conv_1_data': {'shape': np.array([1, 55, 55, 96])},
1391                                  'conv_2': {'can_be_fused': False},
1392                                  'const_conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
1393                                  'conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
1394                                               'output_channel_dim': 3, 'input_channel_dim': 2,
1395                                               'dims_number': 4},
1396                                  'const_conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1397                                  'conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1398                                  'conv_2_data': {'shape': np.array([1, 55, 55, 96])},
1399                                  'concat_1_data': {}
1400                                  })
1401
1402         fuse_linear_ops(graph)
1403         graph_clean_up(graph)
1404
1405         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
1406         self.assertTrue(flag, resp)
1407
1408     # Op->Mul(array)-+->Conv(w+b)------+->Concat
1409     #                |                 |         =>  Same('can_be_fused': False)
1410     #                +-->Conv(w+b)-----+
1411     def test_fuse_lin_ops_3(self):
1412         graph = build_graph(nodes_attributes,
1413                             [('placeholder_1', 'placeholder_1_data'),
1414                              ('placeholder_1_data', 'mul_1'),
1415                              ('const_mul_1_w', 'mul_1_w'),
1416                              ('mul_1_w', 'mul_1'),
1417                              ('mul_1', 'mul_1_data'),
1418                              ('mul_1_data', 'conv_1'),
1419                              ('const_conv_1_w', 'conv_1_w'),
1420                              ('const_conv_1_b', 'conv_1_b'),
1421                              ('conv_1_w', 'conv_1'),
1422                              ('conv_1_b', 'conv_1'),
1423                              ('conv_1', 'conv_1_data'),
1424                              ('mul_1_data', 'conv_2'),
1425                              ('const_conv_2_w', 'conv_2_w'),
1426                              ('const_conv_2_b', 'conv_2_b'),
1427                              ('conv_2_w', 'conv_2'),
1428                              ('conv_2_b', 'conv_2'),
1429                              ('conv_2', 'conv_2_data'),
1430                              ('conv_1_data', 'concat_1'),
1431                              ('conv_2_data', 'concat_1'),
1432                              ('concat_1', 'concat_1_data'),
1433                              ('concat_1_data', 'op_output')
1434                              ],
1435                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
1436                              'mul_1': {'can_be_fused': False},
1437                              'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
1438                              'const_mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
1439                              'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
1440                              'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
1441                              'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
1442                                           'output_channel_dim': 3, 'input_channel_dim': 2,
1443                                           'dims_number': 4},
1444                              'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1445                              'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1446                              'conv_1_data': {'shape': np.array([1, 55, 55, 96])},
1447                              'const_conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
1448                              'conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
1449                                           'output_channel_dim': 3, 'input_channel_dim': 2,
1450                                           'dims_number': 4},
1451                              'const_conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1452                              'conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1453                              'conv_2_data': {'shape': np.array([1, 55, 55, 96])},
1454                              'concat_1_data': {}
1455                              })
1456
1457         graph_ref = build_graph(nodes_attributes,
1458                                 [('placeholder_1', 'placeholder_1_data'),
1459                                  ('placeholder_1_data', 'mul_1'),
1460                                  ('const_mul_1_w', 'mul_1_w'),
1461                                  ('mul_1_w', 'mul_1'),
1462                                  ('mul_1', 'mul_1_data'),
1463                                  ('mul_1_data', 'conv_1'),
1464                                  ('const_conv_1_w', 'conv_1_w'),
1465                                  ('const_conv_1_b', 'conv_1_b'),
1466                                  ('conv_1_w', 'conv_1'),
1467                                  ('conv_1_b', 'conv_1'),
1468                                  ('conv_1', 'conv_1_data'),
1469                                  ('mul_1_data', 'conv_2'),
1470                                  ('const_conv_2_w', 'conv_2_w'),
1471                                  ('const_conv_2_b', 'conv_2_b'),
1472                                  ('conv_2_w', 'conv_2'),
1473                                  ('conv_2_b', 'conv_2'),
1474                                  ('conv_2', 'conv_2_data'),
1475                                  ('conv_1_data', 'concat_1'),
1476                                  ('conv_2_data', 'concat_1'),
1477                                  ('concat_1', 'concat_1_data'),
1478                                  ('concat_1_data', 'op_output')
1479                                  ],
1480                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
1481                                  'mul_1': {'can_be_fused': False},
1482                                  'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
1483                                  'const_mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
1484                                  'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
1485                                  'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
1486                                  'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
1487                                               'output_channel_dim': 3, 'input_channel_dim': 2,
1488                                               'dims_number': 4},
1489                                  'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1490                                  'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1491                                  'conv_1_data': {'shape': np.array([1, 55, 55, 96])},
1492                                  'const_conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
1493                                  'conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
1494                                               'output_channel_dim': 3, 'input_channel_dim': 2,
1495                                               'dims_number': 4},
1496                                  'const_conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1497                                  'conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1498                                  'conv_2_data': {'shape': np.array([1, 55, 55, 96])},
1499                                  'concat_1_data': {}
1500                                  })
1501
1502         fuse_linear_ops(graph)
1503         graph_clean_up(graph)
1504
1505         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
1506         self.assertTrue(flag, resp)