-#!/usr/bin/env python
+#!/usr/bin/env python3
# Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
# Copyright (C) 2018 The TensorFlow Authors
converter_version.add_argument(
"--v2", action="store_true", help="Use TensorFlow Lite Converter 2.x")
+ # Input model format
+ model_format_arg = parser.add_mutually_exclusive_group()
+ model_format_arg.add_argument(
+ "--graph_def",
+ action="store_const",
+ dest="model_format",
+ const="graph_def",
+ help="Use graph def file(default)")
+ model_format_arg.add_argument(
+ "--saved_model",
+ action="store_const",
+ dest="model_format",
+ const="saved_model",
+ help="Use saved model")
+ model_format_arg.add_argument(
+ "--keras_model",
+ action="store_const",
+ dest="model_format",
+ const="keras_model",
+ help="Use keras model")
+
# Input and output path.
parser.add_argument(
"-i",
help="Names of the output arrays, comma-separated.",
required=True)
+ # Set default value
+ parser.set_defaults(model_format="graph_def")
return parser
def _v1_convert(flags):
- input_shapes = None
- if flags.input_shapes:
- input_arrays = _parse_array(flags.input_arrays)
- input_shapes_list = [
- _parse_array(shape, type_fn=int) for shape in flags.input_shapes.split(":")
- ]
- input_shapes = dict(list(zip(input_arrays, input_shapes_list)))
-
- converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
- flags.input_path, _parse_array(flags.input_arrays),
- _parse_array(flags.output_arrays), input_shapes)
+ if flags.model_format == "graph_def":
+ input_shapes = None
+ if flags.input_shapes:
+ input_arrays = _parse_array(flags.input_arrays)
+ input_shapes_list = [
+ _parse_array(shape, type_fn=int)
+ for shape in flags.input_shapes.split(":")
+ ]
+ input_shapes = dict(list(zip(input_arrays, input_shapes_list)))
+
+ converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
+ flags.input_path, _parse_array(flags.input_arrays),
+ _parse_array(flags.output_arrays), input_shapes)
+
+ if flags.model_format == "saved_model":
+ converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(flags.input_path)
+
+ if flags.model_format == "keras_model":
+ converter = tf.compat.v1.lite.TFLiteConverter.from_keras_model_file(
+ flags.input_path)
converter.allow_custom_ops = True
def _v2_convert(flags):
- file_content = open(flags.input_path, 'rb').read()
- try:
- graph_def = tf.compat.v1.GraphDef()
- graph_def.ParseFromString(file_content)
- except (_text_format.ParseError, DecodeError):
+ if flags.model_format == "graph_def":
+ file_content = open(flags.input_path, 'rb').read()
try:
- _text_format.Merge(file_content, graph_def)
+ graph_def = tf.compat.v1.GraphDef()
+ graph_def.ParseFromString(file_content)
except (_text_format.ParseError, DecodeError):
- raise IOError("Unable to parse input file '{}'.".format(flags.input_path))
-
- wrap_func = wrap_frozen_graph(
- graph_def,
- inputs=[
- _str + ":0" if len(_str.split(":")) == 1 else _str
- for _str in _parse_array(flags.input_arrays)
- ],
- outputs=[
- _str + ":0" if len(_str.split(":")) == 1 else _str
- for _str in _parse_array(flags.output_arrays)
- ])
- converter = tf.lite.TFLiteConverter.from_concrete_functions([wrap_func])
+ try:
+ _text_format.Merge(file_content, graph_def)
+ except (_text_format.ParseError, DecodeError):
+ raise IOError("Unable to parse input file '{}'.".format(flags.input_path))
+
+ wrap_func = wrap_frozen_graph(
+ graph_def,
+ inputs=[
+ _str + ":0" if len(_str.split(":")) == 1 else _str
+ for _str in _parse_array(flags.input_arrays)
+ ],
+ outputs=[
+ _str + ":0" if len(_str.split(":")) == 1 else _str
+ for _str in _parse_array(flags.output_arrays)
+ ])
+ converter = tf.lite.TFLiteConverter.from_concrete_functions([wrap_func])
+
+ if flags.model_format == "saved_model":
+ converter = tf.lite.TFLiteConverter.from_saved_model(flags.input_path)
+
+ if flags.model_format == "keras_model":
+ keras_model = tf.keras.models.load_model(flags.input_path)
+ converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
converter.allow_custom_ops = True
converter.experimental_new_converter = True