2 ''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # '''
3 ''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python # '''
4 ''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@" # '''
5 ''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
8 # Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
10 # Licensed under the Apache License, Version 2.0 (the "License");
11 # you may not use this file except in compliance with the License.
12 # You may obtain a copy of the License at
14 # http://www.apache.org/licenses/LICENSE-2.0
16 # Unless required by applicable law or agreed to in writing, software
17 # distributed under the License is distributed on an "AS IS" BASIS,
18 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19 # See the License for the specific language governing permissions and
20 # limitations under the License.
30 import onelib.utils as oneutils
32 # TODO Find better way to suppress trackback on error
33 sys.tracebacklimit = 0
36 def _search_backend_driver(driver):
48 The list where `one-infer` finds its backend driver
49 - `bin` folder where `one-infer` exists
50 - `backends/**/bin/` folder
52 NOTE If there are drivers of the same name in different places,
53 the closer to the top in the list, the higher the priority.
55 dir_path = os.path.dirname(os.path.realpath(__file__))
57 # CASE 1: one/bin/{driver} is found
58 driver_path = dir_path + '/' + driver
59 if os.path.isfile(driver_path) and os.access(driver_path, os.X_OK):
62 # CASE 2: one/backends/**/bin/{driver} is found
63 for driver_path in glob.glob(
64 dir_path + '/../backends/**/bin/' + driver, recursive=True):
65 if os.path.isfile(driver_path) and os.access(driver_path, os.X_OK):
68 # CASE 3: {driver} is found in nowhere
73 infer_usage = 'one-infer [-h] [-v] [-C CONFIG] [-d DRIVER] [--post-process POST_PROCESS] [--] [COMMANDS FOR BACKEND DRIVER]'
75 one-infer provides post-processing after invoking backend inference driver
76 use python script and its arguments to '--post-process' argument as below
77 one-infer -d dummy-infer --post-process "script.py arg1 arg2" -- [arguments for dummy-infer]
79 parser = argparse.ArgumentParser(
80 description='command line tool to infer model',
83 formatter_class=argparse.RawTextHelpFormatter)
85 oneutils.add_default_arg(parser)
87 driver_help_message = 'backend inference driver name to execute'
88 parser.add_argument('-d', '--driver', type=str, help=driver_help_message)
90 post_process_help_message = 'post processing python script and arguments which can be used to convert I/O data to standard format'
91 parser.add_argument('--post-process', type=str, help=post_process_help_message)
96 def _verify_arg(parser, args):
97 """verify given arguments"""
99 if not oneutils.is_valid_attr(args, 'driver'):
100 missing.append('-d/--driver')
102 parser.error('the following arguments are required: ' + ' '.join(missing))
105 def _parse_arg(parser):
108 argv = copy.deepcopy(sys.argv)
112 args = [list(y) for x, y in itertools.groupby(argv, lambda z: z == '--') if not x]
114 # one-infer [-h] [-v] [-C CONFIG] [-d DRIVER] [--post-process POST_PROCESS] -- [COMMANDS FOR BACKEND DRIVER]
117 infer_args = parser.parse_args(infer_args)
118 backend_args = backend_args if len(args) < 2 else args[1]
120 infer_args = parser.parse_args(infer_args)
122 if len(args) and infer_args.version:
123 oneutils.print_version_and_exit(__file__)
125 return infer_args, backend_args
128 def _get_executable(args):
129 driver = oneutils.is_valid_attr(args, 'driver')
131 executable = _search_backend_driver(driver)
135 raise FileNotFoundError(driver + ' not found')
140 parser = _get_parser()
141 args, backend_args = _parse_arg(parser)
143 # parse configuration file
144 oneutils.parse_cfg(args.config, 'one-infer', args)
147 _verify_arg(parser, args)
149 # make a command to run given backend driver
150 driver_path = _get_executable(args)
151 infer_cmd = [driver_path] + backend_args
152 if oneutils.is_valid_attr(args, 'command'):
153 infer_cmd += getattr(args, 'command').split()
156 oneutils.run(infer_cmd, err_prefix=ntpath.basename(driver_path))
158 # run post process script if it's given
159 if oneutils.is_valid_attr(args, 'post_process'):
160 # NOTE: the given python script will be executed by venv of ONE
161 python_path = sys.executable
162 post_process_command = [python_path] + getattr(args,
163 'post_process').strip().split(' ')
164 oneutils.run(post_process_command, err_prefix='one-infer')
167 if __name__ == '__main__':
168 oneutils.safemain(main, __file__)