Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / pipeline / common.py
index 7c21c90..6d4b94c 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
  limitations under the License.
 """
 
+import logging as log
 import os
 from operator import itemgetter
 
-import logging as log
 import networkx as nx
 
-from mo.back.ie_ir_ver_2.emitter import port_renumber, serialize_constants, generate_ie_ir, serialize_mean_image, \
-    create_const_nodes
-from mo.graph.graph import Node, unique_id
+from mo.back.ie_ir_ver_2.emitter import port_renumber, serialize_constants, generate_ie_ir, serialize_mean_image
+from mo.graph.graph import Node, Graph
 from mo.middle.passes import tensor_names, convert_data_type
 from mo.utils.error import Error
 
@@ -62,7 +61,7 @@ def get_fw_tensor_debug_info(node: Node):
     return node.soft_get('fw_tensor_debug_info')
 
 
-def get_sorted_outputs(graph: nx.MultiDiGraph):
+def get_sorted_outputs(graph: Graph):
     outputs = []
     outputs_for_sort = {}
     for node in graph.nodes():
@@ -85,7 +84,7 @@ def get_sorted_outputs(graph: nx.MultiDiGraph):
     return [Node(graph, key) for key, value in sorted(outputs_for_sort.items(), key=itemgetter(1))]
 
 
-def collect_sub_graphs(graph: nx.MultiDiGraph):
+def collect_sub_graphs(graph: Graph):
     ''' Go over all nodes and sub_graphs in the graph recursively; returns all found sub-graphs. '''
     result = []
     for node in graph.nodes():
@@ -97,14 +96,14 @@ def collect_sub_graphs(graph: nx.MultiDiGraph):
     return result
 
 
-def relabel_nodes_inplace_safe(graph: nx.MultiDiGraph, new_labels: dict):
+def relabel_nodes_inplace_safe(graph: Graph, new_labels: dict):
     ''' Safely relabels graph in-place without graph copy.
         
         Safity in this place means that it is guarantied that
         there won't be collisions during relabiling process.
     '''
     # Relabel nodes in two stages
-    intermediate_map = {node: unique_id(graph, '__relabel__{}__'.format(str(i))) for i, node in enumerate(graph.nodes())}
+    intermediate_map = {node: graph.unique_id('__relabel__{}__'.format(str(i))) for i, node in enumerate(graph.nodes())}
     final_map = {dst: new_labels[src] for src, dst in intermediate_map.items()}
     assert len(set(intermediate_map.keys()).intersection(set(intermediate_map.values()))) == 0
     assert len(set(final_map.keys()).intersection(set(final_map.values()))) == 0
@@ -112,11 +111,9 @@ def relabel_nodes_inplace_safe(graph: nx.MultiDiGraph, new_labels: dict):
     nx.relabel_nodes(graph, final_map, copy=False)
 
 
-def prepare_emit_ir(graph: nx.MultiDiGraph, data_type: str, output_dir: str, output_model_name: str,
+def prepare_emit_ir(graph: Graph, data_type: str, output_dir: str, output_model_name: str,
                     mean_data: [list, None] = None, input_names: list = [], meta_info: dict = dict()):
-
     for sub_graph in [graph] + collect_sub_graphs(graph):
-        create_const_nodes(sub_graph, start_data_nodes_are_not_allowed=(sub_graph == graph))
         op_order, data_order = determined_sort(get_sorted_outputs(sub_graph))
         mapping = {v: u for u, v in enumerate(op_order)}
         mapping.update({v: u for u, v in enumerate(data_order, start=len(sub_graph))})