2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
20 from generator import generator, generate
22 from mo.front.extractor import input_user_data_repack, output_user_data_repack, extract_port_from_string, \
23 update_ie_fields, add_input_op
24 from mo.front.extractor import spatial_attr_getter, add_input_ops, attr_getter, CaffePythonFrontExtractorOp, \
26 from mo.graph.graph import Node
27 from mo.middle.passes import eliminate
28 from mo.utils.error import Error
29 from mo.utils.unittest.extractors import FakeMultiParam
30 from mo.utils.unittest.graph import build_graph, build_graph_with_edge_attrs, build_graph_with_attrs, compare_graphs
33 class FakePythonParam:
34 def __init__(self, param: FakeMultiParam):
35 self.__setattr__('python_param', param)
38 nodes_attributes = {'input': {'kind': 'data'},
39 'pool_1': {'type': 'Pooling', 'kind': 'op'},
40 'output': {'kind': 'data'},
41 'op_output': {'kind': 'op', 'op': 'OpOutput'},
45 class UpdateIEFieldsTest(unittest.TestCase):
46 def test_default_update_ie_fields(self):
47 update_ie_fields({}, ir_version=None)
49 def test_not_set_update_ie_fields(self):
50 with self.assertRaisesRegex(Error, 'Unrecognized IR version.*'):
51 update_ie_fields({}, ir_version='abracadabra')
54 class TestExtractor(unittest.TestCase):
55 def test_spatial_attr_getter(self):
56 input_shape = np.array([1, 125, 13, 13])
58 'kernel': np.array([1, 1, 1, 2]),
59 'pad': np.array([1, 1, 3, 4]),
60 'stride': np.array([1, 1, 2, 3]),
62 graph = build_graph(nodes_attributes,
65 ('output', 'op_output')
67 {'input': {'shape': input_shape},
68 'pool_1': {**params, 'spatial_dims': [2, 3]},
69 'output': {'shape': None}})
70 pool_1_node = Node(graph, 'pool_1')
71 for param in params.keys():
72 if type(params[param]) is np.ndarray:
73 port_lambda = lambda x: x
74 self.assertEqual(params[param][2],
75 spatial_attr_getter(pool_1_node, field=param, dim=0, post=port_lambda))
76 self.assertEqual(params[param][3],
77 spatial_attr_getter(pool_1_node, field=param, dim=1, post=port_lambda))
79 def test_attr_getter(self):
80 nodes = {'input': {'kind': 'data'},
81 'reshape': {'type': 'Reshape', 'kind': 'op'},
82 'output': {'kind': 'data'}
84 input_shape = np.array([1, 125, 13, 13])
87 'max_size': np.array([3, 2, 1, 0])
91 'max_size': "3,2,1,0",
93 graph = build_graph(nodes,
94 [('input', 'reshape'),
95 ('reshape', 'output'),
96 ('output', 'op_output')
98 {'input': {'shape': input_shape},
99 'reshape': {**params, 'spatial_dims': [2, 3]},
100 'output': {'shape': None}})
101 pool_1_node = Node(graph, 'reshape')
102 for param in params.keys():
103 if type(params[param]) is list:
104 self.assertEqual(expect_params[param],
105 attr_getter(pool_1_node, param))
108 class TestAddInputOp(unittest.TestCase):
110 ('op_node', {'kind': 'op'}),
111 ('future_input', {'kind': 'op'}),
112 ('another_node', {'kind': 'op'}),
114 edges = [('future_input', 'op_node', {'in': 1, 'out': 0}),
115 ('another_node', 'op_node', {'in': 0, 'out': 0})]
117 def test_in_port_no_data(self):
118 graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges)
119 new_input_shape = np.array([1, 2, 3, 4])
120 graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges[1:],
121 new_nodes_with_attrs=[('input_node', {'kind': 'op', 'op': 'Placeholder',
122 'shape': new_input_shape})],
123 new_edges_with_attrs=[('input_node', 'op_node', {'in': 1, 'out': 0})])
124 add_input_op(graph, 'op_node', 1, data=False, shape=new_input_shape)
125 graph.remove_edge('future_input', 'op_node')
126 (flag, resp) = compare_graphs(graph, graph_ref, last_node='op_node')
127 self.assertTrue(flag, resp)
129 def test_in_port_with_data(self):
130 graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges)
131 new_input_shape = np.array([1, 2, 3, 4])
132 graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges[1:],
133 new_nodes_with_attrs=[('input_node', {'kind': 'op', 'op': 'Placeholder',
134 'shape': new_input_shape}),
135 ('input_data', {'kind': 'data'})],
136 new_edges_with_attrs=[('input_node', 'input_data', {'in': 0, 'out': 0}),
137 ('input_data', 'op_node', {'in': 1, 'out': 0})])
138 add_input_op(graph, 'op_node', 1, data=True, shape=new_input_shape)
139 graph.remove_edge('future_input', 'op_node')
140 (flag, resp) = compare_graphs(graph, graph_ref, last_node='op_node')
141 self.assertTrue(flag, resp)
144 ('op_node', {'kind': 'op'}),
145 ('future_input', {'kind': 'op'}),
146 ('another_node', {'kind': 'op'}),
148 edges_out = [('op_node', 'future_input', {'in': 0, 'out': 1}),
149 ('op_node', 'another_node', {'in': 0, 'out': 0})]
151 def test_out_port_no_data(self):
152 graph = build_graph_with_attrs(nodes_with_attrs=self.nodes_out, edges_with_attrs=self.edges_out)
153 new_input_shape = np.array([1, 2, 3, 4])
154 graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes_out, edges_with_attrs=self.edges_out[1:],
155 new_nodes_with_attrs=[('input_node', {'kind': 'op', 'op': 'Placeholder',
156 'shape': new_input_shape})],
157 new_edges_with_attrs=[('input_node', 'future_input', {'in': 0, 'out': 0})])
158 add_input_op(graph, 'op_node', 1, data=False, shape=new_input_shape, is_out_port=True)
159 graph.remove_edge('op_node', 'future_input')
160 (flag, resp) = compare_graphs(graph, graph_ref, last_node='another_node')
161 self.assertTrue(flag, resp)
162 (flag, resp) = compare_graphs(graph, graph_ref, last_node='future_input')
163 self.assertTrue(flag, resp)
165 def test_out_port_with_data(self):
166 graph = build_graph_with_attrs(nodes_with_attrs=self.nodes_out, edges_with_attrs=self.edges_out[1:],
167 new_nodes_with_attrs=[('input_data', {'kind': 'data', 'shape': None})],
168 new_edges_with_attrs=[('op_node', 'input_data', {'out': 1, 'in': 0}),
169 ('input_data', 'future_input', {'in': 0, 'out': 0})])
170 new_input_shape = np.array([1, 2, 3, 4])
171 graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes_out, edges_with_attrs=self.edges_out[1:],
172 new_nodes_with_attrs=[('input_node', {'kind': 'op', 'op': 'Placeholder',
173 'shape': new_input_shape}),
174 ('input_data', {'kind': 'data', 'shape': None})],
175 new_edges_with_attrs=[('input_node', 'input_data', {'in': 0, 'out': 0}),
176 ('input_data', 'future_input', {'in': 0, 'out': 0})])
177 add_input_op(graph, 'op_node', 1, data=True, shape=new_input_shape, is_out_port=True)
178 graph.remove_edge('op_node', 'input_data')
180 (flag, resp) = compare_graphs(graph, graph_ref, last_node='another_node')
181 self.assertTrue(flag, resp)
182 (flag, resp) = compare_graphs(graph, graph_ref, last_node='future_input')
183 self.assertTrue(flag, resp)
186 class TestInputAddition(unittest.TestCase):
188 nodes = {'node_1': {'type': 'Identity', 'kind': 'op', 'op': 'Placeholder'},
189 'conv_1': {'type': 'Convolution', 'kind': 'op', 'op': 'NotPlaceholder'},
190 'relu_1': {'type': 'ReLU', 'kind': 'op', 'op': 'NotPlaceholder'},
193 ('node_1', 'conv_1'),
194 ('conv_1', 'relu_1'),
197 def test_none_out_port_raise(self):
198 graph = build_graph(self.nodes, self.edges)
199 shape = np.array([1, 2, 3, 4])
200 inputs = {'conv_1': [{'shape': shape, 'out': None}]}
201 with self.assertRaisesRegex(Error, 'Output port for input node conv_1 should be specified, it cannot be None!'):
202 add_input_ops(graph=graph, user_defined_inputs=inputs, before_infer=True)
204 def test_wrong_output_port_raise(self):
205 graph = build_graph(self.nodes, self.edges)
206 shape = np.array([1, 2, 3, 4])
207 inputs = {'conv_1': [{'shape': shape, 'out': 5}]}
208 with self.assertRaisesRegex(Error, 'Output port index 5 is out of number of available output ports for node'):
209 add_input_ops(graph=graph, user_defined_inputs=inputs, before_infer=True)
211 def test_wrong_input_port_raise(self):
212 graph = build_graph(self.nodes, self.edges)
213 shape = np.array([1, 2, 3, 4])
214 inputs = {'conv_1': [{'shape': shape, 'in': 5}]}
215 with self.assertRaisesRegex(Error, 'Input port index 5 is out of number of available input ports for node'):
216 add_input_ops(graph=graph, user_defined_inputs=inputs, before_infer=True)
218 def test_one_input_one_shape(self):
219 shape = np.array([1, 2, 3, 4])
220 inputs = {'conv_1': [{'shape': shape}]}
222 'old_input': {'type': 'Identity', 'kind': 'op', 'op': 'Placeholder'},
223 'conv_1': {'type': 'Convolution', 'kind': 'op', 'op': 'NotPlaceholder'},
224 'relu_1': {'type': 'ReLU', 'kind': 'op', 'op': 'NotPlaceholder'},
225 'output': {'type': 'SoftMax', 'kind': 'op', 'op': 'NotPlaceholder'}
228 ('old_input', 'conv_1'),
229 ('conv_1', 'relu_1'),
232 graph = build_graph(nodes, edges)
233 add_input_ops(graph=graph, user_defined_inputs=inputs, before_infer=True)
234 new_input = list(graph.in_edges('conv_1'))[0][0]
235 self.assertFalse(graph.node['old_input']['is_input'])
236 self.assertTrue(graph.node[new_input]['is_input'])
237 self.assertTrue((new_input, 'conv_1') in graph.edges())
238 self.assertTrue(('old_input', 'conv_1') not in graph.edges())
239 shapes_are_equal = np.array_equal(graph.node[new_input]['shape'], shape)
240 self.assertTrue(shapes_are_equal)
242 def test_one_input_no_shape(self):
244 inputs = {'conv_1': [{'shape': shape}]}
246 'old_input': {'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
247 'old_input_data': {'kind': 'data', 'value': None, 'shape': np.array([-1, 224, 224, 3])},
248 'conv_1': {'type': 'Convolution', 'kind': 'op', 'op': 'NotPlaceholder'},
249 'conv_1_data': {'kind': 'data', 'value': True, 'shape': np.array([-1, 224, 224, 3])},
250 'relu_1': {'type': 'ReLU', 'kind': 'op', 'op': 'NotPlaceholder'},
251 'relu_1_data': {'kind': 'data', 'value': None, 'shape': np.array([-1, 112, 112, 64])},
252 'output': {'type': 'SoftMax', 'kind': 'op', 'op': 'NotPlaceholder'},
253 'output_data': {'name': 'output_data', 'kind': 'data', 'shape': np.array([-1, 112, 112, 64])},
254 'op_output': {'kind': 'op', 'op': 'OpOutput'}
257 ('old_input', 'old_input_data'),
258 ('old_input_data', 'conv_1'),
259 ('conv_1', 'conv_1_data'),
260 ('conv_1_data', 'relu_1'),
261 ('relu_1', 'relu_1_data'),
262 ('relu_1_data', 'output'),
263 ('output', 'output_data'),
264 ('output_data', 'op_output')
266 graph = build_graph(nodes, edges)
267 add_input_ops(graph=graph, user_defined_inputs=inputs, before_infer=False)
268 new_input = list(graph.in_edges(list(graph.in_edges('conv_1'))[0][0]))[0][0]
269 new_input_data = list(graph.in_edges('conv_1'))[0][0]
270 self.assertFalse(graph.node['old_input']['is_input'])
271 self.assertTrue(graph.node[new_input]['is_input'])
272 self.assertTrue((new_input_data, 'conv_1') in graph.edges())
273 self.assertTrue(('old_input_data', 'conv_1') not in graph.edges())
274 self.assertIsNotNone(graph.node[new_input_data]['shape'])
276 def test_two_inputs_two_shapes_positive_1(self):
277 shape_1 = [1, 2, 3, 4]
278 shape_2 = [4, 3, 2, 1]
279 inputs = {'node_1': [{'shape': shape_1}], 'node_4': [{'shape': shape_2}]}
281 'input_1': {'type': 'Identity', 'kind': 'op', 'op': 'Placeholder'},
282 'input_2': {'type': 'Identity', 'kind': 'op', 'op': 'Placeholder'},
283 'node_1': {'type': 'Identity', 'kind': 'op', 'op': 'NotPlaceholder'},
284 'node_2': {'type': 'Identity', 'kind': 'op', 'op': 'NotPlaceholder'},
285 'node_3': {'type': 'Identity', 'kind': 'op', 'op': 'NotPlaceholder'},
286 'node_4': {'type': 'Identity', 'kind': 'op', 'op': 'NotPlaceholder'},
287 'output': {'kind': 'op', 'op': 'OpOutput'}
290 ('input_1', 'node_1'),
291 ('node_1', 'node_2'),
292 ('node_3', 'output'),
293 ('input_2', 'node_4'),
296 graph = build_graph(nodes, edges)
297 add_input_ops(graph=graph, user_defined_inputs=inputs, before_infer=True)
298 new_input_1 = list(graph.in_edges('node_1'))[0][0]
299 new_input_2 = list(graph.in_edges('node_4'))[0][0]
300 self.assertFalse(graph.node['input_1']['is_input'])
301 self.assertTrue(graph.node[new_input_1]['is_input'])
302 self.assertTrue(graph.node[new_input_2]['is_input'])
303 self.assertTrue((new_input_1, 'node_1') in graph.edges())
304 self.assertTrue((new_input_2, 'node_4') in graph.edges())
305 self.assertListEqual(shape_1, graph.node[new_input_1]['shape'])
306 self.assertListEqual(shape_2, graph.node[new_input_2]['shape'])
308 def test_two_inputs_two_shapes_not_all_inputs(self):
309 shape_1 = [1, 2, 3, 4]
310 shape_2 = [4, 3, 2, 1]
311 inputs = {'node_1': [{'shape': shape_1}], 'node_4': [{'shape': shape_2}]}
313 'input_1': {'type': 'Identity', 'kind': 'op', 'op': 'Placeholder'},
314 'input_2': {'type': 'Identity', 'kind': 'op', 'op': 'Placeholder'},
315 'node_1': {'type': 'Identity', 'kind': 'op', 'op': 'NotPlaceholder'},
316 'node_2': {'type': 'Identity', 'kind': 'op', 'op': 'NotPlaceholder'},
317 'node_3': {'type': 'Identity', 'kind': 'op', 'op': 'NotPlaceholder'},
318 'node_4': {'type': 'Identity', 'kind': 'op', 'op': 'NotPlaceholder'},
319 'output': { 'kind': 'op', 'op': 'OpOutput'},
320 'input_3': {'type': 'Identity', 'kind': 'op', 'op': 'Placeholder'}
323 ('input_1', 'node_1'),
324 ('node_1', 'node_2'),
325 ('node_3', 'output'),
326 ('input_2', 'node_4'),
327 ('node_4', 'output'),
328 ('input_3', 'output')
330 graph = build_graph(nodes, edges)
331 self.assertRaises(Error, add_input_ops, graph, inputs, True)
333 # Tests for cases with input/output ports cutting
334 def test_add_input_with_input_port_before_infer(self):
335 shape = np.array([1, 2, 3, 4])
336 inputs = {'conv_1': [{'shape': shape, 'in': 0}]}
338 'old_input': {'type': 'Identity', 'kind': 'op', 'op': 'Placeholder'},
339 'conv_1': {'type': 'Convolution', 'kind': 'op', 'op': 'NotPlaceholder'},
340 'relu_1': {'type': 'ReLU', 'kind': 'op', 'op': 'NotPlaceholder'},
341 'output': {'type': 'SoftMax', 'kind': 'op', 'op': 'NotPlaceholder'}
344 ('old_input', 'conv_1'),
345 ('conv_1', 'relu_1'),
348 graph = build_graph(nodes, edges)
349 add_input_ops(graph=graph, user_defined_inputs=inputs, before_infer=True)
352 graph_ref = build_graph(nodes, edges, update_attributes={'old_input': {'shape': shape}})
353 (flag, resp) = compare_graphs(graph, graph_ref, last_node='output')
354 self.assertTrue(flag, resp)
356 # also checks that new old_input was changed
357 new_input = list(graph.in_edges('conv_1'))[0][0]
358 self.assertFalse(graph.node['old_input']['is_input'])
359 self.assertTrue(graph.node[new_input]['is_input'])
360 self.assertTrue((new_input, 'conv_1') in graph.edges())
361 self.assertTrue(('old_input', 'conv_1') not in graph.edges())
363 def test_add_input_with_output_port_before_infer(self):
364 shape = np.array([1, 2, 3, 4])
365 inputs = {'conv_1': [{'shape': shape, 'out': 0}]}
367 'old_input': {'type': 'Identity', 'kind': 'op', 'op': 'Placeholder'},
368 'conv_1': {'type': 'Convolution', 'kind': 'op', 'op': 'NotPlaceholder'},
369 'conv_2': {'type': 'Convolution', 'kind': 'op', 'op': 'NotPlaceholder'},
370 'relu_1': {'type': 'ReLU', 'kind': 'op', 'op': 'NotPlaceholder'},
371 'output': {'type': 'SoftMax', 'kind': 'op', 'op': 'NotPlaceholder'}
374 ('old_input', 'conv_1'),
375 ('conv_1', 'relu_1'),
376 ('conv_2', 'relu_1'),
379 graph = build_graph(nodes, edges)
380 add_input_ops(graph=graph, user_defined_inputs=inputs, before_infer=True)
382 graph_ref = build_graph(nodes_attrs={'new_input': {'kind': 'op', 'op': 'Placeholder', 'shape': shape},
384 edges=[('new_input', 'relu_1'),
385 ('relu_1', 'output'),
386 ('conv_2', 'relu_1'),
387 ('old_input', 'conv_1'),],)
388 # Check that new input is added right (with right ports !)
389 (flag, resp) = compare_graphs(graph, graph_ref, last_node='output')
390 self.assertTrue(flag, resp)
392 # Check that other graph is not damaged
393 (flag, resp) = compare_graphs(graph, graph_ref, last_node='conv_1')
394 self.assertTrue(flag, resp)
396 # Checks for new input and edges
397 self.assertTrue('conv_1/placeholder_out_port_0' in graph.nodes())
398 new_input = 'conv_1/placeholder_out_port_0'
399 self.assertTrue(graph.node[new_input]['is_input'])
400 self.assertTrue((new_input, 'relu_1') in graph.edges())
401 self.assertTrue(('old_input', 'relu_1') not in graph.edges())
403 def test_add_input_with_output_port_after_infer(self):
404 shape = np.array([1, 2, 3, 4])
405 inputs = {'conv_1': [{'shape': shape, 'out': 0}]}
407 'old_input': {'type': 'Identity', 'kind': 'op', 'op': 'Placeholder'},
408 'inp_data' : {'kind': 'data', 'shape': shape + 1},
409 'conv_1': {'type': 'Convolution', 'kind': 'op', 'op': 'NotPlaceholder'},
410 'conv_data': {'kind': 'data', 'shape': shape},
411 'relu_1': {'type': 'ReLU', 'kind': 'op', 'op': 'NotPlaceholder'},
414 ('old_input', 'inp_data'),
415 ('inp_data', 'conv_1'),
416 ('conv_1', 'conv_data'),
417 ('conv_data', 'relu_1'),
419 graph = build_graph(nodes, edges)
420 add_input_ops(graph=graph, user_defined_inputs=inputs, before_infer=False)
422 graph_ref = build_graph(nodes_attrs={'new_input': {'kind': 'op', 'op': 'Placeholder', 'shape': shape},
424 edges=[('old_input', 'inp_data'),
425 ('inp_data', 'conv_1'),
426 ('new_input', 'conv_data'),
427 ('conv_data', 'relu_1'),
429 # Check that new input is added right (with right ports !)
430 (flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1')
431 self.assertTrue(flag, resp)
433 # Check that other graph is not damaged
434 (flag, resp) = compare_graphs(graph, graph_ref, last_node='conv_1')
435 self.assertTrue(flag, resp)
437 # Checks for new input and edges
438 self.assertTrue('conv_1/placeholder_out_port_0' in graph.nodes())
439 new_input = 'conv_1/placeholder_out_port_0'
441 self.assertTrue(graph.node[new_input]['is_input'])
442 self.assertTrue((new_input, 'conv_data') in graph.edges())
443 self.assertTrue(('conv_1', 'conv_data') not in graph.edges())
446 class TestOutputCut(unittest.TestCase):
447 # {'embeddings': [{'port': None}]}
448 @generate({'C': [{'port': None}]}, {'C': [{'out': 0}]}, {'C': [{'out': 1}]})
449 def test_output_port_cut(self, output):
450 nodes = {'A': {'type': 'Identity', 'kind': 'op'},
451 'B': {'type': 'Identity', 'kind': 'op'},
452 'C': {'type': 'Identity', 'kind': 'op'},
453 'D': {'type': 'Identity', 'kind': 'op'},
454 'E': {'type': 'Identity', 'kind': 'op'},
457 ('A', 'C', {'in': 0, 'out': 0}),
458 ('B', 'C', {'in': 1, 'out': 0}),
459 ('C', 'D', {'in': 0, 'out': 0}),
460 ('C', 'E', {'in': 0, 'out': 1})
462 graph = build_graph_with_edge_attrs(nodes, edges)
463 sinks = add_output_ops(graph, output)
464 eliminate.graph_clean_up(graph)
465 self.assertEqual(len(Node(graph, 'C').out_nodes()), 1)
466 self.assertEqual(len(Node(graph, 'C').in_nodes()), 2)
468 @generate({'C': [{'in': 0}]}, {'C': [{'in': 1}]})
469 def test_output_port_cut(self, output):
470 nodes = {'A': {'op': 'Placeholder', 'kind': 'op'},
471 'B': {'op': 'Placeholder', 'kind': 'op'},
472 'C': {'type': 'Identity', 'kind': 'op'},
473 'D': {'type': 'Identity', 'kind': 'op'},
474 'E': {'type': 'Identity', 'kind': 'op'},
477 ('A', 'C', {'in': 0, 'out': 0}),
478 ('B', 'C', {'in': 1, 'out': 0}),
479 ('C', 'D', {'in': 0, 'out': 0}),
480 ('C', 'E', {'in': 0, 'out': 1})
482 graph = build_graph_with_edge_attrs(nodes, edges)
483 sinks = add_output_ops(graph, output)
484 eliminate.graph_clean_up(graph)
485 self.assertEqual(len(graph.nodes()), 2)
488 class TestUserDataRepack(unittest.TestCase):
489 nodes = {'A': {'name': 'Aa', 'op': 'Placeholder', 'kind': 'op'},
490 'B': {'name': 'Bb', 'op': 'Placeholder', 'kind': 'op'},
491 'C': {'name': 'Cc', 'type': 'Identity', 'value': None, 'kind': 'op'},
492 'D': {'name': 'Dd', 'type': 'Identity', 'value': None, 'kind': 'op'},
493 'E': {'name': 'Ee', 'type': 'Identity', 'value': None, 'kind': 'op'},
496 ('A', 'C', {'in': 0, 'out': 0}),
497 ('B', 'C', {'in': 1, 'out': 0}),
498 ('C', 'D', {'in': 0, 'out': 0}),
499 ('C', 'E', {'in': 0, 'out': 1})
502 def test_input_user_data_repack_none(self):
503 graph = build_graph_with_edge_attrs(self.nodes, self.edges)
504 input, freeze_placeholder = input_user_data_repack(graph, None, None)
505 self.assertEqual(input, None)
506 self.assertEqual(freeze_placeholder, None)
508 def test_input_user_data_repack_names_to_ids_list(self):
509 graph = build_graph_with_edge_attrs(self.nodes, self.edges)
510 input, freeze_placeholder = input_user_data_repack(graph, ['Aa', 'Bb'], None)
511 self.assertDictEqual(input, {'A': [{'shape': None, 'port': None}], 'B': [{'shape': None, 'port': None}]})
512 self.assertEqual(freeze_placeholder, None)
514 def test_input_user_data_repack_names_ports_in_out(self):
515 graph = build_graph_with_edge_attrs(self.nodes, self.edges)
516 input, freeze_placeholder = input_user_data_repack(graph, ['Aa:1', '0:Bb'], None)
517 self.assertDictEqual(input, {'A': [{'shape': None, 'out': 1}], 'B': [{'shape': None, 'in': 0}]})
518 self.assertEqual(freeze_placeholder, None)
520 def test_input_user_data_repack_dict_with_shapes(self):
521 graph = build_graph_with_edge_attrs(self.nodes, self.edges)
522 shape_1 = np.array([1, 160, 160, 3])
523 shape_2 = np.array([1, 127, 127, 3])
524 input, freeze_placeholder = input_user_data_repack(graph, {'Aa': shape_1, 'Bb': shape_2}, None)
525 self.assertDictEqual(input, {'A': [{'shape': shape_1, 'port': None}], 'B': [{'shape': shape_2, 'port': None}]})
526 self.assertEqual(freeze_placeholder, None)
528 def test_input_user_data_repack_dict_with_shapes_and_ports(self):
529 graph = build_graph_with_edge_attrs(self.nodes, self.edges)
530 shape_1 = np.array([1, 160, 160, 3])
531 shape_2 = np.array([1, 127, 127, 3])
532 input, freeze_placeholder = input_user_data_repack(graph, {'Aa:0': shape_1, 'Bb:1': shape_2}, None)
533 self.assertDictEqual(input, {'A': [{'shape': shape_1, 'out': 0}], 'B': [{'shape': shape_2, 'out': 1}]})
534 self.assertEqual(freeze_placeholder, None)
536 def test_freeze_placeholder_and_input(self):
537 graph = build_graph_with_edge_attrs(self.nodes, self.edges)
538 shape_1 = np.array([1, 160, 160, 3])
539 input, freeze_placeholder = input_user_data_repack(graph, {'Aa:0': shape_1}, {'Bb': False})
540 self.assertDictEqual(input, {'A': [{'shape': shape_1, 'out': 0}], 'B': [{'shape': None, 'port': None}]})
541 self.assertEqual(freeze_placeholder, {'B': False})
543 def test_error(self):
544 graph = build_graph_with_edge_attrs(self.nodes, self.edges)
545 self.assertRaises(Error, input_user_data_repack, graph, np.array([1, 227, 227, 3]), None)
547 def test_error_2(self):
548 graph = build_graph_with_edge_attrs(self.nodes, self.edges)
549 self.assertRaises(Error, input_user_data_repack, graph, np.array([1, 227, 227, 3]), None)
551 def test_error_3(self):
552 graph = build_graph_with_edge_attrs(self.nodes, self.edges)
553 self.assertRaises(Error, input_user_data_repack, graph, ['Bcb'], None)
555 def test_input_and_freeze(self):
556 graph = build_graph_with_edge_attrs(self.nodes, self.edges)
557 shape_1 = np.array([1, 160, 160, 3])
558 input, freeze_placeholder = input_user_data_repack(graph, shape_1, {'Bb': True})
559 self.assertDictEqual(input, {'A': [{'shape': shape_1, 'port': None}], 'B': [{'shape': None, 'port': None}]})
560 self.assertDictEqual(freeze_placeholder, {'B': True})
562 def test_output_user_data_repack(self):
563 graph = build_graph_with_edge_attrs(self.nodes, self.edges)
564 output = output_user_data_repack(graph, ['Cc'])
565 self.assertDictEqual(output, {'C': [{'port': None}]})
567 def test_output_user_data_repack_ports(self):
568 graph = build_graph_with_edge_attrs(self.nodes, self.edges)
569 output = output_user_data_repack(graph, ['Cc:1', '0:Cc'])
570 self.assertDictEqual(output, {'C': [{'out': 1}, {'in': 0}]})
572 def test_output_user_data_repack_none(self):
573 graph = build_graph_with_edge_attrs(self.nodes, self.edges)
574 output = output_user_data_repack(graph, None)
575 self.assertEqual(output, None)
578 class TestExtractPort(unittest.TestCase):
579 def test_out_port(self):
580 name, in_port, out_port = extract_port_from_string('node_name:1')
581 self.assertEqual(name, 'node_name')
582 self.assertEqual(in_port, None)
583 self.assertEqual(out_port, 1)
585 def test_in_port(self):
586 name, in_port, out_port = extract_port_from_string('0:node_name')
587 self.assertEqual(name, 'node_name')
588 self.assertEqual(in_port, 0)
589 self.assertEqual(out_port, None)
591 def test_no_port(self):
592 name, in_port, out_port = extract_port_from_string('node_name')
593 self.assertEqual(name, 'node_name')
594 self.assertEqual(in_port, None)
595 self.assertEqual(out_port, None)
597 def test_non_int(self):
598 self.assertRaises(Error, extract_port_from_string, 'port:node_name')
600 def test_two_ports(self):
601 self.assertRaises(Error, extract_port_from_string, '1:node_name:0')
604 class TestCaffePythonFrontExtractorOp(unittest.TestCase):
605 def test_get_attrs(self):
606 exp_attrs = {"test_attr_1": 12, "test_attr_2": "sdf sdf"}
607 param_str = "'test_attr_1': 12, 'test_attr_2': 'sdf sdf'"
608 attrs = CaffePythonFrontExtractorOp.get_attrs(FakePythonParam(FakeMultiParam({'param_str': param_str})))
609 self.assertEqual(exp_attrs, attrs)