arm_compute v18.05
[platform/upstream/armcl.git] / src / core / CL / cl_kernels / depthwise_convolution_quantized.cl
1 /*
2  * Copyright (c) 2017-2018 ARM Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24
25 #include "helpers_asymm.h"
26
27 #if defined(WEIGHTS_OFFSET) && defined(INPUT_OFFSET) && defined(K_OFFSET) && defined(OUTPUT_OFFSET) && defined(OUTPUT_MULTIPLIER) && defined(OUTPUT_SHIFT)
28
29 #if defined(FUSED_ACTIVATION)
30 #define DATA_TYPE uchar
31 #ifndef VEC_SIZE
32 #define VEC_SIZE 8
33 #endif /* VEC_SIZE */
34 #include "activation_layer_qa8.cl"
35 #define ACTIVATION_FUNC(x) PERFORM_ACTIVATION_QA8(FUSED_ACTIVATION, x)
36 #else /* defined(FUSED_ACTIVATION) */
37 #define ACTIVATION_FUNC(x) (x)
38 #endif /* defined(FUSED_ACTIVATION) */
39
40 #if defined(CONV_STRIDE_Y) && defined(CONV_STRIDE_X)
41
42 #if CONV_STRIDE_X > 3
43 #error "Stride X not supported"
44 #endif /* CONV_STRIDE_X > 3 */
45
46 #if CONV_STRIDE_X == 1
47 #define GET_VALUES(first_value, left, middle, right)                              \
48     ({                                                                            \
49         int8 temp0 = CONVERT(vload8(0, first_value), int8);                       \
50         int2 temp1 = CONVERT(vload2(0, (first_value + 8 * sizeof(uchar))), int2); \
51         \
52         left   = CONVERT(temp0.s01234567, int8);                                  \
53         middle = CONVERT((int8)(temp0.s1234, temp0.s567, temp1.s0), int8);        \
54         right  = CONVERT((int8)(temp0.s2345, temp0.s67, temp1.s01), int8);        \
55     })
56 #elif CONV_STRIDE_X == 2
57 #define GET_VALUES(first_value, left, middle, right)                     \
58     ({                                                                   \
59         int16 temp0 = CONVERT(vload16(0, first_value), int16);           \
60         int   temp1 = CONVERT(*(first_value + 16 * sizeof(uchar)), int); \
61         \
62         left   = CONVERT(temp0.s02468ace, int8);                         \
63         middle = CONVERT(temp0.s13579bdf, int8);                         \
64         right  = CONVERT((int8)(temp0.s2468, temp0.sace, temp1), int8);  \
65     })
66 #else /* CONV_STRIDE_X */
67 #define GET_VALUES(first_value, left, middle, right)                                \
68     ({                                                                              \
69         int16 temp0 = CONVERT(vload16(0, first_value), int16);                      \
70         int8  temp1 = CONVERT(vload8(0, (first_value + 16 * sizeof(uchar))), int8); \
71         \
72         left   = CONVERT((int8)(temp0.s0369, temp0.scf, temp1.s25), int8);          \
73         middle = CONVERT((int8)(temp0.s147a, temp0.sd, temp1.s036), int8);          \
74         right  = CONVERT((int8)(temp0.s258b, temp0.se, temp1.s147), int8);          \
75     })
76 #endif /* CONV_STRIDE_X */
77
78 /** This function computes the depthwise convolution quantized.
79  *
80  * @param[in] src_ptr                               Pointer to the source image. Supported data types: QASYMM8
81  * @param[in] src_stride_x                          Stride of the source image in X dimension (in bytes)
82  * @param[in] src_step_x                            src_stride_x * number of elements along X processed per workitem(in bytes)
83  * @param[in] src_stride_y                          Stride of the source image in Y dimension (in bytes)
84  * @param[in] src_step_y                            src_stride_y * number of elements along Y processed per workitem(in bytes)
85  * @param[in] src_offset_first_element_in_bytes     The offset of the first element in the source image
86  * @param[in] src_stride_z                          Stride of the source tensor in Z dimension (in bytes)
87  * @param[in] src_step_z                            src_stride_z * number of elements along Y processed per workitem(in bytes)
88  * @param[in] dst_ptr                               Pointer to the destination tensor. Supported data types: QASYMM8
89  * @param[in] dst_stride_x                          Stride of the destination tensor in X dimension (in bytes)
90  * @param[in] dst_step_x                            dst_stride_x * number of elements along X processed per workitem(in bytes)
91  * @param[in] dst_stride_y                          Stride of the destination tensor in Y dimension (in bytes)
92  * @param[in] dst_step_y                            dst_stride_y * number of elements along Y processed per workitem(in bytes)
93  * @param[in] dst_stride_z                          Stride of the destination tensor in Z dimension (in bytes)
94  * @param[in] dst_step_z                            dst_stride_z * number of elements along Y processed per workitem(in bytes)
95  * @param[in] dst_offset_first_element_in_bytes     The offset of the first element in the destination tensor
96  * @param[in] weights_ptr                           Pointer to the weights tensor. Supported data types: QASYMM8
97  * @param[in] weights_stride_x                      Stride of the weights tensor in X dimension (in bytes)
98  * @param[in] weights_step_x                        weights_stride_x * number of elements along X processed per workitem(in bytes)
99  * @param[in] weights_stride_y                      Stride of the weights tensor in Y dimension (in bytes)
100  * @param[in] weights_step_y                        weights_stride_y * number of elements along Y processed per workitem(in bytes)
101  * @param[in] weights_stride_z                      Stride of the weights tensor in Z dimension (in bytes)
102  * @param[in] weights_step_z                        weights_stride_z * number of elements along Y processed per workitem(in bytes)
103  * @param[in] weights_offset_first_element_in_bytes The offset of the first element in the weights tensor
104  * @param[in] biases_ptr                            (Optional) Pointer to the biases vector. Supported data types: QASYMM8
105  * @param[in] biases_stride_x                       (Optional) Stride of the biases vector in X dimension (in bytes)
106  * @param[in] biases_step_x                         (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes)
107  * @param[in] biases_offset_first_element_in_bytes  (Optional) The offset of the first element in the biases vector
108  */
109
110 __kernel void depthwise_convolution_3x3_quantized_nchw(
111     TENSOR3D_DECLARATION(src),
112     TENSOR3D_DECLARATION(dst),
113     TENSOR3D_DECLARATION(weights)
114 #if defined(HAS_BIAS)
115     ,
116     VECTOR_DECLARATION(biases)
117 #endif //defined(HAS_BIAS)
118 )
119 {
120     Image    src     = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
121     Image    dst     = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
122     Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights);
123 #if defined(HAS_BIAS)
124     Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
125
126     int bias_value = *((__global int *)(vector_offset(&biases, get_global_id(2))));
127 #endif //defined(HAS_BIAS)
128
129     src.ptr -= (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z;
130
131     uchar3 w0 = vload3(0, weights.ptr + 0 * weights_stride_y);
132     uchar3 w1 = vload3(0, weights.ptr + 1 * weights_stride_y);
133     uchar3 w2 = vload3(0, weights.ptr + 2 * weights_stride_y);
134
135     int8 values0 = 0;
136     int8 sum0    = 0;
137 #if CONV_STRIDE_Y == 1
138     int8 values1 = 0;
139     int8 sum1    = 0;
140 #endif /* CONV_STRIDE_Y */
141
142     // Row0
143     int8 left, middle, right;
144     GET_VALUES(src.ptr + 0 * src_stride_y, left, middle, right);
145     values0 += left * (int8)(w0.s0);
146     values0 += middle * (int8)(w0.s1);
147     values0 += right * (int8)(w0.s2);
148
149 #if WEIGHTS_OFFSET != 0
150     sum0 += left + middle + right;
151 #endif /* WEIGHTS_OFFSET != 0 */
152
153     // Row1
154     GET_VALUES(src.ptr + 1 * src_stride_y, left, middle, right);
155     values0 += left * (int8)(w1.s0);
156     values0 += middle * (int8)(w1.s1);
157     values0 += right * (int8)(w1.s2);
158 #if CONV_STRIDE_Y == 1
159     values1 += left * (int8)(w0.s0);
160     values1 += middle * (int8)(w0.s1);
161     values1 += right * (int8)(w0.s2);
162 #endif /* CONV_STRIDE_Y == 1 */
163
164 #if WEIGHTS_OFFSET != 0
165     int8 tmp = left + middle + right;
166     sum0 += tmp;
167 #if CONV_STRIDE_Y == 1
168     sum1 += tmp;
169 #endif /* CONV_STRIDE_Y == 1 */
170 #endif /* WEIGHTS_OFFSET != 0 */
171
172     // Row2
173     GET_VALUES(src.ptr + 2 * src_stride_y, left, middle, right);
174     values0 += left * (int8)(w2.s0);
175     values0 += middle * (int8)(w2.s1);
176     values0 += right * (int8)(w2.s2);
177 #if CONV_STRIDE_Y == 1
178     values1 += left * (int8)(w1.s0);
179     values1 += middle * (int8)(w1.s1);
180     values1 += right * (int8)(w1.s2);
181 #endif /* CONV_STRIDE_Y == 1 */
182
183 #if WEIGHTS_OFFSET != 0
184     tmp = left + middle + right;
185     sum0 += tmp;
186 #if CONV_STRIDE_Y == 1
187     sum1 += tmp;
188 #endif /* CONV_STRIDE_Y == 1 */
189 #endif /* WEIGHTS_OFFSET != 0 */
190
191 #if CONV_STRIDE_Y == 1
192     // Row3
193     GET_VALUES(src.ptr + 3 * src_stride_y, left, middle, right);
194     values1 += left * (int8)(w2.s0);
195     values1 += middle * (int8)(w2.s1);
196     values1 += right * (int8)(w2.s2);
197
198 #if WEIGHTS_OFFSET != 0
199     sum1 += left + middle + right;
200 #endif /* WEIGHTS_OFFSET != 0 */
201 #endif /* CONV_STRIDE_Y == 1 */
202
203 #if defined(HAS_BIAS)
204     values0 += (int8)(bias_value);
205 #if CONV_STRIDE_Y == 1
206     values1 += (int8)(bias_value);
207 #endif /* CONV_STRIDE_Y == 1 */
208 #endif //defined(HAS_BIAS)
209
210 #if WEIGHTS_OFFSET != 0
211     values0 += sum0 * (int8)(WEIGHTS_OFFSET);
212 #if CONV_STRIDE_Y == 1
213     values1 += sum1 * (int8)(WEIGHTS_OFFSET);
214 #endif /* CONV_STRIDE_Y == 1 */
215 #endif /* WEIGHTS_OFFSET != 0 */
216
217 #if INPUT_OFFSET != 0
218     ushort  sum_weights = 0;
219     ushort3 tmp_we      = convert_ushort3(w0) + convert_ushort3(w1) + convert_ushort3(w2);
220     sum_weights += tmp_we.s0 + tmp_we.s1 + tmp_we.s2;
221     values0 += sum_weights * (int8)(INPUT_OFFSET);
222 #if CONV_STRIDE_Y == 1
223     values1 += sum_weights * (int8)(INPUT_OFFSET);
224 #endif /* CONV_STRIDE_Y == 1 */
225 #endif /* INPUT_OFFSET != 0 */
226
227 #if K_OFFSET != 0
228     values0 += (int8)(K_OFFSET);
229 #if CONV_STRIDE_Y == 1
230     values1 += (int8)(K_OFFSET);
231 #endif /* CONV_STRIDE_Y == 1 */
232 #endif /* K_OFFSET != 0 */
233
234     values0 = ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(values0, OUTPUT_MULTIPLIER, OUTPUT_SHIFT, 8);
235     values0 += (int8)OUTPUT_OFFSET;
236     uchar8 res0 = convert_uchar8_sat(values0);
237     res0        = max(res0, (uchar8)0);
238     res0        = min(res0, (uchar8)255);
239
240     vstore8(ACTIVATION_FUNC(res0), 0, dst.ptr);
241 #if CONV_STRIDE_Y == 1
242
243     values1 = ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(values1, OUTPUT_MULTIPLIER, OUTPUT_SHIFT, 8);
244     values1 += (int8)OUTPUT_OFFSET;
245     uchar8 res1 = convert_uchar8_sat(values1);
246     res1        = max(res1, (uchar8)0);
247     res1        = min(res1, (uchar8)255);
248
249     vstore8(ACTIVATION_FUNC(res1), 0, dst.ptr + dst_stride_y);
250 #endif /* CONV_STRIDE_Y == 1 */
251 }
252
253 #endif /* defined(CONV_STRIDE_Y) && defined(CONV_STRIDE_X) */
254
255 #if defined(VEC_SIZE) && defined(SRC_DEPTH) && defined(CONV_PAD_TOP) && defined(ROWS_READ)
256
257 #define asymm_mult_by_quant_multiplier_less_than_one(x, y, z) ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(x, y, z, VEC_SIZE)
258
259 #define VEC_INT VEC_DATA_TYPE(int, VEC_SIZE)
260 #define VEC_UCHAR VEC_DATA_TYPE(uchar, VEC_SIZE)
261
262 #define BIFROST_MAD_4(acc, x, y)               \
263     ({                                         \
264         acc.s0 += (ushort)x.s0 * (ushort)y.s0; \
265         acc.s1 += (ushort)x.s1 * (ushort)y.s1; \
266         acc.s2 += (ushort)x.s2 * (ushort)y.s2; \
267         acc.s3 += (ushort)x.s3 * (ushort)y.s3; \
268     })
269
270 #if WEIGHTS_OFFSET != 0
271 #define BIFROST_MAD_ACC_4(acc, sum, x, y) \
272     ({                                    \
273         sum += CONVERT(x, VEC_INT);       \
274         BIFROST_MAD_4(acc, x, y);         \
275     })
276 #else /* WEIGHTS_OFFSET != 0 */
277 #define BIFROST_MAD_ACC_4(acc, sum, x, y) BIFROST_MAD_4(acc, x, y)
278 #endif /* WEIGHTS_OFFSET != 0 */
279
280 /** This function computes the depthwise convolution quantized.
281  *
282  * @param[in] src_ptr                               Pointer to the source image. Supported data types: QASYMM8
283  * @param[in] src_stride_x                          Stride of the source image in X dimension (in bytes)
284  * @param[in] src_step_x                            src_stride_x * number of elements along X processed per workitem(in bytes)
285  * @param[in] src_stride_y                          Stride of the source image in Y dimension (in bytes)
286  * @param[in] src_step_y                            src_stride_y * number of elements along Y processed per workitem(in bytes)
287  * @param[in] src_offset_first_element_in_bytes     The offset of the first element in the source image
288  * @param[in] src_stride_z                          Stride of the source tensor in Z dimension (in bytes)
289  * @param[in] src_step_z                            src_stride_z * number of elements along Y processed per workitem(in bytes)
290  * @param[in] dst_ptr                               Pointer to the destination tensor. Supported data types: QASYMM8
291  * @param[in] dst_stride_x                          Stride of the destination tensor in X dimension (in bytes)
292  * @param[in] dst_step_x                            dst_stride_x * number of elements along X processed per workitem(in bytes)
293  * @param[in] dst_stride_y                          Stride of the destination tensor in Y dimension (in bytes)
294  * @param[in] dst_step_y                            dst_stride_y * number of elements along Y processed per workitem(in bytes)
295  * @param[in] dst_stride_z                          Stride of the destination tensor in Z dimension (in bytes)
296  * @param[in] dst_step_z                            dst_stride_z * number of elements along Y processed per workitem(in bytes)
297  * @param[in] dst_offset_first_element_in_bytes     The offset of the first element in the destination tensor
298  * @param[in] weights_ptr                           Pointer to the weights tensor. Supported data types: QASYMM8
299  * @param[in] weights_stride_x                      Stride of the weights tensor in X dimension (in bytes)
300  * @param[in] weights_step_x                        weights_stride_x * number of elements along X processed per workitem(in bytes)
301  * @param[in] weights_stride_y                      Stride of the weights tensor in Y dimension (in bytes)
302  * @param[in] weights_step_y                        weights_stride_y * number of elements along Y processed per workitem(in bytes)
303  * @param[in] weights_stride_z                      Stride of the weights tensor in Z dimension (in bytes)
304  * @param[in] weights_step_z                        weights_stride_z * number of elements along Y processed per workitem(in bytes)
305  * @param[in] weights_offset_first_element_in_bytes The offset of the first element in the weights tensor
306  * @param[in] biases_ptr                            (Optional) Pointer to the biases vector. Supported data types: QASYMM8
307  * @param[in] biases_stride_x                       (Optional) Stride of the biases vector in X dimension (in bytes)
308  * @param[in] biases_step_x                         (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes)
309  * @param[in] biases_offset_first_element_in_bytes  (Optional) The offset of the first element in the biases vector
310  */
311
312 __kernel void depthwise_convolution_3x3_quantized_nhwc_stride1(
313     TENSOR3D_DECLARATION(src),
314     TENSOR3D_DECLARATION(dst),
315     TENSOR3D_DECLARATION(weights),
316 #if defined(HAS_BIAS)
317     VECTOR_DECLARATION(biases)
318 #endif /* defined(HAS_BIAS) */
319 )
320 {
321     Image  dst     = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
322     Vector weights = CONVERT_TO_VECTOR_STRUCT(weights);
323 #if defined(HAS_BIAS)
324     Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
325
326     VEC_INT bias_values = VLOAD(VEC_SIZE)(0, (__global int *)biases.ptr);
327 #endif /* defined(HAS_BIAS) */
328
329     __global uchar *first_elem = src_ptr + src_offset_first_element_in_bytes;
330
331     const int z         = get_global_id(2);
332     const int pad_offs  = -ROWS_READ * src_stride_y;
333     const int src_offs0 = get_global_id(0) * src_step_x + get_global_id(1) * src_step_y + z * src_step_z - CONV_PAD_TOP * src_stride_z;
334     const int src_offs1 = src_offs0 + src_stride_z;
335     const int src_offs2 = src_offs1 + src_stride_z;
336
337     const int cond_top    = z - CONV_PAD_TOP < 0;
338     const int cond_bottom = z * (src_step_z / src_stride_z) + 2 > SRC_DEPTH;
339
340     __global uchar *src_addr0 = first_elem + select(src_offs0, pad_offs, cond_top);
341     __global uchar *src_addr1 = first_elem + src_offs1;
342     __global uchar *src_addr2 = first_elem + select(src_offs2, pad_offs, cond_bottom);
343
344     VEC_INT sum_we = 0;
345     VEC_INT acc0 = 0, acc1 = 0, acc2 = 0, acc3 = 0;
346     VEC_INT sum0 = 0, sum1 = 0, sum2 = 0, sum3 = 0;
347
348     // z == 0
349     VEC_UCHAR w0, w1, w2;
350     w0 = VLOAD(VEC_SIZE)(0, weights.ptr + 0 * weights_stride_y);
351     w1 = VLOAD(VEC_SIZE)(0, weights.ptr + 1 * weights_stride_y);
352     w2 = VLOAD(VEC_SIZE)(0, weights.ptr + 2 * weights_stride_y);
353
354 #if INPUT_OFFSET != 0
355     sum_we += CONVERT(w0, VEC_INT) + CONVERT(w1, VEC_INT) + CONVERT(w2, VEC_INT);
356 #endif /* INPUT_OFFSET != 0 */
357
358     VEC_UCHAR values = VLOAD(VEC_SIZE)(0, src_addr0);
359     BIFROST_MAD_ACC_4(acc0, sum0, values, w0);
360
361     src_addr0 += src_stride_y;
362     values = VLOAD(VEC_SIZE)(0, src_addr0);
363     BIFROST_MAD_ACC_4(acc0, sum0, values, w1);
364     BIFROST_MAD_ACC_4(acc1, sum1, values, w0);
365
366     src_addr0 += src_stride_y;
367     values = VLOAD(VEC_SIZE)(0, src_addr0);
368     BIFROST_MAD_ACC_4(acc0, sum0, values, w2);
369     BIFROST_MAD_ACC_4(acc1, sum1, values, w1);
370     BIFROST_MAD_ACC_4(acc2, sum2, values, w0);
371
372     src_addr0 += src_stride_y;
373     values = VLOAD(VEC_SIZE)(0, src_addr0);
374     BIFROST_MAD_ACC_4(acc1, sum1, values, w2);
375     BIFROST_MAD_ACC_4(acc2, sum2, values, w1);
376     BIFROST_MAD_ACC_4(acc3, sum3, values, w0);
377
378     src_addr0 += src_stride_y;
379     values = VLOAD(VEC_SIZE)(0, src_addr0);
380     BIFROST_MAD_ACC_4(acc2, sum2, values, w2);
381     BIFROST_MAD_ACC_4(acc3, sum3, values, w1);
382
383     src_addr0 += src_stride_y;
384     values = VLOAD(VEC_SIZE)(0, src_addr0);
385     BIFROST_MAD_ACC_4(acc3, sum3, values, w2);
386
387     weights.ptr += weights_stride_z;
388
389     // z == 1
390     w0 = VLOAD(VEC_SIZE)(0, weights.ptr + 0 * weights_stride_y);
391     w1 = VLOAD(VEC_SIZE)(0, weights.ptr + 1 * weights_stride_y);
392     w2 = VLOAD(VEC_SIZE)(0, weights.ptr + 2 * weights_stride_y);
393
394 #if INPUT_OFFSET != 0
395     sum_we += CONVERT(w0, VEC_INT) + CONVERT(w1, VEC_INT) + CONVERT(w2, VEC_INT);
396 #endif /* INPUT_OFFSET != 0 */
397
398     values = VLOAD(VEC_SIZE)(0, src_addr1);
399     BIFROST_MAD_ACC_4(acc0, sum0, values, w0);
400
401     src_addr1 += src_stride_y;
402     values = VLOAD(VEC_SIZE)(0, src_addr1);
403     BIFROST_MAD_ACC_4(acc0, sum0, values, w1);
404     BIFROST_MAD_ACC_4(acc1, sum1, values, w0);
405
406     src_addr1 += src_stride_y;
407     values = VLOAD(VEC_SIZE)(0, src_addr1);
408     BIFROST_MAD_ACC_4(acc0, sum0, values, w2);
409     BIFROST_MAD_ACC_4(acc1, sum1, values, w1);
410     BIFROST_MAD_ACC_4(acc2, sum2, values, w0);
411
412     src_addr1 += src_stride_y;
413     values = VLOAD(VEC_SIZE)(0, src_addr1);
414     BIFROST_MAD_ACC_4(acc1, sum1, values, w2);
415     BIFROST_MAD_ACC_4(acc2, sum2, values, w1);
416     BIFROST_MAD_ACC_4(acc3, sum3, values, w0);
417
418     src_addr1 += src_stride_y;
419     values = VLOAD(VEC_SIZE)(0, src_addr1);
420     BIFROST_MAD_ACC_4(acc2, sum2, values, w2);
421     BIFROST_MAD_ACC_4(acc3, sum3, values, w1);
422
423     src_addr1 += src_stride_y;
424     values = VLOAD(VEC_SIZE)(0, src_addr1);
425     BIFROST_MAD_ACC_4(acc3, sum3, values, w2);
426
427     weights.ptr += weights_stride_z;
428
429     // z == 2
430     w0 = VLOAD(VEC_SIZE)(0, weights.ptr + 0 * weights_stride_y);
431     w1 = VLOAD(VEC_SIZE)(0, weights.ptr + 1 * weights_stride_y);
432     w2 = VLOAD(VEC_SIZE)(0, weights.ptr + 2 * weights_stride_y);
433
434 #if INPUT_OFFSET != 0
435     sum_we += CONVERT(w0, VEC_INT) + CONVERT(w1, VEC_INT) + CONVERT(w2, VEC_INT);
436 #endif /* INPUT_OFFSET != 0 */
437
438     values = VLOAD(VEC_SIZE)(0, src_addr2);
439     BIFROST_MAD_ACC_4(acc0, sum0, values, w0);
440
441     src_addr2 += src_stride_y;
442     values = VLOAD(VEC_SIZE)(0, src_addr2);
443     BIFROST_MAD_ACC_4(acc0, sum0, values, w1);
444     BIFROST_MAD_ACC_4(acc1, sum1, values, w0);
445
446     src_addr2 += src_stride_y;
447     values = VLOAD(VEC_SIZE)(0, src_addr2);
448     BIFROST_MAD_ACC_4(acc0, sum0, values, w2);
449     BIFROST_MAD_ACC_4(acc1, sum1, values, w1);
450     BIFROST_MAD_ACC_4(acc2, sum2, values, w0);
451
452     src_addr2 += src_stride_y;
453     values = VLOAD(VEC_SIZE)(0, src_addr2);
454     BIFROST_MAD_ACC_4(acc1, sum1, values, w2);
455     BIFROST_MAD_ACC_4(acc2, sum2, values, w1);
456     BIFROST_MAD_ACC_4(acc3, sum3, values, w0);
457
458     src_addr2 += src_stride_y;
459     values = VLOAD(VEC_SIZE)(0, src_addr2);
460     BIFROST_MAD_ACC_4(acc2, sum2, values, w2);
461     BIFROST_MAD_ACC_4(acc3, sum3, values, w1);
462
463     src_addr2 += src_stride_y;
464     values = VLOAD(VEC_SIZE)(0, src_addr2);
465     BIFROST_MAD_ACC_4(acc3, sum3, values, w2);
466
467 #if defined(HAS_BIAS)
468     acc0 += bias_values;
469     acc1 += bias_values;
470     acc2 += bias_values;
471     acc3 += bias_values;
472 #endif /* defined(HAS_BIAS) */
473
474 #if WEIGHTS_OFFSET != 0
475     acc0 += WEIGHTS_OFFSET * sum0;
476     acc1 += WEIGHTS_OFFSET * sum1;
477     acc2 += WEIGHTS_OFFSET * sum2;
478     acc3 += WEIGHTS_OFFSET * sum3;
479 #endif /* WEIGHTS_OFFSET != 0 */
480
481 #if INPUT_OFFSET != 0
482     VEC_INT offs = INPUT_OFFSET * sum_we;
483
484     acc0 += offs;
485     acc1 += offs;
486     acc2 += offs;
487     acc3 += offs;
488 #endif /* INPUT_OFFSET != 0 */
489
490 #if K_OFFSET != 0
491     acc0 += (VEC_INT)K_OFFSET;
492     acc1 += (VEC_INT)K_OFFSET;
493     acc2 += (VEC_INT)K_OFFSET;
494     acc3 += (VEC_INT)K_OFFSET;
495 #endif /* K_OFFSET != 0 */
496
497     acc0 = asymm_mult_by_quant_multiplier_less_than_one(acc0, OUTPUT_MULTIPLIER, OUTPUT_SHIFT);
498     acc1 = asymm_mult_by_quant_multiplier_less_than_one(acc1, OUTPUT_MULTIPLIER, OUTPUT_SHIFT);
499     acc2 = asymm_mult_by_quant_multiplier_less_than_one(acc2, OUTPUT_MULTIPLIER, OUTPUT_SHIFT);
500     acc3 = asymm_mult_by_quant_multiplier_less_than_one(acc3, OUTPUT_MULTIPLIER, OUTPUT_SHIFT);
501
502     acc0 += (VEC_INT)OUTPUT_OFFSET;
503     acc1 += (VEC_INT)OUTPUT_OFFSET;
504     acc2 += (VEC_INT)OUTPUT_OFFSET;
505     acc3 += (VEC_INT)OUTPUT_OFFSET;
506
507     VEC_UCHAR res0 = CONVERT_SAT(acc0, VEC_UCHAR);
508     VEC_UCHAR res1 = CONVERT_SAT(acc1, VEC_UCHAR);
509     VEC_UCHAR res2 = CONVERT_SAT(acc2, VEC_UCHAR);
510     VEC_UCHAR res3 = CONVERT_SAT(acc3, VEC_UCHAR);
511
512     res0 = CLAMP(res0, (VEC_UCHAR)0, (VEC_UCHAR)255);
513     res1 = CLAMP(res1, (VEC_UCHAR)0, (VEC_UCHAR)255);
514     res2 = CLAMP(res2, (VEC_UCHAR)0, (VEC_UCHAR)255);
515     res3 = CLAMP(res3, (VEC_UCHAR)0, (VEC_UCHAR)255);
516
517     VSTORE(VEC_SIZE)
518     (res0, 0, dst.ptr + 0 * dst_stride_y);
519     VSTORE(VEC_SIZE)
520     (res1, 0, dst.ptr + 1 * dst_stride_y);
521     VSTORE(VEC_SIZE)
522     (res2, 0, dst.ptr + 2 * dst_stride_y);
523     VSTORE(VEC_SIZE)
524     (res3, 0, dst.ptr + 3 * dst_stride_y);
525 }
526
527 /** This function computes the depthwise convolution quantized.
528  *
529  * @param[in] src_ptr                               Pointer to the source image. Supported data types: QASYMM8
530  * @param[in] src_stride_x                          Stride of the source image in X dimension (in bytes)
531  * @param[in] src_step_x                            src_stride_x * number of elements along X processed per workitem(in bytes)
532  * @param[in] src_stride_y                          Stride of the source image in Y dimension (in bytes)
533  * @param[in] src_step_y                            src_stride_y * number of elements along Y processed per workitem(in bytes)
534  * @param[in] src_offset_first_element_in_bytes     The offset of the first element in the source image
535  * @param[in] src_stride_z                          Stride of the source tensor in Z dimension (in bytes)
536  * @param[in] src_step_z                            src_stride_z * number of elements along Y processed per workitem(in bytes)
537  * @param[in] dst_ptr                               Pointer to the destination tensor. Supported data types: QASYMM8
538  * @param[in] dst_stride_x                          Stride of the destination tensor in X dimension (in bytes)
539  * @param[in] dst_step_x                            dst_stride_x * number of elements along X processed per workitem(in bytes)
540  * @param[in] dst_stride_y                          Stride of the destination tensor in Y dimension (in bytes)
541  * @param[in] dst_step_y                            dst_stride_y * number of elements along Y processed per workitem(in bytes)
542  * @param[in] dst_stride_z                          Stride of the destination tensor in Z dimension (in bytes)
543  * @param[in] dst_step_z                            dst_stride_z * number of elements along Y processed per workitem(in bytes)
544  * @param[in] dst_offset_first_element_in_bytes     The offset of the first element in the destination tensor
545  * @param[in] weights_ptr                           Pointer to the weights tensor. Supported data types: QASYMM8
546  * @param[in] weights_stride_x                      Stride of the weights tensor in X dimension (in bytes)
547  * @param[in] weights_step_x                        weights_stride_x * number of elements along X processed per workitem(in bytes)
548  * @param[in] weights_stride_y                      Stride of the weights tensor in Y dimension (in bytes)
549  * @param[in] weights_step_y                        weights_stride_y * number of elements along Y processed per workitem(in bytes)
550  * @param[in] weights_stride_z                      Stride of the weights tensor in Z dimension (in bytes)
551  * @param[in] weights_step_z                        weights_stride_z * number of elements along Y processed per workitem(in bytes)
552  * @param[in] weights_offset_first_element_in_bytes The offset of the first element in the weights tensor
553  * @param[in] biases_ptr                            (Optional) Pointer to the biases vector. Supported data types: QASYMM8
554  * @param[in] biases_stride_x                       (Optional) Stride of the biases vector in X dimension (in bytes)
555  * @param[in] biases_step_x                         (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes)
556  * @param[in] biases_offset_first_element_in_bytes  (Optional) The offset of the first element in the biases vector
557  */
558
559 __kernel void depthwise_convolution_3x3_quantized_nhwc_stride2(
560     TENSOR3D_DECLARATION(src),
561     TENSOR3D_DECLARATION(dst),
562     TENSOR3D_DECLARATION(weights),
563 #if defined(HAS_BIAS)
564     VECTOR_DECLARATION(biases)
565 #endif /* defined(HAS_BIAS) */
566 )
567 {
568     Image  dst     = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
569     Vector weights = CONVERT_TO_VECTOR_STRUCT(weights);
570 #if defined(HAS_BIAS)
571     Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
572
573     VEC_INT bias_values = VLOAD(VEC_SIZE)(0, (__global int *)biases.ptr);
574 #endif /* defined(HAS_BIAS) */
575
576     __global uchar *first_elem = src_ptr + src_offset_first_element_in_bytes;
577
578     const int z         = get_global_id(2);
579     const int pad_offs  = -ROWS_READ * src_stride_y;
580     const int src_offs0 = get_global_id(0) * src_step_x + get_global_id(1) * src_step_y + z * src_step_z - CONV_PAD_TOP * src_stride_z;
581     const int src_offs1 = src_offs0 + src_stride_z;
582     const int src_offs2 = src_offs1 + src_stride_z;
583
584     const int cond_top    = z - CONV_PAD_TOP < 0;
585     const int cond_bottom = z * (src_step_z / src_stride_z) + 2 > SRC_DEPTH;
586
587     __global uchar *src_addr0 = first_elem + select(src_offs0, pad_offs, cond_top);
588     __global uchar *src_addr1 = first_elem + src_offs1;
589     __global uchar *src_addr2 = first_elem + select(src_offs2, pad_offs, cond_bottom);
590
591     VEC_INT sum_we = 0;
592     VEC_INT acc0 = 0, acc2 = 0;
593     VEC_INT sum0 = 0, sum2 = 0;
594
595     // z == 0
596     VEC_UCHAR w0, w1, w2;
597     w0 = VLOAD(VEC_SIZE)(0, weights.ptr + 0 * weights_stride_y);
598     w1 = VLOAD(VEC_SIZE)(0, weights.ptr + 1 * weights_stride_y);
599     w2 = VLOAD(VEC_SIZE)(0, weights.ptr + 2 * weights_stride_y);
600
601 #if INPUT_OFFSET != 0
602     sum_we += CONVERT(w0, VEC_INT) + CONVERT(w1, VEC_INT) + CONVERT(w2, VEC_INT);
603 #endif /* INPUT_OFFSET != 0 */
604
605     VEC_UCHAR values = VLOAD(VEC_SIZE)(0, src_addr0);
606     BIFROST_MAD_ACC_4(acc0, sum0, values, w0);
607
608     src_addr0 += src_stride_y;
609     values = VLOAD(VEC_SIZE)(0, src_addr0);
610     BIFROST_MAD_ACC_4(acc0, sum0, values, w1);
611
612     src_addr0 += src_stride_y;
613     values = VLOAD(VEC_SIZE)(0, src_addr0);
614     BIFROST_MAD_ACC_4(acc0, sum0, values, w2);
615     BIFROST_MAD_ACC_4(acc2, sum2, values, w0);
616
617     src_addr0 += src_stride_y;
618     values = VLOAD(VEC_SIZE)(0, src_addr0);
619     BIFROST_MAD_ACC_4(acc2, sum2, values, w1);
620
621     src_addr0 += src_stride_y;
622     values = VLOAD(VEC_SIZE)(0, src_addr0);
623     BIFROST_MAD_ACC_4(acc2, sum2, values, w2);
624
625     weights.ptr += weights_stride_z;
626
627     // z == 1
628     w0 = VLOAD(VEC_SIZE)(0, weights.ptr + 0 * weights_stride_y);
629     w1 = VLOAD(VEC_SIZE)(0, weights.ptr + 1 * weights_stride_y);
630     w2 = VLOAD(VEC_SIZE)(0, weights.ptr + 2 * weights_stride_y);
631
632 #if INPUT_OFFSET != 0
633     sum_we += CONVERT(w0, VEC_INT) + CONVERT(w1, VEC_INT) + CONVERT(w2, VEC_INT);
634 #endif /* INPUT_OFFSET != 0 */
635
636     values = VLOAD(VEC_SIZE)(0, src_addr1);
637     BIFROST_MAD_ACC_4(acc0, sum0, values, w0);
638
639     src_addr1 += src_stride_y;
640     values = VLOAD(VEC_SIZE)(0, src_addr1);
641     BIFROST_MAD_ACC_4(acc0, sum0, values, w1);
642
643     src_addr1 += src_stride_y;
644     values = VLOAD(VEC_SIZE)(0, src_addr1);
645     BIFROST_MAD_ACC_4(acc0, sum0, values, w2);
646     BIFROST_MAD_ACC_4(acc2, sum2, values, w0);
647
648     src_addr1 += src_stride_y;
649     values = VLOAD(VEC_SIZE)(0, src_addr1);
650     BIFROST_MAD_ACC_4(acc2, sum2, values, w1);
651
652     src_addr1 += src_stride_y;
653     values = VLOAD(VEC_SIZE)(0, src_addr1);
654     BIFROST_MAD_ACC_4(acc2, sum2, values, w2);
655
656     weights.ptr += weights_stride_z;
657
658     // z == 2
659     w0 = VLOAD(VEC_SIZE)(0, weights.ptr + 0 * weights_stride_y);
660     w1 = VLOAD(VEC_SIZE)(0, weights.ptr + 1 * weights_stride_y);
661     w2 = VLOAD(VEC_SIZE)(0, weights.ptr + 2 * weights_stride_y);
662
663 #if INPUT_OFFSET != 0
664     sum_we += CONVERT(w0, VEC_INT) + CONVERT(w1, VEC_INT) + CONVERT(w2, VEC_INT);
665 #endif /* INPUT_OFFSET != 0 */
666
667     values = VLOAD(VEC_SIZE)(0, src_addr2);
668     BIFROST_MAD_ACC_4(acc0, sum0, values, w0);
669
670     src_addr2 += src_stride_y;
671     values = VLOAD(VEC_SIZE)(0, src_addr2);
672     BIFROST_MAD_ACC_4(acc0, sum0, values, w1);
673
674     src_addr2 += src_stride_y;
675     values = VLOAD(VEC_SIZE)(0, src_addr2);
676     BIFROST_MAD_ACC_4(acc0, sum0, values, w2);
677     BIFROST_MAD_ACC_4(acc2, sum2, values, w0);
678
679     src_addr2 += src_stride_y;
680     values = VLOAD(VEC_SIZE)(0, src_addr2);
681     BIFROST_MAD_ACC_4(acc2, sum2, values, w1);
682
683     src_addr2 += src_stride_y;
684     values = VLOAD(VEC_SIZE)(0, src_addr2);
685     BIFROST_MAD_ACC_4(acc2, sum2, values, w2);
686
687 #if defined(HAS_BIAS)
688     acc0 += bias_values;
689     acc2 += bias_values;
690 #endif /* defined(HAS_BIAS) */
691
692 #if WEIGHTS_OFFSET != 0
693     acc0 += WEIGHTS_OFFSET * sum0;
694     acc2 += WEIGHTS_OFFSET * sum2;
695 #endif /* WEIGHTS_OFFSET != 0 */
696
697 #if INPUT_OFFSET != 0
698     VEC_INT offs = INPUT_OFFSET * sum_we;
699
700     acc0 += offs;
701     acc2 += offs;
702 #endif /* INPUT_OFFSET != 0 */
703
704 #if K_OFFSET != 0
705     acc0 += (VEC_INT)K_OFFSET;
706     acc2 += (VEC_INT)K_OFFSET;
707 #endif /* K_OFFSET != 0 */
708
709     acc0 = asymm_mult_by_quant_multiplier_less_than_one(acc0, OUTPUT_MULTIPLIER, OUTPUT_SHIFT);
710     acc2 = asymm_mult_by_quant_multiplier_less_than_one(acc2, OUTPUT_MULTIPLIER, OUTPUT_SHIFT);
711     acc0 += (VEC_INT)OUTPUT_OFFSET;
712     acc2 += (VEC_INT)OUTPUT_OFFSET;
713     VEC_UCHAR res0 = CONVERT_SAT(acc0, VEC_UCHAR);
714     VEC_UCHAR res2 = CONVERT_SAT(acc2, VEC_UCHAR);
715     res0           = CLAMP(res0, (VEC_UCHAR)0, (VEC_UCHAR)255);
716     res2           = CLAMP(res2, (VEC_UCHAR)0, (VEC_UCHAR)255);
717
718     VSTORE(VEC_SIZE)
719     (res0, 0, dst.ptr + 0 * dst_stride_y);
720     VSTORE(VEC_SIZE)
721     (res2, 0, dst.ptr + 1 * dst_stride_y);
722 }
723
724 #endif /* defined(VEC_SIZE) && defined(SRC_DEPTH) && defined(CONV_PAD_TOP) && defined(ROWS_READ) */
725
726 #endif /* defined(WEIGHTS_OFFSET) && defined(INPUT_OFFSET) && defined(K_OFFSET) && defined(OUTPUT_OFFSET) && defined(OUTPUT_MULTIPLIER) && defined(OUTPUT_SHIFT) */