Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_fp32_wino_conv_4x3.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifdef __INTEL_COMPILER
18 #include <immintrin.h>
19 #endif
20
21 #include "mkldnn_types.h"
22
23 #include "c_types_map.hpp"
24 #include "mkldnn_thread.hpp"
25 #include "type_helpers.hpp"
26 #include "utils.hpp"
27
28 #include "jit_avx512_core_fp32_wino_conv_4x3.hpp"
29
30 #ifndef _MSC_VER
31 #define pragma_unroll _Pragma("unroll")
32 #else
33 #define pragma_unroll
34 #endif
35
36
37 namespace mkldnn {
38 namespace impl {
39 namespace cpu {
40
41 using namespace mkldnn::impl::status;
42 using namespace mkldnn::impl::memory_format;
43 using namespace mkldnn::impl::memory_tracking::names;
44 using namespace mkldnn::impl::utils;
45
46 template <bool is_fwd>
47 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
48 ::weight_transform_data(const jit_conv_winograd_conf_t &jcp,
49         float *wp, float *twp) const
50 {
51     float G[] = {0.26890756302521f, 0.688403361344538f, 0.119514472455649f,
52                  1.13777777777778f, 0.430252100840336f, 0.179271708683473f};
53     const int kh = 3;
54     const int kw = 3;
55     float Fw[alpha][alpha][simd_w][simd_w];
56     float F[kh][kw][simd_w][simd_w];
57     float T[alpha][3][simd_w];
58     auto p = jit_wino_transform_call_s();
59
60     p.src = wp;
61     p.dst = twp;
62     p.G = G;
63     p.M = F;
64     p.Mw = Fw;
65     p.T = T;
66
67     kernel_->weights_transform_data_ker(&p);
68 }
69
70 template<bool is_fwd>
71 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::output_transform_data
72 (int image, const jit_conv_winograd_conf_t &jcp,
73     const post_ops_t &p_ops, float *toutp, float *pout_b, float *bias) const {
74
75     float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f};
76     float Ow[alpha][alpha][simd_w];
77     float O[tile_size][tile_size][simd_w];
78     float T[tile_size][alpha][simd_w];
79
80     auto p = jit_wino_transform_call_s();
81     p.src = toutp;
82     p.dst = pout_b;
83     p.G = G;
84     p.M = O;
85     p.Mw = Ow;
86     p.T = T;
87     p.bias = bias;
88
89     int tile_base_index = image * jcp.itiles * jcp.jtiles;
90     int tile_block_ur = tile_base_index % jcp.tile_block_ur;
91     int nb_tile_block_ur =
92         (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur;
93     int tile_block =
94         (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur;
95
96     for (int tj = 0; tj < jcp.jtiles; tj++) {
97         for (int ti = 0; ti < jcp.itiles; ti++) {
98
99             p.tile_block_ur = tile_block_ur;
100             p.nb_tile_block_ur = nb_tile_block_ur;
101             p.tile_block = tile_block;
102             p.tj = tj;
103             p.ti = ti;
104
105             kernel_->output_transform_data_ker(&p);
106
107             tile_block_ur++;
108             if (tile_block_ur >= jcp.tile_block_ur) {
109                 tile_block_ur = 0;
110                 nb_tile_block_ur++;
111             }
112             if (nb_tile_block_ur >= jcp.nb_tile_block_ur) {
113                 nb_tile_block_ur = 0;
114                 tile_block++;
115             }
116         }
117     }
118 }
119
120 template<bool is_fwd>
121 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
122 ::output_transform_tileblock_data(int tile_block,
123     const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops,
124     float *toutp, float *outp, float *bias) const {
125
126     float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f};
127     float Ow[alpha][alpha][simd_w];
128     float O[tile_size][tile_size][simd_w];
129     float T[tile_size][alpha][simd_w];
130
131     auto p = jit_wino_transform_call_s();
132     p.src = toutp;
133     p.dst = outp;
134     p.G = G;
135     p.M = O;
136     p.Mw = Ow;
137     p.T = T;
138     p.bias = bias;
139
140     int outw = is_fwd ? jcp.ow : jcp.iw;
141     int outh = is_fwd ? jcp.oh : jcp.ih;
142
143     int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur;
144
145     for (int nb_tile_block_ur = 0;
146         nb_tile_block_ur < jcp.nb_tile_block_ur;
147         nb_tile_block_ur++) {
148
149         for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur;
150             tile_block_ur++) {
151             int img = tile_index / (jcp.jtiles * jcp.itiles);
152             int ti = tile_index % jcp.itiles;
153             int tj = (tile_index / jcp.itiles) % jcp.jtiles;
154
155             p.tile_block_ur = tile_block_ur;
156             p.nb_tile_block_ur = nb_tile_block_ur;
157             p.tile_block = tile_block;
158             p.tj = tj;
159             p.ti = ti;
160             p.dst = outp + img * (jcp.dimM / jcp.dimM_simd_block)
161                                * outh * outw * jcp.dimM_simd_block;
162
163             kernel_->output_transform_data_ker(&p);
164
165             tile_index++;
166         }
167     }
168 }
169
170
171 template<bool is_fwd>
172 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
173     ::input_transform_data(int image, const jit_conv_winograd_conf_t &jcp,
174         float *inp, float *tinp) const
175 {
176     float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
177                  0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
178
179     float Iw[alpha][alpha][simd_w];
180     float I[alpha][alpha][simd_w];
181     float T[alpha][alpha][simd_w];
182
183     auto p = jit_wino_transform_call_s();
184
185     p.src = inp;
186     p.dst = tinp;
187     p.G = G;
188     p.M = I;
189     p.Mw = Iw;
190     p.T = T;
191
192     int tile_base_index = image * jcp.itiles * jcp.jtiles;
193     int tile_block_ur = tile_base_index % jcp.tile_block_ur;
194     int nb_tile_block_ur =
195         (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur;
196     int tile_block =
197         (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur;
198
199     for (int tj = 0; tj < jcp.jtiles; tj++) {
200         for (int ti = 0; ti < jcp.itiles; ti++) {
201
202             p.tile_block_ur = tile_block_ur;
203             p.nb_tile_block_ur = nb_tile_block_ur;
204             p.tile_block = tile_block;
205             p.tj = tj;
206             p.ti = ti;
207
208             kernel_->input_transform_data_ker(&p);
209
210             tile_block_ur++;
211             if (tile_block_ur >= jcp.tile_block_ur) {
212                 tile_block_ur = 0;
213                 nb_tile_block_ur++;
214             }
215             if (nb_tile_block_ur >= jcp.nb_tile_block_ur) {
216                 nb_tile_block_ur = 0;
217                 tile_block++;
218             }
219         }
220     }
221 }
222
223 template <bool is_fwd>
224 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
225     ::input_transform_tileblock_data(int tile_block,
226         const jit_conv_winograd_conf_t &jcp,
227         float *inp, float *tinp) const
228 {
229     float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
230                0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
231     float Iw[alpha][alpha][simd_w];
232     float I[alpha][alpha][simd_w];
233     float T[alpha][alpha][simd_w];
234
235     const int inph = is_fwd ? jcp.ih : jcp.oh;
236     const int inpw = is_fwd ? jcp.iw : jcp.ow;
237
238     array_offset_calculator<float, 5> input(inp,
239         jcp.mb, jcp.dimK / simd_w, inph, inpw, simd_w);
240     array_offset_calculator<float, 7> output(tinp,
241         alpha, alpha,
242         jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block,
243         jcp.dimN_reg_block, jcp.dimK_reg_block);
244
245     auto p = jit_wino_transform_call_s();
246
247     p.dst = tinp;
248     p.G = G;
249     p.M = I;
250     p.Mw = Iw;
251     p.T = T;
252
253
254     int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur;
255
256     for (int nb_tile_block_ur = 0;
257             nb_tile_block_ur < jcp.nb_tile_block_ur;
258             nb_tile_block_ur++) {
259
260         for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur;
261                 tile_block_ur++) {
262
263             int img = tile_index / (jcp.jtiles * jcp.itiles);
264             int ti = tile_index % jcp.itiles;
265             int tj = (tile_index / jcp.itiles) % jcp.jtiles;
266             float *pinp_b = &(input(img, 0, 0, 0, 0));
267
268             p.src = pinp_b;
269             p.tile_block_ur = tile_block_ur;
270             p.nb_tile_block_ur = nb_tile_block_ur;
271             p.tj = tj;
272             p.ti = ti;
273
274             kernel_->input_transform_data_ker(&p);
275
276             tile_index++;
277         }
278     }
279 }
280
281 template <bool is_fwd>
282 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::_execute_data_W_S_G_D(
283         const int MB, float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr,
284         const memory_tracking::grantor_t &scratchpad) const {
285     const auto &jcp = kernel_->jcp;
286     const auto &p_ops = attr_->post_ops_;
287
288     const int inph = is_fwd ? jcp.ih : jcp.oh;
289     const int inpw = is_fwd ? jcp.iw : jcp.ow;
290     const int outh = is_fwd ? jcp.oh : jcp.ih;
291     const int outw = is_fwd ? jcp.ow : jcp.iw;
292
293     /* Notation:
294        FWD: dimM:oc, dimN:ntiles, dimK:ic,
295        BWD: dimM:ic, dimN:ntiles, dimK:oc,
296        FWD/BWD: V: src/diff_dst transform, U:weight transform,
297                 M:dst/diff_src transform  */
298     array_offset_calculator<float, 5> input(inp_ptr,
299             MB, jcp.dimK/jcp.dimK_reg_block, inph, inpw,
300             jcp.dimK_reg_block);
301     array_offset_calculator<float, 5> output(out_ptr,
302             MB, jcp.dimM/jcp.dimM_simd_block, outh, outw,
303             jcp.dimM_simd_block);
304     array_offset_calculator<float, 6> weights(wei_ptr,
305             jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw,
306             jcp.ic_simd_block, jcp.oc_simd_block);
307     array_offset_calculator<float, 2> bias(bias_ptr,
308             jcp.dimM/jcp.dimM_simd_block, jcp.dimM_simd_block);
309
310     array_offset_calculator<float, 8> M(is_fwd
311             ? scratchpad.template get<float>(key_wino_M)
312             : scratchpad.template get<float>(key_wino_V),
313             jcp.dimN_nb_block, jcp.dimM_nb_block,
314             alpha, alpha,
315             jcp.dimN_block, jcp.dimM_block * jcp.dimM_reg_block,
316             jcp.dimN_reg_block, jcp.dimM_simd_block);
317
318     auto wino_wei = (jcp.prop_kind == prop_kind::forward_inference)
319             ? wei_ptr
320             : scratchpad.template get<float>(key_wino_U);
321
322     array_offset_calculator<float, 8> U(wino_wei,
323             jcp.dimM_nb_block,
324             alpha, alpha,
325             jcp.dimK_nb_block,
326             jcp.dimM_block * jcp.dimM_reg_block, jcp.dimK_block,
327             jcp.dimK_reg_block, jcp.dimM_simd_block);
328     array_offset_calculator<float, 8> V(is_fwd
329             ? scratchpad.template get<float>(key_wino_V)
330             : scratchpad.template get<float>(key_wino_M),
331             jcp.dimN_nb_block, alpha, alpha,
332             jcp.dimN_block, jcp.dimK_nb_block,
333             jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block);
334
335     const bool wants_padded_bias = jcp.with_bias
336         && jcp.oc_without_padding != jcp.oc;
337     float last_slice_bias[simd_w] = {0};
338     if (wants_padded_bias) {
339         for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
340             last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
341     }
342
343 PRAGMA_OMP(parallel)
344     {
345         parallel_nd_in_omp(MB, jcp.dimK_nb_block, jcp.dimK_block,
346                 [&](int img, int K_blk1, int K_blk2) {
347                 input_transform_data(img, jcp,
348                     &(input(img, K_blk1 * jcp.dimK_block + K_blk2,
349                             0, 0, 0)),
350                         &(V(0, 0, 0, 0, K_blk1, K_blk2, 0, 0)));
351                 });
352
353         if (jcp.prop_kind != prop_kind::forward_inference) {
354             parallel_nd_in_omp(jcp.nb_oc, jcp.nb_ic, (jcp.oc_block * jcp.oc_reg_block),
355                 (jcp.ic_block * jcp.ic_reg_block),
356                 [&](int ofm1, int ifm1, int ofm2, int ifm2) {
357                     float *U_base_ptr = is_fwd
358                         ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0))
359                         : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0));
360                     weight_transform_data(jcp,
361                         &(weights(
362                                 ofm1 * jcp.oc_block * jcp.oc_reg_block + ofm2,
363                                 ifm1 * jcp.ic_block * jcp.ic_reg_block + ifm2,
364                                 0, 0, 0, 0)),
365                         U_base_ptr);
366             });
367         }
368
369 PRAGMA_OMP(barrier)
370
371         parallel_nd_in_omp(jcp.dimN_nb_block, alpha, alpha, jcp.dimM_nb_block,
372             [&](int N_blk1, int oj, int oi, int M_blk1) {
373             for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block;
374                  K_blk1++)
375             for (int N_blk2 = 0; N_blk2 < jcp.dimN_block; N_blk2++)
376                 kernel_->gemm_loop_ker(
377                         (float *)&(M(N_blk1, M_blk1, oj, oi,
378                             N_blk2, 0, 0, 0)),
379                         (const float *)&(U(M_blk1, oj, oi,
380                             K_blk1, 0, 0, 0, 0)),
381                         (const float *)&(V(N_blk1, oj, oi,
382                             N_blk2, K_blk1, 0, 0, 0)), K_blk1);
383         });
384
385 PRAGMA_OMP(barrier)
386
387         parallel_nd_in_omp(MB, jcp.dimM_nb_block, (jcp.dimM_block * jcp.dimM_reg_block),
388                     [&](int img, int M_blk1, int M_blk2) {
389             const int M_blk =
390                 M_blk1 * jcp.dimM_block  * jcp.dimM_reg_block + M_blk2;
391
392             float *bias_ptr = wants_padded_bias
393                 && M_blk == jcp.dimM / jcp.dimM_simd_block - 1
394                 ? last_slice_bias : &bias(M_blk, 0);
395             output_transform_data(img, jcp, p_ops,
396                     &(M(0, M_blk1, 0, 0, 0, M_blk2, 0, 0)),
397                     &(output(img, M_blk, 0, 0, 0)), bias_ptr);
398         });
399     }
400 }
401
402 template <bool is_fwd>
403 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::_execute_data_W_SGD(const int MB,
404         float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr,
405         const memory_tracking::grantor_t &scratchpad) const {
406
407     const auto &jcp = kernel_->jcp;
408     const auto &p_ops = attr_->post_ops_;
409
410     const int inph = is_fwd ? jcp.ih : jcp.oh;
411     const int inpw = is_fwd ? jcp.iw : jcp.ow;
412     const int outh = is_fwd ? jcp.oh : jcp.ih;
413     const int outw = is_fwd ? jcp.ow : jcp.iw;
414
415     array_offset_calculator<float, 5> input(inp_ptr,
416         MB, jcp.dimK/jcp.dimK_reg_block, inph, inpw, jcp.dimK_reg_block);
417     array_offset_calculator<float, 5> output(out_ptr,
418         MB, jcp.dimM/jcp.dimM_simd_block, outh, outw, jcp.dimM_simd_block);
419     array_offset_calculator<float, 6> weights(wei_ptr,
420         jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw,
421         jcp.ic_simd_block, jcp.oc_simd_block);
422     array_offset_calculator<float, 2> bias(bias_ptr,
423         jcp.oc/jcp.oc_simd_block, jcp.oc_simd_block);
424
425     auto wino_wei = (jcp.prop_kind == prop_kind::forward_inference)
426                 ? wei_ptr
427                 : scratchpad.template get<float>(key_wino_U);
428
429     array_offset_calculator<float, 8> U(wino_wei,
430             jcp.dimM_nb_block,
431             alpha, alpha,
432             jcp.dimK_nb_block,
433             jcp.dimM_block  * jcp.dimM_reg_block, jcp.dimK_block,
434             jcp.dimK_reg_block, jcp.dimM_simd_block);
435
436     array_offset_calculator<float, 8> M(is_fwd
437             ? scratchpad.template get<float>(key_wino_M)
438             : scratchpad.template get<float>(key_wino_V),
439             0, jcp.dimM_nb_block, alpha, alpha,
440             jcp.dimN_block, jcp.dimM_block * jcp.dimM_reg_block,
441             jcp.dimN_reg_block, jcp.dimM_simd_block);
442     array_offset_calculator<float, 8> V(is_fwd
443             ? scratchpad.template get<float>(key_wino_V)
444             : scratchpad.template get<float>(key_wino_M),
445             0, alpha, alpha, jcp.dimN_block,
446             jcp.dimK_nb_block, jcp.dimK_block,
447             jcp.dimN_reg_block, jcp.dimK_reg_block);
448
449     const bool wants_padded_bias = jcp.with_bias
450         && jcp.oc_without_padding != jcp.oc;
451     float last_slice_bias[simd_w] = {0};
452     if (wants_padded_bias) {
453         for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
454             last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
455     }
456
457     if (jcp.prop_kind != prop_kind::forward_inference) {
458
459         parallel_nd(jcp.nb_oc, jcp.nb_ic, (jcp.oc_block * jcp.oc_reg_block), (jcp.ic_block * jcp.ic_reg_block),
460                     [&](int ofm1, int ifm1, int ofm2, int ifm2) {
461             float *U_base_ptr = is_fwd
462                               ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0))
463                               : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0));
464             weight_transform_data(jcp,
465                     &(weights(
466                         ofm1 * jcp.oc_block * jcp.oc_reg_block + ofm2,
467                         ifm1 * jcp.ic_block * jcp.ic_reg_block + ifm2,
468                         0, 0, 0, 0)),
469                     U_base_ptr);
470         });
471     }
472
473 PRAGMA_OMP(parallel)
474     {
475
476     int ithr = mkldnn_get_thread_num();
477
478 PRAGMA_OMP(for schedule(static))
479     for (int tile_block = 0; tile_block < jcp.tile_block; tile_block++) {
480         for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++) {
481             for (int K_blk2 = 0; K_blk2 < jcp.dimK_block; K_blk2++) {
482
483                 input_transform_tileblock_data(
484                         tile_block, jcp,
485                         &(input(0, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, 0)),
486                         &(V(ithr, 0, 0, 0, K_blk1, K_blk2, 0, 0)));
487             }
488         }
489
490         for (int oj = 0; oj < alpha; oj++) {
491             for (int oi = 0; oi < alpha; oi++) {
492                 for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++)
493                 for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++)
494                 for (int N_blk = 0; N_blk < jcp.dimN_block; N_blk++)
495                     kernel_->gemm_loop_ker(
496                             (float *)&(M(ithr, M_blk1, oj, oi,
497                                     N_blk, 0, 0, 0)),
498                             (const float *)&(U(M_blk1, oj, oi, K_blk1,
499                                     0, 0, 0, 0)),
500                             (const float *)&(V(ithr, oj, oi,
501                                     N_blk, K_blk1, 0, 0, 0)), K_blk1);
502             }
503         }
504
505         for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++) {
506             for (int M_blk2 = 0; M_blk2 < jcp.dimM_block * jcp.dimM_reg_block;
507                   M_blk2++) {
508                 const int M_blk =
509                     M_blk1 * jcp.dimM_block  * jcp.dimM_reg_block + M_blk2;
510
511                 float *bias_ptr = wants_padded_bias
512                     && M_blk == jcp.dimM / jcp.dimM_simd_block - 1
513                     ? last_slice_bias : &bias(M_blk, 0);
514
515                 output_transform_tileblock_data(tile_block, jcp, p_ops,
516                         &(M(ithr, M_blk1, 0, 0, 0, M_blk2, 0, 0)),
517                         &(output(0, M_blk, 0, 0, 0)), bias_ptr);
518             }
519         }
520     }
521     }
522 }
523
524 template struct _jit_avx512_core_fp32_wino_conv_4x3_t<true>;
525 template struct _jit_avx512_core_fp32_wino_conv_4x3_t<false>;
526
527 namespace {
528
529 void subarray_sum(size_t num_arrs, float *output, size_t nelems,
530         float *input_ptrs[], size_t input_starts[], size_t input_ends[]) {
531     using namespace nstl;
532     const size_t block_size = 16 * 1024 / sizeof(float);
533     const size_t blocks_number = nelems / block_size;
534     const size_t tail = nelems % block_size;
535
536 PRAGMA_OMP(parallel)
537     {
538         const int ithr = mkldnn_get_thread_num();
539         const int nthr = mkldnn_get_num_threads();
540         size_t start{ 0 }, end{ 0 };
541         balance211(blocks_number, nthr, ithr, start, end);
542
543         for (size_t nb = start; nb < end; ++nb) {
544             size_t start_e = nb * block_size;
545             size_t end_e = start_e + block_size;
546             size_t input_start = max(start_e, min(input_starts[0], end_e));
547             size_t input_end = max(start_e, min(input_ends[0], end_e));
548
549             PRAGMA_OMP_SIMD()
550             for (size_t e = start_e; e < input_start; e++) {
551                 output[e] = 0.f;
552             }
553
554             PRAGMA_OMP_SIMD()
555             for (size_t e = input_start; e < input_end; e++) {
556                 output[e] = input_ptrs[0][e];
557             }
558
559             PRAGMA_OMP_SIMD()
560             for (size_t e = input_end; e < end_e; e++) {
561                 output[e] = 0.f;
562             }
563
564             for (size_t a = 1; a < num_arrs; a++) {
565                 input_start = max(start_e, input_starts[a]);
566                 input_end = min(input_ends[a], end_e);
567
568                 PRAGMA_OMP_SIMD()
569                 for (size_t e = input_start; e < input_end; e++) {
570                     output[e] += input_ptrs[a][e];
571                 }
572             }
573         }
574
575         if (tail != 0 && ithr == nthr - 1) {
576             size_t start_e = nelems - tail;
577             size_t end_e = nelems;
578             size_t input_start = max(start_e, min(input_starts[0], end_e));
579             size_t input_end = max(start_e, min(input_ends[0], end_e));
580
581             PRAGMA_OMP_SIMD()
582             for (size_t e = start_e; e < input_start; e++) {
583                 output[e] = 0.f;
584             }
585
586             PRAGMA_OMP_SIMD()
587             for (size_t e = input_start; e < input_end; e++) {
588                 output[e] = input_ptrs[0][e];
589             }
590
591             PRAGMA_OMP_SIMD()
592             for (size_t e = input_end; e < end_e; e++) {
593                 output[e] = 0.f;
594             }
595
596             for (size_t a = 1; a < num_arrs; a++) {
597                 input_start = max(start_e, input_starts[a]);
598                 input_end = min(input_ends[a], end_e);
599
600                 PRAGMA_OMP_SIMD()
601                 for (size_t e = input_start; e < input_end; e++) {
602                     output[e] += input_ptrs[a][e];
603                 }
604             }
605         }
606     }
607 }
608
609 const int max_threads_number = 1024;
610
611 // Sum to the first buffer array
612 void array_sum(size_t num_arrs, float *output,
613     size_t nelems, float *input_ptrs[], bool reduce_to_first = true) {
614     const size_t block_size = 16 * 1024 / sizeof(float);
615     const size_t blocks_number = nelems / block_size;
616     const size_t tail = nelems % block_size;
617
618 PRAGMA_OMP(parallel)
619     {
620         const size_t ithr = mkldnn_get_thread_num();
621         const size_t nthr = mkldnn_get_num_threads();
622         size_t start{ 0 }, end{ 0 };
623         balance211(blocks_number, nthr, ithr, start, end);
624
625         for (size_t nb = start; nb < end; ++nb) {
626             size_t start_e = nb * block_size;
627             size_t end_e = start_e + block_size;
628             if (!reduce_to_first) {
629                 PRAGMA_OMP_SIMD()
630                 for (size_t e = start_e; e < end_e; e++) {
631                     output[e] = input_ptrs[0][e];
632                 }
633             }
634             for (size_t a = 1; a < num_arrs; a++) {
635                 PRAGMA_OMP_SIMD()
636                 for (size_t e = start_e; e < end_e; e++) {
637                     output[e] += input_ptrs[a][e];
638                 }
639             }
640         }
641
642         if (tail != 0 && ithr == nthr - 1) {
643             size_t start_e = nelems - tail;
644             size_t end_e = nelems;
645             if (!reduce_to_first) {
646                 PRAGMA_OMP_SIMD()
647                 for (size_t e = start_e; e < end_e; e++) {
648                     output[e] = input_ptrs[0][e];
649                 }
650             }
651             for (size_t a = 1; a < num_arrs; a++) {
652                 PRAGMA_OMP_SIMD()
653                 for (size_t e = start_e; e < end_e; e++) {
654                     output[e] += input_ptrs[a][e];
655                 }
656             }
657         }
658     }
659 }
660 } //bwdw namespace
661
662 void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t::
663 _execute_backward_weights_SDGtWo(
664         const memory_tracking::grantor_t &scratchpad) const {
665     const auto &jcp = kernel_->jcp;
666     const int nthreads = jcp.nthr;
667
668     array_offset_calculator<float, 5> src((float *)this->input_memory(0),
669             jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w);
670     array_offset_calculator<float, 5> diff_dst((float *)this->input_memory(1),
671             jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w);
672     array_offset_calculator<float, 6> diff_weights((float *)this->memory(0),
673             jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
674
675     array_offset_calculator<float, 8> Us(scratchpad.get<float>(key_wino_U),
676             0, alpha, alpha,
677             jcp.oc_block, jcp.ic_block,
678             jcp.ic_simd_block,
679             jcp.oc_reg_block,
680             jcp.oc_simd_block);
681
682     const int U_sz = nthreads * alpha * alpha * jcp.oc / jcp.nb_oc
683         * jcp.ic / jcp.nb_ic;
684     array_offset_calculator<float, 7>diff_weights_prv(
685             scratchpad.get<float>(key_wino_U) + U_sz,
686             0, jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
687
688     array_offset_calculator<float, 8> M(scratchpad.get<float>(key_wino_M),
689             0, alpha, alpha,
690             jcp.oc_block,
691             jcp.nb_tile_block_ur,
692             jcp.tile_block_ur,
693             jcp.oc_reg_block,
694             jcp.oc_simd_block);
695
696     array_offset_calculator<float, 7> V(scratchpad.get<float>(key_wino_V),
697             0, alpha, alpha,
698             jcp.ic_block,
699             jcp.nb_tile_block_ur,
700             jcp.tile_block_ur,
701             jcp.ic_simd_block);
702
703     array_offset_calculator<float, 2> diff_bias_prv(
704             scratchpad.get<float>(key_conv_bia_reduction), nthreads, jcp.oc);
705
706     auto trans_ker_p = jit_wino_transform_call_s();
707     float I[alpha][alpha][simd_w];
708     float T[alpha][alpha][simd_w];
709     float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
710                0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
711     float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f, 0.119514472455649f,
712        0.430252100840336f, 0.168067226890756f, 0.179271708683473f, 0.403361344537815f,
713        1.13777777777778f};
714     float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f};
715
716 PRAGMA_OMP(parallel num_threads(nthreads) firstprivate(trans_ker_p, I, T))
717 {
718     if (jcp.with_bias) {
719         parallel_nd_in_omp(nthreads, jcp.oc / simd_w,
720             [&](int ithr, int ofm){
721                 float *pdbias = &(diff_bias_prv(ithr, ofm * simd_w));
722                 PRAGMA_OMP_SIMD()
723                 for (int v = 0; v < simd_w; v++) {
724                     pdbias[v] = 0.0f;
725                 }
726         });
727     }
728
729     int ithr = mkldnn_get_thread_num();
730     for (int ifm1 = 0; ifm1 < jcp.nb_ic; ++ifm1) {
731         int first_tblk = 0;
732 PRAGMA_OMP(for)
733         for (int tblk1 = 0; tblk1 < jcp.tile_block; ++tblk1) {
734             int tile_index = tblk1 * jcp.nb_tile_block_ur * jcp.tile_block_ur;
735             int img = tile_index / (jcp.itiles * jcp.jtiles);
736             trans_ker_p.ti = tile_index % jcp.itiles;
737             trans_ker_p.tj = (tile_index / jcp.itiles) % jcp.jtiles;
738             trans_ker_p.M = I;
739             trans_ker_p.T = T;
740             trans_ker_p.G = G_I_3x3_4x4;
741             for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) {
742                 int ifm = ifm1 * jcp.ic_block + ifm2;
743                 trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0));
744                 trans_ker_p.dst = (float *)&(V(ithr, 0, 0, ifm2, 0, 0, 0));
745                 kernel_->src_transform(&trans_ker_p);
746             }
747
748             for (int ofm1 = 0; ofm1 < jcp.nb_oc; ++ofm1) {
749                 trans_ker_p.G = G_W_3x3_4x4;
750                 for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) {
751                     int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block;
752                     trans_ker_p.src = (float *)&(diff_dst(img, ofm, 0, 0, 0));
753                     trans_ker_p.dst = (float *)&(M(ithr, 0, 0, ofm2, 0, 0, 0, 0));
754                     if (jcp.with_bias && ifm1 == 0) {
755                         trans_ker_p.bias = (float *)&(diff_bias_prv(ithr, ofm * simd_w));
756                         kernel_->diff_dst_transform_wbias(&trans_ker_p);
757                     } else {
758                         kernel_->diff_dst_transform(&trans_ker_p);
759                     }
760                 }
761
762                 for (int oj = 0; oj < alpha; ++oj) {
763                     for (int oi = 0; oi < alpha; ++oi) {
764                         kernel_->gemm_loop_ker_first_iter(
765                                 &(Us(ithr, oj, oi, 0, 0, 0, 0, 0)),
766                                 &(M(ithr, oj, oi, 0, 0, 0, 0, 0)),
767                                 &(V(ithr, oj, oi, 0, 0, 0, 0)));
768                     }
769                 }
770                 trans_ker_p.G = G_O_3x3_4x4;
771                 for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) {
772                     for (int ofm3 = 0; ofm3 < jcp.oc_reg_block; ++ofm3) {
773                         int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block
774                                 + ofm3;
775                         for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) {
776                             int ifm = ifm1 * jcp.ic_block + ifm2;
777                             trans_ker_p.src = (float *)&(Us(ithr, 0, 0,
778                                         ofm2, ifm2, 0, ofm3, 0));
779                             trans_ker_p.dst = (float *)&(diff_weights_prv(ithr,
780                                         ofm, ifm, 0, 0, 0, 0));
781                             if (first_tblk == 0) {
782                                 kernel_->diff_weights_transform(&trans_ker_p);
783                             } else {
784                                 kernel_->diff_weights_transform_accum(&trans_ker_p);
785                             }
786                         }
787                     }
788                 }
789             }
790             ++first_tblk;
791         }
792     }
793 }
794
795     // Reduce diff-weights
796     {
797         float *output = (float *)(this->memory(0));
798         float *input_base = scratchpad.get<float>(key_wino_U) + U_sz;
799         int nelems = jcp.oc * jcp.ic * jcp.kh * jcp.kw;
800         float *input_ptrs[max_threads_number];
801         for (int i = 0; i < nthreads; ++i) {
802             input_ptrs[i] = input_base + nelems * i;
803         }
804         array_sum(nthreads, output, nelems, input_ptrs, false);
805
806         if (jcp.with_bias) {
807             output = (float *)(this->memory(1));
808             input_base = scratchpad.get<float>(key_conv_bia_reduction);
809             for (int i = 0; i < nthreads; ++i) {
810                 input_ptrs[i] = input_base + jcp.oc * i;
811             }
812             array_sum(nthreads, output, jcp.oc_without_padding, input_ptrs,
813                     false);
814         }
815     }
816 }
817
818 void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t::
819 _execute_backward_weights_S_D_Giot_W(
820         const memory_tracking::grantor_t &scratchpad) const {
821     const auto &jcp = kernel_->jcp;
822     const int nthreads = jcp.nthr;
823
824     array_offset_calculator<float, 5> src((float *)this->input_memory(0),
825             jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w);
826     array_offset_calculator<float, 5> diff_dst((float *)this->input_memory(1),
827             jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w);
828     array_offset_calculator<float, 6> diff_weights((float *)this->memory(0),
829             jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
830     array_offset_calculator<float, 1> diff_bias((float *)this->memory(1), jcp.oc);
831
832     array_offset_calculator<float, 9> U(scratchpad.get<float>(key_wino_U),
833             jcp.nb_ic, jcp.nb_oc,
834             alpha, alpha,
835             jcp.oc_block, jcp.ic_block,
836             jcp.ic_simd_block,
837             jcp.oc_reg_block,
838             jcp.oc_simd_block);
839
840     const int U_size = jcp.oc * jcp.ic * alpha * alpha;
841     array_offset_calculator<float, 10> Us(
842             scratchpad.get<float>(key_wino_U) + U_size,
843             0, jcp.nb_ic, jcp.nb_oc,
844             alpha, alpha,
845             jcp.oc_block, jcp.ic_block,
846             jcp.ic_simd_block,
847             jcp.oc_reg_block,
848             jcp.oc_simd_block);
849
850     array_offset_calculator<float, 9> M(scratchpad.get<float>(key_wino_M),
851             jcp.nb_oc,
852             jcp.tile_block,
853             alpha, alpha,
854             jcp.oc_block,
855             jcp.nb_tile_block_ur,
856             jcp.tile_block_ur ,
857             jcp.oc_reg_block,
858             jcp.oc_simd_block);
859
860     array_offset_calculator<float, 8> V(scratchpad.get<float>(key_wino_V),
861             jcp.nb_ic,
862             jcp.tile_block,
863             alpha, alpha,
864             jcp.ic_block,
865             jcp.nb_tile_block_ur, jcp.tile_block_ur,
866             jcp.ic_simd_block);
867
868     array_offset_calculator<float, 2> diff_bias_prv(
869             scratchpad.get<float>(key_conv_bia_reduction), nthreads, jcp.oc);
870
871     size_t input_starts[max_threads_number] = {0};
872     size_t input_ends[max_threads_number] = {0};
873     size_t first_tblk = 0;
874
875     auto trans_ker_p = jit_wino_transform_call_s();
876     float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
877                0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
878     float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f,
879         0.119514472455649f, 0.430252100840336f, 0.168067226890756f,
880         0.179271708683473f, 0.403361344537815f, 1.13777777777778f};
881     float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f};
882     float I[alpha][alpha][simd_w];
883     float T[alpha][alpha][simd_w];
884
885 PRAGMA_OMP(parallel firstprivate(first_tblk, trans_ker_p, I, T))
886 {
887     if (jcp.with_bias) {
888         parallel_nd_in_omp(nthreads, jcp.oc, [&](int ithr, int ofm) {
889             diff_bias_prv(ithr, ofm) = 0.0f;
890         });
891     }
892
893     trans_ker_p.G = G_I_3x3_4x4;
894     trans_ker_p.M = I;
895     trans_ker_p.T = T;
896
897     parallel_nd_in_omp(jcp.nb_ic, jcp.ic_block, jcp.mb,
898         [&](int ifm1, int ifm2, int img){
899          size_t ifm = ifm1 * jcp.ic_block + ifm2;
900          size_t tile_base_index = img * (jcp.itiles * jcp.jtiles);
901          size_t tblk3 = tile_base_index  % jcp.tile_block_ur;
902          size_t tblk2 = (tile_base_index / jcp.tile_block_ur)
903              % jcp.nb_tile_block_ur;
904          size_t tblk1 = (tile_base_index / jcp.tile_block_ur)
905              / jcp.nb_tile_block_ur;
906          trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3;
907          trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0));
908          trans_ker_p.dst = (float *)&(V(ifm1, tblk1, 0, 0, ifm2, 0, 0, 0));
909          kernel_->src_transform(&trans_ker_p);
910     });
911
912     int ithr = mkldnn_get_thread_num();
913     trans_ker_p.G = G_W_3x3_4x4;
914     parallel_nd_in_omp(jcp.nb_oc, jcp.oc_block, jcp.mb,
915         [&](int ofm1, int ofm2, int img){
916         int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block;
917         size_t tile_base_index = img * (jcp.itiles * jcp.jtiles);
918         size_t tblk3 = tile_base_index  % jcp.tile_block_ur;
919         size_t tblk2 = (tile_base_index / jcp.tile_block_ur)
920             % jcp.nb_tile_block_ur;
921         size_t tblk1 = (tile_base_index / jcp.tile_block_ur)
922             / jcp.nb_tile_block_ur;
923         trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3;
924         trans_ker_p.src = (float *)&(diff_dst(img, ofm, 0, 0, 0));
925         trans_ker_p.dst = (float *)&(M(ofm1, tblk1, 0, 0, ofm2, 0, 0, 0, 0));
926         if (jcp.with_bias) {
927             trans_ker_p.bias = (float *)&(diff_bias_prv(ithr, ofm * simd_w));
928             kernel_->diff_dst_transform_wbias(&trans_ker_p);
929         } else {
930             kernel_->diff_dst_transform(&trans_ker_p);
931         }
932     });
933
934     PRAGMA_OMP(barrier)
935
936     parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, alpha, alpha, jcp.tile_block,
937         [&](int ifm1, int ofm1, int oj, int oi, int tblk1){
938         if (first_tblk == 0) {
939             input_starts[ithr] =
940                 (float *)&(Us(ithr, ifm1, ofm1, oj, oi, 0, 0, 0,
941                             0, 0))
942                 - (float *)&(Us(ithr, 0, 0, 0, 0, 0, 0,
943                             0, 0, 0));
944             input_ends[ithr] = input_starts[ithr]
945                     + jcp.oc_block * jcp.ic_block
946                       * jcp.ic_simd_block * jcp.oc_reg_block
947                       * jcp.oc_simd_block;
948         }
949         else if (tblk1 == 0) {
950             input_ends[ithr] += jcp.oc_block * jcp.ic_block
951                 * jcp.ic_simd_block * jcp.oc_reg_block
952                 * jcp.oc_simd_block;
953         }
954
955         if (first_tblk == 0 || tblk1 == 0) {
956             kernel_->gemm_loop_ker_first_iter(
957                     &(Us(ithr, ifm1, ofm1, oj, oi,
958                             0, 0, 0, 0, 0)),
959                     &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)),
960                     &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0)));
961         } else {
962             kernel_->gemm_loop_ker(
963                     &(Us(ithr, ifm1, ofm1, oj, oi,
964                             0, 0, 0, 0, 0)),
965                     &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)),
966                     &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0)));
967         }
968         ++first_tblk;
969     });
970 }
971
972     // Reduce diff-weights
973     {
974         float *output = &(U(0, 0, 0, 0, 0, 0, 0, 0, 0));
975         size_t nelems = jcp.ic * jcp.oc * alpha * alpha;
976         float *input_ptrs[max_threads_number];
977         for (int i = 0; i < nthreads; ++i)
978             input_ptrs[i] = output + nelems * (i + 1);
979         subarray_sum(nthreads, output, nelems, input_ptrs,
980                 input_starts, input_ends);
981     }
982
983     trans_ker_p.G = G_O_3x3_4x4;
984 PRAGMA_OMP(parallel firstprivate(trans_ker_p))
985     {
986         parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block, jcp.oc_reg_block,
987             [&](int ifm1, int ofm1, int ofm2, int ifm2, int ofm3){
988             int ofm = (ofm1 * jcp.oc_block + ofm2)
989                 * jcp.oc_reg_block + ofm3;
990             int ifm = ifm1 * jcp.ic_block + ifm2;
991             trans_ker_p.src = (float *)&(U(ifm1, ofm1, 0, 0,
992                         ofm2, ifm2, 0, ofm3, 0));
993             trans_ker_p.dst = (float *)&(diff_weights(ofm, ifm,
994                         0, 0, 0, 0));
995             kernel_->diff_weights_transform(&trans_ker_p);
996         });
997     }
998
999     if (jcp.with_bias) {
1000         parallel_nd(jcp.oc / simd_w, [&](int ofm1) {
1001             float* pbias = &(diff_bias(ofm1 * simd_w));
1002             float *pbias_prv = &(diff_bias_prv(0, ofm1 * simd_w));
1003
1004             const int blk_sz = ofm1 == jcp.oc / simd_w - 1
1005                 ? jcp.oc_without_padding - ofm1 * simd_w : simd_w;
1006
1007             PRAGMA_OMP_SIMD()
1008             for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) {
1009                 pbias[ofm2] = pbias_prv[ofm2];
1010             }
1011
1012             for (int ithr = 1; ithr < nthreads; ++ithr) {
1013                 pbias_prv = &(diff_bias_prv(ithr, ofm1 * simd_w));
1014                 PRAGMA_OMP_SIMD()
1015                 for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) {
1016                     pbias[ofm2] += pbias_prv[ofm2];
1017                 }
1018             }
1019         });
1020     }
1021 }
1022
1023 }
1024 }
1025 }
1026 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s