49009d3313dcc57070443440dd4f68fa84d3895b
[platform/core/ml/nnfw.git] / compiler / one-cmds / one-import-tf
1 #!/usr/bin/env bash
2 ''''export SCRIPT_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"                  # '''
3 ''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python                                       # '''
4 ''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@"                                     # '''
5 ''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
6 ''''exit 255                                                                            # '''
7
8 # Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
9 #
10 # Licensed under the Apache License, Version 2.0 (the "License");
11 # you may not use this file except in compliance with the License.
12 # You may obtain a copy of the License at
13 #
14 #    http://www.apache.org/licenses/LICENSE-2.0
15 #
16 # Unless required by applicable law or agreed to in writing, software
17 # distributed under the License is distributed on an "AS IS" BASIS,
18 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19 # See the License for the specific language governing permissions and
20 # limitations under the License.
21
22 import argparse
23 import os
24 import subprocess
25 import sys
26 import tempfile
27
28 import utils as _utils
29
30
31 def _get_parser():
32     parser = argparse.ArgumentParser(
33         description='command line tool to convert TensorFlow to circle')
34
35     _utils._add_default_arg(parser)
36
37     ## tf2tfliteV2 arguments
38     tf2tfliteV2_group = parser.add_argument_group('converter arguments')
39
40     # converter version
41     converter_version = tf2tfliteV2_group.add_mutually_exclusive_group()
42     converter_version.add_argument(
43         '--v1',
44         action='store_const',
45         dest='converter_version_cmd',
46         const='--v1',
47         help='use TensorFlow Lite Converter 1.x')
48     converter_version.add_argument(
49         '--v2',
50         action='store_const',
51         dest='converter_version_cmd',
52         const='--v2',
53         help='use TensorFlow Lite Converter 2.x')
54
55     #converter_version.set_defaults(converter_version='--v1')
56
57     parser.add_argument('--converter_version', type=str, help=argparse.SUPPRESS)
58
59     # input model format
60     model_format_arg = tf2tfliteV2_group.add_mutually_exclusive_group()
61     model_format_arg.add_argument(
62         '--graph_def',
63         action='store_const',
64         dest='model_format_cmd',
65         const='--graph_def',
66         help='use graph def file(default)')
67     model_format_arg.add_argument(
68         '--saved_model',
69         action='store_const',
70         dest='model_format_cmd',
71         const='--saved_model',
72         help='use saved model')
73     model_format_arg.add_argument(
74         '--keras_model',
75         action='store_const',
76         dest='model_format_cmd',
77         const='--keras_model',
78         help='use keras model')
79
80     parser.add_argument('--model_format', type=str, help=argparse.SUPPRESS)
81
82     # input and output path.
83     tf2tfliteV2_group.add_argument(
84         '-i', '--input_path', type=str, help='full filepath of the input file')
85     tf2tfliteV2_group.add_argument(
86         '-o', '--output_path', type=str, help='full filepath of the output file')
87
88     # input and output arrays.
89     tf2tfliteV2_group.add_argument(
90         '-I',
91         '--input_arrays',
92         type=str,
93         help='names of the input arrays, comma-separated')
94     tf2tfliteV2_group.add_argument(
95         '-s',
96         '--input_shapes',
97         type=str,
98         help=
99         'shapes corresponding to --input_arrays, colon-separated (ex:"1,4,4,3:1,20,20,3")'
100     )
101     tf2tfliteV2_group.add_argument(
102         '-O',
103         '--output_arrays',
104         type=str,
105         help='names of the output arrays, comma-separated')
106
107     return parser
108
109
110 def _verify_arg(parser, args):
111     """verify given arguments"""
112     # check if required arguments is given
113     missing = []
114     if not _utils._is_valid_attr(args, 'input_path'):
115         missing.append('-i/--input_path')
116     if not _utils._is_valid_attr(args, 'output_path'):
117         missing.append('-o/--output_path')
118     if len(missing):
119         parser.error('the following arguments are required: ' + ' '.join(missing))
120
121
122 def _parse_arg(parser):
123     args = parser.parse_args()
124     # print version
125     if args.version:
126         _utils._print_version_and_exit(__file__)
127
128     return args
129
130
131 def _convert(args):
132     # get file path to log
133     dir_path = os.path.dirname(os.path.realpath(__file__))
134     logfile_path = os.path.realpath(args.output_path) + '.log'
135
136     with open(logfile_path, 'wb') as f, tempfile.TemporaryDirectory() as tmpdir:
137         # make a command to convert from tf to tflite
138         tf2tfliteV2_path = os.path.join(dir_path, 'tf2tfliteV2.py')
139         tf2tfliteV2_output_path = os.path.join(
140             tmpdir,
141             os.path.splitext(os.path.basename(args.output_path))[0]) + '.tflite'
142         tf2tfliteV2_cmd = _utils._make_tf2tfliteV2_cmd(args, tf2tfliteV2_path,
143                                                        getattr(args, 'input_path'),
144                                                        tf2tfliteV2_output_path)
145
146         f.write((' '.join(tf2tfliteV2_cmd) + '\n').encode())
147
148         # convert tf to tflite
149         with subprocess.Popen(
150                 tf2tfliteV2_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
151                 bufsize=1) as p:
152             for line in p.stdout:
153                 sys.stdout.buffer.write(line)
154                 f.write(line)
155         if p.returncode != 0:
156             sys.exit(p.returncode)
157
158         # make a command to convert from tflite to circle
159         tflite2circle_path = os.path.join(dir_path, 'tflite2circle')
160         tflite2circle_cmd = _utils._make_tflite2circle_cmd(tflite2circle_path,
161                                                            tf2tfliteV2_output_path,
162                                                            getattr(args, 'output_path'))
163
164         f.write((' '.join(tflite2circle_cmd) + '\n').encode())
165
166         # convert tflite to circle
167         with subprocess.Popen(
168                 tflite2circle_cmd,
169                 stdout=subprocess.PIPE,
170                 stderr=subprocess.STDOUT,
171                 bufsize=1) as p:
172             for line in p.stdout:
173                 sys.stdout.buffer.write(line)
174                 f.write(line)
175         if p.returncode != 0:
176             sys.exit(p.returncode)
177
178
179 def main():
180     # parse arguments
181     parser = _get_parser()
182     args = _parse_arg(parser)
183
184     # parse configuration file
185     _utils._parse_cfg(args, 'one-import-tf')
186
187     # verify arguments
188     _verify_arg(parser, args)
189
190     # convert
191     _convert(args)
192
193
194 if __name__ == '__main__':
195     main()