1 // Copyright (c) 2018 Intel Corporation
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
16 #include "include/include_all.cl"
17 #include "include/detection_output_common.cl"
19 UNIT_TYPE FUNC(get_score_sort)(__global UNIT_TYPE* input_bboxes, const uint idx_bbox, const uint idx_image)
21 if (idx_bbox == KEEP_BBOXES_NUM)
23 // Idx set to dummy value, return -1 to exclude this element from sorting
28 return input_bboxes[(idx_bbox + idx_image * NUM_OF_IMAGE_BBOXES) * OUTPUT_ROW_SIZE + INPUT_OFFSET + SCORE_OFFSET];
32 KERNEL (detection_output_sort)(__global UNIT_TYPE* input_bboxes, __global UNIT_TYPE* output)
34 __local uint indexes[NUM_CLASSES_IN];
35 __local bool stillSorting;
36 __local uint output_count;
37 __local uint num_out_per_class[NUM_CLASSES_IN];
40 num_out_per_class[get_local_id(0)] = 0;
42 const uint image_id = get_global_id(0) / NUM_CLASSES_IN;
43 const uint local_id = get_local_id(0) * NUM_OF_ITEMS_SORT; // All bboxes from one image in work group
45 uint image_offset_input = image_id * NUM_OF_IMAGE_BBOXES;
48 for (uint i = 0; i < image_id; i++)
50 count_sum += (input_bboxes[i] < KEEP_TOP_K)? input_bboxes[i] : KEEP_TOP_K;
53 uint image_offset_output = count_sum * OUTPUT_ROW_SIZE;
55 // If there is less elements than needed, write input to output
56 if (input_bboxes[image_id] <= KEEP_TOP_K)
60 for (uint class = 0; class < NUM_CLASSES_IN; class++)
62 if (class == BACKGROUND_LABEL_ID && !HIDDEN_CLASS)
66 for (uint i = 0; i < NUM_OF_CLASS_BBOXES; i++)
68 uint input_idx = (i + image_offset_input + class * NUM_OF_CLASS_BBOXES) * OUTPUT_ROW_SIZE + INPUT_OFFSET;
69 if (input_bboxes[input_idx] != -1)
71 uint out_idx = output_count * OUTPUT_ROW_SIZE + image_offset_output;
73 for (uint idx = 0; idx < OUTPUT_ROW_SIZE; idx++)
75 output[out_idx + idx] = input_bboxes[input_idx + idx];
90 uint sorted_output[KEEP_TOP_K * NUM_CLASSES_IN];
92 for (uint it = 0; it < NUM_OF_ITEMS_SORT; it++)
94 indexes[local_id + it] = (local_id + it) * NUM_OF_CLASS_BBOXES;
97 while (output_count < KEEP_BBOXES_NUM)
103 barrier(CLK_LOCAL_MEM_FENCE);
104 stillSorting = false;
105 for (uint it = 0; it < NUM_OF_ITEMS_SORT; it++)
107 uint item_id = local_id + it;
108 for (uint i = 0; i < 2; i++)
111 uint idx1 = indexes[item_id];
112 uint idx2 = indexes[item_id+1];
113 bool perform = false;
114 if ((((i % 2) && (item_id % 2)) ||
115 ((!(i % 2)) && (!(item_id % 2)))) &&
116 (item_id != (NUM_CLASSES_IN - 1)))
122 (FUNC_CALL(get_score_sort)(input_bboxes, idx1, image_id) <
123 FUNC_CALL(get_score_sort)(input_bboxes, idx2, image_id)))
125 indexes[item_id] = idx2;
126 indexes[item_id+1] = idx1;
129 barrier(CLK_LOCAL_MEM_FENCE);
136 UNIT_TYPE top_score = FUNC_CALL(get_score_sort)(input_bboxes, indexes[0], image_id);
140 for (uint it = 0; (it < NUM_CLASSES_IN) && (output_count < KEEP_BBOXES_NUM); it++)
142 if (FUNC_CALL(get_score_sort)(input_bboxes, indexes[it], image_id) == top_score)
144 // write to output, create counter, and check if keep_top_k is satisfied.
145 uint input_idx = (indexes[it] + image_offset_input) * OUTPUT_ROW_SIZE + INPUT_OFFSET;
146 uint class_idx = input_bboxes[input_idx + 1] - HIDDEN_CLASS;
148 sorted_output[class_idx * KEEP_TOP_K + num_out_per_class[class_idx]] = input_idx;
149 num_out_per_class[class_idx]++;
154 // If all class elements are written to output, set dummy value to exclude class from sorting.
155 if ((indexes[it] % NUM_OF_CLASS_BBOXES) == 0)
157 indexes[it] = KEEP_BBOXES_NUM;
164 // There is no more significant results to sort.
165 output_count = KEEP_BBOXES_NUM;
168 barrier(CLK_LOCAL_MEM_FENCE);
173 for (uint i = 0; i < NUM_CLASSES_IN; i++)
175 for (uint j = 0; j < num_out_per_class[i]; j++)
178 uint out_idx = output_count * OUTPUT_ROW_SIZE + image_offset_output;
179 for (uint idx = 0; idx < OUTPUT_ROW_SIZE; idx++)
181 output[out_idx + idx] = input_bboxes[sorted_output[i * KEEP_TOP_K + j] + idx];
186 uint image_count_sum = (input_bboxes[image_id] < KEEP_TOP_K)? input_bboxes[image_id] : KEEP_TOP_K;
187 for (output_count; output_count < image_count_sum; output_count++)
189 uint out_idx = output_count * OUTPUT_ROW_SIZE + image_offset_output;
190 output[out_idx] = -1.0;
191 output[out_idx + 1] = 0.0;
192 output[out_idx + 2] = 0.0;
193 output[out_idx + 3] = 0.0;
194 output[out_idx + 4] = 0.0;
195 output[out_idx + 5] = 0.0;
196 output[out_idx + 6] = 0.0;
202 image_id == (NUM_IMAGES - 1))
204 for (output_count += count_sum; output_count < (KEEP_TOP_K * NUM_IMAGES); output_count++ )
206 uint out_idx = output_count * OUTPUT_ROW_SIZE;
207 output[out_idx] = -1.0;
208 output[out_idx + 1] = 0.0;
209 output[out_idx + 2] = 0.0;
210 output[out_idx + 3] = 0.0;
211 output[out_idx + 4] = 0.0;
212 output[out_idx + 5] = 0.0;
213 output[out_idx + 6] = 0.0;