Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / main.py
1 """
2  Copyright (c) 2018-2019 Intel Corporation
3
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 """
16
17 import argparse
18 import datetime
19 import logging as log
20 import os
21 import sys
22 import traceback
23 from collections import OrderedDict
24
25 import numpy as np
26
27 from mo.utils import import_extensions
28 from mo.utils.cli_parser import get_placeholder_shapes, get_tuple_values, get_model_name, \
29     get_common_cli_options, get_caffe_cli_options, get_tf_cli_options, get_mxnet_cli_options, get_kaldi_cli_options, \
30     get_onnx_cli_options, get_mean_scale_dictionary, parse_tuple_pairs
31 from mo.utils.error import Error, FrameworkError
32 from mo.utils.guess_framework import guess_framework_by_ext
33 from mo.utils.logger import init_logger
34 from mo.utils.utils import refer_to_faq_msg
35 from mo.utils.version import get_version
36 from mo.utils.versions_checker import check_requirements
37
38
39 def replace_ext(name: str, old: str, new: str):
40     base, ext = os.path.splitext(name)
41     log.debug("base: {}, ext: {}".format(base, ext))
42     if ext == old:
43         return base + new
44
45
46 def print_argv(argv: argparse.Namespace, is_caffe: bool, is_tf: bool, is_mxnet: bool, is_kaldi: bool, is_onnx: bool,
47                model_name: str):
48     print('Model Optimizer arguments:')
49     props = OrderedDict()
50     props['common_args'] = get_common_cli_options(model_name)
51     if is_caffe:
52         props['caffe_args'] = get_caffe_cli_options()
53     if is_tf:
54         props['tf_args'] = get_tf_cli_options()
55     if is_mxnet:
56         props['mxnet_args'] = get_mxnet_cli_options()
57     if is_kaldi:
58         props['kaldi_args'] = get_kaldi_cli_options()
59     if is_onnx:
60         props['onnx_args'] = get_onnx_cli_options()
61
62     framework_specifics_map = {
63         'common_args': 'Common parameters:',
64         'caffe_args': 'Caffe specific parameters:',
65         'tf_args': 'TensorFlow specific parameters:',
66         'mxnet_args': 'MXNet specific parameters:',
67         'kaldi_args': 'Kaldi specific parameters:',
68         'onnx_args': 'ONNX specific parameters:',
69     }
70
71     lines = []
72     for key in props:
73         lines.append(framework_specifics_map[key])
74         for (op, desc) in props[key].items():
75             if isinstance(desc, list):
76                 lines.append('\t{}: \t{}'.format(desc[0], desc[1](getattr(argv, op, 'NONE'))))
77             else:
78                 if op is 'k':
79                     default_path = os.path.join(os.path.dirname(sys.argv[0]),
80                                                 'extensions/front/caffe/CustomLayersMapping.xml')
81                     if getattr(argv, op, 'NONE') == default_path:
82                         lines.append('\t{}: \t{}'.format(desc, 'Default'))
83                         continue
84                 lines.append('\t{}: \t{}'.format(desc, getattr(argv, op, 'NONE')))
85     lines.append('Model Optimizer version: \t{}'.format(get_version()))
86     print('\n'.join(lines))
87
88
89 def driver(argv: argparse.Namespace):
90     if argv.version:
91         print('Version of Model Optimizer is: {}'.format(get_version()))
92         return 0
93
94     init_logger(argv.log_level.upper(), argv.silent)
95     start_time = datetime.datetime.now()
96
97     if not argv.framework:
98         if 'saved_model_dir' in argv and argv.saved_model_dir or \
99                 'input_meta_graph' in argv and argv.input_meta_graph:
100             argv.framework = 'tf'
101         elif 'input_symbol ' in argv and argv.input_symbol or \
102                 'pretrained_model_name' in argv and argv.pretrained_model_name:
103             argv.framework = 'mxnet'
104         elif 'input_proto' in argv and argv.input_proto:
105             argv.framework = 'caffe'
106         elif argv.input_model is None:
107             raise Error('Path to input model is required: use --input_model.')
108         else:
109             argv.framework = guess_framework_by_ext(argv.input_model)
110         if not argv.framework:
111             raise Error(
112                 'Framework name can not be deduced from the given options: {}={}. ' +
113                 'Use --framework to choose one of caffe, tf, mxnet, kaldi, onnx',
114                 '--input_model',
115                 argv.input_model,
116                 refer_to_faq_msg(15),
117             )
118
119     is_tf, is_caffe, is_mxnet, is_kaldi, is_onnx = (argv.framework == x for x in
120                                                     ['tf', 'caffe', 'mxnet', 'kaldi', 'onnx'])
121
122     if is_tf and not argv.input_model and not argv.saved_model_dir and not argv.input_meta_graph:
123         raise Error('Path to input model or saved model dir is required: use --input_model, --saved_model_dir or '
124                     '--input_meta_graph')
125     elif is_mxnet and not argv.input_model and not argv.input_symbol and not argv.pretrained_model_name:
126         raise Error('Path to input model or input symbol or pretrained_model_name is required: use --input_model or '
127                     '--input_symbol or --pretrained_model_name')
128     elif is_caffe and not argv.input_model and not argv.input_proto:
129         raise Error('Path to input model or input proto is required: use --input_model or --input_proto')
130     elif (is_kaldi or is_onnx) and not argv.input_model:
131         raise Error('Path to input model is required: use --input_model.')
132
133     log.debug(str(argv))
134     log.debug("Model Optimizer started")
135
136     model_name = "<UNKNOWN_NAME>"
137     if argv.model_name:
138         model_name = argv.model_name
139     elif argv.input_model:
140         model_name = get_model_name(argv.input_model)
141     elif is_tf and argv.saved_model_dir:
142         model_name = "saved_model"
143     elif is_tf and argv.input_meta_graph:
144         model_name = get_model_name(argv.input_meta_graph)
145     elif is_mxnet and argv.input_symbol:
146         model_name = get_model_name(argv.input_symbol)
147
148     log.debug('Output model name would be {}{{.xml, .bin}}'.format(model_name))
149
150     # if --input_proto is not provided, try to retrieve another one
151     # by suffix substitution from model file name
152     if is_caffe and not argv.input_proto:
153         argv.input_proto = replace_ext(argv.input_model, '.caffemodel', '.prototxt')
154
155         if not argv.input_proto:
156             raise Error("Cannot find prototxt file: for Caffe please specify --input_proto - a " +
157                         "protobuf file that stores topology and --input_model that stores " +
158                         "pretrained weights. " +
159                         refer_to_faq_msg(20))
160         log.info('Deduced name for prototxt: {}'.format(argv.input_proto))
161
162     if not argv.silent:
163         print_argv(argv, is_caffe, is_tf, is_mxnet, is_kaldi, is_onnx, model_name)
164
165     if not any([is_tf, is_caffe, is_mxnet, is_kaldi, is_onnx]):
166         raise Error(
167             'Framework {} is not a valid target. ' +
168             'Please use --framework with one from the list: caffe, tf, mxnet, kaldi, onnx. ' +
169             refer_to_faq_msg(15),
170             argv.framework
171         )
172
173     ret_code = check_requirements(framework=argv.framework)
174     if ret_code:
175         return ret_code
176
177     if is_mxnet and not argv.input_shape:
178         raise Error('Input shape is required to convert MXNet model. Please provide it with --input_shape. ' +
179                     refer_to_faq_msg(16))
180
181     mean_file_offsets = None
182     if is_caffe and argv.mean_file and argv.mean_values:
183         raise Error('Both --mean_file and mean_values are specified. Specify either mean file or mean values. ' +
184                     refer_to_faq_msg(17))
185     elif is_caffe and argv.mean_file and argv.mean_file_offsets:
186
187         values = get_tuple_values(argv.mean_file_offsets, t=int, num_exp_values=2)
188         mean_file_offsets = np.array([int(x) for x in values[0].split(',')])
189         if not all([offset >= 0 for offset in mean_file_offsets]):
190             raise Error("Negative value specified for --mean_file_offsets option. "
191                         "Please specify positive integer values in format '(x,y)'. " +
192                         refer_to_faq_msg(18))
193     custom_layers_mapping_path = argv.k if is_caffe and argv.k else None
194
195     if argv.scale and argv.scale_values:
196         raise Error(
197             'Both --scale and --scale_values are defined. Specify either scale factor or scale values per input ' +
198             'channels. ' + refer_to_faq_msg(19))
199
200     if argv.scale and argv.scale < 1.0:
201         log.error("The scale value is less than 1.0. This is most probably an issue because the scale value specifies "
202                   "floating point value which all input values will be *divided*.", extra={'is_warning': True})
203
204     if argv.input_model and (is_tf and argv.saved_model_dir):
205         raise Error('Both --input_model and --saved_model_dir are defined. '
206                     'Specify either input model or saved model directory.')
207     if is_tf:
208         if argv.saved_model_tags is not None:
209             if ' ' in argv.saved_model_tags:
210                 raise Error('Incorrect saved model tag was provided. Specify --saved_model_tags with no spaces in it')
211             argv.saved_model_tags = argv.saved_model_tags.split(',')
212
213     argv.output = argv.output.split(',') if argv.output else None
214
215     argv.placeholder_shapes = get_placeholder_shapes(argv.input, argv.input_shape, argv.batch)
216
217     mean_values = parse_tuple_pairs(argv.mean_values)
218     scale_values = parse_tuple_pairs(argv.scale_values)
219     mean_scale = get_mean_scale_dictionary(mean_values, scale_values, argv.input)
220     argv.mean_scale_values = mean_scale
221
222     if not os.path.exists(argv.output_dir):
223         try:
224             os.makedirs(argv.output_dir)
225         except PermissionError as e:
226             raise Error("Failed to create directory {}. Permission denied! " +
227                         refer_to_faq_msg(22),
228                         argv.output_dir) from e
229     else:
230         if not os.access(argv.output_dir, os.W_OK):
231             raise Error("Output directory {} is not writable for current user. " +
232                         refer_to_faq_msg(22), argv.output_dir)
233
234     log.debug("Placeholder shapes : {}".format(argv.placeholder_shapes))
235
236     ret_res = 1
237     if hasattr(argv, 'extensions') and argv.extensions and argv.extensions != '':
238         extensions = argv.extensions.split(',')
239     else:
240         extensions = None
241
242     if argv.freeze_placeholder_with_value is not None:
243         replacements = {}
244         for replace in argv.freeze_placeholder_with_value.split(','):
245             rp = replace.split('->')
246             if len(rp) != 2:
247                 raise Error("Wrong replacement syntax. Use --freeze_placeholder_with_value "
248                             "\"node1_name->value1,node2_name->value2\"")
249             if rp[0] in replacements and replacements[rp[0]] != rp[1]:
250                 raise Error("Overriding replacement value of placeholder with name '{}': old value = {}, new value = {}"
251                             ".".format(rp[0], replacements[rp[0]], rp[1]))
252             value = rp[1]
253             if '[' in value.strip(' '):
254                 value = value.replace('[', '').replace(']', '').split(' ')
255             replacements[rp[0]] = value
256         argv.freeze_placeholder_with_value = replacements
257
258     if is_tf:
259         import mo.pipeline.tf as mo_tf
260         from mo.front.tf.register_custom_ops import get_front_classes
261         import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
262         ret_res = mo_tf.tf2nx(argv, argv.input_model, model_name, argv.output_dir,
263                               is_binary=not argv.input_model_is_text)
264
265     elif is_caffe:
266         import mo.pipeline.caffe as mo_caffe
267         from mo.front.caffe.register_custom_ops import get_front_classes
268         import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
269         ret_res = mo_caffe.driver(argv, argv.input_proto, argv.input_model, model_name, argv.output_dir,
270                                   mean_file=argv.mean_file,
271                                   mean_file_offsets=mean_file_offsets,
272                                   custom_layers_mapping_path=custom_layers_mapping_path)
273
274     elif is_mxnet:
275         import mo.pipeline.mx as mo_mxnet
276         from mo.front.mxnet.register_custom_ops import get_front_classes
277         import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
278         ret_res = mo_mxnet.driver(argv, argv.input_model, model_name, argv.output_dir)
279
280     elif is_kaldi:
281         import mo.pipeline.kaldi as mo_kaldi
282         from mo.front.kaldi.register_custom_ops import get_front_classes
283         import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
284         ret_res = mo_kaldi.driver(argv, argv.input_model, model_name, argv.output_dir)
285     elif is_onnx:
286         import mo.pipeline.onnx as mo_onnx
287         from mo.front.onnx.register_custom_ops import get_front_classes
288         import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
289         ret_res = mo_onnx.driver(argv, argv.input_model, model_name, argv.output_dir)
290
291     if ret_res != 0:
292         return ret_res
293     if not (is_tf and argv.tensorflow_custom_operations_config_update):
294         output_dir = argv.output_dir if argv.output_dir != '.' else os.getcwd()
295         print('\n[ SUCCESS ] Generated IR model.')
296         print('[ SUCCESS ] XML file: {}.xml'.format(os.path.join(output_dir, model_name)))
297         print('[ SUCCESS ] BIN file: {}.bin'.format(os.path.join(output_dir, model_name)))
298         elapsed_time = datetime.datetime.now() - start_time
299         print('[ SUCCESS ] Total execution time: {:.2f} seconds. '.format(elapsed_time.total_seconds()))
300     return ret_res
301
302
303 def main(cli_parser: argparse.ArgumentParser, framework: str):
304     try:
305         # Initialize logger with 'ERROR' as default level to be able to form nice messages
306         # before arg parser deliver log_level requested by user
307         init_logger('ERROR', False)
308
309         argv = cli_parser.parse_args()
310         if framework:
311             argv.framework = framework
312         return driver(argv)
313     except (FileNotFoundError, NotADirectoryError) as e:
314         log.error('File {} was not found'.format(str(e).split('No such file or directory:')[1]))
315         log.debug(traceback.format_exc())
316     except Error as err:
317         log.error(err)
318         log.debug(traceback.format_exc())
319     except FrameworkError as err:
320         log.error(err, extra={'framework_error': True})
321         log.debug(traceback.format_exc())
322     except Exception as err:
323         log.error("-------------------------------------------------")
324         log.error("----------------- INTERNAL ERROR ----------------")
325         log.error("Unexpected exception happened.")
326         log.error("Please contact Model Optimizer developers and forward the following information:")
327         log.error(str(err))
328         log.error(traceback.format_exc())
329         log.error("---------------- END OF BUG REPORT --------------")
330         log.error("-------------------------------------------------")
331     return 1