2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from mo.middle.passes.fusing.fuse_linear_seq import fuse_mul_add_sequence
22 from mo.utils.unittest.graph import build_graph, compare_graphs
25 'placeholder_1': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
26 'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
28 'scaleshift_1': {'type': 'ScaleShift', 'kind': 'op', 'op': 'ScaleShift'},
29 'scaleshift_1_w': {'value': None, 'shape': None, 'kind': 'data'},
30 'scaleshift_1_b': {'value': None, 'shape': None, 'kind': 'data'},
31 'scaleshift_1_data': {'value': None, 'shape': None, 'kind': 'data'},
32 # Mul and Add operations
33 'mul_1': {'type': 'Mul', 'kind': 'op', 'op': 'Mul', 'can_be_fused': True},
34 'mul_1_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
35 'mul_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
36 'add_1': {'type': 'Add', 'kind': 'op', 'op': 'Add', 'can_be_fused': True},
37 'add_1_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
38 'add_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
39 # Mul2 and Add2 operations
40 'mul_2': {'type': 'Mul', 'kind': 'op', 'op': 'Mul', 'can_be_fused': True},
41 'mul_2_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
42 'mul_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
43 'add_2': {'type': 'Add', 'kind': 'op', 'op': 'Add', 'can_be_fused': True},
44 'add_2_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
45 'add_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
46 # Mul3 and Add3 operations
47 'mul_3': {'type': 'Mul', 'kind': 'op', 'op': 'Mul', 'can_be_fused': True},
48 'mul_3_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
49 'mul_3_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
50 'add_3': {'type': 'Add', 'kind': 'op', 'op': 'Add', 'can_be_fused': True},
51 'add_3_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
52 'add_3_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
53 # Mul4 and Add4 operations
54 'mul_4': {'type': 'Mul', 'kind': 'op', 'op': 'Mul', 'can_be_fused': True},
55 'mul_4_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
56 'mul_4_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
57 'add_4': {'type': 'Add', 'kind': 'op', 'op': 'Add', 'can_be_fused': True},
58 'add_4_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
59 'add_4_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
61 'concat_1': {'type': 'Concat', 'kind': 'op', 'op': 'Concat'},
62 'concat_1_data': {'value': None, 'shape': None, 'kind': 'data'},
64 'conv_1': {'type': 'Convolution', 'kind': 'op', 'op': 'Conv2D', 'layout': 'NHWC'},
65 'conv_1_w': {'value': None, 'shape': None, 'kind': 'data'},
66 'conv_1_b': {'value': None, 'shape': None, 'kind': 'data'},
67 'conv_1_data': {'value': None, 'shape': None, 'kind': 'data'},
68 'conv_2': {'type': 'Convolution', 'kind': 'op', 'op': 'Conv2D', 'layout': 'NHWC'},
69 'conv_2_w': {'value': None, 'shape': None, 'kind': 'data'},
70 'conv_2_b': {'value': None, 'shape': None, 'kind': 'data'},
71 'conv_2_data': {'value': None, 'shape': None, 'kind': 'data'},
73 'fc_1': {'type': 'FullyConnected', 'kind': 'op', 'op': 'InnerProduct', 'layout': 'NHWC'},
74 'fc_1_w': {'value': None, 'shape': None, 'kind': 'data'},
75 'fc_1_b': {'value': None, 'shape': None, 'kind': 'data'},
76 'fc_1_data': {'value': None, 'shape': None, 'kind': 'data'},
78 'placeholder_2': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
79 'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
80 'placeholder_3': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
81 'placeholder_3_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
82 'op_output': { 'kind': 'op', 'op': 'OpOutput'}
86 # Unit tests for fuse_mul_add_sequence
87 class LinSeqFusingTests(unittest.TestCase):
88 # Placeholder-+->Mul->Add->Mul-+->Concat
91 def test_fuse_lin_seq_1(self):
92 graph = build_graph(nodes_attributes,
93 [('placeholder_1', 'placeholder_1_data'),
94 ('placeholder_1_data', 'mul_1'),
96 ('mul_1', 'mul_1_data'),
97 ('mul_1_data', 'add_1'),
99 ('add_1', 'add_1_data'),
100 ('add_1_data', 'mul_2'),
101 ('mul_2_w', 'mul_2'),
102 ('mul_2', 'mul_2_data'),
103 ('mul_2_data', 'concat_1'),
104 ('concat_1', 'concat_1_data'),
105 ('placeholder_1_data', 'concat_1'),
106 ('concat_1_data', 'op_output')
108 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
109 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
110 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
111 'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
112 'mul_1_w': {'shape': np.array([1]), 'value': 6},
113 'add_1_w': {'shape': np.array([1]), 'value': 6},
114 'mul_2_w': {'shape': np.array([1]), 'value': 6},
116 nodes_with_edges_only=True)
118 graph_ref = build_graph(nodes_attributes,
119 [('placeholder_1', 'placeholder_1_data'),
120 ('placeholder_1_data', 'mul_1'),
121 ('mul_1_w', 'mul_1'),
122 ('mul_1', 'mul_1_data'),
123 ('mul_1_data', 'add_1'),
124 ('add_1_w', 'add_1'),
125 ('add_1', 'add_1_data'),
126 ('add_1_data', 'concat_1'),
127 ('concat_1', 'concat_1_data'),
128 ('placeholder_1_data', 'concat_1'),
129 ('concat_1_data', 'op_output')
131 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
132 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
133 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
134 'mul_1_w': {'shape': np.array([1]), 'value': np.array([36])},
135 'add_1_w': {'shape': np.array([1]), 'value': np.array([36])},
136 'mul_1': {'can_be_fused': True},
137 'add_1': {'can_be_fused': True},
139 nodes_with_edges_only=True)
141 graph.graph['layout'] = 'NHWC'
142 fuse_mul_add_sequence(graph)
143 self.assertTrue(len(graph.node) == len(graph_ref.node),
144 "Graphs has different number of nodes: {} and {}".format(len(graph.node), len(graph_ref.node)))
146 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
147 self.assertTrue(flag, resp)
151 # Placeholder-+->Mul->Add->Mul-+---------------+->Concat
153 # +-->Placeholder----+
154 def test_fuse_lin_seq_2(self):
155 graph = build_graph(nodes_attributes,
156 [('placeholder_1', 'placeholder_1_data'),
157 ('placeholder_1_data', 'mul_1'),
158 ('mul_1_w', 'mul_1'),
159 ('mul_1', 'mul_1_data'),
160 ('mul_1_data', 'add_1'),
161 ('add_1_w', 'add_1'),
162 ('add_1', 'add_1_data'),
163 ('add_1_data', 'mul_2'),
164 ('mul_2_w', 'mul_2'),
165 ('mul_2', 'mul_2_data'),
166 ('mul_2_data', 'concat_1'),
167 ('concat_1', 'concat_1_data'),
168 ('placeholder_1_data', 'concat_1'),
169 ('mul_2_data', 'placeholder_2'),
170 ('placeholder_2', 'placeholder_2_data'),
171 ('placeholder_2_data', 'concat_1'),
172 ('concat_1_data', 'op_output')
174 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
175 'placeholder_2_data': {'shape': np.array([1, 227, 227, 3])},
176 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
177 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
178 'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
179 'mul_1_w': {'shape': np.array([1]), 'value': 6},
180 'add_1_w': {'shape': np.array([1]), 'value': 6},
181 'mul_2_w': {'shape': np.array([1]), 'value': 6},
183 nodes_with_edges_only=True)
185 graph_ref = build_graph(nodes_attributes,
186 [('placeholder_1', 'placeholder_1_data'),
187 ('placeholder_1_data', 'mul_1'),
188 ('mul_1_w', 'mul_1'),
189 ('mul_1', 'mul_1_data'),
190 ('mul_1_data', 'add_1'),
191 ('add_1_w', 'add_1'),
192 ('add_1', 'add_1_data'),
193 ('add_1_data', 'concat_1'),
194 ('concat_1', 'concat_1_data'),
195 ('placeholder_1_data', 'concat_1'),
196 ('add_1_data', 'placeholder_2'),
197 ('placeholder_2', 'placeholder_2_data'),
198 ('placeholder_2_data', 'concat_1'),
199 ('concat_1_data', 'op_output')
201 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
202 'placeholder_2_data': {'shape': np.array([1, 227, 227, 3])},
203 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
204 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
205 'mul_1_w': {'shape': np.array([1]), 'value': np.array([36])},
206 'add_1_w': {'shape': np.array([1]), 'value': np.array([36])},
207 'mul_1': {'can_be_fused': True},
208 'add_1': {'can_be_fused': True},
210 nodes_with_edges_only=True)
211 graph.graph['layout'] = 'NHWC'
212 fuse_mul_add_sequence(graph)
213 self.assertTrue(len(graph.node) == len(graph_ref.node),
214 "Graphs has different number of nodes: {} and {}".format(len(graph.node), len(graph_ref.node)))
216 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
217 self.assertTrue(flag, resp)
220 # | | => The same graph
221 # Placeholder--->Mul->Add->Mul--+->Concat
222 def test_fuse_lin_seq_3(self):
223 graph = build_graph(nodes_attributes,
224 [('placeholder_1', 'placeholder_1_data'),
225 ('placeholder_1_data', 'mul_1'),
226 ('mul_1_w', 'mul_1'),
227 ('mul_1', 'mul_1_data'),
228 ('mul_1_data', 'add_1'),
229 ('add_1_w', 'add_1'),
230 ('add_1', 'add_1_data'),
231 ('add_1_data', 'mul_2'),
232 ('mul_2_w', 'mul_2'),
233 ('mul_2', 'mul_2_data'),
234 ('mul_2_data', 'concat_1'),
235 ('concat_1', 'concat_1_data'),
236 ('add_1_data', 'placeholder_2'),
237 ('placeholder_2', 'placeholder_2_data'),
238 ('placeholder_2_data', 'concat_1'),
239 ('concat_1_data', 'op_output')
241 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
242 'placeholder_2_data': {'shape': np.array([1, 227, 227, 3])},
243 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
244 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
245 'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
246 'mul_1_w': {'shape': np.array([1]), 'value': 6},
247 'add_1_w': {'shape': np.array([1]), 'value': 6},
248 'mul_2_w': {'shape': np.array([1]), 'value': 6},
250 nodes_with_edges_only=True)
252 graph_ref = build_graph(nodes_attributes,
253 [('placeholder_1', 'placeholder_1_data'),
254 ('placeholder_1_data', 'mul_1'),
255 ('mul_1_w', 'mul_1'),
256 ('mul_1', 'mul_1_data'),
257 ('mul_1_data', 'add_1'),
258 ('add_1_w', 'add_1'),
259 ('add_1', 'add_1_data'),
260 ('add_1_data', 'mul_2'),
261 ('mul_2_w', 'mul_2'),
262 ('mul_2', 'mul_2_data'),
263 ('mul_2_data', 'concat_1'),
264 ('concat_1', 'concat_1_data'),
265 ('add_1_data', 'placeholder_2'),
266 ('placeholder_2', 'placeholder_2_data'),
267 ('placeholder_2_data', 'concat_1'),
268 ('concat_1_data', 'op_output')
270 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
271 'placeholder_2_data': {'shape': np.array([1, 227, 227, 3])},
272 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
273 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
274 'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
275 'mul_1_w': {'shape': np.array([1]), 'value': 6},
276 'add_1_w': {'shape': np.array([1]), 'value': 6},
277 'mul_2_w': {'shape': np.array([1]), 'value': 6},
279 nodes_with_edges_only=True)
281 fuse_mul_add_sequence(graph)
282 self.assertTrue(len(graph.node) == len(graph_ref.node),
283 "Graphs has different number of nodes: {} and {}".format(len(graph.node), len(graph_ref.node)))
285 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
286 self.assertTrue(flag, resp)
288 # +-------->Placeholder +-------->Placeholder
290 # Placeholder--->Mul->Add->Mul-+->Concat Placeholder-+->Mul->Mul->Add-+->Concat
291 def test_fuse_lin_seq_4(self):
292 graph = build_graph(nodes_attributes,
293 [('placeholder_1', 'placeholder_1_data'),
294 ('placeholder_1_data', 'mul_1'),
295 ('mul_1_w', 'mul_1'),
296 ('mul_1', 'mul_1_data'),
297 ('mul_1_data', 'add_1'),
298 ('add_1_w', 'add_1'),
299 ('add_1', 'add_1_data'),
300 ('add_1_data', 'mul_2'),
301 ('mul_2_w', 'mul_2'),
302 ('mul_2', 'mul_2_data'),
303 ('mul_2_data', 'concat_1'),
304 ('concat_1', 'concat_1_data'),
305 ('mul_1_data', 'placeholder_2'),
306 ('placeholder_2', 'placeholder_2_data'),
307 ('placeholder_2_data', 'concat_1'),
308 ('concat_1_data', 'op_output')
310 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
311 'placeholder_2_data': {'shape': np.array([1, 227, 227, 3])},
312 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
313 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
314 'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
315 'mul_1_w': {'shape': np.array([1]), 'value': 6},
316 'add_1_w': {'shape': np.array([1]), 'value': 6},
317 'mul_2_w': {'shape': np.array([1]), 'value': 6},
319 nodes_with_edges_only=True)
321 graph_ref = build_graph(nodes_attributes,
322 [('placeholder_1', 'placeholder_1_data'),
323 ('placeholder_1_data', 'mul_1'),
324 ('mul_1_w', 'mul_1'),
325 ('mul_1', 'mul_1_data'),
326 ('mul_1_data', 'mul_2'),
327 ('mul_2_w', 'mul_2'),
328 ('mul_2', 'mul_2_data'),
329 ('mul_2_data', 'add_1'),
330 ('add_1_w', 'add_1'),
331 ('add_1', 'add_1_data'),
332 ('add_1_data', 'concat_1'),
333 ('concat_1', 'concat_1_data'),
334 ('mul_1_data', 'placeholder_2'),
335 ('placeholder_2', 'placeholder_2_data'),
336 ('placeholder_2_data', 'concat_1'),
337 ('concat_1_data', 'op_output')
339 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
340 'placeholder_2_data': {'shape': np.array([1, 227, 227, 3])},
341 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
342 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
343 'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
344 'mul_1_w': {'shape': np.array([1]), 'value': 6},
345 'add_1_w': {'shape': np.array([1]), 'value': np.array([36])},
346 'mul_2_w': {'shape': np.array([1]), 'value': np.array([6])},
348 nodes_with_edges_only=True)
350 graph.graph['layout'] = 'NHWC'
351 fuse_mul_add_sequence(graph)
352 self.assertTrue(len(graph.node) == len(graph_ref.node),
353 "Graphs has different number of nodes: {} and {}".format(len(graph.node), len(graph_ref.node)))
355 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
356 self.assertTrue(flag, resp)
358 # +-------->Placeholder +->Placeholder
360 # Placeholder--->Mul->Add->Mul-+->Concat Placeholder--->Mul-----------+->Concat
361 def test_fuse_lin_seq_5(self):
362 graph = build_graph(nodes_attributes,
363 [('placeholder_1', 'placeholder_1_data'),
364 ('placeholder_1_data', 'mul_1'),
365 ('mul_1_w', 'mul_1'),
366 ('mul_1', 'mul_1_data'),
367 ('mul_1_data', 'add_1'),
368 ('add_1_w', 'add_1'),
369 ('add_1', 'add_1_data'),
370 ('add_1_data', 'mul_2'),
371 ('mul_2_w', 'mul_2'),
372 ('mul_2', 'mul_2_data'),
373 ('mul_2_data', 'concat_1'),
374 ('concat_1', 'concat_1_data'),
375 ('mul_1_data', 'placeholder_2'),
376 ('placeholder_2', 'placeholder_2_data'),
377 ('placeholder_2_data', 'concat_1'),
378 ('concat_1_data', 'op_output')
380 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
381 'placeholder_2_data': {'shape': np.array([1, 227, 227, 3])},
382 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
383 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
384 'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
385 'mul_1_w': {'shape': np.array([1]), 'value': 6},
386 'add_1_w': {'shape': np.array([1]), 'value': 0},
387 'mul_2_w': {'shape': np.array([1]), 'value': 1},
389 nodes_with_edges_only=True)
391 graph_ref = build_graph(nodes_attributes,
392 [('placeholder_1', 'placeholder_1_data'),
393 ('placeholder_1_data', 'mul_1'),
394 ('mul_1_w', 'mul_1'),
395 ('mul_1', 'mul_1_data'),
396 ('mul_1_data', 'concat_1'),
397 ('concat_1', 'concat_1_data'),
398 ('mul_1_data', 'placeholder_2'),
399 ('placeholder_2', 'placeholder_2_data'),
400 ('placeholder_2_data', 'concat_1'),
401 ('concat_1_data', 'op_output')
403 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
404 'placeholder_2_data': {'shape': np.array([1, 227, 227, 3])},
405 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
406 'mul_1_w': {'shape': np.array([1]), 'value': 6},
408 nodes_with_edges_only=True)
410 graph.graph['layout'] = 'NHWC'
411 fuse_mul_add_sequence(graph)
413 self.assertTrue(len(graph.node) == len(graph_ref.node),
414 "Graphs has different number of nodes: {} and {}".format(len(graph.node), len(graph_ref.node)))
416 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
417 self.assertTrue(flag, resp)
419 # +-------->Placeholder +->Placeholder
421 # Placeholder--->Mul->Add->Mul-+->Concat Placeholder--->Mul-->Add-----+->Concat
422 def test_fuse_lin_seq_6(self):
423 graph = build_graph(nodes_attributes,
424 [('placeholder_1', 'placeholder_1_data'),
425 ('placeholder_1_data', 'mul_1'),
426 ('mul_1_w', 'mul_1'),
427 ('mul_1', 'mul_1_data'),
428 ('mul_1_data', 'add_1'),
429 ('add_1_w', 'add_1'),
430 ('add_1', 'add_1_data'),
431 ('add_1_data', 'mul_2'),
432 ('mul_2_w', 'mul_2'),
433 ('mul_2', 'mul_2_data'),
434 ('mul_2_data', 'concat_1'),
435 ('concat_1', 'concat_1_data'),
436 ('mul_1_data', 'placeholder_2'),
437 ('placeholder_2', 'placeholder_2_data'),
438 ('placeholder_2_data', 'concat_1'),
439 ('concat_1_data', 'op_output')
441 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
442 'placeholder_2_data': {'shape': np.array([1, 227, 227, 3])},
443 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
444 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
445 'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
446 'mul_1_w': {'shape': np.array([1]), 'value': 6},
447 'add_1_w': {'shape': np.array([1]), 'value': 6},
448 'mul_2_w': {'shape': np.array([1]), 'value': 1},
450 nodes_with_edges_only=True)
452 graph_ref = build_graph(nodes_attributes,
453 [('placeholder_1', 'placeholder_1_data'),
454 ('placeholder_1_data', 'mul_1'),
455 ('mul_1_w', 'mul_1'),
456 ('mul_1', 'mul_1_data'),
457 ('mul_1_data', 'add_1'),
458 ('add_1_w', 'add_1'),
459 ('add_1', 'add_1_data'),
460 ('add_1_data', 'concat_1'),
461 ('concat_1', 'concat_1_data'),
462 ('mul_1_data', 'placeholder_2'),
463 ('placeholder_2', 'placeholder_2_data'),
464 ('placeholder_2_data', 'concat_1'),
465 ('concat_1_data', 'op_output')
467 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
468 'placeholder_2_data': {'shape': np.array([1, 227, 227, 3])},
469 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
470 'mul_1_w': {'shape': np.array([1]), 'value': 6},
471 'add_1_w': {'shape': np.array([1]), 'value': np.array([6])},
473 nodes_with_edges_only=True)
475 graph.graph['layout'] = 'NHWC'
476 fuse_mul_add_sequence(graph)
477 self.assertTrue(len(graph.node) == len(graph_ref.node),
478 "Graphs has different number of nodes: {} and {}".format(len(graph.node), len(graph_ref.node)))
480 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
481 self.assertTrue(flag, resp)
483 # +-------->Placeholder +->Placeholder
485 # Placeholder--->Mul->Add->Mul-+->Concat Placeholder--->Mul-->Mul-----+->Concat
486 def test_fuse_lin_seq_7(self):
487 graph = build_graph(nodes_attributes,
488 [('placeholder_1', 'placeholder_1_data'),
489 ('placeholder_1_data', 'mul_1'),
490 ('mul_1_w', 'mul_1'),
491 ('mul_1', 'mul_1_data'),
492 ('mul_1_data', 'add_1'),
493 ('add_1_w', 'add_1'),
494 ('add_1', 'add_1_data'),
495 ('add_1_data', 'mul_2'),
496 ('mul_2_w', 'mul_2'),
497 ('mul_2', 'mul_2_data'),
498 ('mul_2_data', 'concat_1'),
499 ('concat_1', 'concat_1_data'),
500 ('mul_1_data', 'placeholder_2'),
501 ('placeholder_2', 'placeholder_2_data'),
502 ('placeholder_2_data', 'concat_1'),
503 ('concat_1_data', 'op_output')
505 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
506 'placeholder_2_data': {'shape': np.array([1, 227, 227, 3])},
507 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
508 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
509 'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
510 'mul_1_w': {'shape': np.array([1]), 'value': 6},
511 'add_1_w': {'shape': np.array([1]), 'value': 0},
512 'mul_2_w': {'shape': np.array([1]), 'value': 6},
514 nodes_with_edges_only=True)
516 graph_ref = build_graph(nodes_attributes,
517 [('placeholder_1', 'placeholder_1_data'),
518 ('placeholder_1_data', 'mul_1'),
519 ('mul_1_w', 'mul_1'),
520 ('mul_1', 'mul_1_data'),
521 ('mul_1_data', 'mul_2'),
522 ('mul_2_w', 'mul_2'),
523 ('mul_2', 'mul_2_data'),
524 ('mul_2_data', 'concat_1'),
525 ('concat_1', 'concat_1_data'),
526 ('mul_1_data', 'placeholder_2'),
527 ('placeholder_2', 'placeholder_2_data'),
528 ('placeholder_2_data', 'concat_1'),
529 ('concat_1_data', 'op_output')
531 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
532 'placeholder_2_data': {'shape': np.array([1, 227, 227, 3])},
533 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
534 'mul_1_w': {'shape': np.array([1]), 'value': 6},
535 'mul_2_w': {'shape': np.array([1]), 'value': np.array([6])},
537 nodes_with_edges_only=True)
539 graph.graph['layout'] = 'NHWC'
540 fuse_mul_add_sequence(graph)
541 self.assertTrue(len(graph.node) == len(graph_ref.node),
542 "Graphs has different number of nodes: {} and {}".format(len(graph.node), len(graph_ref.node)))
544 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
545 self.assertTrue(flag, resp)
547 # Placeholder--->Mul->Add->Mul-+->Concat Placeholder->Concat
548 def test_fuse_lin_seq_8(self):
549 graph = build_graph(nodes_attributes,
550 [('placeholder_1', 'placeholder_1_data'),
551 ('placeholder_1_data', 'mul_1'),
552 ('mul_1_w', 'mul_1'),
553 ('mul_1', 'mul_1_data'),
554 ('mul_1_data', 'add_1'),
555 ('add_1_w', 'add_1'),
556 ('add_1', 'add_1_data'),
557 ('add_1_data', 'mul_2'),
558 ('mul_2_w', 'mul_2'),
559 ('mul_2', 'mul_2_data'),
560 ('mul_2_data', 'concat_1'),
561 ('concat_1', 'concat_1_data'),
562 ('concat_1_data', 'op_output')
564 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
565 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
566 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
567 'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
568 'mul_1_w': {'shape': np.array([1]), 'value': 1},
569 'add_1_w': {'shape': np.array([1]), 'value': 0},
570 'mul_2_w': {'shape': np.array([1]), 'value': 1},
572 nodes_with_edges_only=True)
574 graph_ref = build_graph(nodes_attributes,
575 [('placeholder_1', 'placeholder_1_data'),
576 ('placeholder_1_data', 'concat_1'),
577 ('concat_1', 'concat_1_data'),
578 ('concat_1_data', 'op_output')
580 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])}},
581 nodes_with_edges_only=True)
583 graph.graph['layout'] = 'NHWC'
584 fuse_mul_add_sequence(graph)
585 self.assertTrue(len(graph.node) == len(graph_ref.node),
586 "Graphs has different number of nodes: {} and {}".format(len(graph.node), len(graph_ref.node)))
588 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
589 self.assertTrue(flag, resp)
591 # Placeholder--->Mul->Add->Mul-+->Concat Placeholder->Mul->Add->Concat
592 def test_fuse_lin_seq_9(self):
593 graph = build_graph(nodes_attributes,
594 [('placeholder_1', 'placeholder_1_data'),
595 ('placeholder_1_data', 'mul_1'),
596 ('mul_1_w', 'mul_1'),
597 ('mul_1', 'mul_1_data'),
598 ('mul_1_data', 'add_1'),
599 ('add_1_w', 'add_1'),
600 ('add_1', 'add_1_data'),
601 ('add_1_data', 'mul_2'),
602 ('mul_2_w', 'mul_2'),
603 ('mul_2', 'mul_2_data'),
604 ('mul_2_data', 'concat_1'),
605 ('concat_1', 'concat_1_data'),
606 ('concat_1_data', 'op_output')
608 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
609 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
610 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
611 'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
612 'mul_1_w': {'shape': np.array([1]), 'value': 6},
613 'add_1_w': {'shape': np.array([1]), 'value': 6},
614 'mul_2_w': {'shape': np.array([1]), 'value': 6},
616 nodes_with_edges_only=True)
618 graph_ref = build_graph(nodes_attributes,
619 [('placeholder_1', 'placeholder_1_data'),
620 ('placeholder_1_data', 'mul_1'),
621 ('mul_1_w', 'mul_1'),
622 ('mul_1', 'mul_1_data'),
623 ('mul_1_data', 'add_1'),
624 ('add_1_w', 'add_1'),
625 ('add_1', 'add_1_data'),
626 ('add_1_data', 'concat_1'),
627 ('concat_1', 'concat_1_data'),
628 ('concat_1_data', 'op_output')
630 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
631 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
632 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
633 'mul_1_w': {'shape': np.array([1]), 'value': np.array([36])},
634 'add_1_w': {'shape': np.array([1]), 'value': np.array([36])},
636 nodes_with_edges_only=True)
638 graph.graph['layout'] = 'NHWC'
639 fuse_mul_add_sequence(graph)
640 self.assertTrue(len(graph.node) == len(graph_ref.node),
641 "Graphs has different number of nodes: {} and {}".format(len(graph.node), len(graph_ref.node)))
643 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
644 self.assertTrue(flag, resp)
646 # Placeholder--->Mul->Add->Mul-+->Concat Placeholder->Mul->Add->Concat
647 def test_fuse_lin_seq_10(self):
648 graph = build_graph(nodes_attributes,
649 [('placeholder_1', 'placeholder_1_data'),
650 ('placeholder_1_data', 'mul_1'),
651 ('mul_1_w', 'mul_1'),
652 ('mul_1', 'mul_1_data'),
653 ('mul_1_data', 'add_1'),
654 ('add_1_w', 'add_1'),
655 ('add_1', 'add_1_data'),
656 ('add_1_data', 'mul_2'),
657 ('mul_2_w', 'mul_2'),
658 ('mul_2', 'mul_2_data'),
659 ('mul_2_data', 'concat_1'),
660 ('concat_1', 'concat_1_data'),
661 ('concat_1_data', 'op_output')
663 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
664 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
665 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
666 'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
667 'mul_1_w': {'shape': np.array([1]), 'value': 6},
668 'add_1_w': {'shape': np.array([3]), 'value': np.array([6, 6, 6])},
669 'mul_2_w': {'shape': np.array([1]), 'value': 6},
671 nodes_with_edges_only=True)
673 graph_ref = build_graph(nodes_attributes,
674 [('placeholder_1', 'placeholder_1_data'),
675 ('placeholder_1_data', 'mul_1'),
676 ('mul_1_w', 'mul_1'),
677 ('mul_1', 'mul_1_data'),
678 ('mul_1_data', 'add_1'),
679 ('add_1_w', 'add_1'),
680 ('add_1', 'add_1_data'),
681 ('add_1_data', 'concat_1'),
682 ('concat_1', 'concat_1_data'),
683 ('concat_1_data', 'op_output')
685 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
686 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
687 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
688 'mul_1_w': {'shape': np.array([3]), 'value': np.array([36, 36, 36])},
689 'add_1_w': {'shape': np.array([3]), 'value': np.array([36, 36, 36])},
691 nodes_with_edges_only=True)
693 graph.graph['layout'] = 'NHWC'
694 fuse_mul_add_sequence(graph)
695 self.assertTrue(len(graph.node) == len(graph_ref.node),
696 "Graphs has different number of nodes: {} and {}".format(len(graph.node), len(graph_ref.node)))
698 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
699 self.assertTrue(flag, resp)
701 # Placeholder-+->Mul->Add->Mul-+->Concat
702 # | | With 'can_be_fused' = False
704 def test_fuse_lin_seq_11(self):
705 graph = build_graph(nodes_attributes,
706 [('placeholder_1', 'placeholder_1_data'),
707 ('placeholder_1_data', 'mul_1'),
708 ('mul_1_w', 'mul_1'),
709 ('mul_1', 'mul_1_data'),
710 ('mul_1_data', 'add_1'),
711 ('add_1_w', 'add_1'),
712 ('add_1', 'add_1_data'),
713 ('add_1_data', 'mul_2'),
714 ('mul_2_w', 'mul_2'),
715 ('mul_2', 'mul_2_data'),
716 ('mul_2_data', 'concat_1'),
717 ('concat_1', 'concat_1_data'),
718 ('placeholder_1_data', 'concat_1'),
719 ('concat_1_data', 'op_output')
721 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
722 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
723 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
724 'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
725 'mul_1_w': {'shape': np.array([1]), 'value': 6},
726 'add_1_w': {'shape': np.array([1]), 'value': 6},
727 'mul_2_w': {'shape': np.array([1]), 'value': 6},
728 'mul_1': {'can_be_fused': False},
729 'add_1': {'can_be_fused': False},
731 nodes_with_edges_only=True)
733 graph_ref = build_graph(nodes_attributes,
734 [('placeholder_1', 'placeholder_1_data'),
735 ('placeholder_1_data', 'mul_1'),
736 ('mul_1_w', 'mul_1'),
737 ('mul_1', 'mul_1_data'),
738 ('mul_1_data', 'add_1'),
739 ('add_1_w', 'add_1'),
740 ('add_1', 'add_1_data'),
741 ('add_1_data', 'mul_2'),
742 ('mul_2_w', 'mul_2'),
743 ('mul_2', 'mul_2_data'),
744 ('mul_2_data', 'concat_1'),
745 ('concat_1', 'concat_1_data'),
746 ('placeholder_1_data', 'concat_1'),
747 ('concat_1_data', 'op_output')
749 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
750 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
751 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
752 'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
753 'mul_1_w': {'shape': np.array([1]), 'value': 6},
754 'add_1_w': {'shape': np.array([1]), 'value': 6},
755 'mul_2_w': {'shape': np.array([1]), 'value': 6},
756 'mul_1': {'can_be_fused': False},
757 'add_1': {'can_be_fused': False},
759 nodes_with_edges_only=True)
761 graph.graph['layout'] = 'NHWC'
762 fuse_mul_add_sequence(graph)
763 self.assertTrue(len(graph.node) == len(graph_ref.node),
764 "Graphs has different number of nodes: {} and {}".format(len(graph.node), len(graph_ref.node)))
766 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
767 self.assertTrue(flag, resp)
769 # Placeholder-+->Mul->Add->Mul-+->Concat
770 # | | With 'can_be_fused' = False
772 def test_fuse_lin_seq_12(self):
773 graph = build_graph(nodes_attributes,
774 [('placeholder_1', 'placeholder_1_data'),
775 ('placeholder_1_data', 'mul_1'),
776 ('mul_1_w', 'mul_1'),
777 ('mul_1', 'mul_1_data'),
778 ('mul_1_data', 'add_1'),
779 ('add_1_w', 'add_1'),
780 ('add_1', 'add_1_data'),
781 ('add_1_data', 'mul_2'),
782 ('mul_2_w', 'mul_2'),
783 ('mul_2', 'mul_2_data'),
784 ('mul_2_data', 'concat_1'),
785 ('concat_1', 'concat_1_data'),
786 ('placeholder_1_data', 'concat_1'),
787 ('concat_1_data', 'op_output')
789 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
790 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
791 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
792 'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
793 'mul_1_w': {'shape': np.array([1]), 'value': 6},
794 'add_1_w': {'shape': np.array([1]), 'value': 6},
795 'mul_2_w': {'shape': np.array([1]), 'value': 6},
796 'add_1': {'can_be_fused': False},
798 nodes_with_edges_only=True)
800 graph_ref = build_graph(nodes_attributes,
801 [('placeholder_1', 'placeholder_1_data'),
802 ('placeholder_1_data', 'mul_1'),
803 ('mul_1_w', 'mul_1'),
804 ('mul_1', 'mul_1_data'),
805 ('mul_1_data', 'add_1'),
806 ('add_1_w', 'add_1'),
807 ('add_1', 'add_1_data'),
808 ('add_1_data', 'mul_2'),
809 ('mul_2_w', 'mul_2'),
810 ('mul_2', 'mul_2_data'),
811 ('mul_2_data', 'concat_1'),
812 ('concat_1', 'concat_1_data'),
813 ('placeholder_1_data', 'concat_1'),
814 ('concat_1_data', 'op_output')
816 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
817 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
818 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
819 'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
820 'mul_1_w': {'shape': np.array([1]), 'value': 6},
821 'add_1_w': {'shape': np.array([1]), 'value': 6},
822 'mul_2_w': {'shape': np.array([1]), 'value': 6},
823 'add_1': {'can_be_fused': False},
825 nodes_with_edges_only=True)
827 graph.graph['layout'] = 'NHWC'
828 fuse_mul_add_sequence(graph)
829 self.assertTrue(len(graph.node) == len(graph_ref.node),
830 "Graphs has different number of nodes: {} and {}".format(len(graph.node), len(graph_ref.node)))
832 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
833 self.assertTrue(flag, resp)
835 # Placeholder-+->Mul->Add->Mul-+->Concat
837 # +->Mul->Mul->----+ (This Mul ops has shared weights with upper Mul ops)
838 def test_fuse_lin_seq_shared_weights_1(self):
839 graph = build_graph(nodes_attributes,
840 [('placeholder_1', 'placeholder_1_data'),
841 ('placeholder_1_data', 'mul_1'),
842 ('mul_1_w', 'mul_1'),
843 ('mul_1', 'mul_1_data'),
844 ('mul_1_data', 'add_1'),
845 ('add_1_w', 'add_1'),
846 ('add_1', 'add_1_data'),
847 ('add_1_data', 'mul_2'),
848 ('mul_2_w', 'mul_2'),
849 ('mul_2', 'mul_2_data'),
850 ('mul_2_data', 'concat_1'),
851 ('concat_1', 'concat_1_data'),
852 ('placeholder_1_data', 'mul_3'),
853 ('mul_3', 'mul_3_data'),
854 ('mul_1_w', 'mul_3'),
855 ('mul_3_data', 'mul_4'),
856 ('mul_2_w', 'mul_4'),
857 ('mul_4', 'mul_4_data'),
858 ('mul_4_data', 'concat_1'),
859 ('concat_1_data', 'op_output')
861 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
862 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
863 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
864 'mul_2_data': {'shape': np.array([1, 227, 227, 3])},
865 'mul_3_data': {'shape': np.array([1, 227, 227, 3])},
866 'mul_4_data': {'shape': np.array([1, 227, 227, 3])},
867 'mul_1_w': {'shape': np.array([1]), 'value': 6},
868 'add_1_w': {'shape': np.array([1]), 'value': 6},
869 'mul_2_w': {'shape': np.array([1]), 'value': 6},
871 nodes_with_edges_only=True)
873 graph_ref = build_graph(nodes_attributes,
874 [('placeholder_1', 'placeholder_1_data'),
875 ('placeholder_1_data', 'mul_1'),
876 ('mul_1_w', 'mul_1'),
877 ('mul_1', 'mul_1_data'),
878 ('mul_1_data', 'add_1'),
879 ('add_1_w', 'add_1'),
880 ('add_1', 'add_1_data'),
881 ('add_1_data', 'concat_1'),
882 ('concat_1', 'concat_1_data'),
883 ('placeholder_1_data', 'mul_3'),
884 ('mul_3', 'mul_3_data'),
885 ('mul_3_w', 'mul_3'),
886 ('mul_3_data', 'concat_1'),
887 ('concat_1_data', 'op_output')
889 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
890 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
891 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
892 'mul_3_data': {'shape': np.array([1, 227, 227, 3])},
893 'mul_1_w': {'shape': np.array([1]), 'value': np.array([36])},
894 'mul_3_w': {'shape': np.array([1]), 'value': np.array([36])},
895 'add_1_w': {'shape': np.array([1]), 'value': np.array([36])},
896 'mul_1': {'can_be_fused': True},
897 'add_1': {'can_be_fused': True},
899 nodes_with_edges_only=True)
901 graph.graph['layout'] = 'NHWC'
902 fuse_mul_add_sequence(graph)
903 self.assertTrue(len(graph.node) == len(graph_ref.node),
904 "Graphs has different number of nodes: {} and {}".format(len(graph.node),
905 len(graph_ref.node)))
907 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
908 self.assertTrue(flag, resp)