Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / tf2tfliteV2 / tf2tfliteV2.py
index 8b6ba0d..82d6ee2 100755 (executable)
@@ -1,3 +1,5 @@
+#!/usr/bin/env python
+
 # Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
 # Copyright (C) 2018 The TensorFlow Authors
 #
@@ -48,8 +50,13 @@ def _get_parser():
 
     # Input and output path.
     parser.add_argument(
-        "--input_path", type=str, help="Full filepath of the input file.", required=True)
+        "-i",
+        "--input_path",
+        type=str,
+        help="Full filepath of the input file.",
+        required=True)
     parser.add_argument(
+        "-o",
         "--output_path",
         type=str,
         help="Full filepath of the output file.",
@@ -57,15 +64,20 @@ def _get_parser():
 
     # Input and output arrays.
     parser.add_argument(
+        "-I",
         "--input_arrays",
         type=str,
         help="Names of the input arrays, comma-separated.",
         required=True)
     parser.add_argument(
+        "-s",
         "--input_shapes",
         type=str,
-        help="Shapes corresponding to --input_arrays, colon-separated.")
+        help=
+        "Shapes corresponding to --input_arrays, colon-separated.(ex:\"1,4,4,3:1,20,20,3\")"
+    )
     parser.add_argument(
+        "-O",
         "--output_arrays",
         type=str,
         help="Names of the output arrays, comma-separated.",
@@ -141,9 +153,14 @@ def _v2_convert(flags):
 
     wrap_func = wrap_frozen_graph(
         graph_def,
-        inputs=[_str + ":0" for _str in _parse_array(flags.input_arrays)],
-        # TODO What if multiple outputs come in?
-        outputs=[_str + ":0" for _str in _parse_array(flags.output_arrays)])
+        inputs=[
+            _str + ":0" if len(_str.split(":")) == 1 else _str
+            for _str in _parse_array(flags.input_arrays)
+        ],
+        outputs=[
+            _str + ":0" if len(_str.split(":")) == 1 else _str
+            for _str in _parse_array(flags.output_arrays)
+        ])
     converter = tf.lite.TFLiteConverter.from_concrete_functions([wrap_func])
 
     converter.allow_custom_ops = True