Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / compiler / tf2tfliteV2 / tf2tfliteV2.py
index 82d6ee2..c51dabd 100755 (executable)
@@ -1,4 +1,4 @@
-#!/usr/bin/env python
+#!/usr/bin/env python3
 
 # Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
 # Copyright (C) 2018 The TensorFlow Authors
@@ -48,6 +48,27 @@ def _get_parser():
     converter_version.add_argument(
         "--v2", action="store_true", help="Use TensorFlow Lite Converter 2.x")
 
+    # Input model format
+    model_format_arg = parser.add_mutually_exclusive_group()
+    model_format_arg.add_argument(
+        "--graph_def",
+        action="store_const",
+        dest="model_format",
+        const="graph_def",
+        help="Use graph def file(default)")
+    model_format_arg.add_argument(
+        "--saved_model",
+        action="store_const",
+        dest="model_format",
+        const="saved_model",
+        help="Use saved model")
+    model_format_arg.add_argument(
+        "--keras_model",
+        action="store_const",
+        dest="model_format",
+        const="keras_model",
+        help="Use keras model")
+
     # Input and output path.
     parser.add_argument(
         "-i",
@@ -83,6 +104,8 @@ def _get_parser():
         help="Names of the output arrays, comma-separated.",
         required=True)
 
+    # Set default value
+    parser.set_defaults(model_format="graph_def")
     return parser
 
 
@@ -122,17 +145,26 @@ def _parse_array(arrays, type_fn=str):
 
 
 def _v1_convert(flags):
-    input_shapes = None
-    if flags.input_shapes:
-        input_arrays = _parse_array(flags.input_arrays)
-        input_shapes_list = [
-            _parse_array(shape, type_fn=int) for shape in flags.input_shapes.split(":")
-        ]
-        input_shapes = dict(list(zip(input_arrays, input_shapes_list)))
-
-    converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
-        flags.input_path, _parse_array(flags.input_arrays),
-        _parse_array(flags.output_arrays), input_shapes)
+    if flags.model_format == "graph_def":
+        input_shapes = None
+        if flags.input_shapes:
+            input_arrays = _parse_array(flags.input_arrays)
+            input_shapes_list = [
+                _parse_array(shape, type_fn=int)
+                for shape in flags.input_shapes.split(":")
+            ]
+            input_shapes = dict(list(zip(input_arrays, input_shapes_list)))
+
+        converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
+            flags.input_path, _parse_array(flags.input_arrays),
+            _parse_array(flags.output_arrays), input_shapes)
+
+    if flags.model_format == "saved_model":
+        converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(flags.input_path)
+
+    if flags.model_format == "keras_model":
+        converter = tf.compat.v1.lite.TFLiteConverter.from_keras_model_file(
+            flags.input_path)
 
     converter.allow_custom_ops = True
 
@@ -141,27 +173,35 @@ def _v1_convert(flags):
 
 
 def _v2_convert(flags):
-    file_content = open(flags.input_path, 'rb').read()
-    try:
-        graph_def = tf.compat.v1.GraphDef()
-        graph_def.ParseFromString(file_content)
-    except (_text_format.ParseError, DecodeError):
+    if flags.model_format == "graph_def":
+        file_content = open(flags.input_path, 'rb').read()
         try:
-            _text_format.Merge(file_content, graph_def)
+            graph_def = tf.compat.v1.GraphDef()
+            graph_def.ParseFromString(file_content)
         except (_text_format.ParseError, DecodeError):
-            raise IOError("Unable to parse input file '{}'.".format(flags.input_path))
-
-    wrap_func = wrap_frozen_graph(
-        graph_def,
-        inputs=[
-            _str + ":0" if len(_str.split(":")) == 1 else _str
-            for _str in _parse_array(flags.input_arrays)
-        ],
-        outputs=[
-            _str + ":0" if len(_str.split(":")) == 1 else _str
-            for _str in _parse_array(flags.output_arrays)
-        ])
-    converter = tf.lite.TFLiteConverter.from_concrete_functions([wrap_func])
+            try:
+                _text_format.Merge(file_content, graph_def)
+            except (_text_format.ParseError, DecodeError):
+                raise IOError("Unable to parse input file '{}'.".format(flags.input_path))
+
+        wrap_func = wrap_frozen_graph(
+            graph_def,
+            inputs=[
+                _str + ":0" if len(_str.split(":")) == 1 else _str
+                for _str in _parse_array(flags.input_arrays)
+            ],
+            outputs=[
+                _str + ":0" if len(_str.split(":")) == 1 else _str
+                for _str in _parse_array(flags.output_arrays)
+            ])
+        converter = tf.lite.TFLiteConverter.from_concrete_functions([wrap_func])
+
+    if flags.model_format == "saved_model":
+        converter = tf.lite.TFLiteConverter.from_saved_model(flags.input_path)
+
+    if flags.model_format == "keras_model":
+        keras_model = tf.keras.models.load_model(flags.input_path)
+        converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
 
     converter.allow_custom_ops = True
     converter.experimental_new_converter = True