Publishing 2019 R1 content
[platform/upstream/dldt.git] / tools / accuracy_checker / accuracy_checker / postprocessor / postprocessor.py
1 """
2 Copyright (c) 2019 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 """
16
17 import warnings
18 from enum import Enum
19 from ..representation import ContainerRepresentation
20 from ..config import ConfigValidator, StringField, ConfigError, BaseField
21 from ..dependency import ClassProvider
22 from ..utils import (
23     zipped_transform,
24     string_to_list,
25     check_representation_type,
26     get_supported_representations,
27     enum_values
28 )
29
30
31 class BasePostprocessorConfig(ConfigValidator):
32     type = StringField()
33     annotation_source = BaseField(optional=True)
34     prediction_source = BaseField(optional=True)
35
36
37 class Postprocessor(ClassProvider):
38     __provider_type__ = 'postprocessor'
39
40     annotation_types = ()
41     prediction_types = ()
42
43     def __init__(self, config, name=None, meta=None, state=None):
44         self.config = config
45         self.name = name
46         self.meta = meta
47         self.state = state
48         self.image_size = None
49
50         self.annotation_source = self.config.get('annotation_source')
51         if self.annotation_source and not isinstance(self.annotation_source, list):
52             self.annotation_source = string_to_list(self.annotation_source)
53
54         self.prediction_source = self.config.get('prediction_source')
55         if self.prediction_source and not isinstance(self.prediction_source, list):
56             self.prediction_source = string_to_list(self.prediction_source)
57
58         self.validate_config()
59         self.setup()
60
61     def __call__(self, *args, **kwargs):
62         return self.process_all(*args, **kwargs)
63
64     def setup(self):
65         self.configure()
66
67     def process_image(self, annotation, prediction):
68         raise NotImplementedError
69
70     def process(self, annotation, prediction):
71         image_size = annotation[0].metadata.get('image_size') if not None in annotation else None
72         self.image_size = None
73         if image_size:
74             self.image_size = image_size[0]
75         self.process_image(annotation, prediction)
76
77         return annotation, prediction
78
79     def process_all(self, annotations, predictions):
80         zipped_transform(self.process, zipped_transform(self.get_entries, annotations, predictions))
81         return annotations, predictions
82
83     def configure(self):
84         pass
85
86     def validate_config(self):
87         BasePostprocessorConfig(
88             self.name, on_extra_argument=BasePostprocessorConfig.ERROR_ON_EXTRA_ARGUMENT
89         ).validate(self.config)
90
91     def get_entries(self, annotation, prediction):
92         message_not_found = '{}: {} is not found in container'
93         message_incorrect_type = "Incorrect type of {}. Postprocessor {} can work only with {}"
94
95         def resolve_container(container, supported_types, entry_name, sources=None):
96             if not isinstance(container, ContainerRepresentation):
97                 if sources:
98                     message = 'Warning: {}_source can be applied only to container. Default value will be used'
99                     warnings.warn(message.format(entry_name))
100
101                 return [container]
102
103             if not sources:
104                 return get_supported_representations(container.values(), supported_types)
105
106             entries = []
107             for source in sources:
108                 representation = container.get(source)
109                 if not representation:
110                     raise ConfigError(message_not_found.format(entry_name, source))
111
112                 if supported_types and not check_representation_type(representation, supported_types):
113                     raise TypeError(message_incorrect_type.format(entry_name, self.name, ','.join(supported_types)))
114
115                 entries.append(representation)
116
117             return entries
118
119         annotation_entries = resolve_container(annotation, self.annotation_types, 'annotation', self.annotation_source)
120         prediction_entries = resolve_container(prediction, self.prediction_types, 'prediction', self.prediction_source)
121
122         return annotation_entries, prediction_entries
123
124
125 class ApplyToOption(Enum):
126     ANNOTATION = 'annotation'
127     PREDICTION = 'prediction'
128     ALL = 'all'
129
130
131 class PostprocessorWithTargetsConfigValidator(BasePostprocessorConfig):
132     apply_to = StringField(optional=True, choices=enum_values(ApplyToOption))
133
134
135 class PostprocessorWithSpecificTargets(Postprocessor):
136     def validate_config(self):
137         _config_validator = PostprocessorWithTargetsConfigValidator(
138             self.__provider__, on_extra_argument=PostprocessorWithTargetsConfigValidator.ERROR_ON_EXTRA_ARGUMENT
139         )
140         _config_validator.validate(self.config)
141
142     def setup(self):
143         apply_to = self.config.get('apply_to')
144         self.apply_to = ApplyToOption(apply_to) if apply_to else None
145
146         if (self.annotation_source or self.prediction_source) and self.apply_to:
147             raise ConfigError("apply_to and sources both provided. You need specify only one from them")
148
149         if not self.annotation_source and not self.prediction_source and not self.apply_to:
150             raise ConfigError("apply_to or annotation_source or prediction_source required for {}".format(self.name))
151
152         self.configure()
153
154     def process(self, annotation, prediction):
155         image_size = annotation[0].metadata.get('image_size') if not None in annotation else None
156         self.image_size = None
157         if image_size:
158             self.image_size = image_size[0]
159         target_annotations, target_predictions = None, None
160         if self.annotation_source or self.prediction_source:
161             target_annotations, target_predictions = self._choose_targets_using_sources(annotation, prediction)
162
163         if self.apply_to:
164             target_annotations, target_predictions = self._choose_targets_using_apply_to(annotation, prediction)
165
166         if not target_annotations and not target_predictions:
167             raise ValueError("Suitable targets for {} not found".format(self.name))
168
169         self.process_image(target_annotations, target_predictions)
170         return annotation, prediction
171
172     def _choose_targets_using_sources(self, annotations, predictions):
173         target_annotations = annotations if self.annotation_source else []
174         target_predictions = predictions if self.prediction_source else []
175
176         return target_annotations, target_predictions
177
178     def _choose_targets_using_apply_to(self, annotations, predictions):
179         targets_specification = {
180             ApplyToOption.ANNOTATION: (annotations, []),
181             ApplyToOption.PREDICTION: ([], predictions),
182             ApplyToOption.ALL: (annotations, predictions)
183         }
184
185         return targets_specification[self.apply_to]
186
187     def process_image(self, annotation, prediction):
188         raise NotImplementedError