2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from mo.graph.graph import Node
22 from mo.middle.passes.eliminate import graph_clean_up
23 from mo.middle.passes.fusing.fuse_linear_ops import _fuse_mul, _fuse_add, fuse_linear_ops
24 from mo.utils.unittest.graph import build_graph, compare_graphs
27 'placeholder_1': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
28 'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
30 'scaleshift_1': {'type': 'ScaleShift', 'kind': 'op', 'op': 'ScaleShift'},
31 'scaleshift_1_w': {'value': None, 'shape': None, 'kind': 'data'},
32 'scaleshift_1_b': {'value': None, 'shape': None, 'kind': 'data'},
33 'const_scaleshift_1_w': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None},
34 'const_scaleshift_1_b': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None},
35 'scaleshift_1_data': {'value': None, 'shape': None, 'kind': 'data'},
36 # Mul and Add operations
37 'mul_1': {'type': 'Mul', 'kind': 'op', 'op': 'Mul', 'can_be_fused': True},
38 'mul_1_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
39 'const_mul_1_w': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None},
40 'mul_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
41 'add_1': {'type': 'Add', 'kind': 'op', 'op': 'Add', 'can_be_fused': True},
42 'add_1_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
43 'const_add_1_w': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None},
44 'add_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
45 # Mul2 and Add2 operations
46 'mul_2': {'type': 'Mul', 'kind': 'op', 'op': 'Mul', 'can_be_fused': True},
47 'mul_2_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
48 'const_mul_2_w': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None},
49 'mul_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
50 'add_2': {'type': 'Add', 'kind': 'op', 'op': 'Add', 'can_be_fused': True},
51 'add_2_w': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
52 'const_add_2_w': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None},
53 'add_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
55 'concat_1': {'type': 'Concat', 'kind': 'op', 'op': 'Concat'},
56 'concat_1_data': {'value': None, 'shape': None, 'kind': 'data'},
58 'conv_1': {'type': 'Convolution', 'kind': 'op', 'op': 'Conv2D', 'layout': 'NHWC'},
59 'conv_1_w': {'value': None, 'shape': None, 'kind': 'data'},
60 'conv_1_b': {'value': None, 'shape': None, 'kind': 'data'},
61 'const_conv_1_w': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None},
62 'const_conv_1_b': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None},
63 'conv_1_data': {'value': None, 'shape': None, 'kind': 'data'},
64 'conv_2': {'type': 'Convolution', 'kind': 'op', 'op': 'Conv2D', 'layout': 'NHWC'},
65 'conv_2_w': {'value': None, 'shape': None, 'kind': 'data'},
66 'conv_2_b': {'value': None, 'shape': None, 'kind': 'data'},
67 'const_conv_2_w': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None},
68 'const_conv_2_b': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None},
69 'conv_2_data': {'value': None, 'shape': None, 'kind': 'data'},
71 'fc_1': {'type': 'FullyConnected', 'kind': 'op', 'op': 'InnerProduct', 'layout': 'NHWC'},
72 'fc_1_w': {'value': None, 'shape': None, 'kind': 'data'},
73 'fc_1_b': {'value': None, 'shape': None, 'kind': 'data'},
74 'const_fc_1_w': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None},
75 'const_fc_1_b': {'value': None, 'shape': None, 'kind': 'op', 'data_type': None},
76 'fc_1_data': {'value': None, 'shape': None, 'kind': 'data'},
78 'placeholder_2': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
79 'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
80 'placeholder_3': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
81 'placeholder_3_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
82 'op_output': {'kind': 'op', 'op': 'OpOutput'},
83 'op_output_1': {'kind': 'op', 'op': 'OpOutput'},
84 'op_output_2': {'kind': 'op', 'op': 'OpOutput'}
88 # Unit tests for fuse_mul
89 class FuseMulTests(unittest.TestCase):
90 # Mul(array)->Conv(w+b)
91 def test_fuse_mul_to_conv_1(self):
92 # Placeholder->Mul->Conv
93 graph = build_graph(nodes_attributes,
94 [('placeholder_1', 'placeholder_1_data'),
95 ('placeholder_1_data', 'mul_1'),
96 ('const_mul_1_w', 'mul_1_w'),
98 ('mul_1', 'mul_1_data'),
99 ('mul_1_data', 'conv_1'),
100 ('const_conv_1_w', 'conv_1_w'),
101 ('const_conv_1_b', 'conv_1_b'),
102 ('conv_1_w', 'conv_1'),
103 ('conv_1_b', 'conv_1'),
104 ('conv_1', 'conv_1_data'),
105 ('conv_1_data', 'op_output')
107 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
108 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
109 'const_mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
110 'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
111 'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
112 'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
113 'output_channel_dim': 3, 'input_channel_dim': 2,
115 'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
116 'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
119 ref_weights = np.ones((11, 11, 3, 96)) * np.reshape(np.array([1, 2, 3]), (3, 1))
121 graph_ref = build_graph(nodes_attributes,
122 [('placeholder_1', 'placeholder_1_data'),
123 ('placeholder_1_data', 'conv_1'),
124 ('const_conv_1_w', 'conv_1_w'),
125 ('const_conv_1_b', 'conv_1_b'),
126 ('conv_1_w', 'conv_1'),
127 ('conv_1_b', 'conv_1'),
128 ('conv_1', 'conv_1_data'),
129 ('conv_1_data', 'op_output')
131 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
132 'const_conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
133 'conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
134 'output_channel_dim': 3, 'input_channel_dim': 2,
136 'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
137 'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
141 _fuse_mul(graph, Node(graph, 'mul_1'), [Node(graph, 'conv_1')], backward=False)
142 graph_clean_up(graph)
144 (flag, resp) = compare_graphs(graph, graph_ref, 'conv_1_data')
145 self.assertTrue(flag, resp)
147 # Mul(scalar)->Conv(w+b)
148 def test_fuse_mul_to_conv_2(self):
149 # Placeholder->Mul->Conv
150 graph = build_graph(nodes_attributes,
151 [('placeholder_1', 'placeholder_1_data'),
152 ('placeholder_1_data', 'mul_1'),
153 ('const_mul_1_w', 'mul_1_w'),
154 ('mul_1_w', 'mul_1'),
155 ('mul_1', 'mul_1_data'),
156 ('mul_1_data', 'conv_1'),
157 ('const_conv_1_w', 'conv_1_w'),
158 ('const_conv_1_b', 'conv_1_b'),
159 ('conv_1_w', 'conv_1'),
160 ('conv_1_b', 'conv_1'),
161 ('conv_1', 'conv_1_data'),
162 ('conv_1_data', 'op_output')
164 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
165 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
166 'const_mul_1_w': {'shape': np.array([1]), 'value': 6},
167 'mul_1_w': {'shape': np.array([1]), 'value': 6},
168 'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
169 'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
170 'output_channel_dim': 3, 'input_channel_dim': 2,
172 'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
173 'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
176 ref_weights = np.ones((11, 11, 3, 96)) * np.reshape(np.array([6, 6, 6]), (3, 1))
178 graph_ref = build_graph(nodes_attributes,
179 [('placeholder_1', 'placeholder_1_data'),
180 ('placeholder_1_data', 'conv_1'),
181 ('const_conv_1_w', 'conv_1_w'),
182 ('const_conv_1_b', 'conv_1_b'),
183 ('conv_1_w', 'conv_1'),
184 ('conv_1_b', 'conv_1'),
185 ('conv_1', 'conv_1_data'),
186 ('conv_1_data', 'op_output')
188 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
189 'const_conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
190 'conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
191 'output_channel_dim': 3, 'input_channel_dim': 2,
193 'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
194 'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
198 _fuse_mul(graph, Node(graph, 'mul_1'), [Node(graph, 'conv_1')], backward=False)
199 graph_clean_up(graph)
201 (flag, resp) = compare_graphs(graph, graph_ref, 'conv_1_data')
202 self.assertTrue(flag, resp)
204 # Conv(w+b)->Mul(array)
205 def test_fuse_mul_to_conv_3(self):
206 # Placeholder->Conv->Mul
207 graph = build_graph(nodes_attributes,
208 [('placeholder_1', 'placeholder_1_data'),
209 ('placeholder_1_data', 'conv_1'),
210 ('const_conv_1_w', 'conv_1_w'),
211 ('const_conv_1_b', 'conv_1_b'),
212 ('conv_1_w', 'conv_1'),
213 ('conv_1_b', 'conv_1'),
214 ('conv_1', 'conv_1_data'),
215 ('conv_1_data', 'mul_1'),
216 ('const_mul_1_w', 'mul_1_w'),
217 ('mul_1_w', 'mul_1'),
218 ('mul_1', 'mul_1_data'),
219 ('mul_1_data', 'op_output')
221 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
222 'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
223 'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
224 'output_channel_dim': 3, 'input_channel_dim': 2,
226 'const_conv_1_b': {'shape': np.array([96]), 'value': np.ones(96)},
227 'conv_1_b': {'shape': np.array([96]), 'value': np.ones(96)},
228 'conv_1_data': {'shape': np.array([1, 55, 55, 96])},
229 'mul_1_data': {'shape': np.array([1, 55, 55, 96])},
230 'const_mul_1_w': {'shape': np.array([96]), 'value': np.array([x for x in range(96)])},
231 'mul_1_w': {'shape': np.array([96]), 'value': np.array([x for x in range(96)])},
233 ref_weights = np.ones((11, 11, 3, 96)) * np.reshape(np.array([x for x in range(96)]), 96)
234 ref_biases = np.ones(96) * np.array([x for x in range(96)])
236 graph_ref = build_graph(nodes_attributes,
237 [('placeholder_1', 'placeholder_1_data'),
238 ('placeholder_1_data', 'conv_1'),
239 ('const_conv_1_w', 'conv_1_w'),
240 ('const_conv_1_b', 'conv_1_b'),
241 ('conv_1_w', 'conv_1'),
242 ('conv_1_b', 'conv_1'),
243 ('conv_1', 'conv_1_data'),
244 ('conv_1_data', 'op_output')
246 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
247 'const_conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
248 'conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
249 'output_channel_dim': 3, 'input_channel_dim': 2,
251 'const_conv_1_b': {'shape': ref_biases.shape, 'value': ref_biases},
252 'conv_1_b': {'shape': ref_biases.shape, 'value': ref_biases},
253 'conv_1_data': {'shape': np.array([1, 55, 55, 96])}
256 _fuse_mul(graph, Node(graph, 'mul_1'), [Node(graph, 'conv_1')], backward=True)
257 graph_clean_up(graph)
259 (flag, resp) = compare_graphs(graph, graph_ref, 'mul_1_data', 'conv_1_data')
260 self.assertTrue(flag, resp)
262 # Conv(w)->Mul(scalar)
263 def test_fuse_mul_to_conv_4(self):
264 # Placeholder->Conv->Mul
265 graph = build_graph(nodes_attributes,
266 [('placeholder_1', 'placeholder_1_data'),
267 ('placeholder_1_data', 'conv_1'),
268 ('const_conv_1_w', 'conv_1_w'),
269 ('const_conv_1_b', 'conv_1_b'),
270 ('conv_1_w', 'conv_1'),
271 ('conv_1_b', 'conv_1'),
272 ('conv_1', 'conv_1_data'),
273 ('conv_1_data', 'mul_1'),
274 ('const_mul_1_w', 'mul_1_w'),
275 ('mul_1_w', 'mul_1'),
276 ('mul_1', 'mul_1_data'),
277 ('mul_1_data', 'op_output')
279 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
280 'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
281 'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
282 'output_channel_dim': 3, 'input_channel_dim': 2,
284 'const_conv_1_b': {'shape': np.array([96]), 'value': np.ones(96)},
285 'conv_1_b': {'shape': np.array([96]), 'value': np.ones(96)},
286 'conv_1_data': {'shape': np.array([1, 55, 55, 96])},
287 'mul_1_data': {'shape': np.array([1, 55, 55, 96])},
288 'const_mul_1_w': {'shape': np.array([1]), 'value': 6},
289 'mul_1_w': {'shape': np.array([1]), 'value': 6},
291 ref_weights = np.ones((11, 11, 3, 96)) * np.array([6])
292 ref_biases = np.ones(96) * np.array([6])
294 graph_ref = build_graph(nodes_attributes,
295 [('placeholder_1', 'placeholder_1_data'),
296 ('placeholder_1_data', 'conv_1'),
297 ('const_conv_1_w', 'conv_1_w'),
298 ('const_conv_1_b', 'conv_1_b'),
299 ('conv_1_w', 'conv_1'),
300 ('conv_1_b', 'conv_1'),
301 ('conv_1', 'conv_1_data'),
302 ('conv_1_data', 'op_output')
304 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
305 'const_conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
306 'conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
307 'output_channel_dim': 3, 'input_channel_dim': 2,
309 'const_conv_1_b': {'shape': ref_biases.shape, 'value': ref_biases},
310 'conv_1_b': {'shape': ref_biases.shape, 'value': ref_biases},
311 'conv_1_data': {'shape': np.array([1, 55, 55, 96])}
314 _fuse_mul(graph, Node(graph, 'mul_1'), [Node(graph, 'conv_1')], backward=True)
315 graph_clean_up(graph)
317 (flag, resp) = compare_graphs(graph, graph_ref, 'mul_1_data', 'conv_1_data')
318 self.assertTrue(flag, resp)
320 # Op0-+->Op1--+----+-->Concat Op0-+->Op1--+--+-->Concat
322 # | +->Op2--+ | => | +->Op2--+ |
323 # +---->Mul->Conv-+ +---->Conv----+
324 def test_fuse_mul_to_conv_5(self):
325 graph = build_graph(nodes_attributes,
326 [('placeholder_1', 'placeholder_1_data'),
327 ('placeholder_1_data', 'mul_1'),
328 ('const_mul_1_w', 'mul_1_w'),
329 ('mul_1_w', 'mul_1'),
330 ('mul_1', 'mul_1_data'),
331 ('mul_1_data', 'conv_1'),
332 ('const_conv_1_w', 'conv_1_w'),
333 ('const_conv_1_b', 'conv_1_b'),
334 ('conv_1_w', 'conv_1'),
335 ('conv_1_b', 'conv_1'),
336 ('conv_1', 'conv_1_data'),
337 ('placeholder_1_data', 'placeholder_2'),
338 ('placeholder_2', 'placeholder_2_data'),
339 ('placeholder_1_data', 'placeholder_3'),
340 ('placeholder_3', 'placeholder_3_data'),
341 ('placeholder_2_data', 'concat_1'),
342 ('placeholder_3_data', 'concat_1'),
343 ('conv_1_data', 'concat_1'),
344 ('concat_1', 'concat_1_data'),
345 ('concat_1_data', 'op_output')
348 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
349 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
350 'const_mul_1_w': {'shape': np.array([1]), 'value': 6},
351 'mul_1_w': {'shape': np.array([1]), 'value': 6},
352 'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
353 'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
354 'output_channel_dim': 3, 'input_channel_dim': 2,
356 'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
357 'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
360 ref_weights = np.ones((11, 11, 3, 96)) * np.reshape(np.array([6, 6, 6]), (3, 1))
362 graph_ref = build_graph(nodes_attributes,
363 [('placeholder_1', 'placeholder_1_data'),
364 ('placeholder_1_data', 'conv_1'),
365 ('const_conv_1_w', 'conv_1_w'),
366 ('const_conv_1_b', 'conv_1_b'),
367 ('conv_1_w', 'conv_1'),
368 ('conv_1_b', 'conv_1'),
369 ('conv_1', 'conv_1_data'),
370 ('placeholder_1_data', 'placeholder_2'),
371 ('placeholder_2', 'placeholder_2_data'),
372 ('placeholder_1_data', 'placeholder_3'),
373 ('placeholder_3', 'placeholder_3_data'),
374 ('placeholder_2_data', 'concat_1'),
375 ('placeholder_3_data', 'concat_1'),
376 ('conv_1_data', 'concat_1'),
377 ('concat_1', 'concat_1_data'),
378 ('concat_1_data', 'op_output'),
380 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
381 'const_conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
382 'conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
383 'output_channel_dim': 3,
384 'input_channel_dim': 2, 'dims_number': 4},
385 'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
386 'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
388 'placeholder_2_data': {},
389 'placeholder_3_data': {},
392 _fuse_mul(graph, Node(graph, 'mul_1'), [Node(graph, 'conv_1')], backward=False)
393 graph_clean_up(graph)
395 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
396 self.assertTrue(flag, resp)
398 def test_fuse_mul_to_conv_5_nparray(self):
399 graph = build_graph(nodes_attributes,
400 [('placeholder_1', 'placeholder_1_data'),
401 ('placeholder_1_data', 'mul_1'),
402 ('const_mul_1_w', 'mul_1_w'),
403 ('mul_1_w', 'mul_1'),
404 ('mul_1', 'mul_1_data'),
405 ('mul_1_data', 'conv_1'),
406 ('const_conv_1_w', 'conv_1_w'),
407 ('const_conv_1_b', 'conv_1_b'),
408 ('conv_1_w', 'conv_1'),
409 ('conv_1_b', 'conv_1'),
410 ('conv_1', 'conv_1_data'),
411 ('placeholder_1_data', 'placeholder_2'),
412 ('placeholder_2', 'placeholder_2_data'),
413 ('placeholder_1_data', 'placeholder_3'),
414 ('placeholder_3', 'placeholder_3_data'),
415 ('placeholder_2_data', 'concat_1'),
416 ('placeholder_3_data', 'concat_1'),
417 ('conv_1_data', 'concat_1'),
418 ('concat_1', 'concat_1_data'),
419 ('concat_1_data', 'op_output'),
422 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
423 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
424 'const_mul_1_w': {'shape': np.array([1]), 'value': np.array([6])},
425 'mul_1_w': {'shape': np.array([1]), 'value': np.array([6])},
426 'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
427 'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
428 'output_channel_dim': 3, 'input_channel_dim': 2,
430 'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
431 'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
434 ref_weights = np.ones((11, 11, 3, 96)) * np.reshape(np.array([6, 6, 6]), (3, 1))
436 graph_ref = build_graph(nodes_attributes,
437 [('placeholder_1', 'placeholder_1_data'),
438 ('placeholder_1_data', 'conv_1'),
439 ('const_conv_1_w', 'conv_1_w'),
440 ('const_conv_1_b', 'conv_1_b'),
441 ('conv_1_w', 'conv_1'),
442 ('conv_1_b', 'conv_1'),
443 ('conv_1', 'conv_1_data'),
444 ('placeholder_1_data', 'placeholder_2'),
445 ('placeholder_2', 'placeholder_2_data'),
446 ('placeholder_1_data', 'placeholder_3'),
447 ('placeholder_3', 'placeholder_3_data'),
448 ('placeholder_2_data', 'concat_1'),
449 ('placeholder_3_data', 'concat_1'),
450 ('conv_1_data', 'concat_1'),
451 ('concat_1', 'concat_1_data'),
452 ('concat_1_data', 'op_output'),
454 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
455 'const_conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
456 'conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
457 'output_channel_dim': 3,
458 'input_channel_dim': 2, 'dims_number': 4},
459 'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
460 'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
462 'placeholder_2_data': {},
463 'placeholder_3_data': {},
466 _fuse_mul(graph, Node(graph, 'mul_1'), [Node(graph, 'conv_1')], backward=False)
467 graph_clean_up(graph)
469 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
470 self.assertTrue(flag, resp)
472 # Op->Mul(array)-+->Conv(w+b)--+->Concat Op-+->Conv1-+-->Concat
474 # +-->Conv(w+b)-+ +->Conv2-+
475 def test_fuse_mul_to_convolutions_1(self):
476 graph = build_graph(nodes_attributes,
477 [('placeholder_1', 'placeholder_1_data'),
478 ('placeholder_1_data', 'mul_1'),
479 ('const_mul_1_w', 'mul_1_w'),
480 ('mul_1_w', 'mul_1'),
481 ('mul_1', 'mul_1_data'),
482 ('mul_1_data', 'conv_1'),
483 ('const_conv_1_w', 'conv_1_w'),
484 ('const_conv_1_b', 'conv_1_b'),
485 ('conv_1_w', 'conv_1'),
486 ('conv_1_b', 'conv_1'),
487 ('conv_1', 'conv_1_data'),
488 ('mul_1_data', 'conv_2'),
489 ('const_conv_2_w', 'conv_2_w'),
490 ('const_conv_2_b', 'conv_2_b'),
491 ('conv_2_w', 'conv_2'),
492 ('conv_2_b', 'conv_2'),
493 ('conv_2', 'conv_2_data'),
494 ('conv_1_data', 'concat_1'),
495 ('conv_2_data', 'concat_1'),
496 ('concat_1', 'concat_1_data'),
497 ('concat_1_data', 'op_output')
499 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
500 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
501 'const_mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
502 'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
503 'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
504 'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
505 'output_channel_dim': 3, 'input_channel_dim': 2,
507 'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
508 'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
509 'conv_1_data': {'shape': np.array([1, 55, 55, 96])},
510 'const_conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
511 'conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
512 'output_channel_dim': 3, 'input_channel_dim': 2,
514 'const_conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
515 'conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
516 'conv_2_data': {'shape': np.array([1, 55, 55, 96])},
519 ref_weights = np.ones((11, 11, 3, 96)) * np.reshape(np.array([1, 2, 3]), (3, 1))
521 graph_ref = build_graph(nodes_attributes,
522 [('placeholder_1', 'placeholder_1_data'),
523 ('placeholder_1_data', 'conv_1'),
524 ('const_conv_1_w', 'conv_1_w'),
525 ('const_conv_1_b', 'conv_1_b'),
526 ('conv_1_w', 'conv_1'),
527 ('conv_1_b', 'conv_1'),
528 ('conv_1', 'conv_1_data'),
529 ('placeholder_1_data', 'conv_2'),
530 ('const_conv_2_w', 'conv_2_w'),
531 ('const_conv_2_b', 'conv_2_b'),
532 ('conv_2_w', 'conv_2'),
533 ('conv_2_b', 'conv_2'),
534 ('conv_2', 'conv_2_data'),
535 ('conv_1_data', 'concat_1'),
536 ('conv_2_data', 'concat_1'),
537 ('concat_1', 'concat_1_data'),
538 ('concat_1_data', 'op_output')
540 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
541 'const_conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
542 'conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
543 'output_channel_dim': 3, 'input_channel_dim': 2,
545 'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
546 'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
547 'conv_1_data': {'shape': np.array([1, 55, 55, 96])},
548 'const_conv_2_w': {'shape': ref_weights.shape, 'value': ref_weights},
549 'conv_2_w': {'shape': ref_weights.shape, 'value': ref_weights,
550 'output_channel_dim': 3, 'input_channel_dim': 2,
552 'const_conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
553 'conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
554 'conv_2_data': {'shape': np.array([1, 55, 55, 96])},
557 _fuse_mul(graph, Node(graph, 'mul_1'), [Node(graph, 'conv_1'), Node(graph, 'conv_2')], backward=False)
558 graph_clean_up(graph)
560 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
561 self.assertTrue(flag, resp)
563 # Mul(array)->FC(w+b)
564 def test_fuse_mul_to_fc_1(self):
565 # Placeholder->Mul->FC
566 graph = build_graph(nodes_attributes,
567 [('placeholder_1', 'placeholder_1_data'),
568 ('placeholder_1_data', 'mul_1'),
569 ('const_mul_1_w', 'mul_1_w'),
570 ('mul_1_w', 'mul_1'),
571 ('mul_1', 'mul_1_data'),
572 ('mul_1_data', 'fc_1'),
573 ('const_fc_1_w', 'fc_1_w'),
574 ('const_fc_1_b', 'fc_1_b'),
577 ('fc_1', 'fc_1_data'),
578 ('fc_1_data', 'op_output')
580 {'placeholder_1_data': {'shape': np.array([1, 2048])},
581 'mul_1_data': {'shape': np.array([1, 2048])},
582 'const_mul_1_w': {'shape': np.array([2048]), 'value': np.array([x for x in range(2048)])},
583 'mul_1_w': {'shape': np.array([2048]), 'value': np.array([x for x in range(2048)])},
584 'const_fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048))},
585 'fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048)),
586 'output_channel_dim': 0, 'input_channel_dim': 1,
588 'const_fc_1_b': {'shape': np.array([10260]), 'value': np.ones(10260)},
589 'fc_1_b': {'shape': np.array([10260]), 'value': np.ones(10260)},
590 'fc_1_data': {'shape': np.array([1, 10260])},
592 ref_weights = np.ones((10260, 2048)) * np.array([x for x in range(2048)])
594 graph_ref = build_graph(nodes_attributes,
595 [('placeholder_1', 'placeholder_1_data'),
596 ('placeholder_1_data', 'fc_1'),
597 ('const_fc_1_w', 'fc_1_w'),
598 ('const_fc_1_b', 'fc_1_b'),
601 ('fc_1', 'fc_1_data'),
602 ('fc_1_data', 'op_output')
605 {'placeholder_1_data': {'shape': np.array([1, 2048])},
606 'const_fc_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
607 'fc_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
608 'output_channel_dim': 0, 'input_channel_dim': 1,
610 'const_fc_1_b': {'shape': np.array([10260]), 'value': np.ones(10260)},
611 'fc_1_b': {'shape': np.array([10260]), 'value': np.ones(10260)},
612 'fc_1_data': {'shape': np.array([1, 10260])},
615 _fuse_mul(graph, Node(graph, 'mul_1'), [Node(graph, 'fc_1')], backward=False)
616 graph_clean_up(graph)
618 (flag, resp) = compare_graphs(graph, graph_ref, 'fc_1_data')
619 self.assertTrue(flag, resp)
621 # Mul(scalar)->Conv(w+b) can_be_fused = False
622 def test_fuse_mul_to_conv_6(self):
623 # Placeholder->Mul->Conv
624 graph = build_graph(nodes_attributes,
625 [('placeholder_1', 'placeholder_1_data'),
626 ('placeholder_1_data', 'mul_1'),
627 ('const_mul_1_w', 'mul_1_w'),
628 ('mul_1_w', 'mul_1'),
629 ('mul_1', 'mul_1_data'),
630 ('mul_1_data', 'conv_1'),
631 ('const_conv_1_w', 'conv_1_w'),
632 ('const_conv_1_b', 'conv_1_b'),
633 ('conv_1_w', 'conv_1'),
634 ('conv_1_b', 'conv_1'),
635 ('conv_1', 'conv_1_data'),
636 ('conv_1_data', 'op_output')
638 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
639 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
640 'const_mul_1_w': {'shape': np.array([]), 'value': np.array(6)},
641 'mul_1_w': {'shape': np.array([]), 'value': np.array(6)},
642 'conv_1': {'can_be_fused': False},
643 'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
644 'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
645 'output_channel_dim': 3, 'input_channel_dim': 2,
647 'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
648 'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
652 graph_ref = build_graph(nodes_attributes,
653 [('placeholder_1', 'placeholder_1_data'),
654 ('placeholder_1_data', 'mul_1'),
655 ('const_mul_1_w', 'mul_1_w'),
656 ('mul_1_w', 'mul_1'),
657 ('mul_1', 'mul_1_data'),
658 ('mul_1_data', 'conv_1'),
659 ('const_conv_1_w', 'conv_1_w'),
660 ('const_conv_1_b', 'conv_1_b'),
661 ('conv_1_w', 'conv_1'),
662 ('conv_1_b', 'conv_1'),
663 ('conv_1', 'conv_1_data'),
664 ('conv_1_data', 'op_output')
666 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
667 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
668 'const_mul_1_w': {'shape': np.array([]), 'value': np.array(6)},
669 'mul_1_w': {'shape': np.array([]), 'value': np.array(6)},
670 'conv_1': {'can_be_fused': False},
671 'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
672 'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
673 'output_channel_dim': 3, 'input_channel_dim': 2,
675 'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
676 'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
680 _fuse_mul(graph, Node(graph, 'mul_1'), [Node(graph, 'conv_1')], backward=False)
681 graph_clean_up(graph)
683 (flag, resp) = compare_graphs(graph, graph_ref, 'conv_1_data')
684 self.assertTrue(flag, resp)
686 # Mul(array)->DWConv(w+b)
687 def test_fuse_mul_to_dwconv_1(self):
688 # Placeholder->Mul->Conv
689 graph = build_graph(nodes_attributes,
690 [('placeholder_1', 'placeholder_1_data'),
691 ('placeholder_1_data', 'mul_1'),
692 ('const_mul_1_w', 'mul_1_w'),
693 ('mul_1_w', 'mul_1'),
694 ('mul_1', 'mul_1_data'),
695 ('mul_1_data', 'conv_1'),
696 ('const_conv_1_w', 'conv_1_w'),
697 ('conv_1_w', 'conv_1'),
698 ('conv_1', 'conv_1_data'),
699 ('conv_1_data', 'op_output')
701 {'placeholder_1_data': {'shape': np.array([1, 112, 112, 6])},
702 'mul_1_data': {'shape': np.array([1, 112, 112, 6])},
703 'const_mul_1_w': {'shape': np.array([6]), 'value': np.array([1, 2, 3, 4, 5, 6])},
704 'mul_1_w': {'shape': np.array([6]), 'value': np.array([1, 2, 3, 4, 5, 6])},
705 'const_conv_1_w': {'shape': np.array([3, 3, 6, 1]), 'value': np.ones((3, 3, 6, 1))},
706 'conv_1_w': {'shape': np.array([3, 3, 6, 1]), 'value': np.ones((3, 3, 6, 1)),
707 'output_channel_dim': 2, 'input_channel_dim': 2,
711 ref_weights = np.ones((3, 3, 6, 1)) * np.reshape(np.array([1, 2, 3, 4, 5, 6]), (6, 1))
713 graph_ref = build_graph(nodes_attributes,
714 [('placeholder_1', 'placeholder_1_data'),
715 ('placeholder_1_data', 'conv_1'),
716 ('const_conv_1_w', 'conv_1_w'),
717 ('conv_1_w', 'conv_1'),
718 ('conv_1', 'conv_1_data'),
719 ('conv_1_data', 'op_output')
721 {'placeholder_1_data': {'shape': np.array([1, 112, 112, 6])},
722 'const_conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
723 'conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
724 'output_channel_dim': 2, 'input_channel_dim': 2,
729 _fuse_mul(graph, Node(graph, 'mul_1'), [Node(graph, 'conv_1')], backward=False)
730 graph_clean_up(graph)
732 (flag, resp) = compare_graphs(graph, graph_ref, 'conv_1_data')
733 self.assertTrue(flag, resp)
735 # DWConv(w)->Mul(scalar)
736 def test_fuse_mul_to_dwconv_2(self):
737 # Placeholder->Conv->Mul
738 graph = build_graph(nodes_attributes,
739 [('placeholder_1', 'placeholder_1_data'),
740 ('placeholder_1_data', 'conv_1'),
741 ('const_conv_1_w', 'conv_1_w'),
742 ('conv_1_w', 'conv_1'),
743 ('conv_1', 'conv_1_data'),
744 ('conv_1_data', 'mul_1'),
745 ('const_mul_1_w', 'mul_1_w'),
746 ('mul_1_w', 'mul_1'),
747 ('mul_1', 'mul_1_data'),
748 ('mul_1_data', 'op_output')
750 {'placeholder_1_data': {'shape': np.array([1, 112, 112, 6])},
751 'mul_1_data': {'shape': np.array([1, 112, 112, 6])},
752 'const_mul_1_w': {'shape': np.array([6]), 'value': np.array([1, 2, 3, 4, 5, 6])},
753 'mul_1_w': {'shape': np.array([6]), 'value': np.array([1, 2, 3, 4, 5, 6])},
754 'const_conv_1_w': {'shape': np.array([3, 3, 6, 1]), 'value': np.ones((3, 3, 6, 1))},
755 'conv_1_w': {'shape': np.array([3, 3, 6, 1]), 'value': np.ones((3, 3, 6, 1)),
756 'output_channel_dim': 2, 'input_channel_dim': 2,
761 ref_weights = np.ones((3, 3, 6, 1)) * np.reshape(np.array([1, 2, 3, 4, 5, 6]), (6, 1))
763 graph_ref = build_graph(nodes_attributes,
764 [('placeholder_1', 'placeholder_1_data'),
765 ('placeholder_1_data', 'conv_1'),
766 ('const_conv_1_w', 'conv_1_w'),
767 ('conv_1_w', 'conv_1'),
768 ('conv_1', 'conv_1_data'),
769 ('conv_1_data', 'op_output')
771 {'placeholder_1_data': {'shape': np.array([1, 112, 112, 6])},
772 'const_conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
773 'conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
774 'output_channel_dim': 2, 'input_channel_dim': 2,
778 _fuse_mul(graph, Node(graph, 'mul_1'), [Node(graph, 'conv_1')], backward=True)
779 graph_clean_up(graph)
781 (flag, resp) = compare_graphs(graph, graph_ref, 'mul_1_data', 'conv_1_data')
782 self.assertTrue(flag, resp)
785 # Unit tests for fuse_add
786 class FuseAddTests(unittest.TestCase):
787 # Add(array)->FC(w+b)
788 def test_fuse_add_to_fc_1(self):
789 # Placeholder->Add->FC
790 graph = build_graph(nodes_attributes,
791 [('placeholder_1', 'placeholder_1_data'),
792 ('placeholder_1_data', 'add_1'),
793 ('const_add_1_w', 'add_1_w'),
794 ('add_1_w', 'add_1'),
795 ('add_1', 'add_1_data'),
796 ('add_1_data', 'fc_1'),
797 ('const_fc_1_w', 'fc_1_w'),
798 ('const_fc_1_b', 'fc_1_b'),
801 ('fc_1', 'fc_1_data'),
802 ('fc_1_data', 'op_output')
805 {'placeholder_1_data': {'shape': np.array([1, 2048])},
806 'add_1_data': {'shape': np.array([1, 2048])},
807 'const_add_1_w': {'shape': np.array([2048]), 'value': np.array([x for x in range(2048)])},
808 'add_1_w': {'shape': np.array([2048]), 'value': np.array([x for x in range(2048)])},
809 'const_fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048))},
810 'fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048)),
811 'output_channel_dim': 0, 'input_channel_dim': 1,
813 'const_fc_1_b': {'shape': np.array([10260]), 'value': np.ones(10260)},
814 'fc_1_b': {'shape': np.array([10260]), 'value': np.ones(10260)},
815 'fc_1_data': {'shape': np.array([1, 10260])},
817 ref_weights = np.ones((10260, 2048))
818 ref_biases = np.ones(10260) + np.dot(np.ones((10260, 2048)), np.array([x for x in range(2048)]))
820 graph_ref = build_graph(nodes_attributes,
821 [('placeholder_1', 'placeholder_1_data'),
822 ('placeholder_1_data', 'fc_1'),
823 ('const_fc_1_w', 'fc_1_w'),
824 ('const_fc_1_b', 'fc_1_b'),
827 ('fc_1', 'fc_1_data'),
828 ('fc_1_data', 'op_output')
830 {'placeholder_1_data': {'shape': np.array([1, 2048])},
831 'const_fc_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
832 'fc_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
833 'output_channel_dim': 0, 'input_channel_dim': 1,
835 'const_fc_1_b': {'shape': ref_biases.shape, 'value': ref_biases},
836 'fc_1_b': {'shape': ref_biases.shape, 'value': ref_biases},
837 'fc_1_data': {'shape': np.array([1, 10260])},
840 _fuse_add(graph, Node(graph, 'add_1'), [Node(graph, 'fc_1')], backward=False)
841 graph_clean_up(graph)
843 (flag, resp) = compare_graphs(graph, graph_ref, 'fc_1_data')
844 self.assertTrue(flag, resp)
847 def test_fuse_add_to_fc_2(self):
848 # Placeholder->FC->Add
849 graph = build_graph(nodes_attributes,
850 [('placeholder_1', 'placeholder_1_data'),
851 ('placeholder_1_data', 'fc_1'),
852 ('const_fc_1_w', 'fc_1_w'),
854 ('fc_1', 'fc_1_data'),
855 ('fc_1_data', 'add_1'),
856 ('const_add_1_w', 'add_1_w'),
857 ('add_1_w', 'add_1'),
858 ('add_1', 'add_1_data'),
859 ('add_1_data', 'op_output_1')
861 {'placeholder_1_data': {'shape': np.array([1, 2048])},
862 'add_1_data': {'shape': np.array([1, 10260])},
863 'const_add_1_w': {'shape': np.array([10260]), 'value': np.array([x for x in range(10260)])},
864 'add_1_w': {'shape': np.array([10260]), 'value': np.array([x for x in range(10260)]),
866 'const_fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048))},
867 'fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048)),
868 'output_channel_dim': 0, 'input_channel_dim': 1,
869 'dims_number': 2, 'data_type': None},
870 'fc_1_data': {'shape': np.array([1, 10260])},
873 ref_weights = np.ones((10260, 2048))
874 ref_biases = np.array([x for x in range(10260)])
876 graph_ref = build_graph(nodes_attributes,
877 [('placeholder_1', 'placeholder_1_data'),
878 ('placeholder_1_data', 'fc_1'),
879 ('const_fc_1_w', 'fc_1_w'),
880 ('const_fc_1_b', 'fc_1_b'),
883 ('fc_1', 'fc_1_data'),
884 ('fc_1_data', 'op_output')
886 {'placeholder_1_data': {'shape': np.array([1, 2048])},
887 'const_fc_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
888 'fc_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
889 'output_channel_dim': 0, 'input_channel_dim': 1,
891 'const_fc_1_b': {'shape': ref_biases.shape, 'value': ref_biases},
892 'fc_1_b': {'shape': ref_biases.shape, 'value': ref_biases},
893 'fc_1_data': {'shape': np.array([1, 10260])},
896 _fuse_add(graph, Node(graph, 'add_1'), [Node(graph, 'fc_1')], backward=True)
897 graph_clean_up(graph)
899 (flag, resp) = compare_graphs(graph, graph_ref, 'add_1_data', 'fc_1_data')
900 self.assertTrue(flag, resp)
903 def test_fuse_add_to_fc_3(self):
904 # Placeholder->FC->Add
905 graph = build_graph(nodes_attributes,
906 [('placeholder_1', 'placeholder_1_data'),
907 ('placeholder_1_data', 'fc_1'),
908 ('const_fc_1_w', 'fc_1_w'),
910 ('fc_1', 'fc_1_data'),
911 ('fc_1_data', 'add_1'),
912 ('const_add_1_w', 'add_1_w'),
913 ('add_1_w', 'add_1'),
914 ('add_1', 'add_1_data'),
915 ('add_1_data', 'op_output')
917 {'placeholder_1_data': {'shape': np.array([1, 2048])},
918 'add_1_data': {'shape': np.array([1, 10260])},
919 'const_add_1_w': {'shape': np.array([1]), 'value': 6, 'data_type': None},
920 'add_1_w': {'shape': np.array([1]), 'value': 6, 'data_type': None},
921 'const_fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048))},
922 'fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048)),
923 'output_channel_dim': 0, 'input_channel_dim': 1,
924 'dims_number': 2, 'data_type': None},
925 'fc_1_data': {'shape': np.array([1, 10260])},
928 ref_weights = np.ones((10260, 2048))
929 ref_biases = np.full([10260], 6)
931 graph_ref = build_graph(nodes_attributes,
932 [('placeholder_1', 'placeholder_1_data'),
933 ('placeholder_1_data', 'fc_1'),
934 ('const_fc_1_w', 'fc_1_w'),
935 ('const_fc_1_b', 'fc_1_b'),
938 ('fc_1', 'fc_1_data'),
939 ('fc_1_data', 'op_output')
942 {'placeholder_1_data': {'shape': np.array([1, 2048])},
943 'const_fc_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
944 'fc_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
945 'output_channel_dim': 0, 'input_channel_dim': 1,
947 'const_fc_1_b': {'shape': ref_biases.shape, 'value': ref_biases},
948 'fc_1_b': {'shape': ref_biases.shape, 'value': ref_biases},
949 'fc_1_data': {'shape': np.array([1, 10260])},
952 _fuse_add(graph, Node(graph, 'add_1'), [Node(graph, 'fc_1')], backward=True)
953 graph_clean_up(graph)
955 (flag, resp) = compare_graphs(graph, graph_ref, 'add_1_data', 'fc_1_data')
956 self.assertTrue(flag, resp)
958 # Add(array)->FC(w+b) can_be_fused = False
959 def test_fuse_add_to_fc_4(self):
960 # Placeholder->Add->FC
961 graph = build_graph(nodes_attributes,
962 [('placeholder_1', 'placeholder_1_data'),
963 ('placeholder_1_data', 'add_1'),
964 ('const_add_1_w', 'add_1_w'),
965 ('add_1_w', 'add_1'),
966 ('add_1', 'add_1_data'),
967 ('add_1_data', 'fc_1'),
968 ('const_fc_1_w', 'fc_1_w'),
969 ('const_fc_1_b', 'fc_1_b'),
972 ('fc_1', 'fc_1_data'),
973 ('fc_1_data', 'op_output')
976 {'placeholder_1_data': {'shape': np.array([1, 2048])},
977 'add_1_data': {'shape': np.array([1, 2048])},
978 'const_add_1_w': {'shape': np.array([2048]), 'value': np.array([x for x in range(2048)])},
979 'add_1_w': {'shape': np.array([2048]), 'value': np.array([x for x in range(2048)])},
980 'fc_1': {'can_be_fused': False},
981 'const_fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048))},
982 'fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048)),
983 'output_channel_dim': 0, 'input_channel_dim': 1,
985 'const_fc_1_b': {'shape': np.array([10260]), 'value': np.ones(10260)},
986 'fc_1_b': {'shape': np.array([10260]), 'value': np.ones(10260)},
987 'fc_1_data': {'shape': np.array([1, 10260])},
990 graph_ref = build_graph(nodes_attributes,
991 [('placeholder_1', 'placeholder_1_data'),
992 ('placeholder_1_data', 'add_1'),
993 ('const_add_1_w', 'add_1_w'),
994 ('add_1_w', 'add_1'),
995 ('add_1', 'add_1_data'),
996 ('add_1_data', 'fc_1'),
997 ('const_fc_1_w', 'fc_1_w'),
998 ('const_fc_1_b', 'fc_1_b'),
1001 ('fc_1', 'fc_1_data'),
1002 ('fc_1_data', 'op_output')
1004 {'placeholder_1_data': {'shape': np.array([1, 2048])},
1005 'add_1_data': {'shape': np.array([1, 2048])},
1006 'const_add_1_w': {'shape': np.array([2048]), 'value': np.array([x for x in range(2048)])},
1007 'add_1_w': {'shape': np.array([2048]), 'value': np.array([x for x in range(2048)])},
1008 'fc_1': {'can_be_fused': False},
1009 'const_fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048))},
1010 'fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048)),
1011 'output_channel_dim': 0, 'input_channel_dim': 1,
1013 'const_fc_1_b': {'shape': np.array([10260]), 'value': np.ones(10260)},
1014 'fc_1_b': {'shape': np.array([10260]), 'value': np.ones(10260)},
1015 'fc_1_data': {'shape': np.array([1, 10260])},
1018 _fuse_add(graph, Node(graph, 'add_1'), [Node(graph, 'fc_1')], backward=False)
1019 graph_clean_up(graph)
1021 (flag, resp) = compare_graphs(graph, graph_ref, 'fc_1_data')
1022 self.assertTrue(flag, resp)
1025 # Unit tests for fuse_linear_ops
1026 class FuseLinOpsTests(unittest.TestCase):
1027 # Op->Mul(array)-+->Conv(w+b)->Add-+->Concat Op-+->Conv1-+-->Concat
1029 # +-->Conv(w+b)-----+ +->Conv2-+
1030 def test_fuse_lin_ops_1(self):
1031 graph = build_graph(nodes_attributes,
1032 [('placeholder_1', 'placeholder_1_data'),
1033 ('placeholder_1_data', 'mul_1'),
1034 ('const_mul_1_w', 'mul_1_w'),
1035 ('mul_1_w', 'mul_1'),
1036 ('mul_1', 'mul_1_data'),
1037 ('mul_1_data', 'conv_1'),
1038 ('const_conv_1_w', 'conv_1_w'),
1039 ('const_conv_1_b', 'conv_1_b'),
1040 ('conv_1_w', 'conv_1'),
1041 ('conv_1_b', 'conv_1'),
1042 ('conv_1', 'conv_1_data'),
1043 ('mul_1_data', 'conv_2'),
1044 ('const_conv_2_w', 'conv_2_w'),
1045 ('const_conv_2_b', 'conv_2_b'),
1046 ('conv_2_w', 'conv_2'),
1047 ('conv_2_b', 'conv_2'),
1048 ('conv_2', 'conv_2_data'),
1049 ('conv_1_data', 'concat_1'),
1050 ('conv_2_data', 'concat_1'),
1051 ('concat_1', 'concat_1_data'),
1052 ('concat_1_data', 'op_output')
1054 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
1055 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
1056 'const_mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
1057 'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
1058 'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
1059 'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
1060 'output_channel_dim': 3, 'input_channel_dim': 2,
1062 'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1063 'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1064 'conv_1_data': {'shape': np.array([1, 55, 55, 96])},
1065 'const_conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
1066 'conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
1067 'output_channel_dim': 3, 'input_channel_dim': 2,
1069 'const_conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1070 'conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1071 'conv_2_data': {'shape': np.array([1, 55, 55, 96])},
1074 ref_weights = np.ones((11, 11, 3, 96)) * np.reshape(np.array([1, 2, 3]), (3, 1))
1076 graph_ref = build_graph(nodes_attributes,
1077 [('placeholder_1', 'placeholder_1_data'),
1078 ('placeholder_1_data', 'conv_1'),
1079 ('const_conv_1_w', 'conv_1_w'),
1080 ('const_conv_1_b', 'conv_1_b'),
1081 ('conv_1_w', 'conv_1'),
1082 ('conv_1_b', 'conv_1'),
1083 ('conv_1', 'conv_1_data'),
1084 ('placeholder_1_data', 'conv_2'),
1085 ('const_conv_2_w', 'conv_2_w'),
1086 ('const_conv_2_b', 'conv_2_b'),
1087 ('conv_2_w', 'conv_2'),
1088 ('conv_2_b', 'conv_2'),
1089 ('conv_2', 'conv_2_data'),
1090 ('conv_1_data', 'concat_1'),
1091 ('conv_2_data', 'concat_1'),
1092 ('concat_1', 'concat_1_data'),
1093 ('concat_1_data', 'op_output')
1095 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
1096 'const_conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
1097 'conv_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
1098 'output_channel_dim': 3, 'input_channel_dim': 2,
1100 'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1101 'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1102 'conv_1_data': {'shape': np.array([1, 55, 55, 96])},
1103 'const_conv_2_w': {'shape': ref_weights.shape, 'value': ref_weights},
1104 'conv_2_w': {'shape': ref_weights.shape, 'value': ref_weights,
1105 'output_channel_dim': 3, 'input_channel_dim': 2,
1107 'const_conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1108 'conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1109 'conv_2_data': {'shape': np.array([1, 55, 55, 96])},
1112 fuse_linear_ops(graph)
1113 graph_clean_up(graph)
1115 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
1116 self.assertTrue(flag, resp)
1118 # Mul(array)->FC(w+b)
1119 def test_fuse_mul_to_fc_1(self):
1120 # Placeholder->Mul->FC
1121 graph = build_graph(nodes_attributes,
1122 [('placeholder_1', 'placeholder_1_data'),
1123 ('placeholder_1_data', 'mul_1'),
1124 ('const_mul_1_w', 'mul_1_w'),
1125 ('mul_1_w', 'mul_1'),
1126 ('mul_1', 'mul_1_data'),
1127 ('mul_1_data', 'fc_1'),
1128 ('const_fc_1_w', 'fc_1_w'),
1129 ('const_fc_1_b', 'fc_1_b'),
1132 ('fc_1', 'fc_1_data'),
1133 ('fc_1_data', 'op_output')
1135 {'placeholder_1_data': {'shape': np.array([1, 2048])},
1136 'mul_1_data': {'shape': np.array([1, 2048])},
1137 'const_mul_1_w': {'shape': np.array([2048]), 'value': np.array([x for x in range(2048)])},
1138 'mul_1_w': {'shape': np.array([2048]), 'value': np.array([x for x in range(2048)])},
1139 'const_fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048))},
1140 'fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048)),
1141 'output_channel_dim': 0, 'input_channel_dim': 1,
1143 'const_fc_1_b': {'shape': np.array([10260]), 'value': np.ones(10260)},
1144 'fc_1_b': {'shape': np.array([10260]), 'value': np.ones(10260)},
1145 'fc_1_data': {'shape': np.array([1, 10260])},
1147 ref_weights = np.ones((10260, 2048)) * np.array([x for x in range(2048)])
1149 graph_ref = build_graph(nodes_attributes,
1150 [('placeholder_1', 'placeholder_1_data'),
1151 ('placeholder_1_data', 'fc_1'),
1152 ('const_fc_1_w', 'fc_1_w'),
1153 ('const_fc_1_b', 'fc_1_b'),
1156 ('fc_1', 'fc_1_data'),
1157 ('fc_1_data', 'op_output')
1159 {'placeholder_1_data': {'shape': np.array([1, 2048])},
1160 'const_fc_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
1161 'fc_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
1162 'output_channel_dim': 0, 'input_channel_dim': 1,
1164 'const_fc_1_b': {'shape': np.array([10260]), 'value': np.ones(10260)},
1165 'fc_1_b': {'shape': np.array([10260]), 'value': np.ones(10260)},
1166 'fc_1_data': {'shape': np.array([1, 10260])},
1169 fuse_linear_ops(graph)
1170 graph_clean_up(graph)
1172 (flag, resp) = compare_graphs(graph, graph_ref, 'fc_1_data')
1173 self.assertTrue(flag, resp)
1175 # FC(w)->Add(scalar)
1176 def test_fuse_add_to_fc_3(self):
1177 # Placeholder->FC->Add
1178 graph = build_graph(nodes_attributes,
1179 [('placeholder_1', 'placeholder_1_data'),
1180 ('placeholder_1_data', 'fc_1'),
1181 ('const_fc_1_w', 'fc_1_w'),
1183 ('fc_1', 'fc_1_data'),
1184 ('fc_1_data', 'add_1'),
1185 ('const_add_1_w', 'add_1_w'),
1186 ('add_1_w', 'add_1'),
1187 ('add_1', 'add_1_data'),
1188 ('add_1_data', 'op_output')
1190 {'placeholder_1_data': {'shape': np.array([1, 2048])},
1191 'add_1_data': {'shape': np.array([1, 10260])},
1192 'const_add_1_w': {'shape': np.array([1]), 'value': np.array([6]), 'data_type': None},
1193 'add_1_w': {'shape': np.array([1]), 'value': np.array([6]), 'data_type': None},
1194 'const_fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048))},
1195 'fc_1_w': {'shape': np.array([10260, 2048]), 'value': np.ones((10260, 2048)),
1196 'output_channel_dim': 0, 'input_channel_dim': 1,
1197 'dims_number': 2, 'data_type': None},
1198 'fc_1_data': {'shape': np.array([1, 10260])},
1201 ref_weights = np.ones((10260, 2048))
1202 ref_biases = np.array([6 for x in range(10260)])
1204 graph_ref = build_graph(nodes_attributes,
1205 [('placeholder_1', 'placeholder_1_data'),
1206 ('placeholder_1_data', 'fc_1'),
1207 ('const_fc_1_w', 'fc_1_w'),
1208 ('const_fc_1_b', 'fc_1_b'),
1211 ('fc_1', 'fc_1_data'),
1212 ('fc_1_data', 'op_output')
1214 {'placeholder_1_data': {'shape': np.array([1, 2048])},
1215 'const_fc_1_w': {'shape': ref_weights.shape, 'value': ref_weights},
1216 'fc_1_w': {'shape': ref_weights.shape, 'value': ref_weights,
1217 'output_channel_dim': 0, 'input_channel_dim': 1,
1219 'const_fc_1_b': {'shape': ref_biases.shape, 'value': ref_biases},
1220 'fc_1_b': {'shape': ref_biases.shape, 'value': ref_biases},
1221 'fc_1_data': {'shape': np.array([1, 10260])},
1224 fuse_linear_ops(graph)
1225 graph_clean_up(graph)
1227 (flag, resp) = compare_graphs(graph, graph_ref, 'add_1_data', 'fc_1_data')
1228 self.assertTrue(flag, resp)
1232 # Placeholder--->Add->Mul-----+->Concat
1233 def test_fuse_lin_op_1(self):
1234 graph = build_graph(nodes_attributes,
1235 [('placeholder_1_data', 'conv_1'),
1236 ('conv_1', 'conv_1_data'),
1237 ('const_conv_1_w', 'conv_1_w'),
1238 ('const_conv_1_b', 'conv_1_b'),
1239 ('conv_1_w', 'conv_1'),
1240 ('conv_1_b', 'conv_1'),
1241 ('conv_1_data', 'add_1'),
1242 ('const_add_1_w', 'add_1_w'),
1243 ('add_1_w', 'add_1'),
1244 ('add_1', 'add_1_data'),
1245 ('concat_1', 'concat_1_data'),
1246 ('const_mul_1_w', 'mul_1_w'),
1247 ('mul_1_w', 'mul_1'),
1248 ('mul_1', 'mul_1_data'),
1249 ('add_1_data', 'concat_1'),
1250 ('mul_1_data', 'concat_1'),
1251 ('add_1_data', 'mul_1'),
1252 ('concat_1_data', 'op_output')
1254 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
1255 'const_conv_1_w': {'shape': np.array([1, 1, 3, 3]), 'value': np.zeros((1, 1, 3, 3))},
1256 'conv_1_w': {'shape': np.array([1, 1, 3, 3]), 'value': np.zeros((1, 1, 3, 3)),
1257 'output_channel_dim': 3, 'input_channel_dim': 2,
1259 'const_conv_1_b': {'shape': np.array([3]), 'value': np.zeros(3)},
1260 'conv_1_b': {'shape': np.array([3]), 'value': np.zeros(3)},
1261 'conv_1_data': {'shape': np.array([1, 227, 227, 3])},
1262 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
1263 'add_1_data': {'shape': np.array([1, 227, 227, 3])},
1264 'const_mul_1_w': {'shape': np.array([1]), 'value': np.array([6])},
1265 'mul_1_w': {'shape': np.array([1]), 'value': np.array([6])},
1266 'const_add_1_w': {'shape': np.array([1]), 'value': np.array([1])},
1267 'add_1_w': {'shape': np.array([1]), 'value': np.array([1])},
1271 graph_ref = build_graph(nodes_attributes,
1272 [('placeholder_1_data', 'conv_1'),
1273 ('conv_1', 'conv_1_data'),
1274 ('const_conv_1_w', 'conv_1_w'),
1275 ('const_conv_1_b', 'conv_1_b'),
1276 ('conv_1_w', 'conv_1'),
1277 ('conv_1_b', 'conv_1'),
1278 ('conv_1_data', 'concat_1'),
1279 ('const_mul_1_w', 'mul_1_w'),
1280 ('mul_1_w', 'mul_1'),
1281 ('conv_1_data', 'mul_1'),
1282 ('concat_1', 'concat_1_data'),
1283 ('mul_1', 'mul_1_data'),
1284 ('mul_1_data', 'concat_1'),
1285 ('concat_1_data', 'op_output')
1287 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
1288 'const_conv_1_w': {'shape': np.array([1, 1, 3, 3]), 'value': np.zeros((1, 1, 3, 3))},
1289 'conv_1_w': {'shape': np.array([1, 1, 3, 3]), 'value': np.zeros((1, 1, 3, 3)),
1290 'output_channel_dim': 3, 'input_channel_dim': 2,
1292 'const_conv_1_b': {'shape': np.array([3]), 'value': np.ones(3)},
1293 'conv_1_b': {'shape': np.array([3]), 'value': np.ones(3)},
1294 'conv_1_data': {'shape': np.array([1, 227, 227, 3])},
1295 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
1296 'const_mul_1_w': {'shape': np.array([1]), 'value': np.array([6])},
1297 'mul_1_w': {'shape': np.array([1]), 'value': np.array([6])},
1301 fuse_linear_ops(graph)
1302 graph_clean_up(graph)
1304 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
1305 self.assertTrue(flag, resp)
1307 # Op->Mul(array)-+->Conv(w+b)------+->Concat
1308 # | | => Same('can_be_fused': False)
1309 # +-->Conv(w+b)-----+
1310 def test_fuse_lin_ops_2(self):
1311 graph = build_graph(nodes_attributes,
1312 [('placeholder_1', 'placeholder_1_data'),
1313 ('placeholder_1_data', 'mul_1'),
1314 ('const_mul_1_w', 'mul_1_w'),
1315 ('mul_1_w', 'mul_1'),
1316 ('mul_1', 'mul_1_data'),
1317 ('mul_1_data', 'conv_1'),
1318 ('const_conv_1_w', 'conv_1_w'),
1319 ('const_conv_1_b', 'conv_1_b'),
1320 ('conv_1_w', 'conv_1'),
1321 ('conv_1_b', 'conv_1'),
1322 ('conv_1', 'conv_1_data'),
1323 ('mul_1_data', 'conv_2'),
1324 ('const_conv_2_w', 'conv_2_w'),
1325 ('const_conv_2_b', 'conv_2_b'),
1326 ('conv_2_w', 'conv_2'),
1327 ('conv_2_b', 'conv_2'),
1328 ('conv_2', 'conv_2_data'),
1329 ('conv_1_data', 'concat_1'),
1330 ('conv_2_data', 'concat_1'),
1331 ('concat_1', 'concat_1_data'),
1332 ('concat_1_data', 'op_output')
1335 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
1336 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
1337 'const_mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
1338 'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
1339 'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
1340 'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
1341 'output_channel_dim': 3, 'input_channel_dim': 2,
1343 'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1344 'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1345 'conv_1_data': {'shape': np.array([1, 55, 55, 96])},
1346 'conv_2': {'can_be_fused': False},
1347 'const_conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
1348 'conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
1349 'output_channel_dim': 3, 'input_channel_dim': 2,
1351 'const_conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1352 'conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1353 'conv_2_data': {'shape': np.array([1, 55, 55, 96])},
1357 graph_ref = build_graph(nodes_attributes,
1358 [('placeholder_1', 'placeholder_1_data'),
1359 ('placeholder_1_data', 'mul_1'),
1360 ('const_mul_1_w', 'mul_1_w'),
1361 ('mul_1_w', 'mul_1'),
1362 ('mul_1', 'mul_1_data'),
1363 ('mul_1_data', 'conv_1'),
1364 ('const_conv_1_w', 'conv_1_w'),
1365 ('const_conv_1_b', 'conv_1_b'),
1366 ('conv_1_w', 'conv_1'),
1367 ('conv_1_b', 'conv_1'),
1368 ('conv_1', 'conv_1_data'),
1369 ('mul_1_data', 'conv_2'),
1370 ('const_conv_2_w', 'conv_2_w'),
1371 ('const_conv_2_b', 'conv_2_b'),
1372 ('conv_2_w', 'conv_2'),
1373 ('conv_2_b', 'conv_2'),
1374 ('conv_2', 'conv_2_data'),
1375 ('conv_1_data', 'concat_1'),
1376 ('conv_2_data', 'concat_1'),
1377 ('concat_1', 'concat_1_data'),
1378 ('concat_1_data', 'op_output')
1380 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
1381 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
1382 'const_mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
1383 'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
1384 'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
1385 'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
1386 'output_channel_dim': 3, 'input_channel_dim': 2,
1388 'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1389 'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1390 'conv_1_data': {'shape': np.array([1, 55, 55, 96])},
1391 'conv_2': {'can_be_fused': False},
1392 'const_conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
1393 'conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
1394 'output_channel_dim': 3, 'input_channel_dim': 2,
1396 'const_conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1397 'conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1398 'conv_2_data': {'shape': np.array([1, 55, 55, 96])},
1402 fuse_linear_ops(graph)
1403 graph_clean_up(graph)
1405 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
1406 self.assertTrue(flag, resp)
1408 # Op->Mul(array)-+->Conv(w+b)------+->Concat
1409 # | | => Same('can_be_fused': False)
1410 # +-->Conv(w+b)-----+
1411 def test_fuse_lin_ops_3(self):
1412 graph = build_graph(nodes_attributes,
1413 [('placeholder_1', 'placeholder_1_data'),
1414 ('placeholder_1_data', 'mul_1'),
1415 ('const_mul_1_w', 'mul_1_w'),
1416 ('mul_1_w', 'mul_1'),
1417 ('mul_1', 'mul_1_data'),
1418 ('mul_1_data', 'conv_1'),
1419 ('const_conv_1_w', 'conv_1_w'),
1420 ('const_conv_1_b', 'conv_1_b'),
1421 ('conv_1_w', 'conv_1'),
1422 ('conv_1_b', 'conv_1'),
1423 ('conv_1', 'conv_1_data'),
1424 ('mul_1_data', 'conv_2'),
1425 ('const_conv_2_w', 'conv_2_w'),
1426 ('const_conv_2_b', 'conv_2_b'),
1427 ('conv_2_w', 'conv_2'),
1428 ('conv_2_b', 'conv_2'),
1429 ('conv_2', 'conv_2_data'),
1430 ('conv_1_data', 'concat_1'),
1431 ('conv_2_data', 'concat_1'),
1432 ('concat_1', 'concat_1_data'),
1433 ('concat_1_data', 'op_output')
1435 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
1436 'mul_1': {'can_be_fused': False},
1437 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
1438 'const_mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
1439 'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
1440 'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
1441 'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
1442 'output_channel_dim': 3, 'input_channel_dim': 2,
1444 'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1445 'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1446 'conv_1_data': {'shape': np.array([1, 55, 55, 96])},
1447 'const_conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
1448 'conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
1449 'output_channel_dim': 3, 'input_channel_dim': 2,
1451 'const_conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1452 'conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1453 'conv_2_data': {'shape': np.array([1, 55, 55, 96])},
1457 graph_ref = build_graph(nodes_attributes,
1458 [('placeholder_1', 'placeholder_1_data'),
1459 ('placeholder_1_data', 'mul_1'),
1460 ('const_mul_1_w', 'mul_1_w'),
1461 ('mul_1_w', 'mul_1'),
1462 ('mul_1', 'mul_1_data'),
1463 ('mul_1_data', 'conv_1'),
1464 ('const_conv_1_w', 'conv_1_w'),
1465 ('const_conv_1_b', 'conv_1_b'),
1466 ('conv_1_w', 'conv_1'),
1467 ('conv_1_b', 'conv_1'),
1468 ('conv_1', 'conv_1_data'),
1469 ('mul_1_data', 'conv_2'),
1470 ('const_conv_2_w', 'conv_2_w'),
1471 ('const_conv_2_b', 'conv_2_b'),
1472 ('conv_2_w', 'conv_2'),
1473 ('conv_2_b', 'conv_2'),
1474 ('conv_2', 'conv_2_data'),
1475 ('conv_1_data', 'concat_1'),
1476 ('conv_2_data', 'concat_1'),
1477 ('concat_1', 'concat_1_data'),
1478 ('concat_1_data', 'op_output')
1480 {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
1481 'mul_1': {'can_be_fused': False},
1482 'mul_1_data': {'shape': np.array([1, 227, 227, 3])},
1483 'const_mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
1484 'mul_1_w': {'shape': np.array([3]), 'value': np.array([1, 2, 3])},
1485 'const_conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
1486 'conv_1_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
1487 'output_channel_dim': 3, 'input_channel_dim': 2,
1489 'const_conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1490 'conv_1_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1491 'conv_1_data': {'shape': np.array([1, 55, 55, 96])},
1492 'const_conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96))},
1493 'conv_2_w': {'shape': np.array([11, 11, 3, 96]), 'value': np.ones((11, 11, 3, 96)),
1494 'output_channel_dim': 3, 'input_channel_dim': 2,
1496 'const_conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1497 'conv_2_b': {'shape': np.array([96]), 'value': np.zeros(96)},
1498 'conv_2_data': {'shape': np.array([1, 55, 55, 96])},
1502 fuse_linear_ops(graph)
1503 graph_clean_up(graph)
1505 (flag, resp) = compare_graphs(graph, graph_ref, 'concat_1_data')
1506 self.assertTrue(flag, resp)