2 ''''export SCRIPT_PATH="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" # '''
3 ''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python # '''
4 ''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@" # '''
5 ''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
8 # Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
10 # Licensed under the Apache License, Version 2.0 (the "License");
11 # you may not use this file except in compliance with the License.
12 # You may obtain a copy of the License at
14 # http://www.apache.org/licenses/LICENSE-2.0
16 # Unless required by applicable law or agreed to in writing, software
17 # distributed under the License is distributed on an "AS IS" BASIS,
18 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19 # See the License for the specific language governing permissions and
20 # limitations under the License.
28 import utils as _utils
29 import generate_bcq_output_arrays as _bcq_info_gen
33 parser = argparse.ArgumentParser(
34 description='command line tool to convert TensorFlow with BCQ to circle')
36 _utils._add_default_arg(parser)
38 ## tf2tfliteV2 arguments
39 tf2tfliteV2_group = parser.add_argument_group('converter arguments')
42 converter_version = tf2tfliteV2_group.add_mutually_exclusive_group()
43 converter_version.add_argument(
46 dest='converter_version_cmd',
48 help='use TensorFlow Lite Converter 1.x')
49 converter_version.add_argument(
52 dest='converter_version_cmd',
54 help='use TensorFlow Lite Converter 2.x')
56 parser.add_argument('--converter_version', type=str, help=argparse.SUPPRESS)
58 # input and output path.
59 tf2tfliteV2_group.add_argument(
60 '-i', '--input_path', type=str, help='full filepath of the input file')
61 tf2tfliteV2_group.add_argument(
62 '-o', '--output_path', type=str, help='full filepath of the output file')
64 # input and output arrays.
65 tf2tfliteV2_group.add_argument(
69 help='names of the input arrays, comma-separated')
70 tf2tfliteV2_group.add_argument(
75 'shapes corresponding to --input_arrays, colon-separated (ex:"1,4,4,3:1,20,20,3")'
77 tf2tfliteV2_group.add_argument(
81 help='names of the output arrays, comma-separated')
86 def _verify_arg(parser, args):
87 """verify given arguments"""
88 # check if required arguments is given
90 if not _utils._is_valid_attr(args, 'input_path'):
91 missing.append('-i/--input_path')
92 if not _utils._is_valid_attr(args, 'output_path'):
93 missing.append('-o/--output_path')
95 parser.error('the following arguments are required: ' + ' '.join(missing))
98 def _parse_arg(parser):
99 args = parser.parse_args()
102 _utils._print_version_and_exit(__file__)
107 def _make_generate_bcq_metadata_cmd(args, driver_path, output_path):
108 """make a command for running generate_bcq_metadata"""
109 cmd = [sys.executable, driver_path]
111 if _utils._is_valid_attr(args, 'input_path'):
112 cmd.append('--input_path')
113 cmd.append(os.path.expanduser(getattr(args, 'input_path')))
115 if _utils._is_valid_attr(args, 'output_path'):
116 cmd.append('--output_path')
117 cmd.append(os.path.expanduser(output_path))
119 if _utils._is_valid_attr(args, 'output_arrays'):
120 cmd.append('--output_arrays')
121 cmd.append(getattr(args, 'output_arrays'))
127 # get file path to log
128 dir_path = os.path.dirname(os.path.realpath(__file__))
129 logfile_path = os.path.realpath(args.output_path) + '.log'
131 with open(logfile_path, 'wb') as f, tempfile.TemporaryDirectory() as tmpdir:
132 # make a command to generate BCQ information metadata
133 generate_bcq_metadata_path = os.path.join(dir_path, 'generate_bcq_metadata.py')
134 generate_bcq_metadata_output_path = os.path.join(
136 os.path.splitext(os.path.basename(args.input_path))[0] + '_withmeta.pb')
137 generate_bcq_metadata_cmd = _make_generate_bcq_metadata_cmd(
138 args, generate_bcq_metadata_path, generate_bcq_metadata_output_path)
140 f.write((' '.join(generate_bcq_metadata_cmd) + '\n').encode())
142 # generate BCQ information metadata
143 with subprocess.Popen(
144 generate_bcq_metadata_cmd,
145 stdout=subprocess.PIPE,
146 stderr=subprocess.STDOUT,
148 for line in p.stdout:
149 sys.stdout.buffer.write(line)
151 if p.returncode != 0:
152 sys.exit(p.returncode)
154 # get output_arrays with BCQ
155 bcq_output_arrays = _bcq_info_gen.get_bcq_output_arrays(
156 generate_bcq_metadata_output_path, getattr(args, 'output_arrays'))
158 # make a command to convert from tf with BCQ to tflite
159 tf2tfliteV2_path = os.path.join(dir_path, 'tf2tfliteV2.py')
160 tf2tfliteV2_output_path = os.path.join(
163 os.path.basename(generate_bcq_metadata_output_path))[0]) + '.tflite'
164 tf2tfliteV2_cmd = _utils._make_tf2tfliteV2_cmd(args, tf2tfliteV2_path,
165 generate_bcq_metadata_output_path,
166 tf2tfliteV2_output_path)
168 output_arrays_idx = tf2tfliteV2_cmd.index('--output_arrays')
169 tf2tfliteV2_cmd[output_arrays_idx + 1] = ','.join(bcq_output_arrays)
173 f.write((' '.join(tf2tfliteV2_cmd) + '\n').encode())
175 # convert tf to tflite
176 with subprocess.Popen(
177 tf2tfliteV2_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
179 for line in p.stdout:
180 sys.stdout.buffer.write(line)
182 if p.returncode != 0:
183 sys.exit(p.returncode)
185 # make a command to convert from tflite to circle
186 tflite2circle_path = os.path.join(dir_path, 'tflite2circle')
187 tflite2circle_cmd = _utils._make_tflite2circle_cmd(tflite2circle_path,
188 tf2tfliteV2_output_path,
189 getattr(args, 'output_path'))
191 f.write((' '.join(tflite2circle_cmd) + '\n').encode())
193 # convert tflite to circle
194 with subprocess.Popen(
196 stdout=subprocess.PIPE,
197 stderr=subprocess.STDOUT,
199 for line in p.stdout:
200 sys.stdout.buffer.write(line)
202 if p.returncode != 0:
203 sys.exit(p.returncode)
208 parser = _get_parser()
209 args = _parse_arg(parser)
211 # parse configuration file
212 _utils._parse_cfg(args, 'one-import-bcq')
215 _verify_arg(parser, args)
221 if __name__ == '__main__':