Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / image_scaler.py
index c034256..8ec13c6 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
  limitations under the License.
 """
 
-import networkx as nx
 import numpy as np
 
 from mo.front.common.replacement import FrontReplacementOp
+from mo.graph.graph import Graph
 from mo.ops.const import Const
 from mo.ops.lin_op import Mul, Add
 
@@ -26,7 +26,7 @@ class ImageScaler(FrontReplacementOp):
     op = "ImageScaler"
     enabled = True
 
-    def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict):
+    def replace_sub_graph(self, graph: Graph, match: dict):
         # This replacer replace ImageScalar operation to Mul->Add sequence
         # Also it check that weights and biases are good
         op = match['op']
@@ -38,28 +38,24 @@ class ImageScaler(FrontReplacementOp):
         if all([x == 0 for x in np.nditer(op.bias)]):
             has_bias = False
 
-        # Get all outputs for op node
-        out_nodes = [node for node in op.out_nodes().values()]
+        assert len(op.in_ports()) == 1
 
-        assert len(op.in_nodes()) == 1
+        last_port = op.in_port(0).get_source()
 
-        last_node = op.in_node()
         # Create Mul & Add nodes
         if has_weights:
-            mul_weights = Const(graph, dict(value=op.scale, shape=op.scale.shape))
-            mul_op = Mul(graph, dict(name=op.id + '/mul_'))
-            last_node = mul_op.create_node(inputs=[last_node, mul_weights.create_node()])
+            mul_weights = Const(graph, dict(value=op.scale, shape=op.scale.shape)).create_node()
+            mul_op = Mul(graph, dict(name=op.id + '/mul_')).create_node()
+            op.in_port(0).get_connection().set_destination(mul_op.in_port(0))
+            mul_weights.out_port(0).connect(mul_op.in_port(1))
+            last_port = mul_op.out_port(0)
 
         if has_bias:
-            add_bias = Const(graph, dict(value=op.bias, shape=op.bias.shape))
-            add_op = Add(graph, dict(name=op.id + '/add_'))
-            last_node = add_op.create_node(inputs=[last_node, add_bias.create_node()])
-
-        # Move edges from ImageScaler to last_node (Mul or Add)
-        for out_node in out_nodes:
-            edge_attrs = graph.get_edge_data(op.id, out_node.id)[0]
-            graph.remove_edge(op.id, out_node.id)
-            graph.add_edges_from([(last_node.id, out_node.id, edge_attrs)])
-
-        # Disconnect ImageScalar node
-        graph.remove_edge(op.in_node().id, op.id)
+            add_bias = Const(graph, dict(value=op.bias, shape=op.bias.shape)).create_node()
+            add_op = Add(graph, dict(name=op.id + '/add_')).create_node()
+            last_port.get_connection().set_destination(add_op.in_port(0))
+            add_bias.out_port(0).connect(add_op.in_port(1))
+            last_port = add_op.out_port(0)
+
+        op.in_port(0).disconnect()
+        op.out_port(0).get_connection().set_source(last_port)