Publishing 2019 R1 content
[platform/upstream/dldt.git] / tools / accuracy_checker / accuracy_checker / adapters / text_detection.py
1 """
2 Copyright (c) 2019 Intel Corporation
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 """
16
17 from collections import defaultdict
18
19 import cv2
20 import numpy as np
21
22
23 from ..adapters import Adapter
24 from ..config import ConfigValidator, StringField, NumberField, BoolField, ConfigError
25 from ..representation import TextDetectionPrediction, CharacterRecognitionPrediction
26
27
28 class TextDetectionAdapterConfig(ConfigValidator):
29     type = StringField()
30     pixel_link_out = StringField()
31     pixel_class_out = StringField()
32
33
34 class TextDetectionAdapter(Adapter):
35     __provider__ = 'text_detection'
36
37     def validate_config(self):
38         text_detection_adapter_config = TextDetectionAdapterConfig('TextDetectionAdapter_Config')
39         text_detection_adapter_config.validate(self.launcher_config)
40
41     def configure(self):
42         self.pixel_link_out = self.launcher_config['pixel_link_out']
43         self.pixel_class_out = self.launcher_config['pixel_class_out']
44
45     def process(self, raw, identifiers=None, frame_meta=None):
46         results = []
47         predictions = self._extract_predictions(raw, frame_meta)
48         raw_output = zip(identifiers, frame_meta, predictions[self.pixel_link_out], predictions[self.pixel_class_out])
49         for identifier, current_frame_meta, link_data, cls_data in raw_output:
50             link_data = link_data.reshape((1, *link_data.shape))
51             cls_data = cls_data.reshape((1, *cls_data.shape))
52             link_data_shape = link_data.shape
53             new_link_data_shape = (link_data_shape[0], link_data_shape[2], link_data_shape[3], link_data_shape[1] / 2)
54             cls_data_shape = cls_data.shape
55             new_cls_data_shape = (cls_data_shape[0], cls_data_shape[2], cls_data_shape[3], cls_data_shape[1] / 2)
56             link_data = self.softmax(link_data.transpose((0, 2, 3, 1)).reshape(-1))[1::2]
57             cls_data = self.softmax(cls_data.transpose((0, 2, 3, 1)).reshape(-1))[1::2]
58             mask = self.decode_image_by_join(cls_data, new_cls_data_shape, link_data, new_link_data_shape)
59             rects = self.mask_to_boxes(mask, current_frame_meta['image_size'])
60             results.append(TextDetectionPrediction(identifier, rects))
61
62         return results
63
64     @staticmethod
65     def softmax(data):
66         for i in np.arange(start=0, stop=data.size, step=2, dtype=int):
67             maximum = max(data[i], data[i + 1])
68             data[i] = np.exp(data[i] - maximum)
69             data[i + 1] = np.exp(data[i + 1] - maximum)
70             sum_data = data[i] + data[i + 1]
71             data[i] /= sum_data
72             data[i + 1] /= sum_data
73
74         return data
75
76     def decode_image_by_join(self, cls_data, cls_data_shape, link_data, link_data_shape):
77         k_cls_conf_threshold = 0.7
78         k_link_conf_threshold = 0.7
79         height = cls_data_shape[1]
80         width = cls_data_shape[2]
81         id_pixel_mask = np.argwhere(cls_data >= k_cls_conf_threshold).reshape(-1)
82         pixel_mask = cls_data >= k_cls_conf_threshold
83         group_mask = {}
84         pixel_mask[id_pixel_mask] = True
85         points = []
86         for i in id_pixel_mask:
87             points.append((i % width, i // width))
88             group_mask[i] = -1
89         link_mask = link_data >= k_link_conf_threshold
90         neighbours = link_data_shape[3]
91         for point in points:
92             neighbour = 0
93             point_x, point_y = point
94             x_neighbours = [point_x - 1, point_x, point_x + 1]
95             y_neighbours = [point_y - 1, point_y, point_y + 1]
96             for neighbour_y in y_neighbours:
97                 for neighbour_x in x_neighbours:
98                     if neighbour_x == point_x and neighbour_y == point_y:
99                         continue
100
101                     if neighbour_x < 0 or neighbour_x >= width or neighbour_y < 0 or neighbour_y >= height:
102                         continue
103
104                     pixel_value = np.uint8(pixel_mask[neighbour_y * width + neighbour_x])
105                     link_value = np.uint8(
106                         link_mask[int(point_y * width * neighbours + point_x * neighbours + neighbour)]
107                     )
108
109                     if pixel_value and link_value:
110                         group_mask = self.join(point_x + point_y * width, neighbour_x + neighbour_y * width, group_mask)
111
112                     neighbour += 1
113
114         return self.get_all(points, width, height, group_mask)
115
116     def join(self, point1, point2, group_mask):
117         root1 = self.find_root(point1, group_mask)
118         root2 = self.find_root(point2, group_mask)
119         if root1 != root2:
120             group_mask[root1] = root2
121
122         return group_mask
123
124     def get_all(self, points, width, height, group_mask):
125         root_map = {}
126         mask = np.zeros((height, width))
127
128         for point in points:
129             point_x, point_y = point
130             point_root = self.find_root(point_x + point_y * width, group_mask)
131             if not root_map.get(point_root):
132                 root_map[point_root] = int(len(root_map) + 1)
133             mask[point_y, point_x] = root_map[point_root]
134
135         return mask
136
137     @staticmethod
138     def find_root(point, group_mask):
139         root = point
140         update_parent = False
141         while group_mask[root] != -1:
142             root = group_mask[root]
143             update_parent = True
144
145         if update_parent:
146             group_mask[point] = root
147
148         return root
149
150     @staticmethod
151     def mask_to_boxes(mask, image_size):
152         max_val = np.max(mask).astype(int)
153         resized_mask = cv2.resize(
154             mask.astype(np.float32), (image_size[1], image_size[0]), interpolation=cv2.INTER_NEAREST
155         )
156         bboxes = []
157         for i in range(int(max_val + 1)):
158             bbox_mask = resized_mask == i
159             contours_tuple = cv2.findContours(bbox_mask.astype(np.uint8), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
160             contours = contours_tuple[1] if len(contours_tuple) > 2 else contours_tuple[0]
161             if not contours:
162                 continue
163             rect = cv2.minAreaRect(contours[0])
164             _, hw, _ = rect
165             ignored_height = hw[0] >= image_size[0] - 1
166             ignored_width = hw[1] >= image_size[1] - 1
167             if ignored_height or ignored_width:
168                 continue
169             box = cv2.boxPoints(rect)
170             bboxes.append(box)
171
172         return bboxes
173
174
175 class LPRAdapter(Adapter):
176     __provider__ = 'lpr'
177
178     def configure(self):
179         if not self.label_map:
180             raise ConfigError('LPR adapter requires dataset label map for correct decoding.')
181
182     def process(self, raw, identifiers=None, frame_meta=None):
183         raw_output = self._extract_predictions(raw, frame_meta)
184         predictions = raw_output[self.output_blob]
185         result = []
186         for identifier, output in zip(identifiers, predictions):
187             decoded_out = self.decode(output.reshape(-1))
188             result.append(CharacterRecognitionPrediction(identifier, decoded_out))
189
190         return result
191
192     def decode(self, outputs):
193         decode_out = str()
194         for output in outputs:
195             if output == -1:
196                 break
197             decode_out += str(self.label_map[output])
198
199         return decode_out
200
201
202 class BeamSearchDecoderConfig(ConfigValidator):
203     beam_size = NumberField(optional=True, floats=False, min_value=1)
204     blank_label = NumberField(optional=True, floats=False, min_value=0)
205     softmaxed_probabilities = BoolField(optional=True)
206
207
208 class BeamSearchDecoder(Adapter):
209     __provider__ = 'beam_search_decoder'
210
211     def validate_config(self):
212         beam_search_decoder_config = BeamSearchDecoderConfig(
213             'BeamSearchDecoder_Config',
214             BeamSearchDecoderConfig.IGNORE_ON_EXTRA_ARGUMENT
215         )
216         beam_search_decoder_config.validate(self.launcher_config)
217
218     def configure(self):
219         if not self.label_map:
220             raise ConfigError('Beam Search Decoder requires dataset label map for correct decoding.')
221
222         self.beam_size = self.launcher_config.get('beam_size', 10)
223         self.blank_label = self.launcher_config.get('blank_label', len(self.label_map))
224         self.softmaxed_probabilities = self.launcher_config.get('softmaxed_probabilities', False)
225
226     def process(self, raw, identifiers=None, frame_meta=None):
227         raw_output = self._extract_predictions(raw, frame_meta)
228         output = raw_output[self.output_blob]
229         output = np.swapaxes(output, 0, 1)
230
231         result = []
232         for identifier, data in zip(identifiers, output):
233             if self.softmaxed_probabilities:
234                 data = np.log(data)
235             seq = self.decode(data, self.beam_size, self.blank_label)
236             decoded = ''.join(str(self.label_map[char]) for char in seq)
237             result.append(CharacterRecognitionPrediction(identifier, decoded))
238         return result
239
240     @staticmethod
241     def decode(probabilities, beam_size=10, blank_id=None):
242         """
243          Decode given output probabilities to sequence of labels.
244         Arguments:
245             probabilities: The output log probabilities for each time step.
246             Should be an array of shape (time x output dim).
247             beam_size (int): Size of the beam to use during decoding.
248             blank_id (int): Index of the CTC blank label.
249         Returns the output label sequence.
250         """
251         def make_new_beam():
252             return defaultdict(lambda: (-np.inf, -np.inf))
253
254         def log_sum_exp(*args):
255             if all(a == -np.inf for a in args):
256                 return -np.inf
257             a_max = np.max(args)
258             lsp = np.log(np.sum(np.exp(a - a_max) for a in args))
259
260             return a_max + lsp
261
262         times, symbols = probabilities.shape
263         # Initialize the beam with the empty sequence, a probability of 1 for ending in blank
264         # and zero for ending in non-blank (in log space).
265         beam = [(tuple(), (0.0, -np.inf))]
266
267         for time in range(times):
268             # A default dictionary to store the next step candidates.
269             next_beam = make_new_beam()
270
271             for symbol_id in range(symbols):
272                 current_prob = probabilities[time, symbol_id]
273
274                 for prefix, (prob_blank, prob_non_blank) in beam:
275                     # If propose a blank the prefix doesn't change.
276                     # Only the probability of ending in blank gets updated.
277                     if symbol_id == blank_id:
278                         next_prob_blank, next_prob_non_blank = next_beam[prefix]
279                         next_prob_blank = log_sum_exp(
280                             next_prob_blank, prob_blank + current_prob, prob_non_blank + current_prob
281                         )
282                         next_beam[prefix] = (next_prob_blank, next_prob_non_blank)
283                         continue
284                     # Extend the prefix by the new character symbol and add it to the beam.
285                     # Only the probability of not ending in blank gets updated.
286                     end_t = prefix[-1] if prefix else None
287                     next_prefix = prefix + (symbol_id,)
288                     next_prob_blank, next_prob_non_blank = next_beam[next_prefix]
289                     if symbol_id != end_t:
290                         next_prob_non_blank = log_sum_exp(
291                             next_prob_non_blank, prob_blank + current_prob, prob_non_blank + current_prob
292                         )
293                     else:
294                         # Don't include the previous probability of not ending in blank (prob_non_blank) if symbol
295                         #  is repeated at the end. The CTC algorithm merges characters not separated by a blank.
296                         next_prob_non_blank = log_sum_exp(next_prob_non_blank, prob_blank + current_prob)
297
298                     next_beam[next_prefix] = (next_prob_blank, next_prob_non_blank)
299                     # If symbol is repeated at the end also update the unchanged prefix. This is the merging case.
300                     if symbol_id == end_t:
301                         next_prob_blank, next_prob_non_blank = next_beam[prefix]
302                         next_prob_non_blank = log_sum_exp(next_prob_non_blank, prob_non_blank + current_prob)
303                         next_beam[prefix] = (next_prob_blank, next_prob_non_blank)
304
305             beam = sorted(next_beam.items(), key=lambda x: log_sum_exp(*x[1]), reverse=True)[:beam_size]
306
307         best = beam[0]
308
309         return best[0]