Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / front / div.py
similarity index 54%
rename from model-optimizer/mo/ops/div.py
rename to model-optimizer/extensions/front/div.py
index 4f39e4c..9509d79 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
 """
 
 import numpy as np
-import networkx as nx
 
 from mo.front.common.replacement import FrontReplacementOp
-from mo.graph.graph import Node
+from mo.graph.graph import Node, Graph
 from mo.ops.eltwise import Eltwise
 from mo.ops.power import Power
 
@@ -27,13 +26,15 @@ class Div(FrontReplacementOp):
     op = "Div"
     enabled = True
 
-    def replace_op(self, graph: nx.MultiDiGraph, node: Node):
-        reciprocal = Power(graph, dict(scale=1, power=np.float64(-1), shift=0, name=node.name + '/reciprocal_'))
-        mul = Eltwise(graph, dict(operation='mul', name=node.name + '/mul_'))
+    def replace_op(self, graph: Graph, node: Node):
+        reciprocal = Power(graph, {'scale': 1, 'power': np.float64(-1), 'shift': 0,
+                                   'name': node.name + '/reciprocal_'}).create_node()
+        mul = Eltwise(graph, {'operation': 'mul', 'name': node.name + '/mul_'}).create_node()
+
+        # Connect nodes
+        node.in_port(1).get_connection().set_destination(reciprocal.in_port(0))
+        node.in_port(0).get_connection().set_destination(mul.in_port(1))
+        reciprocal.out_port(0).connect(mul.in_port(0))
 
-        out_node = mul.create_node([(node.in_node(0), node.in_edge(0)['out']),
-                                    reciprocal.create_node([(node.in_node(1), node.in_edge(1)['out'])])
-                                   ])
-        # Replace edge from out port 0 of the matched node with a edge from node out_node.id with port 0.
         # The "explicit" version of the return value is: [(out_node.id, 0)])
-        return [out_node.id]
+        return [mul.id]