2 ''''export SCRIPT_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # '''
3 ''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python # '''
4 ''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@" # '''
5 ''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
8 # Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
10 # Licensed under the Apache License, Version 2.0 (the "License");
11 # you may not use this file except in compliance with the License.
12 # You may obtain a copy of the License at
14 # http://www.apache.org/licenses/LICENSE-2.0
16 # Unless required by applicable law or agreed to in writing, software
17 # distributed under the License is distributed on an "AS IS" BASIS,
18 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19 # See the License for the specific language governing permissions and
20 # limitations under the License.
28 import utils as _utils
32 parser = argparse.ArgumentParser(
33 description='command line tool to convert TensorFlow to circle')
35 _utils._add_default_arg(parser)
37 ## tf2tfliteV2 arguments
38 tf2tfliteV2_group = parser.add_argument_group('converter arguments')
41 converter_version = tf2tfliteV2_group.add_mutually_exclusive_group()
42 converter_version.add_argument(
45 dest='converter_version_cmd',
47 help='use TensorFlow Lite Converter 1.x')
48 converter_version.add_argument(
51 dest='converter_version_cmd',
53 help='use TensorFlow Lite Converter 2.x')
55 #converter_version.set_defaults(converter_version='--v1')
57 parser.add_argument('--converter_version', type=str, help=argparse.SUPPRESS)
60 model_format_arg = tf2tfliteV2_group.add_mutually_exclusive_group()
61 model_format_arg.add_argument(
64 dest='model_format_cmd',
66 help='use graph def file(default)')
67 model_format_arg.add_argument(
70 dest='model_format_cmd',
71 const='--saved_model',
72 help='use saved model')
73 model_format_arg.add_argument(
76 dest='model_format_cmd',
77 const='--keras_model',
78 help='use keras model')
80 parser.add_argument('--model_format', type=str, help=argparse.SUPPRESS)
82 # input and output path.
83 tf2tfliteV2_group.add_argument(
84 '-i', '--input_path', type=str, help='full filepath of the input file')
85 tf2tfliteV2_group.add_argument(
86 '-o', '--output_path', type=str, help='full filepath of the output file')
88 # input and output arrays.
89 tf2tfliteV2_group.add_argument(
93 help='names of the input arrays, comma-separated')
94 tf2tfliteV2_group.add_argument(
99 'shapes corresponding to --input_arrays, colon-separated (ex:"1,4,4,3:1,20,20,3")'
101 tf2tfliteV2_group.add_argument(
105 help='names of the output arrays, comma-separated')
110 def _verify_arg(parser, args):
111 """verify given arguments"""
112 # check if required arguments is given
114 if not _utils._is_valid_attr(args, 'input_path'):
115 missing.append('-i/--input_path')
116 if not _utils._is_valid_attr(args, 'output_path'):
117 missing.append('-o/--output_path')
119 parser.error('the following arguments are required: ' + ' '.join(missing))
122 def _parse_arg(parser):
123 args = parser.parse_args()
126 _utils._print_version_and_exit(__file__)
132 # get file path to log
133 dir_path = os.path.dirname(os.path.realpath(__file__))
134 logfile_path = os.path.realpath(args.output_path) + '.log'
136 with open(logfile_path, 'wb') as f, tempfile.TemporaryDirectory() as tmpdir:
137 # make a command to convert from tf to tflite
138 tf2tfliteV2_path = os.path.join(dir_path, 'tf2tfliteV2.py')
139 tf2tfliteV2_output_path = os.path.join(
141 os.path.splitext(os.path.basename(args.output_path))[0]) + '.tflite'
142 tf2tfliteV2_cmd = _utils._make_tf2tfliteV2_cmd(args, tf2tfliteV2_path,
143 getattr(args, 'input_path'),
144 tf2tfliteV2_output_path)
146 f.write((' '.join(tf2tfliteV2_cmd) + '\n').encode())
148 # convert tf to tflite
149 with subprocess.Popen(
150 tf2tfliteV2_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
152 for line in p.stdout:
153 sys.stdout.buffer.write(line)
155 if p.returncode != 0:
156 sys.exit(p.returncode)
158 # make a command to convert from tflite to circle
159 tflite2circle_path = os.path.join(dir_path, 'tflite2circle')
160 tflite2circle_cmd = _utils._make_tflite2circle_cmd(tflite2circle_path,
161 tf2tfliteV2_output_path,
162 getattr(args, 'output_path'))
164 f.write((' '.join(tflite2circle_cmd) + '\n').encode())
166 # convert tflite to circle
167 with subprocess.Popen(
169 stdout=subprocess.PIPE,
170 stderr=subprocess.STDOUT,
172 for line in p.stdout:
173 sys.stdout.buffer.write(line)
175 if p.returncode != 0:
176 sys.exit(p.returncode)
181 parser = _get_parser()
182 args = _parse_arg(parser)
184 # parse configuration file
185 _utils._parse_cfg(args, 'one-import-tf')
188 _verify_arg(parser, args)
194 if __name__ == '__main__':