"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
limitations under the License.
"""
+import logging as log
import os
from operator import itemgetter
-import logging as log
import networkx as nx
-from mo.back.ie_ir_ver_2.emitter import port_renumber, serialize_constants, generate_ie_ir, serialize_mean_image, \
- create_const_nodes
-from mo.graph.graph import Node, unique_id
+from mo.back.ie_ir_ver_2.emitter import port_renumber, serialize_constants, generate_ie_ir, serialize_mean_image
+from mo.graph.graph import Node, Graph
from mo.middle.passes import tensor_names, convert_data_type
from mo.utils.error import Error
return node.soft_get('fw_tensor_debug_info')
-def get_sorted_outputs(graph: nx.MultiDiGraph):
+def get_sorted_outputs(graph: Graph):
outputs = []
outputs_for_sort = {}
for node in graph.nodes():
return [Node(graph, key) for key, value in sorted(outputs_for_sort.items(), key=itemgetter(1))]
-def collect_sub_graphs(graph: nx.MultiDiGraph):
+def collect_sub_graphs(graph: Graph):
''' Go over all nodes and sub_graphs in the graph recursively; returns all found sub-graphs. '''
result = []
for node in graph.nodes():
return result
-def relabel_nodes_inplace_safe(graph: nx.MultiDiGraph, new_labels: dict):
+def relabel_nodes_inplace_safe(graph: Graph, new_labels: dict):
''' Safely relabels graph in-place without graph copy.
Safity in this place means that it is guarantied that
there won't be collisions during relabiling process.
'''
# Relabel nodes in two stages
- intermediate_map = {node: unique_id(graph, '__relabel__{}__'.format(str(i))) for i, node in enumerate(graph.nodes())}
+ intermediate_map = {node: graph.unique_id('__relabel__{}__'.format(str(i))) for i, node in enumerate(graph.nodes())}
final_map = {dst: new_labels[src] for src, dst in intermediate_map.items()}
assert len(set(intermediate_map.keys()).intersection(set(intermediate_map.values()))) == 0
assert len(set(final_map.keys()).intersection(set(final_map.values()))) == 0
nx.relabel_nodes(graph, final_map, copy=False)
-def prepare_emit_ir(graph: nx.MultiDiGraph, data_type: str, output_dir: str, output_model_name: str,
+def prepare_emit_ir(graph: Graph, data_type: str, output_dir: str, output_model_name: str,
mean_data: [list, None] = None, input_names: list = [], meta_info: dict = dict()):
-
for sub_graph in [graph] + collect_sub_graphs(graph):
- create_const_nodes(sub_graph, start_data_nodes_are_not_allowed=(sub_graph == graph))
op_order, data_order = determined_sort(get_sorted_outputs(sub_graph))
mapping = {v: u for u, v in enumerate(op_order)}
mapping.update({v: u for u, v in enumerate(data_order, start=len(sub_graph))})