arm_compute v18.02
[platform/upstream/armcl.git] / src / core / CL / cl_kernels / softmax_layer_quantized.cl
1 /*
2  * Copyright (c) 2017-2018 ARM Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "helpers_asymm.h"
25
26 #define MAX_OP(x, y, type, size) max((x), (y))
27 #define ADD_OP(x, y, type, size) ((x) + (y))
28
29 /* Number of workitems in dimension 0. */
30 #if !defined(GRID_SIZE)
31 #define GRID_SIZE 1
32 #endif /* !defined(GRID_SIZE) */
33
34 #if VECTOR_SIZE == 2
35 __constant uint2 idx__ = (uint2)(0, 1);
36 #define asymm_mult(a, b) ASYMM_MULT(a, b, 2)
37 #define asymm_exp_on_negative_values(a, k_integer_bits) ASYMM_EXP_ON_NEGATIVE_VALUES(a, k_integer_bits, 2)
38 #define asymm_rescale(value, src_integer_bits, dst_integer_bits) ASYMM_RESCALE(value, src_integer_bits, dst_integer_bits, 2)
39
40 #elif VECTOR_SIZE == 4
41 __constant uint4 idx__ = (uint4)(0, 1, 2, 3);
42 #define asymm_mult(a, b) ASYMM_MULT(a, b, 4)
43 #define asymm_exp_on_negative_values(a, k_integer_bits) ASYMM_EXP_ON_NEGATIVE_VALUES(a, k_integer_bits, 4)
44 #define asymm_rescale(value, src_integer_bits, dst_integer_bits) ASYMM_RESCALE(value, src_integer_bits, dst_integer_bits, 4)
45
46 #elif VECTOR_SIZE == 8
47 __constant uint8 idx__ = (uint8)(0, 1, 2, 3, 4, 5, 6, 7);
48 #define asymm_mult(a, b) ASYMM_MULT(a, b, 8)
49 #define asymm_exp_on_negative_values(a, k_integer_bits) ASYMM_EXP_ON_NEGATIVE_VALUES(a, k_integer_bits, 8)
50 #define asymm_rescale(value, src_integer_bits, dst_integer_bits) ASYMM_RESCALE(value, src_integer_bits, dst_integer_bits, 8)
51
52 #else /* VECTOR_SIZE DEFAULT */
53 #define VECTOR_SIZE 16
54 #define LOG_VECTOR_SIZE 4
55 __constant uint16 idx__ = (uint16)(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
56 #define asymm_mult(a, b) ASYMM_MULT(a, b, 16)
57 #define asymm_exp_on_negative_values(a, k_integer_bits) ASYMM_EXP_ON_NEGATIVE_VALUES(a, k_integer_bits, 16)
58 #define asymm_rescale(value, src_integer_bits, dst_integer_bits) ASYMM_RESCALE(value, src_integer_bits, dst_integer_bits, 16)
59
60 #endif /* VECTOR_SIZE END */
61
62 #define VEC_UCHAR VEC_DATA_TYPE(uchar, VECTOR_SIZE)
63 #define VEC_UINT VEC_DATA_TYPE(uint, VECTOR_SIZE)
64 #define VEC_INT VEC_DATA_TYPE(int, VECTOR_SIZE)
65
66 #if defined(DIFF_MIN)
67
68 VEC_INT mult_by_quantized_multiplier_serial(VEC_INT data)
69 {
70 #if defined(INPUT_BETA_MULTIPLIER) && defined(INPUT_BETA_LEFT_SHIFT)
71     if(INPUT_BETA_MULTIPLIER > 1)
72     {
73         return asymm_mult(data * (1 << INPUT_BETA_LEFT_SHIFT), INPUT_BETA_MULTIPLIER);
74     }
75 #endif /* defined(INPUT_BETA_MULTIPLIER) && defined(INPUT_BETA_LEFT_SHIFT) */
76     return data;
77 }
78
79 int4 mult_by_quantized_multiplier_parallel(int4 data)
80 {
81 #if defined(INPUT_BETA_MULTIPLIER) && defined(INPUT_BETA_LEFT_SHIFT)
82     if(INPUT_BETA_MULTIPLIER > 1)
83     {
84         return ASYMM_MULT(data * (1 << INPUT_BETA_LEFT_SHIFT), INPUT_BETA_MULTIPLIER, 4);
85     }
86 #endif /* defined(INPUT_BETA_MULTIPLIER) && defined(INPUT_BETA_LEFT_SHIFT) */
87     return data;
88 }
89
90 /** Shifts the values of the input tensor by the max calculated in softmax_layer_max kernel,
91  * then gets the exponent of each element as sums all elements across each row.
92  *
93  * @note In case the input is not multiple of 16 -DNON_MULTIPLE_OF_VECTOR_SIZE must be passed.
94  * @note Quantized beta can be optionally passed at compile time using -DINPUT_BETA_MULTIPLIER and -DINPUT_BETA_LEFT_SHIFT (if undefined, assume beta equals 1.0)
95  * @note -DDIFF_MIN must be passed at compile time. It is threshold difference between maximum value of input data and current processed value, it defines whether the value will be taken into account or not.
96  *
97  * @param[in]  src_ptr                           Pointer to the source tensor slice. Supported data types: QASYMM8
98  * @param[in]  src_stride_x                      Stride of the source tensor in X dimension (in bytes)
99  * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
100  * @param[in]  src_stride_y                      Stride of the source tensor in Y dimension (in bytes)
101  * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
102  * @param[in]  src_stride_z                      Stride of the source tensor in Z dimension (in bytes)
103  * @param[in]  src_step_z                        src_stride_z * number of elements along Z processed per workitem(in bytes)
104  * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source tensor
105  * @param[in]  max_ptr                           Pointer to the max values tensor slice. Supported data types: same as @p src_ptr
106  * @param[in]  max_stride_x                      Stride of the max values tensor in X dimension (in bytes)
107  * @param[in]  max_step_x                        max_stride_x * number of elements along X processed per workitem(in bytes)
108  * @param[in]  max_stride_y                      Stride of the max values tensor in Y dimension (in bytes)
109  * @param[in]  max_step_y                        max_stride_y * number of elements along Y processed per workitem(in bytes)
110  * @param[in]  max_stride_z                      Stride of the max values tensor in Z dimension (in bytes)
111  * @param[in]  max_step_z                        max_stride_z * number of elements along Z processed per workitem(in bytes)
112  * @param[in]  max_offset_first_element_in_bytes The offset of the first element in the max values tensor
113  * @param[out] dst_ptr                           Pointer to the destination tensor slice. Supported data types: S32
114  * @param[in]  dst_stride_x                      Stride of the destination tensor in X dimension (in bytes)
115  * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
116  * @param[in]  dst_stride_y                      Stride of the destination tensor in Y dimension (in bytes)
117  * @param[in]  dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
118  * @param[in]  dst_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
119  * @param[in]  dst_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
120  * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
121  * @param[out] sum_ptr                           Pointer to the sum values tensor slice. Supported data types: same as @p dst_ptr
122  * @param[in]  sum_stride_x                      Stride of the sum values tensor in X dimension (in bytes)
123  * @param[in]  sum_step_x                        sum_stride_x * number of elements along X processed per workitem(in bytes)
124  * @param[in]  sum_stride_y                      Stride of the sum values tensor in Y dimension (in bytes)
125  * @param[in]  sum_step_y                        sum_stride_z * number of elements along Z processed per workitem(in bytes)
126  * @param[in]  sum_stride_z                      Stride of the sum values tensor in Z dimension (in bytes)
127  * @param[in]  sum_step_z                        sum_stride_z * number of elements along Z processed per workitem(in bytes)
128  * @param[in]  sum_offset_first_element_in_bytes The offset of the first element in the sum values tensor
129  * @param[in]  width                             Input image width
130  */
131 __kernel void softmax_layer_max_shift_exp_sum_quantized_serial(
132     TENSOR3D_DECLARATION(src),
133     TENSOR3D_DECLARATION(maxo),
134     TENSOR3D_DECLARATION(dst),
135     TENSOR3D_DECLARATION(sum),
136     uint width)
137 {
138     Image src  = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
139     Image dst  = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
140     Image maxo = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(maxo);
141     Image sum  = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(sum);
142
143     VEC_UCHAR max_val_vec = 0;
144
145     // Calculate max of row
146     const uint width4 = width >> LOG_VECTOR_SIZE;
147     for(uint i = 0; i < width4; i++)
148     {
149         VEC_UCHAR data = VLOAD(VECTOR_SIZE)(0, (__global uchar *)offset(&src, i << LOG_VECTOR_SIZE, 0));
150         max_val_vec    = MAX_OP(data, max_val_vec, uchar, 16);
151     }
152
153 #ifdef NON_MULTIPLE_OF_VECTOR_SIZE
154     // Handle non multiple of 16
155     VEC_UCHAR uchar_min = (VEC_UCHAR)0;
156     VEC_UCHAR data      = VLOAD(VECTOR_SIZE)(0, (__global uchar *)offset(&src, width4 << LOG_VECTOR_SIZE, 0));
157     VEC_UCHAR widx      = CONVERT(((VEC_UINT)(width4 << LOG_VECTOR_SIZE) + idx__) < width, VEC_UCHAR);
158     max_val_vec         = MAX_OP(max_val_vec, select(uchar_min, data, widx), uchar, 16);
159 #endif /* NON_MULTIPLE_OF_VECTOR_SIZE */
160
161     // Perform max reduction
162 #if VECTOR_SIZE == 16
163     max_val_vec.s01234567 = MAX_OP(max_val_vec.s01234567, max_val_vec.s89ABCDEF, uchar, 8);
164 #endif /* VECTOR SIZE 16 END */
165 #if VECTOR_SIZE >= 8
166     max_val_vec.s0123 = MAX_OP(max_val_vec.s0123, max_val_vec.s4567, uchar, 4);
167 #endif /* VECTOR SIZE 8 END */
168 #if VECTOR_SIZE >= 4
169     max_val_vec.s01 = MAX_OP(max_val_vec.s01, max_val_vec.s23, uchar, 2);
170 #endif /* VECTOR SIZE 4 END */
171     max_val_vec.s0 = MAX_OP(max_val_vec.s0, max_val_vec.s1, uchar, 1);
172
173     // Store result
174     *((__global uchar *)maxo.ptr) = max_val_vec.s0;
175
176     // Second part
177
178     // Load max value of 1D logits vector (row)
179     int max_val = convert_int(*((__global uchar *)offset(&maxo, 0, 0)));
180
181     // Set sum vector, Q(EXP_ACCUMULATION_INT_BITS)
182     VEC_INT sum1D = 0;
183
184     // Shift values, exp and sum
185     for(uint i = 0; i < width4; i++)
186     {
187         VEC_UCHAR data         = VLOAD(VECTOR_SIZE)(0, (__global uchar *)offset(&src, i << LOG_VECTOR_SIZE, 0));
188         VEC_INT data_fp        = CONVERT(data, VEC_INT);
189         VEC_INT data_diff      = data_fp - max_val;
190         VEC_INT data_diff_mult = mult_by_quantized_multiplier_serial(data_diff);
191         data_fp                = asymm_exp_on_negative_values(data_diff_mult, SCALED_DIFF_INT_BITS);
192         data_fp                = asymm_rescale(data_fp, 0, EXP_ACCUMULATION_INT_BITS);
193         VSTORE(VECTOR_SIZE)
194         (data_diff, 0, (__global int *)offset(&dst, i << LOG_VECTOR_SIZE, 0));
195         sum1D = sum1D + select(0, data_fp, data_diff >= (VEC_INT)(DIFF_MIN));
196     }
197
198 #ifdef NON_MULTIPLE_OF_VECTOR_SIZE
199     // Handle non multiple of 16
200     data                   = VLOAD(VECTOR_SIZE)(0, (__global uchar *)offset(&src, width4 << LOG_VECTOR_SIZE, 0));
201     VEC_INT data_fp        = CONVERT(data, VEC_INT);
202     VEC_INT data_diff      = data_fp - max_val;
203     VEC_INT data_diff_mult = mult_by_quantized_multiplier_serial(data_diff);
204     data_fp                = asymm_exp_on_negative_values(data_diff_mult, SCALED_DIFF_INT_BITS);
205     data_fp                = asymm_rescale(data_fp, 0, EXP_ACCUMULATION_INT_BITS);
206     VEC_INT widx_          = CONVERT(((VEC_UINT)(width4 << LOG_VECTOR_SIZE) + idx__) < width, VEC_INT);
207     VSTORE(VECTOR_SIZE)
208     (data_diff, 0, (__global int *)offset(&dst, width4 << LOG_VECTOR_SIZE, 0));
209     data_fp = select(0, data_fp, data_diff >= (VEC_INT)(DIFF_MIN));
210     sum1D   = sum1D + select(0, data_fp, widx_);
211 #endif /* NON_MULTIPLE_OF_VECTOR_SIZE */
212
213     // Perform sum reduction
214 #if VECTOR_SIZE == 16
215     sum1D.s01234567 = ADD_OP(sum1D.s01234567, sum1D.s89ABCDEF, uchar, 8);
216 #endif /* VECTOR SIZE 16 END */
217 #if VECTOR_SIZE >= 8
218     sum1D.s0123 = ADD_OP(sum1D.s0123, sum1D.s4567, uchar, 4);
219 #endif /* VECTOR SIZE 8 END */
220 #if VECTOR_SIZE >= 4
221     sum1D.s01 = ADD_OP(sum1D.s01, sum1D.s23, uchar, 2);
222 #endif /* VECTOR SIZE 4 END */
223     sum1D.s0 = ADD_OP(sum1D.s0, sum1D.s1, uchar, 1);
224
225     // Calculate and store result
226     *((__global int *)sum.ptr) = sum1D.s0;
227 }
228
229 /** Identifies the maximum value across the 1st dimension and shifts the values of the input tensor by this maximum value,
230  * then gets the exponent of each element as sums all elements across each row.
231  *
232  * @note Datatype must be given as a preprocessor argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=short
233  * @note Fixed point position must be given as a preprocessor argument using -DFIXED_POINT_POSITION=pos. e.g. DFIXED_POINT_POSITION=4
234  * @note In case the input is not a multiple of VECTOR_SIZE (2,4,8,16) -DNON_MULTIPLE_OF_VECTOR_SIZE must be passed.
235  *
236  * @param[in]  src_ptr                            Pointer to the source tensor slice. Supported data types: QS8/QS16/F16/F32
237  * @param[in]  src_stride_x                       Stride of the source tensor in X dimension (in bytes)
238  * @param[in]  src_step_x                         src_stride_x * number of elements along X processed per workitem(in bytes)
239  * @param[in]  src_stride_y                       Stride of the source tensor in Y dimension (in bytes)
240  * @param[in]  src_step_y                         src_stride_y * number of elements along Y processed per workitem(in bytes)
241  * @param[in]  src_stride_z                       Stride of the source tensor in Z dimension (in bytes)
242  * @param[in]  src_step_z                         src_stride_z * number of elements along Z processed per workitem(in bytes)
243  * @param[in]  src_offset_first_element_in_bytes  The offset of the first element in the source tensor
244  * @param[in]  maxo_ptr                           Pointer to the max values tensor slice. Supported data types: same as @p src_ptr
245  * @param[in]  maxo_stride_x                      Stride of the max values tensor in X dimension (in bytes)
246  * @param[in]  maxo_step_x                        max_stride_x * number of elements along X processed per workitem(in bytes)
247  * @param[in]  maxo_stride_y                      Stride of the max values tensor in Y dimension (in bytes)
248  * @param[in]  maxo_step_y                        max_stride_y * number of elements along Y processed per workitem(in bytes)
249  * @param[in]  maxo_stride_z                      Stride of the max values tensor in Z dimension (in bytes)
250  * @param[in]  maxo_step_z                        max_stride_z * number of elements along Z processed per workitem(in bytes)
251  * @param[in]  maxo_offset_first_element_in_bytes The offset of the first element in the max values tensor
252  * @param[out] dst_ptr                            Pointer to the destination tensor slice. Supported data types: same as @p src_ptr
253  * @param[in]  dst_stride_x                       Stride of the destination tensor in X dimension (in bytes)
254  * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
255  * @param[in]  dst_stride_y                       Stride of the destination tensor in Y dimension (in bytes)
256  * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
257  * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
258  * @param[in]  dst_step_z                         dst_stride_z * number of elements along Z processed per workitem(in bytes)
259  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination tensor
260  * @param[out] sum_ptr                            Pointer to the sum values tensor slice. Supported data types: same as @p src_ptr
261  * @param[in]  sum_stride_x                       Stride of the sum values tensor in X dimension (in bytes)
262  * @param[in]  sum_step_x                         sum_stride_x * number of elements along X processed per workitem(in bytes)
263  * @param[in]  sum_stride_y                       Stride of the sum values tensor in Y dimension (in bytes)
264  * @param[in]  sum_step_y                         sum_stride_z * number of elements along Z processed per workitem(in bytes)
265  * @param[in]  sum_stride_z                       Stride of the sum values tensor in Z dimension (in bytes)
266  * @param[in]  sum_step_z                         sum_stride_z * number of elements along Z processed per workitem(in bytes)
267  * @param[in]  sum_offset_first_element_in_bytes  The offset of the first element in the sum values tensor
268  * @param[in]  width                              Input image width
269  */
270 __kernel void softmax_layer_max_shift_exp_sum_quantized_parallel(
271     TENSOR3D_DECLARATION(src),
272     TENSOR3D_DECLARATION(maxo),
273     TENSOR3D_DECLARATION(dst),
274     TENSOR3D_DECLARATION(sum),
275     uint width)
276 {
277     Image src  = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
278     Image dst  = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
279     Image maxo = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(maxo);
280     Image sum  = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(sum);
281
282     const uint4 idx4 = (uint4)(0, 1, 2, 3);
283     const uint  lid  = get_local_id(0);
284
285     // Define one temporary vector per work-item.
286     __local int4 tmp_local[GRID_SIZE];
287     __local uchar max_local;
288
289     uchar4 uchar_min   = (uchar4)0;
290     uchar4 max_val_vec = uchar_min;
291
292     // Number of elements per work-item.
293     const uint row = width / GRID_SIZE;
294     // Number of iterations per work-item.
295     const uint width_ = row >> 2;
296     // Calculate max of row
297     uint i = 0;
298     for(; i < width_; i++)
299     {
300         uchar4 data_max = vload4(0, (__global uchar *)offset(&src, i * GRID_SIZE * 4, 0));
301         max_val_vec     = MAX_OP(data_max, max_val_vec, uchar, 4);
302     }
303 #ifdef NON_MULTIPLE_OF_GRID_SIZE
304     // How many work-items needed to complete the computation.
305     int boundary_workitems = (width % (GRID_SIZE * 4)) / 4;
306     if(lid < boundary_workitems)
307     {
308         uchar4 data_max = vload4(0, (__global uchar *)offset(&src, i * GRID_SIZE * 4, 0));
309         max_val_vec     = MAX_OP(data_max, max_val_vec, uchar, 4);
310     }
311 #ifdef NON_MULTIPLE_OF_VECTOR_SIZE
312     if(boundary_workitems == 0)
313     {
314         boundary_workitems = GRID_SIZE;
315         i--;
316     }
317     if(lid == (boundary_workitems - 1))
318     {
319         // Handle non multiple of 4
320         uchar4 data_max = vload4(0, (__global uchar *)offset(&src, (GRID_SIZE * i * 4) + 4, 0));
321         uchar4 widx     = convert_uchar4(((uint4)(GRID_SIZE * i * 4) + boundary_workitems * 4 + idx4) < width);
322         max_val_vec     = MAX_OP(max_val_vec, select(uchar_min, data_max, widx), uchar, 4);
323     }
324 #endif /* NON_MULTIPLE_OF_VECTOR_SIZE */
325 #endif /* NON_MULTIPLE_OF_GRID_SIZE */
326     tmp_local[lid] = convert_int4(max_val_vec);
327
328     barrier(CLK_LOCAL_MEM_FENCE);
329
330     if(GRID_SIZE >= 256)
331     {
332         if(lid < 128)
333         {
334             tmp_local[lid] = MAX_OP(tmp_local[lid + 128], tmp_local[lid], int, 4);
335         }
336         barrier(CLK_LOCAL_MEM_FENCE);
337     }
338     if(GRID_SIZE >= 128)
339     {
340         if(lid < 64)
341         {
342             tmp_local[lid] = MAX_OP(tmp_local[lid + 64], tmp_local[lid], int, 4);
343         }
344         barrier(CLK_LOCAL_MEM_FENCE);
345     }
346     if(GRID_SIZE >= 64)
347     {
348         if(lid < 32)
349         {
350             tmp_local[lid] = MAX_OP(tmp_local[lid + 32], tmp_local[lid], int, 4);
351         }
352         barrier(CLK_LOCAL_MEM_FENCE);
353     }
354     if(GRID_SIZE >= 32)
355     {
356         if(lid < 16)
357         {
358             tmp_local[lid] = MAX_OP(tmp_local[lid + 16], tmp_local[lid], int, 4);
359         }
360         barrier(CLK_LOCAL_MEM_FENCE);
361     }
362     if(GRID_SIZE >= 16)
363     {
364         if(lid < 8)
365         {
366             tmp_local[lid] = MAX_OP(tmp_local[lid + 8], tmp_local[lid], int, 4);
367         }
368         barrier(CLK_LOCAL_MEM_FENCE);
369     }
370     if(GRID_SIZE >= 8)
371     {
372         if(lid < 4)
373         {
374             tmp_local[lid] = MAX_OP(tmp_local[lid + 4], tmp_local[lid], int, 4);
375         }
376         barrier(CLK_LOCAL_MEM_FENCE);
377     }
378     if(GRID_SIZE >= 4)
379     {
380         if(lid < 2)
381         {
382             tmp_local[lid] = MAX_OP(tmp_local[lid + 2], tmp_local[lid], int, 4);
383         }
384         barrier(CLK_LOCAL_MEM_FENCE);
385     }
386     if(lid == 0)
387     {
388         max_val_vec     = MAX_OP(convert_uchar4(tmp_local[lid + 1]), convert_uchar4(tmp_local[lid]), uchar, 4);
389         max_val_vec.s01 = MAX_OP(max_val_vec.s01, max_val_vec.s23, uchar, 2);
390         max_val_vec.s0  = MAX_OP(max_val_vec.s0, max_val_vec.s1, uchar, 1);
391         max_local       = max_val_vec.s0;
392     }
393     barrier(CLK_LOCAL_MEM_FENCE);
394
395     /* Second section */
396
397     // Set sum vector
398     int4 sum1D   = 0;
399     int  max_val = convert_int(max_local);
400
401     // Shift values, exp and sum
402     for(i = 0; i < width_; i++)
403     {
404         uchar4 data         = vload4(0, (__global uchar *)offset(&src, i * GRID_SIZE * 4, 0));
405         int4 data_fp        = convert_int4(data);
406         int4 data_diff      = data_fp - max_val;
407         int4 data_diff_mult = mult_by_quantized_multiplier_parallel(data_diff);
408         data_fp             = ASYMM_EXP_ON_NEGATIVE_VALUES(data_diff_mult, SCALED_DIFF_INT_BITS, 4);
409         data_fp             = ASYMM_RESCALE(data_fp, 0, EXP_ACCUMULATION_INT_BITS, 4);
410         vstore4(data_diff, 0, (__global int *)offset(&dst, i * GRID_SIZE * 4, 0));
411         sum1D = sum1D + select(0, data_fp, data_diff >= (int4)(DIFF_MIN));
412     }
413 #ifdef NON_MULTIPLE_OF_GRID_SIZE
414     boundary_workitems = (width % (GRID_SIZE * 4)) / 4;
415     if(lid < boundary_workitems)
416     {
417         uchar4 data         = vload4(0, (__global uchar *)offset(&src, i * GRID_SIZE * 4, 0));
418         int4 data_fp        = convert_int4(data);
419         int4 data_diff      = data_fp - max_val;
420         int4 data_diff_mult = mult_by_quantized_multiplier_parallel(data_diff);
421         data_fp             = ASYMM_EXP_ON_NEGATIVE_VALUES(data_diff_mult, SCALED_DIFF_INT_BITS, 4);
422         data_fp             = ASYMM_RESCALE(data_fp, 0, EXP_ACCUMULATION_INT_BITS, 4);
423         vstore4(data_diff, 0, (__global int *)offset(&dst, i * GRID_SIZE * 4, 0));
424         sum1D = sum1D + select(0, data_fp, data_diff >= (int4)(DIFF_MIN));
425     }
426 #ifdef NON_MULTIPLE_OF_VECTOR_SIZE
427     if(boundary_workitems == 0)
428     {
429         boundary_workitems = GRID_SIZE;
430         i--;
431     }
432     if(lid == (boundary_workitems - 1))
433     {
434         // Handle non multiple of vector size ((GRID_SIZE * i * 4) + 4, 0); move 4 float positions ahead, *4 is due to the stride
435         uchar4 data         = vload4(0, (__global uchar *)offset(&src, i * GRID_SIZE * 4 + 4, 0));
436         int4 data_fp        = convert_int4(data);
437         int4 data_diff      = data_fp - max_val;
438         int4 data_diff_mult = mult_by_quantized_multiplier_parallel(data_diff);
439         data_fp             = ASYMM_EXP_ON_NEGATIVE_VALUES(data_diff_mult, SCALED_DIFF_INT_BITS, 4);
440         data_fp             = ASYMM_RESCALE(data_fp, 0, EXP_ACCUMULATION_INT_BITS, 4);
441         int4 widx           = convert_int4(((uint4)(GRID_SIZE * i * 4) + boundary_workitems * 4 + idx4) < width);
442         data_fp             = select(0, data_fp, widx);
443         vstore4(data_diff, 0, (__global int *)offset(&dst, i * GRID_SIZE * 4 + 4, 0));
444         sum1D = sum1D + select(0, data_fp, data_diff >= (int4)(DIFF_MIN));
445     }
446 #endif /* NON_MULTIPLE_OF_VECTOR_SIZE */
447 #endif /* NON_MULTIPLE_OF_GRID_SIZE */
448     tmp_local[lid] = sum1D;
449
450     barrier(CLK_LOCAL_MEM_FENCE);
451
452     if(GRID_SIZE >= 256)
453     {
454         if(lid < 128)
455         {
456             tmp_local[lid] = ADD_OP(tmp_local[lid + 128], tmp_local[lid], int, 4);
457         }
458         barrier(CLK_LOCAL_MEM_FENCE);
459     }
460     if(GRID_SIZE >= 128)
461     {
462         if(lid < 64)
463         {
464             tmp_local[lid] = ADD_OP(tmp_local[lid + 64], tmp_local[lid], int, 4);
465         }
466         barrier(CLK_LOCAL_MEM_FENCE);
467     }
468     if(GRID_SIZE >= 64)
469     {
470         if(lid < 32)
471         {
472             tmp_local[lid] = ADD_OP(tmp_local[lid + 32], tmp_local[lid], int, 4);
473         }
474         barrier(CLK_LOCAL_MEM_FENCE);
475     }
476     if(GRID_SIZE >= 32)
477     {
478         if(lid < 16)
479         {
480             tmp_local[lid] = ADD_OP(tmp_local[lid + 16], tmp_local[lid], int, 4);
481         }
482         barrier(CLK_LOCAL_MEM_FENCE);
483     }
484     if(GRID_SIZE >= 16)
485     {
486         if(lid < 8)
487         {
488             tmp_local[lid] = ADD_OP(tmp_local[lid + 8], tmp_local[lid], int, 4);
489         }
490         barrier(CLK_LOCAL_MEM_FENCE);
491     }
492     if(GRID_SIZE >= 8)
493     {
494         if(lid < 4)
495         {
496             tmp_local[lid] = ADD_OP(tmp_local[lid + 4], tmp_local[lid], int, 4);
497         }
498         barrier(CLK_LOCAL_MEM_FENCE);
499     }
500     if(GRID_SIZE >= 4)
501     {
502         if(lid < 2)
503         {
504             tmp_local[lid] = ADD_OP(tmp_local[lid + 2], tmp_local[lid], int, 4);
505         }
506         barrier(CLK_LOCAL_MEM_FENCE);
507     }
508     if(lid == 0)
509     {
510         sum1D = ADD_OP(tmp_local[lid + 1], tmp_local[lid], int, 4);
511         // Perform max reduction
512         sum1D.s01                  = ADD_OP(sum1D.s01, sum1D.s23, int, 2);
513         sum1D.s0                   = ADD_OP(sum1D.s0, sum1D.s1, int, 1);
514         *((__global int *)sum.ptr) = sum1D.s0;
515     }
516 }
517
518 /** Divides all the values of the input tensor by the sum calculated from softmax_layer_shift_exp_sum kernel.
519  *
520  * @note Fixed point position must be given as a preprocessor argument using -DFIXED_POINT_POSITION=pos. e.g. DFIXED_POINT_POSITION=4
521  * @note Quantized beta can be optionally passed at compile time using -DINPUT_BETA_MULTIPLIER and -DINPUT_BETA_LEFT_SHIFT (if undefined, assume beta equals 1.0)
522  * @note -DDIFF_MIN must be passed at compile time. It is threshold difference between maximum value of input data and current processed value, it defines whether the value will be taken into account or not.
523  *
524  * @param[in]  src_ptr                           Pointer to the source tensor slice. Supported data types: S32
525  * @param[in]  src_stride_x                      Stride of the source tensor in X dimension (in bytes)
526  * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
527  * @param[in]  src_stride_y                      Stride of the source tensor in Y dimension (in bytes)
528  * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
529  * @param[in]  src_stride_z                      Stride of the source tensor in Z dimension (in bytes)
530  * @param[in]  src_step_z                        src_stride_z * number of elements along Z processed per workitem(in bytes)
531  * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source tensor
532  * @param[in]  sum_ptr                           Pointer to the sum values tensor slice. Supported data types: same as @p src_ptr
533  * @param[in]  sum_stride_x                      Stride of the sum values tensor in X dimension (in bytes)
534  * @param[in]  sum_step_x                        sum_stride_x * number of elements along X processed per workitem(in bytes)
535  * @param[in]  sum_stride_y                      Stride of the sum values tensor in Y dimension (in bytes)
536  * @param[in]  sum_step_y                        sum_stride_y * number of elements along Y processed per workitem(in bytes)
537  * @param[in]  sum_stride_z                      Stride of the sum values tensor in Z dimension (in bytes)
538  * @param[in]  sum_step_z                        sum_stride_z * number of elements along Z processed per workitem(in bytes)
539  * @param[in]  sum_offset_first_element_in_bytes The offset of the first element in the sum values tensor
540  * @param[out] dst_ptr                           Pointer to the destination tensor slice. Supported data types: QASYMM8
541  * @param[in]  dst_stride_x                      Stride of the destination tensor in X dimension (in bytes)
542  * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
543  * @param[in]  dst_stride_y                      Stride of the destination tensor in Y dimension (in bytes)
544  * @param[in]  dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
545  * @param[in]  dst_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
546  * @param[in]  dst_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
547  * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
548  */
549 __kernel void softmax_layer_norm_quantized(
550     TENSOR3D_DECLARATION(src),
551     TENSOR3D_DECLARATION(sum),
552     TENSOR3D_DECLARATION(dst))
553 {
554     Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
555     Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
556     Image sum = CONVERT_TENSOR3D_TO_IMAGE_STRUCT_NO_STEP(sum);
557
558     // Load max value of 1D logits vector (row)
559     int sum_val = *((__global int *)offset(&sum, 0, get_global_id(1)));
560
561     // It will be better to calculate this in prev layer and pass here as parameter
562     uint  sum_val_u               = convert_uint(sum_val);
563     int   headroom_plus_one       = clz(sum_val_u);
564     int   num_bits_over_unit      = EXP_ACCUMULATION_INT_BITS - headroom_plus_one;
565     int   shifted_sum_minus_one_1 = convert_int((sum_val_u << headroom_plus_one) - (1u << 31));
566     int16 shifted_sum_minus_one   = shifted_sum_minus_one_1;
567     int16 shifted_scale           = ASYMM_ONE_OVER_ONE_PLUS_X_FOR_X_IN_0_1(shifted_sum_minus_one, 16);
568
569     // It was already calculated in prev layer, should be stored into tmp output and reused
570     int16 data_diff      = vload16(0, (__global int *)offset(&src, 0, 0));
571     int16 data_diff_mult = data_diff;
572 #if defined(INPUT_BETA_MULTIPLIER) && defined(INPUT_BETA_LEFT_SHIFT)
573     if(INPUT_BETA_MULTIPLIER > 1)
574     {
575         data_diff_mult = ASYMM_MULT(data_diff * (1 << INPUT_BETA_LEFT_SHIFT), INPUT_BETA_MULTIPLIER, 16);
576     }
577 #endif /* defined(INPUT_BETA_MULTIPLIER) && defined(INPUT_BETA_LEFT_SHIFT) */
578     int16 data = ASYMM_EXP_ON_NEGATIVE_VALUES(data_diff_mult, SCALED_DIFF_INT_BITS, 16);
579
580     data = ASYMM_MULT(shifted_scale, data, 16);
581     data = ASYMM_ROUNDING_DIVIDE_BY_POW2(data, num_bits_over_unit + 31 - 8, 16);
582     data = select(0, data, data_diff >= (int16)(DIFF_MIN));
583     vstore16(convert_uchar16_sat(data), 0, (__global uchar *)offset(&dst, 0, 0));
584 }
585
586 #endif /* defined(DIFF_MIN) */