Publishing 2019 R1 content
[platform/upstream/dldt.git] / tools / accuracy_checker / accuracy_checker / config / config_reader.py
1 """
2 Copyright (c) 2019 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 """
16
17 import copy
18 from pathlib import Path
19
20 import warnings
21
22 from ..utils import read_yaml, to_lower_register, contains_any
23 from .config_validator import ConfigError
24
25
26 class ConfigReader:
27     """
28     Class for parsing input config.
29     """
30
31     @staticmethod
32     def merge(arguments):
33         """
34         Args:
35             arguments: command-line arguments.
36         Returns:
37             dictionary containing configuration.
38         """
39
40         global_config, local_config = ConfigReader._read_configs(arguments)
41         if not local_config:
42             raise ConfigError('Missing local config')
43
44         ConfigReader._check_local_config(local_config)
45         ConfigReader._prepare_global_configs(global_config)
46
47         config = ConfigReader._merge_configs(global_config, local_config)
48
49         ConfigReader._provide_cmd_arguments(arguments, config)
50         ConfigReader._merge_paths_with_prefixes(arguments, config)
51         ConfigReader._filter_launchers(config, arguments)
52
53         return config
54
55     @staticmethod
56     def _read_configs(arguments):
57         global_config = read_yaml(arguments.definitions) if arguments.definitions else None
58         local_config = read_yaml(arguments.config)
59
60         return global_config, local_config
61
62     @staticmethod
63     def _check_local_config(config):
64         models = config.get('models')
65         if not models:
66             raise ConfigError('Missed "{}" in local config'.format('models'))
67
68         def _is_requirements_missed(target, requirements):
69             return list(filter(lambda entry: not target.get(entry), requirements))
70
71         required_model_entries = ['name', 'launchers', 'datasets']
72         required_dataset_entries = ['name']
73         required_dataset_error = 'Model {} must specify {} for each dataset'
74         for model in models:
75             if _is_requirements_missed(model, required_model_entries):
76                 raise ConfigError('Each model must specify {}'.format(required_model_entries))
77
78             if list(filter(lambda entry: _is_requirements_missed(entry, required_dataset_entries), model['datasets'])):
79                 raise ConfigError(required_dataset_error.format(model['name'], ','.join(required_dataset_entries)))
80
81     @staticmethod
82     def _prepare_global_configs(global_configs):
83         if not global_configs or 'datasets' not in global_configs:
84             return
85
86         datasets = global_configs['datasets']
87
88         def merge(local_entries, global_entries, identifier):
89             if not local_entries or not global_entries:
90                 return
91
92             for i, local in enumerate(local_entries):
93                 local_identifier = local.get(identifier)
94                 if not local_identifier:
95                     continue
96
97                 local_entries[i] = ConfigReader._merge_configs_by_identifier(global_entries, local, identifier)
98
99         for dataset in datasets:
100             merge(dataset.get('preprocessing'), global_configs.get('preprocessing'), 'type')
101             merge(dataset.get('metrics'), global_configs.get('metrics'), 'type')
102             merge(dataset.get('postprocessing'), global_configs.get('postprocessing'), 'type')
103
104     @staticmethod
105     def _merge_configs(global_configs, local_config):
106         config = copy.deepcopy(local_config)
107         if not global_configs:
108             return config
109
110         models = config.get('models')
111         for model in models:
112             for i, launcher_entry in enumerate(model['launchers']):
113                 model['launchers'][i] = ConfigReader._merge_configs_by_identifier(
114                     global_configs['launchers'], launcher_entry, 'framework'
115                 )
116
117             for i, dataset in enumerate(model['datasets']):
118                 model['datasets'][i] = ConfigReader._merge_configs_by_identifier(
119                     global_configs['datasets'], dataset, 'name'
120                 )
121
122         return config
123
124     @staticmethod
125     def _merge_configs_by_identifier(global_config, local_config, identifier):
126         local_identifier = local_config.get(identifier)
127         if local_identifier is None:
128             return local_config
129
130         matched = []
131         for config in global_config:
132             global_identifier = config.get(identifier)
133             if global_identifier is None:
134                 continue
135
136             if global_identifier != local_identifier:
137                 continue
138
139             matched.append(config)
140
141         config = copy.deepcopy(matched[0] if matched else {})
142         for key, value in local_config.items():
143             config[key] = value
144
145         return config
146
147     @staticmethod
148     def _merge_paths_with_prefixes(arguments, config):
149         args = arguments if isinstance(arguments, dict) else vars(arguments)
150         entries_paths = {
151             'launchers': {
152                 'model': 'models',
153                 'weights': 'models',
154                 'caffe_model': 'models',
155                 'caffe_weights': 'models',
156                 'tf_model': 'models',
157                 'mxnet_weights': 'models',
158                 'onnx_model': 'models',
159                 'kaldi_model': 'models',
160                 'cpu_extensions': 'extensions',
161                 'gpu_extensions': 'extensions',
162                 'bitstream': 'bitstreams',
163                 'affinity_map' : 'affinity_map'
164             },
165             'datasets': {
166                 'data_source': 'source',
167                 'segmentation_masks_source': 'source',
168                 'annotation': 'annotations',
169                 'dataset_meta': 'annotations'
170             }
171         }
172
173         def merge_entry_paths(keys, value):
174             for field, argument in keys.items():
175                 if field not in value:
176                     continue
177
178                 config_path = Path(value[field])
179                 if config_path.is_absolute():
180                     value[field] = Path(value[field])
181                     continue
182
183                 if not args[argument]:
184                     continue
185
186                 value[field] = args[argument] / config_path
187
188         def create_command_line_for_conversion(config):
189             mapping = {}
190             value = 'source'
191             for key in config:
192                 if key.endswith('file') or key.endswith('dir'):
193                     mapping[key] = value
194             return mapping
195
196         for model in config['models']:
197             for entry, command_line_arg in entries_paths.items():
198                 if entry not in model:
199                     continue
200
201                 for config_entry in model[entry]:
202                     if entry == 'datasets':
203                         annotation_conversion_config = config_entry.get('annotation_conversion')
204                         if annotation_conversion_config:
205                             command_line_conversion = (create_command_line_for_conversion(annotation_conversion_config))
206                             merge_entry_paths(command_line_conversion, annotation_conversion_config)
207                     merge_entry_paths(command_line_arg, config_entry)
208
209     @staticmethod
210     def _provide_cmd_arguments(arguments, config):
211         def merge_converted_model_path(converted_models_dir, mo_output_dir):
212             if mo_output_dir:
213                 mo_output_dir = Path(mo_output_dir)
214                 if mo_output_dir.is_absolute():
215                     return mo_output_dir
216                 return converted_models_dir / mo_output_dir
217             return converted_models_dir
218
219         additional_keys = [
220             'model_optimizer', 'tf_custom_op_config_dir',
221             'tf_obj_detection_api_pipeline_config_path',
222             'cpu_extensions_mode'
223         ]
224         arguments_dict = arguments if isinstance(arguments, dict) else vars(arguments)
225         update_launcher_entry = {}
226
227         for key in additional_keys:
228             value = arguments_dict.get(key)
229             if value:
230                 update_launcher_entry['_{}'.format(key)] = value
231
232         for model in config['models']:
233             for launcher_entry in model['launchers']:
234                 if launcher_entry['framework'].lower() != 'dlsdk':
235                     continue
236
237                 launcher_entry.update(update_launcher_entry)
238                 models_prefix = arguments.models
239                 if models_prefix:
240                     launcher_entry['_models_prefix'] = models_prefix
241
242                 if not arguments.converted_models:
243                     continue
244
245                 mo_params = launcher_entry.get('mo_params', {})
246
247                 mo_params.update({
248                     'output_dir': merge_converted_model_path(arguments.converted_models, mo_params.get('output_dir'))
249                 })
250
251                 launcher_entry['mo_params'] = mo_params
252
253                 if arguments.aocl:
254                     launcher_entry['_aocl'] = arguments.aocl
255
256     @staticmethod
257     def _filter_launchers(config, arguments):
258         def filtered(launcher, targets):
259             target_tags = args.get('target_tags') or []
260             if target_tags:
261                 if not contains_any(target_tags, launcher.get('tags', [])):
262                     return True
263
264             config_framework = launcher['framework'].lower()
265             target_framework = (args.get('target_framework') or config_framework).lower()
266             if config_framework != target_framework:
267                 return True
268
269             return targets and launcher.get('device', '').lower() not in targets
270
271         args = arguments if isinstance(arguments, dict) else vars(arguments)
272         target_devices = to_lower_register(args.get('target_devices') or [])
273
274         for model in config['models']:
275             launchers = model['launchers']
276             launchers = [launcher for launcher in launchers if not filtered(launcher, target_devices)]
277
278             if not launchers:
279                 warnings.warn('Model "{}" has no launchers'.format(model['name']))
280
281             model['launchers'] = launchers