2 Copyright (c) 2018-2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
23 from collections import OrderedDict
27 from mo.utils import import_extensions
28 from mo.utils.cli_parser import get_placeholder_shapes, get_tuple_values, get_model_name, \
29 get_common_cli_options, get_caffe_cli_options, get_tf_cli_options, get_mxnet_cli_options, get_kaldi_cli_options, \
30 get_onnx_cli_options, get_mean_scale_dictionary, parse_tuple_pairs
31 from mo.utils.error import Error, FrameworkError
32 from mo.utils.guess_framework import guess_framework_by_ext
33 from mo.utils.logger import init_logger
34 from mo.utils.utils import refer_to_faq_msg
35 from mo.utils.version import get_version
36 from mo.utils.versions_checker import check_requirements
39 def replace_ext(name: str, old: str, new: str):
40 base, ext = os.path.splitext(name)
41 log.debug("base: {}, ext: {}".format(base, ext))
46 def print_argv(argv: argparse.Namespace, is_caffe: bool, is_tf: bool, is_mxnet: bool, is_kaldi: bool, is_onnx: bool,
48 print('Model Optimizer arguments:')
50 props['common_args'] = get_common_cli_options(model_name)
52 props['caffe_args'] = get_caffe_cli_options()
54 props['tf_args'] = get_tf_cli_options()
56 props['mxnet_args'] = get_mxnet_cli_options()
58 props['kaldi_args'] = get_kaldi_cli_options()
60 props['onnx_args'] = get_onnx_cli_options()
62 framework_specifics_map = {
63 'common_args': 'Common parameters:',
64 'caffe_args': 'Caffe specific parameters:',
65 'tf_args': 'TensorFlow specific parameters:',
66 'mxnet_args': 'MXNet specific parameters:',
67 'kaldi_args': 'Kaldi specific parameters:',
68 'onnx_args': 'ONNX specific parameters:',
73 lines.append(framework_specifics_map[key])
74 for (op, desc) in props[key].items():
75 if isinstance(desc, list):
76 lines.append('\t{}: \t{}'.format(desc[0], desc[1](getattr(argv, op, 'NONE'))))
79 default_path = os.path.join(os.path.dirname(sys.argv[0]),
80 'extensions/front/caffe/CustomLayersMapping.xml')
81 if getattr(argv, op, 'NONE') == default_path:
82 lines.append('\t{}: \t{}'.format(desc, 'Default'))
84 lines.append('\t{}: \t{}'.format(desc, getattr(argv, op, 'NONE')))
85 lines.append('Model Optimizer version: \t{}'.format(get_version()))
86 print('\n'.join(lines))
89 def driver(argv: argparse.Namespace):
91 print('Version of Model Optimizer is: {}'.format(get_version()))
94 init_logger(argv.log_level.upper(), argv.silent)
95 start_time = datetime.datetime.now()
97 if not argv.framework:
98 if 'saved_model_dir' in argv and argv.saved_model_dir or \
99 'input_meta_graph' in argv and argv.input_meta_graph:
100 argv.framework = 'tf'
101 elif 'input_symbol ' in argv and argv.input_symbol or \
102 'pretrained_model_name' in argv and argv.pretrained_model_name:
103 argv.framework = 'mxnet'
104 elif 'input_proto' in argv and argv.input_proto:
105 argv.framework = 'caffe'
106 elif argv.input_model is None:
107 raise Error('Path to input model is required: use --input_model.')
109 argv.framework = guess_framework_by_ext(argv.input_model)
110 if not argv.framework:
112 'Framework name can not be deduced from the given options: {}={}. ' +
113 'Use --framework to choose one of caffe, tf, mxnet, kaldi, onnx',
116 refer_to_faq_msg(15),
119 is_tf, is_caffe, is_mxnet, is_kaldi, is_onnx = (argv.framework == x for x in
120 ['tf', 'caffe', 'mxnet', 'kaldi', 'onnx'])
122 if is_tf and not argv.input_model and not argv.saved_model_dir and not argv.input_meta_graph:
123 raise Error('Path to input model or saved model dir is required: use --input_model, --saved_model_dir or '
124 '--input_meta_graph')
125 elif is_mxnet and not argv.input_model and not argv.input_symbol and not argv.pretrained_model_name:
126 raise Error('Path to input model or input symbol or pretrained_model_name is required: use --input_model or '
127 '--input_symbol or --pretrained_model_name')
128 elif is_caffe and not argv.input_model and not argv.input_proto:
129 raise Error('Path to input model or input proto is required: use --input_model or --input_proto')
130 elif (is_kaldi or is_onnx) and not argv.input_model:
131 raise Error('Path to input model is required: use --input_model.')
134 log.debug("Model Optimizer started")
136 model_name = "<UNKNOWN_NAME>"
138 model_name = argv.model_name
139 elif argv.input_model:
140 model_name = get_model_name(argv.input_model)
141 elif is_tf and argv.saved_model_dir:
142 model_name = "saved_model"
143 elif is_tf and argv.input_meta_graph:
144 model_name = get_model_name(argv.input_meta_graph)
145 elif is_mxnet and argv.input_symbol:
146 model_name = get_model_name(argv.input_symbol)
148 log.debug('Output model name would be {}{{.xml, .bin}}'.format(model_name))
150 # if --input_proto is not provided, try to retrieve another one
151 # by suffix substitution from model file name
152 if is_caffe and not argv.input_proto:
153 argv.input_proto = replace_ext(argv.input_model, '.caffemodel', '.prototxt')
155 if not argv.input_proto:
156 raise Error("Cannot find prototxt file: for Caffe please specify --input_proto - a " +
157 "protobuf file that stores topology and --input_model that stores " +
158 "pretrained weights. " +
159 refer_to_faq_msg(20))
160 log.info('Deduced name for prototxt: {}'.format(argv.input_proto))
163 print_argv(argv, is_caffe, is_tf, is_mxnet, is_kaldi, is_onnx, model_name)
165 if not any([is_tf, is_caffe, is_mxnet, is_kaldi, is_onnx]):
167 'Framework {} is not a valid target. ' +
168 'Please use --framework with one from the list: caffe, tf, mxnet, kaldi, onnx. ' +
169 refer_to_faq_msg(15),
173 ret_code = check_requirements(framework=argv.framework)
177 if is_mxnet and not argv.input_shape:
178 raise Error('Input shape is required to convert MXNet model. Please provide it with --input_shape. ' +
179 refer_to_faq_msg(16))
181 mean_file_offsets = None
182 if is_caffe and argv.mean_file and argv.mean_values:
183 raise Error('Both --mean_file and mean_values are specified. Specify either mean file or mean values. ' +
184 refer_to_faq_msg(17))
185 elif is_caffe and argv.mean_file and argv.mean_file_offsets:
187 values = get_tuple_values(argv.mean_file_offsets, t=int, num_exp_values=2)
188 mean_file_offsets = np.array([int(x) for x in values[0].split(',')])
189 if not all([offset >= 0 for offset in mean_file_offsets]):
190 raise Error("Negative value specified for --mean_file_offsets option. "
191 "Please specify positive integer values in format '(x,y)'. " +
192 refer_to_faq_msg(18))
193 custom_layers_mapping_path = argv.k if is_caffe and argv.k else None
195 if argv.scale and argv.scale_values:
197 'Both --scale and --scale_values are defined. Specify either scale factor or scale values per input ' +
198 'channels. ' + refer_to_faq_msg(19))
200 if argv.scale and argv.scale < 1.0:
201 log.error("The scale value is less than 1.0. This is most probably an issue because the scale value specifies "
202 "floating point value which all input values will be *divided*.", extra={'is_warning': True})
204 if argv.input_model and (is_tf and argv.saved_model_dir):
205 raise Error('Both --input_model and --saved_model_dir are defined. '
206 'Specify either input model or saved model directory.')
208 if argv.saved_model_tags is not None:
209 if ' ' in argv.saved_model_tags:
210 raise Error('Incorrect saved model tag was provided. Specify --saved_model_tags with no spaces in it')
211 argv.saved_model_tags = argv.saved_model_tags.split(',')
213 argv.output = argv.output.split(',') if argv.output else None
215 argv.placeholder_shapes = get_placeholder_shapes(argv.input, argv.input_shape, argv.batch)
217 mean_values = parse_tuple_pairs(argv.mean_values)
218 scale_values = parse_tuple_pairs(argv.scale_values)
219 mean_scale = get_mean_scale_dictionary(mean_values, scale_values, argv.input)
220 argv.mean_scale_values = mean_scale
222 if not os.path.exists(argv.output_dir):
224 os.makedirs(argv.output_dir)
225 except PermissionError as e:
226 raise Error("Failed to create directory {}. Permission denied! " +
227 refer_to_faq_msg(22),
228 argv.output_dir) from e
230 if not os.access(argv.output_dir, os.W_OK):
231 raise Error("Output directory {} is not writable for current user. " +
232 refer_to_faq_msg(22), argv.output_dir)
234 log.debug("Placeholder shapes : {}".format(argv.placeholder_shapes))
237 if hasattr(argv, 'extensions') and argv.extensions and argv.extensions != '':
238 extensions = argv.extensions.split(',')
242 if argv.freeze_placeholder_with_value is not None:
244 for replace in argv.freeze_placeholder_with_value.split(','):
245 rp = replace.split('->')
247 raise Error("Wrong replacement syntax. Use --freeze_placeholder_with_value "
248 "\"node1_name->value1,node2_name->value2\"")
249 if rp[0] in replacements and replacements[rp[0]] != rp[1]:
250 raise Error("Overriding replacement value of placeholder with name '{}': old value = {}, new value = {}"
251 ".".format(rp[0], replacements[rp[0]], rp[1]))
253 if '[' in value.strip(' '):
254 value = value.replace('[', '').replace(']', '').split(' ')
255 replacements[rp[0]] = value
256 argv.freeze_placeholder_with_value = replacements
259 import mo.pipeline.tf as mo_tf
260 from mo.front.tf.register_custom_ops import get_front_classes
261 import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
262 ret_res = mo_tf.tf2nx(argv, argv.input_model, model_name, argv.output_dir,
263 is_binary=not argv.input_model_is_text)
266 import mo.pipeline.caffe as mo_caffe
267 from mo.front.caffe.register_custom_ops import get_front_classes
268 import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
269 ret_res = mo_caffe.driver(argv, argv.input_proto, argv.input_model, model_name, argv.output_dir,
270 mean_file=argv.mean_file,
271 mean_file_offsets=mean_file_offsets,
272 custom_layers_mapping_path=custom_layers_mapping_path)
275 import mo.pipeline.mx as mo_mxnet
276 from mo.front.mxnet.register_custom_ops import get_front_classes
277 import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
278 ret_res = mo_mxnet.driver(argv, argv.input_model, model_name, argv.output_dir)
281 import mo.pipeline.kaldi as mo_kaldi
282 from mo.front.kaldi.register_custom_ops import get_front_classes
283 import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
284 ret_res = mo_kaldi.driver(argv, argv.input_model, model_name, argv.output_dir)
286 import mo.pipeline.onnx as mo_onnx
287 from mo.front.onnx.register_custom_ops import get_front_classes
288 import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
289 ret_res = mo_onnx.driver(argv, argv.input_model, model_name, argv.output_dir)
293 if not (is_tf and argv.tensorflow_custom_operations_config_update):
294 output_dir = argv.output_dir if argv.output_dir != '.' else os.getcwd()
295 print('\n[ SUCCESS ] Generated IR model.')
296 print('[ SUCCESS ] XML file: {}.xml'.format(os.path.join(output_dir, model_name)))
297 print('[ SUCCESS ] BIN file: {}.bin'.format(os.path.join(output_dir, model_name)))
298 elapsed_time = datetime.datetime.now() - start_time
299 print('[ SUCCESS ] Total execution time: {:.2f} seconds. '.format(elapsed_time.total_seconds()))
303 def main(cli_parser: argparse.ArgumentParser, framework: str):
305 # Initialize logger with 'ERROR' as default level to be able to form nice messages
306 # before arg parser deliver log_level requested by user
307 init_logger('ERROR', False)
309 argv = cli_parser.parse_args()
311 argv.framework = framework
313 except (FileNotFoundError, NotADirectoryError) as e:
314 log.error('File {} was not found'.format(str(e).split('No such file or directory:')[1]))
315 log.debug(traceback.format_exc())
318 log.debug(traceback.format_exc())
319 except FrameworkError as err:
320 log.error(err, extra={'framework_error': True})
321 log.debug(traceback.format_exc())
322 except Exception as err:
323 log.error("-------------------------------------------------")
324 log.error("----------------- INTERNAL ERROR ----------------")
325 log.error("Unexpected exception happened.")
326 log.error("Please contact Model Optimizer developers and forward the following information:")
328 log.error(traceback.format_exc())
329 log.error("---------------- END OF BUG REPORT --------------")
330 log.error("-------------------------------------------------")