3 # Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
9 # http://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
26 from typing import Union
28 import onelib.constant as _constant
31 def add_default_arg(parser):
37 help='show program\'s version number and exit')
44 help='output additional information to stdout or stderr')
47 parser.add_argument('-C', '--config', type=str, help='run with configuation file')
48 # section name that you want to run in configuration file
49 parser.add_argument('-S', '--section', type=str, help=argparse.SUPPRESS)
52 def add_default_arg_no_CS(parser):
54 This adds -v -V args only (no -C nor -S)
61 help='show program\'s version number and exit')
68 help='output additional information to stdout or stderr')
71 def is_accumulated_arg(arg, driver):
72 if driver == "one-quantize":
74 "tensor_name", "scale", "zero_point", "src_tensor_name", "dst_tensor_name"
76 if arg in accumulables:
82 def is_valid_attr(args, attr):
83 return hasattr(args, attr) and getattr(args, attr)
86 def parse_cfg_and_overwrite(config_path, section, args):
88 parse given section of configuration file and set the values of args.
89 Even if the values parsed from the configuration file already exist in args,
90 the values are overwritten.
92 if config_path == None:
95 config = configparser.ConfigParser()
96 # make option names case sensitive
97 config.optionxform = str
98 parsed = config.read(config_path)
100 raise FileNotFoundError('Not found given configuration file')
101 if not config.has_section(section):
102 raise AssertionError('configuration file doesn\'t have \'' + section +
104 for key in config[section]:
105 setattr(args, key, config[section][key])
106 # TODO support accumulated arguments
109 def parse_cfg(config_path: Union[str, None], section_to_parse: str, args):
111 parse configuration file and store the information to args
113 :param config_path: path to configuration file
114 :param section_to_parse: section name to parse
115 :param args: object to store the parsed information
117 if config_path is None:
120 parser = configparser.ConfigParser()
121 parser.optionxform = str
122 parser.read(config_path)
124 if not parser.has_section(section_to_parse):
125 raise AssertionError('configuration file must have \'' + section_to_parse +
128 for key in parser[section_to_parse]:
129 if is_accumulated_arg(key, section_to_parse):
130 if not is_valid_attr(args, key):
131 setattr(args, key, [parser[section_to_parse][key]])
133 getattr(args, key).append(parser[section_to_parse][key])
135 if hasattr(args, key) and getattr(args, key):
137 setattr(args, key, parser[section_to_parse][key])
140 def print_version_and_exit(file_path):
141 """print version of the file located in the file_path"""
142 script_path = os.path.realpath(file_path)
143 dir_path = os.path.dirname(script_path)
144 script_name = os.path.splitext(os.path.basename(script_path))[0]
146 subprocess.call([os.path.join(dir_path, 'one-version'), script_name])
150 def safemain(main, mainpath):
151 """execute given method and print with program name for all uncaught exceptions"""
154 except Exception as e:
155 prog_name = os.path.basename(mainpath)
156 print(f"{prog_name}: {type(e).__name__}: " + str(e), file=sys.stderr)
160 def run(cmd, err_prefix=None, logfile=None):
161 """Execute command in subprocess
164 cmd: command to be executed in subprocess
165 err_prefix: prefix to be put before every stderr lines
166 logfile: file stream to which both of stdout and stderr lines will be written
168 with subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as p:
170 inputs = set([p.stdout, p.stderr])
172 readable, _, _ = select.select(inputs, [], [])
183 line = f"{err_prefix}: ".encode() + line
184 out.buffer.write(line)
188 if p.returncode != 0:
189 sys.exit(p.returncode)
192 def remove_prefix(str, prefix):
193 if str.startswith(prefix):
194 return str[len(prefix):]
198 def remove_suffix(str, suffix):
199 if str.endswith(suffix):
200 return str[:-len(suffix)]
204 def get_optimization_list(get_name=False):
206 returns a list of optimization. If `get_name` is True,
207 only basename without extension is returned rather than full file path.
219 Optimization options must be placed in `optimization` folder
221 dir_path = os.path.dirname(os.path.realpath(__file__))
223 # optimization folder
225 f for f in glob.glob(dir_path + '/../../optimization/O*.cfg', recursive=True)
227 # exclude if the name has space
228 files = [s for s in files if not ' ' in s]
232 base = ntpath.basename(cand)
233 if os.path.isfile(cand) and os.access(cand, os.R_OK):
234 opt_list.append(cand)
237 # NOTE the name includes prefix 'O'
238 # e.g. O1, O2, ONCHW not just 1, 2, NCHW
239 opt_list = [ntpath.basename(f) for f in opt_list]
240 opt_list = [remove_suffix(s, '.cfg') for s in opt_list]
245 def detect_one_import_drivers(search_path):
246 """Looks for import drivers in given directory
249 search_path: path to the directory where to search import drivers
252 dict: each entry is related to single detected driver,
253 key is a config section name, value is a driver name
256 import_drivers_dict = {}
257 for module_name in os.listdir(search_path):
258 full_path = os.path.join(search_path, module_name)
259 if not os.path.isfile(full_path):
261 if module_name.find("one-import-") != 0:
263 module_loader = importlib.machinery.SourceFileLoader(module_name, full_path)
264 module_spec = importlib.util.spec_from_loader(module_name, module_loader)
265 module = importlib.util.module_from_spec(module_spec)
267 module_loader.exec_module(module)
268 if hasattr(module, "get_driver_cfg_section"):
269 section = module.get_driver_cfg_section()
270 import_drivers_dict[section] = module_name
273 return import_drivers_dict