2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 * Copyright (c) 2017-2019 ARM Limited.
20 * SPDX-License-Identifier: MIT
22 * Permission is hereby granted, free of charge, to any person obtaining a copy
23 * of this software and associated documentation files (the "Software"), to
24 * deal in the Software without restriction, including without limitation the
25 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
26 * sell copies of the Software, and to permit persons to whom the Software is
27 * furnished to do so, subject to the following conditions:
29 * The above copyright notice and this permission notice shall be included in all
30 * copies or substantial portions of the Software.
32 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
33 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
34 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
35 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
36 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
37 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
43 #if defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) && \
45 #define VECTOR_CHAR VEC_DATA_TYPE(char, NUM_ELEMS_PROCESSED_PER_THREAD_X)
46 #define VECTOR_INT VEC_DATA_TYPE(int, NUM_ELEMS_PROCESSED_PER_THREAD_X)
47 #define VECTOR_FLOAT VEC_DATA_TYPE(float, NUM_ELEMS_PROCESSED_PER_THREAD_X)
48 /** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B
49 * (src1) in case both matrices have not beed reshaped
51 * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
53 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following
54 * information must be passed at compile time:
55 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
56 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
57 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D
59 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
60 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
62 * @param[in] src0_ptr Pointer to the source matrix. Supported data type:
64 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in
66 * @param[in] src0_step_x src_stride_x * number of elements along X
67 * processed per workitem(in bytes)
68 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in
70 * @param[in] src0_step_y src_stride_y * number of elements along Y
71 * processed per workitem(in bytes)
72 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source
74 * @param[in] src1_ptr Pointer to the source matrix. Supported data type:
76 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in
78 * @param[in] src1_step_x src_stride_x * number of elements along X
79 * processed per workitem(in bytes)
80 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in
82 * @param[in] src1_step_y src_stride_y * number of elements along Y
83 * processed per workitem(in bytes)
84 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source
86 * @param[out] dst_ptr Pointer to the destination matrix Supported data
88 * @param[in] dst_stride_x Stride of the destination matrix in X dimension
90 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X
91 * processed per workitem(in bytes)
92 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension
94 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y
95 * processed per workitem(in bytes)
96 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination
98 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in
100 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in
102 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension
104 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for
105 * the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
106 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for
107 * the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
109 __kernel void gemmlowp_mm_midgard_ex(IMAGE_DECLARATION(src0), IMAGE_DECLARATION(src1),
110 IMAGE_DECLARATION(dst), uint src0_stride_z, uint src1_stride_z,
112 #if defined(REINTERPRET_INPUT_AS_3D)
114 uint src_cross_plane_pad
115 #endif // REINTERPRET_INPUT_AS_3D
116 #if defined(REINTERPRET_OUTPUT_AS_3D)
118 uint dst_cross_plane_pad
119 #endif // REINTERPRET_OUTPUT_AS_3D
122 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
124 // Compute starting address for matrix A and Matrix B
125 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
127 // Update address for the matrix A
128 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
130 // Update address for the matrix B
133 #if defined(REINTERPRET_INPUT_AS_3D)
134 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across
136 // in order to take into account the presence of possible cross plane paddings
141 // |__________________|
142 // |******************|
143 // | cross_plane_pad |
144 // |******************|
148 // |__________________|
150 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)
152 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) /
153 (uint4)HEIGHT_GEMM3D;
154 zin = min(DEPTH_GEMM3D - 1, zin);
156 // Add offset due to the cross plane paddings
157 zin *= (src_cross_plane_pad * src0_stride_y);
159 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
160 // multiply src0_stride_z by DEPTH_GEMM3D
161 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
163 #else // defined(REINTERPRET_INPUT_AS_3D)
165 // Add offset for batched GEMM
166 src_addr.s0 += get_global_id(2) * src0_stride_z;
168 #endif // defined(REINTERPRET_INPUT_AS_3D)
170 #if defined(MATRIX_B_DEPTH)
171 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
172 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
173 #else // defined(MATRIX_B_DEPTH)
174 src_addr.s1 += get_global_id(2) * src1_stride_z;
175 #endif // defined(MATRIX_B_DEPTH)
177 int end_row_vec_a = src_addr.s0 + COLS_A;
180 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
182 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
183 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
185 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
186 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
188 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
189 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
191 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
193 for (; src_addr.s0 <= (end_row_vec_a - 2); src_addr += (int2)(2, 2 * src1_stride_y))
195 // Load values from matrix A
196 char2 a0 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
197 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
198 char2 a1 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
199 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
200 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
201 char2 a2 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
202 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
203 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
204 char2 a3 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
205 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
206 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
207 char2 a4 = vload2(0, (__global char *)(src0_ptr + src_addr.s0 + 4 * src0_stride_y));
208 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
209 // Load values from matrix B
211 VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global char *)(src1_ptr + src_addr.s1));
212 VECTOR_CHAR b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(
213 0, (__global char *)(src1_ptr + src_addr.s1 + src1_stride_y));
216 acc0 += CONVERT(b0, VECTOR_INT) * (VECTOR_INT)a0.s0;
217 acc0 += CONVERT(b1, VECTOR_INT) * (VECTOR_INT)a0.s1;
218 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
219 acc1 += CONVERT(b0, VECTOR_INT) * (VECTOR_INT)a1.s0;
220 acc1 += CONVERT(b1, VECTOR_INT) * (VECTOR_INT)a1.s1;
221 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
222 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
223 acc2 += CONVERT(b0, VECTOR_INT) * (VECTOR_INT)a2.s0;
224 acc2 += CONVERT(b1, VECTOR_INT) * (VECTOR_INT)a2.s1;
225 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
226 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
227 acc3 += CONVERT(b0, VECTOR_INT) * (VECTOR_INT)a3.s0;
228 acc3 += CONVERT(b1, VECTOR_INT) * (VECTOR_INT)a3.s1;
229 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
230 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
231 acc4 += CONVERT(b0, VECTOR_INT) * (VECTOR_INT)a4.s0;
232 acc4 += CONVERT(b1, VECTOR_INT) * (VECTOR_INT)a4.s1;
233 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
236 for (; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y))
238 // Load values from matrix A
239 char a0 = *(__global char *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y);
240 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
241 char a1 = *(__global char *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y);
242 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
243 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
244 char a2 = *(__global char *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y);
245 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
246 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
247 char a3 = *(__global char *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y);
248 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
249 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
250 char a4 = *(__global char *)(src0_ptr + src_addr.s0 + 4 * src0_stride_y);
251 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
252 // Load values from matrix B
254 VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global char *)(src1_ptr + src_addr.s1));
257 acc0 += CONVERT(b0, VECTOR_INT) * (VECTOR_INT)a0;
258 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
259 acc1 += CONVERT(b0, VECTOR_INT) * (VECTOR_INT)a1;
260 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
261 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
262 acc2 += CONVERT(b0, VECTOR_INT) * (VECTOR_INT)a2;
263 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
264 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
265 acc3 += CONVERT(b0, VECTOR_INT) * (VECTOR_INT)a3;
266 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
267 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
268 acc4 += CONVERT(b0, VECTOR_INT) * (VECTOR_INT)a4;
269 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
272 const int z = get_global_id(2);
274 // Compute destination address
275 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
277 #if defined(REINTERPRET_OUTPUT_AS_3D)
278 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across
280 // in order to take into account the presence of possible cross plane paddings
285 // |__________________|
286 // |******************|
287 // | cross_plane_pad |
288 // |******************|
292 // |__________________|
294 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)
296 uint8 zout = ((uint8)(0, 1, 2, 3, 4, 5, 6, 7) +
297 (uint8)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) /
298 (uint8)HEIGHT_GEMM3D;
299 zout = min(DEPTH_GEMM3D - 1, zout);
301 // Add offset due to the cross plane paddings
302 zout *= (dst_cross_plane_pad * dst_stride_y);
304 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
305 // multiply dst_stride_z by DEPTH_GEMM3D
306 dst.ptr += z * dst_stride_z * DEPTH_GEMM3D;
309 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
310 (CONVERT(acc0, VECTOR_INT), 0, (__global int *)(dst.ptr + 0 * dst_stride_y + zout.s0));
311 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
312 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
313 (CONVERT(acc1, VECTOR_INT), 0, (__global int *)(dst.ptr + 1 * dst_stride_y + zout.s1));
314 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
315 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
316 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
317 (CONVERT(acc2, VECTOR_INT), 0, (__global int *)(dst.ptr + 2 * dst_stride_y + zout.s2));
318 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
319 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
320 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
321 (CONVERT(acc3, VECTOR_INT), 0, (__global int *)(dst.ptr + 3 * dst_stride_y + zout.s3));
322 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
323 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
324 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
325 (CONVERT(acc4, VECTOR_INT), 0, (__global int *)(dst.ptr + 4 * dst_stride_y + zout.s4));
326 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
328 #else // defined(REINTERPRET_OUTPUT_AS_3D)
329 // Add offset for batched GEMM
330 dst.ptr += z * dst_stride_z;
333 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
334 (CONVERT(acc0, VECTOR_INT), 0, (__global int *)(dst.ptr + 0 * dst_stride_y));
335 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
336 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
337 (CONVERT(acc1, VECTOR_INT), 0, (__global int *)(dst.ptr + 1 * dst_stride_y));
338 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
339 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
340 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
341 (CONVERT(acc2, VECTOR_INT), 0, (__global int *)(dst.ptr + 2 * dst_stride_y));
342 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
343 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
344 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
345 (CONVERT(acc3, VECTOR_INT), 0, (__global int *)(dst.ptr + 3 * dst_stride_y));
346 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
347 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
348 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
349 (CONVERT(acc4, VECTOR_INT), 0, (__global int *)(dst.ptr + 4 * dst_stride_y));
350 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
351 #endif // defined(REINTERPRET_OUTPUT_AS_3D)
353 #endif // defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) &&