2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
19 from ..representation import ContainerRepresentation
20 from ..config import ConfigValidator, StringField, ConfigError, BaseField
21 from ..dependency import ClassProvider
25 check_representation_type,
26 get_supported_representations,
31 class BasePostprocessorConfig(ConfigValidator):
33 annotation_source = BaseField(optional=True)
34 prediction_source = BaseField(optional=True)
37 class Postprocessor(ClassProvider):
38 __provider_type__ = 'postprocessor'
43 def __init__(self, config, name=None, meta=None, state=None):
48 self.image_size = None
50 self.annotation_source = self.config.get('annotation_source')
51 if self.annotation_source and not isinstance(self.annotation_source, list):
52 self.annotation_source = string_to_list(self.annotation_source)
54 self.prediction_source = self.config.get('prediction_source')
55 if self.prediction_source and not isinstance(self.prediction_source, list):
56 self.prediction_source = string_to_list(self.prediction_source)
58 self.validate_config()
61 def __call__(self, *args, **kwargs):
62 return self.process_all(*args, **kwargs)
67 def process_image(self, annotation, prediction):
68 raise NotImplementedError
70 def process(self, annotation, prediction):
71 image_size = annotation[0].metadata.get('image_size') if not None in annotation else None
72 self.image_size = None
74 self.image_size = image_size[0]
75 self.process_image(annotation, prediction)
77 return annotation, prediction
79 def process_all(self, annotations, predictions):
80 zipped_transform(self.process, zipped_transform(self.get_entries, annotations, predictions))
81 return annotations, predictions
86 def validate_config(self):
87 BasePostprocessorConfig(
88 self.name, on_extra_argument=BasePostprocessorConfig.ERROR_ON_EXTRA_ARGUMENT
89 ).validate(self.config)
91 def get_entries(self, annotation, prediction):
92 message_not_found = '{}: {} is not found in container'
93 message_incorrect_type = "Incorrect type of {}. Postprocessor {} can work only with {}"
95 def resolve_container(container, supported_types, entry_name, sources=None):
96 if not isinstance(container, ContainerRepresentation):
98 message = 'Warning: {}_source can be applied only to container. Default value will be used'
99 warnings.warn(message.format(entry_name))
104 return get_supported_representations(container.values(), supported_types)
107 for source in sources:
108 representation = container.get(source)
109 if not representation:
110 raise ConfigError(message_not_found.format(entry_name, source))
112 if supported_types and not check_representation_type(representation, supported_types):
113 raise TypeError(message_incorrect_type.format(entry_name, self.name, ','.join(supported_types)))
115 entries.append(representation)
119 annotation_entries = resolve_container(annotation, self.annotation_types, 'annotation', self.annotation_source)
120 prediction_entries = resolve_container(prediction, self.prediction_types, 'prediction', self.prediction_source)
122 return annotation_entries, prediction_entries
125 class ApplyToOption(Enum):
126 ANNOTATION = 'annotation'
127 PREDICTION = 'prediction'
131 class PostprocessorWithTargetsConfigValidator(BasePostprocessorConfig):
132 apply_to = StringField(optional=True, choices=enum_values(ApplyToOption))
135 class PostprocessorWithSpecificTargets(Postprocessor):
136 def validate_config(self):
137 _config_validator = PostprocessorWithTargetsConfigValidator(
138 self.__provider__, on_extra_argument=PostprocessorWithTargetsConfigValidator.ERROR_ON_EXTRA_ARGUMENT
140 _config_validator.validate(self.config)
143 apply_to = self.config.get('apply_to')
144 self.apply_to = ApplyToOption(apply_to) if apply_to else None
146 if (self.annotation_source or self.prediction_source) and self.apply_to:
147 raise ConfigError("apply_to and sources both provided. You need specify only one from them")
149 if not self.annotation_source and not self.prediction_source and not self.apply_to:
150 raise ConfigError("apply_to or annotation_source or prediction_source required for {}".format(self.name))
154 def process(self, annotation, prediction):
155 image_size = annotation[0].metadata.get('image_size') if not None in annotation else None
156 self.image_size = None
158 self.image_size = image_size[0]
159 target_annotations, target_predictions = None, None
160 if self.annotation_source or self.prediction_source:
161 target_annotations, target_predictions = self._choose_targets_using_sources(annotation, prediction)
164 target_annotations, target_predictions = self._choose_targets_using_apply_to(annotation, prediction)
166 if not target_annotations and not target_predictions:
167 raise ValueError("Suitable targets for {} not found".format(self.name))
169 self.process_image(target_annotations, target_predictions)
170 return annotation, prediction
172 def _choose_targets_using_sources(self, annotations, predictions):
173 target_annotations = annotations if self.annotation_source else []
174 target_predictions = predictions if self.prediction_source else []
176 return target_annotations, target_predictions
178 def _choose_targets_using_apply_to(self, annotations, predictions):
179 targets_specification = {
180 ApplyToOption.ANNOTATION: (annotations, []),
181 ApplyToOption.PREDICTION: ([], predictions),
182 ApplyToOption.ALL: (annotations, predictions)
185 return targets_specification[self.apply_to]
187 def process_image(self, annotation, prediction):
188 raise NotImplementedError