Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / utils / graph.py
index b651228..cf2d136 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -20,7 +20,7 @@ from re import match, compile
 import logging as log
 import networkx as nx
 
-from mo.graph.graph import Node
+from mo.graph.graph import Node, Graph
 from mo.utils.error import Error
 from mo.utils.utils import refer_to_faq_msg
 
@@ -52,7 +52,7 @@ def backward_bfs_for_operation(start_node: Node, op_names: list):
     return [Node(start_node.graph, x) for x in ret]
 
 
-def bfs_search(graph: nx.MultiDiGraph, start_nodes: list = list()):
+def bfs_search(graph: Graph, start_nodes: list = list()):
     """
     Performs breadth-first search over a graph and returns a list of nodes in the BFS order.
     :param graph: networkx graph to traverse.
@@ -77,7 +77,7 @@ def bfs_search(graph: nx.MultiDiGraph, start_nodes: list = list()):
     return result
 
 
-def dfs(graph: nx.MultiDiGraph, node_name: str, visited: set):
+def dfs(graph: Graph, node_name: str, visited: set):
     """
     Implementation of the depth-first search algorithm starting from the specific node.
     :param graph: networkx graph to operate on.
@@ -103,7 +103,7 @@ def dfs(graph: nx.MultiDiGraph, node_name: str, visited: set):
     return order
 
 
-def pseudo_topological_sort(graph: nx.MultiDiGraph, reverse: bool = False):
+def pseudo_topological_sort(graph: Graph, reverse: bool = False):
     """
     The function performs topological sort but doesn't check for cycle existence. So it may produce wrong nodes order
     for some applications.
@@ -127,7 +127,7 @@ def pseudo_topological_sort(graph: nx.MultiDiGraph, reverse: bool = False):
         return list(reversed(order))
 
 
-def nodes_matching_name_pattern(graph: nx.MultiDiGraph, pattern: str):
+def nodes_matching_name_pattern(graph: Graph, pattern: str):
     """
     Returns list of node names of the graph that match regular expression.
     :param graph: graph to operate on.
@@ -138,7 +138,7 @@ def nodes_matching_name_pattern(graph: nx.MultiDiGraph, pattern: str):
     return [node_name for node_name in list(graph.nodes()) if match(compiled_pattern, node_name)]
 
 
-def is_connected_component(graph: nx.MultiDiGraph, node_names: list):
+def is_connected_component(graph: Graph, node_names: list):
     """
     Checks that specified list of nodes forms a connected sub-graph. It ignores edges direction.
     The algorithm is the following. Run BFS from one of the nodes from the node_names list ignoring edges order and
@@ -167,7 +167,7 @@ def is_connected_component(graph: nx.MultiDiGraph, node_names: list):
     return set(node_names).issubset(visited)
 
 
-def sub_graph_between_nodes(graph: nx.MultiDiGraph, start_nodes: list, end_nodes: list, detect_extra_start_node: callable=None):
+def sub_graph_between_nodes(graph: Graph, start_nodes: list, end_nodes: list, detect_extra_start_node: callable=None):
     """
     Finds nodes of the sub-graph between 'start_nodes' and 'end_nodes'. Input nodes for the sub-graph nodes are also
     added to the sub-graph. Constant inputs of the 'start_nodes' are also added to the sub-graph.
@@ -251,7 +251,7 @@ def node_neighbourhood(node_name: str, depth: int, next_node_fn):
     return list(dist.keys())
 
 
-def node_incoming_neighbourhood(graph: nx.MultiDiGraph, node_name: str, depth: int):
+def node_incoming_neighbourhood(graph: Graph, node_name: str, depth: int):
     """
     Find input neighbourhood of the node.
     :param graph: graph to operate on.
@@ -262,7 +262,7 @@ def node_incoming_neighbourhood(graph: nx.MultiDiGraph, node_name: str, depth: i
     return node_neighbourhood(node_name, depth, lambda node_name: [u for u, v in graph.in_edges([node_name])])
 
 
-def node_outcoming_neighbourhood(graph: nx.MultiDiGraph, node_name: str, depth: int):
+def node_outcoming_neighbourhood(graph: Graph, node_name: str, depth: int):
     """
     Find output neighbourhood of the node.
     :param graph: graph to operate on.
@@ -273,7 +273,7 @@ def node_outcoming_neighbourhood(graph: nx.MultiDiGraph, node_name: str, depth:
     return node_neighbourhood(node_name, depth, lambda node_name: [v for u, v in graph.out_edges([node_name])])
 
 
-def scope_output_nodes(graph: nx.MultiDiGraph, scope: str, scope_delimiter: str='/'):
+def scope_output_nodes(graph: Graph, scope: str, scope_delimiter: str='/'):
     """
     The function returns nodes producing output of the sub-graph defined by scope (name prefix). The node is considered
     output of the scope if it is in this scope and it's output is outside of the scope.