1 // Copyright (c) 2018 Intel Corporation
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
15 #include "include/mmad.cl"
17 #define SUM_SCALE 0.11f
20 #ifdef LIGHTWEIGHT_QUANTIZATION
22 #define QUANTIZATION(idx) \
25 for(uint z = 0; z < 4; z++)\
27 tmp.s0 = (float)eltw_input_vals[z * 4 + 0] * SUM_SCALE + bias_f.s0;\
28 tmp.s1 = (float)eltw_input_vals[z * 4 + 1] * SUM_SCALE + bias_f.s1;\
29 tmp.s2 = (float)eltw_input_vals[z * 4 + 2] * SUM_SCALE + bias_f.s2;\
30 tmp.s3 = (float)eltw_input_vals[z * 4 + 3] * SUM_SCALE + bias_f.s3;\
32 regC_uchar16[z * 4 + 0] = convert_uchar_sat( (regC[0 * 4 + i][idx + z / 4]) * SCALE + tmp.s0);\
33 regC_uchar16[z * 4 + 1] = convert_uchar_sat( (regC[1 * 4 + i][idx + z / 4]) * SCALE + tmp.s1);\
34 regC_uchar16[z * 4 + 2] = convert_uchar_sat( (regC[2 * 4 + i][idx + z / 4]) * SCALE + tmp.s2);\
35 regC_uchar16[z * 4 + 3] = convert_uchar_sat( (regC[3 * 4 + i][idx + z / 4]) * SCALE + tmp.s3);\
41 #define QUANTIZATION(idx) \
42 regC_uchar16.s0 = regC[0 * 4 + i][idx];\
43 regC_uchar16.s1 = regC[1 * 4 + i][idx];\
44 regC_uchar16.s2 = regC[2 * 4 + i][idx];\
45 regC_uchar16.s3 = regC[3 * 4 + i][idx];\
47 regC_uchar16.s4 = regC[0 * 4 + i][idx+1];\
48 regC_uchar16.s5 = regC[1 * 4 + i][idx+1];\
49 regC_uchar16.s6 = regC[2 * 4 + i][idx+1];\
50 regC_uchar16.s7 = regC[3 * 4 + i][idx+1];\
52 regC_uchar16.s8 = regC[0 * 4 + i][idx+2];\
53 regC_uchar16.s9 = regC[1 * 4 + i][idx+2];\
54 regC_uchar16.sa = regC[2 * 4 + i][idx+2];\
55 regC_uchar16.sb = regC[3 * 4 + i][idx+2];\
57 regC_uchar16.sc = regC[0 * 4 + i][idx+3];\
58 regC_uchar16.sd = regC[1 * 4 + i][idx+3];\
59 regC_uchar16.se = regC[2 * 4 + i][idx+3];\
60 regC_uchar16.sf = regC[3 * 4 + i][idx+3];\
63 for(uint s = 0; s <16; s++)\
65 sum[s] = (int)as_char(regC_uchar16[s]) + (int)as_char(eltw_input_vals[s]);\
67 regC_uchar16.s0 = convert_uchar_sat( sum.s0 );\
68 regC_uchar16.s1 = convert_uchar_sat( sum.s1 );\
69 regC_uchar16.s2 = convert_uchar_sat( sum.s2 );\
70 regC_uchar16.s3 = convert_uchar_sat( sum.s3 );\
72 regC_uchar16.s4 = convert_uchar_sat( sum.s4 );\
73 regC_uchar16.s5 = convert_uchar_sat( sum.s5 );\
74 regC_uchar16.s6 = convert_uchar_sat( sum.s6 );\
75 regC_uchar16.s7 = convert_uchar_sat( sum.s7 );\
77 regC_uchar16.s8 = convert_uchar_sat( sum.s8 );\
78 regC_uchar16.s9 = convert_uchar_sat( sum.s9 );\
79 regC_uchar16.sa = convert_uchar_sat( sum.sa );\
80 regC_uchar16.sb = convert_uchar_sat( sum.sb );\
82 regC_uchar16.sc = convert_uchar_sat( sum.sc );\
83 regC_uchar16.sd = convert_uchar_sat( sum.sd );\
84 regC_uchar16.se = convert_uchar_sat( sum.se );\
85 regC_uchar16.sf = convert_uchar_sat( sum.sf );\
90 #define QUANTIZATION(idx) \
91 regC_uchar16.s0 = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[0 * 4 + i][idx]) * quant_f.s0 * I_QF + bias_f.s0) * calib_f.s0)), NL_M, NL_N));\
92 regC_uchar16.s1 = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[1 * 4 + i][idx]) * quant_f.s1 * I_QF + bias_f.s1) * calib_f.s1)), NL_M, NL_N));\
93 regC_uchar16.s2 = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[2 * 4 + i][idx]) * quant_f.s2 * I_QF + bias_f.s2) * calib_f.s2)), NL_M, NL_N));\
94 regC_uchar16.s3 = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[3 * 4 + i][idx]) * quant_f.s3 * I_QF + bias_f.s3) * calib_f.s3)), NL_M, NL_N));\
96 regC_uchar16.s4 = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[0 * 4 + i][idx+1]) * quant_f.s0 * I_QF + bias_f.s0) * calib_f.s0)), NL_M, NL_N));\
97 regC_uchar16.s5 = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[1 * 4 + i][idx+1]) * quant_f.s1 * I_QF + bias_f.s1) * calib_f.s1)), NL_M, NL_N));\
98 regC_uchar16.s6 = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[2 * 4 + i][idx+1]) * quant_f.s2 * I_QF + bias_f.s2) * calib_f.s2)), NL_M, NL_N));\
99 regC_uchar16.s7 = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[3 * 4 + i][idx+1]) * quant_f.s3 * I_QF + bias_f.s3) * calib_f.s3)), NL_M, NL_N));\
101 regC_uchar16.s8 = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[0 * 4 + i][idx+2]) * quant_f.s0 * I_QF + bias_f.s0) * calib_f.s0)), NL_M, NL_N));\
102 regC_uchar16.s9 = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[1 * 4 + i][idx+2]) * quant_f.s1 * I_QF + bias_f.s1) * calib_f.s1)), NL_M, NL_N));\
103 regC_uchar16.sa = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[2 * 4 + i][idx+2]) * quant_f.s2 * I_QF + bias_f.s2) * calib_f.s2)), NL_M, NL_N));\
104 regC_uchar16.sb = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[3 * 4 + i][idx+2]) * quant_f.s3 * I_QF + bias_f.s3) * calib_f.s3)), NL_M, NL_N));\
106 regC_uchar16.sc = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[0 * 4 + i][idx+3]) * quant_f.s0 * I_QF + bias_f.s0) * calib_f.s0)), NL_M, NL_N));\
107 regC_uchar16.sd = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[1 * 4 + i][idx+3]) * quant_f.s1 * I_QF + bias_f.s1) * calib_f.s1)), NL_M, NL_N));\
108 regC_uchar16.se = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[2 * 4 + i][idx+3]) * quant_f.s2 * I_QF + bias_f.s2) * calib_f.s2)), NL_M, NL_N));\
109 regC_uchar16.sf = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[3 * 4 + i][idx+3]) * quant_f.s3 * I_QF + bias_f.s3) * calib_f.s3)), NL_M, NL_N));\
112 for(uint s = 0; s <16; s++)\
114 sum[s] = (int)as_char(regC_uchar16[s]) + (int)as_char(eltw_input_vals[s]);\
116 regC_uchar16.s0 = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.s0) * eltw_calib_f.s0)), NL_M_ELTW, NL_N_ELTW));\
117 regC_uchar16.s1 = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.s1) * eltw_calib_f.s1)), NL_M_ELTW, NL_N_ELTW));\
118 regC_uchar16.s2 = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.s2) * eltw_calib_f.s2)), NL_M_ELTW, NL_N_ELTW));\
119 regC_uchar16.s3 = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.s3) * eltw_calib_f.s3)), NL_M_ELTW, NL_N_ELTW));\
121 regC_uchar16.s4 = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.s4) * eltw_calib_f.s0)), NL_M_ELTW, NL_N_ELTW));\
122 regC_uchar16.s5 = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.s5) * eltw_calib_f.s1)), NL_M_ELTW, NL_N_ELTW));\
123 regC_uchar16.s6 = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.s6) * eltw_calib_f.s2)), NL_M_ELTW, NL_N_ELTW));\
124 regC_uchar16.s7 = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.s7) * eltw_calib_f.s3)), NL_M_ELTW, NL_N_ELTW));\
126 regC_uchar16.s8 = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.s8) * eltw_calib_f.s0)), NL_M_ELTW, NL_N_ELTW));\
127 regC_uchar16.s9 = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.s9) * eltw_calib_f.s1)), NL_M_ELTW, NL_N_ELTW));\
128 regC_uchar16.sa = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.sa) * eltw_calib_f.s2)), NL_M_ELTW, NL_N_ELTW));\
129 regC_uchar16.sb = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.sb) * eltw_calib_f.s3)), NL_M_ELTW, NL_N_ELTW));\
131 regC_uchar16.sc = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.sc) * eltw_calib_f.s0)), NL_M_ELTW, NL_N_ELTW));\
132 regC_uchar16.sd = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.sd) * eltw_calib_f.s1)), NL_M_ELTW, NL_N_ELTW));\
133 regC_uchar16.se = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.se) * eltw_calib_f.s2)), NL_M_ELTW, NL_N_ELTW));\
134 regC_uchar16.sf = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.sf) * eltw_calib_f.s3)), NL_M_ELTW, NL_N_ELTW));\
138 inline uint FUNC(calculate_output_offset_to_account_padding)(uint cOffset)
140 #if OUT_WITH_PADDING == 1
141 uint tmp_idx = cOffset;
142 uint f_val_idx = tmp_idx % 32;
144 uint b_val_idx = tmp_idx % 4;
146 uint x_idx = tmp_idx % OUTPUT_SIZE_X;
147 tmp_idx /= OUTPUT_SIZE_X;
148 uint y_idx = tmp_idx % OUTPUT_SIZE_Y;
149 tmp_idx /= OUTPUT_SIZE_Y;
150 uint b_slice_idx = tmp_idx % (OUTPUT_BATCH_NUM / 4);
151 tmp_idx /= (OUTPUT_BATCH_NUM / 4);
152 uint f_slice_idx = tmp_idx % (OUTPUT_FEATURE_NUM / 32);
154 uint padded_offset = f_slice_idx * OUT_F_BLOCK_PITCH;
155 padded_offset += b_slice_idx * OUT_B_BLOCK_PITCH;
156 padded_offset += y_idx * OUT_Y_PITCH;
157 padded_offset += x_idx * OUT_X_PITCH;
158 padded_offset += b_val_idx * 32;
159 padded_offset += f_val_idx;
160 padded_offset += OUT_OFFSET;
162 return padded_offset;
169 inline uint FUNC(calculate_eltw_input_offset_based_on_output_offset_account_padding)(uint cOffset, uint strideX, uint strideY)
171 #if ELTW_WITH_PADDING == 1 || ELTW_STRIDE_X != 1 || ELTW_STRIDE_Y != 1
172 uint tmp_idx = cOffset;
173 uint f_val_idx = tmp_idx % 32;
175 uint b_val_idx = tmp_idx % 4;
177 uint x_idx = tmp_idx % OUTPUT_SIZE_X;
179 tmp_idx /= OUTPUT_SIZE_X;
180 uint y_idx = tmp_idx % OUTPUT_SIZE_Y;
182 tmp_idx /= OUTPUT_SIZE_Y;
183 uint b_slice_idx = tmp_idx % (OUTPUT_BATCH_NUM / 4);
184 tmp_idx /= (OUTPUT_BATCH_NUM / 4);
185 uint f_slice_idx = tmp_idx % (OUTPUT_FEATURE_NUM / 32);
187 uint padded_offset = f_slice_idx * IN2_F_BLOCK_PITCH;
188 padded_offset += b_slice_idx * IN2_B_BLOCK_PITCH;
189 padded_offset += y_idx * IN2_Y_PITCH;
190 padded_offset += x_idx * IN2_X_PITCH;
191 padded_offset += b_val_idx * 32;
192 padded_offset += f_val_idx;
193 padded_offset += IN2_OFFSET;
195 return padded_offset;
202 inline void FUNC(mmad_32x32_int8)( __local uint* l_tileA, const uint l_offsetTileA,
203 __local int8* l_tileB, const uint l_offsetTileB_col0,
204 const uint l_offsetTileB_col1, const uint l_offsetTileB_col2,
205 const uint l_offsetTileB_col3, int8* rowA, int8* colB,
208 // Read tile A from SLM to regA
209 uint l_offsetTileATemp = l_offsetTileA;
210 __attribute__((opencl_unroll_hint(SG_TILE_M / 8)))
211 for (uint j = 0; j < (SG_TILE_M / 8); ++j)
213 rowA[j] = as_int8(SLM_BLOCK_READ_8(&l_tileA[l_offsetTileATemp]));
214 l_offsetTileATemp += 8 * SG_SIZE;
216 // Read tile B from SLM to regB and compute mmad
217 colB[0] = l_tileB[l_offsetTileB_col0];
218 colB[1] = l_tileB[l_offsetTileB_col1];
219 __attribute__((opencl_unroll_hint(SG_TILE_M / 8)))
220 for (uint j = 0; j < (SG_TILE_M / 8); ++j)
223 regC[0*(SIMD_LANE_M / 8) + j] = MMAD_8x8( rowA[j], colB[0], regC[0*(SIMD_LANE_M / 8) + j]);
225 colB[0] = l_tileB[l_offsetTileB_col2];
226 __attribute__((opencl_unroll_hint(SG_TILE_M / 8)))
227 for (uint j = 0; j < (SG_TILE_M / 8); ++j)
230 regC[1*(SIMD_LANE_M / 8) + j] = MMAD_8x8( rowA[j], colB[1], regC[1*(SIMD_LANE_M / 8) + j] );
232 colB[1] = l_tileB[l_offsetTileB_col3];
233 __attribute__((opencl_unroll_hint(SG_TILE_M / 8)))
234 for (uint j = 0; j < (SG_TILE_M / 8); ++j)
237 regC[2*(SIMD_LANE_M / 8) + j] = MMAD_8x8(rowA[j], colB[0], regC[2*(SIMD_LANE_M / 8) + j]);
239 __attribute__((opencl_unroll_hint(SG_TILE_M / 8)))
240 for (uint j = 0; j < (SG_TILE_M / 8); ++j)
243 regC[3*(SIMD_LANE_M / 8) + j] = MMAD_8x8(rowA[j], colB[1], regC[3*(SIMD_LANE_M / 8) + j]);
248 * \brief GEMM kernel to compute MxN matrix using SLM
249 * \param g_inA - Input matrix
250 * \param g_inB - Input matrix
251 * \param g_outC - Output matrix
254 __attribute__((intel_reqd_sub_group_size(SG_SIZE)))
255 KERNEL(Kernel_GEMM_MMAD8_32x32SG_224x128WG_SLM_INT8_fused_eltwise)
256 (__global char* const g_inA,
257 __global int* g_outC,
258 __global char* const g_inB,
260 __global BIAS_TYPE* biases,
262 __global float* quantizations,
264 __global float* calibrations,
267 __global char* const input2,
268 __global float* eltw_calibrations
272 __global int4* const g_matrixA = (__global int4*)g_inA;
273 __global int4* const g_matrixB = (__global int4*)g_inB;
274 __global int8* g_matrixC = (__global int8*)g_outC;
276 // Each work-group works to compute 128x128 tile.
277 // Each work-group contains 16 sub-groups.
278 // Each sub-group within the work-group works to compute a 32x32 tile.
279 // 1) All work-items in WG fill SLM with tileA (128x32) and tileB (32x128).
280 // 2) Each sub-group works to compute 32x32 tileC (stored in regC).
281 // Note that each work-item in the sub-group computes a 32x4 chunk of tileC.
282 // 3) Repeat until tileC is fully computed (while moving tileA and tileB "windows")
283 __local int8 l_workGroupTileA[2 * (WG_TILE_M * MATRIX_SMALL_K) / sizeof(int8)];
284 __local int8 l_workGroupTileB[2 * (WG_TILE_N * MATRIX_SMALL_K) / sizeof(int8)];
286 __local uint* l_workGroupTileA_uint = (__local uint*)l_workGroupTileA;
287 __local int4* l_workGroupTileA_int4 = (__local int4*)l_workGroupTileA;
288 __local int4* l_workGroupTileB_int4 = (__local int4*)l_workGroupTileB;
290 const uint l_groupSize = get_local_size(DIM_X) * get_local_size(DIM_Y);
292 const uint l_pingPongOffsetA_uint = (WG_TILE_M * MATRIX_SMALL_K) / sizeof(uint);
293 const uint l_pingPongOffsetB_int8 = (WG_TILE_N * MATRIX_SMALL_K) / sizeof(int8);
294 const uint l_pingPongOffsetA_int4 = (WG_TILE_M * MATRIX_SMALL_K) / sizeof(int4);
295 const uint l_pingPongOffsetB_int4 = (WG_TILE_N * MATRIX_SMALL_K) / sizeof(int4);
298 const uint g_tidY = get_global_id(DIM_Y);
299 const uint g_tidX = get_global_id(DIM_X);
300 const uint l_tidX = get_local_id(DIM_X);
301 const uint l_tidY = get_local_id(DIM_Y);
302 const uint l_tid = l_tidY * get_local_size(DIM_X) + l_tidX;
305 const uint sg_tid = get_sub_group_local_id();
306 const uint sg_global_idX = (uint)(g_tidX / SG_SIZE);
307 const uint sg_global_idY = g_tidY;
308 const uint sg_local_idX = (uint)(l_tidX / SG_SIZE);
309 const uint sg_local_idY = l_tidY;
310 const uint sg_local_id = sg_local_idY * get_local_size(DIM_X) / SG_SIZE + sg_local_idX;
312 const uint sub_group_id = get_sub_group_id();
315 int8 regC[(SIMD_LANE_M / 8) * SIMD_LANE_N] = {0}; // Each work-item responsible for 32x4 ints elts
316 int8 rowA[(SG_TILE_M * MATRIX_SMALL_K / SG_SIZE) / sizeof(int8)]; // each work-item will hold 1/8 of matrixA
317 int8 colB[2]; // each lane will store 32x4 piece of matrixB
320 const uint l_offsetTileA = SG_TILE_M * (MATRIX_SMALL_K / sizeof(uint)) * sg_local_idY;
321 const uint numElements32x32TileB = (MATRIX_SMALL_K * SG_TILE_N) / sizeof(int8);
322 const uint numElements32x8TileB = numElements32x32TileB / 4;
323 const uint l_offsetTileB = numElements32x32TileB * sg_local_idX;
324 const uint l_offsetTileB_col0 = l_offsetTileB + sg_tid;
325 const uint l_offsetTileB_col1 = l_offsetTileB + 1 * numElements32x8TileB + sg_tid;
326 const uint l_offsetTileB_col2 = l_offsetTileB + 2 * numElements32x8TileB + sg_tid;
327 const uint l_offsetTileB_col3 = l_offsetTileB + 3 * numElements32x8TileB + sg_tid;
332 #ifdef TILED_GLOBAL_LAYOUT // 32-row major (matrixA) and 32-col major (matrixB)
333 g_idxA[0] = ((MATRIX_SMALL_K / sizeof(int4)) * WG_TILE_M) * get_group_id(DIM_Y) + l_tid;
334 g_idxB[0] = ((MATRIX_SMALL_K / sizeof(int4)) * WG_TILE_N) * get_group_id(DIM_X) + l_tid;
335 g_idxA[1] = g_idxA[0] + l_groupSize;
336 g_idxB[1] = g_idxB[0] + l_groupSize;
337 #else // Row (matrixA) and Col (matrixB) major layout
338 g_idxA[0] = WG_TILE_M * (MATRIX_K / sizeof(int4)) * get_group_id(DIM_Y) +
339 (l_tid / 2) * (MATRIX_K / sizeof(int4)) + (l_tid % 2);
340 g_idxB[0] = WG_TILE_N * (MATRIX_K / sizeof(int4)) * get_group_id(DIM_X) +
341 (l_tid / 2) * (MATRIX_K / sizeof(int4)) + (l_tid % 2);
342 g_idxA[1] = g_idxA[0] + (l_groupSize / 2) * (MATRIX_K / sizeof(int4));
343 g_idxB[1] = g_idxB[0] + (l_groupSize / 2) * (MATRIX_K / sizeof(int4));
347 l_workGroupTileA_int4[l_tid] = g_matrixA[g_idxA[0]];
348 l_workGroupTileB_int4[l_tid] = g_matrixB[g_idxB[0]];
350 l_workGroupTileA_int4[l_tid + l_groupSize] = g_matrixA[g_idxA[1]];
353 // Not all work-items will be needed to fetch the remaining matrix B
354 l_workGroupTileB_int4[l_tid + l_groupSize] = g_matrixB[g_idxB[1]];
356 #ifdef TILED_GLOBAL_LAYOUT
357 g_idxA[0] += MATRIX_M * MATRIX_SMALL_K / sizeof(int4);
358 g_idxB[0] += MATRIX_N * MATRIX_SMALL_K / sizeof(int4);
359 g_idxA[1] += MATRIX_M * MATRIX_SMALL_K / sizeof(int4);
360 g_idxB[1] += MATRIX_N * MATRIX_SMALL_K / sizeof(int4);
362 g_idxA[0] += MATRIX_SMALL_K / sizeof(int4);
363 g_idxB[0] += MATRIX_SMALL_K / sizeof(int4);
364 g_idxA[1] += MATRIX_SMALL_K / sizeof(int4);
365 g_idxB[1] += MATRIX_SMALL_K / sizeof(int4);
368 barrier(CLK_LOCAL_MEM_FENCE);
370 int4 hdcReadValueA[2];
371 int4 hdcReadValueB[2];
373 __attribute__((opencl_unroll_hint(1)))
374 for (uint k = 0; k < (MATRIX_K / MATRIX_SMALL_K) - 1; k++)
376 hdcReadValueA[0] = g_matrixA[g_idxA[0]];
377 hdcReadValueB[0] = g_matrixB[g_idxB[0]];
378 hdcReadValueA[1] = g_matrixA[g_idxA[1]];
381 // Not all work-items will be needed to fetch the remaining matrix B
382 hdcReadValueB[1] = g_matrixB[g_idxB[1]];
384 #ifdef TILED_GLOBAL_LAYOUT
385 g_idxA[0] += MATRIX_M * MATRIX_SMALL_K / sizeof(int4);
386 g_idxB[0] += MATRIX_N * MATRIX_SMALL_K / sizeof(int4);
387 g_idxA[1] += MATRIX_M * MATRIX_SMALL_K / sizeof(int4);
388 g_idxB[1] += MATRIX_N * MATRIX_SMALL_K / sizeof(int4);
390 g_idxA[0] += MATRIX_SMALL_K / sizeof(int4);
391 g_idxB[0] += MATRIX_SMALL_K / sizeof(int4);
392 g_idxA[1] += MATRIX_SMALL_K / sizeof(int4);
393 g_idxB[1] += MATRIX_SMALL_K / sizeof(int4);
398 FUNC_CALL(mmad_32x32_int8)(&l_workGroupTileA_uint[(k % 2) * l_pingPongOffsetA_uint],
399 l_offsetTileA, &l_workGroupTileB[(k % 2) * l_pingPongOffsetB_int8],
400 l_offsetTileB_col0, l_offsetTileB_col1, l_offsetTileB_col2,
401 l_offsetTileB_col3, rowA, colB, regC);
403 //SLM setup - SLM write only
404 l_workGroupTileA_int4[((k + 1) % 2 * l_pingPongOffsetA_int4) + l_tid] = hdcReadValueA[0];
405 l_workGroupTileB_int4[((k + 1) % 2 * l_pingPongOffsetB_int4) + l_tid] = hdcReadValueB[0];
406 l_workGroupTileA_int4[((k + 1) % 2 * l_pingPongOffsetA_int4) + l_tid + l_groupSize] = hdcReadValueA[1];
409 // Not all work-items will be needed to fetch the remaining matrix B
410 l_workGroupTileB_int4[((k + 1) % 2 * l_pingPongOffsetB_int4) + l_tid + l_groupSize] = hdcReadValueB[1];
412 barrier(CLK_LOCAL_MEM_FENCE);
415 //Last MMAD compute iteration (avoids branching in main loop)
416 FUNC_CALL(mmad_32x32_int8)(
417 &l_workGroupTileA_uint[(((MATRIX_K / MATRIX_SMALL_K) - 1) % 2) * l_pingPongOffsetA_uint],
419 &l_workGroupTileB[(((MATRIX_K / MATRIX_SMALL_K) - 1) % 2) * l_pingPongOffsetB_int8],
420 l_offsetTileB_col0, l_offsetTileB_col1, l_offsetTileB_col2, l_offsetTileB_col3, rowA, colB,
424 #ifdef OUTPUT_TILED_GLOBAL_LAYOUT
426 // Write out in swizzled manner after quantizing
427 __global uchar* g_outC_uchar = (__global uchar*)g_outC;
428 uint cOffset = sg_global_idX * (MATRIX_M * SG_TILE_N / sizeof(uchar)) +
429 sg_global_idY * (SG_TILE_M * SG_TILE_N / sizeof(uchar));
431 uchar16 regC_uchar16;
432 uint offset_uc16 = 0;
434 const uint workgroup_id_x = get_group_id(0);
435 uint feature_off = 32*(sub_group_id % (WG_TILE_N / 32)) + WG_TILE_N*workgroup_id_x; //=32*{0,1,2,3} + WG_TILE_N * workgroup_id_x
436 uint feature = get_sub_group_local_id()*4 + feature_off;
438 float4 quant_f = vload4(0, quantizations + feature);
439 float4 bias_f = vload4(0, biases + feature);
440 float4 calib_f = vload4(0, calibrations + feature);
443 float4 eltw_calib_f = vload4(0, eltw_calibrations + feature);
445 uchar16 eltw[(2*SG_TILE_M) / (sizeof(int8) / sizeof(int))];
446 uint tmpcOff = cOffset;
447 __attribute__((opencl_unroll_hint( SG_TILE_M / (sizeof(int8) / sizeof(int)) )))
448 for (uint i = 0; i < (2*SG_TILE_M) / (sizeof(int8) / sizeof(int)); i++)
450 uint padded_offset = FUNC_CALL(calculate_output_offset_to_account_padding)(tmpcOff);
452 eltw[i] = as_uchar16(intel_sub_group_block_read4((__global uint*)(g_outC_uchar + padded_offset)));
454 const uint eltw_second_input_offset = FUNC_CALL(calculate_eltw_input_offset_based_on_output_offset_account_padding)(tmpcOff, ELTW_STRIDE_X, ELTW_STRIDE_Y);
455 eltw[i] = as_uchar16(intel_sub_group_block_read4((__global uint*)(input2 + eltw_second_input_offset)));
457 tmpcOff += sizeof(uchar16) * SG_SIZE;
460 #if MMAD_SUPPORTED == 1
461 __attribute__((opencl_unroll_hint( SG_TILE_M / (sizeof(int8) / sizeof(int)) )))
463 for (uint i = 0; i < SG_TILE_M / (sizeof(int8) / sizeof(int)); i++)
465 uint padded_offset = FUNC_CALL(calculate_output_offset_to_account_padding)(cOffset);
467 uchar16 eltw_input_vals = eltw[i * 2];
472 intel_sub_group_block_write4((__global uint*)(g_outC_uchar + padded_offset), as_uint4(regC_uchar16));
473 cOffset += sizeof(uchar16) * SG_SIZE;
475 // now we need to calculate again for other x
476 padded_offset = FUNC_CALL(calculate_output_offset_to_account_padding)(cOffset);
478 uchar16 eltw_input_vals = eltw[i * 2 + 1];
483 intel_sub_group_block_write4( (__global uint*)(g_outC_uchar + padded_offset), as_uint4(regC_uchar16) );
484 cOffset += sizeof(uchar16) * SG_SIZE;
487 // Write final accumulated values
488 uint cOffset = sg_global_idX * ((MATRIX_M / 8) * SG_TILE_N) + sg_global_idY * (SG_TILE_M / 8) +
489 sg_tid * (MATRIX_M / 8);
490 __attribute__((opencl_unroll_hint(SIMD_LANE_N)))
491 for (uint i = 0; i < (SIMD_LANE_N); ++i)
493 __attribute__((opencl_unroll_hint(SIMD_LANE_M / 8)))
494 for (uint j = 0; j < (SIMD_LANE_M / 8); ++j)
496 g_matrixC[cOffset + j] = regC[i*(SIMD_LANE_M / 8) + j];
498 cOffset += SG_SIZE * (MATRIX_M / 8);