Publishing R5 content (#72)
[platform/upstream/dldt.git] / model-optimizer / mo / front / mxnet / loader.py
1 """
2  Copyright (c) 2018 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import os
18 import json
19
20 import networkx as nx
21 import numpy as np
22 import mxnet as mx
23 import logging as log
24
25 from mo.front.mxnet.extractors.utils import get_mxnet_node_edges, load_params
26 from mo.front.mxnet.extractor import common_mxnet_fields
27 from mo.front.mxnet.nd_to_params import build_params_file
28 from mo.graph.graph import Node
29 from mo.graph.graph import unique_id
30 from mo.utils.error import Error
31 from mo.utils.utils import refer_to_faq_msg
32
33
34 def load_symbol_nodes(model_name, legacy_mxnet_model: bool = False):
35     model_name = '%s-symbol.json' % model_name
36     if legacy_mxnet_model:
37         log.warning('For legacy MXNet models Model Optimizer does not support conversion of old MXNet models' +
38                     '(trained with 1.0.0 version of MXNet and lower) with custom layers. ' +
39                     refer_to_faq_msg(93))
40         sym = mx.symbol.load(model_name)
41         model_nodes = json.loads(sym.tojson())
42     else:
43         if os.path.isfile(model_name):
44             model_nodes = json.load(open(model_name))
45         else:
46             raise Error('Specified input json {} does not exist. ' +
47                         refer_to_faq_msg(84), model_name)
48
49     return model_nodes['nodes']
50
51
52 def parse_input_model(input_model):
53     path_wo_ext = '.'.join(input_model.split('.')[:-1])
54     model_name_w_iter = path_wo_ext.split(os.sep)[-1]
55     iteration_number = int(model_name_w_iter.split('-')[-1])
56     model_name = '-'.join(path_wo_ext.split('-')[:-1])
57     return model_name, iteration_number
58
59
60 def load_symbol_def(input_model_name, input_symbol, input_names: str = '', nd_prefix_name: str = '', pretrained_model_name: str = '', legacy_mxnet_model: bool = False):
61     if not nd_prefix_name and not pretrained_model_name:
62         # model name always has extension 'param'
63         try:
64             model_name, iteration_number = parse_input_model(input_model_name)
65         except ValueError as err:
66             raise Error(
67                 'Input model name {} is not in an expected format, cannot extract iteration number. ' +
68                 refer_to_faq_msg(48),
69                 input_model_name)
70
71         if input_names:
72             model_params = load_params(input_model_name, data_names=input_names.split(','))
73         else:
74             model_params = load_params(input_model_name)
75
76     elif nd_prefix_name and pretrained_model_name and input_symbol:
77         model_name, iteration_number = parse_input_model(pretrained_model_name)
78         model_name = '-'.join(input_symbol.split('-')[:-1])
79         model_params = build_params_file(nd_prefix_name, pretrained_model_name, input_names)
80     else:
81         raise Error(
82             "Arguments --nd_prefix_name, --pretrained_model_name and --input_symbol should be provided. Please provide all or do not use any. " +
83             refer_to_faq_msg(81))
84
85     model_nodes = load_symbol_nodes(model_name, legacy_mxnet_model)
86
87     return model_nodes, model_params, model_name, iteration_number
88
89
90 def symbol_attrs(symbol_node):
91     return {'symbol_dict': symbol_node}
92
93
94 def symbol2nx(model_nodes, model_params, input_names: str = ''):
95     if not input_names:
96         input_names = ('data',)
97     else:
98         input_names = input_names.split(',')
99
100     graph = nx.MultiDiGraph()
101     # as mxnet contain input layers as index of layer, for correct set up edges, we need provide index of layer with name of  graph node
102     index_node_keys = {}
103     for i, node in enumerate(model_nodes):
104         if node['name'] in model_params._arg_params and node['name'] not in input_names:
105             node['value'] = np.array(model_params._arg_params[node['name']].asnumpy(), dtype=np.float32)
106         elif node['name'] in model_params._aux_params and node['name'] not in input_names:
107             node['value'] = np.array(model_params._aux_params[node['name']].asnumpy(), dtype=np.float32)
108         node_name = unique_id(graph, node['name'])
109         graph.add_node(node_name, **symbol_attrs(node))
110         graph.node[node_name].update(common_mxnet_fields(Node(graph, node_name)))
111         index_node_keys[i] = node_name
112
113     for i, attrs in enumerate(model_nodes):
114         node = attrs
115         edges = get_mxnet_node_edges(node, i, list(model_nodes), index_node_keys)
116         if len(edges) > 0:
117             graph.add_edges_from(edges)
118
119     return graph
120
121
122 def find_output_node(graph: nx.MultiDiGraph, src_input_index):
123     for i, attrs in (list(graph.nodes(data=True))[src_input_index + 1:]):
124         for input_index in attrs['symbol_dict']['inputs']:
125             if input_index[0] == src_input_index:
126                 return i