"""
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
raise Error('Incorrect saved model tag was provided. Specify --saved_model_tags with no spaces in it')
argv.saved_model_tags = argv.saved_model_tags.split(',')
- outputs = None
+ argv.output = argv.output.split(',') if argv.output else None
- if argv.output:
- outputs = argv.output.split(',')
-
- placeholder_shapes = get_placeholder_shapes(argv.input, argv.input_shape, argv.batch)
+ argv.placeholder_shapes = get_placeholder_shapes(argv.input, argv.input_shape, argv.batch)
mean_values = parse_tuple_pairs(argv.mean_values)
scale_values = parse_tuple_pairs(argv.scale_values)
mean_scale = get_mean_scale_dictionary(mean_values, scale_values, argv.input)
+ argv.mean_scale_values = mean_scale
if not os.path.exists(argv.output_dir):
try:
raise Error("Output directory {} is not writable for current user. " +
refer_to_faq_msg(22), argv.output_dir)
- log.debug("Placeholder shapes : {}".format(placeholder_shapes))
+ log.debug("Placeholder shapes : {}".format(argv.placeholder_shapes))
ret_res = 1
if hasattr(argv, 'extensions') and argv.extensions and argv.extensions != '':
if is_tf:
import mo.pipeline.tf as mo_tf
- from mo.front.tf.register_custom_ops import update_registration
- import_extensions.load_dirs(argv.framework, extensions, update_registration)
- ret_res = mo_tf.tf2nx(argv, argv.input_model, model_name, outputs, argv.output_dir, argv.scale,
- is_binary=not argv.input_model_is_text,
- user_shapes=placeholder_shapes,
- mean_scale_values=mean_scale)
+ from mo.front.tf.register_custom_ops import get_front_classes
+ import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
+ ret_res = mo_tf.tf2nx(argv, argv.input_model, model_name, argv.output_dir,
+ is_binary=not argv.input_model_is_text)
elif is_caffe:
import mo.pipeline.caffe as mo_caffe
- from mo.front.caffe.register_custom_ops import update_registration
- import_extensions.load_dirs(argv.framework, extensions, update_registration)
- ret_res = mo_caffe.driver(argv, argv.input_proto, argv.input_model, model_name, outputs, argv.output_dir,
- argv.scale,
- user_shapes=placeholder_shapes,
- mean_scale_values=mean_scale,
+ from mo.front.caffe.register_custom_ops import get_front_classes
+ import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
+ ret_res = mo_caffe.driver(argv, argv.input_proto, argv.input_model, model_name, argv.output_dir,
mean_file=argv.mean_file,
mean_file_offsets=mean_file_offsets,
custom_layers_mapping_path=custom_layers_mapping_path)
elif is_mxnet:
import mo.pipeline.mx as mo_mxnet
- from mo.front.mxnet.register_custom_ops import update_registration
- import_extensions.load_dirs(argv.framework, extensions, update_registration)
- ret_res = mo_mxnet.driver(argv, argv.input_model, model_name, outputs, argv.output_dir, argv.scale,
- placeholder_shapes=placeholder_shapes,
- mean_scale_values=mean_scale)
+ from mo.front.mxnet.register_custom_ops import get_front_classes
+ import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
+ ret_res = mo_mxnet.driver(argv, argv.input_model, model_name, argv.output_dir)
elif is_kaldi:
import mo.pipeline.kaldi as mo_kaldi
- from mo.front.kaldi.register_custom_ops import update_registration
- import_extensions.load_dirs(argv.framework, extensions, update_registration)
- ret_res = mo_kaldi.driver(argv, argv.input_model, model_name, outputs, argv.output_dir, argv.scale,
- placeholder_shapes=placeholder_shapes,
- mean_scale_values=mean_scale)
+ from mo.front.kaldi.register_custom_ops import get_front_classes
+ import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
+ ret_res = mo_kaldi.driver(argv, argv.input_model, model_name, argv.output_dir)
elif is_onnx:
import mo.pipeline.onnx as mo_onnx
- from mo.front.onnx.register_custom_ops import update_registration
- import_extensions.load_dirs(argv.framework, extensions, update_registration)
- ret_res = mo_onnx.driver(argv, argv.input_model, model_name, outputs, argv.output_dir, argv.scale,
- user_shapes=placeholder_shapes,
- mean_scale_values=mean_scale)
+ from mo.front.onnx.register_custom_ops import get_front_classes
+ import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
+ ret_res = mo_onnx.driver(argv, argv.input_model, model_name, argv.output_dir)
if ret_res != 0:
return ret_res