2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
21 from ..config import BaseField, BoolField, ConfigValidator, NumberField, StringField, ConfigError
22 from ..dependency import ClassProvider
23 from ..utils import get_size_from_config, get_or_parse_value, string_to_tuple, get_size_3d_from_config
26 class BasePreprocessorConfig(ConfigValidator):
30 class Preprocessor(ClassProvider):
31 __provider_type__ = 'preprocessor'
33 def __init__(self, config, name=None):
37 self.validate_config()
40 def __call__(self, *args, **kwargs):
41 return self.process(*args, **kwargs)
43 def process(self, image, annotation_meta=None):
44 raise NotImplementedError
49 def validate_config(self):
50 config = BasePreprocessorConfig(self.name, on_extra_argument=BasePreprocessorConfig.ERROR_ON_EXTRA_ARGUMENT)
51 config.validate(self.config)
54 def scale_width(dst_width, dst_height, image_width, image_height,):
55 return int(dst_width * image_width / image_height), dst_height
58 def scale_height(dst_width, dst_height, image_width, image_height):
59 return dst_width, int(dst_height * image_height / image_width)
62 def scale_greater(dst_width, dst_height, image_width, image_height):
63 if image_height > image_width:
64 return scale_height(dst_width, dst_height, image_width, image_height)
65 return scale_width(dst_width, dst_height, image_width, image_height)
68 class Resize(Preprocessor):
69 __provider__ = 'resize'
71 PILLOW_INTERPOLATION = {
72 'NEAREST': Image.NEAREST,
75 'BILINEAR': Image.BILINEAR,
76 'LINEAR': Image.LINEAR,
77 'HAMMING': Image.HAMMING,
78 'BICUBIC': Image.BICUBIC,
80 'LANCZOS': Image.LANCZOS,
81 'ANTIALIAS': Image.ANTIALIAS,
84 OPENCV_INTERPOLATION = {
85 'NEAREST': cv2.INTER_NEAREST,
86 'LINEAR': cv2.INTER_LINEAR,
87 'CUBIC': cv2.INTER_CUBIC,
88 'AREA': cv2.INTER_AREA,
90 'BITS': cv2.INTER_BITS,
91 'BITS2': cv2.INTER_BITS2,
92 'LANCZOS4': cv2.INTER_LANCZOS4,
95 ASPECT_RATIO_SCALE = {
97 'height': scale_height,
98 'greater': scale_greater,
101 def validate_config(self):
102 class _ConfigValidator(BasePreprocessorConfig):
103 size = NumberField(floats=False, optional=True, min_value=1)
104 dst_width = NumberField(floats=False, optional=True, min_value=1)
105 dst_height = NumberField(floats=False, optional=True, min_value=1)
106 aspect_ratio_scale = StringField(choices=set(Resize.ASPECT_RATIO_SCALE), optional=True)
107 interpolation = StringField(
108 choices=set(Resize.PILLOW_INTERPOLATION) | set(Resize.OPENCV_INTERPOLATION), optional=True
110 use_pil = BoolField(optional=True)
112 _ConfigValidator(self.name, on_extra_argument=_ConfigValidator.ERROR_ON_EXTRA_ARGUMENT).validate(self.config)
115 self.dst_height, self.dst_width = get_size_from_config(self.config)
116 self.use_pil = self.config.get('use_pil', False)
118 interpolation = self.config.get('interpolation', 'LINEAR')
120 self.scaling_func = Resize.ASPECT_RATIO_SCALE.get(self.config.get('aspect_ratio_scale'))
122 if self.use_pil and interpolation.upper() not in Resize.PILLOW_INTERPOLATION:
123 raise ValueError("Incorrect interpolation option: {} for resize preprocessing".format(interpolation))
124 if not self.use_pil and interpolation.upper() not in Resize.OPENCV_INTERPOLATION:
125 raise ValueError("Incorrect interpolation option: {} for resize preprocessing".format(interpolation))
128 self.interpolation = Resize.PILLOW_INTERPOLATION[interpolation]
130 self.interpolation = Resize.OPENCV_INTERPOLATION[interpolation]
132 def process(self, image, annotation_meta=None):
134 new_height, new_width = self.dst_height, self.dst_width
135 if self.scaling_func:
136 image_h, image_w = data.shape[:2]
137 new_width, new_height = self.scaling_func(self.dst_width, self.dst_height, image_w, image_h)
139 image.metadata['preferable_width'] = max(new_width, self.dst_width)
140 image.metadata['preferable_height'] = max(new_height, self.dst_height)
143 data = Image.fromarray(data)
144 data = data.resize((new_width, new_height), self.interpolation)
145 image.data = np.array(data)
148 data = cv2.resize(data, (new_width, new_height), interpolation=self.interpolation).astype(np.float32)
149 if len(data.shape) == 2:
150 data = np.expand_dims(data, axis=-1)
156 class Normalize(Preprocessor):
157 __provider__ = 'normalization'
159 PRECOMPUTED_MEANS = {
160 'imagenet': (104.00698793, 116.66876762, 122.67891434),
161 'cifar10': (125.307, 122.961, 113.8575),
165 'imagenet': (104.00698793, 116.66876762, 122.67891434),
166 'cifar10': (125.307, 122.961, 113.8575),
169 def validate_config(self):
170 class _ConfigValidator(BasePreprocessorConfig):
171 mean = BaseField(optional=True)
172 std = BaseField(optional=True)
174 _ConfigValidator(self.name, on_extra_argument=_ConfigValidator.ERROR_ON_EXTRA_ARGUMENT).validate(self.config)
177 self.mean = get_or_parse_value(self.config.get('mean'), Normalize.PRECOMPUTED_MEANS)
178 self.std = get_or_parse_value(self.config.get('std'), Normalize.PRECOMPUTED_STDS)
179 if not self.mean and not self.std:
180 raise ConfigError('mean or std value should be provided')
182 if self.std and 0 in self.std:
183 raise ConfigError('std value should not contain 0')
185 if self.mean and not (len(self.mean) == 3 or len(self.mean) == 1):
186 raise ConfigError('mean should be one value or comma-separated list channel-wise values')
188 if self.std and not (len(self.std) == 3 or len(self.std) == 1):
189 raise ConfigError('std should be one value or comma-separated list channel-wise values')
191 def process(self, image, annotation_meta=None):
193 image.data = image.data - self.mean
195 image.data = image.data / self.std
200 class BgrToRgb(Preprocessor):
201 __provider__ = 'bgr_to_rgb'
203 def process(self, image, annotation_meta=None):
204 image.data = cv2.cvtColor(image.data, cv2.COLOR_BGR2RGB)
208 class BgrToGray(Preprocessor):
209 __provider__ = 'bgr_to_gray'
211 def process(self, image, annotation_meta=None):
212 image.data = np.expand_dims(cv2.cvtColor(image.data, cv2.COLOR_BGR2GRAY).astype(np.float32), -1)
216 class Flip(Preprocessor):
217 __provider__ = 'flip'
224 def validate_config(self):
225 class _ConfigValidator(BasePreprocessorConfig):
226 mode = StringField(choices=Flip.FLIP_MODES.keys())
228 _ConfigValidator(self.name, on_extra_argument=_ConfigValidator.ERROR_ON_EXTRA_ARGUMENT).validate(self.config)
231 mode = self.config.get('mode', 'horizontal')
232 if isinstance(mode, str):
233 self.mode = Flip.FLIP_MODES[mode]
235 def process(self, image, annotation_meta=None):
236 image.data = cv2.flip(image.data, self.mode)
240 class Crop(Preprocessor):
241 __provider__ = 'crop'
243 def validate_config(self):
244 class _ConfigValidator(BasePreprocessorConfig):
245 size = NumberField(floats=False, optional=True, min_value=1)
246 dst_width = NumberField(floats=False, optional=True, min_value=1)
247 dst_height = NumberField(floats=False, optional=True, min_value=1)
249 _ConfigValidator(self.name, on_extra_argument=_ConfigValidator.ERROR_ON_EXTRA_ARGUMENT).validate(self.config)
252 self.dst_height, self.dst_width = get_size_from_config(self.config)
254 def process(self, image, annotation_meta=None):
256 height, width, _ = data.shape
257 if width < self.dst_width or height < self.dst_height:
258 resized = np.array([width, height])
259 if resized[0] < self.dst_width:
260 resized = resized * self.dst_width / resized[0]
261 if resized[1] < self.dst_height:
262 resized = resized * self.dst_height / resized[1]
264 data = cv2.resize(data, tuple(np.ceil(resized).astype(int)))
266 height, width, _ = data.shape
267 start_height = (height - self.dst_height) // 2
268 start_width = (width - self.dst_width) // 2
270 image.data = data[start_height:start_height + self.dst_height, start_width:start_width + self.dst_width]
274 class CropRect(Preprocessor):
275 __provider__ = 'crop_rect'
277 def process(self, image, annotation_meta=None):
278 rect = annotation_meta.get('rect')
282 rows, cols = image.data.shape[:2]
283 rect_x_min, rect_y_min, rect_x_max, rect_y_max = rect
284 start_width, start_height = max(0, rect_x_min), max(0, rect_y_min)
286 width = min(start_width + (rect_x_max - rect_x_min), cols)
287 height = min(start_height + (rect_y_max - rect_y_min), rows)
289 image.data = image.data[start_height:height, start_width:width]
293 class ExtendAroundRect(Preprocessor):
294 __provider__ = 'extend_around_rect'
296 def validate_config(self):
297 class _ConfigValidator(BasePreprocessorConfig):
298 augmentation_param = NumberField(floats=True, optional=True)
300 _ConfigValidator(self.name, on_extra_argument=_ConfigValidator.ERROR_ON_EXTRA_ARGUMENT).validate(self.config)
303 self.augmentation_param = self.config.get('augmentation_param', 0)
305 def process(self, image, annotation_meta=None):
306 rect = annotation_meta.get('rect')
307 rows, cols = image.data.shape[:2]
309 rect_x_left, rect_y_top, rect_x_right, rect_y_bottom = rect or (0, 0, cols, rows)
310 rect_x_left = max(0, rect_x_left)
311 rect_y_top = max(0, rect_y_top)
312 rect_x_right = min(rect_x_right, cols)
313 rect_y_bottom = min(rect_y_bottom, rows)
315 rect_w = rect_x_right - rect_x_left
316 rect_h = rect_y_bottom - rect_y_top
318 width_extent = (rect_x_right - rect_x_left + 1) * self.augmentation_param
319 height_extent = (rect_y_bottom - rect_y_top + 1) * self.augmentation_param
320 rect_x_left = rect_x_left - width_extent
321 border_left = abs(min(0, rect_x_left))
322 rect_x_left = int(max(0, rect_x_left))
324 rect_y_top = rect_y_top - height_extent
325 border_top = abs(min(0, rect_y_top))
326 rect_y_top = int(max(0, rect_y_top))
328 rect_y_bottom += border_top
329 rect_y_bottom = int(rect_y_bottom + height_extent + 0.5)
330 border_bottom = abs(max(0, rect_y_bottom - rows))
332 rect_x_right += border_left
333 rect_x_right = int(rect_x_right + width_extent + 0.5)
334 border_right = abs(max(0, rect_x_right - cols))
336 image.data = cv2.copyMakeBorder(
337 image.data, int(border_top), int(border_bottom), int(border_left), int(border_right), cv2.BORDER_REPLICATE
341 int(rect_x_left), int(rect_y_top),
342 int(rect_x_left) + int(rect_w + width_extent * 2), int(rect_y_top) + int(rect_h + height_extent * 2)
344 annotation_meta['rect'] = rect
349 class PointAligner(Preprocessor):
350 __provider__ = 'point_alignment'
352 ref_landmarks = np.array([
353 30.2946 / 96, 51.6963 / 112,
354 65.5318 / 96, 51.5014 / 112,
355 48.0252 / 96, 71.7366 / 112,
356 33.5493 / 96, 92.3655 / 112,
357 62.7299 / 96, 92.2041 / 112
358 ], dtype=np.float64).reshape(5, 2)
360 def validate_config(self):
361 class _ConfigValidator(BasePreprocessorConfig):
362 draw_points = BoolField(optional=True)
363 normalize = BoolField(optional=True)
364 size = NumberField(floats=False, optional=True, min_value=1)
365 dst_width = NumberField(floats=False, optional=True, min_value=1)
366 dst_height = NumberField(floats=False, optional=True, min_value=1)
368 _ConfigValidator(self.name, on_extra_argument=_ConfigValidator.ERROR_ON_EXTRA_ARGUMENT).validate(self.config)
371 self.draw_points = self.config.get('draw_points', False)
372 self.normalize = self.config.get('normalize', True)
373 self.dst_height, self.dst_width = get_size_from_config(self.config)
375 def process(self, image, annotation_meta=None):
376 keypoints = annotation_meta.get('keypoints')
377 image.data = self.align(image.data, keypoints)
380 def align(self, img, points):
384 points_number = len(points) // 2
385 points = np.array(points).reshape(points_number, 2)
389 inp_shape = img.shape
391 keypoints = points.copy().astype(np.float64)
392 keypoints[:, 0] *= (float(self.dst_width) / inp_shape[1])
393 keypoints[:, 1] *= (float(self.dst_height) / inp_shape[0])
395 keypoints_ref = np.zeros((points_number, 2), dtype=np.float64)
396 keypoints_ref[:, 0] = self.ref_landmarks[:, 0] * self.dst_width
397 keypoints_ref[:, 1] = self.ref_landmarks[:, 1] * self.dst_height
399 transformation_matrix = self.transformation_from_points(np.array(keypoints_ref), np.array(keypoints))
400 img = cv2.resize(img, (self.dst_width, self.dst_height))
402 for point in keypoints:
403 cv2.circle(img, (int(point[0]), int(point[1])), 5, (255, 0, 0), -1)
405 return cv2.warpAffine(img, transformation_matrix, (self.dst_width, self.dst_height), flags=cv2.WARP_INVERSE_MAP)
408 def transformation_from_points(points1, points2):
409 points1 = np.matrix(points1.astype(np.float64))
410 points2 = np.matrix(points2.astype(np.float64))
412 c1 = np.mean(points1, axis=0)
413 c2 = np.mean(points2, axis=0)
418 points1 /= np.maximum(s1, np.finfo(np.float64).eps)
419 points2 /= np.maximum(s1, np.finfo(np.float64).eps)
420 points_std_ratio = s2 / np.maximum(s1, np.finfo(np.float64).eps)
422 u, _, vt = np.linalg.svd(points1.T * points2)
425 return np.hstack((points_std_ratio * r, c2.T - points_std_ratio * r * c1.T))
428 class Padding(Preprocessor):
429 __provider__ = 'padding'
431 def validate_config(self):
432 class _ConfigValidator(BasePreprocessorConfig):
433 stride = NumberField(floats=False, min_value=1, optional=True)
434 pad_value = StringField(optional=True)
435 size = NumberField(floats=False, optional=True, min_value=1)
436 dst_width = NumberField(floats=False, optional=True, min_value=1)
437 dst_height = NumberField(floats=False, optional=True, min_value=1)
439 _ConfigValidator(self.name).validate(self.config)
442 self.stride = self.config.get('stride', 1)
443 pad_val = self.config.get('pad_value', '0,0,0')
444 if isinstance(pad_val, int):
445 self.pad_value = (pad_val, pad_val, pad_val)
446 if isinstance(pad_val, str):
447 self.pad_value = string_to_tuple(pad_val, int)
448 self.dst_height, self.dst_width = get_size_from_config(self.config, allow_none=True)
450 def process(self, image, annotation_meta=None):
451 height, width, _ = image.data.shape
452 pref_height = self.dst_height or image.metadata.get('preferable_height', height)
453 pref_width = self.dst_width or image.metadata.get('preferable_width', width)
454 height = min(height, pref_height)
455 pref_height = math.ceil(pref_height / float(self.stride)) * self.stride
456 pref_width = max(pref_width, width)
457 pref_width = math.ceil(pref_width / float(self.stride)) * self.stride
459 pad.append(int(math.floor((pref_height - height) / 2.0)))
460 pad.append(int(math.floor((pref_width - width) / 2.0)))
461 pad.append(int(pref_height - height - pad[0]))
462 pad.append(int(pref_width - width - pad[1]))
463 image.metadata['padding'] = pad
464 image.data = cv2.copyMakeBorder(
465 image.data, pad[0], pad[2], pad[1], pad[3], cv2.BORDER_CONSTANT, value=self.pad_value
470 class Tiling(Preprocessor):
471 __provider__ = 'tiling'
473 def validate_config(self):
474 class _ConfigValidator(BasePreprocessorConfig):
475 margin = NumberField(floats=False, min_value=1)
476 size = NumberField(floats=False, optional=True, min_value=1)
477 dst_width = NumberField(floats=False, optional=True, min_value=1)
478 dst_height = NumberField(floats=False, optional=True, min_value=1)
480 _ConfigValidator(self.name, on_extra_argument=_ConfigValidator.ERROR_ON_EXTRA_ARGUMENT).validate(self.config)
483 self.dst_height, self.dst_width = get_size_from_config(self.config)
484 self.margin = self.config['margin']
486 def process(self, image, annotation_meta=None):
488 image_size = data.shape
489 output_height = self.dst_height - 2 * self.margin
490 output_width = self.dst_width - 2 * self.margin
491 data = cv2.copyMakeBorder(data, *np.full(4, self.margin), cv2.BORDER_REFLECT_101)
492 num_tiles_h = image_size[0] // output_height + (1 if image_size[0] % output_height else 0)
493 num_tiles_w = image_size[1] // output_width + (1 if image_size[1] % output_width else 0)
495 for height in range(num_tiles_h):
496 for width in range(num_tiles_w):
497 offset = [output_height * height, output_width * width]
498 tile = data[offset[0]:offset[0] + self.dst_height, offset[1]:offset[1] + self.dst_width, :]
499 margin = [0, self.dst_height - tile.shape[0], 0, self.dst_width - tile.shape[1]]
500 tile = cv2.copyMakeBorder(tile, *margin, cv2.BORDER_REFLECT_101)
501 tiled_data.append(tile)
502 image.data = tiled_data
503 image.metadata['tiles_shape'] = (num_tiles_h, num_tiles_w)
504 image.metadata['multi_infer'] = True
508 class Crop3D(Preprocessor):
509 __provider__ = 'crop3d'
511 def validate_config(self):
512 class _ConfigValidator(BasePreprocessorConfig):
513 size = NumberField(floats=False, min_value=1)
514 dst_width = NumberField(floats=False, optional=True, min_value=1)
515 dst_height = NumberField(floats=False, optional=True, min_value=1)
516 dst_volume = NumberField(floats=False, optional=True, min_value=1)
518 _ConfigValidator(self.name, on_extra_argument=_ConfigValidator.ERROR_ON_EXTRA_ARGUMENT).validate(self.config)
521 self.dst_height, self.dst_width, self.dst_volume = get_size_3d_from_config(self.config)
523 def process(self, image, annotation_meta=None):
524 image.data = self.crop_center(image.data, self.dst_height, self.dst_width, self.dst_volume)
528 def crop_center(img, cropx, cropy, cropz):
530 z, y, x, _ = img.shape
532 # Make sure starting index is >= 0
533 startx = max(x // 2 - (cropx // 2), 0)
534 starty = max(y // 2 - (cropy // 2), 0)
535 startz = max(z // 2 - (cropz // 2), 0)
537 # Make sure ending index is <= size
538 endx = min(startx + cropx, x)
539 endy = min(starty + cropy, y)
540 endz = min(startz + cropz, z)
542 return img[startz:endz, starty:endy, startx:endx, :]
545 class Normalize3d(Preprocessor):
546 __provider__ = "normalize3d"
548 def process(self, image, annotation_meta=None):
549 data = self.normalize_img(image.data)
552 image_list.append(img)
553 image.data = image_list
554 image.metadata['multi_infer'] = True
559 def normalize_img(img):
560 for channel in range(img.shape[3]):
561 channel_val = img[:, :, :, channel] - np.mean(img[:, :, :, channel])
562 channel_val /= np.std(img[:, :, :, channel])
563 img[:, :, :, channel] = channel_val