3 # Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 # Copyright (C) 2018 The TensorFlow Authors
6 # Licensed under the Apache License, Version 2.0 (the "License");
7 # you may not use this file except in compliance with the License.
8 # You may obtain a copy of the License at
10 # http://www.apache.org/licenses/LICENSE-2.0
12 # Unless required by applicable law or agreed to in writing, software
13 # distributed under the License is distributed on an "AS IS" BASIS,
14 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 # See the License for the specific language governing permissions and
16 # limitations under the License.
18 import tensorflow as tf
22 from google.protobuf.message import DecodeError
23 from google.protobuf import text_format as _text_format
26 def wrap_frozen_graph(graph_def, inputs, outputs):
27 def _imports_graph_def():
28 tf.compat.v1.import_graph_def(graph_def, name="")
30 wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
31 import_graph = wrapped_import.graph
32 return wrapped_import.prune(
33 tf.nest.map_structure(import_graph.as_graph_element, inputs),
34 tf.nest.map_structure(import_graph.as_graph_element, outputs))
39 Returns an ArgumentParser for TensorFlow Lite Converter.
41 parser = argparse.ArgumentParser(
42 description=("Command line tool to run TensorFlow Lite Converter."))
45 converter_version = parser.add_mutually_exclusive_group(required=True)
46 converter_version.add_argument(
47 "--v1", action="store_true", help="Use TensorFlow Lite Converter 1.x")
48 converter_version.add_argument(
49 "--v2", action="store_true", help="Use TensorFlow Lite Converter 2.x")
52 model_format_arg = parser.add_mutually_exclusive_group()
53 model_format_arg.add_argument(
58 help="Use graph def file(default)")
59 model_format_arg.add_argument(
64 help="Use saved model")
65 model_format_arg.add_argument(
70 help="Use keras model")
72 # Input and output path.
77 help="Full filepath of the input file.",
83 help="Full filepath of the output file.",
86 # Input and output arrays.
91 help="Names of the input arrays, comma-separated.",
98 "Shapes corresponding to --input_arrays, colon-separated.(ex:\"1,4,4,3:1,20,20,3\")"
104 help="Names of the output arrays, comma-separated.",
108 parser.set_defaults(model_format="graph_def")
112 def _check_flags(flags):
114 Checks the parsed flags to ensure they are valid.
121 raise ValueError(invalid + " options must be used with v2")
124 if tf.__version__.find("2.") != 0:
126 "Imported TensorFlow should have version >= 2.0 but you have " +
133 raise ValueError(invalid + " options must be used with v1")
135 if flags.input_shapes:
136 if not flags.input_arrays:
137 raise ValueError("--input_shapes must be used with --input_arrays")
138 if flags.input_shapes.count(":") != flags.input_arrays.count(","):
139 raise ValueError("--input_shapes and --input_arrays must have the same "
143 def _parse_array(arrays, type_fn=str):
144 return list(map(type_fn, arrays.split(",")))
147 def _v1_convert(flags):
148 if flags.model_format == "graph_def":
150 if flags.input_shapes:
151 input_arrays = _parse_array(flags.input_arrays)
152 input_shapes_list = [
153 _parse_array(shape, type_fn=int)
154 for shape in flags.input_shapes.split(":")
156 input_shapes = dict(list(zip(input_arrays, input_shapes_list)))
158 converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
159 flags.input_path, _parse_array(flags.input_arrays),
160 _parse_array(flags.output_arrays), input_shapes)
162 if flags.model_format == "saved_model":
163 converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(flags.input_path)
165 if flags.model_format == "keras_model":
166 converter = tf.compat.v1.lite.TFLiteConverter.from_keras_model_file(
169 converter.allow_custom_ops = True
171 tflite_model = converter.convert()
172 open(flags.output_path, "wb").write(tflite_model)
175 def _v2_convert(flags):
176 if flags.model_format == "graph_def":
177 file_content = open(flags.input_path, 'rb').read()
179 graph_def = tf.compat.v1.GraphDef()
180 graph_def.ParseFromString(file_content)
181 except (_text_format.ParseError, DecodeError):
183 _text_format.Merge(file_content, graph_def)
184 except (_text_format.ParseError, DecodeError):
185 raise IOError("Unable to parse input file '{}'.".format(flags.input_path))
187 wrap_func = wrap_frozen_graph(
190 _str + ":0" if len(_str.split(":")) == 1 else _str
191 for _str in _parse_array(flags.input_arrays)
194 _str + ":0" if len(_str.split(":")) == 1 else _str
195 for _str in _parse_array(flags.output_arrays)
197 converter = tf.lite.TFLiteConverter.from_concrete_functions([wrap_func])
199 if flags.model_format == "saved_model":
200 converter = tf.lite.TFLiteConverter.from_saved_model(flags.input_path)
202 if flags.model_format == "keras_model":
203 keras_model = tf.keras.models.load_model(flags.input_path)
204 converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
206 converter.allow_custom_ops = True
207 converter.experimental_new_converter = True
209 converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
211 tflite_model = converter.convert()
212 open(flags.output_path, "wb").write(tflite_model)
223 Input frozen graph must be from TensorFlow 1.13.1
229 parser = _get_parser()
231 # Check if the flags are valid.
232 flags = parser.parse_known_args(args=sys.argv[1:])
233 _check_flags(flags[0])
239 if __name__ == "__main__":