Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / middle / passes / fusing / fuse_linear_seq_test.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import unittest
18
19 import numpy as np
20
21 from mo.middle.passes.fusing.fuse_linear_seq import fuse_mul_add_sequence
22 from mo.utils.unittest.graph import build_graph, compare_graphs
23
24 nodes_attributes = {
25     'placeholder_1': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
26     'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
27     # ScaleShift layer
28     'scaleshift_1': {'type': 'ScaleShift', 'kind': 'op', 'op': 'ScaleShift'},
29     'scaleshift_1_w': {'value': None, 'shape': None, 'kind': 'data'},
30     'scaleshift_1_b': {'value': None, 'shape': None, 'kind': 'data'},
31     'scaleshift_1_data': {'value': None, 'shape': None, 'kind': 'data'},
32     # Mul and Add operations
33     'mul_1': {'type': 'Mul', 'kind': 'op', 'op': 'Mul', 'can_be_fused': True},
34     'mul_1_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
35     'mul_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
36     'add_1': {'type': 'Add', 'kind': 'op', 'op': 'Add', 'can_be_fused': True},
37     'add_1_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
38     'add_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
39     # Mul2 and Add2 operations
40     'mul_2': {'type': 'Mul', 'kind': 'op', 'op': 'Mul', 'can_be_fused': True},
41     'mul_2_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
42     'mul_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
43     'add_2': {'type': 'Add', 'kind': 'op', 'op': 'Add', 'can_be_fused': True},
44     'add_2_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
45     'add_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
46     # Mul3 and Add3 operations
47     'mul_3': {'type': 'Mul', 'kind': 'op', 'op': 'Mul', 'can_be_fused': True},
48     'mul_3_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
49     'mul_3_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
50     'add_3': {'type': 'Add', 'kind': 'op', 'op': 'Add', 'can_be_fused': True},
51     'add_3_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
52     'add_3_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
53     # Mul4 and Add4 operations
54     'mul_4': {'type': 'Mul', 'kind': 'op', 'op': 'Mul', 'can_be_fused': True},
55     'mul_4_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
56     'mul_4_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
57     'add_4': {'type': 'Add', 'kind': 'op', 'op': 'Add', 'can_be_fused': True},
58     'add_4_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
59     'add_4_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
60     # Concat1 operation
61     'concat_1': {'type': 'Concat', 'kind': 'op', 'op': 'Concat'},
62     'concat_1_data': {'value': None, 'shape': None, 'kind': 'data'},
63     # Convolutions
64     'conv_1': {'type': 'Convolution', 'kind': 'op', 'op': 'Conv2D', 'layout': 'NHWC'},
65     'conv_1_w': {'value': None, 'shape': None, 'kind': 'data'},
66     'conv_1_b': {'value': None, 'shape': None, 'kind': 'data'},
67     'conv_1_data': {'value': None, 'shape': None, 'kind': 'data'},
68     'conv_2': {'type': 'Convolution', 'kind': 'op', 'op': 'Conv2D', 'layout': 'NHWC'},
69     'conv_2_w': {'value': None, 'shape': None, 'kind': 'data'},
70     'conv_2_b': {'value': None, 'shape': None, 'kind': 'data'},
71     'conv_2_data': {'value': None, 'shape': None, 'kind': 'data'},
72     # FullyConnected
73     'fc_1': {'type': 'FullyConnected', 'kind': 'op', 'op': 'InnerProduct', 'layout': 'NHWC'},
74     'fc_1_w': {'value': None, 'shape': None, 'kind': 'data'},
75     'fc_1_b': {'value': None, 'shape': None, 'kind': 'data'},
76     'fc_1_data': {'value': None, 'shape': None, 'kind': 'data'},
77     # Placeholders
78     'placeholder_2': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
79     'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
80     'placeholder_3': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
81     'placeholder_3_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
82     'op_output': { 'kind': 'op', 'op': 'OpOutput'}
83 }
84
85
86 # Unit tests for fuse_mul_add_sequence
87 class LinSeqFusingTests(unittest.TestCase):
88     # Placeholder-+->Mul->Add->Mul-+->Concat
89     #             |                |
90     #             +----------------+
91     def test_fuse_lin_seq_1(self):
92         graph = build_graph(nodes_attributes,
93                             [('placeholder_1', 'placeholder_1_data'),
94                              ('placeholder_1_data', 'mul_1'),
95                              ('mul_1_w', 'mul_1'),
96                              ('mul_1', 'mul_1_data'),
97                              ('mul_1_data', 'add_1'),
98                              ('add_1_w', 'add_1'),
99                              ('add_1', 'add_1_data'),
100                              ('add_1_data', 'mul_2'),
101                              ('mul_2_w', 'mul_2'),
102                              ('mul_2', 'mul_2_data'),
103                              ('mul_2_data', 'concat_1'),
104                              ('concat_1', 'concat_1_data'),
105                              ('placeholder_1_data', 'concat_1'),
106                              ('concat_1_data', 'op_output')
107                              ],
108                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
109                              'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
110                              'add_1_data': {'shape': np.array([1, 227, 227, 3])},
111                              'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
112                              'mul_1_w': {'shape': np.array([1]), 'value': 6},
113                              'add_1_w': {'shape': np.array([1]), 'value': 6},
114                              'mul_2_w': {'shape': np.array([1]), 'value': 6},
115                              },
116                             nodes_with_edges_only=True)
117
118         graph_ref = build_graph(nodes_attributes,
119                                 [('placeholder_1', 'placeholder_1_data'),
120                                  ('placeholder_1_data', 'mul_1'),
121                                  ('mul_1_w', 'mul_1'),
122                                  ('mul_1', 'mul_1_data'),
123                                  ('mul_1_data', 'add_1'),
124                                  ('add_1_w', 'add_1'),
125                                  ('add_1', 'add_1_data'),
126                                  ('add_1_data', 'concat_1'),
127                                  ('concat_1', 'concat_1_data'),
128                                  ('placeholder_1_data', 'concat_1'),
129                                  ('concat_1_data', 'op_output')
130                                  ],
131                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
132                                  'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
133                                  'add_1_data': {'shape': np.array([1, 227, 227, 3])},
134                                  'mul_1_w': {'shape': np.array([1]), 'value': np.array([36])},
135                                  'add_1_w': {'shape': np.array([1]), 'value': np.array([36])},
136                                  'mul_1': {'can_be_fused': True},
137                                  'add_1': {'can_be_fused': True},
138                                  },
139                                 nodes_with_edges_only=True)
140
141         graph.graph['layout'] = 'NHWC'
142         fuse_mul_add_sequence(graph)
143         self.assertTrue(len(graph.node) == len(graph_ref.node),
144                         "Graphs has different number of nodes: {} and {}".format(len(graph.node), len(graph_ref.node)))
145
146         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
147         self.assertTrue(flag, resp)
148
149     #             +----------------+
150     #             |                |
151     # Placeholder-+->Mul->Add->Mul-+---------------+->Concat
152     #                           |                  |
153     #                           +-->Placeholder----+
154     def test_fuse_lin_seq_2(self):
155         graph = build_graph(nodes_attributes,
156                             [('placeholder_1', 'placeholder_1_data'),
157                              ('placeholder_1_data', 'mul_1'),
158                              ('mul_1_w', 'mul_1'),
159                              ('mul_1', 'mul_1_data'),
160                              ('mul_1_data', 'add_1'),
161                              ('add_1_w', 'add_1'),
162                              ('add_1', 'add_1_data'),
163                              ('add_1_data', 'mul_2'),
164                              ('mul_2_w', 'mul_2'),
165                              ('mul_2', 'mul_2_data'),
166                              ('mul_2_data', 'concat_1'),
167                              ('concat_1', 'concat_1_data'),
168                              ('placeholder_1_data', 'concat_1'),
169                              ('mul_2_data', 'placeholder_2'),
170                              ('placeholder_2', 'placeholder_2_data'),
171                              ('placeholder_2_data', 'concat_1'),
172                              ('concat_1_data', 'op_output')
173                              ],
174                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
175                              'placeholder_2_data': {'shape': np.array([1, 227, 227, 3])},
176                              'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
177                              'add_1_data': {'shape': np.array([1, 227, 227, 3])},
178                              'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
179                              'mul_1_w': {'shape': np.array([1]), 'value': 6},
180                              'add_1_w': {'shape': np.array([1]), 'value': 6},
181                              'mul_2_w': {'shape': np.array([1]), 'value': 6},
182                              },
183                             nodes_with_edges_only=True)
184
185         graph_ref = build_graph(nodes_attributes,
186                                 [('placeholder_1', 'placeholder_1_data'),
187                                  ('placeholder_1_data', 'mul_1'),
188                                  ('mul_1_w', 'mul_1'),
189                                  ('mul_1', 'mul_1_data'),
190                                  ('mul_1_data', 'add_1'),
191                                  ('add_1_w', 'add_1'),
192                                  ('add_1', 'add_1_data'),
193                                  ('add_1_data', 'concat_1'),
194                                  ('concat_1', 'concat_1_data'),
195                                  ('placeholder_1_data', 'concat_1'),
196                                  ('add_1_data', 'placeholder_2'),
197                                  ('placeholder_2', 'placeholder_2_data'),
198                                  ('placeholder_2_data', 'concat_1'),
199                                  ('concat_1_data', 'op_output')
200                                  ],
201                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
202                                  'placeholder_2_data': {'shape': np.array([1, 227, 227, 3])},
203                                  'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
204                                  'add_1_data': {'shape': np.array([1, 227, 227, 3])},
205                                  'mul_1_w': {'shape': np.array([1]), 'value': np.array([36])},
206                                  'add_1_w': {'shape': np.array([1]), 'value': np.array([36])},
207                                  'mul_1': {'can_be_fused': True},
208                                  'add_1': {'can_be_fused': True},
209                                  },
210                                 nodes_with_edges_only=True)
211         graph.graph['layout'] = 'NHWC'
212         fuse_mul_add_sequence(graph)
213         self.assertTrue(len(graph.node) == len(graph_ref.node),
214                         "Graphs has different number of nodes: {} and {}".format(len(graph.node), len(graph_ref.node)))
215
216         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
217         self.assertTrue(flag, resp)
218
219     #                      +----->Placeholder
220     #                      |        |          =>  The same graph
221     # Placeholder--->Mul->Add->Mul--+->Concat
222     def test_fuse_lin_seq_3(self):
223         graph = build_graph(nodes_attributes,
224                             [('placeholder_1', 'placeholder_1_data'),
225                              ('placeholder_1_data', 'mul_1'),
226                              ('mul_1_w', 'mul_1'),
227                              ('mul_1', 'mul_1_data'),
228                              ('mul_1_data', 'add_1'),
229                              ('add_1_w', 'add_1'),
230                              ('add_1', 'add_1_data'),
231                              ('add_1_data', 'mul_2'),
232                              ('mul_2_w', 'mul_2'),
233                              ('mul_2', 'mul_2_data'),
234                              ('mul_2_data', 'concat_1'),
235                              ('concat_1', 'concat_1_data'),
236                              ('add_1_data', 'placeholder_2'),
237                              ('placeholder_2', 'placeholder_2_data'),
238                              ('placeholder_2_data', 'concat_1'),
239                              ('concat_1_data', 'op_output')
240                              ],
241                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
242                              'placeholder_2_data': {'shape': np.array([1, 227, 227, 3])},
243                              'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
244                              'add_1_data': {'shape': np.array([1, 227, 227, 3])},
245                              'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
246                              'mul_1_w': {'shape': np.array([1]), 'value': 6},
247                              'add_1_w': {'shape': np.array([1]), 'value': 6},
248                              'mul_2_w': {'shape': np.array([1]), 'value': 6},
249                              },
250                             nodes_with_edges_only=True)
251
252         graph_ref = build_graph(nodes_attributes,
253                                 [('placeholder_1', 'placeholder_1_data'),
254                                  ('placeholder_1_data', 'mul_1'),
255                                  ('mul_1_w', 'mul_1'),
256                                  ('mul_1', 'mul_1_data'),
257                                  ('mul_1_data', 'add_1'),
258                                  ('add_1_w', 'add_1'),
259                                  ('add_1', 'add_1_data'),
260                                  ('add_1_data', 'mul_2'),
261                                  ('mul_2_w', 'mul_2'),
262                                  ('mul_2', 'mul_2_data'),
263                                  ('mul_2_data', 'concat_1'),
264                                  ('concat_1', 'concat_1_data'),
265                                  ('add_1_data', 'placeholder_2'),
266                                  ('placeholder_2', 'placeholder_2_data'),
267                                  ('placeholder_2_data', 'concat_1'),
268                                  ('concat_1_data', 'op_output')
269                                  ],
270                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
271                                  'placeholder_2_data': {'shape': np.array([1, 227, 227, 3])},
272                                  'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
273                                  'add_1_data': {'shape': np.array([1, 227, 227, 3])},
274                                  'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
275                                  'mul_1_w': {'shape': np.array([1]), 'value': 6},
276                                  'add_1_w': {'shape': np.array([1]), 'value': 6},
277                                  'mul_2_w': {'shape': np.array([1]), 'value': 6},
278                                  },
279                                 nodes_with_edges_only=True)
280
281         fuse_mul_add_sequence(graph)
282         self.assertTrue(len(graph.node) == len(graph_ref.node),
283                         "Graphs has different number of nodes: {} and {}".format(len(graph.node), len(graph_ref.node)))
284
285         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
286         self.assertTrue(flag, resp)
287
288     #                 +-------->Placeholder                          +-------->Placeholder
289     #                 |            |           =>                    |            |
290     # Placeholder--->Mul->Add->Mul-+->Concat         Placeholder-+->Mul->Mul->Add-+->Concat
291     def test_fuse_lin_seq_4(self):
292         graph = build_graph(nodes_attributes,
293                             [('placeholder_1', 'placeholder_1_data'),
294                              ('placeholder_1_data', 'mul_1'),
295                              ('mul_1_w', 'mul_1'),
296                              ('mul_1', 'mul_1_data'),
297                              ('mul_1_data', 'add_1'),
298                              ('add_1_w', 'add_1'),
299                              ('add_1', 'add_1_data'),
300                              ('add_1_data', 'mul_2'),
301                              ('mul_2_w', 'mul_2'),
302                              ('mul_2', 'mul_2_data'),
303                              ('mul_2_data', 'concat_1'),
304                              ('concat_1', 'concat_1_data'),
305                              ('mul_1_data', 'placeholder_2'),
306                              ('placeholder_2', 'placeholder_2_data'),
307                              ('placeholder_2_data', 'concat_1'),
308                              ('concat_1_data', 'op_output')
309                              ],
310                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
311                              'placeholder_2_data': {'shape': np.array([1, 227, 227, 3])},
312                              'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
313                              'add_1_data': {'shape': np.array([1, 227, 227, 3])},
314                              'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
315                              'mul_1_w': {'shape': np.array([1]), 'value': 6},
316                              'add_1_w': {'shape': np.array([1]), 'value': 6},
317                              'mul_2_w': {'shape': np.array([1]), 'value': 6},
318                              },
319                             nodes_with_edges_only=True)
320
321         graph_ref = build_graph(nodes_attributes,
322                                 [('placeholder_1', 'placeholder_1_data'),
323                                  ('placeholder_1_data', 'mul_1'),
324                                  ('mul_1_w', 'mul_1'),
325                                  ('mul_1', 'mul_1_data'),
326                                  ('mul_1_data', 'mul_2'),
327                                  ('mul_2_w', 'mul_2'),
328                                  ('mul_2', 'mul_2_data'),
329                                  ('mul_2_data', 'add_1'),
330                                  ('add_1_w', 'add_1'),
331                                  ('add_1', 'add_1_data'),
332                                  ('add_1_data', 'concat_1'),
333                                  ('concat_1', 'concat_1_data'),
334                                  ('mul_1_data', 'placeholder_2'),
335                                  ('placeholder_2', 'placeholder_2_data'),
336                                  ('placeholder_2_data', 'concat_1'),
337                                  ('concat_1_data', 'op_output')
338                                  ],
339                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
340                                  'placeholder_2_data': {'shape': np.array([1, 227, 227, 3])},
341                                  'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
342                                  'add_1_data': {'shape': np.array([1, 227, 227, 3])},
343                                  'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
344                                  'mul_1_w': {'shape': np.array([1]), 'value': 6},
345                                  'add_1_w': {'shape': np.array([1]), 'value': np.array([36])},
346                                  'mul_2_w': {'shape': np.array([1]), 'value': np.array([6])},
347                                  },
348                                 nodes_with_edges_only=True)
349
350         graph.graph['layout'] = 'NHWC'
351         fuse_mul_add_sequence(graph)
352         self.assertTrue(len(graph.node) == len(graph_ref.node),
353                         "Graphs has different number of nodes: {} and {}".format(len(graph.node), len(graph_ref.node)))
354
355         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
356         self.assertTrue(flag, resp)
357
358     #                 +-------->Placeholder                          +->Placeholder
359     #                 |            |           =>                    |            |
360     # Placeholder--->Mul->Add->Mul-+->Concat         Placeholder--->Mul-----------+->Concat
361     def test_fuse_lin_seq_5(self):
362         graph = build_graph(nodes_attributes,
363                             [('placeholder_1', 'placeholder_1_data'),
364                              ('placeholder_1_data', 'mul_1'),
365                              ('mul_1_w', 'mul_1'),
366                              ('mul_1', 'mul_1_data'),
367                              ('mul_1_data', 'add_1'),
368                              ('add_1_w', 'add_1'),
369                              ('add_1', 'add_1_data'),
370                              ('add_1_data', 'mul_2'),
371                              ('mul_2_w', 'mul_2'),
372                              ('mul_2', 'mul_2_data'),
373                              ('mul_2_data', 'concat_1'),
374                              ('concat_1', 'concat_1_data'),
375                              ('mul_1_data', 'placeholder_2'),
376                              ('placeholder_2', 'placeholder_2_data'),
377                              ('placeholder_2_data', 'concat_1'),
378                              ('concat_1_data', 'op_output')
379                              ],
380                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
381                              'placeholder_2_data': {'shape': np.array([1, 227, 227, 3])},
382                              'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
383                              'add_1_data': {'shape': np.array([1, 227, 227, 3])},
384                              'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
385                              'mul_1_w': {'shape': np.array([1]), 'value': 6},
386                              'add_1_w': {'shape': np.array([1]), 'value': 0},
387                              'mul_2_w': {'shape': np.array([1]), 'value': 1},
388                              },
389                             nodes_with_edges_only=True)
390
391         graph_ref = build_graph(nodes_attributes,
392                                 [('placeholder_1', 'placeholder_1_data'),
393                                  ('placeholder_1_data', 'mul_1'),
394                                  ('mul_1_w', 'mul_1'),
395                                  ('mul_1', 'mul_1_data'),
396                                  ('mul_1_data', 'concat_1'),
397                                  ('concat_1', 'concat_1_data'),
398                                  ('mul_1_data', 'placeholder_2'),
399                                  ('placeholder_2', 'placeholder_2_data'),
400                                  ('placeholder_2_data', 'concat_1'),
401                                  ('concat_1_data', 'op_output')
402                                  ],
403                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
404                                  'placeholder_2_data': {'shape': np.array([1, 227, 227, 3])},
405                                  'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
406                                  'mul_1_w': {'shape': np.array([1]), 'value': 6},
407                                  },
408                                 nodes_with_edges_only=True)
409
410         graph.graph['layout'] = 'NHWC'
411         fuse_mul_add_sequence(graph)
412
413         self.assertTrue(len(graph.node) == len(graph_ref.node),
414                         "Graphs has different number of nodes: {} and {}".format(len(graph.node), len(graph_ref.node)))
415
416         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
417         self.assertTrue(flag, resp)
418
419     #                 +-------->Placeholder                          +->Placeholder
420     #                 |            |           =>                    |            |
421     # Placeholder--->Mul->Add->Mul-+->Concat         Placeholder--->Mul-->Add-----+->Concat
422     def test_fuse_lin_seq_6(self):
423         graph = build_graph(nodes_attributes,
424                             [('placeholder_1', 'placeholder_1_data'),
425                              ('placeholder_1_data', 'mul_1'),
426                              ('mul_1_w', 'mul_1'),
427                              ('mul_1', 'mul_1_data'),
428                              ('mul_1_data', 'add_1'),
429                              ('add_1_w', 'add_1'),
430                              ('add_1', 'add_1_data'),
431                              ('add_1_data', 'mul_2'),
432                              ('mul_2_w', 'mul_2'),
433                              ('mul_2', 'mul_2_data'),
434                              ('mul_2_data', 'concat_1'),
435                              ('concat_1', 'concat_1_data'),
436                              ('mul_1_data', 'placeholder_2'),
437                              ('placeholder_2', 'placeholder_2_data'),
438                              ('placeholder_2_data', 'concat_1'),
439                              ('concat_1_data', 'op_output')
440                              ],
441                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
442                              'placeholder_2_data': {'shape': np.array([1, 227, 227, 3])},
443                              'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
444                              'add_1_data': {'shape': np.array([1, 227, 227, 3])},
445                              'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
446                              'mul_1_w': {'shape': np.array([1]), 'value': 6},
447                              'add_1_w': {'shape': np.array([1]), 'value': 6},
448                              'mul_2_w': {'shape': np.array([1]), 'value': 1},
449                              },
450                             nodes_with_edges_only=True)
451
452         graph_ref = build_graph(nodes_attributes,
453                                 [('placeholder_1', 'placeholder_1_data'),
454                                  ('placeholder_1_data', 'mul_1'),
455                                  ('mul_1_w', 'mul_1'),
456                                  ('mul_1', 'mul_1_data'),
457                                  ('mul_1_data', 'add_1'),
458                                  ('add_1_w', 'add_1'),
459                                  ('add_1', 'add_1_data'),
460                                  ('add_1_data', 'concat_1'),
461                                  ('concat_1', 'concat_1_data'),
462                                  ('mul_1_data', 'placeholder_2'),
463                                  ('placeholder_2', 'placeholder_2_data'),
464                                  ('placeholder_2_data', 'concat_1'),
465                                  ('concat_1_data', 'op_output')
466                                  ],
467                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
468                                  'placeholder_2_data': {'shape': np.array([1, 227, 227, 3])},
469                                  'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
470                                  'mul_1_w': {'shape': np.array([1]), 'value': 6},
471                                  'add_1_w': {'shape': np.array([1]), 'value': np.array([6])},
472                                  },
473                                 nodes_with_edges_only=True)
474
475         graph.graph['layout'] = 'NHWC'
476         fuse_mul_add_sequence(graph)
477         self.assertTrue(len(graph.node) == len(graph_ref.node),
478                         "Graphs has different number of nodes: {} and {}".format(len(graph.node), len(graph_ref.node)))
479
480         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
481         self.assertTrue(flag, resp)
482
483     #                 +-------->Placeholder                          +->Placeholder
484     #                 |            |           =>                    |            |
485     # Placeholder--->Mul->Add->Mul-+->Concat         Placeholder--->Mul-->Mul-----+->Concat
486     def test_fuse_lin_seq_7(self):
487         graph = build_graph(nodes_attributes,
488                             [('placeholder_1', 'placeholder_1_data'),
489                              ('placeholder_1_data', 'mul_1'),
490                              ('mul_1_w', 'mul_1'),
491                              ('mul_1', 'mul_1_data'),
492                              ('mul_1_data', 'add_1'),
493                              ('add_1_w', 'add_1'),
494                              ('add_1', 'add_1_data'),
495                              ('add_1_data', 'mul_2'),
496                              ('mul_2_w', 'mul_2'),
497                              ('mul_2', 'mul_2_data'),
498                              ('mul_2_data', 'concat_1'),
499                              ('concat_1', 'concat_1_data'),
500                              ('mul_1_data', 'placeholder_2'),
501                              ('placeholder_2', 'placeholder_2_data'),
502                              ('placeholder_2_data', 'concat_1'),
503                              ('concat_1_data', 'op_output')
504                              ],
505                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
506                              'placeholder_2_data': {'shape': np.array([1, 227, 227, 3])},
507                              'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
508                              'add_1_data': {'shape': np.array([1, 227, 227, 3])},
509                              'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
510                              'mul_1_w': {'shape': np.array([1]), 'value': 6},
511                              'add_1_w': {'shape': np.array([1]), 'value': 0},
512                              'mul_2_w': {'shape': np.array([1]), 'value': 6},
513                              },
514                             nodes_with_edges_only=True)
515
516         graph_ref = build_graph(nodes_attributes,
517                                 [('placeholder_1', 'placeholder_1_data'),
518                                  ('placeholder_1_data', 'mul_1'),
519                                  ('mul_1_w', 'mul_1'),
520                                  ('mul_1', 'mul_1_data'),
521                                  ('mul_1_data', 'mul_2'),
522                                  ('mul_2_w', 'mul_2'),
523                                  ('mul_2', 'mul_2_data'),
524                                  ('mul_2_data', 'concat_1'),
525                                  ('concat_1', 'concat_1_data'),
526                                  ('mul_1_data', 'placeholder_2'),
527                                  ('placeholder_2', 'placeholder_2_data'),
528                                  ('placeholder_2_data', 'concat_1'),
529                                  ('concat_1_data', 'op_output')
530                                  ],
531                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
532                                  'placeholder_2_data': {'shape': np.array([1, 227, 227, 3])},
533                                  'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
534                                  'mul_1_w': {'shape': np.array([1]), 'value': 6},
535                                  'mul_2_w': {'shape': np.array([1]), 'value': np.array([6])},
536                                  },
537                                 nodes_with_edges_only=True)
538
539         graph.graph['layout'] = 'NHWC'
540         fuse_mul_add_sequence(graph)
541         self.assertTrue(len(graph.node) == len(graph_ref.node),
542                         "Graphs has different number of nodes: {} and {}".format(len(graph.node), len(graph_ref.node)))
543
544         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
545         self.assertTrue(flag, resp)
546
547     # Placeholder--->Mul->Add->Mul-+->Concat         Placeholder->Concat
548     def test_fuse_lin_seq_8(self):
549         graph = build_graph(nodes_attributes,
550                             [('placeholder_1', 'placeholder_1_data'),
551                              ('placeholder_1_data', 'mul_1'),
552                              ('mul_1_w', 'mul_1'),
553                              ('mul_1', 'mul_1_data'),
554                              ('mul_1_data', 'add_1'),
555                              ('add_1_w', 'add_1'),
556                              ('add_1', 'add_1_data'),
557                              ('add_1_data', 'mul_2'),
558                              ('mul_2_w', 'mul_2'),
559                              ('mul_2', 'mul_2_data'),
560                              ('mul_2_data', 'concat_1'),
561                              ('concat_1', 'concat_1_data'),
562                              ('concat_1_data', 'op_output')
563                              ],
564                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
565                              'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
566                              'add_1_data': {'shape': np.array([1, 227, 227, 3])},
567                              'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
568                              'mul_1_w': {'shape': np.array([1]), 'value': 1},
569                              'add_1_w': {'shape': np.array([1]), 'value': 0},
570                              'mul_2_w': {'shape': np.array([1]), 'value': 1},
571                              },
572                             nodes_with_edges_only=True)
573
574         graph_ref = build_graph(nodes_attributes,
575                                 [('placeholder_1', 'placeholder_1_data'),
576                                  ('placeholder_1_data', 'concat_1'),
577                                  ('concat_1', 'concat_1_data'),
578                                  ('concat_1_data', 'op_output')
579                                  ],
580                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])}},
581                                 nodes_with_edges_only=True)
582
583         graph.graph['layout'] = 'NHWC'
584         fuse_mul_add_sequence(graph)
585         self.assertTrue(len(graph.node) == len(graph_ref.node),
586                         "Graphs has different number of nodes: {} and {}".format(len(graph.node), len(graph_ref.node)))
587
588         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
589         self.assertTrue(flag, resp)
590
591     # Placeholder--->Mul->Add->Mul-+->Concat         Placeholder->Mul->Add->Concat
592     def test_fuse_lin_seq_9(self):
593         graph = build_graph(nodes_attributes,
594                             [('placeholder_1', 'placeholder_1_data'),
595                              ('placeholder_1_data', 'mul_1'),
596                              ('mul_1_w', 'mul_1'),
597                              ('mul_1', 'mul_1_data'),
598                              ('mul_1_data', 'add_1'),
599                              ('add_1_w', 'add_1'),
600                              ('add_1', 'add_1_data'),
601                              ('add_1_data', 'mul_2'),
602                              ('mul_2_w', 'mul_2'),
603                              ('mul_2', 'mul_2_data'),
604                              ('mul_2_data', 'concat_1'),
605                              ('concat_1', 'concat_1_data'),
606                              ('concat_1_data', 'op_output')
607                              ],
608                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
609                              'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
610                              'add_1_data': {'shape': np.array([1, 227, 227, 3])},
611                              'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
612                              'mul_1_w': {'shape': np.array([1]), 'value': 6},
613                              'add_1_w': {'shape': np.array([1]), 'value': 6},
614                              'mul_2_w': {'shape': np.array([1]), 'value': 6},
615                              },
616                             nodes_with_edges_only=True)
617
618         graph_ref = build_graph(nodes_attributes,
619                                 [('placeholder_1', 'placeholder_1_data'),
620                                  ('placeholder_1_data', 'mul_1'),
621                                  ('mul_1_w', 'mul_1'),
622                                  ('mul_1', 'mul_1_data'),
623                                  ('mul_1_data', 'add_1'),
624                                  ('add_1_w', 'add_1'),
625                                  ('add_1', 'add_1_data'),
626                                  ('add_1_data', 'concat_1'),
627                                  ('concat_1', 'concat_1_data'),
628                                  ('concat_1_data', 'op_output')
629                                  ],
630                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
631                                  'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
632                                  'add_1_data': {'shape': np.array([1, 227, 227, 3])},
633                                  'mul_1_w': {'shape': np.array([1]), 'value': np.array([36])},
634                                  'add_1_w': {'shape': np.array([1]), 'value': np.array([36])},
635                                  },
636                                 nodes_with_edges_only=True)
637
638         graph.graph['layout'] = 'NHWC'
639         fuse_mul_add_sequence(graph)
640         self.assertTrue(len(graph.node) == len(graph_ref.node),
641                         "Graphs has different number of nodes: {} and {}".format(len(graph.node), len(graph_ref.node)))
642
643         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
644         self.assertTrue(flag, resp)
645
646     # Placeholder--->Mul->Add->Mul-+->Concat         Placeholder->Mul->Add->Concat
647     def test_fuse_lin_seq_10(self):
648         graph = build_graph(nodes_attributes,
649                             [('placeholder_1', 'placeholder_1_data'),
650                              ('placeholder_1_data', 'mul_1'),
651                              ('mul_1_w', 'mul_1'),
652                              ('mul_1', 'mul_1_data'),
653                              ('mul_1_data', 'add_1'),
654                              ('add_1_w', 'add_1'),
655                              ('add_1', 'add_1_data'),
656                              ('add_1_data', 'mul_2'),
657                              ('mul_2_w', 'mul_2'),
658                              ('mul_2', 'mul_2_data'),
659                              ('mul_2_data', 'concat_1'),
660                              ('concat_1', 'concat_1_data'),
661                              ('concat_1_data', 'op_output')
662                              ],
663                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
664                              'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
665                              'add_1_data': {'shape': np.array([1, 227, 227, 3])},
666                              'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
667                              'mul_1_w': {'shape': np.array([1]), 'value': 6},
668                              'add_1_w': {'shape': np.array([3]), 'value': np.array([6, 6, 6])},
669                              'mul_2_w': {'shape': np.array([1]), 'value': 6},
670                              },
671                             nodes_with_edges_only=True)
672
673         graph_ref = build_graph(nodes_attributes,
674                                 [('placeholder_1', 'placeholder_1_data'),
675                                  ('placeholder_1_data', 'mul_1'),
676                                  ('mul_1_w', 'mul_1'),
677                                  ('mul_1', 'mul_1_data'),
678                                  ('mul_1_data', 'add_1'),
679                                  ('add_1_w', 'add_1'),
680                                  ('add_1', 'add_1_data'),
681                                  ('add_1_data', 'concat_1'),
682                                  ('concat_1', 'concat_1_data'),
683                                  ('concat_1_data', 'op_output')
684                                  ],
685                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
686                                  'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
687                                  'add_1_data': {'shape': np.array([1, 227, 227, 3])},
688                                  'mul_1_w': {'shape': np.array([3]), 'value': np.array([36, 36, 36])},
689                                  'add_1_w': {'shape': np.array([3]), 'value': np.array([36, 36, 36])},
690                                  },
691                                 nodes_with_edges_only=True)
692
693         graph.graph['layout'] = 'NHWC'
694         fuse_mul_add_sequence(graph)
695         self.assertTrue(len(graph.node) == len(graph_ref.node),
696                         "Graphs has different number of nodes: {} and {}".format(len(graph.node), len(graph_ref.node)))
697
698         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
699         self.assertTrue(flag, resp)
700
701     # Placeholder-+->Mul->Add->Mul-+->Concat
702     #             |                |            With 'can_be_fused' = False
703     #             +----------------+
704     def test_fuse_lin_seq_11(self):
705         graph = build_graph(nodes_attributes,
706                             [('placeholder_1', 'placeholder_1_data'),
707                              ('placeholder_1_data', 'mul_1'),
708                              ('mul_1_w', 'mul_1'),
709                              ('mul_1', 'mul_1_data'),
710                              ('mul_1_data', 'add_1'),
711                              ('add_1_w', 'add_1'),
712                              ('add_1', 'add_1_data'),
713                              ('add_1_data', 'mul_2'),
714                              ('mul_2_w', 'mul_2'),
715                              ('mul_2', 'mul_2_data'),
716                              ('mul_2_data', 'concat_1'),
717                              ('concat_1', 'concat_1_data'),
718                              ('placeholder_1_data', 'concat_1'),
719                              ('concat_1_data', 'op_output')
720                              ],
721                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
722                              'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
723                              'add_1_data': {'shape': np.array([1, 227, 227, 3])},
724                              'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
725                              'mul_1_w': {'shape': np.array([1]), 'value': 6},
726                              'add_1_w': {'shape': np.array([1]), 'value': 6},
727                              'mul_2_w': {'shape': np.array([1]), 'value': 6},
728                              'mul_1': {'can_be_fused': False},
729                              'add_1': {'can_be_fused': False},
730                              },
731                             nodes_with_edges_only=True)
732
733         graph_ref = build_graph(nodes_attributes,
734                                 [('placeholder_1', 'placeholder_1_data'),
735                                  ('placeholder_1_data', 'mul_1'),
736                                  ('mul_1_w', 'mul_1'),
737                                  ('mul_1', 'mul_1_data'),
738                                  ('mul_1_data', 'add_1'),
739                                  ('add_1_w', 'add_1'),
740                                  ('add_1', 'add_1_data'),
741                                  ('add_1_data', 'mul_2'),
742                                  ('mul_2_w', 'mul_2'),
743                                  ('mul_2', 'mul_2_data'),
744                                  ('mul_2_data', 'concat_1'),
745                                  ('concat_1', 'concat_1_data'),
746                                  ('placeholder_1_data', 'concat_1'),
747                                  ('concat_1_data', 'op_output')
748                                  ],
749                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
750                                  'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
751                                  'add_1_data': {'shape': np.array([1, 227, 227, 3])},
752                                  'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
753                                  'mul_1_w': {'shape': np.array([1]), 'value': 6},
754                                  'add_1_w': {'shape': np.array([1]), 'value': 6},
755                                  'mul_2_w': {'shape': np.array([1]), 'value': 6},
756                                  'mul_1': {'can_be_fused': False},
757                                  'add_1': {'can_be_fused': False},
758                                  },
759                                 nodes_with_edges_only=True)
760
761         graph.graph['layout'] = 'NHWC'
762         fuse_mul_add_sequence(graph)
763         self.assertTrue(len(graph.node) == len(graph_ref.node),
764                         "Graphs has different number of nodes: {} and {}".format(len(graph.node), len(graph_ref.node)))
765
766         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
767         self.assertTrue(flag, resp)
768
769     # Placeholder-+->Mul->Add->Mul-+->Concat
770     #             |                |            With 'can_be_fused' = False
771     #             +----------------+
772     def test_fuse_lin_seq_12(self):
773         graph = build_graph(nodes_attributes,
774                             [('placeholder_1', 'placeholder_1_data'),
775                              ('placeholder_1_data', 'mul_1'),
776                              ('mul_1_w', 'mul_1'),
777                              ('mul_1', 'mul_1_data'),
778                              ('mul_1_data', 'add_1'),
779                              ('add_1_w', 'add_1'),
780                              ('add_1', 'add_1_data'),
781                              ('add_1_data', 'mul_2'),
782                              ('mul_2_w', 'mul_2'),
783                              ('mul_2', 'mul_2_data'),
784                              ('mul_2_data', 'concat_1'),
785                              ('concat_1', 'concat_1_data'),
786                              ('placeholder_1_data', 'concat_1'),
787                              ('concat_1_data', 'op_output')
788                              ],
789                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
790                              'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
791                              'add_1_data': {'shape': np.array([1, 227, 227, 3])},
792                              'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
793                              'mul_1_w': {'shape': np.array([1]), 'value': 6},
794                              'add_1_w': {'shape': np.array([1]), 'value': 6},
795                              'mul_2_w': {'shape': np.array([1]), 'value': 6},
796                              'add_1': {'can_be_fused': False},
797                              },
798                             nodes_with_edges_only=True)
799
800         graph_ref = build_graph(nodes_attributes,
801                                 [('placeholder_1', 'placeholder_1_data'),
802                                  ('placeholder_1_data', 'mul_1'),
803                                  ('mul_1_w', 'mul_1'),
804                                  ('mul_1', 'mul_1_data'),
805                                  ('mul_1_data', 'add_1'),
806                                  ('add_1_w', 'add_1'),
807                                  ('add_1', 'add_1_data'),
808                                  ('add_1_data', 'mul_2'),
809                                  ('mul_2_w', 'mul_2'),
810                                  ('mul_2', 'mul_2_data'),
811                                  ('mul_2_data', 'concat_1'),
812                                  ('concat_1', 'concat_1_data'),
813                                  ('placeholder_1_data', 'concat_1'),
814                                  ('concat_1_data', 'op_output')
815                                  ],
816                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
817                                  'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
818                                  'add_1_data': {'shape': np.array([1, 227, 227, 3])},
819                                  'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
820                                  'mul_1_w': {'shape': np.array([1]), 'value': 6},
821                                  'add_1_w': {'shape': np.array([1]), 'value': 6},
822                                  'mul_2_w': {'shape': np.array([1]), 'value': 6},
823                                  'add_1': {'can_be_fused': False},
824                                  },
825                                 nodes_with_edges_only=True)
826
827         graph.graph['layout'] = 'NHWC'
828         fuse_mul_add_sequence(graph)
829         self.assertTrue(len(graph.node) == len(graph_ref.node),
830                         "Graphs has different number of nodes: {} and {}".format(len(graph.node), len(graph_ref.node)))
831
832         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
833         self.assertTrue(flag, resp)
834
835     # Placeholder-+->Mul->Add->Mul-+->Concat
836     #             |                |
837     #             +->Mul->Mul->----+  (This Mul ops has shared weights with upper Mul ops)
838     def test_fuse_lin_seq_shared_weights_1(self):
839         graph = build_graph(nodes_attributes,
840                             [('placeholder_1', 'placeholder_1_data'),
841                              ('placeholder_1_data', 'mul_1'),
842                              ('mul_1_w', 'mul_1'),
843                              ('mul_1', 'mul_1_data'),
844                              ('mul_1_data', 'add_1'),
845                              ('add_1_w', 'add_1'),
846                              ('add_1', 'add_1_data'),
847                              ('add_1_data', 'mul_2'),
848                              ('mul_2_w', 'mul_2'),
849                              ('mul_2', 'mul_2_data'),
850                              ('mul_2_data', 'concat_1'),
851                              ('concat_1', 'concat_1_data'),
852                              ('placeholder_1_data', 'mul_3'),
853                              ('mul_3', 'mul_3_data'),
854                              ('mul_1_w', 'mul_3'),
855                              ('mul_3_data', 'mul_4'),
856                              ('mul_2_w', 'mul_4'),
857                              ('mul_4', 'mul_4_data'),
858                              ('mul_4_data', 'concat_1'),
859                              ('concat_1_data', 'op_output')
860                              ],
861                             {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
862                              'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
863                              'add_1_data': {'shape': np.array([1, 227, 227, 3])},
864                              'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
865                              'mul_3_data': {'shape': np.array([1, 227, 227, 3])},
866                              'mul_4_data': {'shape': np.array([1, 227, 227, 3])},
867                              'mul_1_w': {'shape': np.array([1]), 'value': 6},
868                              'add_1_w': {'shape': np.array([1]), 'value': 6},
869                              'mul_2_w': {'shape': np.array([1]), 'value': 6},
870                              },
871                             nodes_with_edges_only=True)
872
873         graph_ref = build_graph(nodes_attributes,
874                                 [('placeholder_1', 'placeholder_1_data'),
875                                  ('placeholder_1_data', 'mul_1'),
876                                  ('mul_1_w', 'mul_1'),
877                                  ('mul_1', 'mul_1_data'),
878                                  ('mul_1_data', 'add_1'),
879                                  ('add_1_w', 'add_1'),
880                                  ('add_1', 'add_1_data'),
881                                  ('add_1_data', 'concat_1'),
882                                  ('concat_1', 'concat_1_data'),
883                                  ('placeholder_1_data', 'mul_3'),
884                                  ('mul_3', 'mul_3_data'),
885                                  ('mul_3_w', 'mul_3'),
886                                  ('mul_3_data', 'concat_1'),
887                                  ('concat_1_data', 'op_output')
888                                  ],
889                                 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
890                                  'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
891                                  'add_1_data': {'shape': np.array([1, 227, 227, 3])},
892                                  'mul_3_data': {'shape': np.array([1, 227, 227, 3])},
893                                  'mul_1_w': {'shape': np.array([1]), 'value': np.array([36])},
894                                  'mul_3_w': {'shape': np.array([1]), 'value': np.array([36])},
895                                  'add_1_w': {'shape': np.array([1]), 'value': np.array([36])},
896                                  'mul_1': {'can_be_fused': True},
897                                  'add_1': {'can_be_fused': True},
898                                  },
899                                 nodes_with_edges_only=True)
900
901         graph.graph['layout'] = 'NHWC'
902         fuse_mul_add_sequence(graph)
903         self.assertTrue(len(graph.node) == len(graph_ref.node),
904                         "Graphs has different number of nodes: {} and {}".format(len(graph.node),
905                                                                                  len(graph_ref.node)))
906
907         (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
908         self.assertTrue(flag, resp)