3d37f9d49729f071bd01025f715e325d25a75424
[platform/core/ml/nnfw.git] / compiler / one-cmds / onelib / utils.py
1 #!/usr/bin/env python
2
3 # Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 #
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
8 #
9 #    http://www.apache.org/licenses/LICENSE-2.0
10 #
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
16
17 import argparse
18 import configparser
19 import glob
20 import importlib
21 import ntpath
22 import os
23 import subprocess
24 import sys
25
26 from typing import Union
27
28 import onelib.constant as _constant
29
30
31 def add_default_arg(parser):
32     # version
33     parser.add_argument(
34         '-v',
35         '--version',
36         action='store_true',
37         help='show program\'s version number and exit')
38
39     # verbose
40     parser.add_argument(
41         '-V',
42         '--verbose',
43         action='store_true',
44         help='output additional information to stdout or stderr')
45
46     # configuration file
47     parser.add_argument('-C', '--config', type=str, help='run with configuation file')
48     # section name that you want to run in configuration file
49     parser.add_argument('-S', '--section', type=str, help=argparse.SUPPRESS)
50
51
52 def add_default_arg_no_CS(parser):
53     """
54     This adds -v -V args only (no -C nor -S)
55     """
56     # version
57     parser.add_argument(
58         '-v',
59         '--version',
60         action='store_true',
61         help='show program\'s version number and exit')
62
63     # verbose
64     parser.add_argument(
65         '-V',
66         '--verbose',
67         action='store_true',
68         help='output additional information to stdout or stderr')
69
70
71 def is_accumulated_arg(arg, driver):
72     if driver == "one-quantize":
73         accumulables = [
74             "tensor_name", "scale", "zero_point", "src_tensor_name", "dst_tensor_name"
75         ]
76         if arg in accumulables:
77             return True
78
79     return False
80
81
82 def is_valid_attr(args, attr):
83     return hasattr(args, attr) and getattr(args, attr)
84
85
86 def parse_cfg_and_overwrite(config_path, section, args):
87     """
88     parse given section of configuration file and set the values of args.
89     Even if the values parsed from the configuration file already exist in args,
90     the values are overwritten.
91     """
92     if config_path == None:
93         # DO NOTHING
94         return
95     config = configparser.ConfigParser()
96     # make option names case sensitive
97     config.optionxform = str
98     parsed = config.read(config_path)
99     if not parsed:
100         raise FileNotFoundError('Not found given configuration file')
101     if not config.has_section(section):
102         raise AssertionError('configuration file doesn\'t have \'' + section +
103                              '\' section')
104     for key in config[section]:
105         setattr(args, key, config[section][key])
106     # TODO support accumulated arguments
107
108
109 def parse_cfg(config_path: Union[str, None], section_to_parse: str, args):
110     """
111     parse configuration file and store the information to args
112     
113     :param config_path: path to configuration file
114     :param section_to_parse: section name to parse
115     :param args: object to store the parsed information
116     """
117     if config_path is None:
118         return
119
120     parser = configparser.ConfigParser()
121     parser.optionxform = str
122     parser.read(config_path)
123
124     if not parser.has_section(section_to_parse):
125         raise AssertionError('configuration file must have \'' + section_to_parse +
126                              '\' section')
127
128     for key in parser[section_to_parse]:
129         if is_accumulated_arg(key, section_to_parse):
130             if not is_valid_attr(args, key):
131                 setattr(args, key, [parser[section_to_parse][key]])
132             else:
133                 getattr(args, key).append(parser[section_to_parse][key])
134             continue
135         if hasattr(args, key) and getattr(args, key):
136             continue
137         setattr(args, key, parser[section_to_parse][key])
138
139
140 def print_version_and_exit(file_path):
141     """print version of the file located in the file_path"""
142     script_path = os.path.realpath(file_path)
143     dir_path = os.path.dirname(script_path)
144     script_name = os.path.splitext(os.path.basename(script_path))[0]
145     # run one-version
146     subprocess.call([os.path.join(dir_path, 'one-version'), script_name])
147     sys.exit()
148
149
150 def safemain(main, mainpath):
151     """execute given method and print with program name for all uncaught exceptions"""
152     try:
153         main()
154     except Exception as e:
155         prog_name = os.path.basename(mainpath)
156         print(f"{prog_name}: {type(e).__name__}: " + str(e), file=sys.stderr)
157         sys.exit(255)
158
159
160 def run(cmd, err_prefix=None, logfile=None):
161     """Execute command in subprocess
162
163     Args:
164         cmd: command to be executed in subprocess
165         err_prefix: prefix to be put before every stderr lines
166         logfile: file stream to which both of stdout and stderr lines will be written
167     """
168     with subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as p:
169         import select
170         inputs = set([p.stdout, p.stderr])
171         while inputs:
172             readable, _, _ = select.select(inputs, [], [])
173             for x in readable:
174                 line = x.readline()
175                 if len(line) == 0:
176                     inputs.discard(x)
177                     continue
178                 if x == p.stdout:
179                     out = sys.stdout
180                 if x == p.stderr:
181                     out = sys.stderr
182                     if err_prefix:
183                         line = f"{err_prefix}: ".encode() + line
184                 out.buffer.write(line)
185                 out.buffer.flush()
186                 if logfile != None:
187                     logfile.write(line)
188     if p.returncode != 0:
189         sys.exit(p.returncode)
190
191
192 def remove_prefix(str, prefix):
193     if str.startswith(prefix):
194         return str[len(prefix):]
195     return str
196
197
198 def remove_suffix(str, suffix):
199     if str.endswith(suffix):
200         return str[:-len(suffix)]
201     return str
202
203
204 def get_optimization_list(get_name=False):
205     """
206     returns a list of optimization. If `get_name` is True,
207     only basename without extension is returned rather than full file path.
208
209     [one hierarchy]
210     one
211     ├── backends
212     ├── bin
213     ├── doc
214     ├── include
215     ├── lib
216     ├── optimization
217     └── test
218
219     Optimization options must be placed in `optimization` folder
220     """
221     dir_path = os.path.dirname(os.path.realpath(__file__))
222
223     # optimization folder
224     files = [
225         f for f in glob.glob(dir_path + '/../../optimization/O*.cfg', recursive=True)
226     ]
227     # exclude if the name has space
228     files = [s for s in files if not ' ' in s]
229
230     opt_list = []
231     for cand in files:
232         base = ntpath.basename(cand)
233         if os.path.isfile(cand) and os.access(cand, os.R_OK):
234             opt_list.append(cand)
235
236     if get_name == True:
237         # NOTE the name includes prefix 'O'
238         # e.g. O1, O2, ONCHW not just 1, 2, NCHW
239         opt_list = [ntpath.basename(f) for f in opt_list]
240         opt_list = [remove_suffix(s, '.cfg') for s in opt_list]
241
242     return opt_list
243
244
245 def detect_one_import_drivers(search_path):
246     """Looks for import drivers in given directory
247
248     Args:
249         search_path: path to the directory where to search import drivers
250
251     Returns:
252     dict: each entry is related to single detected driver,
253           key is a config section name, value is a driver name
254
255     """
256     import_drivers_dict = {}
257     for module_name in os.listdir(search_path):
258         full_path = os.path.join(search_path, module_name)
259         if not os.path.isfile(full_path):
260             continue
261         if module_name.find("one-import-") != 0:
262             continue
263         module_loader = importlib.machinery.SourceFileLoader(module_name, full_path)
264         module_spec = importlib.util.spec_from_loader(module_name, module_loader)
265         module = importlib.util.module_from_spec(module_spec)
266         try:
267             module_loader.exec_module(module)
268             if hasattr(module, "get_driver_cfg_section"):
269                 section = module.get_driver_cfg_section()
270                 import_drivers_dict[section] = module_name
271         except:
272             pass
273     return import_drivers_dict