2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from mo.utils.error import Error
22 from mo.utils.graph import dfs, bfs_search, is_connected_component, sub_graph_between_nodes
23 from mo.graph.graph import Graph
25 class TestGraphUtils(unittest.TestCase):
26 def test_simple_dfs(self):
28 graph.add_nodes_from(list(range(1, 5)))
29 graph.add_edges_from([(1, 2), (1, 3), (3, 4)])
32 order = dfs(graph, 1, visited)
33 self.assertTrue(order == [4, 3, 2, 1] or order == [2, 4, 3, 1])
35 def test_bfs_search_default_start_nodes(self):
37 Check that BFS automatically determines input nodes and start searching from them.
40 graph.add_nodes_from(list(range(1, 6)))
41 graph.add_edges_from([(1, 3), (2, 3), (3, 4), (4, 5)])
43 order = bfs_search(graph)
44 self.assertTrue(order == [1, 2, 3, 4, 5] or order == [2, 1, 3, 4, 5])
46 def test_bfs_search_specific_start_nodes(self):
48 Check that BFS stars from the user defined nodes and doesn't go in backward edge direction.
51 graph.add_nodes_from(list(range(1, 7)))
52 graph.add_edges_from([(1, 3), (2, 3), (3, 4), (4, 5), (6, 1)])
54 order = bfs_search(graph, [1])
55 self.assertTrue(order == [1, 3, 4, 5])
57 def test_is_connected_component_two_separate_sub_graphs(self):
59 Check that if there are two separate sub-graphs the function returns False.
62 graph.add_nodes_from(list(range(1, 7)))
63 graph.add_edges_from([(1, 2), (2, 3), (4, 5), (5, 6)])
64 self.assertFalse(is_connected_component(graph, list(range(1, 7))))
65 self.assertFalse(is_connected_component(graph, [1, 3]))
66 self.assertFalse(is_connected_component(graph, [6, 4]))
67 self.assertFalse(is_connected_component(graph, [2, 5]))
69 def test_is_connected_component_two_separate_sub_graphs_divided_by_ignored_node(self):
71 Check that if there are two separate sub-graphs the function connected by an edge going through the ignored node
72 then the function returns False.
75 node_names = list(range(1, 8))
76 graph.add_nodes_from(node_names)
77 graph.add_edges_from([(1, 2), (2, 3), (4, 5), (5, 6), (1, 7), (7, 4)])
78 self.assertFalse(is_connected_component(graph, list(range(1, 7))))
80 def test_is_connected_component_connected(self):
82 Check that if the sub-graph is connected.
85 node_names = list(range(1, 8))
86 graph.add_nodes_from(node_names)
87 graph.add_edges_from([(1, 2), (2, 3), (4, 5), (5, 6), (1, 7), (7, 4)])
88 self.assertTrue(is_connected_component(graph, list(range(1, 8))))
90 def test_is_connected_component_edges_direction_is_ignored(self):
92 Check that edges direction is ignored when checking for the connectivity.
95 node_names = list(range(1, 5))
96 graph.add_nodes_from(node_names)
97 graph.add_edges_from([(2, 1), (2, 3), (4, 3)])
98 self.assertTrue(is_connected_component(graph, node_names))
99 self.assertTrue(is_connected_component(graph, [2, 1]))
100 self.assertTrue(is_connected_component(graph, [4, 2, 3]))
102 def test_is_connected_component_edges_direction_is_ignored_not_connected(self):
104 Check that edges direction is ignored when checking for the connectivity. In this case the graph is not
108 graph.add_nodes_from(list(range(1, 5)))
109 graph.add_edges_from([(2, 1), (2, 3), (4, 3)])
110 self.assertFalse(is_connected_component(graph, [1, 2, 4]))
111 self.assertFalse(is_connected_component(graph, [1, 4]))
112 self.assertFalse(is_connected_component(graph, [2, 4]))
113 self.assertFalse(is_connected_component(graph, [3, 4, 1]))
115 def test_sub_graph_between_nodes_include_incoming_edges_for_internal_nodes(self):
117 Check that the function adds input nodes for the internal nodes of the graph. For example, we need to add node 5
118 and 6 in the case below if we find match from node 1 till node 4.
125 graph.add_nodes_from(list(range(1, 7)))
126 graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2), (6, 5)])
127 sub_graph_nodes = sub_graph_between_nodes(graph, [1], [4])
128 self.assertIsNotNone(sub_graph_nodes)
129 self.assertListEqual(sorted(sub_graph_nodes), list(range(1, 7)))
131 sub_graph_nodes = sub_graph_between_nodes(graph, [1], [2])
132 self.assertIsNotNone(sub_graph_nodes)
133 self.assertListEqual(sorted(sub_graph_nodes), [1, 2, 5, 6])
135 def test_sub_graph_between_nodes_do_not_include_incoming_edges_for_input_nodes(self):
137 Check that the function doesn't add input nodes for the start nodes of the sub-graph. For example, we do not
138 need to add node 5 in the case below if we find match from node 1 till node 4.
144 graph.add_nodes_from(list(range(1, 6)))
145 graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2)])
146 sub_graph_nodes = sub_graph_between_nodes(graph, [2], [4])
147 self.assertIsNotNone(sub_graph_nodes)
148 self.assertListEqual(sorted(sub_graph_nodes), [2, 3, 4])
150 def test_sub_graph_between_nodes_placeholder_included(self):
152 Check that the function doesn't allow to add Placeholders to the sub-graph. 5 is the Placeholder op.
158 graph.add_nodes_from(list(range(1, 6)))
159 graph.node[5]['op'] = 'Placeholder'
160 graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2)])
161 self.assertRaises(Error, sub_graph_between_nodes, graph, [1], [4])
163 def test_sub_graph_between_nodes_placeholder_excluded(self):
165 Check that the function do not check that node is Placeholders for the nodes not included into the sub-graph.
166 For example, node 5 is Placeholder but it is not included into the sub-graph, so this attribute is ignored.
172 graph.add_nodes_from(list(range(1, 6)))
173 graph.node[5]['op'] = 'Placeholder'
174 graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2)])
175 sub_graph_nodes = sub_graph_between_nodes(graph, [2], [4])
176 self.assertIsNotNone(sub_graph_nodes)
177 self.assertListEqual(sorted(sub_graph_nodes), [2, 3, 4])
179 def test_sub_graph_between_nodes_multiple_inputs(self):
181 Check that the function works correctly when multiple inputs specified.
187 graph.add_nodes_from(list(range(1, 6)))
188 graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2)])
189 sub_graph_nodes = sub_graph_between_nodes(graph, [2, 5], [4])
190 self.assertIsNotNone(sub_graph_nodes)
191 self.assertListEqual(sorted(sub_graph_nodes), sorted([2, 3, 4, 5]))
193 def test_sub_graph_between_nodes_branches_included(self):
195 Check that the function works correctly for tree like structures.
203 node_names = list(range(1, 10))
204 graph.add_nodes_from(node_names)
205 graph.add_edges_from([(1, 2), (2, 3), (3, 4), (2, 5), (5, 6), (5, 7), (7, 8), (9, 5)])
206 self.assertListEqual(sorted(sub_graph_between_nodes(graph, [1], [4])), node_names)
207 self.assertListEqual(sorted(sub_graph_between_nodes(graph, [1], [6])), node_names)
208 self.assertListEqual(sorted(sub_graph_between_nodes(graph, [1], [8])), node_names)
209 # all nodes except 4 because it is a child of end node
210 self.assertListEqual(sorted(sub_graph_between_nodes(graph, [1], [3])), [n for n in node_names if n != 4])
211 # all nodes except 1 because it is a parent node child of start node. The nodes 3 and 4 must be added because
212 # after merging node 2 into sub-graph the node 2 will be removed and it is not known how to calculate the tensor
213 # between node 2 and 3.
214 self.assertListEqual(sorted(sub_graph_between_nodes(graph, [2], [8])), [n for n in node_names if n != 1])