Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_common_convolution_winograd.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifdef __INTEL_COMPILER
18 #include <immintrin.h>
19 #endif
20
21 #include "mkldnn_types.h"
22
23 #include "c_types_map.hpp"
24 #include "mkldnn_thread.hpp"
25 #include "type_helpers.hpp"
26 #include "utils.hpp"
27
28 #include "jit_avx512_common_convolution_winograd.hpp"
29
30 #ifndef _MSC_VER
31 #define pragma_unroll _Pragma("unroll")
32 #else
33 #define pragma_unroll
34 #endif
35
36 namespace mkldnn {
37 namespace impl {
38 namespace cpu {
39
40 using namespace memory_tracking::names;
41
42 namespace {
43
44 unsigned int LLC_cache_size = get_cache_size(3, false);
45
46 void inline load_ps(float *dest, const float *src_mem) {
47 #ifdef __INTEL_COMPILER
48     __m512 *Iv512 = (__m512 *)dest;
49     Iv512[0] = _mm512_load_ps(src_mem);
50 #else
51     PRAGMA_OMP_SIMD()
52     for (int v = 0; v < simd_w; v++) dest[v] = src_mem[v];
53 #endif
54 }
55
56 void inline store_output(float *dest, const float *data, bool streamout) {
57 #ifdef __INTEL_COMPILER
58     if (streamout)
59         _mm512_stream_ps(dest, *((__m512 *)data));
60     else
61         _mm512_store_ps(dest, *((__m512 *)data));
62 #else
63     PRAGMA_OMP_SIMD()
64     for (int v = 0; v < simd_w; v++)
65         dest[v] = data[v];
66 #endif
67 }
68
69 void inline accum_output(
70         float *dest, float *data, bool streamout, bool with_relu_postsum) {
71 #ifdef __INTEL_COMPILER
72     __m512 _data = _mm512_loadu_ps(data);
73     __m512 _dest = _mm512_loadu_ps(dest);
74     _data = _mm512_add_ps(_data, _dest);
75     if (with_relu_postsum)
76         _data = _mm512_max_ps(_data, _mm512_setzero_ps());
77     if (streamout)
78         _mm512_stream_ps(dest, _data);
79     else
80         _mm512_store_ps(dest, _data);
81 #else
82     PRAGMA_OMP_SIMD()
83     for (int v = 0; v < simd_w; v++)
84         data[v] += dest[v];
85
86     if (with_relu_postsum) {
87         PRAGMA_OMP_SIMD()
88         for (int v = 0; v < simd_w; v++)
89             if (data[v] < 0.f)
90                 data[v] = 0.f;
91     }
92
93     PRAGMA_OMP_SIMD()
94     for (int v = 0; v < simd_w; v++)
95         dest[v] = data[v];
96 #endif
97 }
98 }
99
100 using namespace mkldnn::impl::status;
101 using namespace mkldnn::impl::memory_format;
102 using namespace mkldnn::impl::utils;
103
104 void trans_W_4x4_3x3(float Fw_[6][6][16][16], float F[3][3][16][16]) {
105     float Fw[6][16];
106     float T[6][3][16];
107     float t0[16];
108     float t1[16];
109     float t2[16];
110
111     for (int j = 0; j < 16; j++) {
112 #pragma unroll
113         for (int i = 0; i < 3; i++) {
114             PRAGMA_OMP_SIMD()
115             for (int k = 0; k < 16; k++) {
116                 t0[k] = 0.26890756302521f * F[2][i][j][k];
117                 t1[k] = -t0[k] - 0.688403361344538f * F[0][i][j][k];
118                 t2[k] = t0[k] + 0.119514472455649f * F[0][i][j][k];
119
120                 T[0][i][k] = 1.13777777777778f * F[0][i][j][k];
121                 T[1][i][k] = t1[k] - 0.430252100840336f * F[1][i][j][k];
122                 T[2][i][k] = t1[k] + 0.430252100840336f * F[1][i][j][k];
123                 T[3][i][k] = t2[k] + 0.179271708683473f * F[1][i][j][k];
124                 T[4][i][k] = t2[k] - 0.179271708683473f * F[1][i][j][k];
125                 T[5][i][k] = F[2][i][j][k];
126             }
127         }
128 #pragma unroll
129         for (int i = 0; i < 6; i++) {
130             PRAGMA_OMP_SIMD()
131             for (int k = 0; k < 16; k++) {
132                 t0[k] = 0.26890756302521f * T[i][2][k];
133                 t1[k] = -t0[k] - 0.688403361344538f * T[i][0][k];
134                 t2[k] = t0[k] + 0.119514472455649f * T[i][0][k];
135
136                 Fw[0][k] = 1.13777777777778f * T[i][0][k];
137                 Fw[1][k] = t1[k] - 0.430252100840336f * T[i][1][k];
138                 Fw[2][k] = t1[k] + 0.430252100840336f * T[i][1][k];
139                 Fw[3][k] = t2[k] + 0.179271708683473f * T[i][1][k];
140                 Fw[4][k] = t2[k] - 0.179271708683473f * T[i][1][k];
141                 Fw[5][k] = T[i][2][k];
142 #pragma unroll
143                 for (int l = 0; l < 6; l++) {
144                     Fw_[i][l][j][k] = Fw[l][k];
145                 }
146             }
147         }
148     }
149 }
150
151 void trans_O_4x4_3x3(float Mw[6][6][16], float O[4][4][16]) {
152     float T[4][6][16];
153     float t0[16];
154     float t1[16];
155     float t2[16];
156     float t3[16];
157
158 #pragma unroll
159     for (int i = 0; i < 6; i++) {
160         PRAGMA_OMP_SIMD()
161         for (int v = 0; v < 16; v++) {
162             t0[v] = Mw[1][i][v] + Mw[2][i][v];
163             t1[v] = Mw[3][i][v] + Mw[4][i][v];
164             t2[v] = Mw[1][i][v] - Mw[2][i][v];
165             t3[v] = Mw[3][i][v] - Mw[4][i][v];
166
167             T[0][i][v] = t0[v] + t1[v] + Mw[0][i][v];
168             T[1][i][v] = t2[v] * 0.625f + t3[v] * 1.5f;
169             T[2][i][v] = t0[v] * 0.390625f + t1[v] * 2.25f;
170             T[3][i][v] = t2[v] * 0.244140625f + t3[v] * 3.375f + Mw[5][i][v];
171         }
172     }
173 #pragma unroll
174     for (int i = 0; i < 4; i++) {
175         PRAGMA_OMP_SIMD()
176         for (int v = 0; v < 16; v++) {
177             t0[v] = T[i][1][v] + T[i][2][v];
178             t1[v] = T[i][3][v] + T[i][4][v];
179             t2[v] = T[i][1][v] - T[i][2][v];
180             t3[v] = T[i][3][v] - T[i][4][v];
181
182             O[i][0][v] = t0[v] + t1[v] + T[i][0][v];
183             O[i][1][v] = t2[v] * 0.625f + t3[v] * 1.5f;
184             O[i][2][v] = t0[v] * 0.390625f + t1[v] * 2.25f;
185             O[i][3][v] = t2[v] * 0.244140625f + t3[v] * 3.375f + T[i][5][v];
186         }
187     }
188 }
189
190
191 void trans_W_3x3_4x4(float Fw[6][6][16], float F[4][6][16])
192 {
193     const float rcp3 = 1.0f / 3.0f;
194     const float rcp4 = 1.0f / 4.0f;
195     const float rcp6 = 1.0f / 6.0f;
196     const float rcp12 = 1.0f / 12.0f;
197     const float rcp24 = 1.0f / 24.0f;
198     float t0[16];
199     float t1[16];
200     float t2[16];
201     float t3[16];
202     float t4[16];
203     float T[6][4][16];
204
205 pragma_unroll
206     for (int i = 0; i < 4; i++) {
207         PRAGMA_OMP_SIMD()
208         for (int j = 0; j < 16; j++) {
209             t0[j] = F[2][i][j] * rcp6;
210             t1[j] = F[0][i][j] * -rcp6 - t0[j];
211             t2[j] = F[0][i][j] * rcp24 + t0[j];
212             t3[j] = (F[1][i][j] + F[3][i][j]) * rcp6;
213             t4[j] = F[1][i][j] * rcp12 + F[3][i][j] * rcp3;
214
215             T[0][i][j] = F[0][i][j] * rcp4;
216             T[1][i][j] = t1[j] - t3[j];
217             T[2][i][j] = t1[j] + t3[j];
218             T[3][i][j] = t2[j] + t4[j];
219             T[4][i][j] = t2[j] - t4[j];
220             T[5][i][j] = F[3][i][j];
221         }
222     }
223 pragma_unroll
224     for (int i = 0; i < 6; i++) {
225         PRAGMA_OMP_SIMD()
226         for (int j = 0; j < 16; j++) {
227             t0[j] = T[i][2][j] * rcp6;
228             t1[j] = T[i][0][j] * -rcp6 - t0[j];
229             t2[j] = T[i][0][j] * rcp24 + t0[j];
230             t3[j] = (T[i][1][j] + T[i][3][j]) * rcp6;
231             t4[j] = T[i][1][j] * rcp12 + T[i][3][j] * rcp3;
232
233             Fw[i][0][j] = T[i][0][j] * rcp4;
234             Fw[i][1][j] = t1[j] - t3[j];
235             Fw[i][2][j] = t1[j] + t3[j];
236             Fw[i][3][j] = t2[j] + t4[j];
237             Fw[i][4][j] = t2[j] - t4[j];
238             Fw[i][5][j] = T[i][3][j];
239         }
240     }
241 }
242
243 void trans_O_3x3_4x4(float Mw[6][6][16][16], float M[3][3][16][16])
244 {
245     float T[4][6][16];
246     float M_[3][16];
247     float t0[16];
248     float t1[16];
249     float t2[16];
250
251     for (int j = 0; j < 16; j++) {
252 pragma_unroll
253         for (int i = 0; i < 6; i++) {
254             PRAGMA_OMP_SIMD()
255             for (int l = 0; l < 16; l++) {
256                 t0[l] = Mw[1][i][j][l] + Mw[2][i][j][l];
257                 t1[l] = Mw[3][i][j][l] + Mw[4][i][j][l];
258                 t2[l] = t1[l] * 4.0f + Mw[5][i][j][l];
259
260                 T[0][i][l] = Mw[0][i][j][l] + t0[l] + t1[l];
261                 T[1][i][l] = (Mw[1][i][j][l] - Mw[2][i][j][l]) +
262                              2.0f * (Mw[3][i][j][l] - Mw[4][i][j][l]);
263                 T[2][i][l] = t0[l] + t2[l];
264             }
265         }
266 pragma_unroll
267         for (int i = 0; i < 3; i++) {
268             PRAGMA_OMP_SIMD()
269             for (int l = 0; l < 16; l++) {
270                 t0[l] = T[i][1][l] + T[i][2][l];
271                 t1[l] = T[i][3][l] + T[i][4][l];
272                 t2[l] = t1[l] * 4.0f + T[i][5][l];
273
274                 M_[0][l] = T[i][0][l] + t0[l] + t1[l];
275                 M_[1][l] = (T[i][1][l] - T[i][2][l]) +
276                            2.0f * (T[i][3][l] - T[i][4][l]);
277                 M_[2][l] = t0[l] + t2[l];
278
279                 for (int k = 0; k < 3; k++) {
280                     M[i][k][j][l] = M_[k][l];
281                 }
282             }
283         }
284     }
285 }
286
287 void trans_I_4x4_3x3(float Iw[6][6][16], float I[6][6][16])
288 {
289     float T[6][6][16];
290     float t0[16];
291     float t1[16];
292     float t2[16];
293     float t3[16];
294     float t4[16];
295     float t5[16];
296
297 pragma_unroll
298     for (int i = 0; i < 6; i++) {
299         PRAGMA_OMP_SIMD()
300         for (int v = 0; v < 16; v++) {
301             t0[v] = I[2][i][v] * -2.25f + I[4][i][v];
302             t1[v] = I[1][i][v] * -2.25f + I[3][i][v];
303             t2[v] = I[2][i][v] * -0.390625f + I[4][i][v];
304             t3[v] = I[1][i][v] * -0.390625f + I[3][i][v];
305             t4[v] = I[0][i][v] * 0.87890625f + I[4][i][v];
306             t5[v] = I[1][i][v] * 0.87890625f + I[5][i][v];
307
308             T[0][i][v] = I[2][i][v] * -2.640625f + t4[v];
309             T[1][i][v] = t1[v] * 0.625f + t0[v];
310             T[2][i][v] = t1[v] * -0.625f + t0[v];
311             T[3][i][v] = t3[v] * 1.5f + t2[v];
312             T[4][i][v] = t3[v] * -1.5f + t2[v];
313             T[5][i][v] = I[3][i][v] * -2.640625f + t5[v];
314         }
315     }
316
317 pragma_unroll
318     for (int i = 0; i < 6; i++) {
319         PRAGMA_OMP_SIMD()
320         for (int v = 0; v < 16; v++) {
321             t0[v] = T[i][2][v] * -2.25f + T[i][4][v];
322             t1[v] = T[i][1][v] * -2.25f + T[i][3][v];
323             t2[v] = T[i][2][v] * -0.390625f + T[i][4][v];
324             t3[v] = T[i][1][v] * -0.390625f + T[i][3][v];
325             t4[v] = T[i][0][v] * 0.87890625f + T[i][4][v];
326             t5[v] = T[i][1][v] * 0.87890625f + T[i][5][v];
327
328             Iw[i][0][v] = T[i][2][v] * -2.640625f + t4[v];
329             Iw[i][1][v] = t1[v] * 0.625f + t0[v];
330             Iw[i][2][v] = t1[v] * -0.625f + t0[v];
331             Iw[i][3][v] = t3[v] * 1.5f + t2[v];
332             Iw[i][4][v] = t3[v] * -1.5f + t2[v];
333             Iw[i][5][v] = T[i][3][v] * -2.640625f + t5[v];
334         }
335     }
336 }
337
338 void trans_W_3x3_4x4_wu(float Fw[6][6][16], float F[4][6][16])
339 {
340     float T[6][4][16];
341     float t0[16];
342     float t1[16];
343     float t2[16];
344     float t3[16];
345     float t4[16];
346
347 pragma_unroll
348     for (int i = 0; i < 4; i++) {
349         PRAGMA_OMP_SIMD()
350         for (int v = 0; v < 16; v++) {
351             t0[v] = F[2][i][v] * 0.26890756302521f;
352             t1[v] = F[0][i][v] * -0.688403361344538f - t0[v];
353             t2[v] = F[0][i][v] * 0.119514472455649f + t0[v];
354             t3[v] = F[1][i][v] * 0.430252100840336f +
355                     F[3][i][v] * 0.168067226890756f;
356             t4[v] = F[1][i][v] * 0.179271708683473f +
357                     F[3][i][v] * 0.403361344537815f;
358
359             T[0][i][v] = F[0][i][v] * 1.13777777777778f;
360             T[1][i][v] = t1[v] - t3[v];
361             T[2][i][v] = t1[v] + t3[v];
362             T[3][i][v] = t2[v] + t4[v];
363             T[4][i][v] = t2[v] - t4[v];
364             T[5][i][v] = F[3][i][v];
365         }
366     }
367 pragma_unroll
368     for (int i = 0; i < 6; i++) {
369         for (int v = 0; v < 16; v++) {
370             t0[v] = T[i][2][v] * 0.26890756302521f;
371             t1[v] = T[i][0][v] * -0.688403361344538f - t0[v];
372             t2[v] = T[i][0][v] * 0.119514472455649f + t0[v];
373             t3[v] = T[i][1][v] * 0.430252100840336f +
374                     T[i][3][v] * 0.168067226890756f;
375             t4[v] = T[i][1][v] * 0.179271708683473f +
376                     T[i][3][v] * 0.403361344537815f;
377
378             Fw[i][0][v] = T[i][0][v] * 1.13777777777778f;
379             Fw[i][1][v] = t1[v] - t3[v];
380             Fw[i][2][v] = t1[v] + t3[v];
381             Fw[i][3][v] = t2[v] + t4[v];
382             Fw[i][4][v] = t2[v] - t4[v];
383             Fw[i][5][v] = T[i][3][v];
384         }
385     }
386 }
387
388 void trans_O_3x3_4x4_wu(float Mw[6][6][16][16], float M[3][3][16][16])
389 {
390     float T[3][6][16];
391     float t0[16];
392     float t1[16];
393     float t2[16];
394     float M_[3][16];
395
396     for (int j = 0; j < 16; j++) {
397 pragma_unroll
398         for (int i = 0; i < 6; i++) {
399             PRAGMA_OMP_SIMD()
400             for (int v = 0; v < 16; v++) {
401                 t0[v] = Mw[1][i][j][v] + Mw[2][i][j][v];
402                 t1[v] = Mw[3][i][j][v] + Mw[4][i][j][v];
403                 t2[v] = t1[v] * 2.25f + Mw[5][i][j][v];
404
405                 T[0][i][v] = Mw[0][i][j][v] + t0[v] + t1[v];
406                 T[1][i][v] = 0.625f * (Mw[1][i][j][v] - Mw[2][i][j][v]) +
407                              1.5f * (Mw[3][i][j][v] - Mw[4][i][j][v]);
408                 T[2][i][v] = t0[v] * 0.390625f + t2[v];
409             }
410         }
411 pragma_unroll
412         for (int i = 0; i < 3; i++) {
413             PRAGMA_OMP_SIMD()
414             for (int v = 0; v < 16; v++) {
415                 t0[v] = T[i][1][v] + T[i][2][v];
416                 t1[v] = T[i][3][v] + T[i][4][v];
417                 t2[v] = t1[v] * 2.25f + T[i][5][v];
418
419                 M_[0][v] = T[i][0][v] + t0[v] + t1[v];
420                 M_[1][v] = 0.625f * (T[i][1][v] - T[i][2][v]) +
421                            1.5f * (T[i][3][v] - T[i][4][v]);
422                 M_[2][v] = t0[v] * 0.390625f + t2[v];
423             }
424
425 pragma_unroll
426             for (int k = 0; k < 3; k++) {
427                 PRAGMA_OMP_SIMD()
428                 for (int v = 0; v < 16; v++) {
429                     M[i][k][j][v] = M_[k][v];
430                 }
431             }
432         }
433     }
434 }
435
436 template <bool is_fwd>
437 void input_transform_data(int image, const jit_conv_winograd_conf_t &jcp,
438         float *inp, float *tinp, bool streamout = true)
439 {
440     const int inpw = is_fwd ? jcp.iw : jcp.ow;
441     const int inph = is_fwd ? jcp.ih : jcp.oh;
442     const int l_pad = is_fwd ? jcp.l_pad : jcp.iw + jcp.r_pad - jcp.ow;
443     const int t_pad = is_fwd ? jcp.t_pad : jcp.ih + jcp.t_pad - jcp.oh;
444     const int wp_max = inpw + l_pad;
445     const int hp_max = inph + t_pad;
446     float Iw[alpha][alpha][simd_w];
447     float I[alpha][alpha][simd_w];
448
449     array_offset_calculator<float, 5> input(inp,
450             jcp.mb, jcp.dimK/simd_w, inph, inpw,
451             simd_w);
452     array_offset_calculator<float, 8> output(tinp,
453             jcp.dimN_nb_block, alpha, alpha,
454             jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block,
455             jcp.dimN_reg_block, jcp.dimK_reg_block);
456
457     int tile_base_index = image * jcp.itiles * jcp.jtiles;
458     int tile_block_ur = tile_base_index % jcp.tile_block_ur;
459     int nb_tile_block_ur =
460         (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur;
461     int tile_block =
462         (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur;
463
464     for (int tj = 0; tj < jcp.jtiles; tj++) {
465         for (int ti = 0; ti < jcp.itiles; ti++) {
466             for (int j = 0; j < alpha; j++) {
467                 int ydim = tj * tile_size + j;
468                 if ((t_pad <= ydim) && (ydim < hp_max)) {
469                     float *pinp_j = inp + (ydim - t_pad) * inpw * 16 ;
470                     for (int i = 0; i < alpha; i++) {
471                         int xdim = ti * tile_size + i;
472                         if ((l_pad <= xdim) && (xdim < wp_max)) {
473                             float *pinp_i = pinp_j + (xdim - l_pad) * 16;
474                             load_ps(I[j][i], pinp_i);
475                         } else {
476                             PRAGMA_OMP_SIMD()
477                             for (int v = 0; v < simd_w; v++) {
478                                 I[j][i][v] = 0.0f;
479                             }
480                         }
481                     }
482                 } else {
483                     for (int i = 0; i < alpha; i++) {
484                         PRAGMA_OMP_SIMD()
485                         for (int v = 0; v < simd_w; v++) {
486                             I[j][i][v] = 0.0f;
487                         }
488                     }
489                 }
490             }
491
492             trans_I_4x4_3x3(Iw, I);
493
494             for (int j = 0; j < alpha; j++) {
495                 for (int i = 0; i < alpha; i++) {
496                     store_output(&(output(tile_block, j, i,
497                                     nb_tile_block_ur, 0, 0,
498                                     tile_block_ur, 0)),
499                                  Iw[j][i], streamout);
500                 }
501             }
502             tile_block_ur++;
503             if (tile_block_ur >= jcp.tile_block_ur) {
504                 tile_block_ur = 0;
505                 nb_tile_block_ur++;
506             }
507             if (nb_tile_block_ur >= jcp.nb_tile_block_ur) {
508                 nb_tile_block_ur = 0;
509                 tile_block++;
510             }
511         }
512     }
513 }
514
515 template <bool is_fwd>
516 void weight_transform_data(const jit_conv_winograd_conf_t &jcp,
517         float *wp, float *twp)
518 {
519     const int kh = 3;
520     const int kw = 3;
521     array_offset_calculator<float, 6> input(wp,
522             jcp.oc/jcp.oc_simd_block,
523             jcp.ic/jcp.ic_simd_block,
524             jcp.kh, jcp.kw,
525             simd_w, simd_w);
526     array_offset_calculator<float, 8> output(twp,
527             jcp.dimM_nb_block,
528             alpha, alpha,
529             jcp.dimK_nb_block,
530             jcp.dimM_block, jcp.dimK_block,
531             simd_w, simd_w);
532     float Fw[alpha][alpha][simd_w][simd_w];
533     float F[kh][kw][simd_w][simd_w];
534
535     for (int j = 0; j < kh; j++) {
536         for (int i = 0; i < kw; i++) {
537             for (int v1 = 0; v1 < simd_w; v1++) {
538                 float *base_inp = is_fwd
539                                 ? &(input(0, 0, j, i, v1, 0))
540                                 : &(input(0, 0, 2 - j, 2 - i, v1, 0));
541                 PRAGMA_OMP_SIMD()
542                 for (int v2 = 0; v2 < simd_w; v2++) {
543                     if (is_fwd)
544                         F[j][i][v1][v2] = *(base_inp + v2);
545                     else
546                         F[j][i][v2][v1] = *(base_inp + v2);
547                 }
548             }
549         }
550     }
551
552     trans_W_4x4_3x3(Fw, F);
553
554     for (int j = 0; j < alpha; j++) {
555         for (int i = 0; i < alpha; i++) {
556             for (int v1 = 0; v1 < simd_w; v1++) {
557                 PRAGMA_OMP_SIMD()
558                 for (int v2 = 0; v2 < simd_w; v2++) {
559                     output(0, j, i, 0, 0, 0, v1, v2) = Fw[j][i][v1][v2];
560                 }
561             }
562         }
563     }
564 }
565
566 template <bool is_fwd, bool with_bias, bool with_relu_presum, bool with_sum>
567 void output_transform_data(int image, const jit_conv_winograd_conf_t &jcp,
568         const post_ops_t &p_ops, float *toutp, float *pout_b, float *bias,
569         bool streamout = true) {
570     float Ow[alpha][alpha][simd_w];
571     float O[tile_size][tile_size][simd_w];
572     int outw = is_fwd ? jcp.ow : jcp.iw;
573     int outh = is_fwd ? jcp.oh : jcp.ih;
574
575     /* Prepare for PostOps */
576     bool with_relu_postsum = p_ops.find(primitive_kind::eltwise, 1) != -1;
577
578     array_offset_calculator<float, 8> input(toutp,
579             jcp.dimN_nb_block, jcp.dimM_nb_block,
580             alpha, alpha,
581             jcp.dimN_block, jcp.dimM_block,
582             jcp.dimN_reg_block, jcp.dimM_simd_block);
583
584     int tile_base_index = image * jcp.itiles * jcp.jtiles;
585     int tile_block_ur = tile_base_index % jcp.tile_block_ur;
586     int nb_tile_block_ur =
587         (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur;
588     int tile_block =
589         (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur;
590
591     for (int tj = 0; tj < jcp.jtiles; tj++) {
592         for (int ti = 0; ti < jcp.itiles; ti++) {
593             for (int j = 0; j < alpha; j++) {
594                 for (int i = 0; i < alpha; i++) {
595                     PRAGMA_OMP_SIMD()
596                     for (int v = 0; v < simd_w; v++) {
597                         Ow[j][i][v] = input(tile_block, 0,
598                                 j, i,
599                                 nb_tile_block_ur, 0,
600                                 tile_block_ur, v);
601                     }
602                 }
603             }
604
605             trans_O_4x4_3x3(Ow, O);
606
607             for (int j = 0; j < tile_size; j++) {
608                 int ydim = tj * tile_size + j;
609                 if (ydim < outh) {
610                     float *pout_j = pout_b + ydim * outw * simd_w;
611                     for (int i = 0; i < tile_size; i++) {
612                         int xdim = ti * tile_size + i;
613                         if (xdim < outw) {
614                             float *pout_i = pout_j + xdim * simd_w;
615                             if (is_fwd) {
616                                 PRAGMA_OMP_SIMD()
617                                 for (int v = 0; v < simd_w; v++) {
618                                     O[j][i][v] += with_bias ? bias[v] : 0.f;
619                                     O[j][i][v] = true
620                                         && with_relu_presum && O[j][i][v] < 0.f
621                                                 ? O[j][i][v]
622                                                 * jcp.eltwise.alpha
623                                                 : O[j][i][v];
624                                 }
625                             }
626                             if (with_sum)
627                                 accum_output(pout_i, O[j][i], streamout,
628                                         with_relu_postsum);
629                             else
630                                 store_output(pout_i, O[j][i], streamout);
631                         }
632                     }
633                 }
634             }
635             tile_block_ur++;
636             if (tile_block_ur >= jcp.tile_block_ur) {
637                 tile_block_ur = 0;
638                 nb_tile_block_ur++;
639             }
640             if (nb_tile_block_ur >= jcp.nb_tile_block_ur) {
641                 nb_tile_block_ur = 0;
642                 tile_block++;
643             }
644         }
645     }
646 }
647
648 template <bool ver_4fma>
649 void diff_src_transform_bwd_weights(int image, jit_conv_winograd_conf_t conv,
650         float *inp, float *tinp, float *Iw_temp,
651         void (*transpose_4fma_ker)(float *, float *))
652 {
653
654     const int ifwp = conv.iw + conv.l_pad;
655     const int ifhp = conv.ih + conv.t_pad;
656     float I[alpha][alpha][simd_w];
657     float Iw[alpha][alpha][simd_w];
658
659     array_offset_calculator<float, 4> Iw_trans_temp(Iw_temp,
660             alpha, alpha, conv.tile_4fma, simd_w);
661     array_offset_calculator<float, 5> input(inp,
662             conv.mb, conv.ic/simd_w, conv.ih, conv.iw, simd_w);
663     array_offset_calculator<float, 8> output(tinp,
664             conv.nb_ic, alpha, alpha,
665             conv.tile_block, conv.ic_block,
666             conv.nb_tile_block_ur, conv.tile_block_ur,
667             conv.ic_simd_block * conv.tile_4fma);
668
669     int tile_base_index =
670         image * (conv.itiles * conv.jtiles + conv.tile_4fma_padding);
671     int tile_4fma = 0;
672     int tile_block_ur = (tile_base_index / conv.tile_4fma) % conv.tile_block_ur;
673     int nb_tile_block_ur =
674         (tile_base_index / conv.tile_4fma / conv.tile_block_ur)
675         % conv.nb_tile_block_ur;
676     int tile_block = (tile_base_index / conv.tile_4fma / conv.tile_block_ur)
677             / conv.nb_tile_block_ur;
678
679     for (int tj = 0; tj < conv.jtiles; tj++) {
680         for (int ti = 0; ti < conv.itiles; ti++) {
681             for (int j = 0; j < alpha; j++) {
682                 int ydim = tj * tile_size + j;
683                 if ((conv.t_pad <= ydim) && ydim < ifhp) {
684                     for (int i = 0; i < alpha; i++) {
685                         int xdim = ti * tile_size + i;
686                         if ((conv.l_pad <= xdim) && xdim < ifwp) {
687                             PRAGMA_OMP_SIMD()
688                             for (int v = 0; v < simd_w; v++) {
689                                 I[j][i][v] = input(0, 0,
690                                         ydim - conv.t_pad,
691                                         xdim - conv.l_pad, v);
692                             }
693                         } else {
694                             PRAGMA_OMP_SIMD()
695                             for (int v = 0; v < simd_w; v++) {
696                                 I[j][i][v] = 0.0f;
697                             }
698                         }
699                     }
700                 } else {
701                     for (int i = 0; i < alpha; i++) {
702                         PRAGMA_OMP_SIMD()
703                         for (int v = 0; v < simd_w; v++) {
704                             I[j][i][v] = 0.0f;
705                         }
706                     }
707                 }
708             }
709             trans_I_4x4_3x3(Iw, I);
710
711             if (ver_4fma) {
712                 for (int j = 0; j < alpha; j++) {
713                     for (int i = 0; i < alpha; i++) {
714                         float *Iw_temp_base = &(Iw_trans_temp(j, i,
715                                                         tile_4fma, 0));
716                         PRAGMA_OMP_SIMD()
717                         for (int v = 0; v < simd_w; v++) {
718                             Iw_temp_base[v] = Iw[j][i][v];
719                         }
720                     }
721                 }
722                 tile_4fma++;
723                 if (tile_4fma == conv.tile_4fma) {
724                     float *outp = &(output(0, 0, 0,
725                                 tile_block, 0,
726                                 nb_tile_block_ur, tile_block_ur, 0));
727                     transpose_4fma_ker(outp, (float *)Iw_temp);
728                     tile_4fma = 0;
729                     tile_block_ur++;
730                 }
731             } else {
732                 for (int j = 0; j < alpha; j++) {
733                     for (int i = 0; i < alpha; i++) {
734                         store_output(&(output(0, j, i,
735                                         tile_block, 0,
736                                         nb_tile_block_ur, tile_block_ur, 0)),
737                                      Iw[j][i], true);
738                     }
739                 }
740                 tile_block_ur++;
741             }
742
743             if (tile_block_ur == conv.tile_block_ur) {
744                 tile_block_ur = 0;
745                 ++nb_tile_block_ur;
746             }
747             if (nb_tile_block_ur == conv.nb_tile_block_ur) {
748                 nb_tile_block_ur = 0;
749                 tile_block++;
750             }
751         }
752     }
753
754     if (ver_4fma && tile_4fma < conv.tile_4fma && conv.tile_4fma_padding != 0) {
755
756         for (int j = 0; j < alpha; j++) {
757             for (int i = 0; i < alpha; i++) {
758                 for (int tb = tile_4fma; tb < conv.tile_4fma; tb++) {
759                     float *Iw_temp_base = &(Iw_trans_temp(j, i, tb, 0));
760                     PRAGMA_OMP_SIMD()
761                     for (int v = 0; v < simd_w; v++) {
762                         Iw_temp_base[v] = 0;
763                     }
764                 }
765             }
766         }
767         float *outp = &(output(0, 0, 0,
768                     tile_block, 0,
769                     nb_tile_block_ur, tile_block_ur, 0));
770         transpose_4fma_ker(outp, (float *)Iw_temp);
771     }
772 }
773
774 template <bool with_bias>
775 void diff_dst_transform_bwd_weights(int image, jit_conv_winograd_conf_t conv,
776         float *inp, float *tinp, float *dbias)
777 {
778
779     const int total_tiles = conv.itiles * conv.jtiles + conv.tile_4fma_padding;
780     float I[alpha][alpha][simd_w];
781     float Iw[alpha][alpha][simd_w];
782
783     array_offset_calculator<float, 5> input(inp,
784             conv.mb, conv.oc/simd_w, conv.oh, conv.ow, conv.oc_simd_block);
785     array_offset_calculator<float, 8> output(tinp,
786             conv.nb_oc, alpha, alpha,
787             conv.tile_block, conv.oc_block,
788             conv.nb_tile_block_ur,
789             conv.tile_block_ur * conv.tile_4fma, conv.oc_simd_block);
790
791     int tile_base_index = image * total_tiles;
792     int tile_block_ur = tile_base_index % (conv.tile_block_ur * conv.tile_4fma);
793     int nb_tile_block_ur =
794         (tile_base_index / conv.tile_block_ur / conv.tile_4fma)
795             % conv.nb_tile_block_ur;
796     int tile_block = (tile_base_index / conv.tile_block_ur / conv.tile_4fma)
797             / conv.nb_tile_block_ur;
798
799     for (int tj = 0; tj < conv.jtiles; tj++) {
800         for (int ti = 0; ti < conv.itiles; ti++) {
801             for (int j = 0; j < alpha; j++) {
802                 int ydim = tj * tile_size + j;
803                 if (ydim < conv.oh) {
804                     for (int i = 0; i < alpha; i++) {
805                         int xdim = ti * tile_size + i;
806                         if (xdim < conv.ow) {
807                             float *input_base = &(input(0, 0, ydim, xdim, 0));
808
809                             PRAGMA_OMP_SIMD()
810                             for (int v = 0; v < simd_w; v++) {
811                                 I[j][i][v] = input_base[v];
812                             }
813                             if (with_bias && j < tile_size && i < tile_size) {
814                                 PRAGMA_OMP_SIMD()
815                                 for (int v = 0; v < simd_w; v++) {
816                                     dbias[v] += input_base[v];
817                                 }
818                             }
819                         } else {
820                             PRAGMA_OMP_SIMD()
821                             for (int v = 0; v < simd_w; v++) {
822                                 I[j][i][v] = 0.0f;
823                             }
824                         }
825                     }
826                 } else {
827                     for (int i = 0; i < alpha; i++) {
828                         PRAGMA_OMP_SIMD()
829                         for (int v = 0; v < simd_w; v++) {
830                             I[j][i][v] = 0.0f;
831                         }
832                     }
833                 }
834             }
835
836             trans_W_3x3_4x4_wu(Iw, I);
837
838             for (int j = 0; j < alpha; j++) {
839                 for (int i = 0; i < alpha; i++) {
840                     store_output(&(output(0, j, i,
841                                     tile_block, 0,
842                                     nb_tile_block_ur,
843                                     tile_block_ur, 0)),
844                                  Iw[j][i], true);
845                 }
846             }
847             tile_block_ur++;
848             if (tile_block_ur >= conv.tile_block_ur * conv.tile_4fma) {
849                 tile_block_ur = 0;
850                 nb_tile_block_ur++;
851             }
852             if (nb_tile_block_ur >= conv.nb_tile_block_ur) {
853                 nb_tile_block_ur = 0;
854                 tile_block++;
855             }
856         }
857     }
858 }
859
860 void diff_weights_transform_bwd_weights(jit_conv_winograd_conf_t conv,
861         float *wp, float *twp)
862 {
863     const int kh = 3;
864     const int kw = 3;
865     float Fw[alpha][alpha][simd_w][simd_w];
866     float F[kh][kw][simd_w][simd_w];
867
868     array_offset_calculator<float, 8> input(twp,
869             conv.nb_ic, conv.nb_oc,
870             alpha, alpha,
871             conv.oc_block, conv.ic_block,
872             conv.ic_simd_block, conv.oc_simd_block);
873     array_offset_calculator<float, 6> output(wp,
874             conv.oc/simd_w, conv.ic/simd_w,
875             conv.kh, conv.kw,
876             conv.ic_simd_block, conv.oc_simd_block);
877
878     for (int j = 0; j < alpha; j++) {
879         for (int i = 0; i < alpha; i++) {
880             for (int v = 0; v < conv.ic_simd_block; v++) {
881                 PRAGMA_OMP_SIMD()
882                 for (int k = 0; k < conv.oc_simd_block; k++) {
883                     Fw[j][i][v][k] = input(0, 0, j, i, 0, 0, v, k);
884                 }
885             }
886         }
887     }
888
889     trans_O_3x3_4x4_wu(Fw, F);
890
891     for (int j = 0; j < kh; j++) {
892         for (int i = 0; i < kw; i++) {
893             for (int v = 0; v < conv.ic_simd_block; v++) {
894                 store_output(&(output(0, 0, j, i, v, 0)),
895                              F[j][i][v], true);
896             }
897         }
898     }
899 }
900
901 template <bool is_fwd>
902 void _jit_avx512_common_convolution_winograd_t<is_fwd>::_execute_data_W_S_G_D(
903         const int MB, float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr,
904         const memory_tracking::grantor_t &scratchpad) const{
905     const auto &jcp = kernel_->jcp;
906     const auto &p_ops = attr_->post_ops_;
907
908     const int inph = is_fwd ? jcp.ih : jcp.oh;
909     const int inpw = is_fwd ? jcp.iw : jcp.ow;
910     const int outh = is_fwd ? jcp.oh : jcp.ih;
911     const int outw = is_fwd ? jcp.ow : jcp.iw;
912
913     /* Note that jcp.with_eltwise is true for both fused conv+relu primitive
914      * and conv primitive with PostOps with relu before sum
915      * (PostOps relu after sum is handled later) */
916     auto output_transform = jcp.with_bias
917             ? (jcp.with_eltwise
918                 ? (jcp.with_sum
919                     ? output_transform_data<is_fwd, true, true, true>
920                     : output_transform_data<is_fwd, true, true, false>)
921                 : (jcp.with_sum
922                     ? output_transform_data<is_fwd, true, false, true>
923                     : output_transform_data<is_fwd, true, false, false>))
924             : (jcp.with_eltwise
925                 ? (jcp.with_sum
926                     ? output_transform_data<is_fwd, false, true, true>
927                     : output_transform_data<is_fwd, false, true, false>)
928                 : (jcp.with_sum
929                     ? output_transform_data<is_fwd, false, false, true>
930                     : output_transform_data<is_fwd, false, false, false>));
931
932     /* Notation:
933        FWD: dimM:oc, dimN:ntiles, dimK:ic,
934        BWD: dimM:ic, dimN:ntiles, dimK:oc,
935        FWD/BWD: V: src/diff_dst transform, U:weight transform,
936                 M:dst/diff_src transform  */
937     array_offset_calculator<float, 5> input(inp_ptr,
938             MB, jcp.dimK/jcp.dimK_reg_block, inph, inpw,
939             jcp.dimK_reg_block);
940     array_offset_calculator<float, 5> output(out_ptr,
941             MB, jcp.dimM/jcp.dimM_simd_block, outh, outw,
942             jcp.dimM_simd_block);
943     array_offset_calculator<float, 6> weights(wei_ptr,
944             jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw,
945             jcp.ic_simd_block, jcp.oc_simd_block);
946     array_offset_calculator<float, 2> bias(bias_ptr,
947             jcp.dimM/jcp.dimM_simd_block, jcp.dimM_simd_block);
948
949     array_offset_calculator<float, 8> M(is_fwd
950             ? scratchpad.template get<float>(key_wino_M)
951             : scratchpad.template get<float>(key_wino_V),
952             jcp.dimN_nb_block, jcp.dimM_nb_block,
953             alpha, alpha,
954             jcp.dimN_block, jcp.dimM_block,
955             jcp.dimN_reg_block, jcp.dimM_simd_block);
956     array_offset_calculator<float, 8> U(
957             scratchpad.template get<float>(key_wino_U),
958             jcp.dimM_nb_block,
959             alpha, alpha,
960             jcp.dimK_nb_block,
961             jcp.dimM_block, jcp.dimK_block,
962             jcp.dimK_reg_block, jcp.dimM_simd_block);
963     array_offset_calculator<float, 8> V(is_fwd
964             ? scratchpad.template get<float>(key_wino_V)
965             : scratchpad.template get<float>(key_wino_M),
966             jcp.dimN_nb_block, alpha, alpha,
967             jcp.dimN_block, jcp.dimK_nb_block,
968             jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block);
969
970     bool V_streamout = jcp.dimN * jcp.dimK * alpha * alpha * sizeof(float)
971         > 2 * LLC_cache_size ? true : false;
972
973     const bool output_is_aligned = ((size_t)out_ptr & (64 - 1)) == 0;
974
975     const bool wants_padded_bias = jcp.with_bias
976         && jcp.oc_without_padding != jcp.oc;
977     float last_slice_bias[simd_w] = {0};
978     if (wants_padded_bias) {
979         for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
980             last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
981     }
982
983 PRAGMA_OMP(parallel)
984     {
985         parallel_nd_in_omp(MB, jcp.dimK_nb_block, jcp.dimK_block,
986             [&](int img, int K_blk1, int K_blk2) {
987             input_transform_data<is_fwd>(img, jcp,
988                 &(input(img, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, 0)),
989                 &(V(0, 0, 0, 0, K_blk1, K_blk2, 0, 0)), V_streamout);
990         });
991
992         parallel_nd_in_omp(jcp.nb_oc, jcp.nb_ic, jcp.oc_block, jcp.ic_block,
993             [&](int ofm1, int ifm1, int ofm2, int ifm2) {
994             float *U_base_ptr = is_fwd
995                 ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0))
996                 : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0));
997             weight_transform_data<is_fwd>(jcp,
998                 &(weights(ofm1 * jcp.oc_block + ofm2,
999                 ifm1 * jcp.ic_block + ifm2, 0, 0, 0, 0)), U_base_ptr);
1000         });
1001
1002 PRAGMA_OMP(barrier)
1003
1004         parallel_nd_in_omp(jcp.dimN_nb_block, alpha, alpha, jcp.dimM_nb_block, jcp.dimN_block,
1005             [&](int N_blk1, int oj, int oi, int M_blk1, int N_blk2) {
1006
1007             kernel_->gemm_loop_ker_first_iter(
1008                     (float *)&(M(N_blk1, M_blk1, oj, oi,
1009                             N_blk2, 0, 0, 0)),
1010                     (const float *)&(U(M_blk1, oj, oi,
1011                             0, 0, 0, 0, 0)),
1012                     (const float *)&(V(N_blk1, oj, oi,
1013                             N_blk2, 0, 0, 0, 0)));
1014             for (int K_blk1 = 1; K_blk1 < jcp.dimK_nb_block; K_blk1++) {
1015                 kernel_->gemm_loop_ker(
1016                         (float *)&(M(N_blk1, M_blk1, oj, oi,
1017                                 N_blk2, 0, 0, 0)),
1018                         (const float *)&(U(M_blk1, oj, oi,
1019                                 K_blk1, 0, 0, 0, 0)),
1020                         (const float *)&(V(N_blk1, oj, oi,
1021                                 N_blk2, K_blk1,
1022                                 0, 0, 0)));
1023             }
1024
1025         });
1026
1027
1028 PRAGMA_OMP(barrier)
1029
1030         parallel_nd_in_omp(MB, jcp.dimM_nb_block, jcp.dimM_block,
1031                     [&](int img, int M_blk1, int M_blk2) {
1032
1033             const int M_blk = M_blk1 * jcp.dimM_block + M_blk2;
1034
1035             float *bias_ptr = wants_padded_bias
1036                 && M_blk == jcp.dimM / jcp.dimM_simd_block - 1
1037                 ? last_slice_bias : &bias(M_blk, 0);
1038
1039             output_transform(img, jcp, p_ops,
1040                     &(M(0, M_blk1, 0, 0, 0, M_blk2, 0, 0)),
1041                     &(output(img, M_blk, 0, 0, 0)),
1042                     bias_ptr, output_is_aligned);
1043
1044        });
1045     }
1046 }
1047
1048 template struct _jit_avx512_common_convolution_winograd_t<true>;
1049 template struct _jit_avx512_common_convolution_winograd_t<false>;
1050
1051 void jit_avx512_common_convolution_winograd_bwd_weights_t::
1052 _maybe_execute_diff_bias_copy(
1053         const memory_tracking::grantor_t &scratchpad) const {
1054     if (pd()->wants_padded_bias()) {
1055         auto padded_bias = scratchpad.get<float>(key_conv_padded_bias);
1056         float *diff_bias = (float *)this->memory(1);
1057         for (int oc = 0; oc < pd()->jcp_.oc_without_padding; ++oc)
1058             diff_bias[oc] = padded_bias[oc];
1059     }
1060 }
1061
1062 void jit_avx512_common_convolution_winograd_bwd_weights_t::
1063 _execute_backward_weights_S_D_G_W(
1064         const memory_tracking::grantor_t &scratchpad) const {
1065     const auto &jcp = kernel_->jcp;
1066     const int nthreads = jcp.nthr;
1067
1068     auto diff_src_transform_bwd_weights_ver = jcp.ver == ver_4fma ?
1069             diff_src_transform_bwd_weights<true> :
1070             diff_src_transform_bwd_weights<false>;
1071     auto diff_dst_transform_bwd_weights_ver = jcp.with_bias
1072                                             ? diff_dst_transform_bwd_weights<true>
1073                                             : diff_dst_transform_bwd_weights<false>;
1074
1075     array_offset_calculator<float, 5> diff_src((float *)this->input_memory(0),
1076             jcp.mb, jcp.ic/simd_w, jcp.ih, jcp.iw, simd_w);
1077     array_offset_calculator<float, 5> diff_dst((float *)this->input_memory(1),
1078             jcp.mb, jcp.oc/simd_w, jcp.oh, jcp.ow, simd_w);
1079     array_offset_calculator<float, 6> diff_weights((float *)this->memory(0),
1080             jcp.oc/simd_w, jcp.ic/simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
1081     array_offset_calculator<float, 2> diff_bias(pd()->wants_padded_bias()
1082             ? scratchpad.get<float>(key_conv_padded_bias)
1083             : (float *)this->memory(1), jcp.oc/simd_w, simd_w);
1084
1085     array_offset_calculator<float, 8> U(
1086             scratchpad.get<float>(key_wino_U),
1087             jcp.nb_ic, jcp.nb_oc,
1088             alpha, alpha,
1089             jcp.oc_block, jcp.ic_block,
1090             jcp.ic_simd_block, jcp.oc_simd_block);
1091
1092     array_offset_calculator<float, 8> M(
1093             scratchpad.get<float>(key_wino_M),
1094             jcp.nb_oc, alpha, alpha,
1095             jcp.tile_block, jcp.oc_block,
1096             jcp.nb_tile_block_ur, jcp.tile_block_ur * jcp.tile_4fma,
1097             jcp.oc_simd_block);
1098     array_offset_calculator<float, 8> V(
1099             scratchpad.get<float>(key_wino_V),
1100             jcp.nb_ic, alpha, alpha,
1101             jcp.tile_block, jcp.ic_block,
1102             jcp.nb_tile_block_ur, jcp.tile_block_ur,
1103             jcp.ic_simd_block * jcp.tile_4fma);
1104
1105     const int trans_buffer_size = alpha * alpha * jcp.tile_4fma
1106                                 * jcp.ic_simd_block;
1107     array_offset_calculator<float, 2> trans_buffer(
1108             scratchpad.get<float>(key_conv_tr_src),
1109             nthreads,
1110             trans_buffer_size);
1111
1112     array_offset_calculator<float, 2> diff_bias_prv(
1113             scratchpad.get<float>(key_conv_bia_reduction),
1114             nthreads,
1115             jcp.oc);
1116
1117 PRAGMA_OMP(parallel num_threads(nthreads))
1118     {
1119         if (jcp.with_bias) {
1120             parallel_nd_in_omp(nthreads, jcp.oc, [&](int ithr, int ofm) {
1121                 diff_bias_prv(ithr, ofm) = 0.0f;
1122             });
1123
1124 PRAGMA_OMP(for nowait)
1125             for (int bofm = 0; bofm < jcp.oc / simd_w; bofm++) {
1126                 PRAGMA_OMP_SIMD()
1127                 for (int v = 0; v < simd_w; v++)
1128                     diff_bias(bofm, v) = 0.0f;
1129             }
1130         }
1131
1132         const int ithread = mkldnn_get_thread_num();
1133
1134         parallel_nd_in_omp(jcp.mb, jcp.nb_ic, jcp.ic_block,
1135             [&](int img, int ifm1, int ifm2) {
1136             float *transb = jcp.ver == ver_4fma
1137                ? &(trans_buffer(ithread, 0))
1138                : NULL;
1139             diff_src_transform_bwd_weights_ver(img, jcp,
1140                &(diff_src(img, ifm1 * jcp.ic_block + ifm2,
1141                        0, 0, 0)),
1142                &(V(ifm1, 0, 0, 0, ifm2, 0, 0, 0)),
1143                transb,
1144                kernel_->transpose_4fma_ker);
1145         });
1146
1147         parallel_nd_in_omp(jcp.mb, jcp.nb_oc, jcp.oc_block,
1148             [&](int img, int ofm1, int ofm2) {
1149             float *dbias = jcp.with_bias
1150                    ? &(diff_bias_prv(ithread,
1151                                simd_w * (ofm1 * jcp.oc_block + ofm2)))
1152                    : NULL;
1153             diff_dst_transform_bwd_weights_ver(img, jcp,
1154                     &(diff_dst(img, ofm1 * jcp.oc_block + ofm2,
1155                             0, 0, 0)),
1156                     &(M(ofm1, 0, 0, 0, ofm2, 0, 0, 0)),
1157                     dbias);
1158         });
1159
1160 PRAGMA_OMP(barrier)
1161
1162         for (int ifm1 = 0; ifm1 < jcp.nb_ic; ifm1++) {
1163             parallel_nd_in_omp(alpha, alpha, jcp.nb_oc,
1164                 [&](int oj, int oi, int ofm1) {
1165                 kernel_->gemm_loop_ker_first_iter(
1166                     (float *)&(U(ifm1, ofm1, oj, oi,
1167                             0, 0, 0, 0)),
1168                     (const float *)&(M(ofm1, oj, oi,
1169                             0, 0, 0, 0, 0)),
1170                     (const float *)&(V(ifm1, oj, oi,
1171                             0, 0, 0, 0, 0)));
1172                 for (int tile_block = 1; tile_block < jcp.tile_block;
1173                      tile_block++) {
1174                     kernel_->gemm_loop_ker((float *)&(U(ifm1, ofm1,
1175                                 oj, oi,
1176                                 0, 0, 0, 0)),
1177                         (const float *)&(M(ofm1, oj, oi, tile_block,
1178                                 0, 0, 0, 0)),
1179                         (const float *)&(V(ifm1, oj, oi, tile_block,
1180                                 0, 0, 0, 0)));
1181                 }
1182             });
1183         }
1184
1185 PRAGMA_OMP(barrier)
1186
1187         parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block,
1188             [&](int ifm1, int ofm1, int ofm2, int ifm2) {
1189             diff_weights_transform_bwd_weights(jcp,
1190                     &(diff_weights(ofm1 * jcp.oc_block + ofm2,
1191                             ifm1 * jcp.ic_block + ifm2, 0, 0, 0, 0)),
1192                     &(U(ifm1, ofm1, 0, 0, ofm2, ifm2, 0, 0)));
1193         });
1194
1195         if (jcp.with_bias) {
1196 PRAGMA_OMP(for)
1197             for (int ofm1 = 0; ofm1 < jcp.oc / simd_w; ofm1++) {
1198                 for (int ithr = 0; ithr < nthreads; ithr++) {
1199                     float* base_bias_ptr = &(diff_bias(ofm1, 0));
1200                     float* base_bias_prv_ptr = &(diff_bias_prv(
1201                                 ithr * jcp.oc + ofm1 * simd_w));
1202                     PRAGMA_OMP_SIMD()
1203                     for (int ofm2 = 0; ofm2 < simd_w; ofm2++) {
1204                         base_bias_ptr[ofm2] += base_bias_prv_ptr[ofm2];
1205                     }
1206                 }
1207             }
1208         }
1209     }
1210
1211     _maybe_execute_diff_bias_copy(scratchpad);
1212 }
1213
1214 }
1215 }
1216 }
1217 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s