Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / compiler / one-cmds / one-import-tf
1 #!/usr/bin/env bash
2 ''''export SCRIPT_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"                  # '''
3 ''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python                                       # '''
4 ''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@"                                     # '''
5 ''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
6 ''''exit 255                                                                            # '''
7
8 # Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
9 #
10 # Licensed under the Apache License, Version 2.0 (the "License");
11 # you may not use this file except in compliance with the License.
12 # You may obtain a copy of the License at
13 #
14 #    http://www.apache.org/licenses/LICENSE-2.0
15 #
16 # Unless required by applicable law or agreed to in writing, software
17 # distributed under the License is distributed on an "AS IS" BASIS,
18 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19 # See the License for the specific language governing permissions and
20 # limitations under the License.
21
22 import argparse
23 import os
24 import subprocess
25 import sys
26 import tempfile
27
28 import utils as _utils
29
30
31 def _get_parser():
32     parser = argparse.ArgumentParser(
33         description='command line tool to convert TensorFlow to circle')
34
35     _utils._add_default_arg(parser)
36
37     ## tf2tfliteV2 arguments
38     tf2tfliteV2_group = parser.add_argument_group('converter arguments')
39
40     # converter version
41     converter_version = tf2tfliteV2_group.add_mutually_exclusive_group()
42     converter_version.add_argument(
43         '--v1',
44         action='store_const',
45         dest='converter_version_cmd',
46         const='--v1',
47         help='use TensorFlow Lite Converter 1.x')
48     converter_version.add_argument(
49         '--v2',
50         action='store_const',
51         dest='converter_version_cmd',
52         const='--v2',
53         help='use TensorFlow Lite Converter 2.x')
54
55     parser.add_argument('--converter_version', type=str, help=argparse.SUPPRESS)
56
57     # input model format
58     model_format_arg = tf2tfliteV2_group.add_mutually_exclusive_group()
59     model_format_arg.add_argument(
60         '--graph_def',
61         action='store_const',
62         dest='model_format_cmd',
63         const='--graph_def',
64         help='use graph def file(default)')
65     model_format_arg.add_argument(
66         '--saved_model',
67         action='store_const',
68         dest='model_format_cmd',
69         const='--saved_model',
70         help='use saved model')
71     model_format_arg.add_argument(
72         '--keras_model',
73         action='store_const',
74         dest='model_format_cmd',
75         const='--keras_model',
76         help='use keras model')
77
78     parser.add_argument('--model_format', type=str, help=argparse.SUPPRESS)
79
80     # input and output path.
81     tf2tfliteV2_group.add_argument(
82         '-i', '--input_path', type=str, help='full filepath of the input file')
83     tf2tfliteV2_group.add_argument(
84         '-o', '--output_path', type=str, help='full filepath of the output file')
85
86     # input and output arrays.
87     tf2tfliteV2_group.add_argument(
88         '-I',
89         '--input_arrays',
90         type=str,
91         help='names of the input arrays, comma-separated')
92     tf2tfliteV2_group.add_argument(
93         '-s',
94         '--input_shapes',
95         type=str,
96         help=
97         'shapes corresponding to --input_arrays, colon-separated (ex:"1,4,4,3:1,20,20,3")'
98     )
99     tf2tfliteV2_group.add_argument(
100         '-O',
101         '--output_arrays',
102         type=str,
103         help='names of the output arrays, comma-separated')
104
105     return parser
106
107
108 def _verify_arg(parser, args):
109     """verify given arguments"""
110     # check if required arguments is given
111     missing = []
112     if not _utils._is_valid_attr(args, 'input_path'):
113         missing.append('-i/--input_path')
114     if not _utils._is_valid_attr(args, 'output_path'):
115         missing.append('-o/--output_path')
116     if len(missing):
117         parser.error('the following arguments are required: ' + ' '.join(missing))
118
119
120 def _parse_arg(parser):
121     args = parser.parse_args()
122     # print version
123     if args.version:
124         _utils._print_version_and_exit(__file__)
125
126     return args
127
128
129 def _convert(args):
130     # get file path to log
131     dir_path = os.path.dirname(os.path.realpath(__file__))
132     logfile_path = os.path.realpath(args.output_path) + '.log'
133
134     with open(logfile_path, 'wb') as f, tempfile.TemporaryDirectory() as tmpdir:
135         # make a command to convert from tf to tflite
136         tf2tfliteV2_path = os.path.join(dir_path, 'tf2tfliteV2.py')
137         tf2tfliteV2_output_path = os.path.join(
138             tmpdir,
139             os.path.splitext(os.path.basename(args.output_path))[0]) + '.tflite'
140         tf2tfliteV2_cmd = _utils._make_tf2tfliteV2_cmd(args, tf2tfliteV2_path,
141                                                        getattr(args, 'input_path'),
142                                                        tf2tfliteV2_output_path)
143
144         f.write((' '.join(tf2tfliteV2_cmd) + '\n').encode())
145
146         # convert tf to tflite
147         with subprocess.Popen(
148                 tf2tfliteV2_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
149                 bufsize=1) as p:
150             for line in p.stdout:
151                 sys.stdout.buffer.write(line)
152                 f.write(line)
153         if p.returncode != 0:
154             sys.exit(p.returncode)
155
156         # make a command to convert from tflite to circle
157         tflite2circle_path = os.path.join(dir_path, 'tflite2circle')
158         tflite2circle_cmd = _utils._make_tflite2circle_cmd(tflite2circle_path,
159                                                            tf2tfliteV2_output_path,
160                                                            getattr(args, 'output_path'))
161
162         f.write((' '.join(tflite2circle_cmd) + '\n').encode())
163
164         # convert tflite to circle
165         with subprocess.Popen(
166                 tflite2circle_cmd,
167                 stdout=subprocess.PIPE,
168                 stderr=subprocess.STDOUT,
169                 bufsize=1) as p:
170             for line in p.stdout:
171                 sys.stdout.buffer.write(line)
172                 f.write(line)
173         if p.returncode != 0:
174             sys.exit(p.returncode)
175
176
177 def main():
178     # parse arguments
179     parser = _get_parser()
180     args = _parse_arg(parser)
181
182     # parse configuration file
183     _utils._parse_cfg(args, 'one-import-tf')
184
185     # verify arguments
186     _verify_arg(parser, args)
187
188     # convert
189     _convert(args)
190
191
192 if __name__ == '__main__':
193     main()