3 # Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 # Copyright (C) 2018 The TensorFlow Authors
6 # Licensed under the Apache License, Version 2.0 (the "License");
7 # you may not use this file except in compliance with the License.
8 # You may obtain a copy of the License at
10 # http://www.apache.org/licenses/LICENSE-2.0
12 # Unless required by applicable law or agreed to in writing, software
13 # distributed under the License is distributed on an "AS IS" BASIS,
14 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 # See the License for the specific language governing permissions and
16 # limitations under the License.
18 import tensorflow as tf
22 from google.protobuf.message import DecodeError
23 from google.protobuf import text_format as _text_format
26 def wrap_frozen_graph(graph_def, inputs, outputs):
27 def _imports_graph_def():
28 tf.compat.v1.import_graph_def(graph_def, name="")
30 wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
31 import_graph = wrapped_import.graph
32 return wrapped_import.prune(
33 tf.nest.map_structure(import_graph.as_graph_element, inputs),
34 tf.nest.map_structure(import_graph.as_graph_element, outputs))
39 Returns an ArgumentParser for TensorFlow Lite Converter.
41 parser = argparse.ArgumentParser(
42 description=("Command line tool to run TensorFlow Lite Converter."))
45 converter_version = parser.add_mutually_exclusive_group(required=True)
46 converter_version.add_argument(
47 "--v1", action="store_true", help="Use TensorFlow Lite Converter 1.x")
48 converter_version.add_argument(
49 "--v2", action="store_true", help="Use TensorFlow Lite Converter 2.x")
51 # Input and output path.
56 help="Full filepath of the input file.",
62 help="Full filepath of the output file.",
65 # Input and output arrays.
70 help="Names of the input arrays, comma-separated.",
77 "Shapes corresponding to --input_arrays, colon-separated.(ex:\"1,4,4,3:1,20,20,3\")"
83 help="Names of the output arrays, comma-separated.",
89 def _check_flags(flags):
91 Checks the parsed flags to ensure they are valid.
98 raise ValueError(invalid + " options must be used with v2")
101 if tf.__version__.find("2.") != 0:
103 "Imported TensorFlow should have version >= 2.0 but you have " +
110 raise ValueError(invalid + " options must be used with v1")
112 if flags.input_shapes:
113 if not flags.input_arrays:
114 raise ValueError("--input_shapes must be used with --input_arrays")
115 if flags.input_shapes.count(":") != flags.input_arrays.count(","):
116 raise ValueError("--input_shapes and --input_arrays must have the same "
120 def _parse_array(arrays, type_fn=str):
121 return list(map(type_fn, arrays.split(",")))
124 def _v1_convert(flags):
126 if flags.input_shapes:
127 input_arrays = _parse_array(flags.input_arrays)
128 input_shapes_list = [
129 _parse_array(shape, type_fn=int) for shape in flags.input_shapes.split(":")
131 input_shapes = dict(list(zip(input_arrays, input_shapes_list)))
133 converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
134 flags.input_path, _parse_array(flags.input_arrays),
135 _parse_array(flags.output_arrays), input_shapes)
137 converter.allow_custom_ops = True
139 tflite_model = converter.convert()
140 open(flags.output_path, "wb").write(tflite_model)
143 def _v2_convert(flags):
144 file_content = open(flags.input_path, 'rb').read()
146 graph_def = tf.compat.v1.GraphDef()
147 graph_def.ParseFromString(file_content)
148 except (_text_format.ParseError, DecodeError):
150 _text_format.Merge(file_content, graph_def)
151 except (_text_format.ParseError, DecodeError):
152 raise IOError("Unable to parse input file '{}'.".format(flags.input_path))
154 wrap_func = wrap_frozen_graph(
157 _str + ":0" if len(_str.split(":")) == 1 else _str
158 for _str in _parse_array(flags.input_arrays)
161 _str + ":0" if len(_str.split(":")) == 1 else _str
162 for _str in _parse_array(flags.output_arrays)
164 converter = tf.lite.TFLiteConverter.from_concrete_functions([wrap_func])
166 converter.allow_custom_ops = True
167 converter.experimental_new_converter = True
169 converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
171 tflite_model = converter.convert()
172 open(flags.output_path, "wb").write(tflite_model)
183 Input frozen graph must be from TensorFlow 1.13.1
189 parser = _get_parser()
191 # Check if the flags are valid.
192 flags = parser.parse_known_args(args=sys.argv[1:])
193 _check_flags(flags[0])
199 if __name__ == "__main__":