5ccff0f931ee2659a65c757f4c857bc981c20756
[platform/core/ml/nnfw.git] / compiler / one-cmds / one-codegen
1 #!/usr/bin/env bash
2 ''''export SCRIPT_PATH="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)" # '''
3 ''''export PY_PATH=${SCRIPT_PATH}/venv/bin/python                                       # '''
4 ''''test -f ${PY_PATH} && exec ${PY_PATH} "$0" "$@"                                     # '''
5 ''''echo "Error: Virtual environment not found. Please run 'one-prepare-venv' command." # '''
6 ''''exit 255                                                                            # '''
7
8 # Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
9 #
10 # Licensed under the Apache License, Version 2.0 (the "License");
11 # you may not use this file except in compliance with the License.
12 # You may obtain a copy of the License at
13 #
14 #    http://www.apache.org/licenses/LICENSE-2.0
15 #
16 # Unless required by applicable law or agreed to in writing, software
17 # distributed under the License is distributed on an "AS IS" BASIS,
18 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19 # See the License for the specific language governing permissions and
20 # limitations under the License.
21
22 import argparse
23 import copy
24 import glob
25 import itertools
26 import ntpath
27 import os
28 import sys
29 import shutil
30
31 import onelib.utils as oneutils
32
33 # TODO Find better way to suppress trackback on error
34 sys.tracebacklimit = 0
35
36
37 def _get_backends_list():
38     """
39     [one hierarchy]
40     one
41     ├── backends
42     ├── bin
43     ├── doc
44     ├── include
45     ├── lib
46     └── test
47
48     The list where `one-codegen` finds its backends
49     - `bin` folder where `one-codegen` exists
50     - `backends` folder
51     - System path
52
53     NOTE If there are backends of the same name in different places,
54      the closer to the top in the list, the higher the priority.
55     """
56     dir_path = os.path.dirname(os.path.realpath(__file__))
57     backend_set = set()
58
59     # bin folder
60     files = [f for f in glob.glob(dir_path + '/*-compile')]
61     # backends folder
62     files += [
63         f for f in glob.glob(dir_path + '/../backends/**/*-compile', recursive=True)
64     ]
65     # TODO find backends in `$PATH`
66
67     backends_list = []
68     for cand in files:
69         base = ntpath.basename(cand)
70         if not base in backend_set and os.path.isfile(cand) and os.access(cand, os.X_OK):
71             backend_set.add(base)
72             backends_list.append(cand)
73
74     return backends_list
75
76
77 def _get_parser(backends_list):
78     codegen_usage = 'one-codegen [-h] [-v] [-C CONFIG] [-b BACKEND] [--] [COMMANDS FOR BACKEND]'
79     parser = argparse.ArgumentParser(
80         description='command line tool for code generation', usage=codegen_usage)
81
82     oneutils.add_default_arg(parser)
83
84     # get backend list in the directory
85     backends_name = [ntpath.basename(f) for f in backends_list]
86     if not backends_name:
87         backends_name_message = '(There is no available backend drivers)'
88     else:
89         backends_name_message = '(available backend drivers: ' + ', '.join(
90             backends_name) + ')'
91     backend_help_message = 'backend name to use ' + backends_name_message
92     parser.add_argument('-b', '--backend', type=str, help=backend_help_message)
93
94     return parser
95
96
97 def _verify_arg(parser, args):
98     """verify given arguments"""
99     # check if required arguments is given
100     missing = []
101     if not oneutils.is_valid_attr(args, 'backend'):
102         missing.append('-b/--backend')
103     if len(missing):
104         parser.error('the following arguments are required: ' + ' '.join(missing))
105
106
107 def _parse_arg(parser):
108     codegen_args = []
109     backend_args = []
110     unknown_args = []
111     argv = copy.deepcopy(sys.argv)
112     # delete file name
113     del argv[0]
114     # split by '--'
115     args = [list(y) for x, y in itertools.groupby(argv, lambda z: z == '--') if not x]
116     if len(args) == 0:
117         codegen_args = parser.parse_args(codegen_args)
118     # one-codegen has two interfaces
119     # 1. one-codegen [-h] [-v] [-C CONFIG] [-b BACKEND] [COMMANDS FOR BACKEND]
120     if len(args) == 1:
121         codegen_args = args[0]
122         codegen_args, unknown_args = parser.parse_known_args(codegen_args)
123     # 2. one-codegen [-h] [-v] [-C CONFIG] [-b BACKEND] -- [COMMANDS FOR BACKEND]
124     if len(args) == 2:
125         codegen_args = args[0]
126         backend_args = args[1]
127         codegen_args = parser.parse_args(codegen_args)
128     # print version
129     if len(args) and codegen_args.version:
130         oneutils.print_version_and_exit(__file__)
131
132     return codegen_args, backend_args, unknown_args
133
134
135 def main():
136     # get backend list
137     backends_list = _get_backends_list()
138
139     # parse arguments
140     parser = _get_parser(backends_list)
141     args, backend_args, unknown_args = _parse_arg(parser)
142
143     # parse configuration file
144     oneutils.parse_cfg(args.config, 'one-codegen', args)
145
146     # verify arguments
147     _verify_arg(parser, args)
148
149     # make a command to run given backend driver
150     codegen_path = None
151     backend_base = getattr(args, 'backend') + '-compile'
152     for cand in backends_list:
153         if ntpath.basename(cand) == backend_base:
154             codegen_path = cand
155     if not codegen_path:
156         # Find backend from system path
157         codegen_path = shutil.which(backend_base)
158
159     if not codegen_path:
160         raise FileNotFoundError(backend_base + ' not found')
161     codegen_cmd = [codegen_path] + backend_args + unknown_args
162     if oneutils.is_valid_attr(args, 'command'):
163         codegen_cmd += getattr(args, 'command').split()
164
165     # run backend driver
166     oneutils.run(codegen_cmd, err_prefix=backend_base)
167
168
169 if __name__ == '__main__':
170     oneutils.safemain(main, __file__)