125db3e379078f3a7ed3f261e38e1f46d84af264
[platform/core/ml/nnfw.git] / compiler / one-cmds / one-infer
1 #!/usr/bin/env bash
2 ''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # '''
3 ''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python                                       # '''
4 ''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@"                                     # '''
5 ''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
6 ''''exit 255                                                                            # '''
7
8 # Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
9 #
10 # Licensed under the Apache License, Version 2.0 (the "License");
11 # you may not use this file except in compliance with the License.
12 # You may obtain a copy of the License at
13 #
14 #    http://www.apache.org/licenses/LICENSE-2.0
15 #
16 # Unless required by applicable law or agreed to in writing, software
17 # distributed under the License is distributed on an "AS IS" BASIS,
18 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19 # See the License for the specific language governing permissions and
20 # limitations under the License.
21
22 import argparse
23 import copy
24 import glob
25 import itertools
26 import ntpath
27 import os
28 import sys
29
30 import onelib.utils as oneutils
31
32 # TODO Find better way to suppress trackback on error
33 sys.tracebacklimit = 0
34
35
36 def _search_backend_driver(driver):
37     """
38     [one hierarchy]
39     one
40     ├── backends
41     ├── bin
42     ├── doc
43     ├── include
44     ├── lib
45     ├── optimization
46     └── test
47
48     The list where `one-infer` finds its backend driver
49     - `bin` folder where `one-infer` exists
50     - `backends/**/bin/` folder
51
52     NOTE If there are drivers of the same name in different places,
53      the closer to the top in the list, the higher the priority.
54     """
55     dir_path = os.path.dirname(os.path.realpath(__file__))
56
57     # CASE 1: one/bin/{driver} is found
58     driver_path = dir_path + '/' + driver
59     if os.path.isfile(driver_path) and os.access(driver_path, os.X_OK):
60         return driver_path
61
62     # CASE 2: one/backends/**/bin/{driver} is found
63     for driver_path in glob.glob(
64             dir_path + '/../backends/**/bin/' + driver, recursive=True):
65         if os.path.isfile(driver_path) and os.access(driver_path, os.X_OK):
66             return driver_path
67
68     # CASE 3: {driver} is found in nowhere
69     return None
70
71
72 def _get_parser():
73     infer_usage = 'one-infer [-h] [-v] [-C CONFIG] [-d DRIVER] [--post-process POST_PROCESS] [--] [COMMANDS FOR BACKEND DRIVER]'
74     infer_detail = """
75 one-infer provides post-processing after invoking backend inference driver
76 use python script and its arguments to '--post-process' argument as below
77 one-infer -d dummy-infer --post-process "script.py arg1 arg2" -- [arguments for dummy-infer]
78 """
79     parser = argparse.ArgumentParser(
80         description='command line tool to infer model',
81         usage=infer_usage,
82         epilog=infer_detail,
83         formatter_class=argparse.RawTextHelpFormatter)
84
85     oneutils.add_default_arg(parser)
86
87     driver_help_message = 'backend inference driver name to execute'
88     parser.add_argument('-d', '--driver', type=str, help=driver_help_message)
89
90     post_process_help_message = 'post processing python script and arguments which can be used to convert I/O data to standard format'
91     parser.add_argument('--post-process', type=str, help=post_process_help_message)
92
93     return parser
94
95
96 def _verify_arg(parser, args):
97     """verify given arguments"""
98     missing = []
99     if not oneutils.is_valid_attr(args, 'driver'):
100         missing.append('-d/--driver')
101     if len(missing):
102         parser.error('the following arguments are required: ' + ' '.join(missing))
103
104
105 def _parse_arg(parser):
106     infer_args = []
107     backend_args = []
108     argv = copy.deepcopy(sys.argv)
109     # delete file name
110     del argv[0]
111     # split by '--'
112     args = [list(y) for x, y in itertools.groupby(argv, lambda z: z == '--') if not x]
113
114     # one-infer [-h] [-v] [-C CONFIG] [-d DRIVER] [--post-process POST_PROCESS] -- [COMMANDS FOR BACKEND DRIVER]
115     if len(args):
116         infer_args = args[0]
117         infer_args = parser.parse_args(infer_args)
118         backend_args = backend_args if len(args) < 2 else args[1]
119     else:
120         infer_args = parser.parse_args(infer_args)
121     # print version
122     if len(args) and infer_args.version:
123         oneutils.print_version_and_exit(__file__)
124
125     return infer_args, backend_args
126
127
128 def _get_executable(args):
129     driver = oneutils.is_valid_attr(args, 'driver')
130
131     executable = _search_backend_driver(driver)
132     if executable:
133         return executable
134     else:
135         raise FileNotFoundError(driver + ' not found')
136
137
138 def main():
139     # parse arguments
140     parser = _get_parser()
141     args, backend_args = _parse_arg(parser)
142
143     # parse configuration file
144     oneutils.parse_cfg(args.config, 'one-infer', args)
145
146     # verify arguments
147     _verify_arg(parser, args)
148
149     # make a command to run given backend driver
150     driver_path = _get_executable(args)
151     infer_cmd = [driver_path] + backend_args
152     if oneutils.is_valid_attr(args, 'command'):
153         infer_cmd += getattr(args, 'command').split()
154
155     # run backend driver
156     oneutils.run(infer_cmd, err_prefix=ntpath.basename(driver_path))
157
158     # run post process script if it's given
159     if oneutils.is_valid_attr(args, 'post_process'):
160         # NOTE: the given python script will be executed by venv of ONE
161         python_path = sys.executable
162         post_process_command = [python_path] + getattr(args,
163                                                        'post_process').strip().split(' ')
164         oneutils.run(post_process_command, err_prefix='one-infer')
165
166
167 if __name__ == '__main__':
168     oneutils.safemain(main, __file__)