2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from collections import OrderedDict
23 from functools import partial
24 from pathlib import Path
26 from ..utils import get_path, string_to_bool
29 class ConfigError(ValueError):
34 def __init__(self, on_error=None, additional_validator=None):
35 self.on_error = on_error
36 self.additional_validator = additional_validator
40 def validate(self, entry, field_uri=None):
41 field_uri = field_uri or self.field_uri
42 if self.additional_validator and not self.additional_validator(entry, field_uri):
43 self.raise_error(entry, field_uri)
45 def raise_error(self, value, field_uri, reason=None):
47 self.on_error(value, field_uri, reason)
49 error_message = 'Invalid value "{value}" for {field_uri}'.format(value=value, field_uri=field_uri)
51 error_message = '{error_message}: {reason}'.format(error_message=error_message, reason=reason)
53 raise ConfigError(error_message.format(value, field_uri))
56 class _ExtraArgumentBehaviour(enum.Enum):
62 def _is_dict_like(entry):
63 return hasattr(entry, '__iter__') and hasattr(entry, '__getitem__')
66 class ConfigValidator(BaseValidator):
67 WARN_ON_EXTRA_ARGUMENT = _ExtraArgumentBehaviour.WARN
68 ERROR_ON_EXTRA_ARGUMENT = _ExtraArgumentBehaviour.ERROR
69 IGNORE_ON_EXTRA_ARGUMENT = _ExtraArgumentBehaviour.IGNORE
71 def __init__(self, config_uri, on_extra_argument=WARN_ON_EXTRA_ARGUMENT, **kwargs):
72 super().__init__(**kwargs)
73 self.on_extra_argument = on_extra_argument
75 self.fields = OrderedDict()
76 self.field_uri = config_uri
77 for name in dir(self):
78 value = getattr(self, name)
79 if not isinstance(value, BaseField):
82 field_copy = copy(value)
83 field_copy.field_uri = "{}.{}".format(config_uri, name)
84 self.fields[name] = field_copy
86 def validate(self, entry, field_uri=None):
87 super().validate(entry, field_uri)
88 field_uri = field_uri or self.field_uri
89 if not _is_dict_like(entry):
90 raise ConfigError("{} is expected to be dict-like".format(field_uri))
94 if key not in self.fields:
95 extra_arguments.append(key)
98 self.fields[key].validate(entry[key])
100 required_fields = set(name for name, value in self.fields.items() if not value.optional)
101 missing_arguments = required_fields.difference(entry)
103 if missing_arguments:
104 arguments = ', '.join(map(str, missing_arguments))
106 entry, field_uri, "Invalid config for {}: missing required fields: {}".format(field_uri, arguments)
110 unknown_options_error = "specifies unknown options: {}".format(extra_arguments)
111 message = "{} {}".format(field_uri, unknown_options_error)
113 if self.on_extra_argument == _ExtraArgumentBehaviour.WARN:
114 warnings.warn(message)
115 if self.on_extra_argument == _ExtraArgumentBehaviour.ERROR:
116 self.raise_error(entry, field_uri, message)
119 def known_fields(self):
120 return set(self.fields)
122 def raise_error(self, value, field_uri, reason=None):
124 self.on_error(value, field_uri, reason)
126 raise ConfigError(reason)
129 class BaseField(BaseValidator):
130 def __init__(self, optional=False, allow_none=False, **kwargs):
131 super().__init__(**kwargs)
132 self.optional = optional
133 self.allow_none = allow_none
135 def validate(self, entry, field_uri=None):
136 super().validate(entry, field_uri)
137 field_uri = field_uri or self.field_uri
138 if not self.allow_none and entry is None:
139 raise ConfigError("{} is not allowed to be None".format(field_uri))
146 class StringField(BaseField):
147 def __init__(self, choices=None, regex=None, case_sensitive=False, **kwargs):
148 super().__init__(**kwargs)
149 self.choices = choices if case_sensitive or not choices else list(map(str.lower, choices))
150 self.regex = re.compile(regex, flags=re.IGNORECASE if not case_sensitive else 0) if regex else None
151 self.case_sensitive = case_sensitive
153 def validate(self, entry, field_uri=None):
154 super().validate(entry, field_uri)
158 field_uri = field_uri or self.field_uri
161 if not isinstance(entry, str):
162 raise ConfigError("{} is expected to be str".format(source_entry))
164 if not self.case_sensitive:
165 entry = entry.lower()
167 if self.choices and entry not in self.choices:
168 reason = "unsupported option, expected one of: {}".format(', '.join(map(str, self.choices)))
169 self.raise_error(source_entry, field_uri, reason)
171 if self.regex and not self.regex.match(entry):
172 self.raise_error(source_entry, field_uri, reason=None)
179 class DictField(BaseField):
180 def __init__(self, key_type=None, value_type=None, validate_keys=True, validate_values=True, allow_empty=True,
182 super().__init__(**kwargs)
183 self.validate_keys = validate_keys if key_type else False
184 self.validate_values = validate_values if value_type else False
185 self.key_type = _get_field_type(key_type)
186 self.value_type = _get_field_type(value_type)
188 self.allow_empty = allow_empty
190 def validate(self, entry, field_uri=None):
191 super().validate(entry, field_uri)
195 field_uri = field_uri or self.field_uri
196 if not isinstance(entry, dict):
197 raise ConfigError("{} is expected to be dict".format(field_uri))
199 if not entry and not self.allow_empty:
200 self.raise_error(entry, field_uri, "value is empty")
202 for k, v in entry.items():
203 if self.validate_keys:
204 uri = "{}.keys.{}".format(field_uri, k)
205 self.key_type.validate(k, uri)
207 if self.validate_values:
208 uri = "{}.{}".format(field_uri, k)
210 self.value_type.validate(v, uri)
216 class ListField(BaseField):
217 def __init__(self, value_type=None, validate_values=True, allow_empty=True, **kwargs):
218 super().__init__(**kwargs)
219 self.validate_values = validate_values if value_type else False
220 self.value_type = _get_field_type(value_type)
221 self.allow_empty = allow_empty
223 def validate(self, entry, field_uri=None):
224 super().validate(entry, field_uri)
228 if not isinstance(entry, list):
229 raise ConfigError("{} is expected to be list".format(field_uri))
231 if not entry and not self.allow_empty:
232 self.raise_error(entry, field_uri, "value is empty")
234 if self.validate_values:
235 for i, val in enumerate(entry):
236 self.value_type.validate(val, "{}[{}]".format(val, i))
243 class NumberField(BaseField):
244 def __init__(self, floats=True, min_value=None, max_value=None, allow_inf=False, allow_nan=False, **kwargs):
245 super().__init__(**kwargs)
249 self.allow_inf = allow_inf
250 self.allow_nan = allow_nan
252 def validate(self, entry, field_uri=None):
253 super().validate(entry, field_uri)
257 field_uri = field_uri or self.field_uri
258 if not self.floats and isinstance(entry, float):
259 raise ConfigError("{} is expected to be int".format(field_uri))
260 if not isinstance(entry, int) and not isinstance(entry, float):
261 raise ConfigError("{} is expected to be number".format(field_uri))
263 if self.min is not None and entry < self.min:
264 reason = "value is less than minimal allowed - {}".format(self.min)
265 self.raise_error(entry, field_uri, reason)
266 if self.max is not None and entry > self.max:
267 reason = "value is greater than maximal allowed - {}".format(self.max)
268 self.raise_error(entry, field_uri, reason)
270 if math.isinf(entry) and not self.allow_inf:
271 self.raise_error(entry, field_uri, "value is infinity")
272 if math.isnan(entry) and not self.allow_nan:
273 self.raise_error(entry, field_uri, "value is NaN")
277 return float if self.floats else int
280 class PathField(BaseField):
281 def __init__(self, is_directory=False, **kwargs):
282 super().__init__(**kwargs)
283 self.is_directory = is_directory
285 def validate(self, entry, field_uri=None):
286 super().validate(entry, field_uri)
290 field_uri = field_uri or self.field_uri
292 get_path(entry, self.is_directory)
294 self.raise_error(entry, field_uri, "values is expected to be path-like")
295 except FileNotFoundError:
296 self.raise_error(entry, field_uri, "path does not exist")
297 except NotADirectoryError:
298 self.raise_error(entry, field_uri, "path is not a directory")
299 except IsADirectoryError:
300 self.raise_error(entry, field_uri, "path is a directory, regular file expected")
307 class BoolField(BaseField):
308 def validate(self, entry, field_uri=None):
309 super().validate(entry, field_uri)
313 field_uri = field_uri or self.field_uri
314 if not isinstance(entry, bool):
315 raise ConfigError("{} is expected to be bool".format(field_uri))
319 return string_to_bool
322 def _get_field_type(key_type):
323 if not isinstance(key_type, BaseField):
324 type_ = _TYPE_TO_FIELD_CLASS.get(key_type)
331 _TYPE_TO_FIELD_CLASS = {
332 int: partial(NumberField, floats=False),
333 float: partial(NumberField, floats=True),
334 dict: partial(DictField, validate_keys=False, validate_values=False),
335 list: partial(ListField, validate_values=False),