Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / tf / nearest_neighbor_upsampling.py
index 23b1f45..d42b73b 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
 
 import logging as log
 
-import networkx as nx
-
 from extensions.front.Pack import Pack
 from extensions.ops.resample import ResampleOp
 from mo.front.common.replacement import FrontReplacementSubgraph
-from mo.graph.graph import replace_node
+from mo.graph.graph import Node, Graph
 
 
 class NearestNeighborUpsampling(FrontReplacementSubgraph):
@@ -56,7 +54,7 @@ class NearestNeighborUpsampling(FrontReplacementSubgraph):
             ]
         )
 
-    def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict):
+    def replace_sub_graph(self, graph: Graph, match: dict):
         log.debug('Matched NearestNeighborUpsampling pattern: {}'.format([node.id for node in match.values()]))
         try:
             input_height = match['pack_1'].in_node(1).value.item()
@@ -73,5 +71,5 @@ class NearestNeighborUpsampling(FrontReplacementSubgraph):
                                          'resample_type': 'caffe.ResampleParameter.NEAREST'})
         resample_node = resample_op.create_node([match['op']])
 
-        replace_node(match['reshape_2'], resample_node)
+        match['reshape_2'].replace_node(resample_node)
         graph.remove_nodes_from([node.id for node in match.values() if node.id != match['op'].id])