Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / front / mxnet / nd_to_params.py
1 """
2  Copyright (c) 2017-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import os
18
19 import mxnet as mx
20 from mo.front.mxnet.extractors.utils import load_params
21
22
23 def save_params_file(model_name: str, args: dict, auxs: dict, iteration_number: int = 0):
24     pretrained = {}
25     for key in args:
26         pretrained["arg:" + key] = args[key]
27
28     for key in auxs:
29         pretrained["aux:" + key] = auxs[key]
30
31     save_model_path = '{}-{:04}.params'.format(model_name, iteration_number)
32     save_model_path = os.path.expanduser(save_model_path)
33     if os.path.isfile(save_model_path):
34         os.remove(save_model_path)
35     mx.nd.save(save_model_path, pretrained)
36
37
38 def add_pretrained_model(pretrained_params: dict, args: dict, pretrained_model: str, iteration_number: int,
39                          input_names: str):
40     if input_names:
41         input_names = input_names.split(',')
42     else:
43         input_names = 'data'
44
45     arg_dict = args
46     if pretrained_params:
47         symbol, arg_params, aux_params = mx.model.load_checkpoint(pretrained_model, iteration_number)
48         arg_names = symbol.list_arguments()
49         arg_dict = {}
50
51         for name in arg_names:
52             if name in input_names:
53                 continue
54             key = "arg:" + name
55             if key in pretrained_params:
56                 arg_dict[name] = pretrained_params[key].copyto(mx.cpu())
57         del pretrained_params
58         arg_dict.update(args)
59     return arg_dict
60
61
62 def build_params_file(nd_prefix_name: str = '', pretrained_model: str = '', input_names: str = ''):
63     path_wo_ext = '.'.join(pretrained_model.split('.')[:-1])
64     pretrained_model_name_w_iter = path_wo_ext.split(os.sep)[-1]
65     pretrained_model_name = '-'.join(path_wo_ext.split('-')[:-1])
66     iteration_number = int(pretrained_model_name_w_iter.split('-')[-1])
67     files_dir = os.path.dirname(pretrained_model)
68
69     if input_names:
70         model_params = load_params(pretrained_model, data_names=input_names.split(','))
71     else:
72         model_params = load_params(pretrained_model)
73
74     pretrained_params = mx.nd.load(pretrained_model) if pretrained_model_name else None
75     nd_args = mx.nd.load(os.path.join(files_dir, '%s_args.nd' % nd_prefix_name)) if nd_prefix_name else None
76     nd_auxs = mx.nd.load(os.path.join(files_dir, '%s_auxs.nd' % nd_prefix_name)) if nd_prefix_name else None
77     nd_args = add_pretrained_model(pretrained_params, nd_args, pretrained_model_name,
78                                    iteration_number,
79                                    input_names)
80
81     model_params._arg_params = nd_args
82     model_params._aux_params = nd_auxs
83     model_params._param_names = list(nd_args.keys())
84     model_params._aux_names = list(nd_auxs.keys())
85     return model_params