2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
24 from mo.front.mxnet.extractors.utils import get_mxnet_node_edges, load_params, init_rnn_states
25 from mo.front.mxnet.extractor import common_mxnet_fields
26 from mo.front.mxnet.nd_to_params import build_params_file
27 from mo.graph.graph import Node, Graph
28 from mo.utils.error import Error
29 from mo.utils.utils import refer_to_faq_msg
32 def load_symbol_nodes(model_name, legacy_mxnet_model: bool = False):
33 model_name = '%s-symbol.json' % model_name
34 if legacy_mxnet_model:
35 log.warning('For legacy MXNet models Model Optimizer does not support conversion of old MXNet models' +
36 '(trained with 1.0.0 version of MXNet and lower) with custom layers. ' +
38 sym = mx.symbol.load(model_name)
39 model_nodes = json.loads(sym.tojson())
41 if os.path.isfile(model_name):
42 model_nodes = json.load(open(model_name))
44 raise Error('Specified input json {} does not exist. ' +
45 refer_to_faq_msg(84), model_name)
47 return model_nodes['nodes']
50 def parse_input_model(input_model):
51 path_wo_ext = '.'.join(input_model.split('.')[:-1])
52 model_name_w_iter = path_wo_ext.split(os.sep)[-1]
53 iteration_number = int(model_name_w_iter.split('-')[-1])
54 model_name = '-'.join(path_wo_ext.split('-')[:-1])
55 return model_name, iteration_number
58 def load_symbol_def(input_model_name, input_symbol, input_names: str = '', nd_prefix_name: str = '', pretrained_model_name: str = '', legacy_mxnet_model: bool = False):
59 if not nd_prefix_name and not pretrained_model_name:
60 # model name always has extension 'param'
62 model_name, iteration_number = parse_input_model(input_model_name)
63 except ValueError as err:
65 'Input model name {} is not in an expected format, cannot extract iteration number. ' +
70 model_params = load_params(input_model_name, data_names=input_names.split(','))
72 model_params = load_params(input_model_name)
74 elif nd_prefix_name and pretrained_model_name and input_symbol:
75 model_name, iteration_number = parse_input_model(pretrained_model_name)
76 model_name = '-'.join(input_symbol.split('-')[:-1])
77 model_params = build_params_file(nd_prefix_name, pretrained_model_name, input_names)
80 "Arguments --nd_prefix_name, --pretrained_model_name and --input_symbol should be provided. Please provide all or do not use any. " +
83 model_nodes = load_symbol_nodes(model_name, legacy_mxnet_model)
85 return model_nodes, model_params, model_name, iteration_number
88 def symbol_attrs(symbol_node):
89 return {'symbol_dict': symbol_node}
92 def symbol2nx(model_nodes, model_params, input_names: str = ''):
94 input_names = ('data',)
96 input_names = input_names.split(',')
98 rnn_states = init_rnn_states(model_nodes)
99 names_rnn_states = list(rnn_states.keys())
102 # as mxnet contain input layers as index of layer, for correct set up edges, we need provide index of layer with name of graph node
104 for i, node in enumerate(model_nodes):
105 if node['name'] in model_params._arg_params and node['name'] not in input_names:
106 node['value'] = np.array(model_params._arg_params[node['name']].asnumpy(), dtype=np.float32)
107 elif node['name'] in model_params._aux_params and node['name'] not in input_names:
108 node['value'] = np.array(model_params._aux_params[node['name']].asnumpy(), dtype=np.float32)
109 elif node['name'] in names_rnn_states:
110 node['value'] = np.zeros(rnn_states[node['name']])
111 node_name = graph.unique_id(node['name'])
112 graph.add_node(node_name, **symbol_attrs(node))
113 graph.node[node_name].update(common_mxnet_fields(Node(graph, node_name)))
114 index_node_keys[i] = node_name
116 for i, attrs in enumerate(model_nodes):
118 edges = get_mxnet_node_edges(node, i, list(model_nodes), index_node_keys)
120 graph.add_edges_from(edges)
125 def find_output_node(graph: Graph, src_input_index):
126 for i, attrs in (list(graph.nodes(data=True))[src_input_index + 1:]):
127 for input_index in attrs['symbol_dict']['inputs']:
128 if input_index[0] == src_input_index: