2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
16 from functools import singledispatch
17 from typing import Union
20 from ..config import BaseField, BoolField
21 from ..dependency import ClassProvider
22 from ..postprocessor.postprocessor import PostprocessorWithSpecificTargets, PostprocessorWithTargetsConfigValidator
23 from ..representation import (DetectionAnnotation, DetectionPrediction, TextDetectionAnnotation,
24 TextDetectionPrediction, PoseEstimationPrediction, PoseEstimationAnnotation)
25 from ..utils import in_interval, polygon_from_points, convert_to_range
28 class FilterConfig(PostprocessorWithTargetsConfigValidator):
29 remove_filtered = BoolField(optional=True)
31 def __init__(self, config_uri, **kwargs):
32 super().__init__(config_uri, **kwargs)
33 for functor in BaseFilter.providers:
34 self.fields[functor] = BaseField(optional=True)
37 class FilterPostprocessor(PostprocessorWithSpecificTargets):
38 __provider__ = 'filter'
40 annotation_types = (DetectionAnnotation, TextDetectionAnnotation)
41 prediction_types = (DetectionPrediction, TextDetectionPrediction)
43 def __init__(self, *args, **kwargs):
45 self.remove_filtered = False
46 super().__init__(*args, **kwargs)
48 def validate_config(self):
49 filter_config = FilterConfig(self.__provider__, on_extra_argument=FilterConfig.ERROR_ON_EXTRA_ARGUMENT)
50 filter_config.validate(self.config)
53 config = self.config.copy()
55 self.remove_filtered = config.pop('remove_filtered', False)
56 config.pop('annotation_source', None)
57 config.pop('prediction_source', None)
58 config.pop('apply_to', None)
60 for key, value in config.items():
61 self._filters.append(BaseFilter.provide(key, value))
63 def process_image(self, annotation, prediction):
64 for functor in self._filters:
65 for target in annotation:
66 self._filter_entry_by(target, functor)
68 for target in prediction:
69 self._filter_entry_by(target, functor)
71 return annotation, prediction
73 def _filter_entry_by(self, entry, functor):
74 ignored_key = 'difficult_boxes'
76 if not self.remove_filtered and isinstance(entry, (DetectionAnnotation, DetectionPrediction,
77 TextDetectionAnnotation, TextDetectionPrediction,
78 PoseEstimationAnnotation, PoseEstimationPrediction)):
79 ignored = entry.metadata.setdefault(ignored_key, [])
80 ignored.extend(functor(entry))
82 entry.remove(functor(entry))
87 class BaseFilter(ClassProvider):
88 __provider_type__ = 'filter'
90 def __init__(self, filter_arg):
91 self.filter_arg = filter_arg
93 def __call__(self, entry):
94 return self.apply_filter(entry, self.filter_arg)
96 def apply_filter(self, entry, filter_arg):
97 raise NotImplementedError
100 class FilterByLabels(BaseFilter):
101 __provider__ = 'labels'
103 def apply_filter(self, entry, labels):
105 for index, label in enumerate(entry.labels):
107 filtered.append(index)
112 class FilterByMinConfidence(BaseFilter):
113 __provider__ = 'min_confidence'
115 def apply_filter(self, entry, min_confidence):
118 if isinstance(entry, DetectionAnnotation):
121 for index, score in enumerate(entry.scores):
122 if score < min_confidence:
123 filtered.append(index)
128 class FilterByHeightRange(BaseFilter):
129 __provider__ = 'height_range'
131 annotation_types = (DetectionAnnotation, TextDetectionAnnotation)
132 prediction_types = (DetectionPrediction, TextDetectionPrediction)
134 def apply_filter(self, entry, height_range):
136 def filtering(entry_value, height_range_):
139 @filtering.register(Union[DetectionAnnotation, DetectionPrediction])
140 def _(entry_value, height_range_):
142 for index, (y_min, y_max) in enumerate(zip(entry_value.y_mins, entry_value.y_maxs)):
143 height = y_max - y_min
144 if not in_interval(height, height_range_):
145 filtered.append(index)
149 @filtering.register(Union[TextDetectionAnnotation, TextDetectionPrediction])
150 def _(entry_values, height_range_):
152 for index, polygon_points in enumerate(entry_values.points):
153 left_bottom_point, left_top_point, right_top_point, right_bottom_point = polygon_points
154 left_side_height = np.linalg.norm(left_bottom_point - left_top_point)
155 right_side_height = np.linalg.norm(right_bottom_point - right_top_point)
156 if not in_interval(np.mean([left_side_height, right_side_height]), height_range_):
157 filtered.append(index)
161 return filtering(entry, convert_to_range(height_range))
164 class FilterByWidthRange(BaseFilter):
165 __provider__ = 'width_range'
167 annotation_types = (DetectionAnnotation, TextDetectionAnnotation)
168 prediction_types = (DetectionPrediction, TextDetectionPrediction)
170 def apply_filter(self, entry, width_range):
172 def filtering(entry_value, width_range_):
175 @filtering.register(Union[DetectionAnnotation, DetectionPrediction])
176 def _(entry_value, width_range_):
178 for index, (x_min, x_max) in enumerate(zip(entry_value.x_mins, entry_value.x_maxs)):
179 width = x_max - x_min
180 if not in_interval(width, width_range_):
181 filtered.append(index)
185 @filtering.register(Union[TextDetectionAnnotation, TextDetectionPrediction])
186 def _(entry_values, width_range_):
188 for index, polygon_points in enumerate(entry_values.points):
189 left_bottom_point, left_top_point, right_top_point, right_bottom_point = polygon_points
190 top_width = np.linalg.norm(right_top_point - left_top_point)
191 bottom_width = np.linalg.norm(right_bottom_point - left_bottom_point)
192 if not in_interval(top_width, width_range_) or not in_interval(bottom_width, width_range_):
193 filtered.append(index)
197 return filtering(entry, convert_to_range(width_range))
200 class FilterByAreaRange(BaseFilter):
201 __provider__ = 'area_range'
203 annotation_types = (TextDetectionAnnotation, PoseEstimationAnnotation)
204 prediction_types = (TextDetectionPrediction, )
206 def apply_filter(self, entry, area_range):
207 area_range = convert_to_range(area_range)
210 def filtering(entry, area_range):
214 def _(entry: Union[PoseEstimationAnnotation, PoseEstimationPrediction], area_range):
217 for area_id, area in enumerate(areas):
218 if not in_interval(area, area_range):
219 filtered.append(area_id)
223 def _(entry: Union[TextDetectionAnnotation, TextDetectionPrediction]):
225 for index, polygon_points in enumerate(entry.points):
226 if not in_interval(polygon_from_points(polygon_points).area, area_range):
227 filtered.append(index)
230 return filtering(entry, area_range)
233 class FilterEmpty(BaseFilter):
234 __provider__ = 'is_empty'
236 def apply_filter(self, entry: DetectionAnnotation, is_empty):
237 return np.where(np.bitwise_or(entry.x_maxs - entry.x_mins <= 0, entry.y_maxs - entry.y_mins <= 0))[0]
240 class FilterByVisibility(BaseFilter):
241 __provider__ = 'min_visibility'
243 _VISIBILITY_LEVELS = {
245 'partially occluded': 1,
249 def apply_filter(self, entry, min_visibility):
251 min_visibility_level = self.visibility_level(min_visibility)
252 for index, visibility in enumerate(entry.metadata.get('visibilities', [])):
253 if self.visibility_level(visibility) < min_visibility_level:
254 filtered.append(index)
258 def visibility_level(self, visibility):
259 level = self._VISIBILITY_LEVELS.get(visibility)
261 message = 'Unknown visibility level "{}". Supported only "{}"'
262 raise ValueError(message.format(visibility, ','.join(self._VISIBILITY_LEVELS.keys())))
267 class FilterByAspectRatio(BaseFilter):
268 __provider__ = 'aspect_ratio'
270 def apply_filter(self, entry, aspect_ratio):
271 aspect_ratio = convert_to_range(aspect_ratio)
274 coordinates = zip(entry.x_mins, entry.y_mins, entry.x_maxs, entry.y_maxs)
275 for index, (x_min, y_min, x_max, y_max) in enumerate(coordinates):
276 ratio = (y_max - y_min) / np.maximum(x_max - x_min, np.finfo(np.float64).eps)
277 if not in_interval(ratio, aspect_ratio):
278 filtered.append(index)
283 class FilterByAreaRatio(BaseFilter):
284 __provider__ = 'area_ratio'
286 def apply_filter(self, entry, area_ratio):
287 area_ratio = convert_to_range(area_ratio)
290 if not isinstance(entry, DetectionAnnotation):
293 image_size = entry.metadata.get('image_size')
296 image_size = image_size[0]
298 image_area = image_size[0] * image_size[1]
300 occluded_indices = entry.metadata.get('is_occluded', [])
301 coordinates = zip(entry.x_mins, entry.y_mins, entry.x_maxs, entry.y_maxs)
302 for index, (x_min, y_min, x_max, y_max) in enumerate(coordinates):
303 width, height = x_max - x_min, y_max - y_min
304 area = np.sqrt(float(width * height) / np.maximum(image_area, np.finfo(np.float64).eps))
305 if not in_interval(area, area_ratio) or index in occluded_indices:
306 filtered.append(index)
311 class FilterInvalidBoxes(BaseFilter):
312 __provider__ = 'invalid_boxes'
314 def apply_filter(self, entry, invalid_boxes):
315 infinite_mask_x = np.logical_or(~np.isfinite(entry.x_mins), ~np.isfinite(entry.x_maxs))
316 infinite_mask_y = np.logical_or(~np.isfinite(entry.y_mins), ~np.isfinite(entry.y_maxs))
317 infinite_mask = np.logical_or(infinite_mask_x, infinite_mask_y)
319 return np.argwhere(infinite_mask).reshape(-1).tolist()