2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from mo.front.common.partial_infer.concat import concat_infer
22 from mo.graph.graph import Node
23 from mo.middle.passes.infer import override_placeholder_shapes, partial_infer
24 from mo.utils.error import Error
25 from mo.utils.unittest.graph import build_graph
27 nodes_attributes = {'node_1': {'type': 'Identity', 'value': None, 'kind': 'op'},
28 'node_1_data': {'value': None, 'kind': 'data', 'data_type': None},
29 'node_2': {'type': 'Identity', 'value': None, 'kind': 'op'},
30 'concat': {'type': 'Concat', 'value': None, 'kind': 'op'},
31 'node_3': {'type': 'Identity', 'value': None, 'kind': 'op'},
32 'node_3_data': {'value': None, 'kind': 'data', 'data_type': None},
34 'placeholder_1': {'shape': None, 'type': 'Input', 'kind': 'op', 'op': 'Placeholder'},
35 'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
36 'placeholder_2': {'shape': None, 'type': 'Input', 'kind': 'op', 'op': 'Placeholder'},
37 'pl_1': {'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
38 'pl_1_data': {'value': None, 'kind': 'data', 'data_type': None},
39 'pl_2': {'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
40 'pl_2_data': {'value': None, 'kind': 'data', 'data_type': None},
41 'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
43 'scaleshift_1': {'type': 'ScaleShift', 'kind': 'op', 'op': 'ScaleShift'},
44 'scaleshift_1_w': {'value': None, 'shape': None, 'kind': 'data'},
45 'scaleshift_1_b': {'value': None, 'shape': None, 'kind': 'data'},
46 'scaleshift_1_data': {'value': None, 'shape': None, 'kind': 'data'},
48 'mul_1': {'type': None, 'kind': 'op', 'op': 'Mul'},
49 'mul_1_w': {'value': None, 'shape': None, 'kind': 'data'},
50 'mul_1_data': {'value': None, 'shape': None, 'kind': 'data'},
51 'op_output': { 'kind': 'op', 'op': 'OpOutput', 'infer': lambda x: None}
55 class TestInferPass(unittest.TestCase):
56 def test_override_placeholder_shapes(self):
58 Test for overriding shape in placeholder by shape from user_shapes.
60 graph = build_graph(nodes_attributes,
61 [('node_1', 'node_2'),
62 ('node_2', 'op_output')
64 {'node_2': {'shape': None},
65 'node_1': {'shape': np.array([1, 3, 227, 227]), 'op': 'Placeholder'}
67 nodes_with_edges_only=True)
69 ph_shape = np.array([1, 3, 224, 224])
70 user_dict = {'node_1': [{'shape': ph_shape}]}
71 override_placeholder_shapes(graph, user_dict)
72 res_shape = graph.node['node_1']['shape']
73 self.assertTrue(np.array_equal(ph_shape, res_shape))
75 def test_override_placeholder_no_shape(self):
77 Test for case when user_shapes is not defined.
79 graph = build_graph(nodes_attributes,
80 [('node_1', 'node_2'),
81 ('node_2', 'op_output')
83 {'node_2': {'shape': None, 'op': 'Placeholder'},
84 'node_1': {'shape': np.array([1, 3, 227, 227]), 'op': 'Placeholder'}
86 nodes_with_edges_only=True)
87 out = override_placeholder_shapes(graph, None)
88 res_shape = graph.node['node_1']['shape']
89 placeholder_shape = np.array([1, 3, 227, 227])
90 self.assertIsNone(out)
91 self.assertTrue(np.array_equal(placeholder_shape, res_shape))
93 def test_override_placeholder_shapes(self):
95 Test for case when user_shapes is not None, but it shouldn't rewrite shapes.
97 graph = build_graph(nodes_attributes,
98 [('node_1', 'node_2'),
99 ('node_2', 'op_output')
101 {'node_2': {'shape': None},
102 'node_1': {'shape': np.array([1, 3, 227, 227]), 'op': 'Placeholder'}
104 nodes_with_edges_only=True)
106 node_1_shape = np.array([1, 3, 227, 227])
107 user_dict = {'some_node': [{'shape': np.zeros((3))}]}
108 override_placeholder_shapes(graph, user_dict)
109 res_shape = graph.node['node_1']['shape']
110 self.assertTrue(np.array_equal(node_1_shape, res_shape))
112 def test_override_placeholder_shapes_dict(self):
113 graph = build_graph(nodes_attributes,
114 [('node_1', 'node_2'),
115 ('node_2', 'op_output')
117 {'node_2': {'shape': None, 'op': 'Placeholder'},
118 'node_1': {'shape': np.array([1, 3, 227, 227]), 'op': 'Placeholder'}
120 nodes_with_edges_only=True)
122 placeholder_shape = np.array([1, 3, 224, 224])
124 'node_1': [{'shape': placeholder_shape}],
125 'node_2': [{'shape': placeholder_shape}],
127 override_placeholder_shapes(graph, user_shapes)
128 res_shape = graph.node['node_1']['shape']
129 res_shape2 = graph.node['node_2']['shape']
130 self.assertTrue(np.array_equal(placeholder_shape, res_shape))
131 self.assertTrue(np.array_equal(placeholder_shape, res_shape2))
134 'placeholder_1': {'name': 'placeholder_1', 'shape': [1, 2, 3, 4], 'type': 'Placeholder', 'value': None,
135 'kind': 'op', 'op': 'Placeholder'},
136 'placeholder_2': {'name': 'placeholder_2', 'shape': [5, 6, 7, 8], 'type': 'Placeholder', 'value': None,
137 'kind': 'op', 'op': 'Placeholder'},
138 '1': {'name': 'node_1', 'type': 'Identity', 'value': None, 'kind': 'op'},
139 '2': {'name': 'node_2', 'type': 'Identity', 'value': None, 'kind': 'op'},
140 '3': {'name': 'concat', 'type': 'Identity', 'value': None, 'kind': 'op'},
141 '4': {'name': 'output', 'type': 'SoftMax', 'value': None, 'kind': 'op'}
144 ('placeholder_1', '1'),
146 ('placeholder_2', '2'),
151 def test_override_placeholder_shapes_batch_is_not_set(self):
153 Test case when batch is not set. (shapes shouldn't change)
155 graph = build_graph(self.nodes, self.edges)
158 override_placeholder_shapes(graph, shapes, batch)
159 res_shape_1 = graph.node['placeholder_1']['shape']
160 res_shape_2 = graph.node['placeholder_2']['shape']
161 self.assertTrue(np.array_equal(self.nodes['placeholder_1']['shape'], res_shape_1))
162 self.assertTrue(np.array_equal(self.nodes['placeholder_2']['shape'], res_shape_2))
164 def test_override_placeholder_shapes_real_inputs_and_batch(self):
166 Test case when batch is set and shapes should overwrite by user shapes.
168 graph = build_graph(self.nodes, self.edges)
169 shapes = {'placeholder_1': [{'shape': np.array([1, 2, 3, 4])}],
170 'placeholder_2': [{'shape': np.array([1, 5, 6, 7])}]}
172 override_placeholder_shapes(graph, shapes, batch)
173 res_shape_1 = graph.node['placeholder_1']['shape']
174 res_shape_2 = graph.node['placeholder_2']['shape']
175 self.assertTrue(np.array_equal(res_shape_1, np.array([4, 2, 3, 4])))
176 self.assertTrue(np.array_equal(res_shape_2, np.array([4, 5, 6, 7])))
178 def test_override_placeholder_shapes_real_inputs_and_batch_2(self):
180 Test case when batch is set, but shapes in user_shapes is None.
182 graph = build_graph(self.nodes, self.edges)
183 shapes = {'placeholder_1': [{'shape': None}], 'placeholder_2': [{'shape': None}]}
185 graph.node['placeholder_2']['shape'] = np.array([1, 2, 3, 4])
186 graph.node['placeholder_2']['shape'] = np.array([1, 5, 6, 7])
187 override_placeholder_shapes(graph, shapes, batch)
188 np.testing.assert_array_equal(graph.node['placeholder_1']['shape'], np.array([4, 2, 3, 4]))
189 np.testing.assert_array_equal(graph.node['placeholder_2']['shape'], np.array([4, 5, 6, 7]))
191 def test_partial_infer(self):
192 graph = build_graph(nodes_attributes,
193 [('node_1', 'concat'),
194 ('node_2', 'concat'),
195 ('concat', 'node_3'),
196 ('node_3', 'op_output')
198 {'node_3': {'kind': 'data', 'shape': None, 'infer': None},
199 'node_1': {'kind': 'data', 'shape': np.array([1, 3, 227, 227]), 'infer': None},
200 'node_2': {'kind': 'data', 'shape': np.array([1, 3, 227, 227]), 'infer': None},
201 'concat': {'kind': 'op', 'axis': 2, 'infer': concat_infer}
203 nodes_with_edges_only=True)
205 start_node = 'concat'
206 partial_infer(graph, start_node)
207 node = Node(graph, start_node)
208 self.assertTrue(node.is_partial_inferred)
209 self.assertTrue(node.out_node().is_partial_inferred)
211 # check if previous nodes are not inferred
212 node = Node(graph, start_node)
214 # collect nodes in a list
215 if isinstance(node.in_nodes(), list):
216 in_nodes = node.in_nodes()
218 in_nodes = [y for x, y in node.in_nodes().items()]
220 # check parents and find next parent
222 if 'embedded_input_' not in n.id:
224 self.assertFalse(n.has('is_partial_inferred'))
226 if not len(in_nodes):
229 def test_partial_infer_no_shape(self):
230 graph = build_graph(nodes_attributes,
231 [('node_1', 'node_2'),
232 ('node_2', 'op_output')
234 {'node_2': {'shape': None, 'infer': None},
235 'node_1': {'shape': None, 'infer': None}
237 nodes_with_edges_only=True)
238 self.assertRaises(Error, partial_infer, graph, 'node_1')
240 def test_partial_infer_cycle(self):
241 graph = build_graph(nodes_attributes,
242 [('node_1', 'concat'),
243 ('node_2', 'concat'),
244 ('concat', 'node_3'),
245 ('node_3', 'concat'),
246 ('node_3', 'op_output')
248 {'node_3': {'kind': 'data', 'shape': None, 'infer': None},
249 'node_1': {'kind': 'data', 'shape': np.array([1, 3, 227, 227]), 'infer': None},
250 'node_2': {'kind': 'data', 'shape': np.array([1, 3, 227, 227]), 'infer': None},
251 'concat': {'kind': 'op', 'axis': 2, 'infer': concat_infer}
253 nodes_with_edges_only=True)
255 start_node = 'concat'
256 self.assertRaises(Error, partial_infer, graph, start_node)
259 class CycleTest(unittest.TestCase):
260 def test_is_not_fully_inferred_param(self):
261 # Node that have is_not_fully_inferred=True
262 graph = build_graph(nodes_attributes,
263 [('node_1', 'concat'),
264 ('node_2', 'concat'),
265 ('concat', 'node_3'),
266 ('node_3', 'op_output')
268 {'node_3': {'kind': 'data', 'shape': None, 'infer': None},
269 'node_1': {'kind': 'data', 'shape': np.array([1, 3, 227, 227]), 'infer': None},
270 'node_2': {'kind': 'data', 'shape': np.array([1, 3, 227, 227]), 'infer': None},
271 'concat': {'kind': 'op', 'axis': 2, 'infer': concat_infer, 'is_not_fully_inferred': True}
273 nodes_with_edges_only=True)
275 start_node = 'concat'
277 partial_infer(graph, start_node)
279 self.fail("Unexpected Error raised")
280 node = Node(graph, start_node)
281 self.assertTrue(node.is_partial_inferred)
282 self.assertTrue(node.out_node().is_partial_inferred)
284 def test_for_is_cyclic1(self):
285 # Test for case of cyclic graph without is_cyclic attrs
286 graph = build_graph(nodes_attributes,
287 [('node_1', 'node_1_data'),
288 ('node_1_data', 'node_3'),
289 ('node_3', 'node_3_data'),
290 ('node_3_data', 'node_1')],
291 nodes_with_edges_only=True)
292 with self.assertRaisesRegex(Error, 'Graph contains a cycle. Can not proceed.*'):