arm_compute v18.02
[platform/upstream/armcl.git] / src / core / CL / cl_kernels / gemm.cl
1 /*
2  * Copyright (c) 2017-2018 ARM Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "helpers.h"
25
26 #ifdef FIXED_POINT_POSITION
27 #include "fixed_point.h"
28 #endif // FIXED_POINT_POSITION
29
30 #if defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
31
32 #if ELEMENT_SIZE == 1
33 #define DATA_TYPE uchar
34 #elif ELEMENT_SIZE == 2
35 #define DATA_TYPE ushort
36 #elif ELEMENT_SIZE == 4
37 #define DATA_TYPE uint
38 #else // ELEMENT_SIZE == 1
39 #error "Element size not supported"
40 #endif // ELEMENT_SIZE
41
42 /** This OpenCL kernel computes the "vector" 1xW transposition of input matrix
43  *
44  * @note The transposition width must be passed at compile time using -DTRANSPOSE_W (i.e. -DTRANSPOSE_W)
45  * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
46  *
47  * @param[in]  src_ptr                           Pointer to the source matrix. Supported data types: U8/S8/QS8/QASYMM8/U16/S16/QS16/F16/U32/S32/F32
48  * @param[in]  src_stride_x                      Stride of the source matrix in X dimension (in bytes)
49  * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
50  * @param[in]  src_stride_y                      Stride of the source matrix in Y dimension (in bytes)
51  * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
52  * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source matrix
53  * @param[out] dst_ptr                           Pointer to the destination matrix Supported data types: same as @p src_ptr
54  * @param[in]  dst_stride_x                      Stride of the destination matrix in X dimension (in bytes)
55  * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
56  * @param[in]  dst_stride_y                      Stride of the destination matrix in Y dimension (in bytes)
57  * @param[in]  dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
58  * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
59  */
60 __kernel void gemm_transpose1xW(IMAGE_DECLARATION(src),
61                                 IMAGE_DECLARATION(dst))
62 {
63     uint x = get_global_id(0);
64     uint y = get_global_id(1);
65
66     // Compute address for Matrix B - source
67     Image src = CONVERT_TO_IMAGE_STRUCT(src);
68
69     // Compute address for Matrix B transposed - destination. X and Y are swapped
70     uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + y * TRANSPOSE_W * sizeof(DATA_TYPE) * MULT_TRANSPOSE1XW_WIDTH + (x / MULT_TRANSPOSE1XW_WIDTH) * dst_stride_y +
71                              (x % MULT_TRANSPOSE1XW_WIDTH) * TRANSPOSE_W * sizeof(DATA_TYPE);
72
73     VEC_DATA_TYPE(DATA_TYPE, TRANSPOSE_W)
74     b0 = VLOAD(TRANSPOSE_W)(0, (__global DATA_TYPE *)src.ptr);
75
76     VSTORE(TRANSPOSE_W)
77     (b0, 0, (__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes));
78 }
79 #endif // defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
80
81 #if defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
82
83 /** This OpenCL kernel reshapes the input matrix transposing each 4x4 block and interleaving the values
84  *
85  * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
86  * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
87  *
88  * @param[in]  src_ptr                           Pointer to the source matrix. Supported data types: U8/S8/QS8/QASYMM8/U16/S16/QS16/F16/U32/S32/F32
89  * @param[in]  src_stride_x                      Stride of the source matrix in X dimension (in bytes)
90  * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
91  * @param[in]  src_stride_y                      Stride of the source matrix in Y dimension (in bytes)
92  * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
93  * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source matrix
94  * @param[out] dst_ptr                           Pointer to the destination matrix Supported data types: same as @p src_ptr
95  * @param[in]  dst_stride_x                      Stride of the destination matrix in X dimension (in bytes)
96  * @param[in]  dst_step_x                        dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
97  * @param[in]  dst_stride_y                      Stride of the destination matrix in Y dimension (in bytes)
98  * @param[in]  dst_step_y                        dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
99  * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
100  */
101 __kernel void gemm_interleave4x4(IMAGE_DECLARATION(src),
102                                  IMAGE_DECLARATION(dst))
103 {
104     // Compute source and destination addresses
105     uint x = get_global_id(0);
106     uint y = get_global_id(1);
107
108     // Compute address for Matrix B - source
109     Image src = CONVERT_TO_IMAGE_STRUCT(src);
110
111     // Compute address for Matrix B transposed - destination. X and Y are swapped
112     uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) * 16 * MULT_INTERLEAVE4X4_HEIGHT + (y / MULT_INTERLEAVE4X4_HEIGHT) * dst_stride_y +
113                              (y % MULT_INTERLEAVE4X4_HEIGHT) * 4 * sizeof(DATA_TYPE);
114
115     // Load values from Matrix A
116     VEC_DATA_TYPE(DATA_TYPE, 4)
117     a0 = vload4(0, (__global DATA_TYPE *)(offset(&src, 0, 0)));
118     VEC_DATA_TYPE(DATA_TYPE, 4)
119     a1 = vload4(0, (__global DATA_TYPE *)(offset(&src, 0, 1)));
120     VEC_DATA_TYPE(DATA_TYPE, 4)
121     a2 = vload4(0, (__global DATA_TYPE *)(offset(&src, 0, 2)));
122     VEC_DATA_TYPE(DATA_TYPE, 4)
123     a3 = vload4(0, (__global DATA_TYPE *)(offset(&src, 0, 3)));
124
125     VEC_DATA_TYPE(DATA_TYPE, 4)
126     val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s0, a1.s0, a2.s0, a3.s0);
127     vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 0 * MULT_INTERLEAVE4X4_HEIGHT));
128
129     val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s1, a1.s1, a2.s1, a3.s1);
130     vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 4 * MULT_INTERLEAVE4X4_HEIGHT));
131
132     val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s2, a1.s2, a2.s2, a3.s2);
133     vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 8 * MULT_INTERLEAVE4X4_HEIGHT));
134
135     val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s3, a1.s3, a2.s3, a3.s3);
136     vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 12 * MULT_INTERLEAVE4X4_HEIGHT));
137 }
138 #endif // defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
139
140 #if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
141 /** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
142  *  Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
143  *
144  * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
145  * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
146  * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
147  *
148  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F32
149  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
150  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
151  * @param[in]  src0_stride_y                      Stride of the source matrix in Y dimension (in bytes)
152  * @param[in]  src0_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
153  * @param[in]  src0_offset_first_element_in_bytes The offset of the first element in the source matrix
154  * @param[in]  src1_ptr                           Pointer to the source matrix. Supported data types: same as @p src0_ptr
155  * @param[in]  src1_stride_x                      Stride of the source matrix in X dimension (in bytes)
156  * @param[in]  src1_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
157  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
158  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
159  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
160  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
161  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
162  * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
163  * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
164  * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
165  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
166  */
167 __kernel void gemm_mm_interleaved_transposed_f32_midgard(IMAGE_DECLARATION(src0),
168                                                          IMAGE_DECLARATION(src1),
169                                                          IMAGE_DECLARATION(dst))
170 {
171     int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
172     int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
173
174     // Offset
175     const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
176     const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
177
178     // src_addr_a = address of matrix A
179     // src_addr_b = address of matrix B
180     __global float *src_addr_a = (__global float *)(src0_ptr + y * src0_stride_y + src0_offset_first_element_in_bytes);
181     __global float *src_addr_b = (__global float *)(src1_ptr + x * src1_stride_y + src1_offset_first_element_in_bytes);
182
183     // Compute end row address for matrix B
184     __global float *src_end_addr_b = src_addr_b + COLS_B;
185
186     src_addr_a += offset_row_a;
187     src_addr_b += offset_row_b;
188
189     // Reset accumulators
190     float4 c00 = 0.0f;
191     float4 c10 = 0.0f;
192     float4 c20 = 0.0f;
193     float4 c30 = 0.0f;
194
195     for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
196     {
197         // Load values from matrix A (interleaved) and matrix B (transposed)
198         float4 a0 = vload4(0, src_addr_a);
199         float4 b0 = vload4(0, src_addr_b);
200
201         c00 += (float4)a0.s0 * b0;
202         c10 += (float4)a0.s1 * b0;
203         c20 += (float4)a0.s2 * b0;
204         c30 += (float4)a0.s3 * b0;
205
206         // Load values from matrix A (interleaved) and matrix B (transposed)
207         a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
208         b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
209
210         c00 += (float4)a0.s0 * b0;
211         c10 += (float4)a0.s1 * b0;
212         c20 += (float4)a0.s2 * b0;
213         c30 += (float4)a0.s3 * b0;
214     }
215
216     for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH)
217     {
218         // Load values from matrix A (interleaved) and matrix B (transposed)
219         float4 a0 = vload4(0, src_addr_a);
220         float4 b0 = vload4(0, src_addr_b);
221
222         c00 += (float4)a0.s0 * b0;
223         c10 += (float4)a0.s1 * b0;
224         c20 += (float4)a0.s2 * b0;
225         c30 += (float4)a0.s3 * b0;
226     }
227
228     // Compute destination address
229     Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
230
231 #if defined(ALPHA)
232     // Multiply by the weight of matrix product
233     c00 = c00 * (float4)ALPHA;
234     c10 = c10 * (float4)ALPHA;
235     c20 = c20 * (float4)ALPHA;
236     c30 = c30 * (float4)ALPHA;
237 #endif // defined(ALPHA)
238
239     // Store 4x4 block
240     vstore4(c00, 0, (__global float *)(offset(&dst, 0, 0)));
241     vstore4(c10, 0, (__global float *)(offset(&dst, 0, 1)));
242     vstore4(c20, 0, (__global float *)(offset(&dst, 0, 2)));
243     vstore4(c30, 0, (__global float *)(offset(&dst, 0, 3)));
244 }
245
246 /** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
247  *  Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
248  *
249  * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
250  * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
251  * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
252  *
253  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F32
254  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
255  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
256  * @param[in]  src0_stride_y                      Stride of the source matrix in Y dimension (in bytes)
257  * @param[in]  src0_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
258  * @param[in]  src0_offset_first_element_in_bytes The offset of the first element in the source matrix
259  * @param[in]  src1_ptr                           Pointer to the source matrix. Supported data types: same as @p src0_ptr
260  * @param[in]  src1_stride_x                      Stride of the source matrix in X dimension (in bytes)
261  * @param[in]  src1_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
262  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
263  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
264  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
265  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
266  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
267  * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
268  * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
269  * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
270  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
271  */
272 __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
273                                                          IMAGE_DECLARATION(src1),
274                                                          IMAGE_DECLARATION(dst))
275 {
276     int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
277     int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
278
279     // Offset
280     const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
281     const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
282
283     // src_addr_a = address of matrix A
284     // src_addr_b = address of matrix B
285     __global float *src_addr_a = (__global float *)(src0_ptr + y * src0_stride_y + src0_offset_first_element_in_bytes);
286     __global float *src_addr_b = (__global float *)(src1_ptr + x * src1_stride_y + src1_offset_first_element_in_bytes);
287
288     // Compute end row address for matrix B
289     __global float *src_end_addr_b = src_addr_b + COLS_B;
290
291     src_addr_a += offset_row_a;
292     src_addr_b += offset_row_b;
293
294     // Reset accumulators
295     float c00 = 0.0f;
296     float c01 = 0.0f;
297     float c02 = 0.0f;
298     float c03 = 0.0f;
299     float c10 = 0.0f;
300     float c11 = 0.0f;
301     float c12 = 0.0f;
302     float c13 = 0.0f;
303     float c20 = 0.0f;
304     float c21 = 0.0f;
305     float c22 = 0.0f;
306     float c23 = 0.0f;
307     float c30 = 0.0f;
308     float c31 = 0.0f;
309     float c32 = 0.0f;
310     float c33 = 0.0f;
311
312     for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += (16 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (16 * MULT_TRANSPOSE1XW_WIDTH))
313     {
314         // Load values from matrix A (interleaved) and matrix B (transposed)
315         float4 a0 = vload4(0, src_addr_a);
316         float4 b0 = vload4(0, src_addr_b);
317
318         c00 = fma(a0.s0, b0.s0, c00);
319         c01 = fma(a0.s0, b0.s1, c01);
320         c02 = fma(a0.s0, b0.s2, c02);
321         c03 = fma(a0.s0, b0.s3, c03);
322
323         c10 = fma(a0.s1, b0.s0, c10);
324         c11 = fma(a0.s1, b0.s1, c11);
325         c12 = fma(a0.s1, b0.s2, c12);
326         c13 = fma(a0.s1, b0.s3, c13);
327
328         c20 = fma(a0.s2, b0.s0, c20);
329         c21 = fma(a0.s2, b0.s1, c21);
330         c22 = fma(a0.s2, b0.s2, c22);
331         c23 = fma(a0.s2, b0.s3, c23);
332
333         c30 = fma(a0.s3, b0.s0, c30);
334         c31 = fma(a0.s3, b0.s1, c31);
335         c32 = fma(a0.s3, b0.s2, c32);
336         c33 = fma(a0.s3, b0.s3, c33);
337
338         // Load values from matrix A (interleaved) and matrix B (transposed)
339         a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
340         b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
341
342         c00 = fma(a0.s0, b0.s0, c00);
343         c01 = fma(a0.s0, b0.s1, c01);
344         c02 = fma(a0.s0, b0.s2, c02);
345         c03 = fma(a0.s0, b0.s3, c03);
346
347         c10 = fma(a0.s1, b0.s0, c10);
348         c11 = fma(a0.s1, b0.s1, c11);
349         c12 = fma(a0.s1, b0.s2, c12);
350         c13 = fma(a0.s1, b0.s3, c13);
351
352         c20 = fma(a0.s2, b0.s0, c20);
353         c21 = fma(a0.s2, b0.s1, c21);
354         c22 = fma(a0.s2, b0.s2, c22);
355         c23 = fma(a0.s2, b0.s3, c23);
356
357         c30 = fma(a0.s3, b0.s0, c30);
358         c31 = fma(a0.s3, b0.s1, c31);
359         c32 = fma(a0.s3, b0.s2, c32);
360         c33 = fma(a0.s3, b0.s3, c33);
361
362         // Load values from matrix A (interleaved) and matrix B (transposed)
363         a0 = vload4(0, src_addr_a + 8 * MULT_INTERLEAVE4X4_HEIGHT);
364         b0 = vload4(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
365
366         c00 = fma(a0.s0, b0.s0, c00);
367         c01 = fma(a0.s0, b0.s1, c01);
368         c02 = fma(a0.s0, b0.s2, c02);
369         c03 = fma(a0.s0, b0.s3, c03);
370
371         c10 = fma(a0.s1, b0.s0, c10);
372         c11 = fma(a0.s1, b0.s1, c11);
373         c12 = fma(a0.s1, b0.s2, c12);
374         c13 = fma(a0.s1, b0.s3, c13);
375
376         c20 = fma(a0.s2, b0.s0, c20);
377         c21 = fma(a0.s2, b0.s1, c21);
378         c22 = fma(a0.s2, b0.s2, c22);
379         c23 = fma(a0.s2, b0.s3, c23);
380
381         c30 = fma(a0.s3, b0.s0, c30);
382         c31 = fma(a0.s3, b0.s1, c31);
383         c32 = fma(a0.s3, b0.s2, c32);
384         c33 = fma(a0.s3, b0.s3, c33);
385
386         // Load values from matrix A (interleaved) and matrix B (transposed)
387         a0 = vload4(0, src_addr_a + 12 * MULT_INTERLEAVE4X4_HEIGHT);
388         b0 = vload4(0, src_addr_b + 12 * MULT_TRANSPOSE1XW_WIDTH);
389
390         c00 = fma(a0.s0, b0.s0, c00);
391         c01 = fma(a0.s0, b0.s1, c01);
392         c02 = fma(a0.s0, b0.s2, c02);
393         c03 = fma(a0.s0, b0.s3, c03);
394
395         c10 = fma(a0.s1, b0.s0, c10);
396         c11 = fma(a0.s1, b0.s1, c11);
397         c12 = fma(a0.s1, b0.s2, c12);
398         c13 = fma(a0.s1, b0.s3, c13);
399
400         c20 = fma(a0.s2, b0.s0, c20);
401         c21 = fma(a0.s2, b0.s1, c21);
402         c22 = fma(a0.s2, b0.s2, c22);
403         c23 = fma(a0.s2, b0.s3, c23);
404
405         c30 = fma(a0.s3, b0.s0, c30);
406         c31 = fma(a0.s3, b0.s1, c31);
407         c32 = fma(a0.s3, b0.s2, c32);
408         c33 = fma(a0.s3, b0.s3, c33);
409     }
410
411     for(; src_addr_b < src_end_addr_b; src_addr_a += (4 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (4 * MULT_TRANSPOSE1XW_WIDTH))
412     {
413         // Load values from matrix A (interleaved) and matrix B (transposed)
414         float4 a0 = vload4(0, src_addr_a);
415         float4 b0 = vload4(0, src_addr_b);
416
417         c00 = fma(a0.s0, b0.s0, c00);
418         c01 = fma(a0.s0, b0.s1, c01);
419         c02 = fma(a0.s0, b0.s2, c02);
420         c03 = fma(a0.s0, b0.s3, c03);
421
422         c10 = fma(a0.s1, b0.s0, c10);
423         c11 = fma(a0.s1, b0.s1, c11);
424         c12 = fma(a0.s1, b0.s2, c12);
425         c13 = fma(a0.s1, b0.s3, c13);
426
427         c20 = fma(a0.s2, b0.s0, c20);
428         c21 = fma(a0.s2, b0.s1, c21);
429         c22 = fma(a0.s2, b0.s2, c22);
430         c23 = fma(a0.s2, b0.s3, c23);
431
432         c30 = fma(a0.s3, b0.s0, c30);
433         c31 = fma(a0.s3, b0.s1, c31);
434         c32 = fma(a0.s3, b0.s2, c32);
435         c33 = fma(a0.s3, b0.s3, c33);
436     }
437
438     // Compute destination address
439     Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
440
441 #if defined(ALPHA)
442     // Multiply by the weight of matrix product
443     c00 = c00 * ALPHA;
444     c01 = c01 * ALPHA;
445     c02 = c02 * ALPHA;
446     c03 = c03 * ALPHA;
447     c10 = c10 * ALPHA;
448     c11 = c11 * ALPHA;
449     c12 = c12 * ALPHA;
450     c13 = c13 * ALPHA;
451     c20 = c20 * ALPHA;
452     c21 = c21 * ALPHA;
453     c22 = c22 * ALPHA;
454     c23 = c23 * ALPHA;
455     c30 = c30 * ALPHA;
456     c31 = c31 * ALPHA;
457     c32 = c32 * ALPHA;
458     c33 = c33 * ALPHA;
459 #endif // defined(ALPHA)
460
461     // Store 4x4 block
462     vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(offset(&dst, 0, 0)));
463     vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(offset(&dst, 0, 1)));
464     vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(offset(&dst, 0, 2)));
465     vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(offset(&dst, 0, 3)));
466 }
467
468 #if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
469 /** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
470  *  Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
471  *
472  * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
473  * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
474  * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
475  *
476  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16
477  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
478  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
479  * @param[in]  src0_stride_y                      Stride of the source matrix in Y dimension (in bytes)
480  * @param[in]  src0_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
481  * @param[in]  src0_offset_first_element_in_bytes The offset of the first element in the source matrix
482  * @param[in]  src1_ptr                           Pointer to the source matrix. Supported data types: same as @p src0_ptr
483  * @param[in]  src1_stride_x                      Stride of the source matrix in X dimension (in bytes)
484  * @param[in]  src1_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
485  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
486  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
487  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
488  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
489  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
490  * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
491  * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
492  * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
493  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
494  */
495 __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
496                                                  IMAGE_DECLARATION(src1),
497                                                  IMAGE_DECLARATION(dst))
498 {
499     int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
500     int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
501
502     // Offset
503     const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
504     const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
505
506     // src_addr_a = address of matrix A
507     // src_addr_b = address of matrix B
508     __global half *src_addr_a = (__global half *)(src0_ptr + y * src0_stride_y + src0_offset_first_element_in_bytes);
509     __global half *src_addr_b = (__global half *)(src1_ptr + x * src1_stride_y + src1_offset_first_element_in_bytes);
510
511     // Compute end row address for matrix B
512     __global half *src_end_addr_b = src_addr_b + COLS_B;
513
514     src_addr_a += offset_row_a;
515     src_addr_b += offset_row_b;
516
517     // Reset accumulators
518     half8 c00 = 0.0f;
519     half8 c10 = 0.0f;
520     half8 c20 = 0.0f;
521     half8 c30 = 0.0f;
522
523     for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
524     {
525         // Load values from matrix A (interleaved) and matrix B (transposed)
526         half4 a0 = vload4(0, src_addr_a);
527         half8 b0 = vload8(0, src_addr_b);
528
529         c00 += (half8)a0.s0 * b0;
530         c10 += (half8)a0.s1 * b0;
531         c20 += (half8)a0.s2 * b0;
532         c30 += (half8)a0.s3 * b0;
533
534         // Load values from matrix A (interleaved) and matrix B (transposed)
535         a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
536         b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
537
538         c00 += (half8)a0.s0 * b0;
539         c10 += (half8)a0.s1 * b0;
540         c20 += (half8)a0.s2 * b0;
541         c30 += (half8)a0.s3 * b0;
542     }
543
544     for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
545     {
546         // Load values from matrix A (interleaved) and matrix B (transposed)
547         half4 a0 = vload4(0, src_addr_a);
548         half8 b0 = vload8(0, src_addr_b);
549
550         c00 += (half8)a0.s0 * b0;
551         c10 += (half8)a0.s1 * b0;
552         c20 += (half8)a0.s2 * b0;
553         c30 += (half8)a0.s3 * b0;
554     }
555
556     // Compute destination address
557     Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
558
559 #if defined(ALPHA)
560     // Multiply by the weight of matrix product
561     c00 = c00 * (half8)ALPHA;
562     c10 = c10 * (half8)ALPHA;
563     c20 = c20 * (half8)ALPHA;
564     c30 = c30 * (half8)ALPHA;
565 #endif // defined(ALPHA)
566
567     // Store 4x8 block
568     vstore8(c00, 0, (__global half *)(offset(&dst, 0, 0)));
569     vstore8(c10, 0, (__global half *)(offset(&dst, 0, 1)));
570     vstore8(c20, 0, (__global half *)(offset(&dst, 0, 2)));
571     vstore8(c30, 0, (__global half *)(offset(&dst, 0, 3)));
572 }
573 #endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
574
575 #if defined(FIXED_POINT_POSITION)
576 /** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) in 8 bit fixed point precision
577  *  Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_8bit and @ref gemm_transpose1x16 before running the matrix multiplication
578  *
579  * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
580  * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
581  * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
582  *
583  * @note: ALPHA must be passed in 8 bit fixed point format
584  *
585  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: QS8
586  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
587  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
588  * @param[in]  src0_stride_y                      Stride of the source matrix in Y dimension (in bytes)
589  * @param[in]  src0_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
590  * @param[in]  src0_offset_first_element_in_bytes The offset of the first element in the source matrix
591  * @param[in]  src1_ptr                           Pointer to the source matrix. Supported data types: same as @p src0_ptr
592  * @param[in]  src1_stride_x                      Stride of the source matrix in X dimension (in bytes)
593  * @param[in]  src1_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
594  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
595  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
596  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
597  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
598  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
599  * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
600  * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
601  * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
602  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
603  */
604 __kernel void gemm_mm_interleaved_transposed_qs8(IMAGE_DECLARATION(src0),
605                                                  IMAGE_DECLARATION(src1),
606                                                  IMAGE_DECLARATION(dst))
607 {
608     int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
609     int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
610
611     // Offset
612     const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
613     const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 16;
614
615     // src_addr_a = address of matrix A
616     // src_addr_b = address of matrix B
617     __global char *src_addr_a = src0_ptr + y * src0_stride_y + src0_offset_first_element_in_bytes;
618     __global char *src_addr_b = src1_ptr + x * src1_stride_y + src1_offset_first_element_in_bytes;
619
620     // Compute end row address for matrix B
621     __global char *src_end_addr_b = src_addr_b + COLS_B;
622
623     src_addr_a += offset_row_a;
624     src_addr_b += offset_row_b;
625
626     // Reset accumulators
627     short8 c00 = 0.0f;
628     short8 c10 = 0.0f;
629     short8 c20 = 0.0f;
630     short8 c30 = 0.0f;
631     short8 c01 = 0.0f;
632     short8 c11 = 0.0f;
633     short8 c21 = 0.0f;
634     short8 c31 = 0.0f;
635
636     // This for loop performs 1 accumulation for each iteration
637     for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
638     {
639         // Load values from matrix A (interleaved) and matrix B (transposed)
640         char4  a0 = vload4(0, src_addr_a);
641         char16 b0 = vload16(0, src_addr_b);
642
643         c00 = mlal_sat_qs8x8(c00, (char8)a0.s0, b0.s01234567, FIXED_POINT_POSITION);
644         c10 = mlal_sat_qs8x8(c10, (char8)a0.s1, b0.s01234567, FIXED_POINT_POSITION);
645         c20 = mlal_sat_qs8x8(c20, (char8)a0.s2, b0.s01234567, FIXED_POINT_POSITION);
646         c30 = mlal_sat_qs8x8(c30, (char8)a0.s3, b0.s01234567, FIXED_POINT_POSITION);
647
648         c01 = mlal_sat_qs8x8(c01, (char8)a0.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
649         c11 = mlal_sat_qs8x8(c11, (char8)a0.s1, b0.s89ABCDEF, FIXED_POINT_POSITION);
650         c21 = mlal_sat_qs8x8(c21, (char8)a0.s2, b0.s89ABCDEF, FIXED_POINT_POSITION);
651         c31 = mlal_sat_qs8x8(c31, (char8)a0.s3, b0.s89ABCDEF, FIXED_POINT_POSITION);
652     }
653
654     // Compute destination address
655     Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
656
657     // Multiply by the weight of matrix product
658     char16 c00_qs8 = convert_char16_sat((short16)(c00, c01));
659     char16 c10_qs8 = convert_char16_sat((short16)(c10, c11));
660     char16 c20_qs8 = convert_char16_sat((short16)(c20, c21));
661     char16 c30_qs8 = convert_char16_sat((short16)(c30, c31));
662
663 #if defined(ALPHA)
664     c00_qs8 = mul_sat_qs8x16(c00_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
665     c10_qs8 = mul_sat_qs8x16(c10_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
666     c20_qs8 = mul_sat_qs8x16(c20_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
667     c30_qs8 = mul_sat_qs8x16(c30_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
668 #endif // defined(ALPHA)
669
670     // Store 16x4 block
671     vstore16(c00_qs8, 0, (__global char *)(offset(&dst, 0, 0)));
672     vstore16(c10_qs8, 0, (__global char *)(offset(&dst, 0, 1)));
673     vstore16(c20_qs8, 0, (__global char *)(offset(&dst, 0, 2)));
674     vstore16(c30_qs8, 0, (__global char *)(offset(&dst, 0, 3)));
675 }
676
677 /** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) in 16 bit fixed point precision
678  *  Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
679  *
680  * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
681  * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
682  * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
683  *
684  * @note: ALPHA must be passed in 16 bit fixed point format
685  *
686  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: QS16
687  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
688  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
689  * @param[in]  src0_stride_y                      Stride of the source matrix in Y dimension (in bytes)
690  * @param[in]  src0_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
691  * @param[in]  src0_offset_first_element_in_bytes The offset of the first element in the source matrix
692  * @param[in]  src1_ptr                           Pointer to the source matrix. Supported data types: same as @p src0_ptr
693  * @param[in]  src1_stride_x                      Stride of the source matrix in X dimension (in bytes)
694  * @param[in]  src1_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
695  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
696  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
697  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
698  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
699  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
700  * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
701  * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
702  * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
703  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
704  */
705 __kernel void gemm_mm_interleaved_transposed_qs16(IMAGE_DECLARATION(src0),
706                                                   IMAGE_DECLARATION(src1),
707                                                   IMAGE_DECLARATION(dst))
708 {
709     int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
710     int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
711
712     // Offset
713     const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
714     const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
715
716     // src_addr_a = address of matrix A
717     // src_addr_b = address of matrix B
718     __global short *src_addr_a = (__global short *)(src0_ptr + y * src0_stride_y + src0_offset_first_element_in_bytes);
719     __global short *src_addr_b = (__global short *)(src1_ptr + x * src1_stride_y + src1_offset_first_element_in_bytes);
720
721     // Compute end row address for matrix B
722     __global short *src_end_addr_b = src_addr_b + COLS_B;
723
724     src_addr_a += offset_row_a;
725     src_addr_b += offset_row_b;
726
727     // Reset accumulators
728     int8 c00 = 0.0f;
729     int8 c10 = 0.0f;
730     int8 c20 = 0.0f;
731     int8 c30 = 0.0f;
732
733     // This for loop performs 1 accumulation for each iteration
734     for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
735     {
736         /* Load values from matrix A (interleaved) and matrix B (transposed) */
737         short4 a0 = vload4(0, src_addr_a);
738         short8 b0 = vload8(0, src_addr_b);
739
740         c00 = mlal_sat_qs16x8(c00, (short8)a0.s0, b0, FIXED_POINT_POSITION);
741         c10 = mlal_sat_qs16x8(c10, (short8)a0.s1, b0, FIXED_POINT_POSITION);
742         c20 = mlal_sat_qs16x8(c20, (short8)a0.s2, b0, FIXED_POINT_POSITION);
743         c30 = mlal_sat_qs16x8(c30, (short8)a0.s3, b0, FIXED_POINT_POSITION);
744     }
745
746     // Compute destination address
747     Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
748
749     // Multiply by the weight of matrix product
750     short8 c00_qs16 = convert_short8_sat(c00);
751     short8 c10_qs16 = convert_short8_sat(c10);
752     short8 c20_qs16 = convert_short8_sat(c20);
753     short8 c30_qs16 = convert_short8_sat(c30);
754
755 #if defined(ALPHA)
756     c00_qs16 = mul_sat_qs16x8(c00_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
757     c10_qs16 = mul_sat_qs16x8(c10_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
758     c20_qs16 = mul_sat_qs16x8(c20_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
759     c30_qs16 = mul_sat_qs16x8(c30_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
760 #endif // defined(ALPHA)
761
762     // Store 8x4 block
763     vstore8(c00_qs16, 0, (__global short *)(offset(&dst, 0, 0)));
764     vstore8(c10_qs16, 0, (__global short *)(offset(&dst, 0, 1)));
765     vstore8(c20_qs16, 0, (__global short *)(offset(&dst, 0, 2)));
766     vstore8(c30_qs16, 0, (__global short *)(offset(&dst, 0, 3)));
767 }
768 #endif // defined(FIXED_POINT_POSITION)
769 #endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
770
771 #if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
772 #if defined(DATA_TYPE)
773 #define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
774 /** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
775  *
776  * @note This OpenCL kernel works with floating point data types (F16/F32)
777  * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
778  * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
779  * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
780  *
781  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16/F32
782  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
783  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
784  * @param[in]  src0_stride_y                      Stride of the source matrix in Y dimension (in bytes)
785  * @param[in]  src0_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
786  * @param[in]  src0_offset_first_element_in_bytes The offset of the first element in the source matrix
787  * @param[in]  src1_ptr                           Pointer to the source matrix. Supported data types: same as @p src0_ptr
788  * @param[in]  src1_stride_x                      Stride of the source matrix in X dimension (in bytes)
789  * @param[in]  src1_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
790  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
791  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
792  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
793  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
794  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
795  * @param[in]  dst_step_x                         dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
796  * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
797  * @param[in]  dst_step_y                         dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
798  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
799  */
800 __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
801                                      IMAGE_DECLARATION(src1),
802                                      IMAGE_DECLARATION(dst))
803 {
804     int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
805
806     // Compute starting address for matrix A and Matrix B
807     int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
808
809     // Update address for the matrix A
810     src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
811
812     // Update address for the matrix B
813     src_addr.s1 += idx * sizeof(DATA_TYPE);
814
815     int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
816
817     VECTOR_TYPE acc0 = 0.0f;
818 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
819     VECTOR_TYPE acc1 = 0.0f;
820 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
821 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
822     VECTOR_TYPE acc2 = 0.0f;
823 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
824 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
825     VECTOR_TYPE acc3 = 0.0f;
826 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
827
828     for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
829     {
830         // Load values from matrix A
831         VEC_DATA_TYPE(DATA_TYPE, 2)
832         a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
833 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
834         VEC_DATA_TYPE(DATA_TYPE, 2)
835         a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
836 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
837 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
838         VEC_DATA_TYPE(DATA_TYPE, 2)
839         a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
840 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
841 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
842         VEC_DATA_TYPE(DATA_TYPE, 2)
843         a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
844 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
845         // Load values from matrix B
846         VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
847         VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
848
849         // Accumulate
850         acc0 += b0 * (VECTOR_TYPE)a0.s0;
851         acc0 += b1 * (VECTOR_TYPE)a0.s1;
852 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
853         acc1 += b0 * (VECTOR_TYPE)a1.s0;
854         acc1 += b1 * (VECTOR_TYPE)a1.s1;
855 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
856 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
857         acc2 += b0 * (VECTOR_TYPE)a2.s0;
858         acc2 += b1 * (VECTOR_TYPE)a2.s1;
859 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
860 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
861         acc3 += b0 * (VECTOR_TYPE)a3.s0;
862         acc3 += b1 * (VECTOR_TYPE)a3.s1;
863 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
864     }
865
866     for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
867     {
868         // Load values from matrix A
869         DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
870 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
871         DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
872 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
873 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
874         DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
875 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
876 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
877         DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
878 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
879         // Load values from matrix B
880         VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
881
882         // Accumulate
883         acc0 += b0 * (VECTOR_TYPE)a0;
884 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
885         acc1 += b0 * (VECTOR_TYPE)a1;
886 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
887 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
888         acc2 += b0 * (VECTOR_TYPE)a2;
889 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
890 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
891         acc3 += b0 * (VECTOR_TYPE)a3;
892 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
893     }
894
895     // Compute destination address
896     Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
897
898     // Multiply by the weight of matrix-matrix product and store the result
899 #if defined(ALPHA)
900     acc0 = acc0 * (VECTOR_TYPE)ALPHA;
901 #endif // defined(ALPHA)
902     VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
903     (acc0, 0, (__global DATA_TYPE *)(offset(&dst, 0, 0)));
904 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
905 #if defined(ALPHA)
906     acc1 = acc1 * (VECTOR_TYPE)ALPHA;
907 #endif // defined(ALPHA)
908     VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
909     (acc1, 0, (__global DATA_TYPE *)(offset(&dst, 0, 1)));
910 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
911 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
912 #if defined(ALPHA)
913     acc2 = acc2 * (VECTOR_TYPE)ALPHA;
914 #endif // defined(ALPHA)
915     VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
916     (acc2, 0, (__global DATA_TYPE *)(offset(&dst, 0, 2)));
917 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
918 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
919 #if defined(ALPHA)
920     acc3 = acc3 * (VECTOR_TYPE)ALPHA;
921 #endif // defined(ALPHA)
922     VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
923     (acc3, 0, (__global DATA_TYPE *)(offset(&dst, 0, 3)));
924 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
925 }
926 #endif // defined(DATA_TYPE)
927
928 /** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
929  *
930  * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
931  * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
932  * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
933  * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
934  * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
935  *
936  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16/F32
937  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
938  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
939  * @param[in]  src0_stride_y                      Stride of the source matrix in Y dimension (in bytes)
940  * @param[in]  src0_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
941  * @param[in]  src0_offset_first_element_in_bytes The offset of the first element in the source matrix
942  * @param[in]  src1_ptr                           Pointer to the source matrix. Supported data types: same as @p src0_ptr
943  * @param[in]  src1_stride_x                      Stride of the source matrix in X dimension (in bytes)
944  * @param[in]  src1_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
945  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
946  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
947  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
948  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
949  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
950  * @param[in]  dst_step_x                         dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
951  * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
952  * @param[in]  dst_step_y                         dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
953  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
954  */
955 __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
956                                                  IMAGE_DECLARATION(src1),
957                                                  IMAGE_DECLARATION(dst))
958 {
959     int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
960
961     // Compute starting address for matrix A and matrix B
962     int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
963
964     // Update address for matrix A
965     src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
966
967     // Update address for matrix B
968     src_addr.s1 += idx * sizeof(float);
969
970     // Address boundary for matrix A
971     int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(float));
972
973     // Initialize accumulators
974     float acc00 = 0.0f;
975     float acc01 = 0.0f;
976     float acc02 = 0.0f;
977     float acc03 = 0.0f;
978
979 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
980     float acc10 = 0.0f;
981     float acc11 = 0.0f;
982     float acc12 = 0.0f;
983     float acc13 = 0.0f;
984 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
985
986 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
987     float acc20 = 0.0f;
988     float acc21 = 0.0f;
989     float acc22 = 0.0f;
990     float acc23 = 0.0f;
991 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
992
993 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
994     float acc30 = 0.0f;
995     float acc31 = 0.0f;
996     float acc32 = 0.0f;
997     float acc33 = 0.0f;
998 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
999
1000     // A and B src indices get incremented at the same time.
1001     for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
1002     {
1003         // Load values from matrix A
1004         float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1005 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1006         float2 a1 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1007 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1008 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1009         float2 a2 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1010 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1011 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1012         float2 a3 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1013 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1014         // Load values from matrix B
1015         float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
1016         float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
1017
1018         // Multiply and accumulate
1019         acc00 = fma(a0.s0, b0.s0, acc00);
1020         acc00 = fma(a0.s1, b1.s0, acc00);
1021         acc01 = fma(a0.s0, b0.s1, acc01);
1022         acc01 = fma(a0.s1, b1.s1, acc01);
1023         acc02 = fma(a0.s0, b0.s2, acc02);
1024         acc02 = fma(a0.s1, b1.s2, acc02);
1025         acc03 = fma(a0.s1, b1.s3, acc03);
1026         acc03 = fma(a0.s0, b0.s3, acc03);
1027
1028 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1029         acc10 = fma(a1.s0, b0.s0, acc10);
1030         acc11 = fma(a1.s0, b0.s1, acc11);
1031         acc12 = fma(a1.s0, b0.s2, acc12);
1032         acc13 = fma(a1.s0, b0.s3, acc13);
1033
1034         acc10 = fma(a1.s1, b1.s0, acc10);
1035         acc11 = fma(a1.s1, b1.s1, acc11);
1036         acc12 = fma(a1.s1, b1.s2, acc12);
1037         acc13 = fma(a1.s1, b1.s3, acc13);
1038 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1039 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1040         acc20 = fma(a2.s0, b0.s0, acc20);
1041         acc21 = fma(a2.s0, b0.s1, acc21);
1042         acc22 = fma(a2.s0, b0.s2, acc22);
1043         acc23 = fma(a2.s0, b0.s3, acc23);
1044
1045         acc20 = fma(a2.s1, b1.s0, acc20);
1046         acc21 = fma(a2.s1, b1.s1, acc21);
1047         acc22 = fma(a2.s1, b1.s2, acc22);
1048         acc23 = fma(a2.s1, b1.s3, acc23);
1049 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1050 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1051         acc30 = fma(a3.s0, b0.s0, acc30);
1052         acc31 = fma(a3.s0, b0.s1, acc31);
1053         acc32 = fma(a3.s0, b0.s2, acc32);
1054         acc33 = fma(a3.s0, b0.s3, acc33);
1055
1056         acc30 = fma(a3.s1, b1.s0, acc30);
1057         acc31 = fma(a3.s1, b1.s1, acc31);
1058         acc32 = fma(a3.s1, b1.s2, acc32);
1059         acc33 = fma(a3.s1, b1.s3, acc33);
1060 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1061     }
1062
1063     for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
1064     {
1065         // Load values from matrix A
1066         float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1067 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1068         float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1069 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1070 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1071         float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1072 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1073 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1074         float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1075 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1076         // Load values from matrix B
1077         float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1078
1079         // Multiply and accumulate
1080         acc00 = fma(a0, b0.s0, acc00);
1081         acc01 = fma(a0, b0.s1, acc01);
1082         acc02 = fma(a0, b0.s2, acc02);
1083         acc03 = fma(a0, b0.s3, acc03);
1084 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1085         acc10 = fma(a1, b0.s0, acc10);
1086         acc11 = fma(a1, b0.s1, acc11);
1087         acc12 = fma(a1, b0.s2, acc12);
1088         acc13 = fma(a1, b0.s3, acc13);
1089 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1090 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1091         acc20 = fma(a2, b0.s0, acc20);
1092         acc21 = fma(a2, b0.s1, acc21);
1093         acc22 = fma(a2, b0.s2, acc22);
1094         acc23 = fma(a2, b0.s3, acc23);
1095 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1096 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1097         acc30 = fma(a3, b0.s0, acc30);
1098         acc31 = fma(a3, b0.s1, acc31);
1099         acc32 = fma(a3, b0.s2, acc32);
1100         acc33 = fma(a3, b0.s3, acc33);
1101 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1102     }
1103
1104     // Compute destination address
1105     Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1106
1107     // Multiply by the weight of matrix-matrix product and store the result
1108 #if defined(ALPHA)
1109     acc00 = acc00 * ALPHA;
1110     acc01 = acc01 * ALPHA;
1111     acc02 = acc02 * ALPHA;
1112     acc03 = acc03 * ALPHA;
1113 #endif // defined(ALPHA)
1114
1115     float4 acc0 = ((float4)(acc00, acc01, acc02, acc03));
1116     vstore4(acc0, 0, (__global float *)(offset(&dst, 0, 0)));
1117
1118 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1119 #if defined(ALPHA)
1120     acc10 = acc10 * ALPHA;
1121     acc11 = acc11 * ALPHA;
1122     acc12 = acc12 * ALPHA;
1123     acc13 = acc13 * ALPHA;
1124 #endif // defined(ALPHA)
1125     float4 acc1 = ((float4)(acc10, acc11, acc12, acc13));
1126     vstore4(acc1, 0, (__global float *)(offset(&dst, 0, 1)));
1127 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1128 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1129 #if defined(ALPHA)
1130     acc20 = acc20 * ALPHA;
1131     acc21 = acc21 * ALPHA;
1132     acc22 = acc22 * ALPHA;
1133     acc23 = acc23 * ALPHA;
1134 #endif // defined(ALPHA)
1135     float4 acc2 = ((float4)(acc20, acc21, acc22, acc23));
1136     vstore4(acc2, 0, (__global float *)(offset(&dst, 0, 2)));
1137 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1138 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1139 #if defined(ALPHA)
1140     acc30 = acc30 * ALPHA;
1141     acc31 = acc31 * ALPHA;
1142     acc32 = acc32 * ALPHA;
1143     acc33 = acc33 * ALPHA;
1144 #endif // defined(ALPHA)
1145     float4 acc3 = ((float4)(acc30, acc31, acc32, acc33));
1146     vstore4(acc3, 0, (__global float *)(offset(&dst, 0, 3)));
1147 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1148 }
1149
1150 /** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
1151  *
1152  * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
1153  * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
1154  * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
1155  * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
1156  * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
1157  * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
1158  *
1159  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16/F32
1160  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
1161  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
1162  * @param[in]  src0_stride_y                      Stride of the source matrix in Y dimension (in bytes)
1163  * @param[in]  src0_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
1164  * @param[in]  src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1165  * @param[in]  src1_ptr                           Pointer to the source matrix. Supported data types: same as @p src0_ptr
1166  * @param[in]  src1_stride_x                      Stride of the source matrix in X dimension (in bytes)
1167  * @param[in]  src1_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
1168  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
1169  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
1170  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1171  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
1172  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
1173  * @param[in]  dst_step_x                         dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1174  * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
1175  * @param[in]  dst_step_y                         dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1176  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
1177  */
1178 __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
1179                                                       IMAGE_DECLARATION(src1),
1180                                                       IMAGE_DECLARATION(dst))
1181 {
1182     // Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1183     int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1184
1185     // Compute starting address for matrix A and Matrix B
1186     int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1187
1188     // Update address for the matrix A
1189     src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1190
1191     // Update address for the matrix B
1192     src_addr.s1 += idx * sizeof(float);
1193
1194     // Address boundary for the matrix A
1195     int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(float));
1196
1197     // Initialize accumulators
1198     float acc00 = 0.0f;
1199     float acc01 = 0.0f;
1200
1201 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1202     float acc10 = 0.0f;
1203     float acc11 = 0.0f;
1204 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1205 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1206     float acc20 = 0.0f;
1207     float acc21 = 0.0f;
1208 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1209 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1210     float acc30 = 0.0f;
1211     float acc31 = 0.0f;
1212 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1213
1214     // A and B src indices get incremented at the same time.
1215     for(; src_addr.s0 <= (end_row_vec_a - 4 * (int)sizeof(float)); src_addr += (int2)(4 * sizeof(float), 4 * src1_stride_y))
1216     {
1217         // Load values from matrix A
1218         float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1219
1220         // Load values from matrix B
1221         float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
1222         float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
1223         float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 2 * src1_stride_y));
1224         float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 3 * src1_stride_y));
1225
1226         // Multiply and accumulate
1227         acc00 = fma(a0.s0, b0.s0, acc00);
1228         acc00 = fma(a0.s1, b1.s0, acc00);
1229         acc00 = fma(a0.s2, b2.s0, acc00);
1230         acc00 = fma(a0.s3, b3.s0, acc00);
1231
1232         acc01 = fma(a0.s0, b0.s1, acc01);
1233         acc01 = fma(a0.s1, b1.s1, acc01);
1234         acc01 = fma(a0.s2, b2.s1, acc01);
1235         acc01 = fma(a0.s3, b3.s1, acc01);
1236
1237 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1238         a0    = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1239         acc10 = fma(a0.s0, b0.s0, acc10);
1240         acc10 = fma(a0.s1, b1.s0, acc10);
1241         acc10 = fma(a0.s2, b2.s0, acc10);
1242         acc10 = fma(a0.s3, b3.s0, acc10);
1243
1244         acc11 = fma(a0.s0, b0.s1, acc11);
1245         acc11 = fma(a0.s1, b1.s1, acc11);
1246         acc11 = fma(a0.s2, b2.s1, acc11);
1247         acc11 = fma(a0.s3, b3.s1, acc11);
1248 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1249 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1250         a0    = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1251         acc20 = fma(a0.s0, b0.s0, acc20);
1252         acc20 = fma(a0.s1, b1.s0, acc20);
1253         acc20 = fma(a0.s2, b2.s0, acc20);
1254         acc20 = fma(a0.s3, b3.s0, acc20);
1255
1256         acc21 = fma(a0.s0, b0.s1, acc21);
1257         acc21 = fma(a0.s1, b1.s1, acc21);
1258         acc21 = fma(a0.s2, b2.s1, acc21);
1259         acc21 = fma(a0.s3, b3.s1, acc21);
1260 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1261 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1262         a0    = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1263         acc30 = fma(a0.s0, b0.s0, acc30);
1264         acc30 = fma(a0.s1, b1.s0, acc30);
1265         acc30 = fma(a0.s2, b2.s0, acc30);
1266         acc30 = fma(a0.s3, b3.s0, acc30);
1267
1268         acc31 = fma(a0.s0, b0.s1, acc31);
1269         acc31 = fma(a0.s1, b1.s1, acc31);
1270         acc31 = fma(a0.s2, b2.s1, acc31);
1271         acc31 = fma(a0.s3, b3.s1, acc31);
1272 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1273     }
1274     // float size increment
1275     for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(4, src1_stride_y))
1276     {
1277         // Load values from matrix A
1278         float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1279 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1280         float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1281 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1282 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1283         float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1284 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1285 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1286         float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1287 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1288         // Load values from matrix B
1289         float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1290
1291         // Multiply and accumulate
1292         acc00 = fma(a0, b0.s0, acc00);
1293         acc01 = fma(a0, b0.s1, acc01);
1294 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1295         acc10 = fma(a1, b0.s0, acc10);
1296         acc11 = fma(a1, b0.s1, acc11);
1297 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1298 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1299         acc20 = fma(a2, b0.s0, acc20);
1300         acc21 = fma(a2, b0.s1, acc21);
1301 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1302 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1303         acc30 = fma(a3, b0.s0, acc30);
1304         acc31 = fma(a3, b0.s1, acc31);
1305 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1306     }
1307
1308     // Compute destination address
1309     Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1310
1311     // Multiply by the weight of matrix-matrix product and store the result
1312 #if defined(ALPHA)
1313     acc00 = acc00 * ALPHA;
1314     acc01 = acc01 * ALPHA;
1315 #endif // defined(ALPHA)
1316     float2 acc0 = ((float2)(acc00, acc01));
1317     vstore2(acc0, 0, (__global float *)(offset(&dst, 0, 0)));
1318 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1319 #if defined(ALPHA)
1320     acc10 = acc10 * ALPHA;
1321     acc11 = acc11 * ALPHA;
1322 #endif // defined(ALPHA)
1323     float2 acc1 = ((float2)(acc10, acc11));
1324     vstore2(acc1, 0, (__global float *)(offset(&dst, 0, 1)));
1325 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1326 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1327 #if defined(ALPHA)
1328     acc20 = acc20 * ALPHA;
1329     acc21 = acc21 * ALPHA;
1330 #endif // defined(ALPHA)
1331     float2 acc2 = ((float2)(acc20, acc21));
1332     vstore2(acc2, 0, (__global float *)(offset(&dst, 0, 2)));
1333 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1334 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1335 #if defined(ALPHA)
1336     acc30 = acc30 * ALPHA;
1337     acc31 = acc31 * ALPHA;
1338 #endif // defined(ALPHA)
1339     float2 acc3 = (float2)(acc30, acc31);
1340     vstore2(acc3, 0, (__global float *)(offset(&dst, 0, 3)));
1341 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1342 }
1343
1344 #if defined(FIXED_POINT_POSITION)
1345 /** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
1346  *
1347  * @note This OpenCL kernel works with fixed point data types QS8
1348  * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
1349  * @note The number matrix A columns, the number of elements processed per thread along the Y direction and the alpha's value need to be passed at compile time using -DCOLS_A, -DNUM_ELEMS_PROCESSED_PER_THREAD_Y and -DALPHA
1350  * @note The fixed point position need to be passed at compile time using -DFIXED_POINT_POSITION
1351  * @note The optional alpha value must be passed in 8 bit fixed point format using -DALPHA
1352  *
1353  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: QS8/QS16
1354  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
1355  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
1356  * @param[in]  src0_stride_y                      Stride of the source matrix in Y dimension (in bytes)
1357  * @param[in]  src0_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
1358  * @param[in]  src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1359  * @param[in]  src1_ptr                           Pointer to the source matrix. Supported data types: same as @p src0_ptr
1360  * @param[in]  src1_stride_x                      Stride of the source matrix in X dimension (in bytes)
1361  * @param[in]  src1_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
1362  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
1363  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
1364  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1365  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
1366  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
1367  * @param[in]  dst_step_x                         dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1368  * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
1369  * @param[in]  dst_step_y                         dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1370  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
1371  */
1372 __kernel void gemm_mm_qs8(IMAGE_DECLARATION(src0),
1373                           IMAGE_DECLARATION(src1),
1374                           IMAGE_DECLARATION(dst))
1375 {
1376     int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1377
1378     // Compute starting address for matrix A and Matrix B
1379     int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1380
1381     // Update address for the matrix A
1382     src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1383
1384     // Update address for the matrix B
1385     src_addr.s1 += idx * sizeof(char);
1386
1387     int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(char));
1388
1389     short8 acc00 = 0;
1390     short8 acc01 = 0;
1391 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1392     short8 acc10 = 0;
1393     short8 acc11 = 0;
1394 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1395 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1396     short8 acc20 = 0;
1397     short8 acc21 = 0;
1398 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1399 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1400     short8 acc30 = 0;
1401     short8 acc31 = 0;
1402 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1403
1404     // This for loop performs 4 accumulations per iteration
1405     for(; src_addr.s0 <= (end_row_vec_a - 2); src_addr += (int2)(2, 2 * src1_stride_y))
1406     {
1407         char2 a0 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1408 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1409         char2 a1 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1410 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1411 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1412         char2 a2 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1413 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1414 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1415         char2 a3 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1416 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1417         char16 b0 = vload16(0, (__global char *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
1418         char16 b1 = vload16(0, (__global char *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
1419
1420         acc00 = mlal_sat_qs8x8(acc00, (char8)a0.s0, b0.s01234567, FIXED_POINT_POSITION);
1421         acc00 = mlal_sat_qs8x8(acc00, (char8)a0.s1, b1.s01234567, FIXED_POINT_POSITION);
1422         acc01 = mlal_sat_qs8x8(acc01, (char8)a0.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
1423         acc01 = mlal_sat_qs8x8(acc01, (char8)a0.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
1424 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1425         acc10 = mlal_sat_qs8x8(acc10, (char8)a1.s0, b0.s01234567, FIXED_POINT_POSITION);
1426         acc10 = mlal_sat_qs8x8(acc10, (char8)a1.s1, b1.s01234567, FIXED_POINT_POSITION);
1427         acc11 = mlal_sat_qs8x8(acc11, (char8)a1.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
1428         acc11 = mlal_sat_qs8x8(acc11, (char8)a1.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
1429 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1430 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1431         acc20 = mlal_sat_qs8x8(acc20, (char8)a2.s0, b0.s01234567, FIXED_POINT_POSITION);
1432         acc20 = mlal_sat_qs8x8(acc20, (char8)a2.s1, b1.s01234567, FIXED_POINT_POSITION);
1433         acc21 = mlal_sat_qs8x8(acc21, (char8)a2.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
1434         acc21 = mlal_sat_qs8x8(acc21, (char8)a2.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
1435 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1436 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1437         acc30 = mlal_sat_qs8x8(acc30, (char8)a3.s0, b0.s01234567, FIXED_POINT_POSITION);
1438         acc30 = mlal_sat_qs8x8(acc30, (char8)a3.s1, b1.s01234567, FIXED_POINT_POSITION);
1439         acc31 = mlal_sat_qs8x8(acc31, (char8)a3.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
1440         acc31 = mlal_sat_qs8x8(acc31, (char8)a3.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
1441 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1442     }
1443
1444     // Left-over accumulations
1445     for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y))
1446     {
1447         char a0 = *((__global char *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1448 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1449         char a1 = *((__global char *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1450 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1451 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1452         char a2 = *((__global char *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1453 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1454 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1455         char a3 = *((__global char *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1456 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1457         char16 b0 = vload16(0, (__global char *)(src1_ptr + src_addr.s1));
1458
1459         acc00 = mlal_sat_qs8x8(acc00, (char8)a0, b0.s01234567, FIXED_POINT_POSITION);
1460         acc01 = mlal_sat_qs8x8(acc01, (char8)a0, b0.s89ABCDEF, FIXED_POINT_POSITION);
1461 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1462         acc10 = mlal_sat_qs8x8(acc10, (char8)a1, b0.s01234567, FIXED_POINT_POSITION);
1463         acc11 = mlal_sat_qs8x8(acc11, (char8)a1, b0.s89ABCDEF, FIXED_POINT_POSITION);
1464 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1465 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1466         acc20 = mlal_sat_qs8x8(acc20, (char8)a2, b0.s01234567, FIXED_POINT_POSITION);
1467         acc21 = mlal_sat_qs8x8(acc21, (char8)a2, b0.s89ABCDEF, FIXED_POINT_POSITION);
1468 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1469 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1470         acc30 = mlal_sat_qs8x8(acc30, (char8)a3, b0.s01234567, FIXED_POINT_POSITION);
1471         acc31 = mlal_sat_qs8x8(acc31, (char8)a3, b0.s89ABCDEF, FIXED_POINT_POSITION);
1472 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1473     }
1474
1475     // Compute destination address
1476     Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1477
1478     // Multiply by the weight of matrix product and store the result
1479     char16 acc_qs8;
1480     acc_qs8 = convert_char16_sat((short16)(acc00, acc01));
1481 #if defined(ALPHA)
1482     acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
1483 #endif // defined(ALPHA)
1484     vstore16(acc_qs8, 0, (__global char *)(offset(&dst, 0, 0)));
1485 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1486     acc_qs8 = convert_char16_sat((short16)(acc10, acc11));
1487 #if defined(ALPHA)
1488     acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
1489 #endif // defined(ALPHA)
1490     vstore16(acc_qs8, 0, (__global char *)(offset(&dst, 0, 1)));
1491 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1492 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1493     acc_qs8 = convert_char16_sat((short16)(acc20, acc21));
1494 #if defined(ALPHA)
1495     acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
1496 #endif // defined(ALPHA)
1497     vstore16(acc_qs8, 0, (__global char *)(offset(&dst, 0, 2)));
1498 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1499 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1500     acc_qs8 = convert_char16_sat((short16)(acc30, acc31));
1501 #if defined(ALPHA)
1502     acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
1503 #endif // defined(ALPHA)
1504     vstore16(acc_qs8, 0, (__global char *)(offset(&dst, 0, 3)));
1505 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1506 }
1507
1508 /** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
1509  *
1510  * @note This OpenCL kernel works with fixed point data types QS16
1511  * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
1512  * @note The number of matrix A columns, the number of elements processed per thread along the Y direction and the alpha's value need to be passed at compile time using -DCOLS_A, -DNUM_ELEMS_PROCESSED_PER_THREAD_Y and -DALPHA
1513  * @note The fixed point position need to be passed at compile time using -DFIXED_POINT_POSITION
1514  * @note The optional alpha value must be passed in 16 bit fixed point format using -DALPHA
1515  *
1516  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: QS8/QS16
1517  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
1518  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
1519  * @param[in]  src0_stride_y                      Stride of the source matrix in Y dimension (in bytes)
1520  * @param[in]  src0_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
1521  * @param[in]  src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1522  * @param[in]  src1_ptr                           Pointer to the source matrix. Supported data types: same as @p src0_ptr
1523  * @param[in]  src1_stride_x                      Stride of the source matrix in X dimension (in bytes)
1524  * @param[in]  src1_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
1525  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
1526  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
1527  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1528  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
1529  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
1530  * @param[in]  dst_step_x                         dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1531  * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
1532  * @param[in]  dst_step_y                         dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1533  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
1534  */
1535 __kernel void gemm_mm_qs16(IMAGE_DECLARATION(src0),
1536                            IMAGE_DECLARATION(src1),
1537                            IMAGE_DECLARATION(dst))
1538 {
1539     int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1540
1541     // Compute starting address for matrix A and Matrix B
1542     int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1543
1544     // Update address for the matrix A
1545     src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1546
1547     // Update address for the matrix B
1548     src_addr.s1 += idx * sizeof(short);
1549
1550     int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(short));
1551
1552     int8 acc0 = 0;
1553 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1554     int8 acc1 = 0;
1555 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1556 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1557     int8 acc2 = 0;
1558 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1559 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1560     int8 acc3 = 0;
1561 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1562
1563     // This for loop performs 4 accumulations per iteration
1564     for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(short)); src_addr += (int2)(2 * sizeof(short), 2 * src1_stride_y))
1565     {
1566         short2 a0 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1567 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1568         short2 a1 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1569 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1570 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1571         short2 a2 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1572 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1573 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1574         short2 a3 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1575 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1576         short8 b0 = vload8(0, (__global short *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
1577         short8 b1 = vload8(0, (__global short *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
1578
1579         acc0 = mlal_sat_qs16x8(acc0, (short8)a0.s0, b0, FIXED_POINT_POSITION);
1580         acc0 = mlal_sat_qs16x8(acc0, (short8)a0.s1, b1, FIXED_POINT_POSITION);
1581 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1582         acc1 = mlal_sat_qs16x8(acc1, (short8)a1.s0, b0, FIXED_POINT_POSITION);
1583         acc1 = mlal_sat_qs16x8(acc1, (short8)a1.s1, b1, FIXED_POINT_POSITION);
1584 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1585 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1586         acc2 = mlal_sat_qs16x8(acc2, (short8)a2.s0, b0, FIXED_POINT_POSITION);
1587         acc2 = mlal_sat_qs16x8(acc2, (short8)a2.s1, b1, FIXED_POINT_POSITION);
1588 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1589 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1590         acc3 = mlal_sat_qs16x8(acc3, (short8)a3.s0, b0, FIXED_POINT_POSITION);
1591         acc3 = mlal_sat_qs16x8(acc3, (short8)a3.s1, b1, FIXED_POINT_POSITION);
1592 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1593     }
1594
1595     // Left-over accumulations
1596     for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(short), src1_stride_y))
1597     {
1598         short a0 = *((__global short *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1599 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1600         short a1 = *((__global short *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1601 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1602 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1603         short a2 = *((__global short *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1604 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1605 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1606         short a3 = *((__global short *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1607 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1608         short8 b0 = vload8(0, (__global short *)(src1_ptr + src_addr.s1));
1609
1610         acc0 = mlal_sat_qs16x8(acc0, (short8)a0, b0, FIXED_POINT_POSITION);
1611 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1612         acc1 = mlal_sat_qs16x8(acc1, (short8)a1, b0, FIXED_POINT_POSITION);
1613 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1614 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1615         acc2 = mlal_sat_qs16x8(acc2, (short8)a2, b0, FIXED_POINT_POSITION);
1616 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1617 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1618         acc3 = mlal_sat_qs16x8(acc3, (short8)a3, b0, FIXED_POINT_POSITION);
1619 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1620     }
1621
1622     // Compute destination address
1623     Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1624
1625     // Multiply by the weight of matrix product and store the result
1626     short8 acc_qs16;
1627     acc_qs16 = convert_short8_sat(acc0);
1628 #if defined(ALPHA)
1629     acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
1630 #endif // defined(ALPHA)
1631     vstore8(acc_qs16, 0, (__global short *)(offset(&dst, 0, 0)));
1632 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1633     acc_qs16 = convert_short8_sat(acc1);
1634 #if defined(ALPHA)
1635     acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
1636 #endif // defined(ALPHA)
1637     vstore8(acc_qs16, 0, (__global short *)(offset(&dst, 0, 1)));
1638 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1639 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1640     acc_qs16 = convert_short8_sat(acc2);
1641 #if defined(ALPHA)
1642     acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
1643 #endif // defined(ALPHA)
1644     vstore8(acc_qs16, 0, (__global short *)(offset(&dst, 0, 2)));
1645 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1646 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1647     acc_qs16 = convert_short8_sat(acc3);
1648 #if defined(ALPHA)
1649     acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
1650 #endif // defined(ALPHA)
1651     vstore8(acc_qs16, 0, (__global short *)(offset(&dst, 0, 3)));
1652 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1653 }
1654 #endif // defined(FIXED_POINT_POSITION)
1655 #endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
1656
1657 #if defined(BETA)
1658 /** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
1659  *
1660  * @note The beta's value need to be passed at compile time using -DBETA
1661  *
1662  * @param[in]  src_ptr                           Pointer to the source matrix. Supported data types: F32
1663  * @param[in]  src_stride_x                      Stride of the source matrix in X dimension (in bytes)
1664  * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
1665  * @param[in]  src_stride_y                      Stride of the source matrix in Y dimension (in bytes)
1666  * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
1667  * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source matrix
1668  * @param[out] dst_ptr                           Pointer to the destination matrix Supported data types: same as @p src_ptr
1669  * @param[in]  dst_stride_x                      Stride of the destination matrix in X dimension (in bytes)
1670  * @param[in]  dst_step_x                        dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1671  * @param[in]  dst_stride_y                      Stride of the destination matrix in Y dimension (in bytes)
1672  * @param[in]  dst_step_y                        dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1673  * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1674  */
1675 __kernel void gemm_ma_f32(IMAGE_DECLARATION(src),
1676                           IMAGE_DECLARATION(dst))
1677 {
1678     // Compute source and destination addresses
1679     Image src = CONVERT_TO_IMAGE_STRUCT(src);
1680     Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1681
1682     // Load values from A x B
1683     float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
1684
1685     // Load values from Matrix C
1686     float4 c = vload4(0, (__global float *)src.ptr);
1687
1688     // Computes alpha * axb + beta * c
1689     float4 out = alpha_ab + (float4)BETA * c;
1690
1691     // Store final result in axb matrix
1692     vstore4(out, 0, (__global float *)dst.ptr);
1693 }
1694
1695 /** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
1696  *
1697  * @note The beta's value need to be passed at compile time using -DBETA
1698  *
1699  * @param[in]  src_ptr                           Pointer to the source matrix. Supported data types: F16
1700  * @param[in]  src_stride_x                      Stride of the source matrix in X dimension (in bytes)
1701  * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
1702  * @param[in]  src_stride_y                      Stride of the source matrix in Y dimension (in bytes)
1703  * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
1704  * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source matrix
1705  * @param[out] dst_ptr                           Pointer to the destination matrix Supported data types: same as @p src_ptr
1706  * @param[in]  dst_stride_x                      Stride of the destination matrix in X dimension (in bytes)
1707  * @param[in]  dst_step_x                        dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1708  * @param[in]  dst_stride_y                      Stride of the destination matrix in Y dimension (in bytes)
1709  * @param[in]  dst_step_y                        dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1710  * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1711  */
1712 __kernel void gemm_ma_f16(IMAGE_DECLARATION(src),
1713                           IMAGE_DECLARATION(dst))
1714 {
1715     // Compute source and destination addresses
1716     Image src = CONVERT_TO_IMAGE_STRUCT(src);
1717     Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1718
1719     // Load values from A x B
1720     half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
1721
1722     // Load values from Matrix C
1723     half8 c = vload8(0, (__global half *)src.ptr);
1724
1725     // Computes alpha * axb + beta * c
1726     half8 out = alpha_ab + (half8)BETA * c;
1727
1728     // Store final result in axb matrix
1729     vstore8(out, 0, (__global half *)dst.ptr);
1730 }
1731
1732 #if defined(FIXED_POINT_POSITION)
1733 /** This OpenCL kernel performs the in-place matrix addition between 2 matrices in 8 bit fixed point taking into account that the second matrix might be weighted by a scalar value beta:
1734  *
1735  * @note The beta's value and the fixed point position need to be passed at compile time using -DBETA and -DFIXED_POINT_POSITION
1736  *
1737  * @note: BETA must be passed in 8 bit fixed point format
1738  *
1739  * @param[in]  src_ptr                           Pointer to the source matrix. Supported data types: QS8
1740  * @param[in]  src_stride_x                      Stride of the source matrix in X dimension (in bytes)
1741  * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
1742  * @param[in]  src_stride_y                      Stride of the source matrix in Y dimension (in bytes)
1743  * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
1744  * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source matrix
1745  * @param[out] dst_ptr                           Pointer to the destination matrix Supported data types: same as @p src_ptr
1746  * @param[in]  dst_stride_x                      Stride of the destination matrix in X dimension (in bytes)
1747  * @param[in]  dst_step_x                        dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1748  * @param[in]  dst_stride_y                      Stride of the destination matrix in Y dimension (in bytes)
1749  * @param[in]  dst_step_y                        dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1750  * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1751  */
1752 __kernel void gemm_ma_qs8(IMAGE_DECLARATION(src),
1753                           IMAGE_DECLARATION(dst))
1754 {
1755     // Compute source and destination addresses
1756     Image src = CONVERT_TO_IMAGE_STRUCT(src);
1757     Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1758
1759     // Load values from A x B
1760     char16 alpha_ab = vload16(0, (__global char *)dst.ptr);
1761
1762     // Load values from Matrix C
1763     char16 c = vload16(0, (__global char *)src.ptr);
1764
1765     // Computes alpha * axb + beta * c
1766     char16 out = mla_sat_qs8x16(alpha_ab, (char16)BETA, c, FIXED_POINT_POSITION);
1767
1768     // Store final result in axb matrix
1769     vstore16(out, 0, (__global char *)dst.ptr);
1770 }
1771
1772 /** This OpenCL kernel performs the in-place matrix addition between 2 matrices in 16 bit fixed point taking into account that the second matrix might be weighted by a scalar value beta:
1773  *
1774  * @note The beta's value and the fixed point position need to be passed at compile time using -DBETA and -DFIXED_POINT_POSITION
1775  *
1776  * @note: BETA must be passed in 16 bit fixed point format
1777  *
1778  * @param[in]  src_ptr                           Pointer to the source matrix. Supported data types: QS16
1779  * @param[in]  src_stride_x                      Stride of the source matrix in X dimension (in bytes)
1780  * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
1781  * @param[in]  src_stride_y                      Stride of the source matrix in Y dimension (in bytes)
1782  * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
1783  * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source matrix
1784  * @param[out] dst_ptr                           Pointer to the destination matrix Supported data types: same as @p src_ptr
1785  * @param[in]  dst_stride_x                      Stride of the destination matrix in X dimension (in bytes)
1786  * @param[in]  dst_step_x                        dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1787  * @param[in]  dst_stride_y                      Stride of the destination matrix in Y dimension (in bytes)
1788  * @param[in]  dst_step_y                        dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1789  * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1790  */
1791 __kernel void gemm_ma_qs16(IMAGE_DECLARATION(src),
1792                            IMAGE_DECLARATION(dst))
1793 {
1794     // Compute source and destination addresses
1795     Image src = CONVERT_TO_IMAGE_STRUCT(src);
1796     Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1797
1798     // Load values from A x B
1799     short8 alpha_ab = vload8(0, (__global short *)dst.ptr);
1800
1801     // Load values from Matrix C
1802     short8 c = vload8(0, (__global short *)src.ptr);
1803
1804     // Computes alpha * axb + beta * c
1805     short8 out = mla_sat_qs16x8(alpha_ab, (short8)BETA, c, FIXED_POINT_POSITION);
1806
1807     // Store final result in axb matrix
1808     vstore8(out, 0, (__global short *)dst.ptr);
1809 }
1810 #endif // defined(FIXED_POINT_POSITION)
1811 #endif // defined(BETA)
1812
1813 #if defined(WIDTH_VECTOR_A)
1814 /** This OpenCL kernel computes the vector by matrix multiplication between each row of A (src0) and matrix B (src1) used for locally connected layer
1815  *
1816  * @note The width of A need to be passed at compile time using -DWIDTH_VECTOR_A
1817  *
1818  * @note The input A and matrix B must not be reshaped
1819  *
1820  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F32
1821  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
1822  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
1823  * @param[in]  src0_stride_y                      Stride of the source matrix in Y dimension (in bytes)
1824  * @param[in]  src0_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
1825  * @param[in]  src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1826  * @param[in]  src1_ptr                           Pointer to the source matrix. Supported data types: same as @p src0_ptr
1827  * @param[in]  src1_stride_x                      Stride of the source matrix in X dimension (in bytes)
1828  * @param[in]  src1_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
1829  * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
1830  * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
1831  * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
1832  * @param[in]  src1_step_z                        src_stride_z * number of elements along Z processed per workitem(in bytes)
1833  * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1834  * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
1835  * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
1836  * @param[in]  dst_step_x                         dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1837  * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
1838  * @param[in]  dst_step_y                         dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1839  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
1840  */
1841 __kernel void gemm_lc_vm_f32(IMAGE_DECLARATION(src0),
1842                              TENSOR3D_DECLARATION(src1),
1843                              IMAGE_DECLARATION(dst))
1844 {
1845     int idx = get_global_id(0) * 4;
1846     int idy = get_global_id(1);
1847
1848     // Compute the address for the vector A and matrix B
1849     int2 src_addr = ((int2)(src0_offset_first_element_in_bytes + src0_stride_y * idy, src1_offset_first_element_in_bytes + src1_stride_z * idy));
1850     src_addr.s1 += idx * sizeof(float);
1851
1852     int end_row_vec_a = src_addr.s0 + (WIDTH_VECTOR_A * sizeof(float));
1853
1854     float4 acc = 0.0f;
1855
1856     for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
1857     {
1858         float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0));
1859         float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1860         float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + src1_stride_y));
1861
1862         acc += b0 * (float4)a0.s0;
1863         acc += b1 * (float4)a0.s1;
1864     }
1865
1866     for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
1867     {
1868         float  a0 = *((__global float *)(src0_ptr + src_addr.s0));
1869         float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1870
1871         acc += b0 * (float4)a0;
1872     }
1873
1874     // Compute destination address
1875     Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1876
1877     vstore4(acc, 0, (__global float *)(offset(&dst, 0, 0)));
1878 }
1879 #endif // defined(WIDTH_VECTOR_A)
1880
1881 /** This kernel accumulates each row with the biases vector.
1882  *
1883  * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=short.
1884  * @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=16.
1885  *
1886  * @param[in, out] accum_ptr                            Pointer to the accumulate tensor. Supported data type: U8/S8/QS8/U16/S16/F16/U32/S32/F32
1887  * @param[in]      accum_stride_x                       Stride of the accmulate tensor in X dimension (in bytes)
1888  * @param[in]      accum_step_x                         accum_stride_x * number of elements along X processed per workitem(in bytes)
1889  * @param[in]      accum_stride_y                       Stride of the accumlulate tensor in Y dimension (in bytes)
1890  * @param[in]      accum_step_y                         src_stride_y * number of elements along Y processed per workitem(in bytes)
1891  * @param[in]      accum_offset_first_element_in_bytes  The offset of the first element in the accumulate tensor
1892  * @param[in]      biases_ptr                           Pointer to the biases vector. Same as @p accum_ptr
1893  * @param[in]      biases_stride_x                      Stride of the destination tensor in X dimension (in bytes)
1894  * @param[in]      biases_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
1895  * @param[in]      biases_offset_first_element_in_bytes The offset of the first element in the destination tensor
1896  */
1897 #if defined(DATA_TYPE) && defined(VECTOR_SIZE)
1898 __kernel void gemm_accumulate_biases(
1899     IMAGE_DECLARATION(accum),
1900     VECTOR_DECLARATION(biases))
1901 {
1902     Image  accum  = CONVERT_TO_IMAGE_STRUCT(accum);
1903     Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
1904
1905     // Vector size, i.e. number of vector elements.
1906     VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
1907     accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
1908     VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
1909     biases_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
1910 #ifdef FIXED_POINT_POSITION
1911     accum_value = ADD_SAT_OP_EXPAND(biases_value, accum_value, DATA_TYPE, VECTOR_SIZE);
1912 #else  // FIXED_POINT_POSITION
1913     accum_value = biases_value + accum_value;
1914 #endif // FIXED_POINT_POSITION
1915     // Store result in the accumulate buffer
1916     VSTORE(VECTOR_SIZE)
1917     (accum_value, 0, (__global DATA_TYPE *)accum.ptr);
1918 }
1919 #endif // defined(DATA_TYPE) && defined(VECTOR_SIZE)