Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / utils / graph_test.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import unittest
18
19 import networkx as nx
20
21 from mo.utils.error import Error
22 from mo.utils.graph import dfs, bfs_search, is_connected_component, sub_graph_between_nodes
23 from mo.graph.graph import Graph
24
25 class TestGraphUtils(unittest.TestCase):
26     def test_simple_dfs(self):
27         graph = Graph()
28         graph.add_nodes_from(list(range(1, 5)))
29         graph.add_edges_from([(1, 2), (1, 3), (3, 4)])
30
31         visited = set()
32         order = dfs(graph, 1, visited)
33         self.assertTrue(order == [4, 3, 2, 1] or order == [2, 4, 3, 1])
34
35     def test_bfs_search_default_start_nodes(self):
36         """
37         Check that BFS automatically determines input nodes and start searching from them.
38         """
39         graph = Graph()
40         graph.add_nodes_from(list(range(1, 6)))
41         graph.add_edges_from([(1, 3), (2, 3), (3, 4), (4, 5)])
42
43         order = bfs_search(graph)
44         self.assertTrue(order == [1, 2, 3, 4, 5] or order == [2, 1, 3, 4, 5])
45
46     def test_bfs_search_specific_start_nodes(self):
47         """
48         Check that BFS stars from the user defined nodes and doesn't go in backward edge direction.
49         """
50         graph = Graph()
51         graph.add_nodes_from(list(range(1, 7)))
52         graph.add_edges_from([(1, 3), (2, 3), (3, 4), (4, 5), (6, 1)])
53
54         order = bfs_search(graph, [1])
55         self.assertTrue(order == [1, 3, 4, 5])
56
57     def test_is_connected_component_two_separate_sub_graphs(self):
58         """
59         Check that if there are two separate sub-graphs the function returns False.
60         """
61         graph = Graph()
62         graph.add_nodes_from(list(range(1, 7)))
63         graph.add_edges_from([(1, 2), (2, 3), (4, 5), (5, 6)])
64         self.assertFalse(is_connected_component(graph, list(range(1, 7))))
65         self.assertFalse(is_connected_component(graph, [1, 3]))
66         self.assertFalse(is_connected_component(graph, [6, 4]))
67         self.assertFalse(is_connected_component(graph, [2, 5]))
68
69     def test_is_connected_component_two_separate_sub_graphs_divided_by_ignored_node(self):
70         """
71         Check that if there are two separate sub-graphs the function connected by an edge going through the ignored node
72         then the function returns False.
73         """
74         graph = Graph()
75         node_names = list(range(1, 8))
76         graph.add_nodes_from(node_names)
77         graph.add_edges_from([(1, 2), (2, 3), (4, 5), (5, 6), (1, 7), (7, 4)])
78         self.assertFalse(is_connected_component(graph, list(range(1, 7))))
79
80     def test_is_connected_component_connected(self):
81         """
82         Check that if the sub-graph is connected.
83         """
84         graph = Graph()
85         node_names = list(range(1, 8))
86         graph.add_nodes_from(node_names)
87         graph.add_edges_from([(1, 2), (2, 3), (4, 5), (5, 6), (1, 7), (7, 4)])
88         self.assertTrue(is_connected_component(graph, list(range(1, 8))))
89
90     def test_is_connected_component_edges_direction_is_ignored(self):
91         """
92         Check that edges direction is ignored when checking for the connectivity.
93         """
94         graph = Graph()
95         node_names = list(range(1, 5))
96         graph.add_nodes_from(node_names)
97         graph.add_edges_from([(2, 1), (2, 3), (4, 3)])
98         self.assertTrue(is_connected_component(graph, node_names))
99         self.assertTrue(is_connected_component(graph, [2, 1]))
100         self.assertTrue(is_connected_component(graph, [4, 2, 3]))
101
102     def test_is_connected_component_edges_direction_is_ignored_not_connected(self):
103         """
104         Check that edges direction is ignored when checking for the connectivity. In this case the graph is not
105         connected.
106         """
107         graph = Graph()
108         graph.add_nodes_from(list(range(1, 5)))
109         graph.add_edges_from([(2, 1), (2, 3), (4, 3)])
110         self.assertFalse(is_connected_component(graph, [1, 2, 4]))
111         self.assertFalse(is_connected_component(graph, [1, 4]))
112         self.assertFalse(is_connected_component(graph, [2, 4]))
113         self.assertFalse(is_connected_component(graph, [3, 4, 1]))
114
115     def test_sub_graph_between_nodes_include_incoming_edges_for_internal_nodes(self):
116         """
117         Check that the function adds input nodes for the internal nodes of the graph. For example, we need to add node 5
118         and 6 in the case below if we find match from node 1 till node 4.
119         6 -> 5 ->
120                  \
121             1 -> 2 -> 3 -> 4
122         :return:
123         """
124         graph = Graph()
125         graph.add_nodes_from(list(range(1, 7)))
126         graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2), (6, 5)])
127         sub_graph_nodes = sub_graph_between_nodes(graph, [1], [4])
128         self.assertIsNotNone(sub_graph_nodes)
129         self.assertListEqual(sorted(sub_graph_nodes), list(range(1, 7)))
130
131         sub_graph_nodes = sub_graph_between_nodes(graph, [1], [2])
132         self.assertIsNotNone(sub_graph_nodes)
133         self.assertListEqual(sorted(sub_graph_nodes), [1, 2, 5, 6])
134
135     def test_sub_graph_between_nodes_do_not_include_incoming_edges_for_input_nodes(self):
136         """
137         Check that the function doesn't add input nodes for the start nodes of the sub-graph. For example, we do not
138         need to add node 5 in the case below if we find match from node 1 till node 4.
139           5->
140              \
141         1 -> 2 -> 3 -> 4
142         """
143         graph = Graph()
144         graph.add_nodes_from(list(range(1, 6)))
145         graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2)])
146         sub_graph_nodes = sub_graph_between_nodes(graph, [2], [4])
147         self.assertIsNotNone(sub_graph_nodes)
148         self.assertListEqual(sorted(sub_graph_nodes), [2, 3, 4])
149
150     def test_sub_graph_between_nodes_placeholder_included(self):
151         """
152         Check that the function doesn't allow to add Placeholders to the sub-graph. 5 is the Placeholder op.
153           5->
154              \
155         1 -> 2 -> 3 -> 4
156         """
157         graph = Graph()
158         graph.add_nodes_from(list(range(1, 6)))
159         graph.node[5]['op'] = 'Placeholder'
160         graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2)])
161         self.assertRaises(Error, sub_graph_between_nodes, graph, [1], [4])
162
163     def test_sub_graph_between_nodes_placeholder_excluded(self):
164         """
165         Check that the function do not check that node is Placeholders for the nodes not included into the sub-graph.
166         For example, node 5 is Placeholder but it is not included into the sub-graph, so this attribute is ignored.
167           5->
168              \
169         1 -> 2 -> 3 -> 4
170         """
171         graph = Graph()
172         graph.add_nodes_from(list(range(1, 6)))
173         graph.node[5]['op'] = 'Placeholder'
174         graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2)])
175         sub_graph_nodes = sub_graph_between_nodes(graph, [2], [4])
176         self.assertIsNotNone(sub_graph_nodes)
177         self.assertListEqual(sorted(sub_graph_nodes), [2, 3, 4])
178
179     def test_sub_graph_between_nodes_multiple_inputs(self):
180         """
181         Check that the function works correctly when multiple inputs specified.
182           5->
183              \
184         1 -> 2 -> 3 -> 4
185         """
186         graph = Graph()
187         graph.add_nodes_from(list(range(1, 6)))
188         graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2)])
189         sub_graph_nodes = sub_graph_between_nodes(graph, [2, 5], [4])
190         self.assertIsNotNone(sub_graph_nodes)
191         self.assertListEqual(sorted(sub_graph_nodes), sorted([2, 3, 4, 5]))
192
193     def test_sub_graph_between_nodes_branches_included(self):
194         """
195         Check that the function works correctly for tree like structures.
196         1 -> 2 -> 3 -> 4
197              \
198              5 -> 6
199             / \
200         9 ->   -> 7 -> 8
201         """
202         graph = Graph()
203         node_names = list(range(1, 10))
204         graph.add_nodes_from(node_names)
205         graph.add_edges_from([(1, 2), (2, 3), (3, 4), (2, 5), (5, 6), (5, 7), (7, 8), (9, 5)])
206         self.assertListEqual(sorted(sub_graph_between_nodes(graph, [1], [4])), node_names)
207         self.assertListEqual(sorted(sub_graph_between_nodes(graph, [1], [6])), node_names)
208         self.assertListEqual(sorted(sub_graph_between_nodes(graph, [1], [8])), node_names)
209         # all nodes except 4 because it is a child of end node
210         self.assertListEqual(sorted(sub_graph_between_nodes(graph, [1], [3])), [n for n in node_names if n != 4])
211         # all nodes except 1 because it is a parent node child of start node. The nodes 3 and 4 must be added because
212         # after merging node 2 into sub-graph the node 2 will be removed and it is not known how to calculate the tensor
213         # between node 2 and 3.
214         self.assertListEqual(sorted(sub_graph_between_nodes(graph, [2], [8])), [n for n in node_names if n != 1])