2 ''''export SCRIPT_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # '''
3 ''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python # '''
4 ''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@" # '''
5 ''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
8 # Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
10 # Licensed under the Apache License, Version 2.0 (the "License");
11 # you may not use this file except in compliance with the License.
12 # You may obtain a copy of the License at
14 # http://www.apache.org/licenses/LICENSE-2.0
16 # Unless required by applicable law or agreed to in writing, software
17 # distributed under the License is distributed on an "AS IS" BASIS,
18 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19 # See the License for the specific language governing permissions and
20 # limitations under the License.
28 import utils as _utils
32 parser = argparse.ArgumentParser(
33 description='command line tool to convert TensorFlow to circle')
35 _utils._add_default_arg(parser)
37 ## tf2tfliteV2 arguments
38 tf2tfliteV2_group = parser.add_argument_group('converter arguments')
41 converter_version = tf2tfliteV2_group.add_mutually_exclusive_group()
42 converter_version.add_argument(
45 dest='converter_version_cmd',
47 help='use TensorFlow Lite Converter 1.x')
48 converter_version.add_argument(
51 dest='converter_version_cmd',
53 help='use TensorFlow Lite Converter 2.x')
55 parser.add_argument('--converter_version', type=str, help=argparse.SUPPRESS)
58 model_format_arg = tf2tfliteV2_group.add_mutually_exclusive_group()
59 model_format_arg.add_argument(
62 dest='model_format_cmd',
64 help='use graph def file(default)')
65 model_format_arg.add_argument(
68 dest='model_format_cmd',
69 const='--saved_model',
70 help='use saved model')
71 model_format_arg.add_argument(
74 dest='model_format_cmd',
75 const='--keras_model',
76 help='use keras model')
78 parser.add_argument('--model_format', type=str, help=argparse.SUPPRESS)
80 # input and output path.
81 tf2tfliteV2_group.add_argument(
82 '-i', '--input_path', type=str, help='full filepath of the input file')
83 tf2tfliteV2_group.add_argument(
84 '-o', '--output_path', type=str, help='full filepath of the output file')
86 # input and output arrays.
87 tf2tfliteV2_group.add_argument(
91 help='names of the input arrays, comma-separated')
92 tf2tfliteV2_group.add_argument(
97 'shapes corresponding to --input_arrays, colon-separated (ex:"1,4,4,3:1,20,20,3")'
99 tf2tfliteV2_group.add_argument(
103 help='names of the output arrays, comma-separated')
108 def _verify_arg(parser, args):
109 """verify given arguments"""
110 # check if required arguments is given
112 if not _utils._is_valid_attr(args, 'input_path'):
113 missing.append('-i/--input_path')
114 if not _utils._is_valid_attr(args, 'output_path'):
115 missing.append('-o/--output_path')
117 parser.error('the following arguments are required: ' + ' '.join(missing))
120 def _parse_arg(parser):
121 args = parser.parse_args()
124 _utils._print_version_and_exit(__file__)
130 # get file path to log
131 dir_path = os.path.dirname(os.path.realpath(__file__))
132 logfile_path = os.path.realpath(args.output_path) + '.log'
134 with open(logfile_path, 'wb') as f, tempfile.TemporaryDirectory() as tmpdir:
135 # make a command to convert from tf to tflite
136 tf2tfliteV2_path = os.path.join(dir_path, 'tf2tfliteV2.py')
137 tf2tfliteV2_output_path = os.path.join(
139 os.path.splitext(os.path.basename(args.output_path))[0]) + '.tflite'
140 tf2tfliteV2_cmd = _utils._make_tf2tfliteV2_cmd(args, tf2tfliteV2_path,
141 getattr(args, 'input_path'),
142 tf2tfliteV2_output_path)
144 f.write((' '.join(tf2tfliteV2_cmd) + '\n').encode())
146 # convert tf to tflite
147 with subprocess.Popen(
148 tf2tfliteV2_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
150 for line in p.stdout:
151 sys.stdout.buffer.write(line)
153 if p.returncode != 0:
154 sys.exit(p.returncode)
156 # make a command to convert from tflite to circle
157 tflite2circle_path = os.path.join(dir_path, 'tflite2circle')
158 tflite2circle_cmd = _utils._make_tflite2circle_cmd(tflite2circle_path,
159 tf2tfliteV2_output_path,
160 getattr(args, 'output_path'))
162 f.write((' '.join(tflite2circle_cmd) + '\n').encode())
164 # convert tflite to circle
165 with subprocess.Popen(
167 stdout=subprocess.PIPE,
168 stderr=subprocess.STDOUT,
170 for line in p.stdout:
171 sys.stdout.buffer.write(line)
173 if p.returncode != 0:
174 sys.exit(p.returncode)
179 parser = _get_parser()
180 args = _parse_arg(parser)
182 # parse configuration file
183 _utils._parse_cfg(args, 'one-import-tf')
186 _verify_arg(parser, args)
192 if __name__ == '__main__':