Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / cl_kernels / detection_output_sort.cl
1 // Copyright (c) 2018 Intel Corporation
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15
16 #include "include/include_all.cl"
17 #include "include/detection_output_common.cl"
18
19 UNIT_TYPE FUNC(get_score_sort)(__global UNIT_TYPE* input_bboxes, const uint idx_bbox, const uint idx_image)
20 {
21     if (idx_bbox == KEEP_BBOXES_NUM)
22     {
23         // Idx set to dummy value, return -1 to exclude this element from sorting
24         return -1;
25     }
26     else
27     {
28         return input_bboxes[(idx_bbox + idx_image * NUM_OF_IMAGE_BBOXES) * OUTPUT_ROW_SIZE + INPUT_OFFSET + SCORE_OFFSET];
29     }
30 }
31
32 KERNEL (detection_output_sort)(__global UNIT_TYPE* input_bboxes, __global UNIT_TYPE* output)
33 {
34     __local uint indexes[NUM_CLASSES_IN];
35     __local bool stillSorting;
36     __local uint output_count;
37     __local uint num_out_per_class[NUM_CLASSES_IN];
38
39     output_count = 0;
40     num_out_per_class[get_local_id(0)] = 0;
41
42     const uint image_id = get_global_id(0) / NUM_CLASSES_IN;
43     const uint local_id = get_local_id(0) * NUM_OF_ITEMS_SORT; // All bboxes from one image in work group
44
45     uint image_offset_input = image_id * NUM_OF_IMAGE_BBOXES;
46
47     uint count_sum = 0;
48     for (uint i = 0; i < image_id; i++)
49     {
50         count_sum += (input_bboxes[i] < KEEP_TOP_K)? input_bboxes[i] : KEEP_TOP_K;
51     }
52
53     uint image_offset_output = count_sum * OUTPUT_ROW_SIZE;
54
55     // If there is less elements than needed, write input to output
56     if (input_bboxes[image_id] <= KEEP_TOP_K)
57     {
58         if (local_id == 0)
59         {
60             for (uint class = 0; class < NUM_CLASSES_IN; class++)
61             {
62                 if (class == BACKGROUND_LABEL_ID && !HIDDEN_CLASS)
63                 {
64                     continue;
65                 }
66                 for (uint i = 0; i < NUM_OF_CLASS_BBOXES; i++)
67                 {
68                     uint input_idx = (i + image_offset_input + class * NUM_OF_CLASS_BBOXES) * OUTPUT_ROW_SIZE + INPUT_OFFSET;
69                     if (input_bboxes[input_idx] != -1)
70                     {
71                         uint out_idx = output_count * OUTPUT_ROW_SIZE + image_offset_output;
72
73                         for (uint idx = 0; idx < OUTPUT_ROW_SIZE; idx++)
74                         {
75                             output[out_idx + idx] = input_bboxes[input_idx + idx];
76                         }
77
78                         output_count++;
79                     }
80                     else
81                     {
82                         break;
83                     }
84                 }
85             }
86         }
87     }
88     else
89     {
90         uint sorted_output[KEEP_TOP_K * NUM_CLASSES_IN];
91
92         for (uint it = 0; it < NUM_OF_ITEMS_SORT; it++)
93         {
94             indexes[local_id + it] = (local_id + it) * NUM_OF_CLASS_BBOXES;
95         }
96
97         while (output_count < KEEP_BBOXES_NUM)
98         {
99             stillSorting = true;
100
101             while(stillSorting)
102             {
103                 barrier(CLK_LOCAL_MEM_FENCE);
104                 stillSorting = false;
105                 for (uint it = 0; it < NUM_OF_ITEMS_SORT; it++)
106                 {
107                     uint item_id = local_id + it;
108                     for (uint i = 0; i < 2; i++)
109                     {
110
111                         uint idx1 = indexes[item_id];
112                         uint idx2 = indexes[item_id+1];
113                         bool perform = false;
114                         if ((((i % 2) && (item_id % 2)) ||
115                             ((!(i % 2)) && (!(item_id % 2)))) &&
116                             (item_id != (NUM_CLASSES_IN - 1)))
117                         {
118                             perform = true;
119                         }
120
121                         if (perform &&
122                             (FUNC_CALL(get_score_sort)(input_bboxes, idx1, image_id) <
123                              FUNC_CALL(get_score_sort)(input_bboxes, idx2, image_id)))
124                         {
125                             indexes[item_id] = idx2;
126                             indexes[item_id+1] = idx1;
127                             stillSorting = true;
128                         }
129                         barrier(CLK_LOCAL_MEM_FENCE);
130                     }
131                 }
132             }
133
134             if (local_id == 0)
135             {
136                 UNIT_TYPE top_score = FUNC_CALL(get_score_sort)(input_bboxes, indexes[0], image_id);
137
138                 if (top_score != 0)
139                 {
140                     for (uint it = 0; (it < NUM_CLASSES_IN) && (output_count < KEEP_BBOXES_NUM); it++)
141                     {
142                         if (FUNC_CALL(get_score_sort)(input_bboxes, indexes[it], image_id) == top_score)
143                         {
144                             // write to output, create counter, and check if keep_top_k is satisfied.
145                             uint input_idx = (indexes[it] + image_offset_input) * OUTPUT_ROW_SIZE + INPUT_OFFSET;
146                             uint class_idx = input_bboxes[input_idx + 1] - HIDDEN_CLASS;
147
148                             sorted_output[class_idx * KEEP_TOP_K + num_out_per_class[class_idx]] = input_idx;
149                             num_out_per_class[class_idx]++;
150
151                             indexes[it]++;
152                             output_count++;
153
154                             // If all class elements are written to output, set dummy value to exclude class from sorting.
155                             if ((indexes[it] % NUM_OF_CLASS_BBOXES) == 0)
156                             {
157                                 indexes[it] = KEEP_BBOXES_NUM;
158                             }
159                         }
160                     }
161                 }
162                 else
163                 {
164                     // There is no more significant results to sort.
165                     output_count = KEEP_BBOXES_NUM;
166                 }
167             }
168             barrier(CLK_LOCAL_MEM_FENCE);
169         }
170         if (local_id == 0)
171         {
172             output_count = 0;
173             for (uint i = 0; i < NUM_CLASSES_IN; i++)
174             {
175                 for (uint j = 0; j < num_out_per_class[i]; j++)
176                 {
177
178                     uint out_idx = output_count * OUTPUT_ROW_SIZE + image_offset_output;
179                     for (uint idx = 0; idx < OUTPUT_ROW_SIZE; idx++)
180                     {
181                         output[out_idx + idx] = input_bboxes[sorted_output[i * KEEP_TOP_K + j] + idx];
182                     }
183                     output_count++;
184                 }
185            }
186            uint image_count_sum = (input_bboxes[image_id] < KEEP_TOP_K)? input_bboxes[image_id] : KEEP_TOP_K;
187            for (output_count; output_count < image_count_sum; output_count++)
188            {
189                 uint out_idx = output_count * OUTPUT_ROW_SIZE + image_offset_output;
190                 output[out_idx] = -1.0;
191                 output[out_idx + 1] = 0.0;
192                 output[out_idx + 2] = 0.0;
193                 output[out_idx + 3] = 0.0;
194                 output[out_idx + 4] = 0.0;
195                 output[out_idx + 5] = 0.0;
196                 output[out_idx + 6] = 0.0;
197            }
198         }
199     }
200
201     if (local_id == 0 &&
202         image_id == (NUM_IMAGES - 1))
203     {
204         for (output_count += count_sum; output_count < (KEEP_TOP_K *  NUM_IMAGES); output_count++ )
205         {
206             uint out_idx = output_count * OUTPUT_ROW_SIZE;
207             output[out_idx] = -1.0;
208             output[out_idx + 1] = 0.0;
209             output[out_idx + 2] = 0.0;
210             output[out_idx + 3] = 0.0;
211             output[out_idx + 4] = 0.0;
212             output[out_idx + 5] = 0.0;
213             output[out_idx + 6] = 0.0;
214         }
215     }
216
217 }