Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / mxnet / loader.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import os
18 import json
19
20 import numpy as np
21 import mxnet as mx
22 import logging as log
23
24 from mo.front.mxnet.extractors.utils import get_mxnet_node_edges, load_params, init_rnn_states
25 from mo.front.mxnet.extractor import common_mxnet_fields
26 from mo.front.mxnet.nd_to_params import build_params_file
27 from mo.graph.graph import Node, Graph
28 from mo.utils.error import Error
29 from mo.utils.utils import refer_to_faq_msg
30
31
32 def load_symbol_nodes(model_name, legacy_mxnet_model: bool = False):
33     model_name = '%s-symbol.json' % model_name
34     if legacy_mxnet_model:
35         log.warning('For legacy MXNet models Model Optimizer does not support conversion of old MXNet models' +
36                     '(trained with 1.0.0 version of MXNet and lower) with custom layers. ' +
37                     refer_to_faq_msg(93))
38         sym = mx.symbol.load(model_name)
39         model_nodes = json.loads(sym.tojson())
40     else:
41         if os.path.isfile(model_name):
42             model_nodes = json.load(open(model_name))
43         else:
44             raise Error('Specified input json {} does not exist. ' +
45                         refer_to_faq_msg(84), model_name)
46
47     return model_nodes['nodes']
48
49
50 def parse_input_model(input_model):
51     path_wo_ext = '.'.join(input_model.split('.')[:-1])
52     model_name_w_iter = path_wo_ext.split(os.sep)[-1]
53     iteration_number = int(model_name_w_iter.split('-')[-1])
54     model_name = '-'.join(path_wo_ext.split('-')[:-1])
55     return model_name, iteration_number
56
57
58 def load_symbol_def(input_model_name, input_symbol, input_names: str = '', nd_prefix_name: str = '', pretrained_model_name: str = '', legacy_mxnet_model: bool = False):
59     if not nd_prefix_name and not pretrained_model_name:
60         # model name always has extension 'param'
61         try:
62             model_name, iteration_number = parse_input_model(input_model_name)
63         except ValueError as err:
64             raise Error(
65                 'Input model name {} is not in an expected format, cannot extract iteration number. ' +
66                 refer_to_faq_msg(48),
67                 input_model_name)
68
69         if input_names:
70             model_params = load_params(input_model_name, data_names=input_names.split(','))
71         else:
72             model_params = load_params(input_model_name)
73
74     elif nd_prefix_name and pretrained_model_name and input_symbol:
75         model_name, iteration_number = parse_input_model(pretrained_model_name)
76         model_name = '-'.join(input_symbol.split('-')[:-1])
77         model_params = build_params_file(nd_prefix_name, pretrained_model_name, input_names)
78     else:
79         raise Error(
80             "Arguments --nd_prefix_name, --pretrained_model_name and --input_symbol should be provided. Please provide all or do not use any. " +
81             refer_to_faq_msg(81))
82
83     model_nodes = load_symbol_nodes(model_name, legacy_mxnet_model)
84
85     return model_nodes, model_params, model_name, iteration_number
86
87
88 def symbol_attrs(symbol_node):
89     return {'symbol_dict': symbol_node}
90
91
92 def symbol2nx(model_nodes, model_params, input_names: str = ''):
93     if not input_names:
94         input_names = ('data',)
95     else:
96         input_names = input_names.split(',')
97
98     rnn_states = init_rnn_states(model_nodes)
99     names_rnn_states = list(rnn_states.keys())
100
101     graph = Graph()
102     # as mxnet contain input layers as index of layer, for correct set up edges, we need provide index of layer with name of  graph node
103     index_node_keys = {}
104     for i, node in enumerate(model_nodes):
105         if node['name'] in model_params._arg_params and node['name'] not in input_names:
106             node['value'] = np.array(model_params._arg_params[node['name']].asnumpy(), dtype=np.float32)
107         elif node['name'] in model_params._aux_params and node['name'] not in input_names:
108             node['value'] = np.array(model_params._aux_params[node['name']].asnumpy(), dtype=np.float32)
109         elif node['name'] in names_rnn_states:
110             node['value'] = np.zeros(rnn_states[node['name']])
111         node_name = graph.unique_id(node['name'])
112         graph.add_node(node_name, **symbol_attrs(node))
113         graph.node[node_name].update(common_mxnet_fields(Node(graph, node_name)))
114         index_node_keys[i] = node_name
115
116     for i, attrs in enumerate(model_nodes):
117         node = attrs
118         edges = get_mxnet_node_edges(node, i, list(model_nodes), index_node_keys)
119         if len(edges) > 0:
120             graph.add_edges_from(edges)
121
122     return graph
123
124
125 def find_output_node(graph: Graph, src_input_index):
126     for i, attrs in (list(graph.nodes(data=True))[src_input_index + 1:]):
127         for input_index in attrs['symbol_dict']['inputs']:
128             if input_index[0] == src_input_index:
129                 return i