arm_compute v18.02
[platform/upstream/armcl.git] / src / core / CL / cl_kernels / depthwise_convolution_quantized.cl
1 /*
2  * Copyright (c) 2017-2018 ARM Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24
25 #include "helpers_asymm.h"
26
27 #if defined(CONV_STRIDE_X) && defined(CONV_STRIDE_Y) && defined(WEIGHTS_OFFSET) && defined(INPUT_OFFSET) && defined(K_OFFSET) && defined(OUTPUT_OFFSET) && defined(OUTPUT_MULTIPLIER) && defined(OUTPUT_SHIFT)
28
29 #if CONV_STRIDE_X > 3
30 #error "Stride X not supported"
31 #endif /* CONV_STRIDE_X > 3 */
32
33 #if CONV_STRIDE_X == 1
34 #define GET_VALUES(first_value, left, middle, right)                              \
35     ({                                                                            \
36         int8 temp0 = CONVERT(vload8(0, first_value), int8);                       \
37         int2 temp1 = CONVERT(vload2(0, (first_value + 8 * sizeof(uchar))), int2); \
38         \
39         left   = CONVERT(temp0.s01234567, int8);                                  \
40         middle = CONVERT((int8)(temp0.s1234, temp0.s567, temp1.s0), int8);        \
41         right  = CONVERT((int8)(temp0.s2345, temp0.s67, temp1.s01), int8);        \
42     })
43 #elif CONV_STRIDE_X == 2
44 #define GET_VALUES(first_value, left, middle, right)                     \
45     ({                                                                   \
46         int16 temp0 = CONVERT(vload16(0, first_value), int16);           \
47         int   temp1 = CONVERT(*(first_value + 16 * sizeof(uchar)), int); \
48         \
49         left   = CONVERT(temp0.s02468ace, int8);                         \
50         middle = CONVERT(temp0.s13579bdf, int8);                         \
51         right  = CONVERT((int8)(temp0.s2468, temp0.sace, temp1), int8);  \
52     })
53 #else /* CONV_STRIDE_X */
54 #define GET_VALUES(first_value, left, middle, right)                                \
55     ({                                                                              \
56         int16 temp0 = CONVERT(vload16(0, first_value), int16);                      \
57         int8  temp1 = CONVERT(vload8(0, (first_value + 16 * sizeof(uchar))), int8); \
58         \
59         left   = CONVERT((int8)(temp0.s0369, temp0.scf, temp1.s25), int8);          \
60         middle = CONVERT((int8)(temp0.s147a, temp0.sd, temp1.s036), int8);          \
61         right  = CONVERT((int8)(temp0.s258b, temp0.se, temp1.s147), int8);          \
62     })
63 #endif /* CONV_STRIDE_X */
64
65 /** This function computes the horizontal integral of the image and adds offsets.
66  *
67  * @param[in] src_ptr                               Pointer to the source image. Supported data types: QASYMM8
68  * @param[in] src_stride_x                          Stride of the source image in X dimension (in bytes)
69  * @param[in] src_step_x                            src_stride_x * number of elements along X processed per workitem(in bytes)
70  * @param[in] src_stride_y                          Stride of the source image in Y dimension (in bytes)
71  * @param[in] src_step_y                            src_stride_y * number of elements along Y processed per workitem(in bytes)
72  * @param[in] src_offset_first_element_in_bytes     The offset of the first element in the source image
73  * @param[in] src_stride_z                          Stride of the source tensor in Z dimension (in bytes)
74  * @param[in] src_step_z                            src_stride_z * number of elements along Y processed per workitem(in bytes)
75  * @param[in] dst_ptr                               Pointer to the destination tensor. Supported data types: QASYMM8
76  * @param[in] dst_stride_x                          Stride of the destination tensor in X dimension (in bytes)
77  * @param[in] dst_step_x                            dst_stride_x * number of elements along X processed per workitem(in bytes)
78  * @param[in] dst_stride_y                          Stride of the destination tensor in Y dimension (in bytes)
79  * @param[in] dst_step_y                            dst_stride_y * number of elements along Y processed per workitem(in bytes)
80  * @param[in] dst_stride_z                          Stride of the destination tensor in Z dimension (in bytes)
81  * @param[in] dst_step_z                            dst_stride_z * number of elements along Y processed per workitem(in bytes)
82  * @param[in] dst_offset_first_element_in_bytes     The offset of the first element in the destination tensor
83  * @param[in] weights_ptr                           Pointer to the weights tensor. Supported data types: QASYMM8
84  * @param[in] weights_stride_x                      Stride of the weights tensor in X dimension (in bytes)
85  * @param[in] weights_step_x                        weights_stride_x * number of elements along X processed per workitem(in bytes)
86  * @param[in] weights_stride_y                      Stride of the weights tensor in Y dimension (in bytes)
87  * @param[in] weights_step_y                        weights_stride_y * number of elements along Y processed per workitem(in bytes)
88  * @param[in] weights_stride_z                      Stride of the weights tensor in Z dimension (in bytes)
89  * @param[in] weights_step_z                        weights_stride_z * number of elements along Y processed per workitem(in bytes)
90  * @param[in] weights_offset_first_element_in_bytes The offset of the first element in the weights tensor
91  * @param[in] biases_ptr                            (Optional) Pointer to the biases vector. Supported data types: QASYMM8
92  * @param[in] biases_stride_x                       (Optional) Stride of the biases vector in X dimension (in bytes)
93  * @param[in] biases_step_x                         (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes)
94  * @param[in] biases_offset_first_element_in_bytes  (Optional) The offset of the first element in the biases vector
95  */
96
97 __kernel void depthwise_convolution_3x3_quantized(
98     TENSOR3D_DECLARATION(src),
99     TENSOR3D_DECLARATION(dst),
100     TENSOR3D_DECLARATION(weights)
101 #if defined(HAS_BIAS)
102     ,
103     VECTOR_DECLARATION(biases)
104 #endif //defined(HAS_BIAS)
105 )
106 {
107     Image    src     = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
108     Image    dst     = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
109     Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights);
110 #if defined(HAS_BIAS)
111     Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
112
113     int bias_value = *((__global int *)(vector_offset(&biases, get_global_id(2))));
114 #endif //defined(HAS_BIAS)
115
116     uchar3 w0 = vload3(0, weights.ptr + 0 * weights_stride_y);
117     uchar3 w1 = vload3(0, weights.ptr + 1 * weights_stride_y);
118     uchar3 w2 = vload3(0, weights.ptr + 2 * weights_stride_y);
119
120     int8 values0 = 0;
121     int8 sum0    = 0;
122 #if CONV_STRIDE_Y == 1
123     int8 values1 = 0;
124     int8 sum1    = 0;
125 #endif /* CONV_STRIDE_Y */
126
127     // Row0
128     int8 left, middle, right;
129     GET_VALUES(src.ptr + 0 * src_stride_y, left, middle, right);
130     values0 += left * (int8)(w0.s0);
131     values0 += middle * (int8)(w0.s1);
132     values0 += right * (int8)(w0.s2);
133
134 #if WEIGHTS_OFFSET != 0
135     sum0 += left + middle + right;
136 #endif /* WEIGHTS_OFFSET != 0 */
137
138     // Row1
139     GET_VALUES(src.ptr + 1 * src_stride_y, left, middle, right);
140     values0 += left * (int8)(w1.s0);
141     values0 += middle * (int8)(w1.s1);
142     values0 += right * (int8)(w1.s2);
143 #if CONV_STRIDE_Y == 1
144     values1 += left * (int8)(w0.s0);
145     values1 += middle * (int8)(w0.s1);
146     values1 += right * (int8)(w0.s2);
147 #endif /* CONV_STRIDE_Y == 1 */
148
149 #if WEIGHTS_OFFSET != 0
150     int8 tmp = left + middle + right;
151     sum0 += tmp;
152 #if CONV_STRIDE_Y == 1
153     sum1 += tmp;
154 #endif /* CONV_STRIDE_Y == 1 */
155 #endif /* WEIGHTS_OFFSET != 0 */
156
157     // Row2
158     GET_VALUES(src.ptr + 2 * src_stride_y, left, middle, right);
159     values0 += left * (int8)(w2.s0);
160     values0 += middle * (int8)(w2.s1);
161     values0 += right * (int8)(w2.s2);
162 #if CONV_STRIDE_Y == 1
163     values1 += left * (int8)(w1.s0);
164     values1 += middle * (int8)(w1.s1);
165     values1 += right * (int8)(w1.s2);
166 #endif /* CONV_STRIDE_Y == 1 */
167
168 #if WEIGHTS_OFFSET != 0
169     tmp = left + middle + right;
170     sum0 += tmp;
171 #if CONV_STRIDE_Y == 1
172     sum1 += tmp;
173 #endif /* CONV_STRIDE_Y == 1 */
174 #endif /* WEIGHTS_OFFSET != 0 */
175
176 #if CONV_STRIDE_Y == 1
177     // Row3
178     GET_VALUES(src.ptr + 3 * src_stride_y, left, middle, right);
179     values1 += left * (int8)(w2.s0);
180     values1 += middle * (int8)(w2.s1);
181     values1 += right * (int8)(w2.s2);
182
183 #if WEIGHTS_OFFSET != 0
184     sum1 += left + middle + right;
185 #endif /* WEIGHTS_OFFSET != 0 */
186 #endif /* CONV_STRIDE_Y == 1 */
187
188 #if defined(HAS_BIAS)
189     values0 += (int8)(bias_value);
190 #if CONV_STRIDE_Y == 1
191     values1 += (int8)(bias_value);
192 #endif /* CONV_STRIDE_Y == 1 */
193 #endif //defined(HAS_BIAS)
194
195 #if WEIGHTS_OFFSET != 0
196     values0 += sum0 * (int8)(WEIGHTS_OFFSET);
197 #if CONV_STRIDE_Y == 1
198     values1 += sum1 * (int8)(WEIGHTS_OFFSET);
199 #endif /* CONV_STRIDE_Y == 1 */
200 #endif /* WEIGHTS_OFFSET != 0 */
201
202 #if INPUT_OFFSET != 0
203     ushort  sum_weights = 0;
204     ushort3 tmp_we      = convert_ushort3(w0) + convert_ushort3(w1) + convert_ushort3(w2);
205     sum_weights += tmp_we.s0 + tmp_we.s1 + tmp_we.s2;
206     values0 += sum_weights * (int8)(INPUT_OFFSET);
207 #if CONV_STRIDE_Y == 1
208     values1 += sum_weights * (int8)(INPUT_OFFSET);
209 #endif /* CONV_STRIDE_Y == 1 */
210 #endif /* INPUT_OFFSET != 0 */
211
212 #if K_OFFSET != 0
213     values0 += (int8)(K_OFFSET);
214 #if CONV_STRIDE_Y == 1
215     values1 += (int8)(K_OFFSET);
216 #endif /* CONV_STRIDE_Y == 1 */
217 #endif /* K_OFFSET != 0 */
218
219     values0 = ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(values0, OUTPUT_MULTIPLIER, OUTPUT_SHIFT, 8);
220     values0 += (int8)OUTPUT_OFFSET;
221     uchar8 res0 = convert_uchar8_sat(values0);
222     res0        = max(res0, (uchar8)0);
223     res0        = min(res0, (uchar8)255);
224
225     vstore8(res0, 0, dst.ptr);
226 #if CONV_STRIDE_Y == 1
227
228     values1 = ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(values1, OUTPUT_MULTIPLIER, OUTPUT_SHIFT, 8);
229     values1 += (int8)OUTPUT_OFFSET;
230     uchar8 res1 = convert_uchar8_sat(values1);
231     res1        = max(res1, (uchar8)0);
232     res1        = min(res1, (uchar8)255);
233
234     vstore8(res1, 0, dst.ptr + dst_stride_y);
235 #endif /* CONV_STRIDE_Y == 1 */
236 }
237
238 #endif /* defined(CONV_STRIDE_X) && defined(CONV_STRIDE_Y) && defined(WEIGHTS_OFFSET) && defined(INPUT_OFFSET) && defined(K_OFFSET) && defined(OUTPUT_OFFSET) && defined(OUTPUT_MULTIPLIER) && defined(OUTPUT_SHIFT) */