2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "include/include_all.cl"
18 #include "include/sub_group.cl"
21 #define TILE_K FILTER_SIZE_X
24 __attribute__((intel_reqd_sub_group_size(8)))
25 KERNEL(convolution_f32)(
26 const __global float *src0,
28 const __global float *src1,
30 const __global float *bias,
34 #include "include/vec_typedefs.cl"
36 const unsigned group_x = get_group_id(0);
37 const unsigned group_y = get_group_id(1);
38 const unsigned global_x = get_global_id(0);
39 const unsigned global_y = get_global_id(1);
40 const unsigned global_z = get_global_id(2);
42 unsigned interleaved_y;
46 // Result ctile (*dst) is M rows x N columns
47 // LWG size is 1x8. Thus each thread calculates 8*M rows x N cols of ctile.
48 float8 blockC00 = 0.f;
49 float8 blockC10 = 0.f;
50 float8 blockC20 = 0.f;
51 float8 blockC30 = 0.f;
52 float8 blockC01 = 0.f;
53 float8 blockC11 = 0.f;
54 float8 blockC21 = 0.f;
55 float8 blockC31 = 0.f;
57 const uint in_split_offset = split_idx * INPUT0_FEATURE_PITCH * INPUT0_FEATURE_NUM;
58 // Src0 (patch input) is directly used as atile.
59 // Each work item points to the start of a different patch.
60 // atile is M rows x K columns.
61 const uint src0_read_offset0_const = INPUT0_OFFSET_WITH_PADDING + in_split_offset
62 + INPUT0_BATCH_PITCH * global_z // batch offset
63 + ( ( ( global_y * TILE_M + 0 ) / OUTPUT_SIZE_X ) * STRIDE_SIZE_Y * INPUT0_Y_PITCH ) // y offset
64 + ( ( ( global_y * TILE_M + 0 ) % OUTPUT_SIZE_X ) * STRIDE_SIZE_X ); // x offset
65 const uint src0_read_offset1_const = INPUT0_OFFSET_WITH_PADDING + in_split_offset
66 + INPUT0_BATCH_PITCH * global_z // batch offset
67 + ( ( ( global_y * TILE_M + 1 ) / OUTPUT_SIZE_X ) * STRIDE_SIZE_Y * INPUT0_Y_PITCH ) // y offset
68 + ( ( ( global_y * TILE_M + 1 ) % OUTPUT_SIZE_X ) * STRIDE_SIZE_X ); // x offset
70 // Src1 (filter) is directly used as btile.
71 // It starts at the top of src1 and walks down.
72 // btile is K rows x N columns.
73 uint src0_read_offset0 = src0_read_offset0_const;
74 uint src0_read_offset1 = src0_read_offset1_const;
75 uint src1_read_offset = ( global_x * TILE_N * 2);
77 #define DOT_PRODUCT_8( _result, _rowA, colB ) \
79 _result.s0 = mad( _rowA, sub_group_broadcast( colB, 0 ), _result.s0 ); \
80 _result.s1 = mad( _rowA, sub_group_broadcast( colB, 1 ), _result.s1 ); \
81 _result.s2 = mad( _rowA, sub_group_broadcast( colB, 2 ), _result.s2 ); \
82 _result.s3 = mad( _rowA, sub_group_broadcast( colB, 3 ), _result.s3 ); \
83 _result.s4 = mad( _rowA, sub_group_broadcast( colB, 4 ), _result.s4 ); \
84 _result.s5 = mad( _rowA, sub_group_broadcast( colB, 5 ), _result.s5 ); \
85 _result.s6 = mad( _rowA, sub_group_broadcast( colB, 6 ), _result.s6 ); \
86 _result.s7 = mad( _rowA, sub_group_broadcast( colB, 7 ), _result.s7 ); \
89 // Walk DOWN src0 (patch 0, 1, 2, ...) and DOWN src1.
90 // Inner loop loads and FMADs one row (FILTER_SIZE_X) of each input patch
91 // and FILTER_SIZE_X/2 rows of interleaved filter.
92 unsigned patch_depth = 0;
95 unsigned patch_row = 0;
98 // Load atile and btile.
99 // Kernel data is partially interleaved. Every 2 rows are interleaved at float8 granularity.
100 // The exception is that if FILTER_SIZE_X is odd the last row is not interleaved. The non
101 // interleaved row is padded with zero to ensure same size as interleaved rows. This
102 // interleaving is done to ensure 0% GDR bank conflicts. For example, this is how the
103 // kernel data would be arranged before/after interleaving for FILTER_SIZE_X=3.
104 // (0, 0) (8, 0) (16, 0) (24, 0) ... (0, 0) (0, 1) (8, 0) (0, 1) (16, 0) (0, 1) (24, 0) ..
105 // (0, 1) (8, 1) (16, 1) (24, 1) ... => (0, 2) (8, 2) (16, 2) (24, 2) ...
106 // (0, 2) (8, 2) (16, 2) (24, 2) ... ...
108 const bool kernel_width_is_odd = FILTER_SIZE_X % 2 == 1;
110 float blockA00[FILTER_SIZE_X];
111 float blockA01[FILTER_SIZE_X];
113 // in case the data is not aligned to sizeof(T)*FILTER_SIZE_X we need to use vload or set the data in a loop
116 LOOP(FILTER_SIZE_X, i,
119 if(src0_read_offset0_const + (FILTER_SIZE_Y - 1) * INPUT0_Y_PITCH + (INPUT0_FEATURE_NUM - 1) * (INPUT0_FEATURE_PITCH - ( FILTER_SIZE_Y * INPUT0_Y_PITCH )) >= INPUT0_BATCH_NUM * INPUT0_BATCH_PITCH)
121 if(src0_read_offset0 + i < INPUT0_BATCH_NUM * INPUT0_BATCH_PITCH)
122 blockA00[i] = src0[src0_read_offset0 + i];
126 blockA00[i] = src0[src0_read_offset0 + i];
129 if(src0_read_offset1_const + (FILTER_SIZE_Y - 1) * INPUT0_Y_PITCH + (INPUT0_FEATURE_NUM - 1) * (INPUT0_FEATURE_PITCH - ( FILTER_SIZE_Y * INPUT0_Y_PITCH )) >= INPUT0_BATCH_NUM * INPUT0_BATCH_PITCH)
131 if(src0_read_offset1 + i < INPUT0_BATCH_NUM * INPUT0_BATCH_PITCH)
132 blockA01[i] = src0[src0_read_offset1 + i];
136 blockA01[i] = src0[src0_read_offset1 + i];
140 float* pblockA00 = (float*)(&blockA00);
141 float* pblockA01 = (float*)(&blockA01);
143 src0_read_offset0 += INPUT0_Y_PITCH;
144 src0_read_offset1 += INPUT0_Y_PITCH;
147 float blockB00[FILTER_SIZE_X*4];
148 float8* p8BlockB00 = (float8*)blockB00;
149 float4* p4BlockB00 = (float4*)blockB00;
150 float* pBlockB00 = (float* )blockB00;
153 LOOP(FILTER_SIZE_X_DIV2, interleaved_y,
155 p8BlockB00[interleaved_y] = as_float8( intel_sub_group_block_read8( (const __global uint*)src1 + src1_read_offset ) );
156 src1_read_offset += ALIGNED_OFM * 2;
158 if ( kernel_width_is_odd )
160 p4BlockB00[FILTER_SIZE_X - 1] = as_float4( intel_sub_group_block_read4( (const __global uint*)src1 + src1_read_offset ) );
161 src1_read_offset += ALIGNED_OFM * 2;
167 LOOP(FILTER_SIZE_X_DIV2, interleaved_y,
169 kernel_y = interleaved_y * 2;
170 DOT_PRODUCT_8( blockC00, pblockA00[kernel_y ], pBlockB00[kernel_idx] );
171 DOT_PRODUCT_8( blockC01, pblockA01[kernel_y ], pBlockB00[kernel_idx] ); kernel_idx++;
172 DOT_PRODUCT_8( blockC00, pblockA00[kernel_y + 1], pBlockB00[kernel_idx] );
173 DOT_PRODUCT_8( blockC01, pblockA01[kernel_y + 1], pBlockB00[kernel_idx] ); kernel_idx++;
174 DOT_PRODUCT_8( blockC10, pblockA00[kernel_y ], pBlockB00[kernel_idx] );
175 DOT_PRODUCT_8( blockC11, pblockA01[kernel_y ], pBlockB00[kernel_idx] ); kernel_idx++;
176 DOT_PRODUCT_8( blockC10, pblockA00[kernel_y + 1], pBlockB00[kernel_idx] );
177 DOT_PRODUCT_8( blockC11, pblockA01[kernel_y + 1], pBlockB00[kernel_idx] ); kernel_idx++;
178 DOT_PRODUCT_8( blockC20, pblockA00[kernel_y ], pBlockB00[kernel_idx] );
179 DOT_PRODUCT_8( blockC21, pblockA01[kernel_y ], pBlockB00[kernel_idx] ); kernel_idx++;
180 DOT_PRODUCT_8( blockC20, pblockA00[kernel_y + 1], pBlockB00[kernel_idx] );
181 DOT_PRODUCT_8( blockC21, pblockA01[kernel_y + 1], pBlockB00[kernel_idx] ); kernel_idx++;
182 DOT_PRODUCT_8( blockC30, pblockA00[kernel_y ], pBlockB00[kernel_idx] );
183 DOT_PRODUCT_8( blockC31, pblockA01[kernel_y ], pBlockB00[kernel_idx] ); kernel_idx++;
184 DOT_PRODUCT_8( blockC30, pblockA00[kernel_y + 1], pBlockB00[kernel_idx] );
185 DOT_PRODUCT_8( blockC31, pblockA01[kernel_y + 1], pBlockB00[kernel_idx] ); kernel_idx++;
187 if ( kernel_width_is_odd )
189 kernel_y = interleaved_y * 2;
190 DOT_PRODUCT_8( blockC00, pblockA00[kernel_y], pBlockB00[kernel_idx] );
191 DOT_PRODUCT_8( blockC01, pblockA01[kernel_y], pBlockB00[kernel_idx] ); kernel_idx++;
192 DOT_PRODUCT_8( blockC10, pblockA00[kernel_y], pBlockB00[kernel_idx] );
193 DOT_PRODUCT_8( blockC11, pblockA01[kernel_y], pBlockB00[kernel_idx] ); kernel_idx++;
194 DOT_PRODUCT_8( blockC20, pblockA00[kernel_y], pBlockB00[kernel_idx] );
195 DOT_PRODUCT_8( blockC21, pblockA01[kernel_y], pBlockB00[kernel_idx] ); kernel_idx++;
196 DOT_PRODUCT_8( blockC30, pblockA00[kernel_y], pBlockB00[kernel_idx] );
197 DOT_PRODUCT_8( blockC31, pblockA01[kernel_y], pBlockB00[kernel_idx] ); kernel_idx++;
201 //while( ++patch_row < 1 ); //debug
202 while( ++patch_row < FILTER_SIZE_Y );
204 src0_read_offset0 += INPUT0_FEATURE_PITCH - ( FILTER_SIZE_Y * INPUT0_Y_PITCH ); // reset to start of next slice of patch
205 src0_read_offset1 += INPUT0_FEATURE_PITCH - ( FILTER_SIZE_Y * INPUT0_Y_PITCH ); // reset to start of next slice of patch
207 //while ( ++patch_depth < 1 ); //debug
208 while ( ++patch_depth < INPUT0_FEATURE_NUM );
210 const uint out_split_offset = split_idx * OUTPUT_FEATURE_PITCH * OUTPUT_FEATURE_NUM;
211 // Dst resembles a cube of width x height x (output channel * batches). Each tile writes:
212 // (SIMD * TILE_M) x 1 x TILE_N. Partial writes most likely generated if padding used.
213 __global float *out0 = dst + OUTPUT_OFFSET + out_split_offset
214 + global_z * OUTPUT_BATCH_PITCH // batch offset
215 + ( group_x * TILE_N ) * OUTPUT_FEATURE_PITCH // channel offset
216 + ( ( global_y * TILE_M ) / OUTPUT_SIZE_X ) * OUTPUT_Y_PITCH // y offset
217 + ( ( global_y * TILE_M ) % OUTPUT_SIZE_X ); // x offset
218 __global float *out1 = dst + OUTPUT_OFFSET + out_split_offset
219 + global_z * OUTPUT_BATCH_PITCH // batch offset
220 + ( group_x * TILE_N ) * OUTPUT_FEATURE_PITCH // channel offset
221 + ( ( global_y * TILE_M + 1 ) / OUTPUT_SIZE_X ) * OUTPUT_Y_PITCH // y offset
222 + ( ( global_y * TILE_M + 1 ) % OUTPUT_SIZE_X ); // x offset
225 __global float8* biasPtr = (__global float8*) (bias + group_x * TILE_N);
228 if( global_y * TILE_M < OUTPUT_SIZE_X * OUTPUT_SIZE_Y )
230 if ( ( OUTPUT_FEATURE_NUM % TILE_N ) == 0 )
233 blockC00 += *biasPtr;
234 blockC10 += *(biasPtr + 1);
235 blockC20 += *(biasPtr + 2);
236 blockC30 += *(biasPtr + 3);
239 blockC00 = ACTIVATION(blockC00, NL_M, NL_N);
240 blockC10 = ACTIVATION(blockC10, NL_M, NL_N);
241 blockC20 = ACTIVATION(blockC20, NL_M, NL_N);
242 blockC30 = ACTIVATION(blockC30, NL_M, NL_N);
244 for( unsigned i = 0; i < 8; i++ )
246 out0[( 0+i) * OUTPUT_FEATURE_PITCH] = blockC00[i];
247 out0[( 8+i) * OUTPUT_FEATURE_PITCH] = blockC10[i];
248 out0[(16+i) * OUTPUT_FEATURE_PITCH] = blockC20[i];
249 out0[(24+i) * OUTPUT_FEATURE_PITCH] = blockC30[i];
254 if ( ( global_x + 1 ) < get_global_size(0) )
257 blockC00 += *biasPtr;
258 blockC10 += *(biasPtr + 1);
259 blockC20 += *(biasPtr + 2);
260 blockC30 += *(biasPtr + 3);
263 blockC00 = ACTIVATION(blockC00, NL_M, NL_N);
264 blockC10 = ACTIVATION(blockC10, NL_M, NL_N);
265 blockC20 = ACTIVATION(blockC20, NL_M, NL_N);
266 blockC30 = ACTIVATION(blockC30, NL_M, NL_N);
268 for ( unsigned i = 0; i < 8; i++ )
270 out0[( 0+i) * OUTPUT_FEATURE_PITCH] = blockC00[i];
271 out0[( 8+i) * OUTPUT_FEATURE_PITCH] = blockC10[i];
272 out0[(16+i) * OUTPUT_FEATURE_PITCH] = blockC20[i];
273 out0[(24+i) * OUTPUT_FEATURE_PITCH] = blockC30[i];
278 if ( ( OUTPUT_FEATURE_NUM % TILE_N ) >= 24 )
281 blockC00 += *biasPtr;
282 blockC10 += *(biasPtr + 1);
283 blockC20 += *(biasPtr + 2);
284 if (( OUTPUT_FEATURE_NUM % TILE_N) > 24 ) blockC30 += *(biasPtr + 3);
287 blockC00 = ACTIVATION(blockC00, NL_M, NL_N);
288 blockC10 = ACTIVATION(blockC10, NL_M, NL_N);
289 blockC20 = ACTIVATION(blockC20, NL_M, NL_N);
291 for (unsigned i = 0; i < 8; i++)
293 out0[( 0+i) * OUTPUT_FEATURE_PITCH] = blockC00[i];
294 out0[( 8+i) * OUTPUT_FEATURE_PITCH] = blockC10[i];
295 out0[(16+i) * OUTPUT_FEATURE_PITCH] = blockC20[i];
298 // remaining output channels
299 for (unsigned i = 0; i < OUTPUT_FEATURE_NUM % 8; i++)
301 out0[(24+i) * OUTPUT_FEATURE_PITCH] = ACTIVATION(blockC30[i], NL_M, NL_N);
304 else if ( ( OUTPUT_FEATURE_NUM % TILE_N ) >= 16 )
307 blockC00 += *biasPtr;
308 blockC10 += *(biasPtr + 1);
309 if (( OUTPUT_FEATURE_NUM % TILE_N) > 16 )
310 blockC20 += *(biasPtr + 2);
313 blockC00 = ACTIVATION(blockC00, NL_M, NL_N);
314 blockC10 = ACTIVATION(blockC10, NL_M, NL_N);
316 for (unsigned i = 0; i < 8; i++)
318 out0[( 0+i) * OUTPUT_FEATURE_PITCH] = blockC00[i];
319 out0[( 8+i) * OUTPUT_FEATURE_PITCH] = blockC10[i];
322 for (unsigned i = 0; i < OUTPUT_FEATURE_NUM % 8; i++)
324 out0[(16+i) * OUTPUT_FEATURE_PITCH] = ACTIVATION(blockC20[i], NL_M, NL_N);
328 else if ( ( OUTPUT_FEATURE_NUM % TILE_N ) >= 8 )
331 blockC00 += *biasPtr;
332 if (( OUTPUT_FEATURE_NUM % TILE_N) > 8 )
333 blockC10 += *(biasPtr + 1);
336 blockC00 = ACTIVATION(blockC00, NL_M, NL_N);
338 for (unsigned i = 0; i < 8; i++)
340 out0[( 0+i) * OUTPUT_FEATURE_PITCH] = blockC00[i];
343 for (unsigned i = 0; i < OUTPUT_FEATURE_NUM % 8; i++)
345 out0[(8+i) * OUTPUT_FEATURE_PITCH] = ACTIVATION(blockC10[i], NL_M, NL_N);
351 blockC00 += *biasPtr;
353 for (unsigned i = 0; i < OUTPUT_FEATURE_NUM % 8; i++)
355 out0[( 0+i) * OUTPUT_FEATURE_PITCH] = ACTIVATION(blockC00[i], NL_M, NL_N);
362 if ((global_y * TILE_M + 1) < OUTPUT_SIZE_X * OUTPUT_SIZE_Y )
364 if ( ( OUTPUT_FEATURE_NUM % TILE_N ) == 0 )
367 blockC01 += *biasPtr;
368 blockC11 += *(biasPtr + 1);
369 blockC21 += *(biasPtr + 2);
370 blockC31 += *(biasPtr + 3);
373 blockC01 = ACTIVATION(blockC01, NL_M, NL_N);
374 blockC11 = ACTIVATION(blockC11, NL_M, NL_N);
375 blockC21 = ACTIVATION(blockC21, NL_M, NL_N);
376 blockC31 = ACTIVATION(blockC31, NL_M, NL_N);
378 for( unsigned i = 0; i < 8; i++ )
380 out1[( 0+i) * OUTPUT_FEATURE_PITCH] = blockC01[i];
381 out1[( 8+i) * OUTPUT_FEATURE_PITCH] = blockC11[i];
382 out1[(16+i) * OUTPUT_FEATURE_PITCH] = blockC21[i];
383 out1[(24+i) * OUTPUT_FEATURE_PITCH] = blockC31[i];
388 if ( ( global_x + 1 ) < get_global_size(0) )
391 blockC01 += *biasPtr;
392 blockC11 += *(biasPtr + 1);
393 blockC21 += *(biasPtr + 2);
394 blockC31 += *(biasPtr + 3);
397 blockC01 = ACTIVATION(blockC01, NL_M, NL_N);
398 blockC11 = ACTIVATION(blockC11, NL_M, NL_N);
399 blockC21 = ACTIVATION(blockC21, NL_M, NL_N);
400 blockC31 = ACTIVATION(blockC31, NL_M, NL_N);
402 for ( unsigned i = 0; i < 8; i++ )
404 out1[( 0+i) * OUTPUT_FEATURE_PITCH] = blockC01[i];
405 out1[( 8+i) * OUTPUT_FEATURE_PITCH] = blockC11[i];
406 out1[(16+i) * OUTPUT_FEATURE_PITCH] = blockC21[i];
407 out1[(24+i) * OUTPUT_FEATURE_PITCH] = blockC31[i];
412 if ( ( OUTPUT_FEATURE_NUM % TILE_N ) >= 24 )
415 blockC01 += *biasPtr;
416 blockC11 += *(biasPtr + 1);
417 blockC21 += *(biasPtr + 2);
418 if ( ( OUTPUT_FEATURE_NUM % TILE_N ) > 24 ) blockC31 += *(biasPtr + 3);
421 blockC01 = ACTIVATION(blockC01, NL_M, NL_N);
422 blockC11 = ACTIVATION(blockC11, NL_M, NL_N);
423 blockC21 = ACTIVATION(blockC21, NL_M, NL_N);
425 for (unsigned i = 0; i < 8; i++)
427 out1[( 0+i) * OUTPUT_FEATURE_PITCH] = blockC01[i];
428 out1[( 8+i) * OUTPUT_FEATURE_PITCH] = blockC11[i];
429 out1[(16+i) * OUTPUT_FEATURE_PITCH] = blockC21[i];
432 // Remaining channels
433 for (unsigned i = 0; i < OUTPUT_FEATURE_NUM % 8; i++)
435 out1[(24+i) * OUTPUT_FEATURE_PITCH] = ACTIVATION(blockC31[i], NL_M, NL_N);
438 else if ( ( OUTPUT_FEATURE_NUM % TILE_N ) >= 16 )
441 blockC01 += *biasPtr;
442 blockC11 += *(biasPtr + 1);
443 if ( ( OUTPUT_FEATURE_NUM % TILE_N ) > 16 ) blockC21 += *(biasPtr + 2);
446 blockC01 = ACTIVATION(blockC01, NL_M, NL_N);
447 blockC11 = ACTIVATION(blockC11, NL_M, NL_N);
449 for (unsigned i = 0; i < 8; i++)
451 out1[( 0+i) * OUTPUT_FEATURE_PITCH] = blockC01[i];
452 out1[( 8+i) * OUTPUT_FEATURE_PITCH] = blockC11[i];
455 for (unsigned i = 0; i < OUTPUT_FEATURE_NUM % 8; i++)
457 out1[(16+i) * OUTPUT_FEATURE_PITCH] = ACTIVATION(blockC21[i], NL_M, NL_N);
460 else if ( ( OUTPUT_FEATURE_NUM % TILE_N ) >= 8 )
463 blockC01 += *biasPtr;
464 if ( ( OUTPUT_FEATURE_NUM % TILE_N ) > 8 ) blockC11 += *(biasPtr + 1);
467 blockC01 = ACTIVATION(blockC01, NL_M, NL_N);
469 for (unsigned i = 0; i < 8; i++)
471 out1[( 0+i) * OUTPUT_FEATURE_PITCH] = blockC01[i];
474 for (unsigned i = 0; i < OUTPUT_FEATURE_NUM % 8; i++)
476 out1[(8+i) * OUTPUT_FEATURE_PITCH] = ACTIVATION(blockC11[i], NL_M, NL_N);
482 blockC01 += *biasPtr;
485 for (unsigned i = 0; i < OUTPUT_FEATURE_NUM % 8; i++)
487 out1[( 0+i) * OUTPUT_FEATURE_PITCH] = ACTIVATION(blockC01[i], NL_M, NL_N);