2 * Copyright (c) 2017-2018 ARM Limited.
4 * SPDX-License-Identifier: MIT
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26 #ifdef FIXED_POINT_POSITION
27 #include "fixed_point.h"
28 #endif // FIXED_POINT_POSITION
30 #if defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
33 #define DATA_TYPE uchar
34 #elif ELEMENT_SIZE == 2
35 #define DATA_TYPE ushort
36 #elif ELEMENT_SIZE == 4
37 #define DATA_TYPE uint
38 #else // ELEMENT_SIZE == 1
39 #error "Element size not supported"
40 #endif // ELEMENT_SIZE
42 /** This OpenCL kernel computes the "vector" 1xW transposition of input matrix
44 * @note The transposition width must be passed at compile time using -DTRANSPOSE_W (i.e. -DTRANSPOSE_W)
45 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
47 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QS8/QASYMM8/U16/S16/QS16/F16/U32/S32/F32
48 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
49 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
50 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
51 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
52 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
53 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
54 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
55 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
56 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
57 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
58 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
59 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
60 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
61 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
62 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
64 __kernel void gemm_transpose1xW(TENSOR3D_DECLARATION(src),
65 TENSOR3D_DECLARATION(dst))
67 uint x = get_global_id(0);
68 uint y = get_global_id(1);
69 uint z = get_global_id(2);
71 // Compute address for Matrix B - source
72 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
74 // Compute address for Matrix B transposed - destination. X and Y are swapped
75 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + y * TRANSPOSE_W * sizeof(DATA_TYPE) * MULT_TRANSPOSE1XW_WIDTH + (x / MULT_TRANSPOSE1XW_WIDTH) * dst_stride_y +
76 (x % MULT_TRANSPOSE1XW_WIDTH) * TRANSPOSE_W * sizeof(DATA_TYPE);
78 // Add offset for batched GEMM
79 dst_addr_in_bytes += z * dst_stride_z;
81 VEC_DATA_TYPE(DATA_TYPE, TRANSPOSE_W)
82 b0 = VLOAD(TRANSPOSE_W)(0, (__global DATA_TYPE *)src.ptr);
85 (b0, 0, (__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes));
87 #endif // defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
89 #if defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
91 /** This OpenCL kernel reshapes the input matrix transposing each 4x4 block and interleaving the values
93 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
94 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
96 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QS8/QASYMM8/U16/S16/QS16/F16/U32/S32/F32
97 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
98 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
99 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
100 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
101 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
102 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
103 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
104 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
105 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
106 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
107 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
108 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
109 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
110 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
111 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
113 __kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
114 TENSOR3D_DECLARATION(dst))
116 // Compute source and destination addresses
117 uint x = get_global_id(0);
118 uint y = get_global_id(1);
119 uint z = get_global_id(2);
121 // Compute address for source tensor
122 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
124 // Compute address for Matrix B transposed - destination. X and Y are swapped
125 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) * 16 * MULT_INTERLEAVE4X4_HEIGHT + (y / MULT_INTERLEAVE4X4_HEIGHT) * dst_stride_y +
126 (y % MULT_INTERLEAVE4X4_HEIGHT) * 4 * sizeof(DATA_TYPE);
128 // Add offset for batched GEMM
129 dst_addr_in_bytes += z * dst_stride_z;
131 __global uchar *input_ptr = src.ptr;
133 // Load values from Matrix A
134 VEC_DATA_TYPE(DATA_TYPE, 4)
135 a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
136 VEC_DATA_TYPE(DATA_TYPE, 4)
137 a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
138 VEC_DATA_TYPE(DATA_TYPE, 4)
139 a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
140 VEC_DATA_TYPE(DATA_TYPE, 4)
141 a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
143 VEC_DATA_TYPE(DATA_TYPE, 4)
144 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s0, a1.s0, a2.s0, a3.s0);
145 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 0 * MULT_INTERLEAVE4X4_HEIGHT));
147 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s1, a1.s1, a2.s1, a3.s1);
148 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 4 * MULT_INTERLEAVE4X4_HEIGHT));
150 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s2, a1.s2, a2.s2, a3.s2);
151 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 8 * MULT_INTERLEAVE4X4_HEIGHT));
153 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s3, a1.s3, a2.s3, a3.s3);
154 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 12 * MULT_INTERLEAVE4X4_HEIGHT));
156 #endif // defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
158 #if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
159 /** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
160 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
162 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
163 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
164 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
165 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
166 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
168 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
169 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
170 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
171 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
172 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
173 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
174 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
175 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
176 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
177 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
178 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
179 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
180 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
181 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
182 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
183 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
184 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
185 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
187 __kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
188 IMAGE_DECLARATION(src1),
189 IMAGE_DECLARATION(dst),
194 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
195 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
196 int z = get_global_id(2);
199 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
200 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
202 // src_addr_a = address of matrix A
203 // src_addr_b = address of matrix B
204 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
205 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
207 #if defined(MATRIX_B_DEPTH)
208 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
209 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
210 #else // defined(MATRIX_B_DEPTH)
211 src1_addr_in_bytes += z * src1_stride_z;
212 #endif // defined(MATRIX_B_DEPTH)
214 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
215 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
217 // Compute end row address for matrix B
218 __global float *src_end_addr_b = src_addr_b + COLS_B;
220 src_addr_a += offset_row_a;
221 src_addr_b += offset_row_b;
223 // Reset accumulators
229 for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
231 // Load values from matrix A (interleaved) and matrix B (transposed)
232 float4 a0 = vload4(0, src_addr_a);
233 float4 b0 = vload4(0, src_addr_b);
235 c00 += (float4)a0.s0 * b0;
236 c10 += (float4)a0.s1 * b0;
237 c20 += (float4)a0.s2 * b0;
238 c30 += (float4)a0.s3 * b0;
240 // Load values from matrix A (interleaved) and matrix B (transposed)
241 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
242 b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
244 c00 += (float4)a0.s0 * b0;
245 c10 += (float4)a0.s1 * b0;
246 c20 += (float4)a0.s2 * b0;
247 c30 += (float4)a0.s3 * b0;
250 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH)
252 // Load values from matrix A (interleaved) and matrix B (transposed)
253 float4 a0 = vload4(0, src_addr_a);
254 float4 b0 = vload4(0, src_addr_b);
256 c00 += (float4)a0.s0 * b0;
257 c10 += (float4)a0.s1 * b0;
258 c20 += (float4)a0.s2 * b0;
259 c30 += (float4)a0.s3 * b0;
262 // Compute destination address
263 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
266 // Multiply by the weight of matrix product
267 c00 = c00 * (float4)ALPHA;
268 c10 = c10 * (float4)ALPHA;
269 c20 = c20 * (float4)ALPHA;
270 c30 = c30 * (float4)ALPHA;
271 #endif // defined(ALPHA)
273 // Compute dst address
274 __global uchar *dst_addr = offset(&dst, 0, 0);
276 // Add offset for batched GEMM
277 dst_addr += z * dst_stride_z;
280 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
281 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
282 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
283 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
286 /** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
287 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
289 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
290 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
291 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
292 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
293 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
294 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
296 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
297 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
298 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
299 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
300 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
301 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
302 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
303 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
304 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
305 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
306 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
307 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
308 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
309 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
310 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
311 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
312 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
313 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
315 __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
316 IMAGE_DECLARATION(src1),
317 IMAGE_DECLARATION(dst),
322 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
323 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
324 int z = get_global_id(2);
327 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
328 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
330 // src_addr_a = address of matrix A
331 // src_addr_b = address of matrix B
332 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
333 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
335 #if defined(MATRIX_B_DEPTH)
336 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
337 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
338 #else // defined(MATRIX_B_DEPTH)
339 src1_addr_in_bytes += z * src1_stride_z;
340 #endif // defined(MATRIX_B_DEPTH)
342 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
343 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
345 src_addr_a += offset_row_a;
346 src_addr_b += offset_row_b;
348 // Reset accumulators
366 #define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH))
369 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
371 // Load values from matrix A (interleaved) and matrix B (transposed)
372 float4 a0 = vload4(0, src_addr_a);
373 float4 b0 = vload4(0, src_addr_b);
375 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
376 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
378 c00 = fma(a0.s0, b0.s0, c00);
379 c01 = fma(a0.s0, b0.s1, c01);
380 c02 = fma(a0.s0, b0.s2, c02);
381 c03 = fma(a0.s0, b0.s3, c03);
383 c10 = fma(a0.s1, b0.s0, c10);
384 c11 = fma(a0.s1, b0.s1, c11);
385 c12 = fma(a0.s1, b0.s2, c12);
386 c13 = fma(a0.s1, b0.s3, c13);
388 c20 = fma(a0.s2, b0.s0, c20);
389 c21 = fma(a0.s2, b0.s1, c21);
390 c22 = fma(a0.s2, b0.s2, c22);
391 c23 = fma(a0.s2, b0.s3, c23);
393 c30 = fma(a0.s3, b0.s0, c30);
394 c31 = fma(a0.s3, b0.s1, c31);
395 c32 = fma(a0.s3, b0.s2, c32);
396 c33 = fma(a0.s3, b0.s3, c33);
398 // Load values from matrix A (interleaved) and matrix B (transposed)
399 a0 = vload4(0, src_addr_a);
400 b0 = vload4(0, src_addr_b);
402 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
403 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
405 c00 = fma(a0.s0, b0.s0, c00);
406 c01 = fma(a0.s0, b0.s1, c01);
407 c02 = fma(a0.s0, b0.s2, c02);
408 c03 = fma(a0.s0, b0.s3, c03);
410 c10 = fma(a0.s1, b0.s0, c10);
411 c11 = fma(a0.s1, b0.s1, c11);
412 c12 = fma(a0.s1, b0.s2, c12);
413 c13 = fma(a0.s1, b0.s3, c13);
415 c20 = fma(a0.s2, b0.s0, c20);
416 c21 = fma(a0.s2, b0.s1, c21);
417 c22 = fma(a0.s2, b0.s2, c22);
418 c23 = fma(a0.s2, b0.s3, c23);
420 c30 = fma(a0.s3, b0.s0, c30);
421 c31 = fma(a0.s3, b0.s1, c31);
422 c32 = fma(a0.s3, b0.s2, c32);
423 c33 = fma(a0.s3, b0.s3, c33);
425 // Load values from matrix A (interleaved) and matrix B (transposed)
426 a0 = vload4(0, src_addr_a);
427 b0 = vload4(0, src_addr_b);
429 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
430 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
432 c00 = fma(a0.s0, b0.s0, c00);
433 c01 = fma(a0.s0, b0.s1, c01);
434 c02 = fma(a0.s0, b0.s2, c02);
435 c03 = fma(a0.s0, b0.s3, c03);
437 c10 = fma(a0.s1, b0.s0, c10);
438 c11 = fma(a0.s1, b0.s1, c11);
439 c12 = fma(a0.s1, b0.s2, c12);
440 c13 = fma(a0.s1, b0.s3, c13);
442 c20 = fma(a0.s2, b0.s0, c20);
443 c21 = fma(a0.s2, b0.s1, c21);
444 c22 = fma(a0.s2, b0.s2, c22);
445 c23 = fma(a0.s2, b0.s3, c23);
447 c30 = fma(a0.s3, b0.s0, c30);
448 c31 = fma(a0.s3, b0.s1, c31);
449 c32 = fma(a0.s3, b0.s2, c32);
450 c33 = fma(a0.s3, b0.s3, c33);
452 // Load values from matrix A (interleaved) and matrix B (transposed)
453 a0 = vload4(0, src_addr_a);
454 b0 = vload4(0, src_addr_b);
456 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
457 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
459 c00 = fma(a0.s0, b0.s0, c00);
460 c01 = fma(a0.s0, b0.s1, c01);
461 c02 = fma(a0.s0, b0.s2, c02);
462 c03 = fma(a0.s0, b0.s3, c03);
464 c10 = fma(a0.s1, b0.s0, c10);
465 c11 = fma(a0.s1, b0.s1, c11);
466 c12 = fma(a0.s1, b0.s2, c12);
467 c13 = fma(a0.s1, b0.s3, c13);
469 c20 = fma(a0.s2, b0.s0, c20);
470 c21 = fma(a0.s2, b0.s1, c21);
471 c22 = fma(a0.s2, b0.s2, c22);
472 c23 = fma(a0.s2, b0.s3, c23);
474 c30 = fma(a0.s3, b0.s0, c30);
475 c31 = fma(a0.s3, b0.s1, c31);
476 c32 = fma(a0.s3, b0.s2, c32);
477 c33 = fma(a0.s3, b0.s3, c33);
480 for(; i < (int)(COLS_MTX_B); ++i)
482 // Load values from matrix A (interleaved) and matrix B (transposed)
483 float4 a0 = vload4(0, src_addr_a);
484 float4 b0 = vload4(0, src_addr_b);
486 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
487 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
489 c00 = fma(a0.s0, b0.s0, c00);
490 c01 = fma(a0.s0, b0.s1, c01);
491 c02 = fma(a0.s0, b0.s2, c02);
492 c03 = fma(a0.s0, b0.s3, c03);
494 c10 = fma(a0.s1, b0.s0, c10);
495 c11 = fma(a0.s1, b0.s1, c11);
496 c12 = fma(a0.s1, b0.s2, c12);
497 c13 = fma(a0.s1, b0.s3, c13);
499 c20 = fma(a0.s2, b0.s0, c20);
500 c21 = fma(a0.s2, b0.s1, c21);
501 c22 = fma(a0.s2, b0.s2, c22);
502 c23 = fma(a0.s2, b0.s3, c23);
504 c30 = fma(a0.s3, b0.s0, c30);
505 c31 = fma(a0.s3, b0.s1, c31);
506 c32 = fma(a0.s3, b0.s2, c32);
507 c33 = fma(a0.s3, b0.s3, c33);
510 // Compute destination address
511 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
514 // Multiply by the weight of matrix product
531 #endif // defined(ALPHA)
533 // Compute dst address
534 __global uchar *dst_addr = offset(&dst, 0, 0);
536 // Add offset for batched GEMM
537 dst_addr += z * dst_stride_z;
540 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
541 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
542 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
543 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
546 // Undefine local defines
549 #if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
550 /** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
551 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
553 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
554 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
555 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
556 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
557 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
559 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
560 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
561 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
562 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
563 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
564 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
565 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
566 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
567 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
568 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
569 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
570 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
571 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
572 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
573 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
574 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
575 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
576 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
578 __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
579 IMAGE_DECLARATION(src1),
580 IMAGE_DECLARATION(dst),
585 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
586 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
587 int z = get_global_id(2);
590 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
591 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
593 // src_addr_a = address of matrix A
594 // src_addr_b = address of matrix B
595 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
596 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
598 #if defined(MATRIX_B_DEPTH)
599 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
600 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
601 #else // defined(MATRIX_B_DEPTH)
602 src1_addr_in_bytes += z * src1_stride_z;
603 #endif // defined(MATRIX_B_DEPTH)
605 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
606 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
608 // Compute end row address for matrix B
609 __global half *src_end_addr_b = src_addr_b + COLS_B;
611 src_addr_a += offset_row_a;
612 src_addr_b += offset_row_b;
614 // Reset accumulators
620 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
622 // Load values from matrix A (interleaved) and matrix B (transposed)
623 half4 a0 = vload4(0, src_addr_a);
624 half8 b0 = vload8(0, src_addr_b);
626 c00 += (half8)a0.s0 * b0;
627 c10 += (half8)a0.s1 * b0;
628 c20 += (half8)a0.s2 * b0;
629 c30 += (half8)a0.s3 * b0;
631 // Load values from matrix A (interleaved) and matrix B (transposed)
632 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
633 b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
635 c00 += (half8)a0.s0 * b0;
636 c10 += (half8)a0.s1 * b0;
637 c20 += (half8)a0.s2 * b0;
638 c30 += (half8)a0.s3 * b0;
641 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
643 // Load values from matrix A (interleaved) and matrix B (transposed)
644 half4 a0 = vload4(0, src_addr_a);
645 half8 b0 = vload8(0, src_addr_b);
647 c00 += (half8)a0.s0 * b0;
648 c10 += (half8)a0.s1 * b0;
649 c20 += (half8)a0.s2 * b0;
650 c30 += (half8)a0.s3 * b0;
653 // Compute destination address
654 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
657 // Multiply by the weight of matrix product
658 c00 = c00 * (half8)ALPHA;
659 c10 = c10 * (half8)ALPHA;
660 c20 = c20 * (half8)ALPHA;
661 c30 = c30 * (half8)ALPHA;
662 #endif // defined(ALPHA)
664 // Compute dst address
665 __global uchar *dst_addr = offset(&dst, 0, 0);
667 // Add offset for batched GEMM
668 dst_addr += z * dst_stride_z;
671 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
672 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
673 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
674 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
677 /** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A (src0) and matrix B (src1)
678 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
680 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
681 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
682 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
683 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
684 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
686 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
687 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
688 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
689 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
690 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
691 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
692 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
693 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
694 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
695 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
696 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
697 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
698 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
699 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
700 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
701 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
702 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
703 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
705 __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
706 IMAGE_DECLARATION(src1),
707 IMAGE_DECLARATION(dst),
712 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
713 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
714 int z = get_global_id(2);
717 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
718 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
720 // src_addr_a = address of matrix A
721 // src_addr_b = address of matrix B
722 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
723 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
725 #if defined(MATRIX_B_DEPTH)
726 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
727 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
728 #else // defined(MATRIX_B_DEPTH)
729 src1_addr_in_bytes += z * src1_stride_z;
730 #endif // defined(MATRIX_B_DEPTH)
732 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
733 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
735 // Compute end row address for matrix B
736 __global half *src_end_addr_b = src_addr_b + COLS_B;
738 src_addr_a += offset_row_a;
739 src_addr_b += offset_row_b;
741 // Reset accumulators
747 #define COLS_MTX_B (COLS_B / (8 * MULT_TRANSPOSE1XW_WIDTH))
750 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
752 #if MULT_INTERLEAVE4X4_HEIGHT == 1
753 // Load values from matrix A (interleaved) and matrix B (transposed)
754 half8 a0 = vload8(0, src_addr_a);
755 half8 b0 = vload8(0, src_addr_b);
757 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
758 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
760 c00 = fma((half8)a0.s0, b0, c00);
761 c10 = fma((half8)a0.s1, b0, c10);
762 c20 = fma((half8)a0.s2, b0, c20);
763 c30 = fma((half8)a0.s3, b0, c30);
765 // Load values from matrix B (transposed)
766 b0 = vload8(0, src_addr_b);
768 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
770 c00 = fma((half8)a0.s4, b0, c00);
771 c10 = fma((half8)a0.s5, b0, c10);
772 c20 = fma((half8)a0.s6, b0, c20);
773 c30 = fma((half8)a0.s7, b0, c30);
775 // Load values from matrix A (interleaved) and matrix B (transposed)
776 a0 = vload8(0, src_addr_a);
777 b0 = vload8(0, src_addr_b);
779 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
780 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
782 c00 = fma((half8)a0.s0, b0, c00);
783 c10 = fma((half8)a0.s1, b0, c10);
784 c20 = fma((half8)a0.s2, b0, c20);
785 c30 = fma((half8)a0.s3, b0, c30);
787 // Load values from matrix B (transposed)
788 b0 = vload8(0, src_addr_b);
790 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
792 c00 = fma((half8)a0.s4, b0, c00);
793 c10 = fma((half8)a0.s5, b0, c10);
794 c20 = fma((half8)a0.s6, b0, c20);
795 c30 = fma((half8)a0.s7, b0, c30);
796 #else // MULT_INTERLEAVE4X4_HEIGHT == 1
797 // Load values from matrix A (interleaved) and matrix B (transposed)
798 half4 a0 = vload4(0, src_addr_a);
799 half8 b0 = vload8(0, src_addr_b);
801 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
802 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
804 c00 = fma((half8)a0.s0, b0, c00);
805 c10 = fma((half8)a0.s1, b0, c10);
806 c20 = fma((half8)a0.s2, b0, c20);
807 c30 = fma((half8)a0.s3, b0, c30);
809 // Load values from matrix A (interleaved) and matrix B (transposed)
810 a0 = vload4(0, src_addr_a);
811 b0 = vload8(0, src_addr_b);
813 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
814 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
816 c00 = fma((half8)a0.s0, b0, c00);
817 c10 = fma((half8)a0.s1, b0, c10);
818 c20 = fma((half8)a0.s2, b0, c20);
819 c30 = fma((half8)a0.s3, b0, c30);
821 // Load values from matrix A (interleaved) and matrix B (transposed)
822 a0 = vload4(0, src_addr_a);
823 b0 = vload8(0, src_addr_b);
825 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
826 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
828 c00 = fma((half8)a0.s0, b0, c00);
829 c10 = fma((half8)a0.s1, b0, c10);
830 c20 = fma((half8)a0.s2, b0, c20);
831 c30 = fma((half8)a0.s3, b0, c30);
833 // Load values from matrix A (interleaved) and matrix B (transposed)
834 a0 = vload4(0, src_addr_a);
835 b0 = vload8(0, src_addr_b);
837 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
838 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
840 c00 = fma((half8)a0.s0, b0, c00);
841 c10 = fma((half8)a0.s1, b0, c10);
842 c20 = fma((half8)a0.s2, b0, c20);
843 c30 = fma((half8)a0.s3, b0, c30);
844 #endif // MULT_INTERLEAVE4X4_HEIGHT == 1
847 for(; i < (int)(COLS_MTX_B); ++i)
849 // Load values from matrix A (interleaved) and matrix B (transposed)
850 half4 a0 = vload4(0, src_addr_a);
851 half8 b0 = vload8(0, src_addr_b);
853 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
854 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
856 c00 = fma((half8)a0.s0, b0, c00);
857 c10 = fma((half8)a0.s1, b0, c10);
858 c20 = fma((half8)a0.s2, b0, c20);
859 c30 = fma((half8)a0.s3, b0, c30);
862 // Compute destination address
863 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
866 // Multiply by the weight of matrix product
867 c00 = c00 * (half8)ALPHA;
868 c10 = c10 * (half8)ALPHA;
869 c20 = c20 * (half8)ALPHA;
870 c30 = c30 * (half8)ALPHA;
871 #endif // defined(ALPHA)
873 // Compute dst address
874 __global uchar *dst_addr = offset(&dst, 0, 0);
876 // Add offset for batched GEMM
877 dst_addr += z * dst_stride_z;
880 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
881 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
882 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
883 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
886 // Undefine local defines
889 #endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
891 #if defined(FIXED_POINT_POSITION)
892 /** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) in 8 bit fixed point precision
893 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_8bit and @ref gemm_transpose1x16 before running the matrix multiplication
895 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
896 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
897 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
898 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
899 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
900 * @note:ALPHA must be passed in 8 bit fixed point format
902 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS8
903 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
904 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
905 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
906 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
907 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
908 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
909 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
910 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
911 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
912 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
913 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
914 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
915 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
916 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
917 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
918 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
919 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
921 __kernel void gemm_mm_interleaved_transposed_qs8(IMAGE_DECLARATION(src0),
922 IMAGE_DECLARATION(src1),
923 IMAGE_DECLARATION(dst),
928 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
929 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
930 int z = get_global_id(2);
933 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
934 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 16;
936 // src_addr_a = address of matrix A
937 // src_addr_b = address of matrix B
938 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
939 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
941 #if defined(MATRIX_B_DEPTH)
942 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
943 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
944 #else // defined(MATRIX_B_DEPTH)
945 src1_addr_in_bytes += z * src1_stride_z;
946 #endif // defined(MATRIX_B_DEPTH)
948 __global char *src_addr_a = (__global char *)(src0_ptr + src0_addr_in_bytes);
949 __global char *src_addr_b = (__global char *)(src1_ptr + src1_addr_in_bytes);
951 // Compute end row address for matrix B
952 __global char *src_end_addr_b = src_addr_b + COLS_B;
954 src_addr_a += offset_row_a;
955 src_addr_b += offset_row_b;
957 // Reset accumulators
967 // This for loop performs 1 accumulation for each iteration
968 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
970 // Load values from matrix A (interleaved) and matrix B (transposed)
971 char4 a0 = vload4(0, src_addr_a);
972 char16 b0 = vload16(0, src_addr_b);
974 c00 = mlal_sat_qs8x8(c00, (char8)a0.s0, b0.s01234567, FIXED_POINT_POSITION);
975 c10 = mlal_sat_qs8x8(c10, (char8)a0.s1, b0.s01234567, FIXED_POINT_POSITION);
976 c20 = mlal_sat_qs8x8(c20, (char8)a0.s2, b0.s01234567, FIXED_POINT_POSITION);
977 c30 = mlal_sat_qs8x8(c30, (char8)a0.s3, b0.s01234567, FIXED_POINT_POSITION);
979 c01 = mlal_sat_qs8x8(c01, (char8)a0.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
980 c11 = mlal_sat_qs8x8(c11, (char8)a0.s1, b0.s89ABCDEF, FIXED_POINT_POSITION);
981 c21 = mlal_sat_qs8x8(c21, (char8)a0.s2, b0.s89ABCDEF, FIXED_POINT_POSITION);
982 c31 = mlal_sat_qs8x8(c31, (char8)a0.s3, b0.s89ABCDEF, FIXED_POINT_POSITION);
985 // Compute destination address
986 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
988 // Multiply by the weight of matrix product
989 char16 c00_qs8 = convert_char16_sat((short16)(c00, c01));
990 char16 c10_qs8 = convert_char16_sat((short16)(c10, c11));
991 char16 c20_qs8 = convert_char16_sat((short16)(c20, c21));
992 char16 c30_qs8 = convert_char16_sat((short16)(c30, c31));
995 c00_qs8 = mul_sat_qs8x16(c00_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
996 c10_qs8 = mul_sat_qs8x16(c10_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
997 c20_qs8 = mul_sat_qs8x16(c20_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
998 c30_qs8 = mul_sat_qs8x16(c30_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
999 #endif // defined(ALPHA)
1001 // Compute dst address
1002 __global uchar *dst_addr = offset(&dst, 0, 0);
1004 // Add offset for batched GEMM
1005 dst_addr += z * dst_stride_z;
1008 vstore16(c00_qs8, 0, (__global char *)(dst_addr + 0 * dst_stride_y));
1009 vstore16(c10_qs8, 0, (__global char *)(dst_addr + 1 * dst_stride_y));
1010 vstore16(c20_qs8, 0, (__global char *)(dst_addr + 2 * dst_stride_y));
1011 vstore16(c30_qs8, 0, (__global char *)(dst_addr + 3 * dst_stride_y));
1014 /** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) in 16 bit fixed point precision
1015 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
1017 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
1018 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
1019 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
1020 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1021 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
1022 * @note:ALPHA must be passed in 16 bit fixed point format
1024 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS16
1025 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1026 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1027 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1028 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1029 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1030 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1031 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1032 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1033 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1034 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1035 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1036 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1037 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1038 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1039 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1040 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1041 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1043 __kernel void gemm_mm_interleaved_transposed_qs16(IMAGE_DECLARATION(src0),
1044 IMAGE_DECLARATION(src1),
1045 IMAGE_DECLARATION(dst),
1050 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
1051 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
1052 int z = get_global_id(2);
1055 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
1056 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
1058 // src_addr_a = address of matrix A
1059 // src_addr_b = address of matrix B
1060 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
1061 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
1063 #if defined(MATRIX_B_DEPTH)
1064 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1065 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
1066 #else // defined(MATRIX_B_DEPTH)
1067 src1_addr_in_bytes += z * src1_stride_z;
1068 #endif // defined(MATRIX_B_DEPTH)
1070 __global short *src_addr_a = (__global short *)(src0_ptr + src0_addr_in_bytes);
1071 __global short *src_addr_b = (__global short *)(src1_ptr + src1_addr_in_bytes);
1073 // Compute end row address for matrix B
1074 __global short *src_end_addr_b = src_addr_b + COLS_B;
1076 src_addr_a += offset_row_a;
1077 src_addr_b += offset_row_b;
1079 // Reset accumulators
1085 // This for loop performs 1 accumulation for each iteration
1086 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
1088 /* Load values from matrix A (interleaved) and matrix B (transposed) */
1089 short4 a0 = vload4(0, src_addr_a);
1090 short8 b0 = vload8(0, src_addr_b);
1092 c00 = mlal_sat_qs16x8(c00, (short8)a0.s0, b0, FIXED_POINT_POSITION);
1093 c10 = mlal_sat_qs16x8(c10, (short8)a0.s1, b0, FIXED_POINT_POSITION);
1094 c20 = mlal_sat_qs16x8(c20, (short8)a0.s2, b0, FIXED_POINT_POSITION);
1095 c30 = mlal_sat_qs16x8(c30, (short8)a0.s3, b0, FIXED_POINT_POSITION);
1098 // Compute destination address
1099 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1101 // Multiply by the weight of matrix product
1102 short8 c00_qs16 = convert_short8_sat(c00);
1103 short8 c10_qs16 = convert_short8_sat(c10);
1104 short8 c20_qs16 = convert_short8_sat(c20);
1105 short8 c30_qs16 = convert_short8_sat(c30);
1108 c00_qs16 = mul_sat_qs16x8(c00_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
1109 c10_qs16 = mul_sat_qs16x8(c10_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
1110 c20_qs16 = mul_sat_qs16x8(c20_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
1111 c30_qs16 = mul_sat_qs16x8(c30_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
1112 #endif // defined(ALPHA)
1114 // Compute dst address
1115 __global uchar *dst_addr = offset(&dst, 0, 0);
1117 // Add offset for batched GEMM
1118 dst_addr += z * dst_stride_z;
1121 vstore8(c00_qs16, 0, (__global short *)(dst_addr + 0 * dst_stride_y));
1122 vstore8(c10_qs16, 0, (__global short *)(dst_addr + 1 * dst_stride_y));
1123 vstore8(c20_qs16, 0, (__global short *)(dst_addr + 2 * dst_stride_y));
1124 vstore8(c30_qs16, 0, (__global short *)(dst_addr + 3 * dst_stride_y));
1126 #endif // defined(FIXED_POINT_POSITION)
1127 #endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
1129 #if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
1130 #if defined(DATA_TYPE)
1131 #define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
1132 /** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
1134 * @note This OpenCL kernel works with floating point data types (F16/F32)
1135 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
1136 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
1137 * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
1138 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1139 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
1141 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
1142 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1143 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1144 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1145 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1146 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1147 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1148 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1149 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1150 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1151 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1152 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1153 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1154 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1155 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1156 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1157 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1158 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1160 __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
1161 IMAGE_DECLARATION(src1),
1162 IMAGE_DECLARATION(dst),
1167 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1169 // Compute starting address for matrix A and Matrix B
1170 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1172 // Update address for the matrix A
1173 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1175 // Update address for the matrix B
1176 src_addr.s1 += idx * sizeof(DATA_TYPE);
1178 // Add offset for batched GEMM
1179 src_addr.s0 += get_global_id(2) * src0_stride_z;
1181 #if defined(MATRIX_B_DEPTH)
1182 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1183 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1184 #else // defined(MATRIX_B_DEPTH)
1185 src_addr.s1 += get_global_id(2) * src1_stride_z;
1186 #endif // defined(MATRIX_B_DEPTH)
1188 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
1190 VECTOR_TYPE acc0 = 0.0f;
1191 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1192 VECTOR_TYPE acc1 = 0.0f;
1193 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1194 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1195 VECTOR_TYPE acc2 = 0.0f;
1196 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1197 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1198 VECTOR_TYPE acc3 = 0.0f;
1199 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1201 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
1203 // Load values from matrix A
1204 VEC_DATA_TYPE(DATA_TYPE, 2)
1205 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1206 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1207 VEC_DATA_TYPE(DATA_TYPE, 2)
1208 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1209 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1210 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1211 VEC_DATA_TYPE(DATA_TYPE, 2)
1212 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1213 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1214 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1215 VEC_DATA_TYPE(DATA_TYPE, 2)
1216 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1217 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1218 // Load values from matrix B
1219 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
1220 VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
1223 acc0 += b0 * (VECTOR_TYPE)a0.s0;
1224 acc0 += b1 * (VECTOR_TYPE)a0.s1;
1225 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1226 acc1 += b0 * (VECTOR_TYPE)a1.s0;
1227 acc1 += b1 * (VECTOR_TYPE)a1.s1;
1228 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1229 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1230 acc2 += b0 * (VECTOR_TYPE)a2.s0;
1231 acc2 += b1 * (VECTOR_TYPE)a2.s1;
1232 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1233 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1234 acc3 += b0 * (VECTOR_TYPE)a3.s0;
1235 acc3 += b1 * (VECTOR_TYPE)a3.s1;
1236 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1239 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
1241 // Load values from matrix A
1242 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1243 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1244 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1245 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1246 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1247 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1248 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1249 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1250 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1251 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1252 // Load values from matrix B
1253 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
1256 acc0 += b0 * (VECTOR_TYPE)a0;
1257 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1258 acc1 += b0 * (VECTOR_TYPE)a1;
1259 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1260 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1261 acc2 += b0 * (VECTOR_TYPE)a2;
1262 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1263 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1264 acc3 += b0 * (VECTOR_TYPE)a3;
1265 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1268 // Compute destination address
1269 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1271 // Compute dst address
1272 __global uchar *dst_addr = offset(&dst, 0, 0);
1274 // Add offset for batched GEMM
1275 dst_addr += get_global_id(2) * dst_stride_z;
1277 // Multiply by the weight of matrix-matrix product and store the result
1279 acc0 = acc0 * (VECTOR_TYPE)ALPHA;
1280 #endif // defined(ALPHA)
1281 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1282 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
1283 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1285 acc1 = acc1 * (VECTOR_TYPE)ALPHA;
1286 #endif // defined(ALPHA)
1287 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1288 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
1289 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1290 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1292 acc2 = acc2 * (VECTOR_TYPE)ALPHA;
1293 #endif // defined(ALPHA)
1294 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1295 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
1296 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1297 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1299 acc3 = acc3 * (VECTOR_TYPE)ALPHA;
1300 #endif // defined(ALPHA)
1301 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1302 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
1303 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1305 #endif // defined(DATA_TYPE)
1307 /** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
1309 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
1310 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
1311 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
1312 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
1313 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
1314 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1315 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
1317 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
1318 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1319 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1320 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1321 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1322 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1323 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1324 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1325 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1326 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1327 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1328 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1329 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1330 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1331 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1332 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1333 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1334 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1336 __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
1337 IMAGE_DECLARATION(src1),
1338 IMAGE_DECLARATION(dst),
1343 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1345 // Compute starting address for matrix A and matrix B
1346 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1348 // Update address for matrix A
1349 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1351 // Update address for matrix B
1352 src_addr.s1 += idx * sizeof(float);
1354 // Add offset for batched GEMM
1355 src_addr.s0 += get_global_id(2) * src0_stride_z;
1357 #if defined(MATRIX_B_DEPTH)
1358 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1359 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1360 #else // defined(MATRIX_B_DEPTH)
1361 src_addr.s1 += get_global_id(2) * src1_stride_z;
1362 #endif // defined(MATRIX_B_DEPTH)
1364 // Initialize accumulators
1370 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1375 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1377 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1382 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1384 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1389 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1391 // A and B src indices get incremented at the same time.
1393 for(; i <= ((int)COLS_A - 4); i += 4)
1395 // Load values from matrix A and matrix B
1396 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1397 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1398 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1399 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1400 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1401 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1402 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1403 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1404 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1405 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1406 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1407 src_addr.s1 += src1_stride_y;
1409 // Multiply and accumulate
1410 acc00 = fma(a0.s0, b0.s0, acc00);
1411 acc01 = fma(a0.s0, b0.s1, acc01);
1412 acc02 = fma(a0.s0, b0.s2, acc02);
1413 acc03 = fma(a0.s0, b0.s3, acc03);
1415 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1417 acc10 = fma(a1.s0, b0.s0, acc10);
1418 acc11 = fma(a1.s0, b0.s1, acc11);
1419 acc12 = fma(a1.s0, b0.s2, acc12);
1420 acc13 = fma(a1.s0, b0.s3, acc13);
1422 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1423 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1425 acc20 = fma(a2.s0, b0.s0, acc20);
1426 acc21 = fma(a2.s0, b0.s1, acc21);
1427 acc22 = fma(a2.s0, b0.s2, acc22);
1428 acc23 = fma(a2.s0, b0.s3, acc23);
1430 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1431 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1433 acc30 = fma(a3.s0, b0.s0, acc30);
1434 acc31 = fma(a3.s0, b0.s1, acc31);
1435 acc32 = fma(a3.s0, b0.s2, acc32);
1436 acc33 = fma(a3.s0, b0.s3, acc33);
1437 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1439 // Load values from matrix A and matrix B
1440 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1441 src_addr.s1 += src1_stride_y;
1443 // Multiply and accumulate
1444 acc00 = fma(a0.s1, b0.s0, acc00);
1445 acc01 = fma(a0.s1, b0.s1, acc01);
1446 acc02 = fma(a0.s1, b0.s2, acc02);
1447 acc03 = fma(a0.s1, b0.s3, acc03);
1449 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1451 acc10 = fma(a1.s1, b0.s0, acc10);
1452 acc11 = fma(a1.s1, b0.s1, acc11);
1453 acc12 = fma(a1.s1, b0.s2, acc12);
1454 acc13 = fma(a1.s1, b0.s3, acc13);
1456 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1457 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1459 acc20 = fma(a2.s1, b0.s0, acc20);
1460 acc21 = fma(a2.s1, b0.s1, acc21);
1461 acc22 = fma(a2.s1, b0.s2, acc22);
1462 acc23 = fma(a2.s1, b0.s3, acc23);
1464 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1465 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1467 acc30 = fma(a3.s1, b0.s0, acc30);
1468 acc31 = fma(a3.s1, b0.s1, acc31);
1469 acc32 = fma(a3.s1, b0.s2, acc32);
1470 acc33 = fma(a3.s1, b0.s3, acc33);
1471 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1473 // Load values from matrix A and matrix B
1474 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1475 src_addr.s1 += src1_stride_y;
1477 // Multiply and accumulate
1478 acc00 = fma(a0.s2, b0.s0, acc00);
1479 acc01 = fma(a0.s2, b0.s1, acc01);
1480 acc02 = fma(a0.s2, b0.s2, acc02);
1481 acc03 = fma(a0.s2, b0.s3, acc03);
1483 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1485 acc10 = fma(a1.s2, b0.s0, acc10);
1486 acc11 = fma(a1.s2, b0.s1, acc11);
1487 acc12 = fma(a1.s2, b0.s2, acc12);
1488 acc13 = fma(a1.s2, b0.s3, acc13);
1490 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1491 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1493 acc20 = fma(a2.s2, b0.s0, acc20);
1494 acc21 = fma(a2.s2, b0.s1, acc21);
1495 acc22 = fma(a2.s2, b0.s2, acc22);
1496 acc23 = fma(a2.s2, b0.s3, acc23);
1498 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1499 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1501 acc30 = fma(a3.s2, b0.s0, acc30);
1502 acc31 = fma(a3.s2, b0.s1, acc31);
1503 acc32 = fma(a3.s2, b0.s2, acc32);
1504 acc33 = fma(a3.s2, b0.s3, acc33);
1505 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1507 // Load values from matrix A and matrix B
1508 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1509 src_addr.s1 += src1_stride_y;
1511 // Multiply and accumulate
1512 acc00 = fma(a0.s3, b0.s0, acc00);
1513 acc01 = fma(a0.s3, b0.s1, acc01);
1514 acc02 = fma(a0.s3, b0.s2, acc02);
1515 acc03 = fma(a0.s3, b0.s3, acc03);
1517 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1519 acc10 = fma(a1.s3, b0.s0, acc10);
1520 acc11 = fma(a1.s3, b0.s1, acc11);
1521 acc12 = fma(a1.s3, b0.s2, acc12);
1522 acc13 = fma(a1.s3, b0.s3, acc13);
1524 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1525 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1527 acc20 = fma(a2.s3, b0.s0, acc20);
1528 acc21 = fma(a2.s3, b0.s1, acc21);
1529 acc22 = fma(a2.s3, b0.s2, acc22);
1530 acc23 = fma(a2.s3, b0.s3, acc23);
1532 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1533 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1535 acc30 = fma(a3.s3, b0.s0, acc30);
1536 acc31 = fma(a3.s3, b0.s1, acc31);
1537 acc32 = fma(a3.s3, b0.s2, acc32);
1538 acc33 = fma(a3.s3, b0.s3, acc33);
1539 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1541 src_addr.s0 += 4 * sizeof(float);
1544 for(; i < (int)COLS_A; ++i)
1546 // Load values from matrix A
1547 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
1548 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1549 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1550 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1551 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1552 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1553 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1554 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1555 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1556 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1557 // Load values from matrix B
1558 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1559 src_addr.s1 += src1_stride_y;
1561 // Multiply and accumulate
1562 acc00 = fma(a0, b0.s0, acc00);
1563 acc01 = fma(a0, b0.s1, acc01);
1564 acc02 = fma(a0, b0.s2, acc02);
1565 acc03 = fma(a0, b0.s3, acc03);
1566 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1567 acc10 = fma(a1, b0.s0, acc10);
1568 acc11 = fma(a1, b0.s1, acc11);
1569 acc12 = fma(a1, b0.s2, acc12);
1570 acc13 = fma(a1, b0.s3, acc13);
1571 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1572 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1573 acc20 = fma(a2, b0.s0, acc20);
1574 acc21 = fma(a2, b0.s1, acc21);
1575 acc22 = fma(a2, b0.s2, acc22);
1576 acc23 = fma(a2, b0.s3, acc23);
1577 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1578 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1579 acc30 = fma(a3, b0.s0, acc30);
1580 acc31 = fma(a3, b0.s1, acc31);
1581 acc32 = fma(a3, b0.s2, acc32);
1582 acc33 = fma(a3, b0.s3, acc33);
1583 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1585 src_addr.s0 += sizeof(float);
1588 // Compute destination address
1589 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1591 // Multiply by the weight of matrix-matrix product and store the result
1593 acc00 = acc00 * ALPHA;
1594 acc01 = acc01 * ALPHA;
1595 acc02 = acc02 * ALPHA;
1596 acc03 = acc03 * ALPHA;
1597 #endif // defined(ALPHA)
1599 // Compute dst address
1600 __global uchar *dst_addr = offset(&dst, 0, 0);
1602 // Add offset for batched GEMM
1603 dst_addr += get_global_id(2) * dst_stride_z;
1605 float4 acc0 = ((float4)(acc00, acc01, acc02, acc03));
1606 vstore4(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
1608 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1610 acc10 = acc10 * ALPHA;
1611 acc11 = acc11 * ALPHA;
1612 acc12 = acc12 * ALPHA;
1613 acc13 = acc13 * ALPHA;
1614 #endif // defined(ALPHA)
1615 float4 acc1 = ((float4)(acc10, acc11, acc12, acc13));
1616 vstore4(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
1617 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1618 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1620 acc20 = acc20 * ALPHA;
1621 acc21 = acc21 * ALPHA;
1622 acc22 = acc22 * ALPHA;
1623 acc23 = acc23 * ALPHA;
1624 #endif // defined(ALPHA)
1625 float4 acc2 = ((float4)(acc20, acc21, acc22, acc23));
1626 vstore4(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
1627 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1628 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1630 acc30 = acc30 * ALPHA;
1631 acc31 = acc31 * ALPHA;
1632 acc32 = acc32 * ALPHA;
1633 acc33 = acc33 * ALPHA;
1634 #endif // defined(ALPHA)
1635 float4 acc3 = ((float4)(acc30, acc31, acc32, acc33));
1636 vstore4(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
1637 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1640 /** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
1642 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
1643 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
1644 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
1645 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
1646 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
1647 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
1648 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1649 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
1651 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
1652 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1653 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1654 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1655 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1656 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1657 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1658 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1659 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1660 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1661 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1662 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1663 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1664 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1665 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1666 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1667 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1668 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1670 __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
1671 IMAGE_DECLARATION(src1),
1672 IMAGE_DECLARATION(dst),
1677 // Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1678 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1680 // Compute starting address for matrix A and Matrix B
1681 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1683 // Update address for the matrix A
1684 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1686 // Update address for the matrix B
1687 src_addr.s1 += idx * sizeof(float);
1689 // Add offset for batched GEMM
1690 src_addr.s0 += get_global_id(2) * src0_stride_z;
1692 #if defined(MATRIX_B_DEPTH)
1693 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1694 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1695 #else // defined(MATRIX_B_DEPTH)
1696 src_addr.s1 += get_global_id(2) * src1_stride_z;
1697 #endif // defined(MATRIX_B_DEPTH)
1699 // Initialize accumulators
1703 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1706 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1707 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1710 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1711 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1714 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1716 // A and B src indices get incremented at the same time.
1718 for(; i <= ((int)COLS_A - 8); i += 8)
1720 // Load values from matrix A
1721 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
1723 // Load values from matrix B
1724 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1725 src_addr.s1 += src1_stride_y;
1726 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1727 src_addr.s1 += src1_stride_y;
1728 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1729 src_addr.s1 += src1_stride_y;
1730 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1731 src_addr.s1 += src1_stride_y;
1732 float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1733 src_addr.s1 += src1_stride_y;
1734 float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1735 src_addr.s1 += src1_stride_y;
1736 float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1737 src_addr.s1 += src1_stride_y;
1738 float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1739 src_addr.s1 += src1_stride_y;
1741 // Multiply and accumulate
1742 acc00 = fma(a0.s0, b0.s0, acc00);
1743 acc00 = fma(a0.s1, b1.s0, acc00);
1744 acc00 = fma(a0.s2, b2.s0, acc00);
1745 acc00 = fma(a0.s3, b3.s0, acc00);
1746 acc00 = fma(a0.s4, b4.s0, acc00);
1747 acc00 = fma(a0.s5, b5.s0, acc00);
1748 acc00 = fma(a0.s6, b6.s0, acc00);
1749 acc00 = fma(a0.s7, b7.s0, acc00);
1751 acc01 = fma(a0.s0, b0.s1, acc01);
1752 acc01 = fma(a0.s1, b1.s1, acc01);
1753 acc01 = fma(a0.s2, b2.s1, acc01);
1754 acc01 = fma(a0.s3, b3.s1, acc01);
1755 acc01 = fma(a0.s4, b4.s1, acc01);
1756 acc01 = fma(a0.s5, b5.s1, acc01);
1757 acc01 = fma(a0.s6, b6.s1, acc01);
1758 acc01 = fma(a0.s7, b7.s1, acc01);
1760 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1761 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1762 acc10 = fma(a0.s0, b0.s0, acc10);
1763 acc10 = fma(a0.s1, b1.s0, acc10);
1764 acc10 = fma(a0.s2, b2.s0, acc10);
1765 acc10 = fma(a0.s3, b3.s0, acc10);
1766 acc10 = fma(a0.s4, b4.s0, acc10);
1767 acc10 = fma(a0.s5, b5.s0, acc10);
1768 acc10 = fma(a0.s6, b6.s0, acc10);
1769 acc10 = fma(a0.s7, b7.s0, acc10);
1771 acc11 = fma(a0.s0, b0.s1, acc11);
1772 acc11 = fma(a0.s1, b1.s1, acc11);
1773 acc11 = fma(a0.s2, b2.s1, acc11);
1774 acc11 = fma(a0.s3, b3.s1, acc11);
1775 acc11 = fma(a0.s4, b4.s1, acc11);
1776 acc11 = fma(a0.s5, b5.s1, acc11);
1777 acc11 = fma(a0.s6, b6.s1, acc11);
1778 acc11 = fma(a0.s7, b7.s1, acc11);
1779 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1780 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1781 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1782 acc20 = fma(a0.s0, b0.s0, acc20);
1783 acc20 = fma(a0.s1, b1.s0, acc20);
1784 acc20 = fma(a0.s2, b2.s0, acc20);
1785 acc20 = fma(a0.s3, b3.s0, acc20);
1786 acc20 = fma(a0.s4, b4.s0, acc20);
1787 acc20 = fma(a0.s5, b5.s0, acc20);
1788 acc20 = fma(a0.s6, b6.s0, acc20);
1789 acc20 = fma(a0.s7, b7.s0, acc20);
1791 acc21 = fma(a0.s0, b0.s1, acc21);
1792 acc21 = fma(a0.s1, b1.s1, acc21);
1793 acc21 = fma(a0.s2, b2.s1, acc21);
1794 acc21 = fma(a0.s3, b3.s1, acc21);
1795 acc21 = fma(a0.s4, b4.s1, acc21);
1796 acc21 = fma(a0.s5, b5.s1, acc21);
1797 acc21 = fma(a0.s6, b6.s1, acc21);
1798 acc21 = fma(a0.s7, b7.s1, acc21);
1799 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1800 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1801 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1802 acc30 = fma(a0.s0, b0.s0, acc30);
1803 acc30 = fma(a0.s1, b1.s0, acc30);
1804 acc30 = fma(a0.s2, b2.s0, acc30);
1805 acc30 = fma(a0.s3, b3.s0, acc30);
1806 acc30 = fma(a0.s4, b4.s0, acc30);
1807 acc30 = fma(a0.s5, b5.s0, acc30);
1808 acc30 = fma(a0.s6, b6.s0, acc30);
1809 acc30 = fma(a0.s7, b7.s0, acc30);
1811 acc31 = fma(a0.s0, b0.s1, acc31);
1812 acc31 = fma(a0.s1, b1.s1, acc31);
1813 acc31 = fma(a0.s2, b2.s1, acc31);
1814 acc31 = fma(a0.s3, b3.s1, acc31);
1815 acc31 = fma(a0.s4, b4.s1, acc31);
1816 acc31 = fma(a0.s5, b5.s1, acc31);
1817 acc31 = fma(a0.s6, b6.s1, acc31);
1818 acc31 = fma(a0.s7, b7.s1, acc31);
1819 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1821 src_addr.s0 += sizeof(float) * 8;
1823 // float size increment
1824 for(; i < (int)COLS_A; ++i)
1826 // Load values from matrix A
1827 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1828 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1829 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1830 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1831 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1832 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1833 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1834 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1835 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1836 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1837 // Load values from matrix B
1838 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
1839 src_addr.s1 += src1_stride_y;
1841 // Multiply and accumulate
1842 acc00 = fma(a0, b0.s0, acc00);
1843 acc01 = fma(a0, b0.s1, acc01);
1844 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1845 acc10 = fma(a1, b0.s0, acc10);
1846 acc11 = fma(a1, b0.s1, acc11);
1847 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1848 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1849 acc20 = fma(a2, b0.s0, acc20);
1850 acc21 = fma(a2, b0.s1, acc21);
1851 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1852 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1853 acc30 = fma(a3, b0.s0, acc30);
1854 acc31 = fma(a3, b0.s1, acc31);
1855 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1857 src_addr.s0 += sizeof(float);
1860 // Compute destination address
1861 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1863 // Compute dst address
1864 __global uchar *dst_addr = offset(&dst, 0, 0);
1866 // Add offset for batched GEMM
1867 dst_addr += get_global_id(2) * dst_stride_z;
1869 // Multiply by the weight of matrix-matrix product and store the result
1871 acc00 = acc00 * ALPHA;
1872 acc01 = acc01 * ALPHA;
1873 #endif // defined(ALPHA)
1874 float2 acc0 = ((float2)(acc00, acc01));
1875 vstore2(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
1876 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1878 acc10 = acc10 * ALPHA;
1879 acc11 = acc11 * ALPHA;
1880 #endif // defined(ALPHA)
1881 float2 acc1 = ((float2)(acc10, acc11));
1882 vstore2(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
1883 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1884 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1886 acc20 = acc20 * ALPHA;
1887 acc21 = acc21 * ALPHA;
1888 #endif // defined(ALPHA)
1889 float2 acc2 = ((float2)(acc20, acc21));
1890 vstore2(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
1891 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1892 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1894 acc30 = acc30 * ALPHA;
1895 acc31 = acc31 * ALPHA;
1896 #endif // defined(ALPHA)
1897 float2 acc3 = (float2)(acc30, acc31);
1898 vstore2(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
1899 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1902 /** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
1904 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
1905 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
1906 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
1907 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
1908 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
1909 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1910 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
1912 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
1913 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1914 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1915 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1916 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1917 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1918 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1919 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1920 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1921 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1922 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1923 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1924 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1925 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1926 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1927 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1928 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1929 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1931 __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
1932 IMAGE_DECLARATION(src1),
1933 IMAGE_DECLARATION(dst),
1938 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1940 // Compute starting address for matrix A and Matrix B
1941 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1943 // Update address for the matrix A
1944 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1946 // Update address for the matrix B
1947 src_addr.s1 += idx * sizeof(half);
1949 // Add offset for batched GEMM
1950 src_addr.s0 += get_global_id(2) * src0_stride_z;
1952 #if defined(MATRIX_B_DEPTH)
1953 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1954 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1955 #else // defined(MATRIX_B_DEPTH)
1956 src_addr.s1 += get_global_id(2) * src1_stride_z;
1957 #endif // defined(MATRIX_B_DEPTH)
1960 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1962 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1963 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1965 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1966 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1968 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1971 for(; i <= ((int)COLS_A - 4); i += 4)
1973 // Load values from matrix A
1974 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1975 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1976 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1977 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1978 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1979 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1980 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1981 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1982 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1983 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1984 // Load values from matrix B
1985 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
1986 src_addr.s1 += src1_stride_y;
1989 acc0 = fma(b0, (half8)a0.s0, acc0);
1990 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1991 acc1 = fma(b0, (half8)a1.s0, acc1);
1992 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1993 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1994 acc2 = fma(b0, (half8)a2.s0, acc2);
1995 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1996 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1997 acc3 = fma(b0, (half8)a3.s0, acc3);
1998 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2000 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2001 src_addr.s1 += src1_stride_y;
2002 acc0 = fma(b0, (half8)a0.s1, acc0);
2003 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2004 acc1 = fma(b0, (half8)a1.s1, acc1);
2005 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2006 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2007 acc2 = fma(b0, (half8)a2.s1, acc2);
2008 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2009 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2010 acc3 = fma(b0, (half8)a3.s1, acc3);
2011 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2013 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2014 src_addr.s1 += src1_stride_y;
2015 acc0 = fma(b0, (half8)a0.s2, acc0);
2016 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2017 acc1 = fma(b0, (half8)a1.s2, acc1);
2018 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2019 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2020 acc2 = fma(b0, (half8)a2.s2, acc2);
2021 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2022 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2023 acc3 = fma(b0, (half8)a3.s2, acc3);
2024 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2026 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2027 src_addr.s1 += src1_stride_y;
2028 acc0 = fma(b0, (half8)a0.s3, acc0);
2029 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2030 acc1 = fma(b0, (half8)a1.s3, acc1);
2031 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2032 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2033 acc2 = fma(b0, (half8)a2.s3, acc2);
2034 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2035 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2036 acc3 = fma(b0, (half8)a3.s3, acc3);
2037 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2039 src_addr.s0 += 4 * sizeof(half);
2042 for(; i < (int)COLS_A; ++i)
2044 // Load values from matrix A
2045 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2046 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2047 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2048 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2049 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2050 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2051 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2052 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2053 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2054 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2055 // Load values from matrix B
2056 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
2058 src_addr += (int2)(sizeof(half), src1_stride_y);
2061 acc0 = fma(b0, (half8)a0, acc0); // b0 * (half8)a0;
2062 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2063 acc1 = fma(b0, (half8)a1, acc1); // b0 * (half8)a1;
2064 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2065 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2066 acc2 = fma(b0, (half8)a2, acc2); // b0 * (half8)a2;
2067 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2068 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2069 acc3 = fma(b0, (half8)a3, acc3); // b0 * (half8)a3;
2070 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2073 // Compute destination address
2074 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2076 // Compute dst address
2077 __global uchar *dst_addr = offset(&dst, 0, 0);
2079 // Add offset for batched GEMM
2080 dst_addr += get_global_id(2) * dst_stride_z;
2082 // Multiply by the weight of matrix-matrix product and store the result
2084 acc0 = acc0 * (half8)ALPHA;
2085 #endif // defined(ALPHA)
2086 vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
2087 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2089 acc1 = acc1 * (half8)ALPHA;
2090 #endif // defined(ALPHA)
2091 vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
2092 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2093 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2095 acc2 = acc2 * (half8)ALPHA;
2096 #endif // defined(ALPHA)
2097 vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
2098 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2099 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2101 acc3 = acc3 * (half8)ALPHA;
2102 #endif // defined(ALPHA)
2103 vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
2104 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2107 #if defined(FIXED_POINT_POSITION)
2108 /** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
2110 * @note This OpenCL kernel works with fixed point data types QS8
2111 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
2112 * @note The number matrix A columns, the number of elements processed per thread along the Y direction and the alpha's value need to be passed at compile time using -DCOLS_A, -DNUM_ELEMS_PROCESSED_PER_THREAD_Y and -DALPHA
2113 * @note The fixed point position need to be passed at compile time using -DFIXED_POINT_POSITION
2114 * @note The optional alpha value must be passed in 8 bit fixed point format using -DALPHA
2115 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2116 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
2118 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS8/QS16
2119 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2120 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2121 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2122 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2123 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2124 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2125 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2126 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2127 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2128 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2129 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
2130 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2131 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2132 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2133 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2134 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2135 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2137 __kernel void gemm_mm_qs8(IMAGE_DECLARATION(src0),
2138 IMAGE_DECLARATION(src1),
2139 IMAGE_DECLARATION(dst),
2144 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
2146 // Compute starting address for matrix A and Matrix B
2147 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
2149 // Update address for the matrix A
2150 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
2152 // Update address for the matrix B
2153 src_addr.s1 += idx * sizeof(char);
2155 // Add offset for batched GEMM
2156 src_addr.s0 += get_global_id(2) * src0_stride_z;
2158 #if defined(MATRIX_B_DEPTH)
2159 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2160 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2161 #else // defined(MATRIX_B_DEPTH)
2162 src_addr.s1 += get_global_id(2) * src1_stride_z;
2163 #endif // defined(MATRIX_B_DEPTH)
2165 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(char));
2169 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2172 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2173 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2176 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2177 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2180 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2182 // This for loop performs 4 accumulations per iteration
2183 for(; src_addr.s0 <= (end_row_vec_a - 2); src_addr += (int2)(2, 2 * src1_stride_y))
2185 char2 a0 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2186 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2187 char2 a1 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2188 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2189 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2190 char2 a2 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2191 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2192 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2193 char2 a3 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2194 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2195 char16 b0 = vload16(0, (__global char *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
2196 char16 b1 = vload16(0, (__global char *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
2198 acc00 = mlal_sat_qs8x8(acc00, (char8)a0.s0, b0.s01234567, FIXED_POINT_POSITION);
2199 acc00 = mlal_sat_qs8x8(acc00, (char8)a0.s1, b1.s01234567, FIXED_POINT_POSITION);
2200 acc01 = mlal_sat_qs8x8(acc01, (char8)a0.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
2201 acc01 = mlal_sat_qs8x8(acc01, (char8)a0.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
2202 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2203 acc10 = mlal_sat_qs8x8(acc10, (char8)a1.s0, b0.s01234567, FIXED_POINT_POSITION);
2204 acc10 = mlal_sat_qs8x8(acc10, (char8)a1.s1, b1.s01234567, FIXED_POINT_POSITION);
2205 acc11 = mlal_sat_qs8x8(acc11, (char8)a1.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
2206 acc11 = mlal_sat_qs8x8(acc11, (char8)a1.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
2207 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2208 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2209 acc20 = mlal_sat_qs8x8(acc20, (char8)a2.s0, b0.s01234567, FIXED_POINT_POSITION);
2210 acc20 = mlal_sat_qs8x8(acc20, (char8)a2.s1, b1.s01234567, FIXED_POINT_POSITION);
2211 acc21 = mlal_sat_qs8x8(acc21, (char8)a2.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
2212 acc21 = mlal_sat_qs8x8(acc21, (char8)a2.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
2213 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2214 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2215 acc30 = mlal_sat_qs8x8(acc30, (char8)a3.s0, b0.s01234567, FIXED_POINT_POSITION);
2216 acc30 = mlal_sat_qs8x8(acc30, (char8)a3.s1, b1.s01234567, FIXED_POINT_POSITION);
2217 acc31 = mlal_sat_qs8x8(acc31, (char8)a3.s0, b0.s89ABCDEF, FIXED_POINT_POSITION);
2218 acc31 = mlal_sat_qs8x8(acc31, (char8)a3.s1, b1.s89ABCDEF, FIXED_POINT_POSITION);
2219 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2222 // Left-over accumulations
2223 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y))
2225 char a0 = *((__global char *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2226 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2227 char a1 = *((__global char *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2228 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2229 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2230 char a2 = *((__global char *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2231 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2232 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2233 char a3 = *((__global char *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2234 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2235 char16 b0 = vload16(0, (__global char *)(src1_ptr + src_addr.s1));
2237 acc00 = mlal_sat_qs8x8(acc00, (char8)a0, b0.s01234567, FIXED_POINT_POSITION);
2238 acc01 = mlal_sat_qs8x8(acc01, (char8)a0, b0.s89ABCDEF, FIXED_POINT_POSITION);
2239 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2240 acc10 = mlal_sat_qs8x8(acc10, (char8)a1, b0.s01234567, FIXED_POINT_POSITION);
2241 acc11 = mlal_sat_qs8x8(acc11, (char8)a1, b0.s89ABCDEF, FIXED_POINT_POSITION);
2242 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2243 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2244 acc20 = mlal_sat_qs8x8(acc20, (char8)a2, b0.s01234567, FIXED_POINT_POSITION);
2245 acc21 = mlal_sat_qs8x8(acc21, (char8)a2, b0.s89ABCDEF, FIXED_POINT_POSITION);
2246 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2247 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2248 acc30 = mlal_sat_qs8x8(acc30, (char8)a3, b0.s01234567, FIXED_POINT_POSITION);
2249 acc31 = mlal_sat_qs8x8(acc31, (char8)a3, b0.s89ABCDEF, FIXED_POINT_POSITION);
2250 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2253 // Compute destination address
2254 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2256 // Compute dst address
2257 __global uchar *dst_addr = offset(&dst, 0, 0);
2259 // Add offset for batched GEMM
2260 dst_addr += get_global_id(2) * dst_stride_z;
2262 // Multiply by the weight of matrix product and store the result
2264 acc_qs8 = convert_char16_sat((short16)(acc00, acc01));
2266 acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
2267 #endif // defined(ALPHA)
2268 vstore16(acc_qs8, 0, (__global char *)(dst_addr + 0 * dst_stride_y));
2269 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2270 acc_qs8 = convert_char16_sat((short16)(acc10, acc11));
2272 acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
2273 #endif // defined(ALPHA)
2274 vstore16(acc_qs8, 0, (__global char *)(dst_addr + 1 * dst_stride_y));
2275 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2276 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2277 acc_qs8 = convert_char16_sat((short16)(acc20, acc21));
2279 acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
2280 #endif // defined(ALPHA)
2281 vstore16(acc_qs8, 0, (__global char *)(dst_addr + 2 * dst_stride_y));
2282 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2283 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2284 acc_qs8 = convert_char16_sat((short16)(acc30, acc31));
2286 acc_qs8 = mul_sat_qs8x16(acc_qs8, (char16)ALPHA, FIXED_POINT_POSITION);
2287 #endif // defined(ALPHA)
2288 vstore16(acc_qs8, 0, (__global char *)(dst_addr + 3 * dst_stride_y));
2289 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2292 /** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
2294 * @note This OpenCL kernel works with fixed point data types QS16
2295 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
2296 * @note The number of matrix A columns, the number of elements processed per thread along the Y direction and the alpha's value need to be passed at compile time using -DCOLS_A, -DNUM_ELEMS_PROCESSED_PER_THREAD_Y and -DALPHA
2297 * @note The fixed point position need to be passed at compile time using -DFIXED_POINT_POSITION
2298 * @note The optional alpha value must be passed in 16 bit fixed point format using -DALPHA
2299 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2300 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
2302 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: QS8/QS16
2303 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2304 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2305 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2306 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2307 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2308 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2309 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2310 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2311 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2312 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2313 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
2314 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2315 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2316 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2317 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2318 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2319 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2321 __kernel void gemm_mm_qs16(IMAGE_DECLARATION(src0),
2322 IMAGE_DECLARATION(src1),
2323 IMAGE_DECLARATION(dst),
2328 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
2330 // Compute starting address for matrix A and Matrix B
2331 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
2333 // Update address for the matrix A
2334 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
2336 // Update address for the matrix B
2337 src_addr.s1 += idx * sizeof(short);
2339 // Add offset for batched GEMM
2340 src_addr.s0 += get_global_id(2) * src0_stride_z;
2342 #if defined(MATRIX_B_DEPTH)
2343 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2344 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2345 #else // defined(MATRIX_B_DEPTH)
2346 src_addr.s1 += get_global_id(2) * src1_stride_z;
2347 #endif // defined(MATRIX_B_DEPTH)
2349 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(short));
2352 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2354 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2355 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2357 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2358 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2360 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2362 // This for loop performs 4 accumulations per iteration
2363 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(short)); src_addr += (int2)(2 * sizeof(short), 2 * src1_stride_y))
2365 short2 a0 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2366 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2367 short2 a1 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2368 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2369 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2370 short2 a2 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2371 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2372 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2373 short2 a3 = vload2(0, (__global short *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2374 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2375 short8 b0 = vload8(0, (__global short *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
2376 short8 b1 = vload8(0, (__global short *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
2378 acc0 = mlal_sat_qs16x8(acc0, (short8)a0.s0, b0, FIXED_POINT_POSITION);
2379 acc0 = mlal_sat_qs16x8(acc0, (short8)a0.s1, b1, FIXED_POINT_POSITION);
2380 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2381 acc1 = mlal_sat_qs16x8(acc1, (short8)a1.s0, b0, FIXED_POINT_POSITION);
2382 acc1 = mlal_sat_qs16x8(acc1, (short8)a1.s1, b1, FIXED_POINT_POSITION);
2383 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2384 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2385 acc2 = mlal_sat_qs16x8(acc2, (short8)a2.s0, b0, FIXED_POINT_POSITION);
2386 acc2 = mlal_sat_qs16x8(acc2, (short8)a2.s1, b1, FIXED_POINT_POSITION);
2387 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2388 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2389 acc3 = mlal_sat_qs16x8(acc3, (short8)a3.s0, b0, FIXED_POINT_POSITION);
2390 acc3 = mlal_sat_qs16x8(acc3, (short8)a3.s1, b1, FIXED_POINT_POSITION);
2391 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2394 // Left-over accumulations
2395 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(short), src1_stride_y))
2397 short a0 = *((__global short *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2398 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2399 short a1 = *((__global short *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2400 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2401 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2402 short a2 = *((__global short *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2403 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2404 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2405 short a3 = *((__global short *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2406 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2407 short8 b0 = vload8(0, (__global short *)(src1_ptr + src_addr.s1));
2409 acc0 = mlal_sat_qs16x8(acc0, (short8)a0, b0, FIXED_POINT_POSITION);
2410 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2411 acc1 = mlal_sat_qs16x8(acc1, (short8)a1, b0, FIXED_POINT_POSITION);
2412 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2413 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2414 acc2 = mlal_sat_qs16x8(acc2, (short8)a2, b0, FIXED_POINT_POSITION);
2415 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2416 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2417 acc3 = mlal_sat_qs16x8(acc3, (short8)a3, b0, FIXED_POINT_POSITION);
2418 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2421 // Compute destination address
2422 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2424 // Compute dst address
2425 __global uchar *dst_addr = offset(&dst, 0, 0);
2427 // Add offset for batched GEMM
2428 dst_addr += get_global_id(2) * dst_stride_z;
2430 // Multiply by the weight of matrix product and store the result
2432 acc_qs16 = convert_short8_sat(acc0);
2434 acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
2435 #endif // defined(ALPHA)
2436 vstore8(acc_qs16, 0, (__global short *)(dst_addr + 0 * dst_stride_y));
2437 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2438 acc_qs16 = convert_short8_sat(acc1);
2440 acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
2441 #endif // defined(ALPHA)
2442 vstore8(acc_qs16, 0, (__global short *)(dst_addr + 1 * dst_stride_y));
2443 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2444 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2445 acc_qs16 = convert_short8_sat(acc2);
2447 acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
2448 #endif // defined(ALPHA)
2449 vstore8(acc_qs16, 0, (__global short *)(dst_addr + 2 * dst_stride_y));
2450 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2451 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2452 acc_qs16 = convert_short8_sat(acc3);
2454 acc_qs16 = mul_sat_qs16x8(acc_qs16, (short8)ALPHA, FIXED_POINT_POSITION);
2455 #endif // defined(ALPHA)
2456 vstore8(acc_qs16, 0, (__global short *)(dst_addr + 3 * dst_stride_y));
2457 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2459 #endif // defined(FIXED_POINT_POSITION)
2460 #endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
2463 /** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
2465 * @note The beta's value need to be passed at compile time using -DBETA
2467 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F32
2468 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
2469 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2470 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
2471 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2472 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
2473 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
2474 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2475 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2476 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2477 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2478 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2480 __kernel void gemm_ma_f32(IMAGE_DECLARATION(src),
2481 IMAGE_DECLARATION(dst))
2483 // Compute source and destination addresses
2484 Image src = CONVERT_TO_IMAGE_STRUCT(src);
2485 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2487 // Load values from A x B
2488 float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
2490 // Load values from Matrix C
2491 float4 c = vload4(0, (__global float *)src.ptr);
2493 // Computes alpha * axb + beta * c
2494 float4 out = alpha_ab + (float4)BETA * c;
2496 // Store final result in axb matrix
2497 vstore4(out, 0, (__global float *)dst.ptr);
2500 /** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
2502 * @note The beta's value need to be passed at compile time using -DBETA
2504 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F16
2505 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
2506 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2507 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
2508 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2509 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
2510 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
2511 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2512 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2513 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2514 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2515 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2517 __kernel void gemm_ma_f16(IMAGE_DECLARATION(src),
2518 IMAGE_DECLARATION(dst))
2520 // Compute source and destination addresses
2521 Image src = CONVERT_TO_IMAGE_STRUCT(src);
2522 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2524 // Load values from A x B
2525 half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
2527 // Load values from Matrix C
2528 half8 c = vload8(0, (__global half *)src.ptr);
2530 // Computes alpha * axb + beta * c
2531 half8 out = alpha_ab + (half8)BETA * c;
2533 // Store final result in axb matrix
2534 vstore8(out, 0, (__global half *)dst.ptr);
2537 #if defined(FIXED_POINT_POSITION)
2538 /** This OpenCL kernel performs the in-place matrix addition between 2 matrices in 8 bit fixed point taking into account that the second matrix might be weighted by a scalar value beta:
2540 * @note The beta's value and the fixed point position need to be passed at compile time using -DBETA and -DFIXED_POINT_POSITION
2542 * @note: BETA must be passed in 8 bit fixed point format
2544 * @param[in] src_ptr Pointer to the source matrix. Supported data types: QS8
2545 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
2546 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2547 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
2548 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2549 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
2550 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
2551 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2552 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2553 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2554 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2555 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2557 __kernel void gemm_ma_qs8(IMAGE_DECLARATION(src),
2558 IMAGE_DECLARATION(dst))
2560 // Compute source and destination addresses
2561 Image src = CONVERT_TO_IMAGE_STRUCT(src);
2562 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2564 // Load values from A x B
2565 char16 alpha_ab = vload16(0, (__global char *)dst.ptr);
2567 // Load values from Matrix C
2568 char16 c = vload16(0, (__global char *)src.ptr);
2570 // Computes alpha * axb + beta * c
2571 char16 out = mla_sat_qs8x16(alpha_ab, (char16)BETA, c, FIXED_POINT_POSITION);
2573 // Store final result in axb matrix
2574 vstore16(out, 0, (__global char *)dst.ptr);
2577 /** This OpenCL kernel performs the in-place matrix addition between 2 matrices in 16 bit fixed point taking into account that the second matrix might be weighted by a scalar value beta:
2579 * @note The beta's value and the fixed point position need to be passed at compile time using -DBETA and -DFIXED_POINT_POSITION
2581 * @note: BETA must be passed in 16 bit fixed point format
2583 * @param[in] src_ptr Pointer to the source matrix. Supported data types: QS16
2584 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
2585 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2586 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
2587 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2588 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
2589 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
2590 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2591 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2592 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2593 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2594 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2596 __kernel void gemm_ma_qs16(IMAGE_DECLARATION(src),
2597 IMAGE_DECLARATION(dst))
2599 // Compute source and destination addresses
2600 Image src = CONVERT_TO_IMAGE_STRUCT(src);
2601 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2603 // Load values from A x B
2604 short8 alpha_ab = vload8(0, (__global short *)dst.ptr);
2606 // Load values from Matrix C
2607 short8 c = vload8(0, (__global short *)src.ptr);
2609 // Computes alpha * axb + beta * c
2610 short8 out = mla_sat_qs16x8(alpha_ab, (short8)BETA, c, FIXED_POINT_POSITION);
2612 // Store final result in axb matrix
2613 vstore8(out, 0, (__global short *)dst.ptr);
2615 #endif // defined(FIXED_POINT_POSITION)
2616 #endif // defined(BETA)
2618 #if defined(WIDTH_VECTOR_A)
2619 /** This OpenCL kernel computes the vector by matrix multiplication between each row of A (src0) and matrix B (src1) used for locally connected layer
2621 * @note The width of A need to be passed at compile time using -DWIDTH_VECTOR_A
2623 * @note The input A and matrix B must not be reshaped
2625 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
2626 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2627 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2628 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2629 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2630 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2631 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2632 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2633 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2634 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2635 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2636 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2637 * @param[in] src1_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
2638 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
2639 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2640 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2641 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2642 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2643 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2644 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2646 __kernel void gemm_lc_vm_f32(IMAGE_DECLARATION(src0),
2647 TENSOR3D_DECLARATION(src1),
2648 IMAGE_DECLARATION(dst))
2650 int idx = get_global_id(0) * 4;
2651 int idy = get_global_id(1);
2653 // Compute the address for the vector A and matrix B
2654 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes + src0_stride_y * idy, src1_offset_first_element_in_bytes + src1_stride_z * idy));
2655 src_addr.s1 += idx * sizeof(float);
2657 int end_row_vec_a = src_addr.s0 + (WIDTH_VECTOR_A * sizeof(float));
2661 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
2663 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0));
2664 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
2665 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + src1_stride_y));
2667 acc += b0 * (float4)a0.s0;
2668 acc += b1 * (float4)a0.s1;
2671 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
2673 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
2674 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
2676 acc += b0 * (float4)a0;
2679 // Compute destination address
2680 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2682 vstore4(acc, 0, (__global float *)(offset(&dst, 0, 0)));
2684 #endif // defined(WIDTH_VECTOR_A)
2686 /** This kernel accumulates each row with the biases vector.
2688 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=short.
2689 * @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=16.
2691 * @param[in, out] accum_ptr Pointer to the accumulate tensor. Supported data type: U8/S8/QS8/U16/S16/F16/U32/S32/F32
2692 * @param[in] accum_stride_x Stride of the accmulate tensor in X dimension (in bytes)
2693 * @param[in] accum_step_x accum_stride_x * number of elements along X processed per workitem(in bytes)
2694 * @param[in] accum_stride_y Stride of the accumlulate tensor in Y dimension (in bytes)
2695 * @param[in] accum_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2696 * @param[in] accum_offset_first_element_in_bytes The offset of the first element in the accumulate tensor
2697 * @param[in] biases_ptr Pointer to the biases vector. Same as @p accum_ptr
2698 * @param[in] biases_stride_x Stride of the destination tensor in X dimension (in bytes)
2699 * @param[in] biases_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2700 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the destination tensor
2702 #if defined(DATA_TYPE) && defined(VECTOR_SIZE)
2703 __kernel void gemm_accumulate_biases(
2704 IMAGE_DECLARATION(accum),
2705 VECTOR_DECLARATION(biases))
2707 Image accum = CONVERT_TO_IMAGE_STRUCT(accum);
2708 Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
2710 // Vector size, i.e. number of vector elements.
2711 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
2712 accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
2713 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
2714 biases_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
2715 #ifdef FIXED_POINT_POSITION
2716 accum_value = ADD_SAT_OP_EXPAND(biases_value, accum_value, DATA_TYPE, VECTOR_SIZE);
2717 #else // FIXED_POINT_POSITION
2718 accum_value = biases_value + accum_value;
2719 #endif // FIXED_POINT_POSITION
2720 // Store result in the accumulate buffer
2722 (accum_value, 0, (__global DATA_TYPE *)accum.ptr);
2724 #endif // defined(DATA_TYPE) && defined(VECTOR_SIZE)