Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / extensions / back / insert_compatibility_l2normalization.py
index 4f4dfe9..994b5af 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2017-2018 Intel Corporation
+ Copyright (c) 2017-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -17,7 +17,7 @@
 import numpy as np
 import networkx as nx
 from mo.ops.op import Op
-from mo.graph.graph import create_edge
+from mo.graph.graph import Graph
 from mo.back.replacement import BackReplacementPattern
 
 
@@ -32,7 +32,7 @@ class CompatibilityL2NormalizationPattern(BackReplacementPattern):
             ],
             edges=[])
 
-    def replace_pattern(self, graph: nx.MultiDiGraph, match: dict):
+    def replace_pattern(self, graph: Graph, match: dict):
         """
         Adds Normalize layer weights, which are required by Inference Engine, 
         but do not always exist in MXNet model. 
@@ -42,7 +42,7 @@ class CompatibilityL2NormalizationPattern(BackReplacementPattern):
         
         Parameters
         ----------
-        graph : nx.MultiDiGraph
+        graph : Graph
            Graph with loaded model.
          match : dict
            Patterns which were found in graph structure.
@@ -51,4 +51,4 @@ class CompatibilityL2NormalizationPattern(BackReplacementPattern):
         if len(l2_normalization_node.in_nodes()) < 2:
             value = np.full([l2_normalization_node.in_node(0).shape[1]], 1.0, dtype=np.float32)
             weights_node = Op.create_input_data_node(graph, name=l2_normalization_node['name'] + '_weights', value=value)
-            create_edge(weights_node, l2_normalization_node, out_port=0, in_port=1, edge_attrs={'bin': 'weights'})
+            graph.create_edge(weights_node, l2_normalization_node, out_port=0, in_port=1, edge_attrs={'bin': 'weights'})