Publishing 2019 R1 content
[platform/upstream/dldt.git] / tools / accuracy_checker / accuracy_checker / postprocessor / filter.py
1 """
2 Copyright (c) 2019 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 """
16 from functools import singledispatch
17 from typing import Union
18 import numpy as np
19
20 from ..config import BaseField, BoolField
21 from ..dependency import ClassProvider
22 from ..postprocessor.postprocessor import PostprocessorWithSpecificTargets, PostprocessorWithTargetsConfigValidator
23 from ..representation import (DetectionAnnotation, DetectionPrediction, TextDetectionAnnotation,
24                               TextDetectionPrediction, PoseEstimationPrediction, PoseEstimationAnnotation)
25 from ..utils import in_interval, polygon_from_points, convert_to_range
26
27
28 class FilterConfig(PostprocessorWithTargetsConfigValidator):
29     remove_filtered = BoolField(optional=True)
30
31     def __init__(self, config_uri, **kwargs):
32         super().__init__(config_uri, **kwargs)
33         for functor in BaseFilter.providers:
34             self.fields[functor] = BaseField(optional=True)
35
36
37 class FilterPostprocessor(PostprocessorWithSpecificTargets):
38     __provider__ = 'filter'
39
40     annotation_types = (DetectionAnnotation, TextDetectionAnnotation)
41     prediction_types = (DetectionPrediction, TextDetectionPrediction)
42
43     def __init__(self, *args, **kwargs):
44         self._filters = []
45         self.remove_filtered = False
46         super().__init__(*args, **kwargs)
47
48     def validate_config(self):
49         filter_config = FilterConfig(self.__provider__, on_extra_argument=FilterConfig.ERROR_ON_EXTRA_ARGUMENT)
50         filter_config.validate(self.config)
51
52     def configure(self):
53         config = self.config.copy()
54         config.pop('type')
55         self.remove_filtered = config.pop('remove_filtered', False)
56         config.pop('annotation_source', None)
57         config.pop('prediction_source', None)
58         config.pop('apply_to', None)
59
60         for key, value in config.items():
61             self._filters.append(BaseFilter.provide(key, value))
62
63     def process_image(self, annotation, prediction):
64         for functor in self._filters:
65             for target in annotation:
66                 self._filter_entry_by(target, functor)
67
68             for target in prediction:
69                 self._filter_entry_by(target, functor)
70
71         return annotation, prediction
72
73     def _filter_entry_by(self, entry, functor):
74         ignored_key = 'difficult_boxes'
75
76         if not self.remove_filtered and isinstance(entry, (DetectionAnnotation, DetectionPrediction,
77                                                            TextDetectionAnnotation, TextDetectionPrediction,
78                                                            PoseEstimationAnnotation, PoseEstimationPrediction)):
79             ignored = entry.metadata.setdefault(ignored_key, [])
80             ignored.extend(functor(entry))
81         else:
82             entry.remove(functor(entry))
83
84         return entry
85
86
87 class BaseFilter(ClassProvider):
88     __provider_type__ = 'filter'
89
90     def __init__(self, filter_arg):
91         self.filter_arg = filter_arg
92
93     def __call__(self, entry):
94         return self.apply_filter(entry, self.filter_arg)
95
96     def apply_filter(self, entry, filter_arg):
97         raise NotImplementedError
98
99
100 class FilterByLabels(BaseFilter):
101     __provider__ = 'labels'
102
103     def apply_filter(self, entry, labels):
104         filtered = []
105         for index, label in enumerate(entry.labels):
106             if label in labels:
107                 filtered.append(index)
108
109         return filtered
110
111
112 class FilterByMinConfidence(BaseFilter):
113     __provider__ = 'min_confidence'
114
115     def apply_filter(self, entry, min_confidence):
116         filtered = []
117
118         if isinstance(entry, DetectionAnnotation):
119             return filtered
120
121         for index, score in enumerate(entry.scores):
122             if score < min_confidence:
123                 filtered.append(index)
124
125         return filtered
126
127
128 class FilterByHeightRange(BaseFilter):
129     __provider__ = 'height_range'
130
131     annotation_types = (DetectionAnnotation, TextDetectionAnnotation)
132     prediction_types = (DetectionPrediction, TextDetectionPrediction)
133
134     def apply_filter(self, entry, height_range):
135         @singledispatch
136         def filtering(entry_value, height_range_):
137             return []
138
139         @filtering.register(Union[DetectionAnnotation, DetectionPrediction])
140         def _(entry_value, height_range_):
141             filtered = []
142             for index, (y_min, y_max) in enumerate(zip(entry_value.y_mins, entry_value.y_maxs)):
143                 height = y_max - y_min
144                 if not in_interval(height, height_range_):
145                     filtered.append(index)
146
147             return filtered
148
149         @filtering.register(Union[TextDetectionAnnotation, TextDetectionPrediction])
150         def _(entry_values, height_range_):
151             filtered = []
152             for index, polygon_points in enumerate(entry_values.points):
153                 left_bottom_point, left_top_point, right_top_point, right_bottom_point = polygon_points
154                 left_side_height = np.linalg.norm(left_bottom_point - left_top_point)
155                 right_side_height = np.linalg.norm(right_bottom_point - right_top_point)
156                 if not in_interval(np.mean([left_side_height, right_side_height]), height_range_):
157                     filtered.append(index)
158
159             return filtered
160
161         return filtering(entry, convert_to_range(height_range))
162
163
164 class FilterByWidthRange(BaseFilter):
165     __provider__ = 'width_range'
166
167     annotation_types = (DetectionAnnotation, TextDetectionAnnotation)
168     prediction_types = (DetectionPrediction, TextDetectionPrediction)
169
170     def apply_filter(self, entry, width_range):
171         @singledispatch
172         def filtering(entry_value, width_range_):
173             return []
174
175         @filtering.register(Union[DetectionAnnotation, DetectionPrediction])
176         def _(entry_value, width_range_):
177             filtered = []
178             for index, (x_min, x_max) in enumerate(zip(entry_value.x_mins, entry_value.x_maxs)):
179                 width = x_max - x_min
180                 if not in_interval(width, width_range_):
181                     filtered.append(index)
182
183             return filtered
184
185         @filtering.register(Union[TextDetectionAnnotation, TextDetectionPrediction])
186         def _(entry_values, width_range_):
187             filtered = []
188             for index, polygon_points in enumerate(entry_values.points):
189                 left_bottom_point, left_top_point, right_top_point, right_bottom_point = polygon_points
190                 top_width = np.linalg.norm(right_top_point - left_top_point)
191                 bottom_width = np.linalg.norm(right_bottom_point - left_bottom_point)
192                 if not in_interval(top_width, width_range_) or not in_interval(bottom_width, width_range_):
193                     filtered.append(index)
194
195             return filtered
196
197         return filtering(entry, convert_to_range(width_range))
198
199
200 class FilterByAreaRange(BaseFilter):
201     __provider__ = 'area_range'
202
203     annotation_types = (TextDetectionAnnotation, PoseEstimationAnnotation)
204     prediction_types = (TextDetectionPrediction, )
205
206     def apply_filter(self, entry, area_range):
207         area_range = convert_to_range(area_range)
208
209         @singledispatch
210         def filtering(entry, area_range):
211             return []
212
213         @filtering.register
214         def _(entry: Union[PoseEstimationAnnotation, PoseEstimationPrediction], area_range):
215             filtered = []
216             areas = entry.areas
217             for area_id, area in enumerate(areas):
218                 if not in_interval(area, area_range):
219                     filtered.append(area_id)
220             return filtered
221
222         @filtering.register
223         def _(entry: Union[TextDetectionAnnotation, TextDetectionPrediction]):
224             filtered = []
225             for index, polygon_points in enumerate(entry.points):
226                 if not in_interval(polygon_from_points(polygon_points).area, area_range):
227                     filtered.append(index)
228             return filtered
229
230         return filtering(entry, area_range)
231
232
233 class FilterEmpty(BaseFilter):
234     __provider__ = 'is_empty'
235
236     def apply_filter(self, entry: DetectionAnnotation, is_empty):
237         return np.where(np.bitwise_or(entry.x_maxs - entry.x_mins <= 0, entry.y_maxs - entry.y_mins <= 0))[0]
238
239
240 class FilterByVisibility(BaseFilter):
241     __provider__ = 'min_visibility'
242
243     _VISIBILITY_LEVELS = {
244         'heavy occluded': 0,
245         'partially occluded': 1,
246         'visible': 2
247     }
248
249     def apply_filter(self, entry, min_visibility):
250         filtered = []
251         min_visibility_level = self.visibility_level(min_visibility)
252         for index, visibility in enumerate(entry.metadata.get('visibilities', [])):
253             if self.visibility_level(visibility) < min_visibility_level:
254                 filtered.append(index)
255
256         return filtered
257
258     def visibility_level(self, visibility):
259         level = self._VISIBILITY_LEVELS.get(visibility)
260         if level is None:
261             message = 'Unknown visibility level "{}". Supported only "{}"'
262             raise ValueError(message.format(visibility, ','.join(self._VISIBILITY_LEVELS.keys())))
263
264         return level
265
266
267 class FilterByAspectRatio(BaseFilter):
268     __provider__ = 'aspect_ratio'
269
270     def apply_filter(self, entry, aspect_ratio):
271         aspect_ratio = convert_to_range(aspect_ratio)
272
273         filtered = []
274         coordinates = zip(entry.x_mins, entry.y_mins, entry.x_maxs, entry.y_maxs)
275         for index, (x_min, y_min, x_max, y_max) in enumerate(coordinates):
276             ratio = (y_max - y_min) / np.maximum(x_max - x_min, np.finfo(np.float64).eps)
277             if not in_interval(ratio, aspect_ratio):
278                 filtered.append(index)
279
280         return filtered
281
282
283 class FilterByAreaRatio(BaseFilter):
284     __provider__ = 'area_ratio'
285
286     def apply_filter(self, entry, area_ratio):
287         area_ratio = convert_to_range(area_ratio)
288
289         filtered = []
290         if not isinstance(entry, DetectionAnnotation):
291             return filtered
292
293         image_size = entry.metadata.get('image_size')
294         if not image_size:
295             return filtered
296         image_size = image_size[0]
297
298         image_area = image_size[0] * image_size[1]
299
300         occluded_indices = entry.metadata.get('is_occluded', [])
301         coordinates = zip(entry.x_mins, entry.y_mins, entry.x_maxs, entry.y_maxs)
302         for index, (x_min, y_min, x_max, y_max) in enumerate(coordinates):
303             width, height = x_max - x_min, y_max - y_min
304             area = np.sqrt(float(width * height) / np.maximum(image_area, np.finfo(np.float64).eps))
305             if not in_interval(area, area_ratio) or index in occluded_indices:
306                 filtered.append(index)
307
308         return filtered
309
310
311 class FilterInvalidBoxes(BaseFilter):
312     __provider__ = 'invalid_boxes'
313
314     def apply_filter(self, entry, invalid_boxes):
315         infinite_mask_x = np.logical_or(~np.isfinite(entry.x_mins), ~np.isfinite(entry.x_maxs))
316         infinite_mask_y = np.logical_or(~np.isfinite(entry.y_mins), ~np.isfinite(entry.y_maxs))
317         infinite_mask = np.logical_or(infinite_mask_x, infinite_mask_y)
318
319         return np.argwhere(infinite_mask).reshape(-1).tolist()