Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / utils / graph_test.py
index 5d4ed57..21bf45d 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -20,11 +20,11 @@ import networkx as nx
 
 from mo.utils.error import Error
 from mo.utils.graph import dfs, bfs_search, is_connected_component, sub_graph_between_nodes
-
+from mo.graph.graph import Graph
 
 class TestGraphUtils(unittest.TestCase):
     def test_simple_dfs(self):
-        graph = nx.MultiDiGraph()
+        graph = Graph()
         graph.add_nodes_from(list(range(1, 5)))
         graph.add_edges_from([(1, 2), (1, 3), (3, 4)])
 
@@ -36,7 +36,7 @@ class TestGraphUtils(unittest.TestCase):
         """
         Check that BFS automatically determines input nodes and start searching from them.
         """
-        graph = nx.MultiDiGraph()
+        graph = Graph()
         graph.add_nodes_from(list(range(1, 6)))
         graph.add_edges_from([(1, 3), (2, 3), (3, 4), (4, 5)])
 
@@ -47,7 +47,7 @@ class TestGraphUtils(unittest.TestCase):
         """
         Check that BFS stars from the user defined nodes and doesn't go in backward edge direction.
         """
-        graph = nx.MultiDiGraph()
+        graph = Graph()
         graph.add_nodes_from(list(range(1, 7)))
         graph.add_edges_from([(1, 3), (2, 3), (3, 4), (4, 5), (6, 1)])
 
@@ -58,7 +58,7 @@ class TestGraphUtils(unittest.TestCase):
         """
         Check that if there are two separate sub-graphs the function returns False.
         """
-        graph = nx.MultiDiGraph()
+        graph = Graph()
         graph.add_nodes_from(list(range(1, 7)))
         graph.add_edges_from([(1, 2), (2, 3), (4, 5), (5, 6)])
         self.assertFalse(is_connected_component(graph, list(range(1, 7))))
@@ -71,7 +71,7 @@ class TestGraphUtils(unittest.TestCase):
         Check that if there are two separate sub-graphs the function connected by an edge going through the ignored node
         then the function returns False.
         """
-        graph = nx.MultiDiGraph()
+        graph = Graph()
         node_names = list(range(1, 8))
         graph.add_nodes_from(node_names)
         graph.add_edges_from([(1, 2), (2, 3), (4, 5), (5, 6), (1, 7), (7, 4)])
@@ -81,7 +81,7 @@ class TestGraphUtils(unittest.TestCase):
         """
         Check that if the sub-graph is connected.
         """
-        graph = nx.MultiDiGraph()
+        graph = Graph()
         node_names = list(range(1, 8))
         graph.add_nodes_from(node_names)
         graph.add_edges_from([(1, 2), (2, 3), (4, 5), (5, 6), (1, 7), (7, 4)])
@@ -91,7 +91,7 @@ class TestGraphUtils(unittest.TestCase):
         """
         Check that edges direction is ignored when checking for the connectivity.
         """
-        graph = nx.MultiDiGraph()
+        graph = Graph()
         node_names = list(range(1, 5))
         graph.add_nodes_from(node_names)
         graph.add_edges_from([(2, 1), (2, 3), (4, 3)])
@@ -104,7 +104,7 @@ class TestGraphUtils(unittest.TestCase):
         Check that edges direction is ignored when checking for the connectivity. In this case the graph is not
         connected.
         """
-        graph = nx.MultiDiGraph()
+        graph = Graph()
         graph.add_nodes_from(list(range(1, 5)))
         graph.add_edges_from([(2, 1), (2, 3), (4, 3)])
         self.assertFalse(is_connected_component(graph, [1, 2, 4]))
@@ -121,7 +121,7 @@ class TestGraphUtils(unittest.TestCase):
             1 -> 2 -> 3 -> 4
         :return:
         """
-        graph = nx.MultiDiGraph()
+        graph = Graph()
         graph.add_nodes_from(list(range(1, 7)))
         graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2), (6, 5)])
         sub_graph_nodes = sub_graph_between_nodes(graph, [1], [4])
@@ -140,7 +140,7 @@ class TestGraphUtils(unittest.TestCase):
              \
         1 -> 2 -> 3 -> 4
         """
-        graph = nx.MultiDiGraph()
+        graph = Graph()
         graph.add_nodes_from(list(range(1, 6)))
         graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2)])
         sub_graph_nodes = sub_graph_between_nodes(graph, [2], [4])
@@ -154,7 +154,7 @@ class TestGraphUtils(unittest.TestCase):
              \
         1 -> 2 -> 3 -> 4
         """
-        graph = nx.MultiDiGraph()
+        graph = Graph()
         graph.add_nodes_from(list(range(1, 6)))
         graph.node[5]['op'] = 'Placeholder'
         graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2)])
@@ -168,7 +168,7 @@ class TestGraphUtils(unittest.TestCase):
              \
         1 -> 2 -> 3 -> 4
         """
-        graph = nx.MultiDiGraph()
+        graph = Graph()
         graph.add_nodes_from(list(range(1, 6)))
         graph.node[5]['op'] = 'Placeholder'
         graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2)])
@@ -183,7 +183,7 @@ class TestGraphUtils(unittest.TestCase):
              \
         1 -> 2 -> 3 -> 4
         """
-        graph = nx.MultiDiGraph()
+        graph = Graph()
         graph.add_nodes_from(list(range(1, 6)))
         graph.add_edges_from([(1, 2), (2, 3), (3, 4), (5, 2)])
         sub_graph_nodes = sub_graph_between_nodes(graph, [2, 5], [4])
@@ -199,7 +199,7 @@ class TestGraphUtils(unittest.TestCase):
             / \
         9 ->   -> 7 -> 8
         """
-        graph = nx.MultiDiGraph()
+        graph = Graph()
         node_names = list(range(1, 10))
         graph.add_nodes_from(node_names)
         graph.add_edges_from([(1, 2), (2, 3), (3, 4), (2, 5), (5, 6), (5, 7), (7, 8), (9, 5)])