Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / tf / ConvFlatten.py
index 27282d3..2fd80f2 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
  limitations under the License.
 """
 
-import networkx as nx
 import numpy as np
 
 from mo.front.subgraph_matcher import SubgraphMatch
 from mo.front.tf.replacement import FrontReplacementFromConfigFileSubGraph
-from mo.graph.graph import insert_node_after
+from mo.graph.graph import Graph
 from mo.ops.permute import Permute
 
 
 class ConvFlattenReplacement(FrontReplacementFromConfigFileSubGraph):
     replacement_id = 'ConvFlatten'
 
-    def output_edges_match(self, graph: nx.DiGraph, match: SubgraphMatch, new_sub_graph: dict):
+    def output_edges_match(self, graph: Graph, match: SubgraphMatch, new_sub_graph: dict):
         return {}
 
-    def input_edges_match(self, graph: nx.DiGraph, match: SubgraphMatch, new_sub_graph: dict):
+    def input_edges_match(self, graph: Graph, match: SubgraphMatch, new_sub_graph: dict):
         return {}
 
-    def nodes_to_remove(self, graph: nx.MultiDiGraph, match: SubgraphMatch):
+    def nodes_to_remove(self, graph: Graph, match: SubgraphMatch):
         # no need to remove any of matched nodes. We just insert 'Permute' node before the matched sub-graph.
         return []
 
-    def generate_sub_graph(self, graph: nx.MultiDiGraph, match: SubgraphMatch):
+    def generate_sub_graph(self, graph: Graph, match: SubgraphMatch):
         permute_op = Permute(graph, {'order': np.array([0, 2, 3, 1])})
         permute_node = permute_op.add_node({'name': match.scope + '_permute_'})
 
@@ -44,5 +43,5 @@ class ConvFlattenReplacement(FrontReplacementFromConfigFileSubGraph):
 
         # reshape_in_node is the node after which we should insert Permute
         reshape_in_node = reshape_node.in_nodes()[0]
-        insert_node_after(reshape_in_node, permute_node, 0)
+        reshape_in_node.insert_node_after(permute_node, 0)
         return {}