2 Copyright (c) 2019 Intel Corporation
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
8 http://www.apache.org/licenses/LICENSE-2.0
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
17 from collections import defaultdict
23 from ..adapters import Adapter
24 from ..config import ConfigValidator, StringField, NumberField, BoolField, ConfigError
25 from ..representation import TextDetectionPrediction, CharacterRecognitionPrediction
28 class TextDetectionAdapterConfig(ConfigValidator):
30 pixel_link_out = StringField()
31 pixel_class_out = StringField()
34 class TextDetectionAdapter(Adapter):
35 __provider__ = 'text_detection'
37 def validate_config(self):
38 text_detection_adapter_config = TextDetectionAdapterConfig('TextDetectionAdapter_Config')
39 text_detection_adapter_config.validate(self.launcher_config)
42 self.pixel_link_out = self.launcher_config['pixel_link_out']
43 self.pixel_class_out = self.launcher_config['pixel_class_out']
45 def process(self, raw, identifiers=None, frame_meta=None):
47 predictions = self._extract_predictions(raw, frame_meta)
48 raw_output = zip(identifiers, frame_meta, predictions[self.pixel_link_out], predictions[self.pixel_class_out])
49 for identifier, current_frame_meta, link_data, cls_data in raw_output:
50 link_data = link_data.reshape((1, *link_data.shape))
51 cls_data = cls_data.reshape((1, *cls_data.shape))
52 link_data_shape = link_data.shape
53 new_link_data_shape = (link_data_shape[0], link_data_shape[2], link_data_shape[3], link_data_shape[1] / 2)
54 cls_data_shape = cls_data.shape
55 new_cls_data_shape = (cls_data_shape[0], cls_data_shape[2], cls_data_shape[3], cls_data_shape[1] / 2)
56 link_data = self.softmax(link_data.transpose((0, 2, 3, 1)).reshape(-1))[1::2]
57 cls_data = self.softmax(cls_data.transpose((0, 2, 3, 1)).reshape(-1))[1::2]
58 mask = self.decode_image_by_join(cls_data, new_cls_data_shape, link_data, new_link_data_shape)
59 rects = self.mask_to_boxes(mask, current_frame_meta['image_size'])
60 results.append(TextDetectionPrediction(identifier, rects))
66 for i in np.arange(start=0, stop=data.size, step=2, dtype=int):
67 maximum = max(data[i], data[i + 1])
68 data[i] = np.exp(data[i] - maximum)
69 data[i + 1] = np.exp(data[i + 1] - maximum)
70 sum_data = data[i] + data[i + 1]
72 data[i + 1] /= sum_data
76 def decode_image_by_join(self, cls_data, cls_data_shape, link_data, link_data_shape):
77 k_cls_conf_threshold = 0.7
78 k_link_conf_threshold = 0.7
79 height = cls_data_shape[1]
80 width = cls_data_shape[2]
81 id_pixel_mask = np.argwhere(cls_data >= k_cls_conf_threshold).reshape(-1)
82 pixel_mask = cls_data >= k_cls_conf_threshold
84 pixel_mask[id_pixel_mask] = True
86 for i in id_pixel_mask:
87 points.append((i % width, i // width))
89 link_mask = link_data >= k_link_conf_threshold
90 neighbours = link_data_shape[3]
93 point_x, point_y = point
94 x_neighbours = [point_x - 1, point_x, point_x + 1]
95 y_neighbours = [point_y - 1, point_y, point_y + 1]
96 for neighbour_y in y_neighbours:
97 for neighbour_x in x_neighbours:
98 if neighbour_x == point_x and neighbour_y == point_y:
101 if neighbour_x < 0 or neighbour_x >= width or neighbour_y < 0 or neighbour_y >= height:
104 pixel_value = np.uint8(pixel_mask[neighbour_y * width + neighbour_x])
105 link_value = np.uint8(
106 link_mask[int(point_y * width * neighbours + point_x * neighbours + neighbour)]
109 if pixel_value and link_value:
110 group_mask = self.join(point_x + point_y * width, neighbour_x + neighbour_y * width, group_mask)
114 return self.get_all(points, width, height, group_mask)
116 def join(self, point1, point2, group_mask):
117 root1 = self.find_root(point1, group_mask)
118 root2 = self.find_root(point2, group_mask)
120 group_mask[root1] = root2
124 def get_all(self, points, width, height, group_mask):
126 mask = np.zeros((height, width))
129 point_x, point_y = point
130 point_root = self.find_root(point_x + point_y * width, group_mask)
131 if not root_map.get(point_root):
132 root_map[point_root] = int(len(root_map) + 1)
133 mask[point_y, point_x] = root_map[point_root]
138 def find_root(point, group_mask):
140 update_parent = False
141 while group_mask[root] != -1:
142 root = group_mask[root]
146 group_mask[point] = root
151 def mask_to_boxes(mask, image_size):
152 max_val = np.max(mask).astype(int)
153 resized_mask = cv2.resize(
154 mask.astype(np.float32), (image_size[1], image_size[0]), interpolation=cv2.INTER_NEAREST
157 for i in range(int(max_val + 1)):
158 bbox_mask = resized_mask == i
159 contours_tuple = cv2.findContours(bbox_mask.astype(np.uint8), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
160 contours = contours_tuple[1] if len(contours_tuple) > 2 else contours_tuple[0]
163 rect = cv2.minAreaRect(contours[0])
165 ignored_height = hw[0] >= image_size[0] - 1
166 ignored_width = hw[1] >= image_size[1] - 1
167 if ignored_height or ignored_width:
169 box = cv2.boxPoints(rect)
175 class LPRAdapter(Adapter):
179 if not self.label_map:
180 raise ConfigError('LPR adapter requires dataset label map for correct decoding.')
182 def process(self, raw, identifiers=None, frame_meta=None):
183 raw_output = self._extract_predictions(raw, frame_meta)
184 predictions = raw_output[self.output_blob]
186 for identifier, output in zip(identifiers, predictions):
187 decoded_out = self.decode(output.reshape(-1))
188 result.append(CharacterRecognitionPrediction(identifier, decoded_out))
192 def decode(self, outputs):
194 for output in outputs:
197 decode_out += str(self.label_map[output])
202 class BeamSearchDecoderConfig(ConfigValidator):
203 beam_size = NumberField(optional=True, floats=False, min_value=1)
204 blank_label = NumberField(optional=True, floats=False, min_value=0)
205 softmaxed_probabilities = BoolField(optional=True)
208 class BeamSearchDecoder(Adapter):
209 __provider__ = 'beam_search_decoder'
211 def validate_config(self):
212 beam_search_decoder_config = BeamSearchDecoderConfig(
213 'BeamSearchDecoder_Config',
214 BeamSearchDecoderConfig.IGNORE_ON_EXTRA_ARGUMENT
216 beam_search_decoder_config.validate(self.launcher_config)
219 if not self.label_map:
220 raise ConfigError('Beam Search Decoder requires dataset label map for correct decoding.')
222 self.beam_size = self.launcher_config.get('beam_size', 10)
223 self.blank_label = self.launcher_config.get('blank_label', len(self.label_map))
224 self.softmaxed_probabilities = self.launcher_config.get('softmaxed_probabilities', False)
226 def process(self, raw, identifiers=None, frame_meta=None):
227 raw_output = self._extract_predictions(raw, frame_meta)
228 output = raw_output[self.output_blob]
229 output = np.swapaxes(output, 0, 1)
232 for identifier, data in zip(identifiers, output):
233 if self.softmaxed_probabilities:
235 seq = self.decode(data, self.beam_size, self.blank_label)
236 decoded = ''.join(str(self.label_map[char]) for char in seq)
237 result.append(CharacterRecognitionPrediction(identifier, decoded))
241 def decode(probabilities, beam_size=10, blank_id=None):
243 Decode given output probabilities to sequence of labels.
245 probabilities: The output log probabilities for each time step.
246 Should be an array of shape (time x output dim).
247 beam_size (int): Size of the beam to use during decoding.
248 blank_id (int): Index of the CTC blank label.
249 Returns the output label sequence.
252 return defaultdict(lambda: (-np.inf, -np.inf))
254 def log_sum_exp(*args):
255 if all(a == -np.inf for a in args):
258 lsp = np.log(np.sum(np.exp(a - a_max) for a in args))
262 times, symbols = probabilities.shape
263 # Initialize the beam with the empty sequence, a probability of 1 for ending in blank
264 # and zero for ending in non-blank (in log space).
265 beam = [(tuple(), (0.0, -np.inf))]
267 for time in range(times):
268 # A default dictionary to store the next step candidates.
269 next_beam = make_new_beam()
271 for symbol_id in range(symbols):
272 current_prob = probabilities[time, symbol_id]
274 for prefix, (prob_blank, prob_non_blank) in beam:
275 # If propose a blank the prefix doesn't change.
276 # Only the probability of ending in blank gets updated.
277 if symbol_id == blank_id:
278 next_prob_blank, next_prob_non_blank = next_beam[prefix]
279 next_prob_blank = log_sum_exp(
280 next_prob_blank, prob_blank + current_prob, prob_non_blank + current_prob
282 next_beam[prefix] = (next_prob_blank, next_prob_non_blank)
284 # Extend the prefix by the new character symbol and add it to the beam.
285 # Only the probability of not ending in blank gets updated.
286 end_t = prefix[-1] if prefix else None
287 next_prefix = prefix + (symbol_id,)
288 next_prob_blank, next_prob_non_blank = next_beam[next_prefix]
289 if symbol_id != end_t:
290 next_prob_non_blank = log_sum_exp(
291 next_prob_non_blank, prob_blank + current_prob, prob_non_blank + current_prob
294 # Don't include the previous probability of not ending in blank (prob_non_blank) if symbol
295 # is repeated at the end. The CTC algorithm merges characters not separated by a blank.
296 next_prob_non_blank = log_sum_exp(next_prob_non_blank, prob_blank + current_prob)
298 next_beam[next_prefix] = (next_prob_blank, next_prob_non_blank)
299 # If symbol is repeated at the end also update the unchanged prefix. This is the merging case.
300 if symbol_id == end_t:
301 next_prob_blank, next_prob_non_blank = next_beam[prefix]
302 next_prob_non_blank = log_sum_exp(next_prob_non_blank, prob_non_blank + current_prob)
303 next_beam[prefix] = (next_prob_blank, next_prob_non_blank)
305 beam = sorted(next_beam.items(), key=lambda x: log_sum_exp(*x[1]), reverse=True)[:beam_size]