2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
25 from pathlib import Path
26 from typing import Union
27 from warnings import warn
29 from shapely.geometry.polygon import Polygon
34 import lxml.etree as et
36 import xml.etree.cElementTree as et
39 def concat_lists(*lists):
40 return list(itertools.chain(*lists))
43 def get_path(entry: Union[str, Path], is_directory=False):
47 raise TypeError('"{}" is expected to be a path-like'.format(entry))
49 # pathlib.Path.exists throws an exception in case of broken symlink
50 if not os.path.exists(str(path)):
51 raise FileNotFoundError('{}: {}'.format(os.strerror(errno.ENOENT), path))
53 if is_directory and not path.is_dir():
54 raise NotADirectoryError('{}: {}'.format(os.strerror(errno.ENOTDIR), path))
56 # if it exists it is either file (or valid symlink to file) or directory (or valid symlink to directory)
57 if not is_directory and not path.is_file():
58 raise IsADirectoryError('{}: {}'.format(os.strerror(errno.EISDIR), path))
63 def contains_all(container, *args):
64 sequence = set(container)
67 if len(sequence.intersection(arg)) != len(arg):
73 def contains_any(container, *args):
74 sequence = set(container)
77 if sequence.intersection(arg):
83 def string_to_tuple(string, casting_type=float):
84 processed = string.replace(' ', '')
85 processed = processed.replace('(', '')
86 processed = processed.replace(')', '')
87 processed = processed.split(',')
89 return tuple([casting_type(entry) for entry in processed])
92 def string_to_list(string):
93 processed = string.replace(' ', '')
94 processed = processed.replace('[', '')
95 processed = processed.replace(']', '')
96 processed = processed.split(',')
98 return list(entry for entry in processed)
101 class JSONDecoderWithAutoConversion(json.JSONDecoder):
103 Custom json decoder to convert all strings into numbers (int, float) during reading json file.
106 def decode(self, s, _w=json.decoder.WHITESPACE.match):
107 decoded = super().decode(s, _w)
108 return self._decode(decoded)
110 def _decode(self, entry):
111 if isinstance(entry, str):
120 elif isinstance(entry, dict):
121 return {self._decode(key): self._decode(value) for key, value in entry.items()}
122 elif isinstance(entry, list):
123 return [self._decode(value) for value in entry]
128 def dict_subset(dict_, key_subset):
129 return {key: value for key, value in dict_.items() if key in key_subset}
132 def zipped_transform(fn, *iterables, inplace=False):
133 result = (iterables if inplace else tuple([] for _ in range(len(iterables))))
134 updater = (list.__setitem__ if inplace else lambda container, _, entry: container.append(entry))
136 for idx, values in enumerate(zip(*iterables)):
137 iter_res = fn(*values)
141 for dst, res in zip(result, iter_res):
142 updater(dst, idx, res)
147 def overrides(obj, attribute_name, base=None):
148 cls = obj if isinstance(obj, type) else obj.__class__
150 base = base or cls.__bases__[0]
151 obj_attr = getattr(cls, attribute_name, None)
152 base_attr = getattr(base, attribute_name, None)
154 return obj_attr and obj_attr != base_attr
157 def enum_values(enum):
158 return [member.value for member in enum]
161 def get_size_from_config(config, allow_none=False):
162 if contains_all(config, ('size', 'dst_width', 'dst_height')):
163 warn('All parameters: size, dst_width, dst_height are provided. Size will be used. '
164 'You should specify only size or pair values des_width, dst_height in config.')
166 return config['size'], config['size']
167 if contains_all(config, ('dst_width', 'dst_height')):
168 return config['dst_height'], config['dst_width']
170 raise ValueError('Either size or dst_width and dst_height required')
175 def get_size_3d_from_config(config, allow_none=False):
176 if contains_all(config, ('size', 'dst_width', 'dst_height', 'dst_volume')):
177 warn('All parameters: size, dst_width, dst_height, dst_volume are provided. Size will be used. '
178 'You should specify only size or three values des_width, dst_height, dst_volume in config.')
180 return config['size'], config['size'], config['size']
181 if contains_all(config, ('dst_width', 'dst_height', 'dst_volume')):
182 return config['dst_height'], config['dst_width'], config['dst_volume']
184 raise ValueError('Either size or dst_width and dst_height required')
186 return config.get('dst_height'), config.get('dst_width'), config.get('dst_volume')
189 def in_interval(value, interval):
190 minimum = interval[0]
191 maximum = interval[1] if len(interval) >= 2 else None
194 return minimum <= value
196 return minimum <= value < maximum
199 def finalize_metric_result(values, names):
200 result_values, result_names = [], []
201 for value, name in zip(values, names):
205 result_values.append(value)
206 result_names.append(name)
208 return result_values, result_names
211 def get_representations(values, representation_source):
212 return np.reshape([value.get(representation_source) for value in values], -1)
215 def get_supported_representations(container, supported_types):
216 if np.shape(container) == ():
217 container = [container]
219 return list(filter(lambda rep: check_representation_type(rep, supported_types), container))
222 def check_representation_type(representation, representation_types):
223 for representation_type in representation_types:
224 if type(representation).__name__ == representation_type.__name__:
229 def is_single_metric_source(source):
233 return np.size(source.split(',')) == 1
236 def read_txt(file: Union[str, Path], sep='\n', **kwargs):
237 def is_empty(string):
238 return not string or string.isspace()
240 with get_path(file).open() as content:
241 content = content.read(**kwargs).split(sep)
242 content = list(filter(lambda string: not is_empty(string), content))
244 return list(map(str.strip, content))
247 def read_xml(file: Union[str, Path], *args, **kwargs):
248 return et.parse(str(get_path(file)), *args, **kwargs).getroot()
251 def read_json(file: Union[str, Path], *args, **kwargs):
252 with get_path(file).open() as content:
253 return json.load(content, *args, **kwargs)
256 def read_pickle(file: Union[str, Path], *args, **kwargs):
257 with get_path(file).open('rb') as content:
258 return pickle.load(content, *args, **kwargs)
261 def read_yaml(file: Union[str, Path], *args, **kwargs):
262 # yaml does not keep order of keys in dictionaries but it is important for reading pre/post processing
263 yaml.add_representer(collections.OrderedDict, lambda dumper, data: dumper.represent_dict(data.items()))
264 yaml.add_constructor(
265 yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG,
266 lambda loader, node: collections.OrderedDict(loader.construct_pairs(node))
269 with get_path(file).open() as content:
270 return yaml.load(content, Loader=yaml.SafeLoader, *args, **kwargs)
273 def read_csv(file: Union[str, Path], *args, **kwargs):
274 with get_path(file).open() as content:
275 return list(csv.DictReader(content, *args, **kwargs))
278 def extract_image_representations(image_representations):
279 images = [rep.data for rep in image_representations]
280 meta = [rep.metadata for rep in image_representations]
285 def convert_bboxes_xywh_to_x1y1x2y2(x_coord, y_coord, width, height):
286 return x_coord, y_coord, x_coord + width, y_coord + height
289 def get_or_parse_value(item, supported_values, default=None):
290 if isinstance(item, str):
292 if item in supported_values:
293 return supported_values[item]
296 return string_to_tuple(item)
298 message = 'Invalid value "{}", expected one of precomputed: ({}) or list of values'.format(
299 item, ', '.join(supported_values.keys())
301 raise ValueError(message)
303 if isinstance(item, (float, int)):
309 def string_to_bool(string):
310 return string.lower() in ['yes', 'true', 't', '1']
313 def get_key_by_value(container, target):
314 for key, value in container.items():
322 return '--{}'.format(key)
325 def to_lower_register(str_list):
326 return list(map(lambda item: item.lower() if item else None, str_list))
329 def polygon_from_points(points):
330 return Polygon(points)
333 def remove_difficult(difficult, indexes):
338 while id_difficult < len(difficult) and id_removed < len(indexes):
339 if difficult[id_difficult] < indexes[id_removed]:
340 new_difficult.append(difficult[id_difficult] - decrementor)
349 def convert_to_range(entry):
351 if isinstance(entry, str):
352 entry_range = string_to_tuple(entry_range)
353 elif not isinstance(entry_range, tuple) and not isinstance(entry_range, list):
354 entry_range = [entry_range]
359 def add_input_shape_to_meta(meta, shape):
360 meta['input_shape'] = shape