2 Copyright (c) 2018 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
25 from mo.front.mxnet.extractors.utils import get_mxnet_node_edges, load_params
26 from mo.front.mxnet.extractor import common_mxnet_fields
27 from mo.front.mxnet.nd_to_params import build_params_file
28 from mo.graph.graph import Node
29 from mo.graph.graph import unique_id
30 from mo.utils.error import Error
31 from mo.utils.utils import refer_to_faq_msg
34 def load_symbol_nodes(model_name, legacy_mxnet_model: bool = False):
35 model_name = '%s-symbol.json' % model_name
36 if legacy_mxnet_model:
37 log.warning('For legacy MXNet models Model Optimizer does not support conversion of old MXNet models' +
38 '(trained with 1.0.0 version of MXNet and lower) with custom layers. ' +
40 sym = mx.symbol.load(model_name)
41 model_nodes = json.loads(sym.tojson())
43 if os.path.isfile(model_name):
44 model_nodes = json.load(open(model_name))
46 raise Error('Specified input json {} does not exist. ' +
47 refer_to_faq_msg(84), model_name)
49 return model_nodes['nodes']
52 def parse_input_model(input_model):
53 path_wo_ext = '.'.join(input_model.split('.')[:-1])
54 model_name_w_iter = path_wo_ext.split(os.sep)[-1]
55 iteration_number = int(model_name_w_iter.split('-')[-1])
56 model_name = '-'.join(path_wo_ext.split('-')[:-1])
57 return model_name, iteration_number
60 def load_symbol_def(input_model_name, input_symbol, input_names: str = '', nd_prefix_name: str = '', pretrained_model_name: str = '', legacy_mxnet_model: bool = False):
61 if not nd_prefix_name and not pretrained_model_name:
62 # model name always has extension 'param'
64 model_name, iteration_number = parse_input_model(input_model_name)
65 except ValueError as err:
67 'Input model name {} is not in an expected format, cannot extract iteration number. ' +
72 model_params = load_params(input_model_name, data_names=input_names.split(','))
74 model_params = load_params(input_model_name)
76 elif nd_prefix_name and pretrained_model_name and input_symbol:
77 model_name, iteration_number = parse_input_model(pretrained_model_name)
78 model_name = '-'.join(input_symbol.split('-')[:-1])
79 model_params = build_params_file(nd_prefix_name, pretrained_model_name, input_names)
82 "Arguments --nd_prefix_name, --pretrained_model_name and --input_symbol should be provided. Please provide all or do not use any. " +
85 model_nodes = load_symbol_nodes(model_name, legacy_mxnet_model)
87 return model_nodes, model_params, model_name, iteration_number
90 def symbol_attrs(symbol_node):
91 return {'symbol_dict': symbol_node}
94 def symbol2nx(model_nodes, model_params, input_names: str = ''):
96 input_names = ('data',)
98 input_names = input_names.split(',')
100 graph = nx.MultiDiGraph()
101 # as mxnet contain input layers as index of layer, for correct set up edges, we need provide index of layer with name of graph node
103 for i, node in enumerate(model_nodes):
104 if node['name'] in model_params._arg_params and node['name'] not in input_names:
105 node['value'] = np.array(model_params._arg_params[node['name']].asnumpy(), dtype=np.float32)
106 elif node['name'] in model_params._aux_params and node['name'] not in input_names:
107 node['value'] = np.array(model_params._aux_params[node['name']].asnumpy(), dtype=np.float32)
108 node_name = unique_id(graph, node['name'])
109 graph.add_node(node_name, **symbol_attrs(node))
110 graph.node[node_name].update(common_mxnet_fields(Node(graph, node_name)))
111 index_node_keys[i] = node_name
113 for i, attrs in enumerate(model_nodes):
115 edges = get_mxnet_node_edges(node, i, list(model_nodes), index_node_keys)
117 graph.add_edges_from(edges)
122 def find_output_node(graph: nx.MultiDiGraph, src_input_index):
123 for i, attrs in (list(graph.nodes(data=True))[src_input_index + 1:]):
124 for input_index in attrs['symbol_dict']['inputs']:
125 if input_index[0] == src_input_index: