Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / compute / ARMComputeEx / src / core / CL / cl_kernels / gemmlowp_ex.cl
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /*
18  * Copyright (c) 2017-2019 ARM Limited.
19  *
20  * SPDX-License-Identifier: MIT
21  *
22  * Permission is hereby granted, free of charge, to any person obtaining a copy
23  * of this software and associated documentation files (the "Software"), to
24  * deal in the Software without restriction, including without limitation the
25  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
26  * sell copies of the Software, and to permit persons to whom the Software is
27  * furnished to do so, subject to the following conditions:
28  *
29  * The above copyright notice and this permission notice shall be included in all
30  * copies or substantial portions of the Software.
31  *
32  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
33  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
34  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
35  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
36  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
37  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
38  * SOFTWARE.
39  */
40
41 #include "helpers.h"
42
43 #if defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) && \
44   defined(COLS_A)
45 #define VECTOR_CHAR VEC_DATA_TYPE(char, NUM_ELEMS_PROCESSED_PER_THREAD_X)
46 #define VECTOR_INT VEC_DATA_TYPE(int, NUM_ELEMS_PROCESSED_PER_THREAD_X)
47 #define VECTOR_FLOAT VEC_DATA_TYPE(float, NUM_ELEMS_PROCESSED_PER_THREAD_X)
48 /** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B
49  * (src1) in case both matrices have not beed reshaped
50  *
51  * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
52  *
53  * @note In case the input or output have to be reinterpreted as a 3D tensor, the following
54  * information must be passed at compile time:
55  *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
56  *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
57  *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D
58  * tensor.
59  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
60  *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
61  *
62  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data type:
63  * QASYMM8
64  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in
65  * bytes)
66  * @param[in]  src0_step_x                        src_stride_x * number of elements along X
67  * processed per workitem(in bytes)
68  * @param[in]  src0_stride_y                      Stride of the source matrix in Y dimension (in
69  * bytes)
70  * @param[in]  src0_step_y                        src_stride_y * number of elements along Y
71  * processed per workitem(in bytes)
72  * @param[in]  src0_offset_first_element_in_bytes The offset of the first element in the source
73  * matrix
74  * @param[in]  src1_ptr                           Pointer to the source matrix. Supported data type:
75  * same as @p src0_ptr
76  * @param[in]  src1_stride_x                      Stride of the source matrix in X dimension (in
77  * bytes)
78  * @param[in]  src1_step_x                        src_stride_x * number of elements along X
79  * processed per workitem(in bytes)
80  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in
81  * bytes)
82  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y
83  * processed per workitem(in bytes)
84  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source
85  * matrix
86  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data
87  * type: S32
88  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension
89  * (in bytes)
90  * @param[in]  dst_step_x                         dst_gx_stride_x * number of elements along X
91  * processed per workitem(in bytes)
92  * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension
93  * (in bytes)
94  * @param[in]  dst_step_y                         dst_gx_stride_y * number of elements along Y
95  * processed per workitem(in bytes)
96  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination
97  * matrix
98  * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in
99  * bytes)
100  * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in
101  * bytes)
102  * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension
103  * (in bytes)
104  * @param[in]  src_cross_plane_pad                (Optional) Bottom paddings in unit of elements for
105  * the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
106  * @param[in]  dst_cross_plane_pad                (Optional) Bottom paddings in unit of elements for
107  * the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
108  */
109 __kernel void gemmlowp_mm_midgard_ex(IMAGE_DECLARATION(src0), IMAGE_DECLARATION(src1),
110                                      IMAGE_DECLARATION(dst), uint src0_stride_z, uint src1_stride_z,
111                                      uint dst_stride_z
112 #if defined(REINTERPRET_INPUT_AS_3D)
113                                      ,
114                                      uint src_cross_plane_pad
115 #endif // REINTERPRET_INPUT_AS_3D
116 #if defined(REINTERPRET_OUTPUT_AS_3D)
117                                      ,
118                                      uint dst_cross_plane_pad
119 #endif // REINTERPRET_OUTPUT_AS_3D
120 )
121 {
122   int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
123
124   // Compute starting address for matrix A and Matrix B
125   int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
126
127   // Update address for the matrix A
128   src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
129
130   // Update address for the matrix B
131   src_addr.s1 += idx;
132
133 #if defined(REINTERPRET_INPUT_AS_3D)
134   // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across
135   // the z dimension
136   // in order to take into account the presence of possible cross plane paddings
137   //
138   //  |                  |
139   //  |      plane0      |
140   //  |                  |
141   //  |__________________|
142   //  |******************|
143   //  |  cross_plane_pad |
144   //  |******************|
145   //  |                  |
146   //  |      plane1      |
147   //  |                  |
148   //  |__________________|
149
150   // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)
151   // by HEIGHT_GEMM3D
152   uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) /
153               (uint4)HEIGHT_GEMM3D;
154   zin = min(DEPTH_GEMM3D - 1, zin);
155
156   // Add offset due to the cross plane paddings
157   zin *= (src_cross_plane_pad * src0_stride_y);
158
159   // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
160   // multiply src0_stride_z by DEPTH_GEMM3D
161   src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
162
163 #else // defined(REINTERPRET_INPUT_AS_3D)
164
165   // Add offset for batched GEMM
166   src_addr.s0 += get_global_id(2) * src0_stride_z;
167
168 #endif // defined(REINTERPRET_INPUT_AS_3D)
169
170 #if defined(MATRIX_B_DEPTH)
171   // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
172   src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
173 #else  // defined(MATRIX_B_DEPTH)
174   src_addr.s1 += get_global_id(2) * src1_stride_z;
175 #endif // defined(MATRIX_B_DEPTH)
176
177   int end_row_vec_a = src_addr.s0 + COLS_A;
178
179   VECTOR_INT acc0 = 0;
180 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
181   VECTOR_INT acc1 = 0;
182 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
183 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
184   VECTOR_INT acc2 = 0;
185 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
186 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
187   VECTOR_INT acc3 = 0;
188 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
189 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
190   VECTOR_INT acc4 = 0;
191 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
192
193   for (; src_addr.s0 <= (end_row_vec_a - 2); src_addr += (int2)(2, 2 * src1_stride_y))
194   {
195     // Load values from matrix A
196     char2 a0 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
197 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
198     char2 a1 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
199 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
200 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
201     char2 a2 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
202 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
203 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
204     char2 a3 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
205 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
206 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
207     char2 a4 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 4 * src0_stride_y));
208 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
209     // Load values from matrix B
210     VECTOR_CHAR b0 =
211       VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global char *)(src1_ptr + src_addr.s1));
212     VECTOR_CHAR b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(
213       0, (__global char *)(src1_ptr + src_addr.s1 + src1_stride_y));
214
215     // Accumulate
216     acc0 += CONVERT(b0, VECTOR_INT) * (VECTOR_INT)a0.s0;
217     acc0 += CONVERT(b1, VECTOR_INT) * (VECTOR_INT)a0.s1;
218 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
219     acc1 += CONVERT(b0, VECTOR_INT) * (VECTOR_INT)a1.s0;
220     acc1 += CONVERT(b1, VECTOR_INT) * (VECTOR_INT)a1.s1;
221 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
222 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
223     acc2 += CONVERT(b0, VECTOR_INT) * (VECTOR_INT)a2.s0;
224     acc2 += CONVERT(b1, VECTOR_INT) * (VECTOR_INT)a2.s1;
225 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
226 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
227     acc3 += CONVERT(b0, VECTOR_INT) * (VECTOR_INT)a3.s0;
228     acc3 += CONVERT(b1, VECTOR_INT) * (VECTOR_INT)a3.s1;
229 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
230 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
231     acc4 += CONVERT(b0, VECTOR_INT) * (VECTOR_INT)a4.s0;
232     acc4 += CONVERT(b1, VECTOR_INT) * (VECTOR_INT)a4.s1;
233 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
234   }
235
236   for (; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y))
237   {
238     // Load values from matrix A
239     char a0 = *(__global char *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y);
240 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
241     char a1 = *(__global char *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y);
242 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
243 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
244     char a2 = *(__global char *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y);
245 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
246 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
247     char a3 = *(__global char *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y);
248 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
249 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
250     char a4 = *(__global char *)(src0_ptr + src_addr.s0 + 4 * src0_stride_y);
251 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
252     // Load values from matrix B
253     VECTOR_CHAR b0 =
254       VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global char *)(src1_ptr + src_addr.s1));
255
256     // Accumulate
257     acc0 += CONVERT(b0, VECTOR_INT) * (VECTOR_INT)a0;
258 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
259     acc1 += CONVERT(b0, VECTOR_INT) * (VECTOR_INT)a1;
260 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
261 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
262     acc2 += CONVERT(b0, VECTOR_INT) * (VECTOR_INT)a2;
263 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
264 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
265     acc3 += CONVERT(b0, VECTOR_INT) * (VECTOR_INT)a3;
266 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
267 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
268     acc4 += CONVERT(b0, VECTOR_INT) * (VECTOR_INT)a4;
269 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
270   }
271
272   const int z = get_global_id(2);
273
274   // Compute destination address
275   Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
276
277 #if defined(REINTERPRET_OUTPUT_AS_3D)
278   // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across
279   // the z dimension
280   // in order to take into account the presence of possible cross plane paddings
281   //
282   //  |                  |
283   //  |      plane0      |
284   //  |                  |
285   //  |__________________|
286   //  |******************|
287   //  |  cross_plane_pad |
288   //  |******************|
289   //  |                  |
290   //  |      plane1      |
291   //  |                  |
292   //  |__________________|
293
294   // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)
295   // by HEIGHT_GEMM3D
296   uint8 zout = ((uint8)(0, 1, 2, 3, 4, 5, 6, 7) +
297                 (uint8)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) /
298                (uint8)HEIGHT_GEMM3D;
299   zout = min(DEPTH_GEMM3D - 1, zout);
300
301   // Add offset due to the cross plane paddings
302   zout *= (dst_cross_plane_pad * dst_stride_y);
303
304   // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
305   // multiply dst_stride_z by DEPTH_GEMM3D
306   dst.ptr += z * dst_stride_z * DEPTH_GEMM3D;
307
308   // Store the result
309   VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
310   (CONVERT(acc0, VECTOR_INT), 0, (__global int *)(dst.ptr + 0 * dst_stride_y + zout.s0));
311 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
312   VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
313   (CONVERT(acc1, VECTOR_INT), 0, (__global int *)(dst.ptr + 1 * dst_stride_y + zout.s1));
314 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
315 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
316   VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
317   (CONVERT(acc2, VECTOR_INT), 0, (__global int *)(dst.ptr + 2 * dst_stride_y + zout.s2));
318 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
319 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
320   VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
321   (CONVERT(acc3, VECTOR_INT), 0, (__global int *)(dst.ptr + 3 * dst_stride_y + zout.s3));
322 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
323 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
324   VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
325   (CONVERT(acc4, VECTOR_INT), 0, (__global int *)(dst.ptr + 4 * dst_stride_y + zout.s4));
326 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
327
328 #else // defined(REINTERPRET_OUTPUT_AS_3D)
329   // Add offset for batched GEMM
330   dst.ptr += z * dst_stride_z;
331
332   // Store the result
333   VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
334   (CONVERT(acc0, VECTOR_INT), 0, (__global int *)(dst.ptr + 0 * dst_stride_y));
335 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
336   VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
337   (CONVERT(acc1, VECTOR_INT), 0, (__global int *)(dst.ptr + 1 * dst_stride_y));
338 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
339 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
340   VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
341   (CONVERT(acc2, VECTOR_INT), 0, (__global int *)(dst.ptr + 2 * dst_stride_y));
342 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
343 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
344   VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
345   (CONVERT(acc3, VECTOR_INT), 0, (__global int *)(dst.ptr + 3 * dst_stride_y));
346 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
347 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
348   VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
349   (CONVERT(acc4, VECTOR_INT), 0, (__global int *)(dst.ptr + 4 * dst_stride_y));
350 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
351 #endif // defined(REINTERPRET_OUTPUT_AS_3D)
352 }
353 #endif // defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) &&
354        // defined(COLS_A)