Publishing 2019 R1 content
[platform/upstream/dldt.git] / model-optimizer / mo / main.py
index f843c5d..ac96364 100644 (file)
@@ -1,5 +1,5 @@
 """
- Copyright (c) 2018 Intel Corporation
+ Copyright (c) 2018-2019 Intel Corporation
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -210,16 +210,14 @@ def driver(argv: argparse.Namespace):
                 raise Error('Incorrect saved model tag was provided. Specify --saved_model_tags with no spaces in it')
             argv.saved_model_tags = argv.saved_model_tags.split(',')
 
-    outputs = None
+    argv.output = argv.output.split(',') if argv.output else None
 
-    if argv.output:
-        outputs = argv.output.split(',')
-
-    placeholder_shapes = get_placeholder_shapes(argv.input, argv.input_shape, argv.batch)
+    argv.placeholder_shapes = get_placeholder_shapes(argv.input, argv.input_shape, argv.batch)
 
     mean_values = parse_tuple_pairs(argv.mean_values)
     scale_values = parse_tuple_pairs(argv.scale_values)
     mean_scale = get_mean_scale_dictionary(mean_values, scale_values, argv.input)
+    argv.mean_scale_values = mean_scale
 
     if not os.path.exists(argv.output_dir):
         try:
@@ -233,7 +231,7 @@ def driver(argv: argparse.Namespace):
             raise Error("Output directory {} is not writable for current user. " +
                         refer_to_faq_msg(22), argv.output_dir)
 
-    log.debug("Placeholder shapes : {}".format(placeholder_shapes))
+    log.debug("Placeholder shapes : {}".format(argv.placeholder_shapes))
 
     ret_res = 1
     if hasattr(argv, 'extensions') and argv.extensions and argv.extensions != '':
@@ -259,47 +257,36 @@ def driver(argv: argparse.Namespace):
 
     if is_tf:
         import mo.pipeline.tf as mo_tf
-        from mo.front.tf.register_custom_ops import update_registration
-        import_extensions.load_dirs(argv.framework, extensions, update_registration)
-        ret_res = mo_tf.tf2nx(argv, argv.input_model, model_name, outputs, argv.output_dir, argv.scale,
-                              is_binary=not argv.input_model_is_text,
-                              user_shapes=placeholder_shapes,
-                              mean_scale_values=mean_scale)
+        from mo.front.tf.register_custom_ops import get_front_classes
+        import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
+        ret_res = mo_tf.tf2nx(argv, argv.input_model, model_name, argv.output_dir,
+                              is_binary=not argv.input_model_is_text)
 
     elif is_caffe:
         import mo.pipeline.caffe as mo_caffe
-        from mo.front.caffe.register_custom_ops import update_registration
-        import_extensions.load_dirs(argv.framework, extensions, update_registration)
-        ret_res = mo_caffe.driver(argv, argv.input_proto, argv.input_model, model_name, outputs, argv.output_dir,
-                                  argv.scale,
-                                  user_shapes=placeholder_shapes,
-                                  mean_scale_values=mean_scale,
+        from mo.front.caffe.register_custom_ops import get_front_classes
+        import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
+        ret_res = mo_caffe.driver(argv, argv.input_proto, argv.input_model, model_name, argv.output_dir,
                                   mean_file=argv.mean_file,
                                   mean_file_offsets=mean_file_offsets,
                                   custom_layers_mapping_path=custom_layers_mapping_path)
 
     elif is_mxnet:
         import mo.pipeline.mx as mo_mxnet
-        from mo.front.mxnet.register_custom_ops import update_registration
-        import_extensions.load_dirs(argv.framework, extensions, update_registration)
-        ret_res = mo_mxnet.driver(argv, argv.input_model, model_name, outputs, argv.output_dir, argv.scale,
-                                  placeholder_shapes=placeholder_shapes,
-                                  mean_scale_values=mean_scale)
+        from mo.front.mxnet.register_custom_ops import get_front_classes
+        import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
+        ret_res = mo_mxnet.driver(argv, argv.input_model, model_name, argv.output_dir)
 
     elif is_kaldi:
         import mo.pipeline.kaldi as mo_kaldi
-        from mo.front.kaldi.register_custom_ops import update_registration
-        import_extensions.load_dirs(argv.framework, extensions, update_registration)
-        ret_res = mo_kaldi.driver(argv, argv.input_model, model_name, outputs, argv.output_dir, argv.scale,
-                                  placeholder_shapes=placeholder_shapes,
-                                  mean_scale_values=mean_scale)
+        from mo.front.kaldi.register_custom_ops import get_front_classes
+        import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
+        ret_res = mo_kaldi.driver(argv, argv.input_model, model_name, argv.output_dir)
     elif is_onnx:
         import mo.pipeline.onnx as mo_onnx
-        from mo.front.onnx.register_custom_ops import update_registration
-        import_extensions.load_dirs(argv.framework, extensions, update_registration)
-        ret_res = mo_onnx.driver(argv, argv.input_model, model_name, outputs, argv.output_dir, argv.scale,
-                                 user_shapes=placeholder_shapes,
-                                 mean_scale_values=mean_scale)
+        from mo.front.onnx.register_custom_ops import get_front_classes
+        import_extensions.load_dirs(argv.framework, extensions, get_front_classes)
+        ret_res = mo_onnx.driver(argv, argv.input_model, model_name, argv.output_dir)
 
     if ret_res != 0:
         return ret_res