2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
18 from pathlib import Path
22 from ..utils import read_yaml, to_lower_register, contains_any
23 from .config_validator import ConfigError
28 Class for parsing input config.
35 arguments: command-line arguments.
37 dictionary containing configuration.
40 global_config, local_config = ConfigReader._read_configs(arguments)
42 raise ConfigError('Missing local config')
44 ConfigReader._check_local_config(local_config)
45 ConfigReader._prepare_global_configs(global_config)
47 config = ConfigReader._merge_configs(global_config, local_config)
49 ConfigReader._provide_cmd_arguments(arguments, config)
50 ConfigReader._merge_paths_with_prefixes(arguments, config)
51 ConfigReader._filter_launchers(config, arguments)
56 def _read_configs(arguments):
57 global_config = read_yaml(arguments.definitions) if arguments.definitions else None
58 local_config = read_yaml(arguments.config)
60 return global_config, local_config
63 def _check_local_config(config):
64 models = config.get('models')
66 raise ConfigError('Missed "{}" in local config'.format('models'))
68 def _is_requirements_missed(target, requirements):
69 return list(filter(lambda entry: not target.get(entry), requirements))
71 required_model_entries = ['name', 'launchers', 'datasets']
72 required_dataset_entries = ['name']
73 required_dataset_error = 'Model {} must specify {} for each dataset'
75 if _is_requirements_missed(model, required_model_entries):
76 raise ConfigError('Each model must specify {}'.format(required_model_entries))
78 if list(filter(lambda entry: _is_requirements_missed(entry, required_dataset_entries), model['datasets'])):
79 raise ConfigError(required_dataset_error.format(model['name'], ','.join(required_dataset_entries)))
82 def _prepare_global_configs(global_configs):
83 if not global_configs or 'datasets' not in global_configs:
86 datasets = global_configs['datasets']
88 def merge(local_entries, global_entries, identifier):
89 if not local_entries or not global_entries:
92 for i, local in enumerate(local_entries):
93 local_identifier = local.get(identifier)
94 if not local_identifier:
97 local_entries[i] = ConfigReader._merge_configs_by_identifier(global_entries, local, identifier)
99 for dataset in datasets:
100 merge(dataset.get('preprocessing'), global_configs.get('preprocessing'), 'type')
101 merge(dataset.get('metrics'), global_configs.get('metrics'), 'type')
102 merge(dataset.get('postprocessing'), global_configs.get('postprocessing'), 'type')
105 def _merge_configs(global_configs, local_config):
106 config = copy.deepcopy(local_config)
107 if not global_configs:
110 models = config.get('models')
112 for i, launcher_entry in enumerate(model['launchers']):
113 model['launchers'][i] = ConfigReader._merge_configs_by_identifier(
114 global_configs['launchers'], launcher_entry, 'framework'
117 for i, dataset in enumerate(model['datasets']):
118 model['datasets'][i] = ConfigReader._merge_configs_by_identifier(
119 global_configs['datasets'], dataset, 'name'
125 def _merge_configs_by_identifier(global_config, local_config, identifier):
126 local_identifier = local_config.get(identifier)
127 if local_identifier is None:
131 for config in global_config:
132 global_identifier = config.get(identifier)
133 if global_identifier is None:
136 if global_identifier != local_identifier:
139 matched.append(config)
141 config = copy.deepcopy(matched[0] if matched else {})
142 for key, value in local_config.items():
148 def _merge_paths_with_prefixes(arguments, config):
149 args = arguments if isinstance(arguments, dict) else vars(arguments)
154 'caffe_model': 'models',
155 'caffe_weights': 'models',
156 'tf_model': 'models',
157 'mxnet_weights': 'models',
158 'onnx_model': 'models',
159 'kaldi_model': 'models',
160 'cpu_extensions': 'extensions',
161 'gpu_extensions': 'extensions',
162 'bitstream': 'bitstreams',
163 'affinity_map' : 'affinity_map'
166 'data_source': 'source',
167 'segmentation_masks_source': 'source',
168 'annotation': 'annotations',
169 'dataset_meta': 'annotations'
173 def merge_entry_paths(keys, value):
174 for field, argument in keys.items():
175 if field not in value:
178 config_path = Path(value[field])
179 if config_path.is_absolute():
180 value[field] = Path(value[field])
183 if not args[argument]:
186 value[field] = args[argument] / config_path
188 def create_command_line_for_conversion(config):
192 if key.endswith('file') or key.endswith('dir'):
196 for model in config['models']:
197 for entry, command_line_arg in entries_paths.items():
198 if entry not in model:
201 for config_entry in model[entry]:
202 if entry == 'datasets':
203 annotation_conversion_config = config_entry.get('annotation_conversion')
204 if annotation_conversion_config:
205 command_line_conversion = (create_command_line_for_conversion(annotation_conversion_config))
206 merge_entry_paths(command_line_conversion, annotation_conversion_config)
207 merge_entry_paths(command_line_arg, config_entry)
210 def _provide_cmd_arguments(arguments, config):
211 def merge_converted_model_path(converted_models_dir, mo_output_dir):
213 mo_output_dir = Path(mo_output_dir)
214 if mo_output_dir.is_absolute():
216 return converted_models_dir / mo_output_dir
217 return converted_models_dir
220 'model_optimizer', 'tf_custom_op_config_dir',
221 'tf_obj_detection_api_pipeline_config_path',
222 'cpu_extensions_mode'
224 arguments_dict = arguments if isinstance(arguments, dict) else vars(arguments)
225 update_launcher_entry = {}
227 for key in additional_keys:
228 value = arguments_dict.get(key)
230 update_launcher_entry['_{}'.format(key)] = value
232 for model in config['models']:
233 for launcher_entry in model['launchers']:
234 if launcher_entry['framework'].lower() != 'dlsdk':
237 launcher_entry.update(update_launcher_entry)
238 models_prefix = arguments.models
240 launcher_entry['_models_prefix'] = models_prefix
242 if not arguments.converted_models:
245 mo_params = launcher_entry.get('mo_params', {})
248 'output_dir': merge_converted_model_path(arguments.converted_models, mo_params.get('output_dir'))
251 launcher_entry['mo_params'] = mo_params
254 launcher_entry['_aocl'] = arguments.aocl
257 def _filter_launchers(config, arguments):
258 def filtered(launcher, targets):
259 target_tags = args.get('target_tags') or []
261 if not contains_any(target_tags, launcher.get('tags', [])):
264 config_framework = launcher['framework'].lower()
265 target_framework = (args.get('target_framework') or config_framework).lower()
266 if config_framework != target_framework:
269 return targets and launcher.get('device', '').lower() not in targets
271 args = arguments if isinstance(arguments, dict) else vars(arguments)
272 target_devices = to_lower_register(args.get('target_devices') or [])
274 for model in config['models']:
275 launchers = model['launchers']
276 launchers = [launcher for launcher in launchers if not filtered(launcher, target_devices)]
279 warnings.warn('Model "{}" has no launchers'.format(model['name']))
281 model['launchers'] = launchers