2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from operator import itemgetter
23 from mo.back.ie_ir_ver_2.emitter import port_renumber, serialize_constants, generate_ie_ir, serialize_mean_image
24 from mo.graph.graph import Node, Graph
25 from mo.middle.passes import tensor_names, convert_data_type
26 from mo.utils.error import Error
29 def determined_sort(outputs: list):
34 while len(stack) != 0:
40 in_names = [n.id if isinstance(node.in_nodes(), list) else node.in_node(n).id for n in node.in_nodes()]
41 for in_node_name in in_names:
42 if in_node_name not in visited:
44 stack.insert(0, Node(node.graph, in_node_name))
49 op_order.append(node_id)
50 if node.kind == 'data':
51 data_order.append(node_id)
52 return op_order, data_order
55 def get_fw_tensor_debug_info(node: Node):
56 while not node.has_valid('fw_tensor_debug_info') and not node.has_valid('output_sort_order') \
57 and len(node.in_nodes()):
59 if node.has_valid('output_sort_order'):
60 return node.soft_get('output_sort_order')
61 return node.soft_get('fw_tensor_debug_info')
64 def get_sorted_outputs(graph: Graph):
67 for node in graph.nodes():
68 if len(graph.out_edges(node)) == 0:
69 outputs.append(Node(graph, node))
73 debug_info = get_fw_tensor_debug_info(node)
74 if isinstance(debug_info, str):
75 outputs_for_sort[node.id] = debug_info
76 elif isinstance(debug_info, list):
77 outputs_for_sort[node.id] = debug_info[0][0] + '_' + str(debug_info[0][1])
79 raise Error('Unsupported type of the variable with debug information used to sort output nodes')
80 if len(outputs_for_sort) != len(set(outputs_for_sort.values())):
81 log.warning('There are at least two output nodes with the same key used to sort the outputs. This means that '
82 'IRs with different order of nodes may be generated between Model Optimizer runs. The dictionary '
83 'with outputs is: {}'.format(outputs_for_sort))
84 return [Node(graph, key) for key, value in sorted(outputs_for_sort.items(), key=itemgetter(1))]
87 def collect_sub_graphs(graph: Graph):
88 ''' Go over all nodes and sub_graphs in the graph recursively; returns all found sub-graphs. '''
90 for node in graph.nodes():
91 node = Node(graph, node)
92 if node.has_valid('sub_graphs'):
93 for sub_graph in node.sub_graphs:
94 result.append(node[sub_graph])
95 result += collect_sub_graphs(node[sub_graph])
99 def relabel_nodes_inplace_safe(graph: Graph, new_labels: dict):
100 ''' Safely relabels graph in-place without graph copy.
102 Safity in this place means that it is guarantied that
103 there won't be collisions during relabiling process.
105 # Relabel nodes in two stages
106 intermediate_map = {node: graph.unique_id('__relabel__{}__'.format(str(i))) for i, node in enumerate(graph.nodes())}
107 final_map = {dst: new_labels[src] for src, dst in intermediate_map.items()}
108 assert len(set(intermediate_map.keys()).intersection(set(intermediate_map.values()))) == 0
109 assert len(set(final_map.keys()).intersection(set(final_map.values()))) == 0
110 nx.relabel_nodes(graph, intermediate_map, copy=False)
111 nx.relabel_nodes(graph, final_map, copy=False)
114 def prepare_emit_ir(graph: Graph, data_type: str, output_dir: str, output_model_name: str,
115 mean_data: [list, None] = None, input_names: list = [], meta_info: dict = dict()):
116 for sub_graph in [graph] + collect_sub_graphs(graph):
117 op_order, data_order = determined_sort(get_sorted_outputs(sub_graph))
118 mapping = {v: u for u, v in enumerate(op_order)}
119 mapping.update({v: u for u, v in enumerate(data_order, start=len(sub_graph))})
120 relabel_nodes_inplace_safe(sub_graph, mapping)
121 port_renumber(sub_graph)
122 convert_data_type.convert(sub_graph, data_type)
124 tensor_names.propagate_op_name_to_tensor(graph)
126 bin_file = os.path.join(output_dir, '{}.bin'.format(output_model_name))
127 serialize_constants(graph, bin_file)
132 mean_offset, mean_size = serialize_mean_image(bin_file, mean_data=mean_data)
134 generate_ie_ir(graph=graph,
135 file_name=os.path.join(output_dir, '{}.xml'.format(output_model_name)),
136 input_names=input_names,
137 mean_offset=mean_offset,
140 tensor_names.output_tensor_names_map(graph, os.path.join(output_dir, '{}.mapping'.format(output_model_name)))