299255c6fc6b8ffbe7a510cb97863384816e2b11
[platform/core/ml/nnfw.git] / compiler / one-cmds / one-init
1 #!/usr/bin/env bash
2 ''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # '''
3 ''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python                                       # '''
4 ''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@"                                     # '''
5 ''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
6 ''''exit 255                                                                            # '''
7
8 # Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
9 #
10 # Licensed under the Apache License, Version 2.0 (the "License");
11 # you may not use this file except in compliance with the License.
12 # You may obtain a copy of the License at
13 #
14 #    http://www.apache.org/licenses/LICENSE-2.0
15 #
16 # Unless required by applicable law or agreed to in writing, software
17 # distributed under the License is distributed on an "AS IS" BASIS,
18 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19 # See the License for the specific language governing permissions and
20 # limitations under the License.
21
22 import argparse
23 import copy
24 import glob
25 import itertools
26 import ntpath
27 import os
28 import sys
29
30 import configparser
31 import onelib.utils as oneutils
32
33 # TODO Find better way to suppress trackback on error
34 sys.tracebacklimit = 0
35
36
37 class InputOutputPath:
38     '''
39     Class that remembers input circle file and output circle file of section k,
40
41     After calling enter_new_section(),
42     output path in section k will be used as input path of section k+1
43     '''
44
45     def __init__(self, initial_input_path: str):
46         self._first_step = True
47         self._input_path = initial_input_path
48         self._output_path = ''
49
50     def enter_new_section(self, section_output_path: str):
51         '''
52         Call this when starting a section
53         '''
54         if self._first_step == True:
55             self._output_path = section_output_path
56         else:
57             self._input_path = self._output_path
58             self._output_path = section_output_path
59
60         self._first_step = False
61
62     def input_path(self):
63         return self._input_path
64
65     def output_path(self):
66         return self._output_path
67
68
69 class CommentableConfigParser(configparser.ConfigParser):
70     """
71     ConfigParser where comment can be stored
72     In Python ConfigParser, comment in ini file ( starting with ';') is considered a key of which
73     value is None.
74     Ref: https://stackoverflow.com/questions/6620637/writing-comments-to-files-with-configparser
75     """
76
77     def __init__(self):
78         # allow_no_value=True to add comment
79         # ref: https://stackoverflow.com/a/19432072
80         configparser.ConfigParser.__init__(self, allow_no_value=True)
81         self.optionxform = str
82
83     def add_comment(self, section, comment):
84         comment_sign = ';'
85         self[section][f'{comment_sign} {comment}'] = None
86
87
88 def _get_backends_list():
89     """
90     [one hierarchy]
91     one
92     ├── backends
93     ├── bin
94     ├── doc
95     ├── include
96     ├── lib
97     ├── optimization
98     └── test
99
100     The list where `one-init` finds its backends
101     - `bin` folder where `one-init` exists
102     - `backends` folder
103
104     NOTE If there are backends of the same name in different places,
105      the closer to the top in the list, the higher the priority.
106     """
107     dir_path = os.path.dirname(os.path.realpath(__file__))
108     backend_set = set()
109
110     # bin folder
111     files = [f for f in glob.glob(dir_path + '/*-init')]
112     # backends folder
113     files += [f for f in glob.glob(dir_path + '/../backends/**/*-init', recursive=True)]
114     # TODO find backends in `$PATH`
115
116     backends_list = []
117     for cand in files:
118         base = ntpath.basename(cand)
119         if (not base in backend_set) and os.path.isfile(cand) and os.access(
120                 cand, os.X_OK):
121             backend_set.add(base)
122             backends_list.append(cand)
123
124     return backends_list
125
126
127 # TODO Add support for TF graphdef and bcq
128 def _get_parser(backends_list):
129     init_usage = (
130         'one-init [-h] [-v] [-V] '
131         '[-i INPUT_PATH] '
132         '[-o OUTPUT_PATH] '
133         '[-m MODEL_TYPE] '
134         '[-b BACKEND] '
135         # args for onnx model
136         '[--convert_nchw_to_nhwc] '
137         '[--nchw_to_nhwc_input_shape] '
138         '[--nchw_to_nhwc_output_shape] '
139         # args for backend driver
140         '[--] [COMMANDS FOR BACKEND DRIVER]')
141     """
142     NOTE
143     layout options for onnx model could be difficult to users.
144     In one-init, we could consider easier args for the the above three:
145     For example, we could have another option, e.g., --input_img_layout LAYOUT
146       - When LAYOUT is NHWC, apply 'nchw_to_nhwc_input_shape=True' into cfg
147       - When LAYOUT is NCHW, apply 'nchw_to_nhwc_input_shape=False' into cfg
148     """
149
150     parser = argparse.ArgumentParser(
151         description='Command line tool to generate initial cfg file. '
152         'Currently tflite and onnx models are supported',
153         usage=init_usage)
154
155     oneutils.add_default_arg_no_CS(parser)
156
157     parser.add_argument(
158         '-i', '--input_path', type=str, help='full filepath of the input model file')
159     parser.add_argument(
160         '-o', '--output_path', type=str, help='full filepath of the output cfg file')
161     parser.add_argument(
162         '-m',
163         '--model_type',
164         type=str,
165         help=('type of input model: "onnx", "tflite". '
166               'If the file extension passed to --input_path is '
167               '".tflite" or ".onnx", this arg can be omitted.'))
168
169     onnx_group = parser.add_argument_group('arguments when model type is onnx')
170     onnx_group.add_argument(
171         '--convert_nchw_to_nhwc',
172         action='store_true',
173         help=
174         'Convert NCHW operators to NHWC under the assumption that input model is NCHW.')
175     onnx_group.add_argument(
176         '--nchw_to_nhwc_input_shape',
177         action='store_true',
178         help='Convert the input shape of the model (argument for convert_nchw_to_nhwc)')
179     onnx_group.add_argument(
180         '--nchw_to_nhwc_output_shape',
181         action='store_true',
182         help='Convert the output shape of the model (argument for convert_nchw_to_nhwc)')
183
184     # get backend list in the directory
185     backends_name = [ntpath.basename(f) for f in backends_list]
186     if not backends_name:
187         backends_name_message = '(There is no available backend drivers)'
188     else:
189         backends_name_message = '(available backend drivers: ' + ', '.join(
190             backends_name) + ')'
191     backend_help_message = 'backend name to use ' + backends_name_message
192     parser.add_argument('-b', '--backend', type=str, help=backend_help_message)
193
194     return parser
195
196
197 def _verify_arg(parser, args):
198     # check if required arguments is given
199     missing = []
200     if not oneutils.is_valid_attr(args, 'input_path'):
201         missing.append('-i/--input_path')
202     if not oneutils.is_valid_attr(args, 'output_path'):
203         missing.append('-o/--output_path')
204     if not oneutils.is_valid_attr(args, 'backend'):
205         missing.append('-b/--backend')
206
207     if oneutils.is_valid_attr(args, 'model_type'):
208         # TODO Support model types other than onnx and tflite (e.g., TF)
209         if getattr(args, 'model_type') not in ['onnx', 'tflite']:
210             parser.error('Allowed value for --model_type: "onnx" or "tflite"')
211
212     if oneutils.is_valid_attr(args, 'nchw_to_nhwc_input_shape'):
213         if not oneutils.is_valid_attr(args, 'convert_nchw_to_nhwc'):
214             missing.append('--convert_nchw_to_nhwc')
215     if oneutils.is_valid_attr(args, 'nchw_to_nhwc_output_shape'):
216         if not oneutils.is_valid_attr(args, 'convert_nchw_to_nhwc'):
217             missing.append('--convert_nchw_to_nhwc')
218
219     if len(missing):
220         parser.error('the following arguments are required: ' + ' '.join(missing))
221
222
223 def _parse_arg(parser):
224     init_args = []
225     backend_args = []
226     argv = copy.deepcopy(sys.argv)
227     # delete file name
228     del argv[0]
229     # split by '--'
230     args = [list(y) for x, y in itertools.groupby(argv, lambda z: z == '--') if not x]
231
232     # one-init [-h] [-v] ...
233     if len(args):
234         init_args = args[0]
235         init_args = parser.parse_args(init_args)
236         backend_args = backend_args if len(args) < 2 else args[1]
237     # print version
238     if len(args) and init_args.version:
239         oneutils.print_version_and_exit(__file__)
240
241     return init_args, backend_args
242
243
244 def _get_executable(args, backends_list):
245     if oneutils.is_valid_attr(args, 'backend'):
246         backend_base = getattr(args, 'backend') + '-init'
247         for cand in backends_list:
248             if ntpath.basename(cand) == backend_base:
249                 return cand
250         raise FileNotFoundError(backend_base + ' not found')
251
252
253 # TODO Support workflow format (https://github.com/Samsung/ONE/pull/9354)
254 def _generate(args, model_type: str, inout_path: InputOutputPath):
255     # generate cfg file
256     config = CommentableConfigParser()
257     model_dir = os.path.dirname(args.input_path)
258     model_name = os.path.basename(args.input_path).split('.')[0]
259
260     def _assert_section(section: str):
261         if not config.has_section(section):
262             raise RuntimeError(f'Cannot find section: {section}')
263
264     def _add_onecc_sections():
265         '''
266         This adds all sections
267         '''
268         config.add_section('onecc')
269         sections = [
270             f'one-import-{model_type}', 'one-optimize', 'one-quantize', 'one-codegen'
271         ]
272
273         for section in sections:
274             config['onecc'][section] = 'True'
275             # add empty section as a preperation of next procedure
276             config.add_section(section)
277
278     def _gen_import():
279         section = f'one-import-{model_type}'
280         _assert_section(section)
281
282         output_path = os.path.join(model_dir, f'{model_name}.circle')
283         inout_path.enter_new_section(section_output_path=output_path)
284         config[section]['input_path'] = inout_path.input_path()
285         config[section]['output_path'] = inout_path.output_path()
286
287     def _gen_optimize():
288         section = 'one-optimize'
289         _assert_section(section)
290
291         output_path = os.path.join(model_dir, f'{model_name}.opt.circle')
292         inout_path.enter_new_section(section_output_path=output_path)
293         config[section]['input_path'] = inout_path.input_path()
294         config[section]['output_path'] = inout_path.output_path()
295
296         # TODO Add optimization optinos
297
298     def _gen_quantize():
299         section = 'one-quantize'
300         _assert_section(section)
301
302         output_path = os.path.join(model_dir, f'{model_name}.q.circle')
303         inout_path.enter_new_section(section_output_path=output_path)
304         config[section]['input_path'] = inout_path.input_path()
305         config[section]['output_path'] = inout_path.output_path()
306
307     def _gen_codegen():
308         section = 'one-codegen'
309         _assert_section(section)
310
311         # [backend]-init must provide default value for 'command'
312         config[section]['backend'] = args.backend
313
314     #
315     # NYI: one-profile, one-partition, one-pack, one-infer
316     #
317
318     _add_onecc_sections()
319
320     _gen_import()
321     _gen_optimize()
322     _gen_quantize()
323     _gen_codegen()
324
325     with open(args.output_path, 'w') as f:
326         config.write(f)
327
328
329 def _get_model_type(parser, args):
330     if oneutils.is_valid_attr(args, 'model_type'):
331         return args.model_type
332
333     if oneutils.is_valid_attr(args, 'input_path'):
334         _, ext = os.path.splitext(args.input_path)
335
336         # ext would be, e.g., '.tflite' or '.onnx'.
337         # Note: when args.input_path does not have an extension, e.g., '/home/foo'
338         # ext after os.path.splitext() is '' and ''[1:] is still ''.
339         # TODO support tensorflow model
340         ext = ext[1:]
341         if ext in ["tflite", "onnx"]:
342             return ext
343         else:
344             parser.error(f'following file extensions are supported: ".onnx" ".tflite"')
345
346     parser.error(f'the following argument is required: --input_path')
347
348
349 def main():
350     # get backend list
351     backends_list = _get_backends_list()
352
353     # parse arguments
354     parser = _get_parser(backends_list)
355     args, backend_args = _parse_arg(parser)
356
357     # verify arguments
358     _verify_arg(parser, args)
359
360     model_type = _get_model_type(parser, args)
361     inout_path = InputOutputPath(args.input_path)
362     _generate(args, model_type, inout_path)
363
364     # make a command to run given backend driver
365     driver_path = _get_executable(args, backends_list)
366     init_cmd = [driver_path] + backend_args
367
368     # run backend driver
369     oneutils.run(init_cmd, err_prefix=ntpath.basename(driver_path))
370
371     raise NotImplementedError("NYI")
372
373
374 if __name__ == '__main__':
375     oneutils.safemain(main, __file__)