Publishing 2019 R1 content
[platform/upstream/dldt.git] / tools / accuracy_checker / accuracy_checker / config / config_validator.py
1 """
2 Copyright (c) 2019 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 """
16
17 import enum
18 import math
19 import re
20 import warnings
21 from collections import OrderedDict
22 from copy import copy
23 from functools import partial
24 from pathlib import Path
25
26 from ..utils import get_path, string_to_bool
27
28
29 class ConfigError(ValueError):
30     pass
31
32
33 class BaseValidator:
34     def __init__(self, on_error=None, additional_validator=None):
35         self.on_error = on_error
36         self.additional_validator = additional_validator
37
38         self.field_uri = None
39
40     def validate(self, entry, field_uri=None):
41         field_uri = field_uri or self.field_uri
42         if self.additional_validator and not self.additional_validator(entry, field_uri):
43             self.raise_error(entry, field_uri)
44
45     def raise_error(self, value, field_uri, reason=None):
46         if self.on_error:
47             self.on_error(value, field_uri, reason)
48
49         error_message = 'Invalid value "{value}" for {field_uri}'.format(value=value, field_uri=field_uri)
50         if reason:
51             error_message = '{error_message}: {reason}'.format(error_message=error_message, reason=reason)
52
53         raise ConfigError(error_message.format(value, field_uri))
54
55
56 class _ExtraArgumentBehaviour(enum.Enum):
57     WARN = 'warn'
58     IGNORE = 'ignore'
59     ERROR = 'error'
60
61
62 def _is_dict_like(entry):
63     return hasattr(entry, '__iter__') and hasattr(entry, '__getitem__')
64
65
66 class ConfigValidator(BaseValidator):
67     WARN_ON_EXTRA_ARGUMENT = _ExtraArgumentBehaviour.WARN
68     ERROR_ON_EXTRA_ARGUMENT = _ExtraArgumentBehaviour.ERROR
69     IGNORE_ON_EXTRA_ARGUMENT = _ExtraArgumentBehaviour.IGNORE
70
71     def __init__(self, config_uri, on_extra_argument=WARN_ON_EXTRA_ARGUMENT, **kwargs):
72         super().__init__(**kwargs)
73         self.on_extra_argument = on_extra_argument
74
75         self.fields = OrderedDict()
76         self.field_uri = config_uri
77         for name in dir(self):
78             value = getattr(self, name)
79             if not isinstance(value, BaseField):
80                 continue
81
82             field_copy = copy(value)
83             field_copy.field_uri = "{}.{}".format(config_uri, name)
84             self.fields[name] = field_copy
85
86     def validate(self, entry, field_uri=None):
87         super().validate(entry, field_uri)
88         field_uri = field_uri or self.field_uri
89         if not _is_dict_like(entry):
90             raise ConfigError("{} is expected to be dict-like".format(field_uri))
91
92         extra_arguments = []
93         for key in entry:
94             if key not in self.fields:
95                 extra_arguments.append(key)
96                 continue
97
98             self.fields[key].validate(entry[key])
99
100         required_fields = set(name for name, value in self.fields.items() if not value.optional)
101         missing_arguments = required_fields.difference(entry)
102
103         if missing_arguments:
104             arguments = ', '.join(map(str, missing_arguments))
105             self.raise_error(
106                 entry, field_uri, "Invalid config for {}: missing required fields: {}".format(field_uri, arguments)
107             )
108
109         if extra_arguments:
110             unknown_options_error = "specifies unknown options: {}".format(extra_arguments)
111             message = "{} {}".format(field_uri, unknown_options_error)
112
113             if self.on_extra_argument == _ExtraArgumentBehaviour.WARN:
114                 warnings.warn(message)
115             if self.on_extra_argument == _ExtraArgumentBehaviour.ERROR:
116                 self.raise_error(entry, field_uri, message)
117
118     @property
119     def known_fields(self):
120         return set(self.fields)
121
122     def raise_error(self, value, field_uri, reason=None):
123         if self.on_error:
124             self.on_error(value, field_uri, reason)
125         else:
126             raise ConfigError(reason)
127
128
129 class BaseField(BaseValidator):
130     def __init__(self, optional=False, allow_none=False, **kwargs):
131         super().__init__(**kwargs)
132         self.optional = optional
133         self.allow_none = allow_none
134
135     def validate(self, entry, field_uri=None):
136         super().validate(entry, field_uri)
137         field_uri = field_uri or self.field_uri
138         if not self.allow_none and entry is None:
139             raise ConfigError("{} is not allowed to be None".format(field_uri))
140
141     @property
142     def type(self):
143         return str
144
145
146 class StringField(BaseField):
147     def __init__(self, choices=None, regex=None, case_sensitive=False, **kwargs):
148         super().__init__(**kwargs)
149         self.choices = choices if case_sensitive or not choices else list(map(str.lower, choices))
150         self.regex = re.compile(regex, flags=re.IGNORECASE if not case_sensitive else 0) if regex else None
151         self.case_sensitive = case_sensitive
152
153     def validate(self, entry, field_uri=None):
154         super().validate(entry, field_uri)
155         if entry is None:
156             return
157
158         field_uri = field_uri or self.field_uri
159         source_entry = entry
160
161         if not isinstance(entry, str):
162             raise ConfigError("{} is expected to be str".format(source_entry))
163
164         if not self.case_sensitive:
165             entry = entry.lower()
166
167         if self.choices and entry not in self.choices:
168             reason = "unsupported option, expected one of: {}".format(', '.join(map(str, self.choices)))
169             self.raise_error(source_entry, field_uri, reason)
170
171         if self.regex and not self.regex.match(entry):
172             self.raise_error(source_entry, field_uri, reason=None)
173
174     @property
175     def type(self):
176         return str
177
178
179 class DictField(BaseField):
180     def __init__(self, key_type=None, value_type=None, validate_keys=True, validate_values=True, allow_empty=True,
181                  **kwargs):
182         super().__init__(**kwargs)
183         self.validate_keys = validate_keys if key_type else False
184         self.validate_values = validate_values if value_type else False
185         self.key_type = _get_field_type(key_type)
186         self.value_type = _get_field_type(value_type)
187
188         self.allow_empty = allow_empty
189
190     def validate(self, entry, field_uri=None):
191         super().validate(entry, field_uri)
192         if entry is None:
193             return
194
195         field_uri = field_uri or self.field_uri
196         if not isinstance(entry, dict):
197             raise ConfigError("{} is expected to be dict".format(field_uri))
198
199         if not entry and not self.allow_empty:
200             self.raise_error(entry, field_uri, "value is empty")
201
202         for k, v in entry.items():
203             if self.validate_keys:
204                 uri = "{}.keys.{}".format(field_uri, k)
205                 self.key_type.validate(k, uri)
206
207             if self.validate_values:
208                 uri = "{}.{}".format(field_uri, k)
209
210                 self.value_type.validate(v, uri)
211     @property
212     def type(self):
213         return dict
214
215
216 class ListField(BaseField):
217     def __init__(self, value_type=None, validate_values=True, allow_empty=True, **kwargs):
218         super().__init__(**kwargs)
219         self.validate_values = validate_values if value_type else False
220         self.value_type = _get_field_type(value_type)
221         self.allow_empty = allow_empty
222
223     def validate(self, entry, field_uri=None):
224         super().validate(entry, field_uri)
225         if entry is None:
226             return
227
228         if not isinstance(entry, list):
229             raise ConfigError("{} is expected to be list".format(field_uri))
230
231         if not entry and not self.allow_empty:
232             self.raise_error(entry, field_uri, "value is empty")
233
234         if self.validate_values:
235             for i, val in enumerate(entry):
236                 self.value_type.validate(val, "{}[{}]".format(val, i))
237
238     @property
239     def type(self):
240         return list
241
242
243 class NumberField(BaseField):
244     def __init__(self, floats=True, min_value=None, max_value=None, allow_inf=False, allow_nan=False, **kwargs):
245         super().__init__(**kwargs)
246         self.floats = floats
247         self.min = min_value
248         self.max = max_value
249         self.allow_inf = allow_inf
250         self.allow_nan = allow_nan
251
252     def validate(self, entry, field_uri=None):
253         super().validate(entry, field_uri)
254         if entry is None:
255             return
256
257         field_uri = field_uri or self.field_uri
258         if not self.floats and isinstance(entry, float):
259             raise ConfigError("{} is expected to be int".format(field_uri))
260         if not isinstance(entry, int) and not isinstance(entry, float):
261             raise ConfigError("{} is expected to be number".format(field_uri))
262
263         if self.min is not None and entry < self.min:
264             reason = "value is less than minimal allowed - {}".format(self.min)
265             self.raise_error(entry, field_uri, reason)
266         if self.max is not None and entry > self.max:
267             reason = "value is greater than maximal allowed - {}".format(self.max)
268             self.raise_error(entry, field_uri, reason)
269
270         if math.isinf(entry) and not self.allow_inf:
271             self.raise_error(entry, field_uri, "value is infinity")
272         if math.isnan(entry) and not self.allow_nan:
273             self.raise_error(entry, field_uri, "value is NaN")
274
275     @property
276     def type(self):
277         return float if self.floats else int
278
279
280 class PathField(BaseField):
281     def __init__(self, is_directory=False, **kwargs):
282         super().__init__(**kwargs)
283         self.is_directory = is_directory
284
285     def validate(self, entry, field_uri=None):
286         super().validate(entry, field_uri)
287         if entry is None:
288             return
289
290         field_uri = field_uri or self.field_uri
291         try:
292             get_path(entry, self.is_directory)
293         except TypeError:
294             self.raise_error(entry, field_uri, "values is expected to be path-like")
295         except FileNotFoundError:
296             self.raise_error(entry, field_uri, "path does not exist")
297         except NotADirectoryError:
298             self.raise_error(entry, field_uri, "path is not a directory")
299         except IsADirectoryError:
300             self.raise_error(entry, field_uri, "path is a directory, regular file expected")
301
302     @property
303     def type(self):
304         return Path
305
306
307 class BoolField(BaseField):
308     def validate(self, entry, field_uri=None):
309         super().validate(entry, field_uri)
310         if entry is None:
311             return
312
313         field_uri = field_uri or self.field_uri
314         if not isinstance(entry, bool):
315             raise ConfigError("{} is expected to be bool".format(field_uri))
316
317     @property
318     def type(self):
319         return string_to_bool
320
321
322 def _get_field_type(key_type):
323     if not isinstance(key_type, BaseField):
324         type_ = _TYPE_TO_FIELD_CLASS.get(key_type)
325         if callable(type_):
326             return type_()
327
328     return key_type
329
330
331 _TYPE_TO_FIELD_CLASS = {
332     int: partial(NumberField, floats=False),
333     float: partial(NumberField, floats=True),
334     dict: partial(DictField, validate_keys=False, validate_values=False),
335     list: partial(ListField, validate_values=False),
336     Path: PathField,
337     str: StringField,
338     bool: BoolField,
339 }