Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / extractor_test.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import unittest
18
19 import numpy as np
20 from generator import generator, generate
21
22 from mo.front.extractor import input_user_data_repack, output_user_data_repack, extract_port_from_string, \
23     update_ie_fields, add_input_op
24 from mo.front.extractor import spatial_attr_getter, add_input_ops, attr_getter, CaffePythonFrontExtractorOp, \
25     add_output_ops
26 from mo.graph.graph import Node
27 from mo.middle.passes import eliminate
28 from mo.utils.error import Error
29 from mo.utils.unittest.extractors import FakeMultiParam
30 from mo.utils.unittest.graph import build_graph, build_graph_with_edge_attrs, build_graph_with_attrs, compare_graphs
31
32
33 class FakePythonParam:
34     def __init__(self, param: FakeMultiParam):
35         self.__setattr__('python_param', param)
36
37
38 nodes_attributes = {'input': {'kind': 'data'},
39                     'pool_1': {'type': 'Pooling', 'kind': 'op'},
40                     'output': {'kind': 'data'},
41                     'op_output': {'kind': 'op', 'op': 'OpOutput'},
42                     }
43
44
45 class UpdateIEFieldsTest(unittest.TestCase):
46     def test_default_update_ie_fields(self):
47         update_ie_fields({}, ir_version=None)
48
49     def test_not_set_update_ie_fields(self):
50         with self.assertRaisesRegex(Error, 'Unrecognized IR version.*'):
51             update_ie_fields({}, ir_version='abracadabra')
52
53
54 class TestExtractor(unittest.TestCase):
55     def test_spatial_attr_getter(self):
56         input_shape = np.array([1, 125, 13, 13])
57         params = {
58             'kernel': np.array([1, 1, 1, 2]),
59             'pad': np.array([1, 1, 3, 4]),
60             'stride': np.array([1, 1, 2, 3]),
61         }
62         graph = build_graph(nodes_attributes,
63                             [('input', 'pool_1'),
64                              ('pool_1', 'output'),
65                              ('output', 'op_output')
66                              ],
67                             {'input': {'shape': input_shape},
68                              'pool_1': {**params, 'spatial_dims': [2, 3]},
69                              'output': {'shape': None}})
70         pool_1_node = Node(graph, 'pool_1')
71         for param in params.keys():
72             if type(params[param]) is np.ndarray:
73                 port_lambda = lambda x: x
74                 self.assertEqual(params[param][2],
75                                  spatial_attr_getter(pool_1_node, field=param, dim=0, post=port_lambda))
76                 self.assertEqual(params[param][3],
77                                  spatial_attr_getter(pool_1_node, field=param, dim=1, post=port_lambda))
78
79     def test_attr_getter(self):
80         nodes = {'input': {'kind': 'data'},
81                  'reshape': {'type': 'Reshape', 'kind': 'op'},
82                  'output': {'kind': 'data'}
83                  }
84         input_shape = np.array([1, 125, 13, 13])
85         params = {
86             'dim': [1, 1, 2, 3],
87             'max_size': np.array([3, 2, 1, 0])
88         }
89         expect_params = {
90             'dim': "1,1,2,3",
91             'max_size': "3,2,1,0",
92         }
93         graph = build_graph(nodes,
94                             [('input', 'reshape'),
95                              ('reshape', 'output'),
96                              ('output', 'op_output')
97                              ],
98                             {'input': {'shape': input_shape},
99                              'reshape': {**params, 'spatial_dims': [2, 3]},
100                              'output': {'shape': None}})
101         pool_1_node = Node(graph, 'reshape')
102         for param in params.keys():
103             if type(params[param]) is list:
104                 self.assertEqual(expect_params[param],
105                                  attr_getter(pool_1_node, param))
106
107
108 class TestAddInputOp(unittest.TestCase):
109     nodes = [
110         ('op_node', {'kind': 'op'}),
111         ('future_input', {'kind': 'op'}),
112         ('another_node', {'kind': 'op'}),
113     ]
114     edges = [('future_input', 'op_node', {'in': 1, 'out': 0}),
115              ('another_node', 'op_node', {'in': 0, 'out': 0})]
116
117     def test_in_port_no_data(self):
118         graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges)
119         new_input_shape = np.array([1, 2, 3, 4])
120         graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges[1:],
121                                            new_nodes_with_attrs=[('input_node', {'kind': 'op', 'op': 'Placeholder',
122                                                                                  'shape': new_input_shape})],
123                                            new_edges_with_attrs=[('input_node', 'op_node', {'in': 1, 'out': 0})])
124         add_input_op(graph, 'op_node', 1, data=False, shape=new_input_shape)
125         graph.remove_edge('future_input', 'op_node')
126         (flag, resp) = compare_graphs(graph, graph_ref, last_node='op_node')
127         self.assertTrue(flag, resp)
128
129     def test_in_port_with_data(self):
130         graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges)
131         new_input_shape = np.array([1, 2, 3, 4])
132         graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges[1:],
133                                            new_nodes_with_attrs=[('input_node', {'kind': 'op', 'op': 'Placeholder',
134                                                                                  'shape': new_input_shape}),
135                                                                  ('input_data', {'kind': 'data'})],
136                                            new_edges_with_attrs=[('input_node', 'input_data', {'in': 0, 'out': 0}),
137                                                                  ('input_data', 'op_node', {'in': 1, 'out': 0})])
138         add_input_op(graph, 'op_node', 1, data=True, shape=new_input_shape)
139         graph.remove_edge('future_input', 'op_node')
140         (flag, resp) = compare_graphs(graph, graph_ref, last_node='op_node')
141         self.assertTrue(flag, resp)
142
143     nodes_out = [
144         ('op_node', {'kind': 'op'}),
145         ('future_input', {'kind': 'op'}),
146         ('another_node', {'kind': 'op'}),
147     ]
148     edges_out = [('op_node', 'future_input', {'in': 0, 'out': 1}),
149                  ('op_node', 'another_node', {'in': 0, 'out': 0})]
150
151     def test_out_port_no_data(self):
152         graph = build_graph_with_attrs(nodes_with_attrs=self.nodes_out, edges_with_attrs=self.edges_out)
153         new_input_shape = np.array([1, 2, 3, 4])
154         graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes_out, edges_with_attrs=self.edges_out[1:],
155                                            new_nodes_with_attrs=[('input_node', {'kind': 'op', 'op': 'Placeholder',
156                                                                                  'shape': new_input_shape})],
157                                            new_edges_with_attrs=[('input_node', 'future_input', {'in': 0, 'out': 0})])
158         add_input_op(graph, 'op_node', 1, data=False, shape=new_input_shape, is_out_port=True)
159         graph.remove_edge('op_node', 'future_input')
160         (flag, resp) = compare_graphs(graph, graph_ref, last_node='another_node')
161         self.assertTrue(flag, resp)
162         (flag, resp) = compare_graphs(graph, graph_ref, last_node='future_input')
163         self.assertTrue(flag, resp)
164
165     def test_out_port_with_data(self):
166         graph = build_graph_with_attrs(nodes_with_attrs=self.nodes_out, edges_with_attrs=self.edges_out[1:],
167                                        new_nodes_with_attrs=[('input_data', {'kind': 'data', 'shape': None})],
168                                        new_edges_with_attrs=[('op_node', 'input_data', {'out': 1, 'in': 0}),
169                                                              ('input_data', 'future_input', {'in': 0, 'out': 0})])
170         new_input_shape = np.array([1, 2, 3, 4])
171         graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes_out, edges_with_attrs=self.edges_out[1:],
172                                            new_nodes_with_attrs=[('input_node', {'kind': 'op', 'op': 'Placeholder',
173                                                                                  'shape': new_input_shape}),
174                                                                  ('input_data', {'kind': 'data', 'shape': None})],
175                                            new_edges_with_attrs=[('input_node', 'input_data', {'in': 0, 'out': 0}),
176                                                                  ('input_data', 'future_input', {'in': 0, 'out': 0})])
177         add_input_op(graph, 'op_node', 1, data=True, shape=new_input_shape, is_out_port=True)
178         graph.remove_edge('op_node', 'input_data')
179
180         (flag, resp) = compare_graphs(graph, graph_ref, last_node='another_node')
181         self.assertTrue(flag, resp)
182         (flag, resp) = compare_graphs(graph, graph_ref, last_node='future_input')
183         self.assertTrue(flag, resp)
184
185
186 class TestInputAddition(unittest.TestCase):
187     # Tests for input
188     nodes = {'node_1': {'type': 'Identity', 'kind': 'op', 'op': 'Placeholder'},
189              'conv_1': {'type': 'Convolution', 'kind': 'op', 'op': 'NotPlaceholder'},
190              'relu_1': {'type': 'ReLU', 'kind': 'op', 'op': 'NotPlaceholder'},
191              }
192     edges = [
193         ('node_1', 'conv_1'),
194         ('conv_1', 'relu_1'),
195     ]
196
197     def test_none_out_port_raise(self):
198         graph = build_graph(self.nodes, self.edges)
199         shape = np.array([1, 2, 3, 4])
200         inputs = {'conv_1': [{'shape': shape, 'out': None}]}
201         with self.assertRaisesRegex(Error, 'Output port for input node conv_1 should be specified, it cannot be None!'):
202             add_input_ops(graph=graph, user_defined_inputs=inputs, before_infer=True)
203
204     def test_wrong_output_port_raise(self):
205         graph = build_graph(self.nodes, self.edges)
206         shape = np.array([1, 2, 3, 4])
207         inputs = {'conv_1': [{'shape': shape, 'out': 5}]}
208         with self.assertRaisesRegex(Error, 'Output port index 5 is out of number of available output ports for node'):
209             add_input_ops(graph=graph, user_defined_inputs=inputs, before_infer=True)
210
211     def test_wrong_input_port_raise(self):
212         graph = build_graph(self.nodes, self.edges)
213         shape = np.array([1, 2, 3, 4])
214         inputs = {'conv_1': [{'shape': shape, 'in': 5}]}
215         with self.assertRaisesRegex(Error, 'Input port index 5 is out of number of available input ports for node'):
216             add_input_ops(graph=graph, user_defined_inputs=inputs, before_infer=True)
217
218     def test_one_input_one_shape(self):
219         shape = np.array([1, 2, 3, 4])
220         inputs = {'conv_1': [{'shape': shape}]}
221         nodes = {
222             'old_input': {'type': 'Identity', 'kind': 'op', 'op': 'Placeholder'},
223             'conv_1': {'type': 'Convolution', 'kind': 'op', 'op': 'NotPlaceholder'},
224             'relu_1': {'type': 'ReLU', 'kind': 'op', 'op': 'NotPlaceholder'},
225             'output': {'type': 'SoftMax', 'kind': 'op', 'op': 'NotPlaceholder'}
226         }
227         edges = [
228             ('old_input', 'conv_1'),
229             ('conv_1', 'relu_1'),
230             ('relu_1', 'output')
231         ]
232         graph = build_graph(nodes, edges)
233         add_input_ops(graph=graph, user_defined_inputs=inputs, before_infer=True)
234         new_input = list(graph.in_edges('conv_1'))[0][0]
235         self.assertFalse(graph.node['old_input']['is_input'])
236         self.assertTrue(graph.node[new_input]['is_input'])
237         self.assertTrue((new_input, 'conv_1') in graph.edges())
238         self.assertTrue(('old_input', 'conv_1') not in graph.edges())
239         shapes_are_equal = np.array_equal(graph.node[new_input]['shape'], shape)
240         self.assertTrue(shapes_are_equal)
241
242     def test_one_input_no_shape(self):
243         shape = None
244         inputs = {'conv_1': [{'shape': shape}]}
245         nodes = {
246             'old_input': {'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
247             'old_input_data': {'kind': 'data', 'value': None, 'shape': np.array([-1, 224, 224, 3])},
248             'conv_1': {'type': 'Convolution', 'kind': 'op', 'op': 'NotPlaceholder'},
249             'conv_1_data': {'kind': 'data', 'value': True, 'shape': np.array([-1, 224, 224, 3])},
250             'relu_1': {'type': 'ReLU', 'kind': 'op', 'op': 'NotPlaceholder'},
251             'relu_1_data': {'kind': 'data', 'value': None, 'shape': np.array([-1, 112, 112, 64])},
252             'output': {'type': 'SoftMax', 'kind': 'op', 'op': 'NotPlaceholder'},
253             'output_data': {'name': 'output_data', 'kind': 'data', 'shape': np.array([-1, 112, 112, 64])},
254             'op_output': {'kind': 'op', 'op': 'OpOutput'}
255         }
256         edges = [
257             ('old_input', 'old_input_data'),
258             ('old_input_data', 'conv_1'),
259             ('conv_1', 'conv_1_data'),
260             ('conv_1_data', 'relu_1'),
261             ('relu_1', 'relu_1_data'),
262             ('relu_1_data', 'output'),
263             ('output', 'output_data'),
264             ('output_data', 'op_output')
265         ]
266         graph = build_graph(nodes, edges)
267         add_input_ops(graph=graph, user_defined_inputs=inputs, before_infer=False)
268         new_input = list(graph.in_edges(list(graph.in_edges('conv_1'))[0][0]))[0][0]
269         new_input_data = list(graph.in_edges('conv_1'))[0][0]
270         self.assertFalse(graph.node['old_input']['is_input'])
271         self.assertTrue(graph.node[new_input]['is_input'])
272         self.assertTrue((new_input_data, 'conv_1') in graph.edges())
273         self.assertTrue(('old_input_data', 'conv_1') not in graph.edges())
274         self.assertIsNotNone(graph.node[new_input_data]['shape'])
275
276     def test_two_inputs_two_shapes_positive_1(self):
277         shape_1 = [1, 2, 3, 4]
278         shape_2 = [4, 3, 2, 1]
279         inputs = {'node_1': [{'shape': shape_1}], 'node_4': [{'shape': shape_2}]}
280         nodes = {
281             'input_1': {'type': 'Identity', 'kind': 'op', 'op': 'Placeholder'},
282             'input_2': {'type': 'Identity', 'kind': 'op', 'op': 'Placeholder'},
283             'node_1': {'type': 'Identity', 'kind': 'op', 'op': 'NotPlaceholder'},
284             'node_2': {'type': 'Identity', 'kind': 'op', 'op': 'NotPlaceholder'},
285             'node_3': {'type': 'Identity', 'kind': 'op', 'op': 'NotPlaceholder'},
286             'node_4': {'type': 'Identity', 'kind': 'op', 'op': 'NotPlaceholder'},
287             'output': {'kind': 'op', 'op': 'OpOutput'}
288         }
289         edges = [
290             ('input_1', 'node_1'),
291             ('node_1', 'node_2'),
292             ('node_3', 'output'),
293             ('input_2', 'node_4'),
294             ('node_4', 'output')
295         ]
296         graph = build_graph(nodes, edges)
297         add_input_ops(graph=graph, user_defined_inputs=inputs, before_infer=True)
298         new_input_1 = list(graph.in_edges('node_1'))[0][0]
299         new_input_2 = list(graph.in_edges('node_4'))[0][0]
300         self.assertFalse(graph.node['input_1']['is_input'])
301         self.assertTrue(graph.node[new_input_1]['is_input'])
302         self.assertTrue(graph.node[new_input_2]['is_input'])
303         self.assertTrue((new_input_1, 'node_1') in graph.edges())
304         self.assertTrue((new_input_2, 'node_4') in graph.edges())
305         self.assertListEqual(shape_1, graph.node[new_input_1]['shape'])
306         self.assertListEqual(shape_2, graph.node[new_input_2]['shape'])
307
308     def test_two_inputs_two_shapes_not_all_inputs(self):
309         shape_1 = [1, 2, 3, 4]
310         shape_2 = [4, 3, 2, 1]
311         inputs = {'node_1': [{'shape': shape_1}], 'node_4': [{'shape': shape_2}]}
312         nodes = {
313             'input_1': {'type': 'Identity', 'kind': 'op', 'op': 'Placeholder'},
314             'input_2': {'type': 'Identity', 'kind': 'op', 'op': 'Placeholder'},
315             'node_1': {'type': 'Identity', 'kind': 'op', 'op': 'NotPlaceholder'},
316             'node_2': {'type': 'Identity', 'kind': 'op', 'op': 'NotPlaceholder'},
317             'node_3': {'type': 'Identity', 'kind': 'op', 'op': 'NotPlaceholder'},
318             'node_4': {'type': 'Identity', 'kind': 'op', 'op': 'NotPlaceholder'},
319             'output': { 'kind': 'op', 'op': 'OpOutput'},
320             'input_3': {'type': 'Identity', 'kind': 'op', 'op': 'Placeholder'}
321         }
322         edges = [
323             ('input_1', 'node_1'),
324             ('node_1', 'node_2'),
325             ('node_3', 'output'),
326             ('input_2', 'node_4'),
327             ('node_4', 'output'),
328             ('input_3', 'output')
329         ]
330         graph = build_graph(nodes, edges)
331         self.assertRaises(Error, add_input_ops, graph, inputs, True)
332
333     # Tests for cases with input/output ports cutting
334     def test_add_input_with_input_port_before_infer(self):
335         shape = np.array([1, 2, 3, 4])
336         inputs = {'conv_1': [{'shape': shape, 'in': 0}]}
337         nodes = {
338             'old_input': {'type': 'Identity', 'kind': 'op', 'op': 'Placeholder'},
339             'conv_1': {'type': 'Convolution', 'kind': 'op', 'op': 'NotPlaceholder'},
340             'relu_1': {'type': 'ReLU', 'kind': 'op', 'op': 'NotPlaceholder'},
341             'output': {'type': 'SoftMax', 'kind': 'op', 'op': 'NotPlaceholder'}
342         }
343         edges = [
344             ('old_input', 'conv_1'),
345             ('conv_1', 'relu_1'),
346             ('relu_1', 'output')
347         ]
348         graph = build_graph(nodes, edges)
349         add_input_ops(graph=graph, user_defined_inputs=inputs, before_infer=True)
350
351         # Check that graph
352         graph_ref = build_graph(nodes, edges, update_attributes={'old_input': {'shape': shape}})
353         (flag, resp) = compare_graphs(graph, graph_ref, last_node='output')
354         self.assertTrue(flag, resp)
355
356         # also checks that new old_input was changed
357         new_input = list(graph.in_edges('conv_1'))[0][0]
358         self.assertFalse(graph.node['old_input']['is_input'])
359         self.assertTrue(graph.node[new_input]['is_input'])
360         self.assertTrue((new_input, 'conv_1') in graph.edges())
361         self.assertTrue(('old_input', 'conv_1') not in graph.edges())
362
363     def test_add_input_with_output_port_before_infer(self):
364         shape = np.array([1, 2, 3, 4])
365         inputs = {'conv_1': [{'shape': shape, 'out': 0}]}
366         nodes = {
367             'old_input': {'type': 'Identity', 'kind': 'op', 'op': 'Placeholder'},
368             'conv_1': {'type': 'Convolution', 'kind': 'op', 'op': 'NotPlaceholder'},
369             'conv_2': {'type': 'Convolution', 'kind': 'op', 'op': 'NotPlaceholder'},
370             'relu_1': {'type': 'ReLU', 'kind': 'op', 'op': 'NotPlaceholder'},
371             'output': {'type': 'SoftMax', 'kind': 'op', 'op': 'NotPlaceholder'}
372         }
373         edges = [
374             ('old_input', 'conv_1'),
375             ('conv_1', 'relu_1'),
376             ('conv_2', 'relu_1'),
377             ('relu_1', 'output')
378         ]
379         graph = build_graph(nodes, edges)
380         add_input_ops(graph=graph, user_defined_inputs=inputs, before_infer=True)
381
382         graph_ref = build_graph(nodes_attrs={'new_input': {'kind': 'op', 'op': 'Placeholder', 'shape': shape},
383                                              **nodes},
384                                 edges=[('new_input', 'relu_1'),
385                                        ('relu_1', 'output'),
386                                        ('conv_2', 'relu_1'),
387                                        ('old_input', 'conv_1'),],)
388         # Check that new input is added right (with right ports !)
389         (flag, resp) = compare_graphs(graph, graph_ref, last_node='output')
390         self.assertTrue(flag, resp)
391
392         # Check that other graph is not damaged
393         (flag, resp) = compare_graphs(graph, graph_ref, last_node='conv_1')
394         self.assertTrue(flag, resp)
395
396         # Checks for new input and edges
397         self.assertTrue('conv_1/placeholder_out_port_0' in graph.nodes())
398         new_input = 'conv_1/placeholder_out_port_0'
399         self.assertTrue(graph.node[new_input]['is_input'])
400         self.assertTrue((new_input, 'relu_1') in graph.edges())
401         self.assertTrue(('old_input', 'relu_1') not in graph.edges())
402
403     def test_add_input_with_output_port_after_infer(self):
404         shape = np.array([1, 2, 3, 4])
405         inputs = {'conv_1': [{'shape': shape, 'out': 0}]}
406         nodes = {
407             'old_input': {'type': 'Identity', 'kind': 'op', 'op': 'Placeholder'},
408             'inp_data' : {'kind': 'data', 'shape': shape + 1},
409             'conv_1': {'type': 'Convolution', 'kind': 'op', 'op': 'NotPlaceholder'},
410             'conv_data': {'kind': 'data', 'shape': shape},
411             'relu_1': {'type': 'ReLU', 'kind': 'op', 'op': 'NotPlaceholder'},
412         }
413         edges = [
414             ('old_input', 'inp_data'),
415             ('inp_data', 'conv_1'),
416             ('conv_1', 'conv_data'),
417             ('conv_data', 'relu_1'),
418         ]
419         graph = build_graph(nodes, edges)
420         add_input_ops(graph=graph, user_defined_inputs=inputs, before_infer=False)
421
422         graph_ref = build_graph(nodes_attrs={'new_input': {'kind': 'op', 'op': 'Placeholder', 'shape': shape},
423                                              **nodes},
424                                 edges=[('old_input', 'inp_data'),
425                                        ('inp_data', 'conv_1'),
426                                        ('new_input', 'conv_data'),
427                                        ('conv_data', 'relu_1'),
428                                        ],)
429         # Check that new input is added right (with right ports !)
430         (flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1')
431         self.assertTrue(flag, resp)
432
433         # Check that other graph is not damaged
434         (flag, resp) = compare_graphs(graph, graph_ref, last_node='conv_1')
435         self.assertTrue(flag, resp)
436
437         # Checks for new input and edges
438         self.assertTrue('conv_1/placeholder_out_port_0' in graph.nodes())
439         new_input = 'conv_1/placeholder_out_port_0'
440
441         self.assertTrue(graph.node[new_input]['is_input'])
442         self.assertTrue((new_input, 'conv_data') in graph.edges())
443         self.assertTrue(('conv_1', 'conv_data') not in graph.edges())
444
445 @generator
446 class TestOutputCut(unittest.TestCase):
447     # {'embeddings': [{'port': None}]}
448     @generate({'C': [{'port': None}]}, {'C': [{'out': 0}]}, {'C': [{'out': 1}]})
449     def test_output_port_cut(self, output):
450         nodes = {'A': {'type': 'Identity', 'kind': 'op'},
451                  'B': {'type': 'Identity', 'kind': 'op'},
452                  'C': {'type': 'Identity', 'kind': 'op'},
453                  'D': {'type': 'Identity', 'kind': 'op'},
454                  'E': {'type': 'Identity', 'kind': 'op'},
455                  }
456         edges = [
457             ('A', 'C', {'in': 0, 'out': 0}),
458             ('B', 'C', {'in': 1, 'out': 0}),
459             ('C', 'D', {'in': 0, 'out': 0}),
460             ('C', 'E', {'in': 0, 'out': 1})
461         ]
462         graph = build_graph_with_edge_attrs(nodes, edges)
463         sinks = add_output_ops(graph, output)
464         eliminate.graph_clean_up(graph)
465         self.assertEqual(len(Node(graph, 'C').out_nodes()), 1)
466         self.assertEqual(len(Node(graph, 'C').in_nodes()), 2)
467
468     @generate({'C': [{'in': 0}]}, {'C': [{'in': 1}]})
469     def test_output_port_cut(self, output):
470         nodes = {'A': {'op': 'Placeholder', 'kind': 'op'},
471                  'B': {'op': 'Placeholder', 'kind': 'op'},
472                  'C': {'type': 'Identity', 'kind': 'op'},
473                  'D': {'type': 'Identity', 'kind': 'op'},
474                  'E': {'type': 'Identity', 'kind': 'op'},
475                  }
476         edges = [
477             ('A', 'C', {'in': 0, 'out': 0}),
478             ('B', 'C', {'in': 1, 'out': 0}),
479             ('C', 'D', {'in': 0, 'out': 0}),
480             ('C', 'E', {'in': 0, 'out': 1})
481         ]
482         graph = build_graph_with_edge_attrs(nodes, edges)
483         sinks = add_output_ops(graph, output)
484         eliminate.graph_clean_up(graph)
485         self.assertEqual(len(graph.nodes()), 2)
486
487
488 class TestUserDataRepack(unittest.TestCase):
489     nodes = {'A': {'name': 'Aa', 'op': 'Placeholder', 'kind': 'op'},
490              'B': {'name': 'Bb', 'op': 'Placeholder', 'kind': 'op'},
491              'C': {'name': 'Cc', 'type': 'Identity', 'value': None, 'kind': 'op'},
492              'D': {'name': 'Dd', 'type': 'Identity', 'value': None, 'kind': 'op'},
493              'E': {'name': 'Ee', 'type': 'Identity', 'value': None, 'kind': 'op'},
494              }
495     edges = [
496         ('A', 'C', {'in': 0, 'out': 0}),
497         ('B', 'C', {'in': 1, 'out': 0}),
498         ('C', 'D', {'in': 0, 'out': 0}),
499         ('C', 'E', {'in': 0, 'out': 1})
500     ]
501
502     def test_input_user_data_repack_none(self):
503         graph = build_graph_with_edge_attrs(self.nodes, self.edges)
504         input, freeze_placeholder = input_user_data_repack(graph, None, None)
505         self.assertEqual(input, None)
506         self.assertEqual(freeze_placeholder, None)
507
508     def test_input_user_data_repack_names_to_ids_list(self):
509         graph = build_graph_with_edge_attrs(self.nodes, self.edges)
510         input, freeze_placeholder = input_user_data_repack(graph, ['Aa', 'Bb'], None)
511         self.assertDictEqual(input, {'A': [{'shape': None, 'port': None}], 'B': [{'shape': None, 'port': None}]})
512         self.assertEqual(freeze_placeholder, None)
513
514     def test_input_user_data_repack_names_ports_in_out(self):
515         graph = build_graph_with_edge_attrs(self.nodes, self.edges)
516         input, freeze_placeholder = input_user_data_repack(graph, ['Aa:1', '0:Bb'], None)
517         self.assertDictEqual(input, {'A': [{'shape': None, 'out': 1}], 'B': [{'shape': None, 'in': 0}]})
518         self.assertEqual(freeze_placeholder, None)
519
520     def test_input_user_data_repack_dict_with_shapes(self):
521         graph = build_graph_with_edge_attrs(self.nodes, self.edges)
522         shape_1 = np.array([1, 160, 160, 3])
523         shape_2 = np.array([1, 127, 127, 3])
524         input, freeze_placeholder = input_user_data_repack(graph, {'Aa': shape_1, 'Bb': shape_2}, None)
525         self.assertDictEqual(input, {'A': [{'shape': shape_1, 'port': None}], 'B': [{'shape': shape_2, 'port': None}]})
526         self.assertEqual(freeze_placeholder, None)
527
528     def test_input_user_data_repack_dict_with_shapes_and_ports(self):
529         graph = build_graph_with_edge_attrs(self.nodes, self.edges)
530         shape_1 = np.array([1, 160, 160, 3])
531         shape_2 = np.array([1, 127, 127, 3])
532         input, freeze_placeholder = input_user_data_repack(graph, {'Aa:0': shape_1, 'Bb:1': shape_2}, None)
533         self.assertDictEqual(input, {'A': [{'shape': shape_1, 'out': 0}], 'B': [{'shape': shape_2, 'out': 1}]})
534         self.assertEqual(freeze_placeholder, None)
535
536     def test_freeze_placeholder_and_input(self):
537         graph = build_graph_with_edge_attrs(self.nodes, self.edges)
538         shape_1 = np.array([1, 160, 160, 3])
539         input, freeze_placeholder = input_user_data_repack(graph, {'Aa:0': shape_1}, {'Bb': False})
540         self.assertDictEqual(input, {'A': [{'shape': shape_1, 'out': 0}], 'B': [{'shape': None, 'port': None}]})
541         self.assertEqual(freeze_placeholder, {'B': False})
542
543     def test_error(self):
544         graph = build_graph_with_edge_attrs(self.nodes, self.edges)
545         self.assertRaises(Error, input_user_data_repack, graph, np.array([1, 227, 227, 3]), None)
546
547     def test_error_2(self):
548         graph = build_graph_with_edge_attrs(self.nodes, self.edges)
549         self.assertRaises(Error, input_user_data_repack, graph, np.array([1, 227, 227, 3]), None)
550
551     def test_error_3(self):
552         graph = build_graph_with_edge_attrs(self.nodes, self.edges)
553         self.assertRaises(Error, input_user_data_repack, graph, ['Bcb'], None)
554
555     def test_input_and_freeze(self):
556         graph = build_graph_with_edge_attrs(self.nodes, self.edges)
557         shape_1 = np.array([1, 160, 160, 3])
558         input, freeze_placeholder = input_user_data_repack(graph, shape_1, {'Bb': True})
559         self.assertDictEqual(input, {'A': [{'shape': shape_1, 'port': None}], 'B': [{'shape': None, 'port': None}]})
560         self.assertDictEqual(freeze_placeholder, {'B': True})
561
562     def test_output_user_data_repack(self):
563         graph = build_graph_with_edge_attrs(self.nodes, self.edges)
564         output = output_user_data_repack(graph, ['Cc'])
565         self.assertDictEqual(output, {'C': [{'port': None}]})
566
567     def test_output_user_data_repack_ports(self):
568         graph = build_graph_with_edge_attrs(self.nodes, self.edges)
569         output = output_user_data_repack(graph, ['Cc:1', '0:Cc'])
570         self.assertDictEqual(output, {'C': [{'out': 1}, {'in': 0}]})
571
572     def test_output_user_data_repack_none(self):
573         graph = build_graph_with_edge_attrs(self.nodes, self.edges)
574         output = output_user_data_repack(graph, None)
575         self.assertEqual(output, None)
576
577
578 class TestExtractPort(unittest.TestCase):
579     def test_out_port(self):
580         name, in_port, out_port = extract_port_from_string('node_name:1')
581         self.assertEqual(name, 'node_name')
582         self.assertEqual(in_port, None)
583         self.assertEqual(out_port, 1)
584
585     def test_in_port(self):
586         name, in_port, out_port = extract_port_from_string('0:node_name')
587         self.assertEqual(name, 'node_name')
588         self.assertEqual(in_port, 0)
589         self.assertEqual(out_port, None)
590
591     def test_no_port(self):
592         name, in_port, out_port = extract_port_from_string('node_name')
593         self.assertEqual(name, 'node_name')
594         self.assertEqual(in_port, None)
595         self.assertEqual(out_port, None)
596
597     def test_non_int(self):
598         self.assertRaises(Error, extract_port_from_string, 'port:node_name')
599
600     def test_two_ports(self):
601         self.assertRaises(Error, extract_port_from_string, '1:node_name:0')
602
603
604 class TestCaffePythonFrontExtractorOp(unittest.TestCase):
605     def test_get_attrs(self):
606         exp_attrs = {"test_attr_1": 12, "test_attr_2": "sdf sdf"}
607         param_str = "'test_attr_1': 12, 'test_attr_2': 'sdf sdf'"
608         attrs = CaffePythonFrontExtractorOp.get_attrs(FakePythonParam(FakeMultiParam({'param_str': param_str})))
609         self.assertEqual(exp_attrs, attrs)