Publishing 2019 R1 content
[platform/upstream/dldt.git] / tools / accuracy_checker / accuracy_checker / utils.py
1 """
2 Copyright (c) 2019 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 """
16
17 import collections
18 import csv
19 import errno
20 import itertools
21 import json
22 import os
23 import pickle
24
25 from pathlib import Path
26 from typing import Union
27 from warnings import warn
28
29 from shapely.geometry.polygon import Polygon
30 import numpy as np
31 import yaml
32
33 try:
34     import lxml.etree as et
35 except ImportError:
36     import xml.etree.cElementTree as et
37
38
39 def concat_lists(*lists):
40     return list(itertools.chain(*lists))
41
42
43 def get_path(entry: Union[str, Path], is_directory=False):
44     try:
45         path = Path(entry)
46     except TypeError:
47         raise TypeError('"{}" is expected to be a path-like'.format(entry))
48
49     # pathlib.Path.exists throws an exception in case of broken symlink
50     if not os.path.exists(str(path)):
51         raise FileNotFoundError('{}: {}'.format(os.strerror(errno.ENOENT), path))
52
53     if is_directory and not path.is_dir():
54         raise NotADirectoryError('{}: {}'.format(os.strerror(errno.ENOTDIR), path))
55
56     # if it exists it is either file (or valid symlink to file) or directory (or valid symlink to directory)
57     if not is_directory and not path.is_file():
58         raise IsADirectoryError('{}: {}'.format(os.strerror(errno.EISDIR), path))
59
60     return path
61
62
63 def contains_all(container, *args):
64     sequence = set(container)
65
66     for arg in args:
67         if len(sequence.intersection(arg)) != len(arg):
68             return False
69
70     return True
71
72
73 def contains_any(container, *args):
74     sequence = set(container)
75
76     for arg in args:
77         if sequence.intersection(arg):
78             return True
79
80     return False
81
82
83 def string_to_tuple(string, casting_type=float):
84     processed = string.replace(' ', '')
85     processed = processed.replace('(', '')
86     processed = processed.replace(')', '')
87     processed = processed.split(',')
88
89     return tuple([casting_type(entry) for entry in processed])
90
91
92 def string_to_list(string):
93     processed = string.replace(' ', '')
94     processed = processed.replace('[', '')
95     processed = processed.replace(']', '')
96     processed = processed.split(',')
97
98     return list(entry for entry in processed)
99
100
101 class JSONDecoderWithAutoConversion(json.JSONDecoder):
102     """
103     Custom json decoder to convert all strings into numbers (int, float) during reading json file.
104     """
105
106     def decode(self, s, _w=json.decoder.WHITESPACE.match):
107         decoded = super().decode(s, _w)
108         return self._decode(decoded)
109
110     def _decode(self, entry):
111         if isinstance(entry, str):
112             try:
113                 return int(entry)
114             except ValueError:
115                 pass
116             try:
117                 return float(entry)
118             except ValueError:
119                 pass
120         elif isinstance(entry, dict):
121             return {self._decode(key): self._decode(value) for key, value in entry.items()}
122         elif isinstance(entry, list):
123             return [self._decode(value) for value in entry]
124
125         return entry
126
127
128 def dict_subset(dict_, key_subset):
129     return {key: value for key, value in dict_.items() if key in key_subset}
130
131
132 def zipped_transform(fn, *iterables, inplace=False):
133     result = (iterables if inplace else tuple([] for _ in range(len(iterables))))
134     updater = (list.__setitem__ if inplace else lambda container, _, entry: container.append(entry))
135
136     for idx, values in enumerate(zip(*iterables)):
137         iter_res = fn(*values)
138         if not iter_res:
139             continue
140
141         for dst, res in zip(result, iter_res):
142             updater(dst, idx, res)
143
144     return result
145
146
147 def overrides(obj, attribute_name, base=None):
148     cls = obj if isinstance(obj, type) else obj.__class__
149
150     base = base or cls.__bases__[0]
151     obj_attr = getattr(cls, attribute_name, None)
152     base_attr = getattr(base, attribute_name, None)
153
154     return obj_attr and obj_attr != base_attr
155
156
157 def enum_values(enum):
158     return [member.value for member in enum]
159
160
161 def get_size_from_config(config, allow_none=False):
162     if contains_all(config, ('size', 'dst_width', 'dst_height')):
163         warn('All parameters: size, dst_width, dst_height are provided. Size will be used. '
164              'You should specify only size or pair values des_width, dst_height in config.')
165     if 'size' in config:
166         return config['size'], config['size']
167     if contains_all(config, ('dst_width', 'dst_height')):
168         return config['dst_height'], config['dst_width']
169     if not allow_none:
170         raise ValueError('Either size or dst_width and dst_height required')
171
172     return None, None
173
174
175 def get_size_3d_from_config(config, allow_none=False):
176     if contains_all(config, ('size', 'dst_width', 'dst_height', 'dst_volume')):
177         warn('All parameters: size, dst_width, dst_height, dst_volume are provided. Size will be used. '
178              'You should specify only size or three values des_width, dst_height, dst_volume in config.')
179     if 'size' in config:
180         return config['size'], config['size'], config['size']
181     if contains_all(config, ('dst_width', 'dst_height', 'dst_volume')):
182         return config['dst_height'], config['dst_width'], config['dst_volume']
183     if not allow_none:
184         raise ValueError('Either size or dst_width and dst_height required')
185
186     return config.get('dst_height'), config.get('dst_width'), config.get('dst_volume')
187
188
189 def in_interval(value, interval):
190     minimum = interval[0]
191     maximum = interval[1] if len(interval) >= 2 else None
192
193     if not maximum:
194         return minimum <= value
195
196     return minimum <= value < maximum
197
198
199 def finalize_metric_result(values, names):
200     result_values, result_names = [], []
201     for value, name in zip(values, names):
202         if np.isnan(value):
203             continue
204
205         result_values.append(value)
206         result_names.append(name)
207
208     return result_values, result_names
209
210
211 def get_representations(values, representation_source):
212     return np.reshape([value.get(representation_source) for value in values], -1)
213
214
215 def get_supported_representations(container, supported_types):
216     if np.shape(container) == ():
217         container = [container]
218
219     return list(filter(lambda rep: check_representation_type(rep, supported_types), container))
220
221
222 def check_representation_type(representation, representation_types):
223     for representation_type in representation_types:
224         if type(representation).__name__ == representation_type.__name__:
225             return True
226     return False
227
228
229 def is_single_metric_source(source):
230     if not source:
231         return False
232
233     return np.size(source.split(',')) == 1
234
235
236 def read_txt(file: Union[str, Path], sep='\n', **kwargs):
237     def is_empty(string):
238         return not string or string.isspace()
239
240     with get_path(file).open() as content:
241         content = content.read(**kwargs).split(sep)
242         content = list(filter(lambda string: not is_empty(string), content))
243
244         return list(map(str.strip, content))
245
246
247 def read_xml(file: Union[str, Path], *args, **kwargs):
248     return et.parse(str(get_path(file)), *args, **kwargs).getroot()
249
250
251 def read_json(file: Union[str, Path], *args, **kwargs):
252     with get_path(file).open() as content:
253         return json.load(content, *args, **kwargs)
254
255
256 def read_pickle(file: Union[str, Path], *args, **kwargs):
257     with get_path(file).open('rb') as content:
258         return pickle.load(content, *args, **kwargs)
259
260
261 def read_yaml(file: Union[str, Path], *args, **kwargs):
262     # yaml does not keep order of keys in dictionaries but it is important for reading pre/post processing
263     yaml.add_representer(collections.OrderedDict, lambda dumper, data: dumper.represent_dict(data.items()))
264     yaml.add_constructor(
265         yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG,
266         lambda loader, node: collections.OrderedDict(loader.construct_pairs(node))
267     )
268
269     with get_path(file).open() as content:
270         return yaml.load(content, Loader=yaml.SafeLoader, *args, **kwargs)
271
272
273 def read_csv(file: Union[str, Path], *args, **kwargs):
274     with get_path(file).open() as content:
275         return list(csv.DictReader(content, *args, **kwargs))
276
277
278 def extract_image_representations(image_representations):
279     images = [rep.data for rep in image_representations]
280     meta = [rep.metadata for rep in image_representations]
281
282     return images, meta
283
284
285 def convert_bboxes_xywh_to_x1y1x2y2(x_coord, y_coord, width, height):
286     return x_coord, y_coord, x_coord + width, y_coord + height
287
288
289 def get_or_parse_value(item, supported_values, default=None):
290     if isinstance(item, str):
291         item = item.lower()
292         if item in supported_values:
293             return supported_values[item]
294
295         try:
296             return string_to_tuple(item)
297         except ValueError:
298             message = 'Invalid value "{}", expected one of precomputed: ({}) or list of values'.format(
299                 item, ', '.join(supported_values.keys())
300             )
301             raise ValueError(message)
302
303     if isinstance(item, (float, int)):
304         return (item, )
305
306     return default
307
308
309 def string_to_bool(string):
310     return string.lower() in ['yes', 'true', 't', '1']
311
312
313 def get_key_by_value(container, target):
314     for key, value in container.items():
315         if value == target:
316             return key
317
318     return None
319
320
321 def format_key(key):
322     return '--{}'.format(key)
323
324
325 def to_lower_register(str_list):
326     return list(map(lambda item: item.lower() if item else None, str_list))
327
328
329 def polygon_from_points(points):
330     return Polygon(points)
331
332
333 def remove_difficult(difficult, indexes):
334     new_difficult = []
335     decrementor = 0
336     id_difficult = 0
337     id_removed = 0
338     while id_difficult < len(difficult) and id_removed < len(indexes):
339         if difficult[id_difficult] < indexes[id_removed]:
340             new_difficult.append(difficult[id_difficult] - decrementor)
341             id_difficult += 1
342         else:
343             decrementor += 1
344             id_removed += 1
345
346     return new_difficult
347
348
349 def convert_to_range(entry):
350     entry_range = entry
351     if isinstance(entry, str):
352         entry_range = string_to_tuple(entry_range)
353     elif not isinstance(entry_range, tuple) and not isinstance(entry_range, list):
354         entry_range = [entry_range]
355
356     return entry_range
357
358
359 def add_input_shape_to_meta(meta, shape):
360     meta['input_shape'] = shape
361     return meta