Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / middle / passes / infer_test.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import unittest
18
19 import numpy as np
20
21 from mo.front.common.partial_infer.concat import concat_infer
22 from mo.graph.graph import Node
23 from mo.middle.passes.infer import override_placeholder_shapes, partial_infer
24 from mo.utils.error import Error
25 from mo.utils.unittest.graph import build_graph
26
27 nodes_attributes = {'node_1': {'type': 'Identity', 'value': None, 'kind': 'op'},
28                     'node_1_data': {'value': None, 'kind': 'data', 'data_type': None},
29                     'node_2': {'type': 'Identity', 'value': None, 'kind': 'op'},
30                     'concat': {'type': 'Concat', 'value': None, 'kind': 'op'},
31                     'node_3': {'type': 'Identity', 'value': None, 'kind': 'op'},
32                     'node_3_data': {'value': None, 'kind': 'data', 'data_type': None},
33                     # Placeholders
34                     'placeholder_1': {'shape': None, 'type': 'Input', 'kind': 'op', 'op': 'Placeholder'},
35                     'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
36                     'placeholder_2': {'shape': None, 'type': 'Input', 'kind': 'op', 'op': 'Placeholder'},
37                     'pl_1': {'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
38                     'pl_1_data': {'value': None, 'kind': 'data', 'data_type': None},
39                     'pl_2': {'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
40                     'pl_2_data': {'value': None, 'kind': 'data', 'data_type': None},
41                     'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
42                     # ScaleShift layer
43                     'scaleshift_1': {'type': 'ScaleShift', 'kind': 'op', 'op': 'ScaleShift'},
44                     'scaleshift_1_w': {'value': None, 'shape': None, 'kind': 'data'},
45                     'scaleshift_1_b': {'value': None, 'shape': None, 'kind': 'data'},
46                     'scaleshift_1_data': {'value': None, 'shape': None, 'kind': 'data'},
47                     # Mul op
48                     'mul_1': {'type': None, 'kind': 'op', 'op': 'Mul'},
49                     'mul_1_w': {'value': None, 'shape': None, 'kind': 'data'},
50                     'mul_1_data': {'value': None, 'shape': None, 'kind': 'data'},
51                     'op_output': { 'kind': 'op', 'op': 'OpOutput', 'infer': lambda x: None}
52                     }
53
54
55 class TestInferPass(unittest.TestCase):
56     def test_override_placeholder_shapes(self):
57         """
58         Test for overriding shape in placeholder by shape from user_shapes.
59         """
60         graph = build_graph(nodes_attributes,
61                             [('node_1', 'node_2'),
62                              ('node_2', 'op_output')
63                              ],
64                             {'node_2': {'shape': None},
65                              'node_1': {'shape': np.array([1, 3, 227, 227]), 'op': 'Placeholder'}
66                              },
67                             nodes_with_edges_only=True)
68
69         ph_shape = np.array([1, 3, 224, 224])
70         user_dict = {'node_1': [{'shape': ph_shape}]}
71         override_placeholder_shapes(graph, user_dict)
72         res_shape = graph.node['node_1']['shape']
73         self.assertTrue(np.array_equal(ph_shape, res_shape))
74
75     def test_override_placeholder_no_shape(self):
76         """
77         Test for case when user_shapes is not defined.
78         """
79         graph = build_graph(nodes_attributes,
80                             [('node_1', 'node_2'),
81                              ('node_2', 'op_output')
82                              ],
83                             {'node_2': {'shape': None, 'op': 'Placeholder'},
84                              'node_1': {'shape': np.array([1, 3, 227, 227]), 'op': 'Placeholder'}
85                              },
86                             nodes_with_edges_only=True)
87         out = override_placeholder_shapes(graph, None)
88         res_shape = graph.node['node_1']['shape']
89         placeholder_shape = np.array([1, 3, 227, 227])
90         self.assertIsNone(out)
91         self.assertTrue(np.array_equal(placeholder_shape, res_shape))
92
93     def test_override_placeholder_shapes(self):
94         """
95         Test for case when user_shapes is not None, but it shouldn't rewrite shapes.
96         """
97         graph = build_graph(nodes_attributes,
98                             [('node_1', 'node_2'),
99                              ('node_2', 'op_output')
100                              ],
101                             {'node_2': {'shape': None},
102                              'node_1': {'shape': np.array([1, 3, 227, 227]), 'op': 'Placeholder'}
103                              },
104                             nodes_with_edges_only=True)
105
106         node_1_shape = np.array([1, 3, 227, 227])
107         user_dict = {'some_node': [{'shape': np.zeros((3))}]}
108         override_placeholder_shapes(graph, user_dict)
109         res_shape = graph.node['node_1']['shape']
110         self.assertTrue(np.array_equal(node_1_shape, res_shape))
111
112     def test_override_placeholder_shapes_dict(self):
113         graph = build_graph(nodes_attributes,
114                             [('node_1', 'node_2'),
115                              ('node_2', 'op_output')
116                              ],
117                             {'node_2': {'shape': None, 'op': 'Placeholder'},
118                              'node_1': {'shape': np.array([1, 3, 227, 227]), 'op': 'Placeholder'}
119                              },
120                             nodes_with_edges_only=True)
121
122         placeholder_shape = np.array([1, 3, 224, 224])
123         user_shapes = {
124             'node_1': [{'shape': placeholder_shape}],
125             'node_2': [{'shape': placeholder_shape}],
126         }
127         override_placeholder_shapes(graph, user_shapes)
128         res_shape = graph.node['node_1']['shape']
129         res_shape2 = graph.node['node_2']['shape']
130         self.assertTrue(np.array_equal(placeholder_shape, res_shape))
131         self.assertTrue(np.array_equal(placeholder_shape, res_shape2))
132
133     nodes = {
134         'placeholder_1': {'name': 'placeholder_1', 'shape': [1, 2, 3, 4], 'type': 'Placeholder', 'value': None,
135                           'kind': 'op', 'op': 'Placeholder'},
136         'placeholder_2': {'name': 'placeholder_2', 'shape': [5, 6, 7, 8], 'type': 'Placeholder', 'value': None,
137                           'kind': 'op', 'op': 'Placeholder'},
138         '1': {'name': 'node_1', 'type': 'Identity', 'value': None, 'kind': 'op'},
139         '2': {'name': 'node_2', 'type': 'Identity', 'value': None, 'kind': 'op'},
140         '3': {'name': 'concat', 'type': 'Identity', 'value': None, 'kind': 'op'},
141         '4': {'name': 'output', 'type': 'SoftMax', 'value': None, 'kind': 'op'}
142     }
143     edges = [
144         ('placeholder_1', '1'),
145         ('1', '3'),
146         ('placeholder_2', '2'),
147         ('2', '3'),
148         ('3', '4')
149     ]
150
151     def test_override_placeholder_shapes_batch_is_not_set(self):
152         """
153         Test case when batch is not set. (shapes shouldn't change)
154         """
155         graph = build_graph(self.nodes, self.edges)
156         shapes = {}
157         batch = None
158         override_placeholder_shapes(graph, shapes, batch)
159         res_shape_1 = graph.node['placeholder_1']['shape']
160         res_shape_2 = graph.node['placeholder_2']['shape']
161         self.assertTrue(np.array_equal(self.nodes['placeholder_1']['shape'], res_shape_1))
162         self.assertTrue(np.array_equal(self.nodes['placeholder_2']['shape'], res_shape_2))
163
164     def test_override_placeholder_shapes_real_inputs_and_batch(self):
165         """
166         Test case when batch is set and shapes should overwrite by user shapes.
167         """
168         graph = build_graph(self.nodes, self.edges)
169         shapes = {'placeholder_1': [{'shape': np.array([1, 2, 3, 4])}],
170                   'placeholder_2': [{'shape': np.array([1, 5, 6, 7])}]}
171         batch = 4
172         override_placeholder_shapes(graph, shapes, batch)
173         res_shape_1 = graph.node['placeholder_1']['shape']
174         res_shape_2 = graph.node['placeholder_2']['shape']
175         self.assertTrue(np.array_equal(res_shape_1, np.array([4, 2, 3, 4])))
176         self.assertTrue(np.array_equal(res_shape_2, np.array([4, 5, 6, 7])))
177
178     def test_override_placeholder_shapes_real_inputs_and_batch_2(self):
179         """
180         Test case when batch is set, but shapes in user_shapes is None.
181         """
182         graph = build_graph(self.nodes, self.edges)
183         shapes = {'placeholder_1': [{'shape': None}], 'placeholder_2': [{'shape': None}]}
184         batch = 4
185         graph.node['placeholder_2']['shape'] = np.array([1, 2, 3, 4])
186         graph.node['placeholder_2']['shape'] = np.array([1, 5, 6, 7])
187         override_placeholder_shapes(graph, shapes, batch)
188         np.testing.assert_array_equal(graph.node['placeholder_1']['shape'], np.array([4, 2, 3, 4]))
189         np.testing.assert_array_equal(graph.node['placeholder_2']['shape'], np.array([4, 5, 6, 7]))
190
191     def test_partial_infer(self):
192         graph = build_graph(nodes_attributes,
193                             [('node_1', 'concat'),
194                              ('node_2', 'concat'),
195                              ('concat', 'node_3'),
196                              ('node_3', 'op_output')
197                              ],
198                             {'node_3': {'kind': 'data', 'shape': None, 'infer': None},
199                              'node_1': {'kind': 'data', 'shape': np.array([1, 3, 227, 227]), 'infer': None},
200                              'node_2': {'kind': 'data', 'shape': np.array([1, 3, 227, 227]), 'infer': None},
201                              'concat': {'kind': 'op', 'axis': 2, 'infer': concat_infer}
202                              },
203                             nodes_with_edges_only=True)
204
205         start_node = 'concat'
206         partial_infer(graph, start_node)
207         node = Node(graph, start_node)
208         self.assertTrue(node.is_partial_inferred)
209         self.assertTrue(node.out_node().is_partial_inferred)
210
211         # check if previous nodes are not inferred
212         node = Node(graph, start_node)
213         while True:
214             # collect nodes in a list
215             if isinstance(node.in_nodes(), list):
216                 in_nodes = node.in_nodes()
217             else:
218                 in_nodes = [y for x, y in node.in_nodes().items()]
219
220             # check parents and find next parent
221             for n in in_nodes:
222                 if 'embedded_input_' not in n.id:
223                     node = n
224                 self.assertFalse(n.has('is_partial_inferred'))
225
226             if not len(in_nodes):
227                 break
228
229     def test_partial_infer_no_shape(self):
230         graph = build_graph(nodes_attributes,
231                             [('node_1', 'node_2'),
232                              ('node_2', 'op_output')
233                              ],
234                             {'node_2': {'shape': None, 'infer': None},
235                              'node_1': {'shape': None, 'infer': None}
236                              },
237                             nodes_with_edges_only=True)
238         self.assertRaises(Error, partial_infer, graph, 'node_1')
239
240     def test_partial_infer_cycle(self):
241         graph = build_graph(nodes_attributes,
242                             [('node_1', 'concat'),
243                              ('node_2', 'concat'),
244                              ('concat', 'node_3'),
245                              ('node_3', 'concat'),
246                              ('node_3', 'op_output')
247                              ],
248                             {'node_3': {'kind': 'data', 'shape': None, 'infer': None},
249                              'node_1': {'kind': 'data', 'shape': np.array([1, 3, 227, 227]), 'infer': None},
250                              'node_2': {'kind': 'data', 'shape': np.array([1, 3, 227, 227]), 'infer': None},
251                              'concat': {'kind': 'op', 'axis': 2, 'infer': concat_infer}
252                              },
253                             nodes_with_edges_only=True)
254
255         start_node = 'concat'
256         self.assertRaises(Error, partial_infer, graph, start_node)
257
258
259 class CycleTest(unittest.TestCase):
260     def test_is_not_fully_inferred_param(self):
261         # Node that have is_not_fully_inferred=True
262         graph = build_graph(nodes_attributes,
263                             [('node_1', 'concat'),
264                              ('node_2', 'concat'),
265                              ('concat', 'node_3'),
266                              ('node_3', 'op_output')
267                              ],
268                             {'node_3': {'kind': 'data', 'shape': None, 'infer': None},
269                              'node_1': {'kind': 'data', 'shape': np.array([1, 3, 227, 227]), 'infer': None},
270                              'node_2': {'kind': 'data', 'shape': np.array([1, 3, 227, 227]), 'infer': None},
271                              'concat': {'kind': 'op', 'axis': 2, 'infer': concat_infer, 'is_not_fully_inferred': True}
272                              },
273                             nodes_with_edges_only=True)
274
275         start_node = 'concat'
276         try:
277             partial_infer(graph, start_node)
278         except Error:
279             self.fail("Unexpected Error raised")
280         node = Node(graph, start_node)
281         self.assertTrue(node.is_partial_inferred)
282         self.assertTrue(node.out_node().is_partial_inferred)
283
284     def test_for_is_cyclic1(self):
285         # Test for case of cyclic graph without is_cyclic attrs
286         graph = build_graph(nodes_attributes,
287                             [('node_1', 'node_1_data'),
288                              ('node_1_data', 'node_3'),
289                              ('node_3', 'node_3_data'),
290                              ('node_3_data', 'node_1')],
291                             nodes_with_edges_only=True)
292         with self.assertRaisesRegex(Error, 'Graph contains a cycle. Can not proceed.*'):
293             partial_infer(graph)