Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / tf2tfliteV2 / tf2tfliteV2.py
1 #!/usr/bin/env python
2
3 # Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 # Copyright (C) 2018 The TensorFlow Authors
5 #
6 # Licensed under the Apache License, Version 2.0 (the "License");
7 # you may not use this file except in compliance with the License.
8 # You may obtain a copy of the License at
9 #
10 #    http://www.apache.org/licenses/LICENSE-2.0
11 #
12 # Unless required by applicable law or agreed to in writing, software
13 # distributed under the License is distributed on an "AS IS" BASIS,
14 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 # See the License for the specific language governing permissions and
16 # limitations under the License.
17
18 import tensorflow as tf
19 import argparse
20 import sys
21
22 from google.protobuf.message import DecodeError
23 from google.protobuf import text_format as _text_format
24
25
26 def wrap_frozen_graph(graph_def, inputs, outputs):
27     def _imports_graph_def():
28         tf.compat.v1.import_graph_def(graph_def, name="")
29
30     wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
31     import_graph = wrapped_import.graph
32     return wrapped_import.prune(
33         tf.nest.map_structure(import_graph.as_graph_element, inputs),
34         tf.nest.map_structure(import_graph.as_graph_element, outputs))
35
36
37 def _get_parser():
38     """
39   Returns an ArgumentParser for TensorFlow Lite Converter.
40   """
41     parser = argparse.ArgumentParser(
42         description=("Command line tool to run TensorFlow Lite Converter."))
43
44     # Converter version.
45     converter_version = parser.add_mutually_exclusive_group(required=True)
46     converter_version.add_argument(
47         "--v1", action="store_true", help="Use TensorFlow Lite Converter 1.x")
48     converter_version.add_argument(
49         "--v2", action="store_true", help="Use TensorFlow Lite Converter 2.x")
50
51     # Input and output path.
52     parser.add_argument(
53         "-i",
54         "--input_path",
55         type=str,
56         help="Full filepath of the input file.",
57         required=True)
58     parser.add_argument(
59         "-o",
60         "--output_path",
61         type=str,
62         help="Full filepath of the output file.",
63         required=True)
64
65     # Input and output arrays.
66     parser.add_argument(
67         "-I",
68         "--input_arrays",
69         type=str,
70         help="Names of the input arrays, comma-separated.",
71         required=True)
72     parser.add_argument(
73         "-s",
74         "--input_shapes",
75         type=str,
76         help=
77         "Shapes corresponding to --input_arrays, colon-separated.(ex:\"1,4,4,3:1,20,20,3\")"
78     )
79     parser.add_argument(
80         "-O",
81         "--output_arrays",
82         type=str,
83         help="Names of the output arrays, comma-separated.",
84         required=True)
85
86     return parser
87
88
89 def _check_flags(flags):
90     """
91   Checks the parsed flags to ensure they are valid.
92   """
93     if flags.v1:
94         invalid = ""
95         # To be filled
96
97         if invalid:
98             raise ValueError(invalid + " options must be used with v2")
99
100     if flags.v2:
101         if tf.__version__.find("2.") != 0:
102             raise ValueError(
103                 "Imported TensorFlow should have version >= 2.0 but you have " +
104                 tf.__version__)
105
106         invalid = ""
107         # To be filled
108
109         if invalid:
110             raise ValueError(invalid + " options must be used with v1")
111
112     if flags.input_shapes:
113         if not flags.input_arrays:
114             raise ValueError("--input_shapes must be used with --input_arrays")
115         if flags.input_shapes.count(":") != flags.input_arrays.count(","):
116             raise ValueError("--input_shapes and --input_arrays must have the same "
117                              "number of items")
118
119
120 def _parse_array(arrays, type_fn=str):
121     return list(map(type_fn, arrays.split(",")))
122
123
124 def _v1_convert(flags):
125     input_shapes = None
126     if flags.input_shapes:
127         input_arrays = _parse_array(flags.input_arrays)
128         input_shapes_list = [
129             _parse_array(shape, type_fn=int) for shape in flags.input_shapes.split(":")
130         ]
131         input_shapes = dict(list(zip(input_arrays, input_shapes_list)))
132
133     converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
134         flags.input_path, _parse_array(flags.input_arrays),
135         _parse_array(flags.output_arrays), input_shapes)
136
137     converter.allow_custom_ops = True
138
139     tflite_model = converter.convert()
140     open(flags.output_path, "wb").write(tflite_model)
141
142
143 def _v2_convert(flags):
144     file_content = open(flags.input_path, 'rb').read()
145     try:
146         graph_def = tf.compat.v1.GraphDef()
147         graph_def.ParseFromString(file_content)
148     except (_text_format.ParseError, DecodeError):
149         try:
150             _text_format.Merge(file_content, graph_def)
151         except (_text_format.ParseError, DecodeError):
152             raise IOError("Unable to parse input file '{}'.".format(flags.input_path))
153
154     wrap_func = wrap_frozen_graph(
155         graph_def,
156         inputs=[
157             _str + ":0" if len(_str.split(":")) == 1 else _str
158             for _str in _parse_array(flags.input_arrays)
159         ],
160         outputs=[
161             _str + ":0" if len(_str.split(":")) == 1 else _str
162             for _str in _parse_array(flags.output_arrays)
163         ])
164     converter = tf.lite.TFLiteConverter.from_concrete_functions([wrap_func])
165
166     converter.allow_custom_ops = True
167     converter.experimental_new_converter = True
168
169     converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
170
171     tflite_model = converter.convert()
172     open(flags.output_path, "wb").write(tflite_model)
173
174
175 def _convert(flags):
176     if (flags.v1):
177         _v1_convert(flags)
178     else:
179         _v2_convert(flags)
180
181
182 """
183 Input frozen graph must be from TensorFlow 1.13.1
184 """
185
186
187 def main():
188     # Parse argument.
189     parser = _get_parser()
190
191     # Check if the flags are valid.
192     flags = parser.parse_known_args(args=sys.argv[1:])
193     _check_flags(flags[0])
194
195     # Convert
196     _convert(flags[0])
197
198
199 if __name__ == "__main__":
200     main()