Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / cl_kernels / fused_conv_eltwise_gpu_mmad_32x32sg_128x128wg_slm_int8.cl
1 // Copyright (c) 2018 Intel Corporation
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "include/mmad.cl"
16
17 #define SUM_SCALE 0.11f
18 #define SCALE 0.11f
19
20 #ifdef LIGHTWEIGHT_QUANTIZATION
21
22 #define QUANTIZATION(idx) \
23     {\
24         float4 tmp;\
25         for(uint z = 0; z < 4; z++)\
26         {\
27             tmp.s0 = (float)eltw_input_vals[z * 4 + 0] * SUM_SCALE + bias_f.s0;\
28             tmp.s1 = (float)eltw_input_vals[z * 4 + 1] * SUM_SCALE + bias_f.s1;\
29             tmp.s2 = (float)eltw_input_vals[z * 4 + 2] * SUM_SCALE + bias_f.s2;\
30             tmp.s3 = (float)eltw_input_vals[z * 4 + 3] * SUM_SCALE + bias_f.s3;\
31             \
32             regC_uchar16[z * 4 + 0] = convert_uchar_sat( (regC[0 * 4 + i][idx + z / 4]) * SCALE + tmp.s0);\
33             regC_uchar16[z * 4 + 1] = convert_uchar_sat( (regC[1 * 4 + i][idx + z / 4]) * SCALE + tmp.s1);\
34             regC_uchar16[z * 4 + 2] = convert_uchar_sat( (regC[2 * 4 + i][idx + z / 4]) * SCALE + tmp.s2);\
35             regC_uchar16[z * 4 + 3] = convert_uchar_sat( (regC[3 * 4 + i][idx + z / 4]) * SCALE + tmp.s3);\
36         }\
37     }
38
39 #elif NO_QUANTIZATION
40
41 #define QUANTIZATION(idx) \
42     regC_uchar16.s0 = regC[0 * 4 + i][idx];\
43     regC_uchar16.s1 = regC[1 * 4 + i][idx];\
44     regC_uchar16.s2 = regC[2 * 4 + i][idx];\
45     regC_uchar16.s3 = regC[3 * 4 + i][idx];\
46     \
47     regC_uchar16.s4 = regC[0 * 4 + i][idx+1];\
48     regC_uchar16.s5 = regC[1 * 4 + i][idx+1];\
49     regC_uchar16.s6 = regC[2 * 4 + i][idx+1];\
50     regC_uchar16.s7 = regC[3 * 4 + i][idx+1];\
51     \
52     regC_uchar16.s8 = regC[0 * 4 + i][idx+2];\
53     regC_uchar16.s9 = regC[1 * 4 + i][idx+2];\
54     regC_uchar16.sa = regC[2 * 4 + i][idx+2];\
55     regC_uchar16.sb = regC[3 * 4 + i][idx+2];\
56     \
57     regC_uchar16.sc = regC[0 * 4 + i][idx+3];\
58     regC_uchar16.sd = regC[1 * 4 + i][idx+3];\
59     regC_uchar16.se = regC[2 * 4 + i][idx+3];\
60     regC_uchar16.sf = regC[3 * 4 + i][idx+3];\
61     {\
62         int16 sum;\
63         for(uint s = 0; s <16; s++)\
64         {\
65             sum[s] = (int)as_char(regC_uchar16[s]) + (int)as_char(eltw_input_vals[s]);\
66         }\
67         regC_uchar16.s0 = convert_uchar_sat( sum.s0 );\
68         regC_uchar16.s1 = convert_uchar_sat( sum.s1 );\
69         regC_uchar16.s2 = convert_uchar_sat( sum.s2 );\
70         regC_uchar16.s3 = convert_uchar_sat( sum.s3 );\
71         \
72         regC_uchar16.s4 = convert_uchar_sat( sum.s4 );\
73         regC_uchar16.s5 = convert_uchar_sat( sum.s5 );\
74         regC_uchar16.s6 = convert_uchar_sat( sum.s6 );\
75         regC_uchar16.s7 = convert_uchar_sat( sum.s7 );\
76         \
77         regC_uchar16.s8 = convert_uchar_sat( sum.s8 );\
78         regC_uchar16.s9 = convert_uchar_sat( sum.s9 );\
79         regC_uchar16.sa = convert_uchar_sat( sum.sa );\
80         regC_uchar16.sb = convert_uchar_sat( sum.sb );\
81         \
82         regC_uchar16.sc = convert_uchar_sat( sum.sc );\
83         regC_uchar16.sd = convert_uchar_sat( sum.sd );\
84         regC_uchar16.se = convert_uchar_sat( sum.se );\
85         regC_uchar16.sf = convert_uchar_sat( sum.sf );\
86     }
87
88 #else
89
90 #define QUANTIZATION(idx) \
91     regC_uchar16.s0 = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[0 * 4 + i][idx]) * quant_f.s0 * I_QF + bias_f.s0) * calib_f.s0)), NL_M, NL_N));\
92     regC_uchar16.s1 = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[1 * 4 + i][idx]) * quant_f.s1 * I_QF + bias_f.s1) * calib_f.s1)), NL_M, NL_N));\
93     regC_uchar16.s2 = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[2 * 4 + i][idx]) * quant_f.s2 * I_QF + bias_f.s2) * calib_f.s2)), NL_M, NL_N));\
94     regC_uchar16.s3 = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[3 * 4 + i][idx]) * quant_f.s3 * I_QF + bias_f.s3) * calib_f.s3)), NL_M, NL_N));\
95     \
96     regC_uchar16.s4 = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[0 * 4 + i][idx+1]) * quant_f.s0 * I_QF + bias_f.s0) * calib_f.s0)), NL_M, NL_N));\
97     regC_uchar16.s5 = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[1 * 4 + i][idx+1]) * quant_f.s1 * I_QF + bias_f.s1) * calib_f.s1)), NL_M, NL_N));\
98     regC_uchar16.s6 = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[2 * 4 + i][idx+1]) * quant_f.s2 * I_QF + bias_f.s2) * calib_f.s2)), NL_M, NL_N));\
99     regC_uchar16.s7 = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[3 * 4 + i][idx+1]) * quant_f.s3 * I_QF + bias_f.s3) * calib_f.s3)), NL_M, NL_N));\
100     \
101     regC_uchar16.s8 = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[0 * 4 + i][idx+2]) * quant_f.s0 * I_QF + bias_f.s0) * calib_f.s0)), NL_M, NL_N));\
102     regC_uchar16.s9 = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[1 * 4 + i][idx+2]) * quant_f.s1 * I_QF + bias_f.s1) * calib_f.s1)), NL_M, NL_N));\
103     regC_uchar16.sa = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[2 * 4 + i][idx+2]) * quant_f.s2 * I_QF + bias_f.s2) * calib_f.s2)), NL_M, NL_N));\
104     regC_uchar16.sb = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[3 * 4 + i][idx+2]) * quant_f.s3 * I_QF + bias_f.s3) * calib_f.s3)), NL_M, NL_N));\
105     \
106     regC_uchar16.sc = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[0 * 4 + i][idx+3]) * quant_f.s0 * I_QF + bias_f.s0) * calib_f.s0)), NL_M, NL_N));\
107     regC_uchar16.sd = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[1 * 4 + i][idx+3]) * quant_f.s1 * I_QF + bias_f.s1) * calib_f.s1)), NL_M, NL_N));\
108     regC_uchar16.se = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[2 * 4 + i][idx+3]) * quant_f.s2 * I_QF + bias_f.s2) * calib_f.s2)), NL_M, NL_N));\
109     regC_uchar16.sf = as_uchar(ACTIVATION( convert_char(round(( (float)(regC[3 * 4 + i][idx+3]) * quant_f.s3 * I_QF + bias_f.s3) * calib_f.s3)), NL_M, NL_N));\
110     {\
111         int16 sum;\
112         for(uint s = 0; s <16; s++)\
113         {\
114             sum[s] = (int)as_char(regC_uchar16[s]) + (int)as_char(eltw_input_vals[s]);\
115         }\
116         regC_uchar16.s0 = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.s0)  * eltw_calib_f.s0)), NL_M_ELTW, NL_N_ELTW));\
117         regC_uchar16.s1 = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.s1)  * eltw_calib_f.s1)), NL_M_ELTW, NL_N_ELTW));\
118         regC_uchar16.s2 = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.s2)  * eltw_calib_f.s2)), NL_M_ELTW, NL_N_ELTW));\
119         regC_uchar16.s3 = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.s3)  * eltw_calib_f.s3)), NL_M_ELTW, NL_N_ELTW));\
120         \
121         regC_uchar16.s4 = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.s4)  * eltw_calib_f.s0)), NL_M_ELTW, NL_N_ELTW));\
122         regC_uchar16.s5 = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.s5)  * eltw_calib_f.s1)), NL_M_ELTW, NL_N_ELTW));\
123         regC_uchar16.s6 = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.s6)  * eltw_calib_f.s2)), NL_M_ELTW, NL_N_ELTW));\
124         regC_uchar16.s7 = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.s7)  * eltw_calib_f.s3)), NL_M_ELTW, NL_N_ELTW));\
125         \
126         regC_uchar16.s8 = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.s8)  * eltw_calib_f.s0)), NL_M_ELTW, NL_N_ELTW));\
127         regC_uchar16.s9 = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.s9)  * eltw_calib_f.s1)), NL_M_ELTW, NL_N_ELTW));\
128         regC_uchar16.sa = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.sa)  * eltw_calib_f.s2)), NL_M_ELTW, NL_N_ELTW));\
129         regC_uchar16.sb = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.sb)  * eltw_calib_f.s3)), NL_M_ELTW, NL_N_ELTW));\
130         \
131         regC_uchar16.sc = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.sc)  * eltw_calib_f.s0)), NL_M_ELTW, NL_N_ELTW));\
132         regC_uchar16.sd = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.sd)  * eltw_calib_f.s1)), NL_M_ELTW, NL_N_ELTW));\
133         regC_uchar16.se = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.se)  * eltw_calib_f.s2)), NL_M_ELTW, NL_N_ELTW));\
134         regC_uchar16.sf = as_uchar(ACTIVATION_ELTW( convert_char((int)round( (float)(sum.sf)  * eltw_calib_f.s3)), NL_M_ELTW, NL_N_ELTW));\
135     }
136 #endif
137
138
139 inline uint FUNC(calculate_output_offset_to_account_padding)(uint cOffset)
140 {
141 #if OUT_WITH_PADDING == 1
142     uint tmp_idx = cOffset;
143     uint f_val_idx = tmp_idx % 32;
144     tmp_idx /= 32;
145     uint b_val_idx = tmp_idx % 4;
146     tmp_idx /= 4;
147     uint x_idx = tmp_idx % OUTPUT_SIZE_X;
148     tmp_idx /= OUTPUT_SIZE_X;
149     uint y_idx = tmp_idx % OUTPUT_SIZE_Y;
150     tmp_idx /= OUTPUT_SIZE_Y;
151     uint b_slice_idx = tmp_idx % (OUTPUT_BATCH_NUM / 4);
152     tmp_idx /= (OUTPUT_BATCH_NUM / 4);
153     uint f_slice_idx = tmp_idx % (OUTPUT_FEATURE_NUM / 32);
154
155     uint padded_offset = f_slice_idx * OUT_F_BLOCK_PITCH;
156     padded_offset += b_slice_idx * OUT_B_BLOCK_PITCH;
157     padded_offset += y_idx * OUT_Y_PITCH;
158     padded_offset += x_idx * OUT_X_PITCH;
159     padded_offset += b_val_idx * 32;
160     padded_offset += f_val_idx;
161     padded_offset += OUT_OFFSET;
162
163     return padded_offset;
164 #else
165     return cOffset;
166 #endif
167 }
168
169 #if IN_OUT_OPT != 1
170 inline uint FUNC(calculate_eltw_input_offset_based_on_output_offset_account_padding)(uint cOffset, uint strideX, uint strideY)
171 {
172 #if ELTW_WITH_PADDING == 1 || ELTW_STRIDE_X != 1 || ELTW_STRIDE_Y != 1
173     uint tmp_idx = cOffset;
174     uint f_val_idx = tmp_idx % 32;
175     tmp_idx /= 32;
176     uint b_val_idx = tmp_idx % 4;
177     tmp_idx /= 4;
178     uint x_idx = tmp_idx % OUTPUT_SIZE_X;
179     x_idx *= strideX;
180     tmp_idx /= OUTPUT_SIZE_X;
181     uint y_idx = tmp_idx % OUTPUT_SIZE_Y;
182     y_idx *= strideY;
183     tmp_idx /= OUTPUT_SIZE_Y;
184     uint b_slice_idx = tmp_idx % (OUTPUT_BATCH_NUM / 4);
185     tmp_idx /= (OUTPUT_BATCH_NUM / 4);
186     uint f_slice_idx = tmp_idx % (OUTPUT_FEATURE_NUM / 32);
187
188     uint padded_offset = f_slice_idx * IN2_F_BLOCK_PITCH;
189     padded_offset += b_slice_idx * IN2_B_BLOCK_PITCH;
190     padded_offset += y_idx * IN2_Y_PITCH;
191     padded_offset += x_idx * IN2_X_PITCH;
192     padded_offset += b_val_idx * 32;
193     padded_offset += f_val_idx;
194     padded_offset += IN2_OFFSET;
195
196     return padded_offset;
197 #else
198     return cOffset;
199 #endif
200 }
201 #endif
202
203 inline void FUNC(mmad_32x32_int8)(  __local uint* l_tileA, const uint l_offsetTileA,
204                                     __local int8* l_tileB, const uint l_offsetTileB_col0,
205                                     const uint l_offsetTileB_col1, const uint l_offsetTileB_col2,
206                                     const uint l_offsetTileB_col3, int8* rowA, int8* colB,
207                                     int8* regC)
208 {
209     // Read tile A from SLM to regA
210     uint l_offsetTileATemp = l_offsetTileA;
211     __attribute__((opencl_unroll_hint(SG_TILE_M / 8)))
212     for (uint j = 0; j < (SG_TILE_M / 8); ++j)
213     {
214         rowA[j] = as_int8(SLM_BLOCK_READ_8(&l_tileA[l_offsetTileATemp]));
215         l_offsetTileATemp += 8 * SG_SIZE;
216     }
217     // Read tile B from SLM to regB and compute mmad
218     colB[0] = l_tileB[l_offsetTileB_col0];
219     colB[1] = l_tileB[l_offsetTileB_col1];
220     __attribute__((opencl_unroll_hint(SG_TILE_M / 8)))
221     for (uint j = 0; j < (SG_TILE_M / 8); ++j)
222     {
223         // Compute partial C
224         regC[0*(SIMD_LANE_M / 8) + j] = MMAD_8x8( rowA[j], colB[0], regC[0*(SIMD_LANE_M / 8) + j]);
225     }
226     colB[0] = l_tileB[l_offsetTileB_col2];
227     __attribute__((opencl_unroll_hint(SG_TILE_M / 8)))
228     for (uint j = 0; j < (SG_TILE_M / 8); ++j)
229     {
230         // Compute partial C
231         regC[1*(SIMD_LANE_M / 8) + j] = MMAD_8x8( rowA[j], colB[1], regC[1*(SIMD_LANE_M / 8) + j] );
232     }
233     colB[1] = l_tileB[l_offsetTileB_col3];
234     __attribute__((opencl_unroll_hint(SG_TILE_M / 8)))
235     for (uint j = 0; j < (SG_TILE_M / 8); ++j)
236     {
237         // Compute partial C
238         regC[2*(SIMD_LANE_M / 8) + j] = MMAD_8x8(rowA[j], colB[0], regC[2*(SIMD_LANE_M / 8) + j]);
239     }
240     __attribute__((opencl_unroll_hint(SG_TILE_M / 8)))
241     for (uint j = 0; j < (SG_TILE_M / 8); ++j)
242     {
243         // Compute partial C
244         regC[3*(SIMD_LANE_M / 8) + j] = MMAD_8x8(rowA[j], colB[1], regC[3*(SIMD_LANE_M / 8) + j]);
245     }
246 }
247
248 /*
249  *  \brief GEMM kernel to compute MxN matrix using SLM
250  *  \param g_inA  - Input matrix 
251  *  \param g_inB  - Input matrix 
252  *  \param g_outC - Output matrix
253  */
254
255 __attribute__((intel_reqd_sub_group_size(SG_SIZE)))   
256 KERNEL(Kernel_GEMM_MMAD8_32x32SG_128x128WG_SLM_INT8_fused_eltwise)
257   (
258   __global char* const g_inA,
259   __global int* g_outC,
260   __global char* const g_inB,
261     #if BIAS_TERM
262         __global BIAS_TYPE* biases,
263     #endif
264         __global float* quantizations,
265     #if CALIBRATION_TERM
266         __global float* calibrations,
267     #endif
268         uint split_idx,
269   __global char* const input2,
270   __global float* eltw_calibrations
271    )
272 {
273
274     __global int4* const g_matrixA = (__global int4*)g_inA;
275     __global int4* const g_matrixB = (__global int4*)g_inB;
276     __global int8* g_matrixC = (__global int8*)g_outC;
277
278     // Each work-group works to compute 128x128 tile.
279     // Each work-group contains 16 sub-groups.
280     // Each sub-group within the work-group works to compute a 32x32 tile.
281     // 1) All work-items in WG fill SLM with tileA (128x32) and tileB (32x128).
282     // 2) Each sub-group works to compute 32x32 tileC (stored in regC).
283     //    Note that each work-item in the sub-group computes a 32x4 chunk of tileC.
284     // 3) Repeat until tileC is fully computed (while moving tileA and tileB "windows")
285     __local int8 l_workGroupTileA[2 * (WG_TILE_M * MATRIX_SMALL_K) / sizeof(int8)]; // [2*128*32/8] = 1024 
286     __local int8 l_workGroupTileB[2 * (WG_TILE_N * MATRIX_SMALL_K) / sizeof(int8)]; // [2*128*32/8] = 1024 
287
288     __local uint* l_workGroupTileA_uint = (__local uint*)l_workGroupTileA;
289     __local int4* l_workGroupTileA_int4 = (__local int4*)l_workGroupTileA;
290     __local int4* l_workGroupTileB_int4 = (__local int4*)l_workGroupTileB;
291
292     const uint l_groupSize = get_local_size(DIM_X) * get_local_size(DIM_Y);
293
294     const uint l_pingPongOffsetA_uint = (WG_TILE_M * MATRIX_SMALL_K) / sizeof(uint);
295     const uint l_pingPongOffsetB_int8 = (WG_TILE_N * MATRIX_SMALL_K) / sizeof(int8);
296     const uint l_pingPongOffsetA_int4 = (WG_TILE_M * MATRIX_SMALL_K) / sizeof(int4);
297     const uint l_pingPongOffsetB_int4 = (WG_TILE_N * MATRIX_SMALL_K) / sizeof(int4);
298
299     // Thread IDs
300     const uint g_tidY = get_global_id(DIM_Y); // 0,...,all_wi_inY
301     const uint g_tidX = get_global_id(DIM_X); // 0,...,all_wi_inX
302     const uint l_tidX = get_local_id(DIM_X);  // 0,...,31 in WG
303     const uint l_tidY = get_local_id(DIM_Y);  // 0,1,2,3  in WG
304     const uint l_tid = l_tidY * get_local_size(DIM_X) + l_tidX; // 0,1,2,...127
305
306     // SubGroup IDs
307     const uint sg_tid = get_sub_group_local_id();            // 0,1,...,8
308     const uint sg_global_idX = (uint)(g_tidX / SG_SIZE);     //{0}/8
309     const uint sg_global_idY = g_tidY;                       //{0}
310
311     const uint sg_local_idX = (uint)(l_tidX / SG_SIZE);      // {0,...,31}/8={0,0,0,0,0...,1,1,1,...,3,3,3}
312     const uint sg_local_idY = l_tidY;                        // 0,1,2,3
313     const uint sg_local_id = sg_local_idY * get_local_size(DIM_X) / SG_SIZE + sg_local_idX;  // get_local_size(DIM_X) / SG_SIZE = 32/8 = 4
314
315     const uint sub_group_id = get_sub_group_id();
316
317
318     // Registers
319     int8 regC[(SIMD_LANE_M / 8) * SIMD_LANE_N] = {0}; // Each work-item responsible for 32x4 ints elts   // (32/8)*4
320     int8 rowA[(SG_TILE_M * MATRIX_SMALL_K / SG_SIZE) / sizeof(int8)]; // each work-item will hold 1/8 of matrixA 
321     int8 colB[2];  // each lane will store 32x4 piece of matrixB
322
323     // SLM indices
324     const uint l_offsetTileA = SG_TILE_M * (MATRIX_SMALL_K / sizeof(uint)) * sg_local_idY;
325     const uint numElements32x32TileB = (MATRIX_SMALL_K * SG_TILE_N) / sizeof(int8);
326     const uint numElements32x8TileB = numElements32x32TileB / 4;
327     const uint l_offsetTileB = numElements32x32TileB * sg_local_idX;
328     const uint l_offsetTileB_col0 = l_offsetTileB + sg_tid;
329     const uint l_offsetTileB_col1 = l_offsetTileB + 1 * numElements32x8TileB + sg_tid;
330     const uint l_offsetTileB_col2 = l_offsetTileB + 2 * numElements32x8TileB + sg_tid;
331     const uint l_offsetTileB_col3 = l_offsetTileB + 3 * numElements32x8TileB + sg_tid;
332
333     // Global indices
334     uint g_idxA[2];
335     uint g_idxB[2];
336 #ifdef TILED_GLOBAL_LAYOUT // 32-row major (matrixA) and 32-col major (matrixB)
337     g_idxA[0] = ((MATRIX_SMALL_K / sizeof(int4)) * WG_TILE_M) * get_group_id(DIM_Y) + l_tid;
338     g_idxB[0] = ((MATRIX_SMALL_K / sizeof(int4)) * WG_TILE_N) * get_group_id(DIM_X) + l_tid;
339     g_idxA[1] = g_idxA[0] + l_groupSize;
340     g_idxB[1] = g_idxB[0] + l_groupSize;
341 #else // Row (matrixA) and Col (matrixB) major layout
342     g_idxA[0] = WG_TILE_M * (MATRIX_K / sizeof(int4)) * get_group_id(DIM_Y) +
343                (l_tid / 2) * (MATRIX_K / sizeof(int4)) + (l_tid % 2);
344     g_idxB[0] = WG_TILE_N * (MATRIX_K / sizeof(int4)) * get_group_id(DIM_X) +
345                (l_tid / 2) * (MATRIX_K / sizeof(int4)) + (l_tid % 2);
346     g_idxA[1] = g_idxA[0] + (l_groupSize / 2) * (MATRIX_K / sizeof(int4));
347     g_idxB[1] = g_idxB[0] + (l_groupSize / 2) * (MATRIX_K / sizeof(int4));
348 #endif
349
350     // Initial SLM setup
351     {
352         l_workGroupTileA_int4[l_tid] = g_matrixA[g_idxA[0]];
353         l_workGroupTileB_int4[l_tid] = g_matrixB[g_idxB[0]];
354         l_workGroupTileA_int4[l_tid + l_groupSize] = g_matrixA[g_idxA[1]];
355         l_workGroupTileB_int4[l_tid + l_groupSize] = g_matrixB[g_idxB[1]];
356
357 #ifdef TILED_GLOBAL_LAYOUT
358         g_idxA[0] += MATRIX_M * MATRIX_SMALL_K / sizeof(int4);
359         g_idxB[0] += MATRIX_N * MATRIX_SMALL_K / sizeof(int4);
360         g_idxA[1] += MATRIX_M * MATRIX_SMALL_K / sizeof(int4);
361         g_idxB[1] += MATRIX_N * MATRIX_SMALL_K / sizeof(int4);
362 #else
363         g_idxA[0] += MATRIX_SMALL_K / sizeof(int4);
364         g_idxB[0] += MATRIX_SMALL_K / sizeof(int4);
365         g_idxA[1] += MATRIX_SMALL_K / sizeof(int4);
366         g_idxB[1] += MATRIX_SMALL_K / sizeof(int4);
367 #endif
368
369         barrier(CLK_LOCAL_MEM_FENCE);
370     }
371
372     int4 hdcReadValueA[2];
373     int4 hdcReadValueB[2];
374
375     __attribute__((opencl_unroll_hint(1)))
376     for (uint k = 0; k < (MATRIX_K / MATRIX_SMALL_K) - 1; k++)
377     {
378         /*
379          * SLM setup - HDC read only
380          */
381         // Overlap HDC reads with mmad compute
382         hdcReadValueA[0] = g_matrixA[g_idxA[0]];
383         hdcReadValueB[0] = g_matrixB[g_idxB[0]];
384         hdcReadValueA[1] = g_matrixA[g_idxA[1]];
385         hdcReadValueB[1] = g_matrixB[g_idxB[1]];
386
387 #ifdef TILED_GLOBAL_LAYOUT
388         g_idxA[0] += MATRIX_M * MATRIX_SMALL_K / sizeof(int4);
389         g_idxB[0] += MATRIX_N * MATRIX_SMALL_K / sizeof(int4);
390         g_idxA[1] += MATRIX_M * MATRIX_SMALL_K / sizeof(int4);
391         g_idxB[1] += MATRIX_N * MATRIX_SMALL_K / sizeof(int4);
392 #else
393         g_idxA[0] += MATRIX_SMALL_K / sizeof(int4);
394         g_idxB[0] += MATRIX_SMALL_K / sizeof(int4);
395         g_idxA[1] += MATRIX_SMALL_K / sizeof(int4);
396         g_idxB[1] += MATRIX_SMALL_K / sizeof(int4);
397 #endif
398
399         /*
400          * mmad compute
401          */
402         FUNC_CALL(mmad_32x32_int8)(&l_workGroupTileA_uint[(k % 2) * l_pingPongOffsetA_uint],
403                                 l_offsetTileA, &l_workGroupTileB[(k % 2) * l_pingPongOffsetB_int8],
404                                 l_offsetTileB_col0, l_offsetTileB_col1, l_offsetTileB_col2,
405                                 l_offsetTileB_col3, rowA, colB, regC);
406
407         /*
408          * SLM setup - SLM write only
409          */
410         l_workGroupTileA_int4[((k + 1) % 2 * l_pingPongOffsetA_int4) + l_tid] = hdcReadValueA[0];
411         l_workGroupTileB_int4[((k + 1) % 2 * l_pingPongOffsetB_int4) + l_tid] = hdcReadValueB[0];
412         l_workGroupTileA_int4[((k + 1) % 2 * l_pingPongOffsetA_int4) + l_tid + l_groupSize] = hdcReadValueA[1];
413         l_workGroupTileB_int4[((k + 1) % 2 * l_pingPongOffsetB_int4) + l_tid + l_groupSize] = hdcReadValueB[1];
414
415         barrier(CLK_LOCAL_MEM_FENCE);
416     } // main outer loop
417
418     /*
419      * Last mmad compute iteration (avoids branching in main loop)
420      */
421
422     FUNC_CALL(mmad_32x32_int8)(
423         &l_workGroupTileA_uint[(((MATRIX_K / MATRIX_SMALL_K) - 1) % 2) * l_pingPongOffsetA_uint],
424         l_offsetTileA,
425         &l_workGroupTileB[(((MATRIX_K / MATRIX_SMALL_K) - 1) % 2) * l_pingPongOffsetB_int8],
426         l_offsetTileB_col0, l_offsetTileB_col1, l_offsetTileB_col2, l_offsetTileB_col3, rowA, colB,
427         regC);
428
429 #ifdef OUTPUT_TILED_GLOBAL_LAYOUT
430     // Write out in swizzled manner after quantizing
431     __global uchar* g_outC_uchar = (__global uchar*)g_outC;
432     uint cOffset = sg_global_idX * (MATRIX_M * SG_TILE_N / sizeof(uchar)) +
433                    sg_global_idY * (SG_TILE_M * SG_TILE_N / sizeof(uchar));
434
435     uchar16 regC_uchar16;
436     uint offset_uc16 = 0;
437
438     const uint workgroup_id_x = get_group_id(0); 
439     uint feature_off = 32*(sub_group_id % (WG_TILE_N / 32)) + WG_TILE_N*workgroup_id_x; //=32*{0,1,2,3} + WG_TILE_N * workgroup_id_x 
440     uint feature = get_sub_group_local_id()*4 + feature_off;
441
442     float4 quant_f = vload4(0, quantizations + feature);
443     float4 bias_f = vload4(0, biases + feature);
444     float4 calib_f = vload4(0, calibrations + feature);
445
446     // eltwise calibs
447     float4 eltw_calib_f = vload4(0, eltw_calibrations + feature);
448
449     uchar16 eltw[(2*SG_TILE_M) / (sizeof(int8) / sizeof(int))];
450     uint tmpcOff = cOffset;
451     __attribute__((opencl_unroll_hint( SG_TILE_M / (sizeof(int8) / sizeof(int)) )))
452     for (uint i = 0; i < (2*SG_TILE_M) / (sizeof(int8) / sizeof(int)); i++)
453     {
454         uint padded_offset = FUNC_CALL(calculate_output_offset_to_account_padding)(tmpcOff);
455 #if IN_OUT_OPT == 1
456         eltw[i] = as_uchar16(intel_sub_group_block_read4((__global uint*)(g_outC_uchar + padded_offset)));
457 #else
458         const uint eltw_second_input_offset = FUNC_CALL(calculate_eltw_input_offset_based_on_output_offset_account_padding)(tmpcOff, ELTW_STRIDE_X, ELTW_STRIDE_Y);
459         eltw[i] = as_uchar16(intel_sub_group_block_read4((__global uint*)(input2 + eltw_second_input_offset)));
460 #endif
461         tmpcOff += sizeof(uchar16) * SG_SIZE;
462     }
463
464 #if MMAD_SUPPORTED == 1
465     __attribute__((opencl_unroll_hint( SG_TILE_M / (sizeof(int8) / sizeof(int)) )))
466 #endif
467     for (uint i = 0; i < SG_TILE_M / (sizeof(int8) / sizeof(int)); i++)
468     {
469         uint padded_offset = FUNC_CALL(calculate_output_offset_to_account_padding)(cOffset);
470         {
471             uchar16 eltw_input_vals = eltw[i * 2];
472             // B0..3, F0..31
473             QUANTIZATION(0);
474         }
475
476         intel_sub_group_block_write4((__global uint*)(g_outC_uchar + padded_offset), as_uint4(regC_uchar16));
477         cOffset += sizeof(uchar16) * SG_SIZE;
478
479         // now we need to calculate again for other x
480         padded_offset = FUNC_CALL(calculate_output_offset_to_account_padding)(cOffset);
481         {
482             uchar16 eltw_input_vals = eltw[i * 2 + 1];
483             // B0..3, F0..31
484             QUANTIZATION(4);
485         }
486
487         intel_sub_group_block_write4( (__global uint*)(g_outC_uchar + padded_offset), as_uint4(regC_uchar16) );
488         cOffset += sizeof(uchar16) * SG_SIZE;
489     }
490 #else
491     // Write final accumulated values
492     uint cOffset = sg_global_idX * ((MATRIX_M / 8) * SG_TILE_N) + sg_global_idY * (SG_TILE_M / 8) +
493                    sg_tid * (MATRIX_M / 8);
494     __attribute__((opencl_unroll_hint(SIMD_LANE_N)))
495     for (uint i = 0; i < (SIMD_LANE_N); ++i)
496     {
497         __attribute__((opencl_unroll_hint(SIMD_LANE_M / 8)))
498         for (uint j = 0; j < (SIMD_LANE_M / 8); ++j)
499         {
500             g_matrixC[cOffset + j] = regC[i*(SIMD_LANE_M / 8) + j];
501         }
502         cOffset += SG_SIZE * (MATRIX_M / 8);
503     }
504 #endif
505 }
506
507 #undef SUM_SCALE
508 #undef SCALE
509 #undef QUANTIZATION