Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / tf / replacement.py
index b9e1e60..c9b48ee 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2017-2018 Intel Corporation
+ Copyright (c) 2017-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
 """
 import logging as log
 
-import networkx as nx
-
 from mo.front.common.custom_replacement_registry import CustomReplacementRegistry
 from mo.front.common.replacement import FrontReplacementSubgraph, FrontReplacementPattern
 from mo.front.subgraph_matcher import SubgraphMatcher, SubgraphMatch
 from mo.front.tf.custom_subgraph_call import merge_nodes
-from mo.graph.graph import dump_graph_for_graphviz, unique_id
+from mo.graph.graph import Graph
 from mo.ops.op import Op
 from mo.utils import class_registration
 from mo.utils.graph import is_connected_component
@@ -40,7 +38,7 @@ class FrontReplacementFromConfigFileGeneral(FrontReplacementPattern):
     def transform_graph(self, graph, replacement_descriptions):
         raise Exception('Function "transform_graph" must be overridden in the sub-class')
 
-    def find_and_replace_pattern(self, graph: nx.MultiDiGraph):
+    def find_and_replace_pattern(self, graph: Graph):
         replacement_descriptions = CustomReplacementRegistry().get_custom_replacement_description(self.replacement_id)
         if replacement_descriptions is None or len(replacement_descriptions) < 1:
             log.info("Failed to find custom replacement description with id '{}'".format(self.replacement_id))
@@ -72,10 +70,10 @@ class FrontReplacementFromConfigFileSubGraph(FrontReplacementSubgraph):
     def __init__(self):
         super().__init__()
 
-    def nodes_to_remove(self, graph: nx.MultiDiGraph, match: SubgraphMatch):
+    def nodes_to_remove(self, graph: Graph, match: SubgraphMatch):
         return match.matched_nodes_names()
 
-    def find_and_replace_pattern(self, graph: nx.MultiDiGraph):
+    def find_and_replace_pattern(self, graph: Graph):
         replacement_descriptions = CustomReplacementRegistry().get_custom_replacement_description(self.replacement_id)
         if replacement_descriptions is None:
             log.info("Failed to find custom replacement description with id '{}'".format(self.replacement_id))
@@ -87,7 +85,7 @@ class FrontReplacementFromConfigFileSubGraph(FrontReplacementSubgraph):
                 if not is_connected_component(graph, match.matched_nodes_names()):
                     log.warning("The following nodes don't form connected sub-graph: {}".format(
                         match.matched_nodes_names()))
-                    dump_graph_for_graphviz(graph, match.matched_nodes_names())
+                    graph.dump_graph_for_graphviz(match.matched_nodes_names())
                 self.replace_sub_graph(graph, match)
 
     registered_ops = {}
@@ -111,7 +109,7 @@ class FrontReplacementFromConfigFileOp(FrontReplacementFromConfigFileSubGraph):
         super().__init__()
 
     def input_edges_match(self,  # pylint: disable=method-hidden
-                          graph: nx.DiGraph,
+                          graph: Graph,
                           match: SubgraphMatch,
                           new_sub_graph: dict):
         """
@@ -131,7 +129,7 @@ class FrontReplacementFromConfigFileOp(FrontReplacementFromConfigFileSubGraph):
         return input_edges_match
 
     def output_edges_match(self,  # pylint: disable=method-hidden
-                           graph: nx.DiGraph,
+                           graph: Graph,
                            match: SubgraphMatch,
                            new_sub_graph: dict):
         """
@@ -150,7 +148,7 @@ class FrontReplacementFromConfigFileOp(FrontReplacementFromConfigFileSubGraph):
             output_edges_match[(output_node.id, output_port)] = (new_sub_graph['new_node'].id, sub_graph_output_port)
         return output_edges_match
 
-    def generate_sub_graph(self, graph: nx.MultiDiGraph, match: SubgraphMatch):
+    def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
         replacement_desc = match.custom_replacement_desc
         op = Op.get_op_class_by_name(replacement_desc.op)(graph, match.custom_replacement_desc.custom_attributes)
         op.default_backend_attrs = list(match.custom_replacement_desc.custom_attributes.keys())
@@ -159,7 +157,7 @@ class FrontReplacementFromConfigFileOp(FrontReplacementFromConfigFileSubGraph):
             op.substitute_ie_attrs(op.attrs)
             node = merge_nodes(graph, match.matched_nodes_names(), replacement_desc.get_inputs_description(),
                                replacement_desc.get_outputs_description())
-            node.name = unique_id(graph, op.attrs['type'])
+            node.name = graph.unique_id(op.attrs['type'])
             node_attrs = graph.node[node.id]
             # copy attributes which are defined in the custom operation
             for key in op.attrs.keys():