2 ''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # '''
3 ''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python # '''
4 ''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@" # '''
5 ''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
8 # Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
10 # Licensed under the Apache License, Version 2.0 (the "License");
11 # you may not use this file except in compliance with the License.
12 # You may obtain a copy of the License at
14 # http://www.apache.org/licenses/LICENSE-2.0
16 # Unless required by applicable law or agreed to in writing, software
17 # distributed under the License is distributed on an "AS IS" BASIS,
18 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19 # See the License for the specific language governing permissions and
20 # limitations under the License.
31 import onelib.utils as oneutils
33 # TODO Find better way to suppress trackback on error
34 sys.tracebacklimit = 0
37 class InputOutputPath:
39 Class that remembers input circle file and output circle file of section k,
41 After calling enter_new_section(),
42 output path in section k will be used as input path of section k+1
45 def __init__(self, initial_input_path: str):
46 self._first_step = True
47 self._input_path = initial_input_path
48 self._output_path = ''
50 def enter_new_section(self, section_output_path: str):
52 Call this when starting a section
54 if self._first_step == True:
55 self._output_path = section_output_path
57 self._input_path = self._output_path
58 self._output_path = section_output_path
60 self._first_step = False
63 return self._input_path
65 def output_path(self):
66 return self._output_path
69 class CommentableConfigParser(configparser.ConfigParser):
71 ConfigParser where comment can be stored
72 In Python ConfigParser, comment in ini file ( starting with ';') is considered a key of which
74 Ref: https://stackoverflow.com/questions/6620637/writing-comments-to-files-with-configparser
78 # allow_no_value=True to add comment
79 # ref: https://stackoverflow.com/a/19432072
80 configparser.ConfigParser.__init__(self, allow_no_value=True)
81 self.optionxform = str
83 def add_comment(self, section, comment):
85 self[section][f'{comment_sign} {comment}'] = None
88 def _get_backends_list():
100 The list where `one-init` finds its backends
101 - `bin` folder where `one-init` exists
104 NOTE If there are backends of the same name in different places,
105 the closer to the top in the list, the higher the priority.
107 dir_path = os.path.dirname(os.path.realpath(__file__))
111 files = [f for f in glob.glob(dir_path + '/*-init')]
113 files += [f for f in glob.glob(dir_path + '/../backends/**/*-init', recursive=True)]
114 # TODO find backends in `$PATH`
118 base = ntpath.basename(cand)
119 if (not base in backend_set) and os.path.isfile(cand) and os.access(
121 backend_set.add(base)
122 backends_list.append(cand)
127 # TODO Add support for TF graphdef and bcq
128 def _get_parser(backends_list):
130 'one-init [-h] [-v] [-V] '
135 # args for onnx model
136 '[--convert_nchw_to_nhwc] '
137 '[--nchw_to_nhwc_input_shape] '
138 '[--nchw_to_nhwc_output_shape] '
139 # args for backend driver
140 '[--] [COMMANDS FOR BACKEND DRIVER]')
143 layout options for onnx model could be difficult to users.
144 In one-init, we could consider easier args for the the above three:
145 For example, we could have another option, e.g., --input_img_layout LAYOUT
146 - When LAYOUT is NHWC, apply 'nchw_to_nhwc_input_shape=True' into cfg
147 - When LAYOUT is NCHW, apply 'nchw_to_nhwc_input_shape=False' into cfg
150 parser = argparse.ArgumentParser(
151 description='Command line tool to generate initial cfg file. '
152 'Currently tflite and onnx models are supported',
155 oneutils.add_default_arg_no_CS(parser)
158 '-i', '--input_path', type=str, help='full filepath of the input model file')
160 '-o', '--output_path', type=str, help='full filepath of the output cfg file')
165 help=('type of input model: "onnx", "tflite". '
166 'If the file extension passed to --input_path is '
167 '".tflite" or ".onnx", this arg can be omitted.'))
169 onnx_group = parser.add_argument_group('arguments when model type is onnx')
170 onnx_group.add_argument(
171 '--convert_nchw_to_nhwc',
174 'Convert NCHW operators to NHWC under the assumption that input model is NCHW.')
175 onnx_group.add_argument(
176 '--nchw_to_nhwc_input_shape',
178 help='Convert the input shape of the model (argument for convert_nchw_to_nhwc)')
179 onnx_group.add_argument(
180 '--nchw_to_nhwc_output_shape',
182 help='Convert the output shape of the model (argument for convert_nchw_to_nhwc)')
184 # get backend list in the directory
185 backends_name = [ntpath.basename(f) for f in backends_list]
186 if not backends_name:
187 backends_name_message = '(There is no available backend drivers)'
189 backends_name_message = '(available backend drivers: ' + ', '.join(
191 backend_help_message = 'backend name to use ' + backends_name_message
192 parser.add_argument('-b', '--backend', type=str, help=backend_help_message)
197 def _verify_arg(parser, args):
198 # check if required arguments is given
200 if not oneutils.is_valid_attr(args, 'input_path'):
201 missing.append('-i/--input_path')
202 if not oneutils.is_valid_attr(args, 'output_path'):
203 missing.append('-o/--output_path')
204 if not oneutils.is_valid_attr(args, 'backend'):
205 missing.append('-b/--backend')
207 if oneutils.is_valid_attr(args, 'model_type'):
208 # TODO Support model types other than onnx and tflite (e.g., TF)
209 if getattr(args, 'model_type') not in ['onnx', 'tflite']:
210 parser.error('Allowed value for --model_type: "onnx" or "tflite"')
212 if oneutils.is_valid_attr(args, 'nchw_to_nhwc_input_shape'):
213 if not oneutils.is_valid_attr(args, 'convert_nchw_to_nhwc'):
214 missing.append('--convert_nchw_to_nhwc')
215 if oneutils.is_valid_attr(args, 'nchw_to_nhwc_output_shape'):
216 if not oneutils.is_valid_attr(args, 'convert_nchw_to_nhwc'):
217 missing.append('--convert_nchw_to_nhwc')
220 parser.error('the following arguments are required: ' + ' '.join(missing))
223 def _parse_arg(parser):
226 argv = copy.deepcopy(sys.argv)
230 args = [list(y) for x, y in itertools.groupby(argv, lambda z: z == '--') if not x]
232 # one-init [-h] [-v] ...
235 init_args = parser.parse_args(init_args)
236 backend_args = backend_args if len(args) < 2 else args[1]
238 if len(args) and init_args.version:
239 oneutils.print_version_and_exit(__file__)
241 return init_args, backend_args
244 def _get_executable(args, backends_list):
245 if oneutils.is_valid_attr(args, 'backend'):
246 backend_base = getattr(args, 'backend') + '-init'
247 for cand in backends_list:
248 if ntpath.basename(cand) == backend_base:
250 raise FileNotFoundError(backend_base + ' not found')
253 # TODO Support workflow format (https://github.com/Samsung/ONE/pull/9354)
254 def _generate(args, model_type: str, inout_path: InputOutputPath):
256 config = CommentableConfigParser()
257 model_dir = os.path.dirname(args.input_path)
258 model_name = os.path.basename(args.input_path).split('.')[0]
260 def _assert_section(section: str):
261 if not config.has_section(section):
262 raise RuntimeError(f'Cannot find section: {section}')
264 def _add_onecc_sections():
266 This adds all sections
268 config.add_section('onecc')
270 f'one-import-{model_type}', 'one-optimize', 'one-quantize', 'one-codegen'
273 for section in sections:
274 config['onecc'][section] = 'True'
275 # add empty section as a preperation of next procedure
276 config.add_section(section)
279 section = f'one-import-{model_type}'
280 _assert_section(section)
282 output_path = os.path.join(model_dir, f'{model_name}.circle')
283 inout_path.enter_new_section(section_output_path=output_path)
284 config[section]['input_path'] = inout_path.input_path()
285 config[section]['output_path'] = inout_path.output_path()
288 section = 'one-optimize'
289 _assert_section(section)
291 output_path = os.path.join(model_dir, f'{model_name}.opt.circle')
292 inout_path.enter_new_section(section_output_path=output_path)
293 config[section]['input_path'] = inout_path.input_path()
294 config[section]['output_path'] = inout_path.output_path()
296 # TODO Add optimization optinos
299 section = 'one-quantize'
300 _assert_section(section)
302 output_path = os.path.join(model_dir, f'{model_name}.q.circle')
303 inout_path.enter_new_section(section_output_path=output_path)
304 config[section]['input_path'] = inout_path.input_path()
305 config[section]['output_path'] = inout_path.output_path()
308 section = 'one-codegen'
309 _assert_section(section)
311 # [backend]-init must provide default value for 'command'
312 config[section]['backend'] = args.backend
315 # NYI: one-profile, one-partition, one-pack, one-infer
318 _add_onecc_sections()
325 with open(args.output_path, 'w') as f:
329 def _get_model_type(parser, args):
330 if oneutils.is_valid_attr(args, 'model_type'):
331 return args.model_type
333 if oneutils.is_valid_attr(args, 'input_path'):
334 _, ext = os.path.splitext(args.input_path)
336 # ext would be, e.g., '.tflite' or '.onnx'.
337 # Note: when args.input_path does not have an extension, e.g., '/home/foo'
338 # ext after os.path.splitext() is '' and ''[1:] is still ''.
339 # TODO support tensorflow model
341 if ext in ["tflite", "onnx"]:
344 parser.error(f'following file extensions are supported: ".onnx" ".tflite"')
346 parser.error(f'the following argument is required: --input_path')
351 backends_list = _get_backends_list()
354 parser = _get_parser(backends_list)
355 args, backend_args = _parse_arg(parser)
358 _verify_arg(parser, args)
360 model_type = _get_model_type(parser, args)
361 inout_path = InputOutputPath(args.input_path)
362 _generate(args, model_type, inout_path)
364 # make a command to run given backend driver
365 driver_path = _get_executable(args, backends_list)
366 init_cmd = [driver_path] + backend_args
369 oneutils.run(init_cmd, err_prefix=ntpath.basename(driver_path))
371 raise NotImplementedError("NYI")
374 if __name__ == '__main__':
375 oneutils.safemain(main, __file__)