26 #if defined(NUM_CHANNELS) 51 __kernel
void winograd_filter_transform_2x2_3x3_nchw(
60 float3 w0 = vload3(0, (__global
float *)(src_addr + 0 * src_stride_y));
61 float3 w1 = vload3(0, (__global
float *)(src_addr + 1 * src_stride_y));
62 float3 w2 = vload3(0, (__global
float *)(src_addr + 2 * src_stride_y));
72 out0.s1 = (w0.s0 + w0.s1 + w0.s2) * 0.5f;
73 out0.s2 = (w0.s0 + w0.s2 - w0.s1) * 0.5f;
77 out1.s0 = (w0.s0 + w1.s0 + w2.s0) * 0.5f;
78 out1.s1 = (w0.s0 + w1.s0 + w2.s0 + w0.s1 + w1.s1 + w2.s1 + w0.s2 + w1.s2 + w2.s2) * 0.25f;
79 out1.s2 = (w0.s0 + w1.s0 + w2.s0 + w0.s2 + w1.s2 + w2.s2 - w0.s1 - w1.s1 - w2.s1) * 0.25f;
80 out1.s3 = (w0.s2 + w1.s2 + w2.s2) * 0.5f;
83 out2.s0 = (w0.s0 + w2.s0 - w1.s0) * 0.5f;
84 out2.s1 = (w0.s0 + w2.s0 + w0.s1 + w2.s1 + w0.s2 + w2.s2 - w1.s0 - w1.s1 - w1.s2) * 0.25f;
85 out2.s2 = (w0.s0 + w2.s0 + w1.s1 + w0.s2 + w2.s2 - w1.s0 - w0.s1 - w2.s1 - w1.s2) * 0.25f;
86 out2.s3 = (w0.s2 + w2.s2 - w1.s2) * 0.5f;
90 out3.s1 = (w2.s0 + w2.s1 + w2.s2) * 0.5f;
91 out3.s2 = (w2.s0 + w2.s2 - w2.s1) * 0.5f;
94 int z = get_global_id(2);
95 int x0 = z / NUM_CHANNELS;
96 int y0 = z % NUM_CHANNELS;
99 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
102 *(__global
float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
103 *(__global
float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
104 *(__global
float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
105 *(__global
float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
106 *(__global
float *)(dst_addr + 4 * dst_stride_z) = out1.s0;
107 *(__global
float *)(dst_addr + 5 * dst_stride_z) = out1.s1;
108 *(__global
float *)(dst_addr + 6 * dst_stride_z) = out1.s2;
109 *(__global
float *)(dst_addr + 7 * dst_stride_z) = out1.s3;
110 *(__global
float *)(dst_addr + 8 * dst_stride_z) = out2.s0;
111 *(__global
float *)(dst_addr + 9 * dst_stride_z) = out2.s1;
112 *(__global
float *)(dst_addr + 10 * dst_stride_z) = out2.s2;
113 *(__global
float *)(dst_addr + 11 * dst_stride_z) = out2.s3;
114 *(__global
float *)(dst_addr + 12 * dst_stride_z) = out3.s0;
115 *(__global
float *)(dst_addr + 13 * dst_stride_z) = out3.s1;
116 *(__global
float *)(dst_addr + 14 * dst_stride_z) = out3.s2;
117 *(__global
float *)(dst_addr + 15 * dst_stride_z) = out3.s3;
143 __kernel
void winograd_filter_transform_4x4_3x3_nchw(
152 float3 w0 = vload3(0, (__global
float *)(src_addr + 0 * src_stride_y));
153 float3 w1 = vload3(0, (__global
float *)(src_addr + 1 * src_stride_y));
154 float3 w2 = vload3(0, (__global
float *)(src_addr + 2 * src_stride_y));
165 out0.s0 = (w0.s0) / 16.f;
166 out0.s1 = (-w0.s0 - w0.s1 - w0.s2) / 24.f;
167 out0.s2 = (-w0.s0 + w0.s1 - w0.s2) / 24.f;
168 out0.s3 = (w0.s0 + 2.f * w0.s1 + 4.f * w0.s2) / 96.f;
169 out0.s4 = (w0.s0 - 2.f * w0.s1 + 4.f * w0.s2) / 96.f;
170 out0.s5 = (w0.s2) / 4.f;
173 out1.s0 = (-w0.s0 - w1.s0 - w2.s0) / 24.f;
174 out1.s1 = (w0.s0 + w1.s0 + w2.s0 + w0.s1 + w1.s1 + w2.s1 + w0.s2 + w1.s2 + w2.s2) / 36.f;
175 out1.s2 = (w0.s0 + w1.s0 + w2.s0 - w0.s1 - w1.s1 - w2.s1 + w0.s2 + w1.s2 + w2.s2) / 36.f;
176 out1.s3 = (-w0.s0 - w1.s0 - w2.s0 + 2.f * (-w0.s1 - w1.s1 - w2.s1) + 4.f * (-w0.s2 - w1.s2 - w2.s2)) / 144.f;
177 out1.s4 = (-w0.s0 - w1.s0 - w2.s0 + 2.f * (w0.s1 + w1.s1 + w2.s1) + 4.f * (-w0.s2 - w1.s2 - w2.s2)) / 144.f;
178 out1.s5 = (-w0.s2 - w1.s2 - w2.s2) / 6.f;
181 out2.s0 = (-w0.s0 + w1.s0 - w2.s0) / 24.f;
182 out2.s1 = (w0.s0 - w1.s0 + w2.s0 + w0.s1 - w1.s1 + w2.s1 + w0.s2 - w1.s2 + w2.s2) / 36.f;
183 out2.s2 = (w0.s0 - w1.s0 + w2.s0 - w0.s1 + w1.s1 - w2.s1 + w0.s2 - w1.s2 + w2.s2) / 36.f;
184 out2.s3 = (-w0.s0 + w1.s0 - w2.s0 + 2.f * (-w0.s1 + w1.s1 - w2.s1) + 4.f * (-w0.s2 + w1.s2 - w2.s2)) / 144.f;
185 out2.s4 = (-w0.s0 + w1.s0 - w2.s0 + 2.f * (w0.s1 - w1.s1 + w2.s1) + 4.f * (-w0.s2 + w1.s2 - w2.s2)) / 144.f;
186 out2.s5 = (-w0.s2 + w1.s2 - w2.s2) / 6.f;
189 out3.s0 = (w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) / 96.f;
190 out3.s1 = (-w0.s0 - 2.f * w1.s0 - 4.f * w2.s0 - w0.s1 - 2.f * w1.s1 - 4.f * w2.s1 - w0.s2 - 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
191 out3.s2 = (-w0.s0 - 2.f * w1.s0 - 4.f * w2.s0 + w0.s1 + 2.f * w1.s1 + 4.f * w2.s1 - w0.s2 - 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
192 out3.s3 = ((w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (w0.s1 + 2.f * w1.s1 + 4.f * w2.s1) + 4.f * (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
193 out3.s4 = ((w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (-w0.s1 - 2.f * w1.s1 - 4.f * w2.s1) + 4.f * (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
194 out3.s5 = (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2) / 24.f;
197 out4.s0 = (w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) / 96.f;
198 out4.s1 = (-w0.s0 + 2.f * w1.s0 - 4.f * w2.s0 - w0.s1 + 2.f * w1.s1 - 4.f * w2.s1 - w0.s2 + 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
199 out4.s2 = (-w0.s0 + 2.f * w1.s0 - 4.f * w2.s0 + w0.s1 - 2.f * w1.s1 + 4.f * w2.s1 - w0.s2 + 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
200 out4.s3 = ((w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (w0.s1 - 2.f * w1.s1 + 4.f * w2.s1) + 4.f * (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
201 out4.s4 = ((w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (-w0.s1 + 2.f * w1.s1 - 4.f * w2.s1) + 4.f * (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
202 out4.s5 = (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2) / 24.f;
205 out5.s0 = (w2.s0) / 4.f;
206 out5.s1 = (-w2.s0 - w2.s1 - w2.s2) / 6.f;
207 out5.s2 = (-w2.s0 + w2.s1 - w2.s2) / 6.f;
208 out5.s3 = (w2.s0 + 2.f * w2.s1 + 4.f * w2.s2) / 24.f;
209 out5.s4 = (w2.s0 - 2.f * w2.s1 + 4.f * w2.s2) / 24.f;
212 int z = get_global_id(2);
213 int x0 = z / NUM_CHANNELS;
214 int y0 = z % NUM_CHANNELS;
217 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
220 *(__global
float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
221 *(__global
float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
222 *(__global
float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
223 *(__global
float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
224 *(__global
float *)(dst_addr + 4 * dst_stride_z) = out0.s4;
225 *(__global
float *)(dst_addr + 5 * dst_stride_z) = out0.s5;
226 *(__global
float *)(dst_addr + 6 * dst_stride_z) = out1.s0;
227 *(__global
float *)(dst_addr + 7 * dst_stride_z) = out1.s1;
228 *(__global
float *)(dst_addr + 8 * dst_stride_z) = out1.s2;
229 *(__global
float *)(dst_addr + 9 * dst_stride_z) = out1.s3;
230 *(__global
float *)(dst_addr + 10 * dst_stride_z) = out1.s4;
231 *(__global
float *)(dst_addr + 11 * dst_stride_z) = out1.s5;
232 *(__global
float *)(dst_addr + 12 * dst_stride_z) = out2.s0;
233 *(__global
float *)(dst_addr + 13 * dst_stride_z) = out2.s1;
234 *(__global
float *)(dst_addr + 14 * dst_stride_z) = out2.s2;
235 *(__global
float *)(dst_addr + 15 * dst_stride_z) = out2.s3;
236 *(__global
float *)(dst_addr + 16 * dst_stride_z) = out2.s4;
237 *(__global
float *)(dst_addr + 17 * dst_stride_z) = out2.s5;
238 *(__global
float *)(dst_addr + 18 * dst_stride_z) = out3.s0;
239 *(__global
float *)(dst_addr + 19 * dst_stride_z) = out3.s1;
240 *(__global
float *)(dst_addr + 20 * dst_stride_z) = out3.s2;
241 *(__global
float *)(dst_addr + 21 * dst_stride_z) = out3.s3;
242 *(__global
float *)(dst_addr + 22 * dst_stride_z) = out3.s4;
243 *(__global
float *)(dst_addr + 23 * dst_stride_z) = out3.s5;
244 *(__global
float *)(dst_addr + 24 * dst_stride_z) = out4.s0;
245 *(__global
float *)(dst_addr + 25 * dst_stride_z) = out4.s1;
246 *(__global
float *)(dst_addr + 26 * dst_stride_z) = out4.s2;
247 *(__global
float *)(dst_addr + 27 * dst_stride_z) = out4.s3;
248 *(__global
float *)(dst_addr + 28 * dst_stride_z) = out4.s4;
249 *(__global
float *)(dst_addr + 29 * dst_stride_z) = out4.s5;
250 *(__global
float *)(dst_addr + 30 * dst_stride_z) = out5.s0;
251 *(__global
float *)(dst_addr + 31 * dst_stride_z) = out5.s1;
252 *(__global
float *)(dst_addr + 32 * dst_stride_z) = out5.s2;
253 *(__global
float *)(dst_addr + 33 * dst_stride_z) = out5.s3;
254 *(__global
float *)(dst_addr + 34 * dst_stride_z) = out5.s4;
255 *(__global
float *)(dst_addr + 35 * dst_stride_z) = out5.s5;
281 __kernel
void winograd_filter_transform_4x4_5x5_nchw(
290 const char stride_x = 4 *
sizeof(float);
291 const uchar8 stride_y = (uchar8)(0, 1, 2, 3, 4, 0, 0, 0) * (uchar8)src_stride_y;
293 float4 w00 = vload4(0, (__global
float *)(src_addr + stride_y.s0));
294 float w01 = *((__global
float *)(src_addr + stride_y.s0 + stride_x));
295 float4 w10 = vload4(0, (__global
float *)(src_addr + stride_y.s1));
296 float w11 = *((__global
float *)(src_addr + stride_y.s1 + stride_x));
297 float4 w20 = vload4(0, (__global
float *)(src_addr + stride_y.s2));
298 float w21 = *((__global
float *)(src_addr + stride_y.s2 + stride_x));
299 float4 w30 = vload4(0, (__global
float *)(src_addr + stride_y.s3));
300 float w31 = *((__global
float *)(src_addr + stride_y.s3 + stride_x));
301 float4 w40 = vload4(0, (__global
float *)(src_addr + stride_y.s4));
302 float w41 = *((__global
float *)(src_addr + stride_y.s4 + stride_x));
316 out0.s1 = -2.f * (w00.s0 + w00.s1 + w00.s2 + w00.s3 + w01) / 9.f;
317 out0.s2 = -2.f * (w00.s0 - w00.s1 + w00.s2 - w00.s3 + w01) / 9.f;
318 out0.s3 = (w00.s0 + 2.f * w00.s1 + 4.f * w00.s2 + 8.f * w00.s3 + 16.f * w01) / 90.f;
319 out0.s4 = (w00.s0 - 2.f * w00.s1 + 4.f * w00.s2 - 8.f * w00.s3 + 16.f * w01) / 90.f;
320 out0.s5 = (16.f * w00.s0 + 8.f * w00.s1 + 4.f * w00.s2 + 2.f * w00.s3 + w01) / 180.f;
321 out0.s6 = (16.f * w00.s0 - 8.f * w00.s1 + 4.f * w00.s2 - 2.f * w00.s3 + w01) / 180.f;
325 out1.s0 = -2.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) / 9.f;
326 out1.s1 = 4.f * ((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) +
327 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 81.f;
328 out1.s2 = 4.f * ((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) -
329 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 81.f;
330 out1.s3 = -((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + 2.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) + 8.f *
331 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + 16.f * (w01 + w11 + w21 + w31 + w41)) / 405.f;
332 out1.s4 = -((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - 2.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) - 8.f *
333 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + 16.f * (w01 + w11 + w21 + w31 + w41)) / 405.f;
334 out1.s5 = -(16.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + 8.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) + 2.f *
335 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 810.f;
336 out1.s6 = -(16.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - 8.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) - 2.f *
337 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 810.f;
338 out1.s7 = -2.f * (w01 + w11 + w21 + w31 + w41) / 9.f;
341 out2.s0 = -2.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) / 9.f;
342 out2.s1 = 4.f * ((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) +
343 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 81.f;
344 out2.s2 = 4.f * ((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) -
345 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 81.f;
346 out2.s3 = -((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + 2.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) + 8.f *
347 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + 16.f * (w01 - w11 + w21 - w31 + w41)) / 405.f;
348 out2.s4 = -((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - 2.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) - 8.f *
349 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + 16.f * (w01 - w11 + w21 - w31 + w41)) / 405.f;
350 out2.s5 = -(16.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + 8.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) + 2.f *
351 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 810.f;
352 out2.s6 = -(16.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - 8.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) - 2.f *
353 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 810.f;
354 out2.s7 = -2.f * (w01 - w11 + w21 - w31 + w41) / 9.f;
357 out3.s0 = (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) / 90.f;
358 out3.s1 = -((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) +
359 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
360 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 405.f;
361 out3.s2 = -((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) +
362 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
363 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 405.f;
364 out3.s3 = ((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + 2.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
365 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + 8.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
366 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 8100.f;
367 out3.s4 = ((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - 2.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
368 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - 8.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
369 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 8100.f;
370 out3.s5 = (16.f * (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + 8.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
371 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + 2.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
372 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 16200.f;
373 out3.s6 = (16.f * (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - 8.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
374 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - 2.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
375 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 16200.f;
376 out3.s7 = (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) / 90.f;
379 out4.s0 = (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) / 90.f;
380 out4.s1 = -((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) +
381 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
382 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 405.f;
383 out4.s2 = -((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) +
384 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
385 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 405.f;
386 out4.s3 = ((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + 2.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
387 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + 8.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
388 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 8100.f;
389 out4.s4 = ((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - 2.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
390 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - 8.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
391 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 8100.f;
392 out4.s5 = (16.f * (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + 8.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
393 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + 2.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
394 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 16200.f;
395 out4.s6 = (16.f * (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - 8.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
396 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - 2.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
397 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 16200.f;
398 out4.s7 = (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) / 90.f;
401 out5.s0 = (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) / 180.f;
402 out5.s1 = -((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) +
403 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
404 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 810.f;
405 out5.s2 = -((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) +
406 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
407 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 810.f;
408 out5.s3 = ((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + 2.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
409 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + 8.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + 16.f *
410 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 16200.f;
411 out5.s4 = ((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - 2.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
412 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - 8.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + 16.f *
413 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 16200.f;
414 out5.s5 = (16.f * (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + 8.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
415 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + 2.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
416 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 32400.f;
417 out5.s6 = (16.f * (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - 8.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
418 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - 2.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
419 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 32400.f;
420 out5.s7 = (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) / 180.f;
423 out6.s0 = (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) / 180.f;
424 out6.s1 = -((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) +
425 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
426 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 810.f;
427 out6.s2 = -((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) +
428 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
429 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 810.f;
430 out6.s3 = ((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + 2.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
431 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + 8.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + 16.f *
432 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 16200.f;
433 out6.s4 = ((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - 2.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
434 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - 8.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + 16.f *
435 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 16200.f;
436 out6.s5 = (16.f * (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + 8.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
437 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + 2.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
438 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 32400.f;
439 out6.s6 = (16.f * (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - 8.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
440 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - 2.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
441 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 32400.f;
442 out6.s7 = (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) / 180.f;
446 out7.s1 = -2.f * (w40.s0 + w40.s1 + w40.s2 + w40.s3 + w41) / 9.f;
447 out7.s2 = -2.f * (w40.s0 - w40.s1 + w40.s2 - w40.s3 + w41) / 9.f;
448 out7.s3 = (w40.s0 + 2.f * w40.s1 + 4.f * w40.s2 + 8.f * w40.s3 + 16.f * w41) / 90.f;
449 out7.s4 = (w40.s0 - 2.f * w40.s1 + 4.f * w40.s2 - 8.f * w40.s3 + 16.f * w41) / 90.f;
450 out7.s5 = (16.f * w40.s0 + 8.f * w40.s1 + 4.f * w40.s2 + 2.f * w40.s3 + w41) / 180.f;
451 out7.s6 = (16.f * w40.s0 - 8.f * w40.s1 + 4.f * w40.s2 - 2.f * w40.s3 + w41) / 180.f;
454 int z = get_global_id(2);
455 int x0 = z / NUM_CHANNELS;
456 int y0 = z % NUM_CHANNELS;
459 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
462 *(__global
float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
463 *(__global
float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
464 *(__global
float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
465 *(__global
float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
466 *(__global
float *)(dst_addr + 4 * dst_stride_z) = out0.s4;
467 *(__global
float *)(dst_addr + 5 * dst_stride_z) = out0.s5;
468 *(__global
float *)(dst_addr + 6 * dst_stride_z) = out0.s6;
469 *(__global
float *)(dst_addr + 7 * dst_stride_z) = out0.s7;
470 *(__global
float *)(dst_addr + 8 * dst_stride_z) = out1.s0;
471 *(__global
float *)(dst_addr + 9 * dst_stride_z) = out1.s1;
472 *(__global
float *)(dst_addr + 10 * dst_stride_z) = out1.s2;
473 *(__global
float *)(dst_addr + 11 * dst_stride_z) = out1.s3;
474 *(__global
float *)(dst_addr + 12 * dst_stride_z) = out1.s4;
475 *(__global
float *)(dst_addr + 13 * dst_stride_z) = out1.s5;
476 *(__global
float *)(dst_addr + 14 * dst_stride_z) = out1.s6;
477 *(__global
float *)(dst_addr + 15 * dst_stride_z) = out1.s7;
478 *(__global
float *)(dst_addr + 16 * dst_stride_z) = out2.s0;
479 *(__global
float *)(dst_addr + 17 * dst_stride_z) = out2.s1;
480 *(__global
float *)(dst_addr + 18 * dst_stride_z) = out2.s2;
481 *(__global
float *)(dst_addr + 19 * dst_stride_z) = out2.s3;
482 *(__global
float *)(dst_addr + 20 * dst_stride_z) = out2.s4;
483 *(__global
float *)(dst_addr + 21 * dst_stride_z) = out2.s5;
484 *(__global
float *)(dst_addr + 22 * dst_stride_z) = out2.s6;
485 *(__global
float *)(dst_addr + 23 * dst_stride_z) = out2.s7;
486 *(__global
float *)(dst_addr + 24 * dst_stride_z) = out3.s0;
487 *(__global
float *)(dst_addr + 25 * dst_stride_z) = out3.s1;
488 *(__global
float *)(dst_addr + 26 * dst_stride_z) = out3.s2;
489 *(__global
float *)(dst_addr + 27 * dst_stride_z) = out3.s3;
490 *(__global
float *)(dst_addr + 28 * dst_stride_z) = out3.s4;
491 *(__global
float *)(dst_addr + 29 * dst_stride_z) = out3.s5;
492 *(__global
float *)(dst_addr + 30 * dst_stride_z) = out3.s6;
493 *(__global
float *)(dst_addr + 31 * dst_stride_z) = out3.s7;
494 *(__global
float *)(dst_addr + 32 * dst_stride_z) = out4.s0;
495 *(__global
float *)(dst_addr + 33 * dst_stride_z) = out4.s1;
496 *(__global
float *)(dst_addr + 34 * dst_stride_z) = out4.s2;
497 *(__global
float *)(dst_addr + 35 * dst_stride_z) = out4.s3;
498 *(__global
float *)(dst_addr + 36 * dst_stride_z) = out4.s4;
499 *(__global
float *)(dst_addr + 37 * dst_stride_z) = out4.s5;
500 *(__global
float *)(dst_addr + 38 * dst_stride_z) = out4.s6;
501 *(__global
float *)(dst_addr + 39 * dst_stride_z) = out4.s7;
502 *(__global
float *)(dst_addr + 40 * dst_stride_z) = out5.s0;
503 *(__global
float *)(dst_addr + 41 * dst_stride_z) = out5.s1;
504 *(__global
float *)(dst_addr + 42 * dst_stride_z) = out5.s2;
505 *(__global
float *)(dst_addr + 43 * dst_stride_z) = out5.s3;
506 *(__global
float *)(dst_addr + 44 * dst_stride_z) = out5.s4;
507 *(__global
float *)(dst_addr + 45 * dst_stride_z) = out5.s5;
508 *(__global
float *)(dst_addr + 46 * dst_stride_z) = out5.s6;
509 *(__global
float *)(dst_addr + 47 * dst_stride_z) = out5.s7;
510 *(__global
float *)(dst_addr + 48 * dst_stride_z) = out6.s0;
511 *(__global
float *)(dst_addr + 49 * dst_stride_z) = out6.s1;
512 *(__global
float *)(dst_addr + 50 * dst_stride_z) = out6.s2;
513 *(__global
float *)(dst_addr + 51 * dst_stride_z) = out6.s3;
514 *(__global
float *)(dst_addr + 52 * dst_stride_z) = out6.s4;
515 *(__global
float *)(dst_addr + 53 * dst_stride_z) = out6.s5;
516 *(__global
float *)(dst_addr + 54 * dst_stride_z) = out6.s6;
517 *(__global
float *)(dst_addr + 55 * dst_stride_z) = out6.s7;
518 *(__global
float *)(dst_addr + 56 * dst_stride_z) = out7.s0;
519 *(__global
float *)(dst_addr + 57 * dst_stride_z) = out7.s1;
520 *(__global
float *)(dst_addr + 58 * dst_stride_z) = out7.s2;
521 *(__global
float *)(dst_addr + 59 * dst_stride_z) = out7.s3;
522 *(__global
float *)(dst_addr + 60 * dst_stride_z) = out7.s4;
523 *(__global
float *)(dst_addr + 61 * dst_stride_z) = out7.s5;
524 *(__global
float *)(dst_addr + 62 * dst_stride_z) = out7.s6;
525 *(__global
float *)(dst_addr + 63 * dst_stride_z) = out7.s7;
527 #endif // defined(NUM_CHANNELS) 529 #if defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP) 552 __kernel
void winograd_input_transform_2x2_3x3_stepz1_nchw(
556 int x = get_global_id(0);
557 int y = get_global_id(1);
558 int z = get_global_id(2);
561 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * 2 * src_stride_x + y * 2 * src_stride_y + z * src_stride_z;
563 src_addr = src_addr - ((int)PAD_LEFT * src_stride_x) - ((int)PAD_TOP * src_stride_y);
565 float4 in_row0 = vload4(0, (__global
float *)(src_addr + 0 * src_stride_y));
566 float4 in_row1 = vload4(0, (__global
float *)(src_addr + 1 * src_stride_y));
567 float4 in_row2 = vload4(0, (__global
float *)(src_addr + 2 * src_stride_y));
568 float4 in_row3 = vload4(0, (__global
float *)(src_addr + 3 * src_stride_y));
570 float4 tmp0 = in_row0 - in_row2;
571 float4 tmp1 = in_row1 + in_row2;
572 float4 tmp2 = in_row2 - in_row1;
573 float4 tmp3 = in_row1 - in_row3;
575 float out00 = tmp0.s0 - tmp0.s2;
576 float out01 = tmp0.s1 + tmp0.s2;
577 float out02 = tmp0.s2 - tmp0.s1;
578 float out03 = tmp0.s1 - tmp0.s3;
580 float out10 = tmp1.s0 - tmp1.s2;
581 float out11 = tmp1.s1 + tmp1.s2;
582 float out12 = tmp1.s2 - tmp1.s1;
583 float out13 = tmp1.s1 - tmp1.s3;
585 float out20 = tmp2.s0 - tmp2.s2;
586 float out21 = tmp2.s1 + tmp2.s2;
587 float out22 = tmp2.s2 - tmp2.s1;
588 float out23 = tmp2.s1 - tmp2.s3;
590 float out30 = tmp3.s0 - tmp3.s2;
591 float out31 = tmp3.s1 + tmp3.s2;
592 float out32 = tmp3.s2 - tmp3.s1;
593 float out33 = tmp3.s1 - tmp3.s3;
595 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * dst_stride_x + (x + y * (int)NUM_TILES_X) * dst_stride_y;
597 *((__global
float *)(dst_addr + 0 * dst_stride_z)) = out00;
598 *((__global
float *)(dst_addr + 1 * dst_stride_z)) = out01;
599 *((__global
float *)(dst_addr + 2 * dst_stride_z)) = out02;
600 *((__global
float *)(dst_addr + 3 * dst_stride_z)) = out03;
601 *((__global
float *)(dst_addr + 4 * dst_stride_z)) = out10;
602 *((__global
float *)(dst_addr + 5 * dst_stride_z)) = out11;
603 *((__global
float *)(dst_addr + 6 * dst_stride_z)) = out12;
604 *((__global
float *)(dst_addr + 7 * dst_stride_z)) = out13;
605 *((__global
float *)(dst_addr + 8 * dst_stride_z)) = out20;
606 *((__global
float *)(dst_addr + 9 * dst_stride_z)) = out21;
607 *((__global
float *)(dst_addr + 10 * dst_stride_z)) = out22;
608 *((__global
float *)(dst_addr + 11 * dst_stride_z)) = out23;
609 *((__global
float *)(dst_addr + 12 * dst_stride_z)) = out30;
610 *((__global
float *)(dst_addr + 13 * dst_stride_z)) = out31;
611 *((__global
float *)(dst_addr + 14 * dst_stride_z)) = out32;
612 *((__global
float *)(dst_addr + 15 * dst_stride_z)) = out33;
637 __kernel
void winograd_input_transform_2x2_3x3_stepz2_nchw(
641 int x = get_global_id(0);
642 int y = get_global_id(1);
643 int z = get_global_id(2) * 2;
646 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * 2 * src_stride_x + y * 2 * src_stride_y + z * src_stride_z;
648 src_addr = src_addr - ((int)PAD_LEFT * src_stride_x) - ((int)PAD_TOP * src_stride_y);
650 float4 in_row0 = vload4(0, (__global
float *)(src_addr + 0 * src_stride_y));
651 float4 in_row1 = vload4(0, (__global
float *)(src_addr + 1 * src_stride_y));
652 float4 in_row2 = vload4(0, (__global
float *)(src_addr + 2 * src_stride_y));
653 float4 in_row3 = vload4(0, (__global
float *)(src_addr + 3 * src_stride_y));
655 src_addr += src_stride_z;
656 float4 in_row4 = vload4(0, (__global
float *)(src_addr + 0 * src_stride_y));
657 float4 in_row5 = vload4(0, (__global
float *)(src_addr + 1 * src_stride_y));
658 float4 in_row6 = vload4(0, (__global
float *)(src_addr + 2 * src_stride_y));
659 float4 in_row7 = vload4(0, (__global
float *)(src_addr + 3 * src_stride_y));
661 float4 tmp0 = in_row0 - in_row2;
662 float4 tmp1 = in_row1 + in_row2;
663 float4 tmp2 = in_row2 - in_row1;
664 float4 tmp3 = in_row1 - in_row3;
666 float4 tmp4 = in_row4 - in_row6;
667 float4 tmp5 = in_row5 + in_row6;
668 float4 tmp6 = in_row6 - in_row5;
669 float4 tmp7 = in_row5 - in_row7;
671 float2 out00 = (float2)(tmp0.s0 - tmp0.s2, tmp4.s0 - tmp4.s2);
672 float2 out01 = (float2)(tmp0.s1 + tmp0.s2, tmp4.s1 + tmp4.s2);
673 float2 out02 = (float2)(tmp0.s2 - tmp0.s1, tmp4.s2 - tmp4.s1);
674 float2 out03 = (float2)(tmp0.s1 - tmp0.s3, tmp4.s1 - tmp4.s3);
676 float2 out10 = (float2)(tmp1.s0 - tmp1.s2, tmp5.s0 - tmp5.s2);
677 float2 out11 = (float2)(tmp1.s1 + tmp1.s2, tmp5.s1 + tmp5.s2);
678 float2 out12 = (float2)(tmp1.s2 - tmp1.s1, tmp5.s2 - tmp5.s1);
679 float2 out13 = (float2)(tmp1.s1 - tmp1.s3, tmp5.s1 - tmp5.s3);
681 float2 out20 = (float2)(tmp2.s0 - tmp2.s2, tmp6.s0 - tmp6.s2);
682 float2 out21 = (float2)(tmp2.s1 + tmp2.s2, tmp6.s1 + tmp6.s2);
683 float2 out22 = (float2)(tmp2.s2 - tmp2.s1, tmp6.s2 - tmp6.s1);
684 float2 out23 = (float2)(tmp2.s1 - tmp2.s3, tmp6.s1 - tmp6.s3);
686 float2 out30 = (float2)(tmp3.s0 - tmp3.s2, tmp7.s0 - tmp7.s2);
687 float2 out31 = (float2)(tmp3.s1 + tmp3.s2, tmp7.s1 + tmp7.s2);
688 float2 out32 = (float2)(tmp3.s2 - tmp3.s1, tmp7.s2 - tmp7.s1);
689 float2 out33 = (float2)(tmp3.s1 - tmp3.s3, tmp7.s1 - tmp7.s3);
691 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * dst_stride_x + (x + y * (int)NUM_TILES_X) * dst_stride_y;
693 vstore2(out00, 0, (__global
float *)(dst_addr + 0 * dst_stride_z));
694 vstore2(out01, 0, (__global
float *)(dst_addr + 1 * dst_stride_z));
695 vstore2(out02, 0, (__global
float *)(dst_addr + 2 * dst_stride_z));
696 vstore2(out03, 0, (__global
float *)(dst_addr + 3 * dst_stride_z));
697 vstore2(out10, 0, (__global
float *)(dst_addr + 4 * dst_stride_z));
698 vstore2(out11, 0, (__global
float *)(dst_addr + 5 * dst_stride_z));
699 vstore2(out12, 0, (__global
float *)(dst_addr + 6 * dst_stride_z));
700 vstore2(out13, 0, (__global
float *)(dst_addr + 7 * dst_stride_z));
701 vstore2(out20, 0, (__global
float *)(dst_addr + 8 * dst_stride_z));
702 vstore2(out21, 0, (__global
float *)(dst_addr + 9 * dst_stride_z));
703 vstore2(out22, 0, (__global
float *)(dst_addr + 10 * dst_stride_z));
704 vstore2(out23, 0, (__global
float *)(dst_addr + 11 * dst_stride_z));
705 vstore2(out30, 0, (__global
float *)(dst_addr + 12 * dst_stride_z));
706 vstore2(out31, 0, (__global
float *)(dst_addr + 13 * dst_stride_z));
707 vstore2(out32, 0, (__global
float *)(dst_addr + 14 * dst_stride_z));
708 vstore2(out33, 0, (__global
float *)(dst_addr + 15 * dst_stride_z));
733 __kernel
void winograd_input_transform_4x4_3x3_stepz1_nchw(
737 int x = get_global_id(0);
738 int y = get_global_id(1);
739 int z = get_global_id(2);
742 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * 4 * src_stride_x + y * 4 * src_stride_y + z * src_stride_z;
744 src_addr = src_addr - ((int)PAD_LEFT * src_stride_x) - ((int)PAD_TOP * src_stride_y);
747 float4 d40 = vload4(0, (__global
float *)(src_addr + 4 * src_stride_y));
748 float2 d41 = vload2(2, (__global
float *)(src_addr + 4 * src_stride_y));
757 k0 += 4.0f * d40.s0 - 5.0f * d40.s2;
758 k1 += -4.0f * d40.s1 - 4.0f * d40.s2 + d40.s3;
759 k2 += 4.0f * d40.s1 - 4.0f * d40.s2 - d40.s3;
760 k3 += -2.0f * d40.s1 + 2.0f * d40.s3 - d40.s2;
761 k4 += 2.0f * d40.s1 - 2.0f * d40.s3 - d40.s2;
762 k5 += 4.0f * d40.s1 - 5.0f * d40.s3 + d41.s1;
765 float4 d00 = vload4(0, (__global
float *)(src_addr + 0 * src_stride_y));
766 float2 d01 = vload2(2, (__global
float *)(src_addr + 0 * src_stride_y));
769 float4 d20 = vload4(0, (__global
float *)(src_addr + 2 * src_stride_y));
770 float2 d21 = vload2(2, (__global
float *)(src_addr + 2 * src_stride_y));
773 __global
float *dst_addr = (__global
float *)(dst_ptr + dst_offset_first_element_in_bytes + z * dst_stride_x + (x + y * (
int)NUM_TILES_X) * dst_stride_y);
775 uint dst_plane_stride = dst_stride_z /
sizeof(float);
809 out0 += 16.0f * d00.s0 - 20.0f * d00.s2 - 20.0f * d20.s0 + 25.0f * d20.s2 + 4.0f * d01.s0 - 5.0f * d21.s0;
810 out1 += -16.0f * d00.s1 - 16.0f * d00.s2 + 4.0f * d00.s3 + 20.0f * d20.s1 + 20.0f * d20.s2 - 5.0f * d20.s3 + 4.0f * d01.s0 - 5.0f * d21.s0;
811 out2 += 16.0f * d00.s1 - 16.0f * d00.s2 - 4.0f * d00.s3 - 20.0f * d20.s1 + 20.0f * d20.s2 + 5.0f * d20.s3 + 4.0f * d01.s0 - 5.0f * d21.s0;
812 out3 += -8.0f * d00.s1 - 4.0f * d00.s2 + 8.0f * d00.s3 + 10.0f * d20.s1 + 5.0f * d20.s2 - 10.0f * d20.s3 + 4.0f * d01.s0 - 5.0f * d21.s0;
813 out4 += 8.0f * d00.s1 - 4.0f * d00.s2 - 8.0f * d00.s3 - 10.0f * d20.s1 + 5.0f * d20.s2 + 10.0f * d20.s3 + 4.0f * d01.s0 - 5.0f * d21.s0;
814 out5 += 16.0f * d00.s1 - 20.0f * d00.s3 - 20.0f * d20.s1 + 4.0f * d01.s1 + 25.0f * d20.s3 - 5.0f * d21.s1;
817 dst_addr += dst_plane_stride;
819 dst_addr += dst_plane_stride;
821 dst_addr += dst_plane_stride;
823 dst_addr += dst_plane_stride;
825 dst_addr += dst_plane_stride;
827 dst_addr += dst_plane_stride;
830 float4 d10 = vload4(0, (__global
float *)(src_addr + 1 * src_stride_y));
831 float2 d11 = vload2(2, (__global
float *)(src_addr + 1 * src_stride_y));
834 float4 d30 = vload4(0, (__global
float *)(src_addr + 3 * src_stride_y));
835 float2 d31 = vload2(2, (__global
float *)(src_addr + 3 * src_stride_y));
840 float part0 = -16.0f * d20.s0 + 20.0f * d20.s2 - 4.0f * d21.s0;
841 float part1 = 16.0f * d10.s0 - 20.0f * d10.s2 + 4.0f * d11.s0 - 4.0f * d30.s0 + 5.0f * d30.s2 - d31.s0;
842 float part2 = 16.0f * d20.s2 - 4.0f * d21.s0;
843 float part3 = 16.0f * d20.s1 - 4.0f * d20.s3;
844 float part4 = 16.0f * d10.s2 - 4.0f * d11.s0 - 4.0f * d30.s2 + d31.s0;
845 float part5 = 16.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + d30.s3;
846 float part6 = 4.0f * d20.s2 - 4.0f * d21.s0;
847 float part7 = 8.0f * d10.s1 - 8.0f * d10.s3 - 2.0f * d30.s1 + 2.0f * d30.s3;
848 float part8 = 4.0f * d10.s2 - 4.0f * d11.s0 - d30.s2 + d31.s0;
849 float part9 = 8.0f * d20.s1 - 8.0f * d20.s3;
850 float part10 = -16.0f * d20.s1 + 20.0f * d20.s3 - 4.0f * d21.s1;
851 float part11 = -16.0f * d10.s1 + 20.0f * d10.s3 - 4.0f * d11.s1 + 4.0f * d30.s1 - 5.0f * d30.s3 + d31.s1;
855 float part12 = 8.0f * d10.s0 - 10.0f * d10.s2 + 2.0f * d11.s0 - 8.0f * d30.s0 + 10.0f * d30.s2 - 2.0f * d31.s0;
856 float part13 = part0 * 0.25f;
857 float part14 = part2 * 0.25f;
858 float part15 = 8.0f * d10.s1 - 2.0f * d10.s3 - 8.0f * d30.s1 + 2.0f * d30.s3;
859 float part16 = 8.0f * d10.s2 - 2.0f * d11.s0 - 8.0f * d30.s2 + 2.0f * d31.s0;
860 float part17 = part3 * 0.25f;
861 float part18 = part6 * 0.25f;
862 float part19 = 4.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + 4.0f * d30.s3;
863 float part20 = 2.0f * d10.s2 - 2.0f * d11.s0 - 2.0f * d30.s2 + 2.0f * d31.s0;
864 float part21 = part9 * 0.25f;
865 float part22 = part10 * 0.25f;
866 float part23 = part11 * 0.5f + 6.0f * d30.s1 - 7.5f * d30.s3 + 1.5f * d31.s1;
868 out6 += part0 - part1;
869 out12 += part0 + part1;
870 out7 += part2 + part3 + part4 + part5;
871 out8 += part2 - part3 + part4 - part5;
872 out13 += part2 + part3 - part4 - part5;
873 out14 += part2 - part3 - part4 + part5;
874 out9 += part6 + part7 + part8 + part9;
875 out10 += part6 - part7 + part8 - part9;
876 out15 += part6 - part7 - part8 + part9;
877 out16 += part6 + part7 - part8 - part9;
878 out11 += part10 + part11;
879 out17 += part10 - part11;
881 out18 += part13 - part12;
882 out24 += part13 + part12;
883 out19 += part14 + part15 + part16 + part17;
884 out20 += part14 - part15 + part16 - part17;
885 out25 += part14 - part15 - part16 + part17;
886 out26 += part14 + part15 - part16 - part17;
887 out21 += part18 + part19 + part20 + part21;
888 out22 += part18 - part19 + part20 - part21;
889 out27 += part18 - part19 - part20 + part21;
890 out28 += part18 + part19 - part20 - part21;
891 out23 += part22 + part23;
892 out29 += part22 - part23;
895 dst_addr += dst_plane_stride;
897 dst_addr += dst_plane_stride;
899 dst_addr += dst_plane_stride;
901 dst_addr += dst_plane_stride;
903 dst_addr += dst_plane_stride;
905 dst_addr += dst_plane_stride;
907 dst_addr += dst_plane_stride;
909 dst_addr += dst_plane_stride;
911 dst_addr += dst_plane_stride;
913 dst_addr += dst_plane_stride;
915 dst_addr += dst_plane_stride;
917 dst_addr += dst_plane_stride;
920 dst_addr += dst_plane_stride;
922 dst_addr += dst_plane_stride;
924 dst_addr += dst_plane_stride;
926 dst_addr += dst_plane_stride;
928 dst_addr += dst_plane_stride;
930 dst_addr += dst_plane_stride;
932 dst_addr += dst_plane_stride;
934 dst_addr += dst_plane_stride;
936 dst_addr += dst_plane_stride;
938 dst_addr += dst_plane_stride;
940 dst_addr += dst_plane_stride;
942 dst_addr += dst_plane_stride;
945 float4 d50 = vload4(0, (__global
float *)(src_addr + 5 * src_stride_y));
946 float2 d51 = vload2(2, (__global
float *)(src_addr + 5 * src_stride_y));
949 out0 = 16.0f * d10.s0 - 20.0f * d10.s2 - 20.0f * d30.s0 + 25.0f * d30.s2 + 4.0f * d50.s0 - 5.0f * d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
950 out1 = -16.0f * d10.s1 - 16.0f * d10.s2 + 4.0f * d10.s3 + 20.0f * d30.s1 + 20.0f * d30.s2 - 5.0f * d30.s3 - 4.0f * d50.s1 - 4.0f * d50.s2 + d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
951 out2 = 16.0f * d10.s1 - 16.0f * d10.s2 - 4.0f * d10.s3 - 20.0f * d30.s1 + 20.0f * d30.s2 + 5.0f * d30.s3 + 4.0f * d50.s1 - 4.0f * d50.s2 - d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
952 out3 = -8.0f * d10.s1 - 4.0f * d10.s2 + 8.0f * d10.s3 + 10.0f * d30.s1 - 10.0f * d30.s3 + 5.0f * d30.s2 - 2.0f * d50.s1 + 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
953 out4 = 8.0f * d10.s1 - 4.0f * d10.s2 - 8.0f * d10.s3 - 10.0f * d30.s1 + 5.0f * d30.s2 + 10.0f * d30.s3 + 2.0f * d50.s1 - 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
954 out5 = 16.0f * d10.s1 - 20.0f * d10.s3 + 4.0f * d11.s1 - 20.0f * d30.s1 + 25.0f * d30.s3 - 5.0f * d31.s1 + 4.0f * d50.s1 - 5.0f * d50.s3 + d51.s1;
957 dst_addr += dst_plane_stride;
959 dst_addr += dst_plane_stride;
961 dst_addr += dst_plane_stride;
963 dst_addr += dst_plane_stride;
965 dst_addr += dst_plane_stride;
967 dst_addr += dst_plane_stride;
970 #define OUTPUT_ROW_4x4_5x5(out, tmp, comm_fact) \ 972 comm_fact.s0 = tmp.s2 - 4.25f * tmp.s4 + tmp.s6; \ 973 comm_fact.s1 = tmp.s1 - 4.25f * tmp.s3 + tmp.s5; \ 974 comm_fact.s2 = 2.5f * tmp.s3; \ 975 comm_fact.s3 = 0.5f * tmp.s1 + 2.f * tmp.s5 - comm_fact.s2; \ 976 comm_fact.s4 = 0.25f * tmp.s2 - 1.25f * tmp.s4 + tmp.s6; \ 977 comm_fact.s5 = 4.f * tmp.s2 + tmp.s6 - 5.f * tmp.s4; \ 978 comm_fact.s6 = 2.f * tmp.s1 + 0.5f * tmp.s5 - comm_fact.s2; \ 980 out.s0 = tmp.s0 - tmp.s6 + 5.25f * tmp.s4 - 5.25f * tmp.s2; \ 981 out.s1 = comm_fact.s0 + comm_fact.s1; \ 982 out.s2 = comm_fact.s0 - comm_fact.s1; \ 983 out.s3 = comm_fact.s3 + comm_fact.s4; \ 984 out.s4 = comm_fact.s4 - comm_fact.s3; \ 985 out.s5 = comm_fact.s5 + comm_fact.s6; \ 986 out.s6 = comm_fact.s5 - comm_fact.s6; \ 987 out.s7 = tmp.s7 - tmp.s1 + 5.25f * tmp.s3 - 5.25f * tmp.s5; \ 1012 __kernel
void winograd_input_transform_4x4_5x5_stepz1_nchw(
1016 int x = get_global_id(0);
1017 int y = get_global_id(1);
1018 int z = get_global_id(2);
1021 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * 4 * src_stride_x + y * 4 * src_stride_y + z * src_stride_z;
1023 src_addr = src_addr - ((int)PAD_LEFT * src_stride_x) - ((int)PAD_TOP * src_stride_y);
1026 const float8 in_row0 = vload8(0, (__global
float *)(src_addr + 0 * src_stride_y));
1027 const float8 in_row1 = vload8(0, (__global
float *)(src_addr + 1 * src_stride_y));
1028 const float8 in_row2 = vload8(0, (__global
float *)(src_addr + 2 * src_stride_y));
1029 const float8 in_row3 = vload8(0, (__global
float *)(src_addr + 3 * src_stride_y));
1030 const float8 in_row4 = vload8(0, (__global
float *)(src_addr + 4 * src_stride_y));
1031 const float8 in_row5 = vload8(0, (__global
float *)(src_addr + 5 * src_stride_y));
1032 const float8 in_row6 = vload8(0, (__global
float *)(src_addr + 6 * src_stride_y));
1033 const float8 in_row7 = vload8(0, (__global
float *)(src_addr + 7 * src_stride_y));
1036 float8 comm_fact0 = in_row2 + in_row6 - 4.25f * in_row4;
1037 float8 comm_fact1 = in_row1 + in_row5 - 4.25f * in_row3;
1038 float8 comm_fact2 = 0.25f * in_row2 - 1.25f * in_row4 + in_row6;
1041 const float8 tmp0 = in_row0 - in_row6 + 5.25f * in_row4 - 5.25f * in_row2;
1042 const float8 tmp1 = comm_fact0 + comm_fact1;
1043 const float8 tmp2 = comm_fact0 - comm_fact1;
1045 comm_fact0 = 2.5f * in_row3;
1046 comm_fact1 = 0.5f * in_row1 - comm_fact0 + 2.f * in_row5;
1048 const float8 tmp3 = comm_fact1 + comm_fact2;
1049 const float8 tmp4 = comm_fact2 - comm_fact1;
1051 comm_fact1 = 2.f * in_row1 - comm_fact0 + 0.5f * in_row5;
1052 comm_fact2 = 4.f * in_row2 - 5.f * in_row4 + in_row6;
1054 const float8 tmp5 = comm_fact1 + comm_fact2;
1055 const float8 tmp6 = comm_fact2 - comm_fact1;
1056 const float8 tmp7 = in_row7 - in_row1 + 5.25f * in_row3 - 5.25f * in_row5;
1059 float8 out0, out1, out2, out3, out4, out5, out6, out7;
1061 OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
1062 OUTPUT_ROW_4x4_5x5(out1, tmp1, comm_fact0);
1063 OUTPUT_ROW_4x4_5x5(out2, tmp2, comm_fact0);
1064 OUTPUT_ROW_4x4_5x5(out3, tmp3, comm_fact0);
1065 OUTPUT_ROW_4x4_5x5(out4, tmp4, comm_fact0);
1066 OUTPUT_ROW_4x4_5x5(out5, tmp5, comm_fact0);
1067 OUTPUT_ROW_4x4_5x5(out6, tmp6, comm_fact0);
1068 OUTPUT_ROW_4x4_5x5(out7, tmp7, comm_fact0);
1071 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * dst_stride_x + (x + y * (int)NUM_TILES_X) * dst_stride_y;
1073 *((__global
float *)(dst_addr + 0 * dst_stride_z)) = out0.s0;
1074 *((__global
float *)(dst_addr + 1 * dst_stride_z)) = out0.s1;
1075 *((__global
float *)(dst_addr + 2 * dst_stride_z)) = out0.s2;
1076 *((__global
float *)(dst_addr + 3 * dst_stride_z)) = out0.s3;
1077 *((__global
float *)(dst_addr + 4 * dst_stride_z)) = out0.s4;
1078 *((__global
float *)(dst_addr + 5 * dst_stride_z)) = out0.s5;
1079 *((__global
float *)(dst_addr + 6 * dst_stride_z)) = out0.s6;
1080 *((__global
float *)(dst_addr + 7 * dst_stride_z)) = out0.s7;
1081 *((__global
float *)(dst_addr + 8 * dst_stride_z)) = out1.s0;
1082 *((__global
float *)(dst_addr + 9 * dst_stride_z)) = out1.s1;
1083 *((__global
float *)(dst_addr + 10 * dst_stride_z)) = out1.s2;
1084 *((__global
float *)(dst_addr + 11 * dst_stride_z)) = out1.s3;
1085 *((__global
float *)(dst_addr + 12 * dst_stride_z)) = out1.s4;
1086 *((__global
float *)(dst_addr + 13 * dst_stride_z)) = out1.s5;
1087 *((__global
float *)(dst_addr + 14 * dst_stride_z)) = out1.s6;
1088 *((__global
float *)(dst_addr + 15 * dst_stride_z)) = out1.s7;
1089 *((__global
float *)(dst_addr + 16 * dst_stride_z)) = out2.s0;
1090 *((__global
float *)(dst_addr + 17 * dst_stride_z)) = out2.s1;
1091 *((__global
float *)(dst_addr + 18 * dst_stride_z)) = out2.s2;
1092 *((__global
float *)(dst_addr + 19 * dst_stride_z)) = out2.s3;
1093 *((__global
float *)(dst_addr + 20 * dst_stride_z)) = out2.s4;
1094 *((__global
float *)(dst_addr + 21 * dst_stride_z)) = out2.s5;
1095 *((__global
float *)(dst_addr + 22 * dst_stride_z)) = out2.s6;
1096 *((__global
float *)(dst_addr + 23 * dst_stride_z)) = out2.s7;
1097 *((__global
float *)(dst_addr + 24 * dst_stride_z)) = out3.s0;
1098 *((__global
float *)(dst_addr + 25 * dst_stride_z)) = out3.s1;
1099 *((__global
float *)(dst_addr + 26 * dst_stride_z)) = out3.s2;
1100 *((__global
float *)(dst_addr + 27 * dst_stride_z)) = out3.s3;
1101 *((__global
float *)(dst_addr + 28 * dst_stride_z)) = out3.s4;
1102 *((__global
float *)(dst_addr + 29 * dst_stride_z)) = out3.s5;
1103 *((__global
float *)(dst_addr + 30 * dst_stride_z)) = out3.s6;
1104 *((__global
float *)(dst_addr + 31 * dst_stride_z)) = out3.s7;
1105 *((__global
float *)(dst_addr + 32 * dst_stride_z)) = out4.s0;
1106 *((__global
float *)(dst_addr + 33 * dst_stride_z)) = out4.s1;
1107 *((__global
float *)(dst_addr + 34 * dst_stride_z)) = out4.s2;
1108 *((__global
float *)(dst_addr + 35 * dst_stride_z)) = out4.s3;
1109 *((__global
float *)(dst_addr + 36 * dst_stride_z)) = out4.s4;
1110 *((__global
float *)(dst_addr + 37 * dst_stride_z)) = out4.s5;
1111 *((__global
float *)(dst_addr + 38 * dst_stride_z)) = out4.s6;
1112 *((__global
float *)(dst_addr + 39 * dst_stride_z)) = out4.s7;
1113 *((__global
float *)(dst_addr + 40 * dst_stride_z)) = out5.s0;
1114 *((__global
float *)(dst_addr + 41 * dst_stride_z)) = out5.s1;
1115 *((__global
float *)(dst_addr + 42 * dst_stride_z)) = out5.s2;
1116 *((__global
float *)(dst_addr + 43 * dst_stride_z)) = out5.s3;
1117 *((__global
float *)(dst_addr + 44 * dst_stride_z)) = out5.s4;
1118 *((__global
float *)(dst_addr + 45 * dst_stride_z)) = out5.s5;
1119 *((__global
float *)(dst_addr + 46 * dst_stride_z)) = out5.s6;
1120 *((__global
float *)(dst_addr + 47 * dst_stride_z)) = out5.s7;
1121 *((__global
float *)(dst_addr + 48 * dst_stride_z)) = out6.s0;
1122 *((__global
float *)(dst_addr + 49 * dst_stride_z)) = out6.s1;
1123 *((__global
float *)(dst_addr + 50 * dst_stride_z)) = out6.s2;
1124 *((__global
float *)(dst_addr + 51 * dst_stride_z)) = out6.s3;
1125 *((__global
float *)(dst_addr + 52 * dst_stride_z)) = out6.s4;
1126 *((__global
float *)(dst_addr + 53 * dst_stride_z)) = out6.s5;
1127 *((__global
float *)(dst_addr + 54 * dst_stride_z)) = out6.s6;
1128 *((__global
float *)(dst_addr + 55 * dst_stride_z)) = out6.s7;
1129 *((__global
float *)(dst_addr + 56 * dst_stride_z)) = out7.s0;
1130 *((__global
float *)(dst_addr + 57 * dst_stride_z)) = out7.s1;
1131 *((__global
float *)(dst_addr + 58 * dst_stride_z)) = out7.s2;
1132 *((__global
float *)(dst_addr + 59 * dst_stride_z)) = out7.s3;
1133 *((__global
float *)(dst_addr + 60 * dst_stride_z)) = out7.s4;
1134 *((__global
float *)(dst_addr + 61 * dst_stride_z)) = out7.s5;
1135 *((__global
float *)(dst_addr + 62 * dst_stride_z)) = out7.s6;
1136 *((__global
float *)(dst_addr + 63 * dst_stride_z)) = out7.s7;
1138 #endif // defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP) 1140 #if defined(NUM_TILES_X) 1162 __kernel
void winograd_output_transform_2x2_3x3_nchw(
1165 #
if defined(HAS_BIAS)
1177 float d00 = *((__global
float *)(src_addr + 0 * src_stride_z));
1178 float d01 = *((__global
float *)(src_addr + 1 * src_stride_z));
1179 float d02 = *((__global
float *)(src_addr + 2 * src_stride_z));
1180 float d03 = *((__global
float *)(src_addr + 3 * src_stride_z));
1182 float d10 = *((__global
float *)(src_addr + 4 * src_stride_z));
1183 float d11 = *((__global
float *)(src_addr + 5 * src_stride_z));
1184 float d12 = *((__global
float *)(src_addr + 6 * src_stride_z));
1185 float d13 = *((__global
float *)(src_addr + 7 * src_stride_z));
1187 float d20 = *((__global
float *)(src_addr + 8 * src_stride_z));
1188 float d21 = *((__global
float *)(src_addr + 9 * src_stride_z));
1189 float d22 = *((__global
float *)(src_addr + 10 * src_stride_z));
1190 float d23 = *((__global
float *)(src_addr + 11 * src_stride_z));
1192 float d30 = *((__global
float *)(src_addr + 12 * src_stride_z));
1193 float d31 = *((__global
float *)(src_addr + 13 * src_stride_z));
1194 float d32 = *((__global
float *)(src_addr + 14 * src_stride_z));
1195 float d33 = *((__global
float *)(src_addr + 15 * src_stride_z));
1198 float k0 = d01 + d11 + d21;
1199 float k1 = d02 + d12 + d22;
1200 float k2 = d11 - d21 - d31;
1201 float k3 = d12 - d22 - d32;
1213 out00 += d00 + d20 + k0 + k1;
1214 out01 += k0 - k1 - (d03 + d23);
1215 out10 += -d20 - d30 + k2 + k3;
1216 out11 += k2 - k3 + d23 + d33;
1218 int y_in = get_global_id(1);
1219 int x_out = (y_in % NUM_TILES_X) * 2;
1220 int y_out = (y_in / NUM_TILES_X) * 2;
1221 int z_out = get_global_id(0);
1223 #if defined(HAS_BIAS) 1227 float b = (float) * ((__global
float *)(
vector_offset(&bias, z_out)));
1233 #endif // defined(HAS_BIAS) 1236 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * dst_stride_x + y_out * dst_stride_y + z_out * dst_stride_z;
1239 vstore2((float2)(out00, out01), 0, (__global
float *)(dst_addr + 0 * dst_stride_y));
1240 vstore2((float2)(out10, out11), 0, (__global
float *)(dst_addr + 1 * dst_stride_y));
1264 __kernel
void winograd_output_transform_4x4_3x3_nchw(
1267 #
if defined(HAS_BIAS)
1279 float d00 = *((__global
float *)(src_addr + 0 * src_stride_z));
1280 float d01 = *((__global
float *)(src_addr + 1 * src_stride_z));
1281 float d02 = *((__global
float *)(src_addr + 2 * src_stride_z));
1282 float d03 = *((__global
float *)(src_addr + 3 * src_stride_z));
1283 float d04 = *((__global
float *)(src_addr + 4 * src_stride_z));
1284 float d05 = *((__global
float *)(src_addr + 5 * src_stride_z));
1286 float d10 = *((__global
float *)(src_addr + 6 * src_stride_z));
1287 float d11 = *((__global
float *)(src_addr + 7 * src_stride_z));
1288 float d12 = *((__global
float *)(src_addr + 8 * src_stride_z));
1289 float d13 = *((__global
float *)(src_addr + 9 * src_stride_z));
1290 float d14 = *((__global
float *)(src_addr + 10 * src_stride_z));
1291 float d15 = *((__global
float *)(src_addr + 11 * src_stride_z));
1293 float d20 = *((__global
float *)(src_addr + 12 * src_stride_z));
1294 float d21 = *((__global
float *)(src_addr + 13 * src_stride_z));
1295 float d22 = *((__global
float *)(src_addr + 14 * src_stride_z));
1296 float d23 = *((__global
float *)(src_addr + 15 * src_stride_z));
1297 float d24 = *((__global
float *)(src_addr + 16 * src_stride_z));
1298 float d25 = *((__global
float *)(src_addr + 17 * src_stride_z));
1300 float d30 = *((__global
float *)(src_addr + 18 * src_stride_z));
1301 float d31 = *((__global
float *)(src_addr + 19 * src_stride_z));
1302 float d32 = *((__global
float *)(src_addr + 20 * src_stride_z));
1303 float d33 = *((__global
float *)(src_addr + 21 * src_stride_z));
1304 float d34 = *((__global
float *)(src_addr + 22 * src_stride_z));
1305 float d35 = *((__global
float *)(src_addr + 23 * src_stride_z));
1307 float d40 = *((__global
float *)(src_addr + 24 * src_stride_z));
1308 float d41 = *((__global
float *)(src_addr + 25 * src_stride_z));
1309 float d42 = *((__global
float *)(src_addr + 26 * src_stride_z));
1310 float d43 = *((__global
float *)(src_addr + 27 * src_stride_z));
1311 float d44 = *((__global
float *)(src_addr + 28 * src_stride_z));
1312 float d45 = *((__global
float *)(src_addr + 29 * src_stride_z));
1314 float d50 = *((__global
float *)(src_addr + 30 * src_stride_z));
1315 float d51 = *((__global
float *)(src_addr + 31 * src_stride_z));
1316 float d52 = *((__global
float *)(src_addr + 32 * src_stride_z));
1317 float d53 = *((__global
float *)(src_addr + 33 * src_stride_z));
1318 float d54 = *((__global
float *)(src_addr + 34 * src_stride_z));
1319 float d55 = *((__global
float *)(src_addr + 35 * src_stride_z));
1322 float out00 = d01 + d21 + d41 + d11 + d31;
1323 float out01 = d01 + d21 + d41 + d11 + d31;
1324 float out02 = d01 + d21 + d41 + d11 + d31;
1325 float out03 = d01 + d21 + d41 + d11 + d31;
1327 float k0 = d03 + d04 + d13 + d14 + d23 + d24 + d33 + d34 + d43 + d44;
1328 float k1 = 2.0f * d03 - 2.0f * d04 + 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 2.0f * d33 - 2.0f * d34 + 2.0f * d43 - 2.0f * d44;
1330 out00 += k0 + d00 + d02 + d10 + d12 + d20 + d22 + d30 + d32 + d40 + d42;
1331 out01 += k1 - d02 - d12 - d22 - d32 - d42;
1332 out02 += 4.0f * k0 + d02 + d12 + d22 + d32 + d42;
1333 out03 += 4.0f * k1 - d02 - d12 - d22 - d32 - d42 + d05 + d15 + d25 + d35 + d45;
1336 float out10 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
1337 float out11 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
1338 float out12 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
1339 float out13 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
1341 k0 = d13 + d14 - d23 - d24 + 2.0f * d33 + 2.0f * d34 - 2.0f * d43 - 2.0f * d44;
1342 k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 4.0f * d33 - 4.0f * d34 - 4.0f * d43 + 4.0f * d44;
1344 out10 += k0 + d10 + d12 - d20 - d22 + 2.0f * d30 + 2.0f * d32 - 2.0f * d40 - 2.0f * d42;
1345 out11 += k1 - d12 + d22 - 2.0f * d32 + 2.0f * d42;
1346 out12 += 4.0f * k0 + d12 - d22 + 2.0f * d32 - 2.0f * d42;
1347 out13 += 4.0f * k1 - d12 + d15 + d22 - d25 - 2.0f * d32 + 2.0f * d35 + 2.0f * d42 - 2.0f * d45;
1350 float out20 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
1351 float out21 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
1352 float out22 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
1353 float out23 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
1355 k0 = d13 + d14 + d23 + d24 + 4.0f * d33 + 4.0f * d34 + 4.0f * d43 + 4.0f * d44;
1356 k1 = 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 8.0f * d33 - 8.0f * d34 + 8.0f * d43 - 8.0f * d44;
1358 out20 += k0 + d10 + d12 + d20 + d22 + 4.0f * d30 + 4.0f * d32 + 4.0f * d40 + 4.0f * d42;
1359 out21 += k1 - d12 - d22 - 4.0f * d32 - 4.0f * d42;
1360 out22 += 4.0f * k0 + d12 + d22 + 4.0f * d32 + 4.0f * d42;
1361 out23 += 4.0f * k1 - d12 + d15 - d22 + d25 - 4.0f * d32 + 4.0f * d35 - 4.0f * d42 + 4.0f * d45;
1364 float out30 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
1365 float out31 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
1366 float out32 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
1367 float out33 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
1369 k0 = d13 + d14 - d23 - d24 + 8.0f * d33 + 8.0f * d34 - 8.0f * d43 - 8.0f * d44 + d53 + d54;
1370 k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 16.0f * d33 - 16.0f * d34 - 16.0f * d43 + 16.0f * d44 + 2.0f * d53 - 2.0f * d54;
1372 out30 += k0 + d10 + d12 - d20 - d22 + 8.0f * d30 + 8.0f * d32 - 8.0f * d40 - 8.0f * d42 + d50 + d52;
1373 out31 += k1 - d12 + d22 - 8.0f * d32 + 8.0f * d42 - d52;
1374 out32 += 4.0f * k0 + d12 - d22 + 8.0f * d32 - 8.0f * d42 + d52;
1375 out33 += 4.0f * k1 - d12 + d15 + d22 - d25 - 8.0f * d32 + 8.0f * d35 + 8.0f * d42 - 8.0f * d45 - d52 + d55;
1377 int y_in = get_global_id(1);
1378 int x_out = (y_in % NUM_TILES_X) * 4;
1379 int y_out = (y_in / NUM_TILES_X) * 4;
1380 int z_out = get_global_id(0);
1382 #if defined(HAS_BIAS) 1386 float b = (float) * ((__global
float *)(
vector_offset(&bias, z_out)));
1408 #endif // defined(HAS_BIAS) 1411 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * dst_stride_x + y_out * dst_stride_y + z_out * dst_stride_z;
1414 vstore4((float4)(out00, out01, out02, out03), 0, (__global
float *)(dst_addr + 0 * dst_stride_y));
1415 vstore4((float4)(out10, out11, out12, out13), 0, (__global
float *)(dst_addr + 1 * dst_stride_y));
1416 vstore4((float4)(out20, out21, out22, out23), 0, (__global
float *)(dst_addr + 2 * dst_stride_y));
1417 vstore4((float4)(out30, out31, out32, out33), 0, (__global
float *)(dst_addr + 3 * dst_stride_y));
1420 #define COMPUTE_TMP_COL(col, d0, d1, d2, d3, d4, d5, d6, d7, comm_fact) \ 1422 comm_fact.s0 = d1 + d2; \ 1423 comm_fact.s1 = d3 + d4; \ 1424 comm_fact.s2 = d5 + d6; \ 1426 col.s0 = comm_fact.s0 + comm_fact.s1 + 8.f * comm_fact.s2 + d0; \ 1427 col.s2 = comm_fact.s0 + 4.f * comm_fact.s1 + 2.f * comm_fact.s2; \ 1429 comm_fact.s0 = d1 - d2; \ 1430 comm_fact.s1 = d3 - d4; \ 1431 comm_fact.s2 = d5 - d6; \ 1433 col.s1 = comm_fact.s0 + 2.f * comm_fact.s1 + 4.f * comm_fact.s2; \ 1434 col.s3 = comm_fact.s0 + 8.f * comm_fact.s1 + comm_fact.s2 + d7; \ 1458 __kernel
void winograd_output_transform_4x4_5x5_nchw(
1461 #
if defined(HAS_BIAS)
1473 float d00 = *((__global
float *)(src_addr + 0 * src_stride_z));
1474 float d01 = *((__global
float *)(src_addr + 1 * src_stride_z));
1475 float d02 = *((__global
float *)(src_addr + 2 * src_stride_z));
1476 float d03 = *((__global
float *)(src_addr + 3 * src_stride_z));
1477 float d04 = *((__global
float *)(src_addr + 4 * src_stride_z));
1478 float d05 = *((__global
float *)(src_addr + 5 * src_stride_z));
1479 float d06 = *((__global
float *)(src_addr + 6 * src_stride_z));
1480 float d07 = *((__global
float *)(src_addr + 7 * src_stride_z));
1482 float d10 = *((__global
float *)(src_addr + 8 * src_stride_z));
1483 float d11 = *((__global
float *)(src_addr + 9 * src_stride_z));
1484 float d12 = *((__global
float *)(src_addr + 10 * src_stride_z));
1485 float d13 = *((__global
float *)(src_addr + 11 * src_stride_z));
1486 float d14 = *((__global
float *)(src_addr + 12 * src_stride_z));
1487 float d15 = *((__global
float *)(src_addr + 13 * src_stride_z));
1488 float d16 = *((__global
float *)(src_addr + 14 * src_stride_z));
1489 float d17 = *((__global
float *)(src_addr + 15 * src_stride_z));
1491 float d20 = *((__global
float *)(src_addr + 16 * src_stride_z));
1492 float d21 = *((__global
float *)(src_addr + 17 * src_stride_z));
1493 float d22 = *((__global
float *)(src_addr + 18 * src_stride_z));
1494 float d23 = *((__global
float *)(src_addr + 19 * src_stride_z));
1495 float d24 = *((__global
float *)(src_addr + 20 * src_stride_z));
1496 float d25 = *((__global
float *)(src_addr + 21 * src_stride_z));
1497 float d26 = *((__global
float *)(src_addr + 22 * src_stride_z));
1498 float d27 = *((__global
float *)(src_addr + 23 * src_stride_z));
1500 float d30 = *((__global
float *)(src_addr + 24 * src_stride_z));
1501 float d31 = *((__global
float *)(src_addr + 25 * src_stride_z));
1502 float d32 = *((__global
float *)(src_addr + 26 * src_stride_z));
1503 float d33 = *((__global
float *)(src_addr + 27 * src_stride_z));
1504 float d34 = *((__global
float *)(src_addr + 28 * src_stride_z));
1505 float d35 = *((__global
float *)(src_addr + 29 * src_stride_z));
1506 float d36 = *((__global
float *)(src_addr + 30 * src_stride_z));
1507 float d37 = *((__global
float *)(src_addr + 31 * src_stride_z));
1509 float d40 = *((__global
float *)(src_addr + 32 * src_stride_z));
1510 float d41 = *((__global
float *)(src_addr + 33 * src_stride_z));
1511 float d42 = *((__global
float *)(src_addr + 34 * src_stride_z));
1512 float d43 = *((__global
float *)(src_addr + 35 * src_stride_z));
1513 float d44 = *((__global
float *)(src_addr + 36 * src_stride_z));
1514 float d45 = *((__global
float *)(src_addr + 37 * src_stride_z));
1515 float d46 = *((__global
float *)(src_addr + 38 * src_stride_z));
1516 float d47 = *((__global
float *)(src_addr + 39 * src_stride_z));
1518 float d50 = *((__global
float *)(src_addr + 40 * src_stride_z));
1519 float d51 = *((__global
float *)(src_addr + 41 * src_stride_z));
1520 float d52 = *((__global
float *)(src_addr + 42 * src_stride_z));
1521 float d53 = *((__global
float *)(src_addr + 43 * src_stride_z));
1522 float d54 = *((__global
float *)(src_addr + 44 * src_stride_z));
1523 float d55 = *((__global
float *)(src_addr + 45 * src_stride_z));
1524 float d56 = *((__global
float *)(src_addr + 46 * src_stride_z));
1525 float d57 = *((__global
float *)(src_addr + 47 * src_stride_z));
1527 float d60 = *((__global
float *)(src_addr + 48 * src_stride_z));
1528 float d61 = *((__global
float *)(src_addr + 49 * src_stride_z));
1529 float d62 = *((__global
float *)(src_addr + 50 * src_stride_z));
1530 float d63 = *((__global
float *)(src_addr + 51 * src_stride_z));
1531 float d64 = *((__global
float *)(src_addr + 52 * src_stride_z));
1532 float d65 = *((__global
float *)(src_addr + 53 * src_stride_z));
1533 float d66 = *((__global
float *)(src_addr + 54 * src_stride_z));
1534 float d67 = *((__global
float *)(src_addr + 55 * src_stride_z));
1536 float d70 = *((__global
float *)(src_addr + 56 * src_stride_z));
1537 float d71 = *((__global
float *)(src_addr + 57 * src_stride_z));
1538 float d72 = *((__global
float *)(src_addr + 58 * src_stride_z));
1539 float d73 = *((__global
float *)(src_addr + 59 * src_stride_z));
1540 float d74 = *((__global
float *)(src_addr + 60 * src_stride_z));
1541 float d75 = *((__global
float *)(src_addr + 61 * src_stride_z));
1542 float d76 = *((__global
float *)(src_addr + 62 * src_stride_z));
1543 float d77 = *((__global
float *)(src_addr + 63 * src_stride_z));
1546 float4 comm_fact0, comm_fact1, comm_fact2;
1547 float4 tmp_col0, tmp_col1, tmp_col2, tmp_col3, tmp_col4, tmp_col5, tmp_col6, tmp_col7;
1549 COMPUTE_TMP_COL(tmp_col0, d00, d10, d20, d30, d40, d50, d60, d70, comm_fact0);
1550 COMPUTE_TMP_COL(tmp_col1, d01, d11, d21, d31, d41, d51, d61, d71, comm_fact0);
1551 COMPUTE_TMP_COL(tmp_col2, d02, d12, d22, d32, d42, d52, d62, d72, comm_fact0);
1552 COMPUTE_TMP_COL(tmp_col3, d03, d13, d23, d33, d43, d53, d63, d73, comm_fact0);
1553 COMPUTE_TMP_COL(tmp_col4, d04, d14, d24, d34, d44, d54, d64, d74, comm_fact0);
1554 COMPUTE_TMP_COL(tmp_col5, d05, d15, d25, d35, d45, d55, d65, d75, comm_fact0);
1555 COMPUTE_TMP_COL(tmp_col6, d06, d16, d26, d36, d46, d56, d66, d76, comm_fact0);
1556 COMPUTE_TMP_COL(tmp_col7, d07, d17, d27, d37, d47, d57, d67, d77, comm_fact0);
1559 comm_fact0 = tmp_col1 + tmp_col2;
1560 comm_fact1 = tmp_col3 + tmp_col4;
1561 comm_fact2 = tmp_col5 + tmp_col6;
1563 float4 out_col0 = comm_fact0 + comm_fact1 + 8.f * comm_fact2 + tmp_col0;
1564 float4 out_col2 = comm_fact0 + 4.f * comm_fact1 + 2.f * comm_fact2;
1566 comm_fact0 = tmp_col1 - tmp_col2;
1567 comm_fact1 = tmp_col3 - tmp_col4;
1568 comm_fact2 = tmp_col5 - tmp_col6;
1570 float4 out_col1 = comm_fact0 + 2.f * comm_fact1 + 4.f * comm_fact2;
1571 float4 out_col3 = comm_fact0 + 8.f * comm_fact1 + comm_fact2 + tmp_col7;
1573 int y_in = get_global_id(1);
1574 int x_out = (y_in % NUM_TILES_X) * 4;
1575 int y_out = (y_in / NUM_TILES_X) * 4;
1576 int z_out = get_global_id(0);
1578 #if defined(HAS_BIAS) 1582 float b = (float) * ((__global
float *)(
vector_offset(&bias, z_out)));
1584 out_col0 += (float4)b;
1585 out_col1 += (float4)b;
1586 out_col2 += (float4)b;
1587 out_col3 += (float4)b;
1588 #endif // defined(HAS_BIAS) 1591 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * dst_stride_x + y_out * dst_stride_y + z_out * dst_stride_z;
1594 *(__global
float *)(dst_addr + 0 * dst_stride_x + 0 * dst_stride_y) = out_col0.s0;
1595 *(__global
float *)(dst_addr + 1 * dst_stride_x + 0 * dst_stride_y) = out_col1.s0;
1596 *(__global
float *)(dst_addr + 2 * dst_stride_x + 0 * dst_stride_y) = out_col2.s0;
1597 *(__global
float *)(dst_addr + 3 * dst_stride_x + 0 * dst_stride_y) = out_col3.s0;
1598 *(__global
float *)(dst_addr + 0 * dst_stride_x + 1 * dst_stride_y) = out_col0.s1;
1599 *(__global
float *)(dst_addr + 1 * dst_stride_x + 1 * dst_stride_y) = out_col1.s1;
1600 *(__global
float *)(dst_addr + 2 * dst_stride_x + 1 * dst_stride_y) = out_col2.s1;
1601 *(__global
float *)(dst_addr + 3 * dst_stride_x + 1 * dst_stride_y) = out_col3.s1;
1602 *(__global
float *)(dst_addr + 0 * dst_stride_x + 2 * dst_stride_y) = out_col0.s2;
1603 *(__global
float *)(dst_addr + 1 * dst_stride_x + 2 * dst_stride_y) = out_col1.s2;
1604 *(__global
float *)(dst_addr + 2 * dst_stride_x + 2 * dst_stride_y) = out_col2.s2;
1605 *(__global
float *)(dst_addr + 3 * dst_stride_x + 2 * dst_stride_y) = out_col3.s2;
1606 *(__global
float *)(dst_addr + 0 * dst_stride_x + 3 * dst_stride_y) = out_col0.s3;
1607 *(__global
float *)(dst_addr + 1 * dst_stride_x + 3 * dst_stride_y) = out_col1.s3;
1608 *(__global
float *)(dst_addr + 2 * dst_stride_x + 3 * dst_stride_y) = out_col2.s3;
1609 *(__global
float *)(dst_addr + 3 * dst_stride_x + 3 * dst_stride_y) = out_col3.s3;
1611 #endif // defined(NUM_TILES_X) Structure to hold Vector information.
#define CONVERT_TO_TENSOR3D_STRUCT(name)
#define CONVERT_TO_TENSOR4D_STRUCT(name, mod_size)
Structure to hold 3D tensor information.
Structure to hold 4D tensor information.
__global const uchar * tensor4D_offset(const Tensor4D *tensor, int x, int y, int z, int w)
Get the pointer position of a Tensor4D.
__global const uchar * tensor3D_offset(const Tensor3D *tensor, int x, int y, int z)
Get the pointer position of a Tensor3D.
#define VECTOR_DECLARATION(name)
#define TENSOR3D_DECLARATION(name)
#define CONVERT_TO_VECTOR_STRUCT_NO_STEP(name)
__global const uchar * vector_offset(const Vector *vec, int x)
Get the pointer position of a Vector.
#define TENSOR4D_DECLARATION(name)
convolution configure & src