Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / compiler / tf2tfliteV2 / tf2tfliteV2.py
1 #!/usr/bin/env python3
2
3 # Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 # Copyright (C) 2018 The TensorFlow Authors
5 #
6 # Licensed under the Apache License, Version 2.0 (the "License");
7 # you may not use this file except in compliance with the License.
8 # You may obtain a copy of the License at
9 #
10 #    http://www.apache.org/licenses/LICENSE-2.0
11 #
12 # Unless required by applicable law or agreed to in writing, software
13 # distributed under the License is distributed on an "AS IS" BASIS,
14 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 # See the License for the specific language governing permissions and
16 # limitations under the License.
17
18 import tensorflow as tf
19 import argparse
20 import sys
21
22 from google.protobuf.message import DecodeError
23 from google.protobuf import text_format as _text_format
24
25
26 def wrap_frozen_graph(graph_def, inputs, outputs):
27     def _imports_graph_def():
28         tf.compat.v1.import_graph_def(graph_def, name="")
29
30     wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
31     import_graph = wrapped_import.graph
32     return wrapped_import.prune(
33         tf.nest.map_structure(import_graph.as_graph_element, inputs),
34         tf.nest.map_structure(import_graph.as_graph_element, outputs))
35
36
37 def _get_parser():
38     """
39   Returns an ArgumentParser for TensorFlow Lite Converter.
40   """
41     parser = argparse.ArgumentParser(
42         description=("Command line tool to run TensorFlow Lite Converter."))
43
44     # Converter version.
45     converter_version = parser.add_mutually_exclusive_group(required=True)
46     converter_version.add_argument(
47         "--v1", action="store_true", help="Use TensorFlow Lite Converter 1.x")
48     converter_version.add_argument(
49         "--v2", action="store_true", help="Use TensorFlow Lite Converter 2.x")
50
51     # Input model format
52     model_format_arg = parser.add_mutually_exclusive_group()
53     model_format_arg.add_argument(
54         "--graph_def",
55         action="store_const",
56         dest="model_format",
57         const="graph_def",
58         help="Use graph def file(default)")
59     model_format_arg.add_argument(
60         "--saved_model",
61         action="store_const",
62         dest="model_format",
63         const="saved_model",
64         help="Use saved model")
65     model_format_arg.add_argument(
66         "--keras_model",
67         action="store_const",
68         dest="model_format",
69         const="keras_model",
70         help="Use keras model")
71
72     # Input and output path.
73     parser.add_argument(
74         "-i",
75         "--input_path",
76         type=str,
77         help="Full filepath of the input file.",
78         required=True)
79     parser.add_argument(
80         "-o",
81         "--output_path",
82         type=str,
83         help="Full filepath of the output file.",
84         required=True)
85
86     # Input and output arrays.
87     parser.add_argument(
88         "-I",
89         "--input_arrays",
90         type=str,
91         help="Names of the input arrays, comma-separated.",
92         required=True)
93     parser.add_argument(
94         "-s",
95         "--input_shapes",
96         type=str,
97         help=
98         "Shapes corresponding to --input_arrays, colon-separated.(ex:\"1,4,4,3:1,20,20,3\")"
99     )
100     parser.add_argument(
101         "-O",
102         "--output_arrays",
103         type=str,
104         help="Names of the output arrays, comma-separated.",
105         required=True)
106
107     # Set default value
108     parser.set_defaults(model_format="graph_def")
109     return parser
110
111
112 def _check_flags(flags):
113     """
114   Checks the parsed flags to ensure they are valid.
115   """
116     if flags.v1:
117         invalid = ""
118         # To be filled
119
120         if invalid:
121             raise ValueError(invalid + " options must be used with v2")
122
123     if flags.v2:
124         if tf.__version__.find("2.") != 0:
125             raise ValueError(
126                 "Imported TensorFlow should have version >= 2.0 but you have " +
127                 tf.__version__)
128
129         invalid = ""
130         # To be filled
131
132         if invalid:
133             raise ValueError(invalid + " options must be used with v1")
134
135     if flags.input_shapes:
136         if not flags.input_arrays:
137             raise ValueError("--input_shapes must be used with --input_arrays")
138         if flags.input_shapes.count(":") != flags.input_arrays.count(","):
139             raise ValueError("--input_shapes and --input_arrays must have the same "
140                              "number of items")
141
142
143 def _parse_array(arrays, type_fn=str):
144     return list(map(type_fn, arrays.split(",")))
145
146
147 def _v1_convert(flags):
148     if flags.model_format == "graph_def":
149         input_shapes = None
150         if flags.input_shapes:
151             input_arrays = _parse_array(flags.input_arrays)
152             input_shapes_list = [
153                 _parse_array(shape, type_fn=int)
154                 for shape in flags.input_shapes.split(":")
155             ]
156             input_shapes = dict(list(zip(input_arrays, input_shapes_list)))
157
158         converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
159             flags.input_path, _parse_array(flags.input_arrays),
160             _parse_array(flags.output_arrays), input_shapes)
161
162     if flags.model_format == "saved_model":
163         converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(flags.input_path)
164
165     if flags.model_format == "keras_model":
166         converter = tf.compat.v1.lite.TFLiteConverter.from_keras_model_file(
167             flags.input_path)
168
169     converter.allow_custom_ops = True
170
171     tflite_model = converter.convert()
172     open(flags.output_path, "wb").write(tflite_model)
173
174
175 def _v2_convert(flags):
176     if flags.model_format == "graph_def":
177         file_content = open(flags.input_path, 'rb').read()
178         try:
179             graph_def = tf.compat.v1.GraphDef()
180             graph_def.ParseFromString(file_content)
181         except (_text_format.ParseError, DecodeError):
182             try:
183                 _text_format.Merge(file_content, graph_def)
184             except (_text_format.ParseError, DecodeError):
185                 raise IOError("Unable to parse input file '{}'.".format(flags.input_path))
186
187         wrap_func = wrap_frozen_graph(
188             graph_def,
189             inputs=[
190                 _str + ":0" if len(_str.split(":")) == 1 else _str
191                 for _str in _parse_array(flags.input_arrays)
192             ],
193             outputs=[
194                 _str + ":0" if len(_str.split(":")) == 1 else _str
195                 for _str in _parse_array(flags.output_arrays)
196             ])
197         converter = tf.lite.TFLiteConverter.from_concrete_functions([wrap_func])
198
199     if flags.model_format == "saved_model":
200         converter = tf.lite.TFLiteConverter.from_saved_model(flags.input_path)
201
202     if flags.model_format == "keras_model":
203         keras_model = tf.keras.models.load_model(flags.input_path)
204         converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
205
206     converter.allow_custom_ops = True
207     converter.experimental_new_converter = True
208
209     converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
210
211     tflite_model = converter.convert()
212     open(flags.output_path, "wb").write(tflite_model)
213
214
215 def _convert(flags):
216     if (flags.v1):
217         _v1_convert(flags)
218     else:
219         _v2_convert(flags)
220
221
222 """
223 Input frozen graph must be from TensorFlow 1.13.1
224 """
225
226
227 def main():
228     # Parse argument.
229     parser = _get_parser()
230
231     # Check if the flags are valid.
232     flags = parser.parse_known_args(args=sys.argv[1:])
233     _check_flags(flags[0])
234
235     # Convert
236     _convert(flags[0])
237
238
239 if __name__ == "__main__":
240     main()