updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_fp32_wino_conv_4x3.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifdef __INTEL_COMPILER
18 #include <immintrin.h>
19 #endif
20
21 #include "mkldnn_types.h"
22
23 #include "c_types_map.hpp"
24 #include "mkldnn_thread.hpp"
25 #include "type_helpers.hpp"
26 #include "utils.hpp"
27
28 #include "jit_avx512_core_fp32_wino_conv_4x3.hpp"
29
30 #ifndef _MSC_VER
31 #define pragma_unroll _Pragma("unroll")
32 #else
33 #define pragma_unroll
34 #endif
35
36
37 namespace mkldnn {
38 namespace impl {
39 namespace cpu {
40
41 using namespace mkldnn::impl::status;
42 using namespace mkldnn::impl::memory_format;
43 using namespace mkldnn::impl::memory_tracking::names;
44 using namespace mkldnn::impl::utils;
45
46 template <bool is_fwd>
47 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
48 ::weight_transform_data(const jit_conv_winograd_conf_t &jcp,
49         float *wp, float *twp) const
50 {
51     float G[] = {0.26890756302521f, 0.688403361344538f, 0.119514472455649f,
52                  1.13777777777778f, 0.430252100840336f, 0.179271708683473f};
53     const int kh = 3;
54     const int kw = 3;
55     float Fw[alpha][alpha][simd_w][simd_w];
56     float F[kh][kw][simd_w][simd_w];
57     float T[alpha][3][simd_w];
58     auto p = jit_wino_transform_call_s();
59
60     p.src = wp;
61     p.dst = twp;
62     p.G = G;
63     p.M = F;
64     p.Mw = Fw;
65     p.T = T;
66
67     kernel_->weights_transform_data_ker(&p);
68 }
69
70 template<bool is_fwd>
71 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::output_transform_data
72 (int image, const jit_conv_winograd_conf_t &jcp,
73     const post_ops_t &p_ops, float *toutp, float *pout_b, float *bias) const {
74
75     float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f};
76     float Ow[alpha][alpha][simd_w];
77     float O[tile_size][tile_size][simd_w];
78     float T[tile_size][alpha][simd_w];
79
80     auto p = jit_wino_transform_call_s();
81     p.src = toutp;
82     p.dst = pout_b;
83     p.G = G;
84     p.M = O;
85     p.Mw = Ow;
86     p.T = T;
87     p.bias = bias;
88
89     int tile_base_index = image * jcp.itiles * jcp.jtiles;
90     int tile_block_ur = tile_base_index % jcp.tile_block_ur;
91     int nb_tile_block_ur =
92         (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur;
93     int tile_block =
94         (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur;
95
96     for (int tj = 0; tj < jcp.jtiles; tj++) {
97         for (int ti = 0; ti < jcp.itiles; ti++) {
98
99             p.tile_block_ur = tile_block_ur;
100             p.nb_tile_block_ur = nb_tile_block_ur;
101             p.tile_block = tile_block;
102             p.tj = tj;
103             p.ti = ti;
104
105             kernel_->output_transform_data_ker(&p);
106
107             tile_block_ur++;
108             if (tile_block_ur >= jcp.tile_block_ur) {
109                 tile_block_ur = 0;
110                 nb_tile_block_ur++;
111             }
112             if (nb_tile_block_ur >= jcp.nb_tile_block_ur) {
113                 nb_tile_block_ur = 0;
114                 tile_block++;
115             }
116         }
117     }
118 }
119
120 template<bool is_fwd>
121 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
122 ::output_transform_tileblock_data(int tile_block,
123     const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops,
124     float *toutp, float *outp, float *bias) const {
125
126     float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f};
127     float Ow[alpha][alpha][simd_w];
128     float O[tile_size][tile_size][simd_w];
129     float T[tile_size][alpha][simd_w];
130
131     auto p = jit_wino_transform_call_s();
132     p.src = toutp;
133     p.dst = outp;
134     p.G = G;
135     p.M = O;
136     p.Mw = Ow;
137     p.T = T;
138     p.bias = bias;
139
140     int outw = is_fwd ? jcp.ow : jcp.iw;
141     int outh = is_fwd ? jcp.oh : jcp.ih;
142
143     int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur;
144
145     for (int nb_tile_block_ur = 0;
146         nb_tile_block_ur < jcp.nb_tile_block_ur;
147         nb_tile_block_ur++) {
148
149         for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur;
150             tile_block_ur++) {
151             int img = tile_index / (jcp.jtiles * jcp.itiles);
152             int ti = tile_index % jcp.itiles;
153             int tj = (tile_index / jcp.itiles) % jcp.jtiles;
154
155             p.tile_block_ur = tile_block_ur;
156             p.nb_tile_block_ur = nb_tile_block_ur;
157             p.tile_block = tile_block;
158             p.tj = tj;
159             p.ti = ti;
160             p.dst = outp + img * (jcp.dimM / jcp.dimM_simd_block)
161                                * outh * outw * jcp.dimM_simd_block;
162
163             kernel_->output_transform_data_ker(&p);
164
165             tile_index++;
166         }
167     }
168 }
169
170
171 template<bool is_fwd>
172 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
173     ::input_transform_data(int image, const jit_conv_winograd_conf_t &jcp,
174         float *inp, float *tinp) const
175 {
176     float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
177                  0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
178
179     float Iw[alpha][alpha][simd_w];
180     float I[alpha][alpha][simd_w];
181     float T[alpha][alpha][simd_w];
182
183     auto p = jit_wino_transform_call_s();
184
185     p.src = inp;
186     p.dst = tinp;
187     p.G = G;
188     p.M = I;
189     p.Mw = Iw;
190     p.T = T;
191
192     int tile_base_index = image * jcp.itiles * jcp.jtiles;
193     int tile_block_ur = tile_base_index % jcp.tile_block_ur;
194     int nb_tile_block_ur =
195         (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur;
196     int tile_block =
197         (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur;
198
199     for (int tj = 0; tj < jcp.jtiles; tj++) {
200         for (int ti = 0; ti < jcp.itiles; ti++) {
201
202             p.tile_block_ur = tile_block_ur;
203             p.nb_tile_block_ur = nb_tile_block_ur;
204             p.tile_block = tile_block;
205             p.tj = tj;
206             p.ti = ti;
207
208             kernel_->input_transform_data_ker(&p);
209
210             tile_block_ur++;
211             if (tile_block_ur >= jcp.tile_block_ur) {
212                 tile_block_ur = 0;
213                 nb_tile_block_ur++;
214             }
215             if (nb_tile_block_ur >= jcp.nb_tile_block_ur) {
216                 nb_tile_block_ur = 0;
217                 tile_block++;
218             }
219         }
220     }
221 }
222
223 template <bool is_fwd>
224 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
225     ::input_transform_tileblock_data(int tile_block,
226         const jit_conv_winograd_conf_t &jcp,
227         float *inp, float *tinp) const
228 {
229     float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
230                0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
231     float Iw[alpha][alpha][simd_w];
232     float I[alpha][alpha][simd_w];
233     float T[alpha][alpha][simd_w];
234
235     const int inph = is_fwd ? jcp.ih : jcp.oh;
236     const int inpw = is_fwd ? jcp.iw : jcp.ow;
237
238     array_offset_calculator<float, 5> input(inp,
239         jcp.mb, jcp.dimK / simd_w, inph, inpw, simd_w);
240     array_offset_calculator<float, 7> output(tinp,
241         alpha, alpha,
242         jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block,
243         jcp.dimN_reg_block, jcp.dimK_reg_block);
244
245     auto p = jit_wino_transform_call_s();
246
247     p.dst = tinp;
248     p.G = G;
249     p.M = I;
250     p.Mw = Iw;
251     p.T = T;
252
253
254     int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur;
255
256     for (int nb_tile_block_ur = 0;
257             nb_tile_block_ur < jcp.nb_tile_block_ur;
258             nb_tile_block_ur++) {
259
260         for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur;
261                 tile_block_ur++) {
262
263             int img = tile_index / (jcp.jtiles * jcp.itiles);
264             int ti = tile_index % jcp.itiles;
265             int tj = (tile_index / jcp.itiles) % jcp.jtiles;
266             float *pinp_b = &(input(img, 0, 0, 0, 0));
267
268             p.src = pinp_b;
269             p.tile_block_ur = tile_block_ur;
270             p.nb_tile_block_ur = nb_tile_block_ur;
271             p.tj = tj;
272             p.ti = ti;
273
274             kernel_->input_transform_data_ker(&p);
275
276             tile_index++;
277         }
278     }
279 }
280
281 template <bool is_fwd>
282 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::_execute_data_W_S_G_D(
283         const int MB, float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr,
284         const memory_tracking::grantor_t &scratchpad) const {
285     const auto &jcp = kernel_->jcp;
286     const auto &p_ops = attr_->post_ops_;
287
288     const int inph = is_fwd ? jcp.ih : jcp.oh;
289     const int inpw = is_fwd ? jcp.iw : jcp.ow;
290     const int outh = is_fwd ? jcp.oh : jcp.ih;
291     const int outw = is_fwd ? jcp.ow : jcp.iw;
292
293     /* Notation:
294        FWD: dimM:oc, dimN:ntiles, dimK:ic,
295        BWD: dimM:ic, dimN:ntiles, dimK:oc,
296        FWD/BWD: V: src/diff_dst transform, U:weight transform,
297                 M:dst/diff_src transform  */
298     array_offset_calculator<float, 5> input(inp_ptr,
299             MB, jcp.dimK/jcp.dimK_reg_block, inph, inpw,
300             jcp.dimK_reg_block);
301     array_offset_calculator<float, 5> output(out_ptr,
302             MB, jcp.dimM/jcp.dimM_simd_block, outh, outw,
303             jcp.dimM_simd_block);
304     array_offset_calculator<float, 6> weights(wei_ptr,
305             jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw,
306             jcp.ic_simd_block, jcp.oc_simd_block);
307     array_offset_calculator<float, 2> bias(bias_ptr,
308             jcp.dimM/jcp.dimM_simd_block, jcp.dimM_simd_block);
309
310     array_offset_calculator<float, 8> M(is_fwd
311             ? scratchpad.template get<float>(key_wino_M)
312             : scratchpad.template get<float>(key_wino_V),
313             jcp.dimN_nb_block, jcp.dimM_nb_block,
314             alpha, alpha,
315             jcp.dimN_block, jcp.dimM_block * jcp.dimM_reg_block,
316             jcp.dimN_reg_block, jcp.dimM_simd_block);
317
318     auto wino_wei = (jcp.prop_kind == prop_kind::forward_inference)
319             ? wei_ptr
320             : scratchpad.template get<float>(key_wino_U);
321
322     array_offset_calculator<float, 8> U(wino_wei,
323             jcp.dimM_nb_block,
324             alpha, alpha,
325             jcp.dimK_nb_block,
326             jcp.dimM_block * jcp.dimM_reg_block, jcp.dimK_block,
327             jcp.dimK_reg_block, jcp.dimM_simd_block);
328     array_offset_calculator<float, 8> V(is_fwd
329             ? scratchpad.template get<float>(key_wino_V)
330             : scratchpad.template get<float>(key_wino_M),
331             jcp.dimN_nb_block, alpha, alpha,
332             jcp.dimN_block, jcp.dimK_nb_block,
333             jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block);
334
335     const bool wants_padded_bias = jcp.with_bias
336         && jcp.oc_without_padding != jcp.oc;
337     float last_slice_bias[simd_w] = {0};
338     if (wants_padded_bias) {
339         for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
340             last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
341     }
342
343 #if MKLDNN_THR == MKLDNN_THR_OMP
344 #define PARALLEL_ND parallel_nd_in_omp
345 #else
346 #define PARALLEL_ND parallel_nd
347 #endif
348
349 #if MKLDNN_THR == MKLDNN_THR_OMP
350 PRAGMA_OMP(parallel)
351 #endif
352     {
353         PARALLEL_ND(MB, jcp.dimK_nb_block, jcp.dimK_block,
354                 [&](int img, int K_blk1, int K_blk2) {
355                 input_transform_data(img, jcp,
356                     &(input(img, K_blk1 * jcp.dimK_block + K_blk2,
357                             0, 0, 0)),
358                         &(V(0, 0, 0, 0, K_blk1, K_blk2, 0, 0)));
359                 });
360
361         if (jcp.prop_kind != prop_kind::forward_inference) {
362             PARALLEL_ND(jcp.nb_oc, jcp.nb_ic, (jcp.oc_block * jcp.oc_reg_block),
363                 (jcp.ic_block * jcp.ic_reg_block),
364                 [&](int ofm1, int ifm1, int ofm2, int ifm2) {
365                     float *U_base_ptr = is_fwd
366                         ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0))
367                         : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0));
368                     weight_transform_data(jcp,
369                         &(weights(
370                                 ofm1 * jcp.oc_block * jcp.oc_reg_block + ofm2,
371                                 ifm1 * jcp.ic_block * jcp.ic_reg_block + ifm2,
372                                 0, 0, 0, 0)),
373                         U_base_ptr);
374             });
375         }
376
377
378 PRAGMA_OMP(barrier)
379
380         PARALLEL_ND(jcp.dimN_nb_block, alpha, alpha, jcp.dimM_nb_block,
381             [&](int N_blk1, int oj, int oi, int M_blk1) {
382             for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block;
383                  K_blk1++)
384             for (int N_blk2 = 0; N_blk2 < jcp.dimN_block; N_blk2++)
385                 kernel_->gemm_loop_ker(
386                         (float *)&(M(N_blk1, M_blk1, oj, oi,
387                             N_blk2, 0, 0, 0)),
388                         (const float *)&(U(M_blk1, oj, oi,
389                             K_blk1, 0, 0, 0, 0)),
390                         (const float *)&(V(N_blk1, oj, oi,
391                             N_blk2, K_blk1, 0, 0, 0)), K_blk1);
392         });
393
394 PRAGMA_OMP(barrier)
395
396         PARALLEL_ND(MB, jcp.dimM_nb_block, (jcp.dimM_block * jcp.dimM_reg_block),
397                     [&](int img, int M_blk1, int M_blk2) {
398             const int M_blk =
399                 M_blk1 * jcp.dimM_block  * jcp.dimM_reg_block + M_blk2;
400
401             float *bias_ptr = wants_padded_bias
402                 && M_blk == jcp.dimM / jcp.dimM_simd_block - 1
403                 ? last_slice_bias : &bias(M_blk, 0);
404             output_transform_data(img, jcp, p_ops,
405                     &(M(0, M_blk1, 0, 0, 0, M_blk2, 0, 0)),
406                     &(output(img, M_blk, 0, 0, 0)), bias_ptr);
407         });
408     }
409
410 #undef PARALLEL_ND
411 }
412
413 template <bool is_fwd>
414 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::_execute_data_W_SGD(const int MB,
415         float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr,
416         const memory_tracking::grantor_t &scratchpad) const {
417
418     const auto &jcp = kernel_->jcp;
419     const auto &p_ops = attr_->post_ops_;
420
421     const int inph = is_fwd ? jcp.ih : jcp.oh;
422     const int inpw = is_fwd ? jcp.iw : jcp.ow;
423     const int outh = is_fwd ? jcp.oh : jcp.ih;
424     const int outw = is_fwd ? jcp.ow : jcp.iw;
425
426     array_offset_calculator<float, 5> input(inp_ptr,
427         MB, jcp.dimK/jcp.dimK_reg_block, inph, inpw, jcp.dimK_reg_block);
428     array_offset_calculator<float, 5> output(out_ptr,
429         MB, jcp.dimM/jcp.dimM_simd_block, outh, outw, jcp.dimM_simd_block);
430     array_offset_calculator<float, 6> weights(wei_ptr,
431         jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw,
432         jcp.ic_simd_block, jcp.oc_simd_block);
433     array_offset_calculator<float, 2> bias(bias_ptr,
434         jcp.oc/jcp.oc_simd_block, jcp.oc_simd_block);
435
436     auto wino_wei = (jcp.prop_kind == prop_kind::forward_inference)
437                 ? wei_ptr
438                 : scratchpad.template get<float>(key_wino_U);
439
440     array_offset_calculator<float, 8> U(wino_wei,
441             jcp.dimM_nb_block,
442             alpha, alpha,
443             jcp.dimK_nb_block,
444             jcp.dimM_block  * jcp.dimM_reg_block, jcp.dimK_block,
445             jcp.dimK_reg_block, jcp.dimM_simd_block);
446
447     array_offset_calculator<float, 8> M(is_fwd
448             ? scratchpad.template get<float>(key_wino_M)
449             : scratchpad.template get<float>(key_wino_V),
450             0, jcp.dimM_nb_block, alpha, alpha,
451             jcp.dimN_block, jcp.dimM_block * jcp.dimM_reg_block,
452             jcp.dimN_reg_block, jcp.dimM_simd_block);
453     array_offset_calculator<float, 8> V(is_fwd
454             ? scratchpad.template get<float>(key_wino_V)
455             : scratchpad.template get<float>(key_wino_M),
456             0, alpha, alpha, jcp.dimN_block,
457             jcp.dimK_nb_block, jcp.dimK_block,
458             jcp.dimN_reg_block, jcp.dimK_reg_block);
459
460     const bool wants_padded_bias = jcp.with_bias
461         && jcp.oc_without_padding != jcp.oc;
462     float last_slice_bias[simd_w] = {0};
463     if (wants_padded_bias) {
464         for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
465             last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
466     }
467
468     if (jcp.prop_kind != prop_kind::forward_inference) {
469
470         parallel_nd(jcp.nb_oc, jcp.nb_ic, (jcp.oc_block * jcp.oc_reg_block), (jcp.ic_block * jcp.ic_reg_block),
471                     [&](int ofm1, int ifm1, int ofm2, int ifm2) {
472             float *U_base_ptr = is_fwd
473                               ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0))
474                               : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0));
475             weight_transform_data(jcp,
476                     &(weights(
477                         ofm1 * jcp.oc_block * jcp.oc_reg_block + ofm2,
478                         ifm1 * jcp.ic_block * jcp.ic_reg_block + ifm2,
479                         0, 0, 0, 0)),
480                     U_base_ptr);
481         });
482     }
483
484 PRAGMA_OMP(parallel)
485     {
486
487     int ithr = mkldnn_get_thread_num();
488
489 PRAGMA_OMP(for schedule(static))
490     for (int tile_block = 0; tile_block < jcp.tile_block; tile_block++) {
491         for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++) {
492             for (int K_blk2 = 0; K_blk2 < jcp.dimK_block; K_blk2++) {
493
494                 input_transform_tileblock_data(
495                         tile_block, jcp,
496                         &(input(0, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, 0)),
497                         &(V(ithr, 0, 0, 0, K_blk1, K_blk2, 0, 0)));
498             }
499         }
500
501         for (int oj = 0; oj < alpha; oj++) {
502             for (int oi = 0; oi < alpha; oi++) {
503                 for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++)
504                 for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++)
505                 for (int N_blk = 0; N_blk < jcp.dimN_block; N_blk++)
506                     kernel_->gemm_loop_ker(
507                             (float *)&(M(ithr, M_blk1, oj, oi,
508                                     N_blk, 0, 0, 0)),
509                             (const float *)&(U(M_blk1, oj, oi, K_blk1,
510                                     0, 0, 0, 0)),
511                             (const float *)&(V(ithr, oj, oi,
512                                     N_blk, K_blk1, 0, 0, 0)), K_blk1);
513             }
514         }
515
516         for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++) {
517             for (int M_blk2 = 0; M_blk2 < jcp.dimM_block * jcp.dimM_reg_block;
518                   M_blk2++) {
519                 const int M_blk =
520                     M_blk1 * jcp.dimM_block  * jcp.dimM_reg_block + M_blk2;
521
522                 float *bias_ptr = wants_padded_bias
523                     && M_blk == jcp.dimM / jcp.dimM_simd_block - 1
524                     ? last_slice_bias : &bias(M_blk, 0);
525
526                 output_transform_tileblock_data(tile_block, jcp, p_ops,
527                         &(M(ithr, M_blk1, 0, 0, 0, M_blk2, 0, 0)),
528                         &(output(0, M_blk, 0, 0, 0)), bias_ptr);
529             }
530         }
531     }
532     }
533 }
534
535 template struct _jit_avx512_core_fp32_wino_conv_4x3_t<true>;
536 template struct _jit_avx512_core_fp32_wino_conv_4x3_t<false>;
537
538 namespace {
539
540 void subarray_sum(size_t num_arrs, float *output, size_t nelems,
541         float *input_ptrs[], size_t input_starts[], size_t input_ends[]) {
542     using namespace nstl;
543     const size_t block_size = 16 * 1024 / sizeof(float);
544     const size_t blocks_number = nelems / block_size;
545     const size_t tail = nelems % block_size;
546
547 PRAGMA_OMP(parallel)
548     {
549         const int ithr = mkldnn_get_thread_num();
550         const int nthr = mkldnn_get_num_threads();
551         size_t start{ 0 }, end{ 0 };
552         balance211(blocks_number, nthr, ithr, start, end);
553
554         for (size_t nb = start; nb < end; ++nb) {
555             size_t start_e = nb * block_size;
556             size_t end_e = start_e + block_size;
557             size_t input_start = max(start_e, min(input_starts[0], end_e));
558             size_t input_end = max(start_e, min(input_ends[0], end_e));
559
560             PRAGMA_OMP_SIMD()
561             for (size_t e = start_e; e < input_start; e++) {
562                 output[e] = 0.f;
563             }
564
565             PRAGMA_OMP_SIMD()
566             for (size_t e = input_start; e < input_end; e++) {
567                 output[e] = input_ptrs[0][e];
568             }
569
570             PRAGMA_OMP_SIMD()
571             for (size_t e = input_end; e < end_e; e++) {
572                 output[e] = 0.f;
573             }
574
575             for (size_t a = 1; a < num_arrs; a++) {
576                 input_start = max(start_e, input_starts[a]);
577                 input_end = min(input_ends[a], end_e);
578
579                 PRAGMA_OMP_SIMD()
580                 for (size_t e = input_start; e < input_end; e++) {
581                     output[e] += input_ptrs[a][e];
582                 }
583             }
584         }
585
586         if (tail != 0 && ithr == nthr - 1) {
587             size_t start_e = nelems - tail;
588             size_t end_e = nelems;
589             size_t input_start = max(start_e, min(input_starts[0], end_e));
590             size_t input_end = max(start_e, min(input_ends[0], end_e));
591
592             PRAGMA_OMP_SIMD()
593             for (size_t e = start_e; e < input_start; e++) {
594                 output[e] = 0.f;
595             }
596
597             PRAGMA_OMP_SIMD()
598             for (size_t e = input_start; e < input_end; e++) {
599                 output[e] = input_ptrs[0][e];
600             }
601
602             PRAGMA_OMP_SIMD()
603             for (size_t e = input_end; e < end_e; e++) {
604                 output[e] = 0.f;
605             }
606
607             for (size_t a = 1; a < num_arrs; a++) {
608                 input_start = max(start_e, input_starts[a]);
609                 input_end = min(input_ends[a], end_e);
610
611                 PRAGMA_OMP_SIMD()
612                 for (size_t e = input_start; e < input_end; e++) {
613                     output[e] += input_ptrs[a][e];
614                 }
615             }
616         }
617     }
618 }
619
620 const int max_threads_number = 1024;
621
622 // Sum to the first buffer array
623 void array_sum(size_t num_arrs, float *output,
624     size_t nelems, float *input_ptrs[], bool reduce_to_first = true) {
625     const size_t block_size = 16 * 1024 / sizeof(float);
626     const size_t blocks_number = nelems / block_size;
627     const size_t tail = nelems % block_size;
628
629 PRAGMA_OMP(parallel)
630     {
631         const size_t ithr = mkldnn_get_thread_num();
632         const size_t nthr = mkldnn_get_num_threads();
633         size_t start{ 0 }, end{ 0 };
634         balance211(blocks_number, nthr, ithr, start, end);
635
636         for (size_t nb = start; nb < end; ++nb) {
637             size_t start_e = nb * block_size;
638             size_t end_e = start_e + block_size;
639             if (!reduce_to_first) {
640                 PRAGMA_OMP_SIMD()
641                 for (size_t e = start_e; e < end_e; e++) {
642                     output[e] = input_ptrs[0][e];
643                 }
644             }
645             for (size_t a = 1; a < num_arrs; a++) {
646                 PRAGMA_OMP_SIMD()
647                 for (size_t e = start_e; e < end_e; e++) {
648                     output[e] += input_ptrs[a][e];
649                 }
650             }
651         }
652
653         if (tail != 0 && ithr == nthr - 1) {
654             size_t start_e = nelems - tail;
655             size_t end_e = nelems;
656             if (!reduce_to_first) {
657                 PRAGMA_OMP_SIMD()
658                 for (size_t e = start_e; e < end_e; e++) {
659                     output[e] = input_ptrs[0][e];
660                 }
661             }
662             for (size_t a = 1; a < num_arrs; a++) {
663                 PRAGMA_OMP_SIMD()
664                 for (size_t e = start_e; e < end_e; e++) {
665                     output[e] += input_ptrs[a][e];
666                 }
667             }
668         }
669     }
670 }
671 } //bwdw namespace
672
673 void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t::
674 _execute_backward_weights_SDGtWo(
675         const memory_tracking::grantor_t &scratchpad) const {
676     const auto &jcp = kernel_->jcp;
677     const int nthreads = jcp.nthr;
678
679     array_offset_calculator<float, 5> src((float *)this->input_memory(0),
680             jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w);
681     array_offset_calculator<float, 5> diff_dst((float *)this->input_memory(1),
682             jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w);
683     array_offset_calculator<float, 6> diff_weights((float *)this->memory(0),
684             jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
685
686     array_offset_calculator<float, 8> Us(scratchpad.get<float>(key_wino_U),
687             0, alpha, alpha,
688             jcp.oc_block, jcp.ic_block,
689             jcp.ic_simd_block,
690             jcp.oc_reg_block,
691             jcp.oc_simd_block);
692
693     const int U_sz = nthreads * alpha * alpha * jcp.oc / jcp.nb_oc
694         * jcp.ic / jcp.nb_ic;
695     array_offset_calculator<float, 7>diff_weights_prv(
696             scratchpad.get<float>(key_wino_U) + U_sz,
697             0, jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
698
699     array_offset_calculator<float, 8> M(scratchpad.get<float>(key_wino_M),
700             0, alpha, alpha,
701             jcp.oc_block,
702             jcp.nb_tile_block_ur,
703             jcp.tile_block_ur,
704             jcp.oc_reg_block,
705             jcp.oc_simd_block);
706
707     array_offset_calculator<float, 7> V(scratchpad.get<float>(key_wino_V),
708             0, alpha, alpha,
709             jcp.ic_block,
710             jcp.nb_tile_block_ur,
711             jcp.tile_block_ur,
712             jcp.ic_simd_block);
713
714     array_offset_calculator<float, 2> diff_bias_prv(
715             scratchpad.get<float>(key_conv_bia_reduction), nthreads, jcp.oc);
716
717     auto trans_ker_p = jit_wino_transform_call_s();
718     float I[alpha][alpha][simd_w];
719     float T[alpha][alpha][simd_w];
720     float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
721                0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
722     float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f, 0.119514472455649f,
723        0.430252100840336f, 0.168067226890756f, 0.179271708683473f, 0.403361344537815f,
724        1.13777777777778f};
725     float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f};
726
727 PRAGMA_OMP(parallel num_threads(nthreads) firstprivate(trans_ker_p, I, T))
728 {
729     if (jcp.with_bias) {
730         parallel_nd_in_omp(nthreads, jcp.oc / simd_w,
731             [&](int ithr, int ofm){
732                 float *pdbias = &(diff_bias_prv(ithr, ofm * simd_w));
733                 PRAGMA_OMP_SIMD()
734                 for (int v = 0; v < simd_w; v++) {
735                     pdbias[v] = 0.0f;
736                 }
737         });
738     }
739
740     int ithr = mkldnn_get_thread_num();
741     for (int ifm1 = 0; ifm1 < jcp.nb_ic; ++ifm1) {
742         int first_tblk = 0;
743 PRAGMA_OMP(for)
744         for (int tblk1 = 0; tblk1 < jcp.tile_block; ++tblk1) {
745             int tile_index = tblk1 * jcp.nb_tile_block_ur * jcp.tile_block_ur;
746             int img = tile_index / (jcp.itiles * jcp.jtiles);
747             trans_ker_p.ti = tile_index % jcp.itiles;
748             trans_ker_p.tj = (tile_index / jcp.itiles) % jcp.jtiles;
749             trans_ker_p.M = I;
750             trans_ker_p.T = T;
751             trans_ker_p.G = G_I_3x3_4x4;
752             for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) {
753                 int ifm = ifm1 * jcp.ic_block + ifm2;
754                 trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0));
755                 trans_ker_p.dst = (float *)&(V(ithr, 0, 0, ifm2, 0, 0, 0));
756                 kernel_->src_transform(&trans_ker_p);
757             }
758
759             for (int ofm1 = 0; ofm1 < jcp.nb_oc; ++ofm1) {
760                 trans_ker_p.G = G_W_3x3_4x4;
761                 for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) {
762                     int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block;
763                     trans_ker_p.src = (float *)&(diff_dst(img, ofm, 0, 0, 0));
764                     trans_ker_p.dst = (float *)&(M(ithr, 0, 0, ofm2, 0, 0, 0, 0));
765                     if (jcp.with_bias && ifm1 == 0) {
766                         trans_ker_p.bias = (float *)&(diff_bias_prv(ithr, ofm * simd_w));
767                         kernel_->diff_dst_transform_wbias(&trans_ker_p);
768                     } else {
769                         kernel_->diff_dst_transform(&trans_ker_p);
770                     }
771                 }
772
773                 for (int oj = 0; oj < alpha; ++oj) {
774                     for (int oi = 0; oi < alpha; ++oi) {
775                         kernel_->gemm_loop_ker_first_iter(
776                                 &(Us(ithr, oj, oi, 0, 0, 0, 0, 0)),
777                                 &(M(ithr, oj, oi, 0, 0, 0, 0, 0)),
778                                 &(V(ithr, oj, oi, 0, 0, 0, 0)));
779                     }
780                 }
781                 trans_ker_p.G = G_O_3x3_4x4;
782                 for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) {
783                     for (int ofm3 = 0; ofm3 < jcp.oc_reg_block; ++ofm3) {
784                         int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block
785                                 + ofm3;
786                         for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) {
787                             int ifm = ifm1 * jcp.ic_block + ifm2;
788                             trans_ker_p.src = (float *)&(Us(ithr, 0, 0,
789                                         ofm2, ifm2, 0, ofm3, 0));
790                             trans_ker_p.dst = (float *)&(diff_weights_prv(ithr,
791                                         ofm, ifm, 0, 0, 0, 0));
792                             if (first_tblk == 0) {
793                                 kernel_->diff_weights_transform(&trans_ker_p);
794                             } else {
795                                 kernel_->diff_weights_transform_accum(&trans_ker_p);
796                             }
797                         }
798                     }
799                 }
800             }
801             ++first_tblk;
802         }
803     }
804 }
805
806     // Reduce diff-weights
807     {
808         float *output = (float *)(this->memory(0));
809         float *input_base = scratchpad.get<float>(key_wino_U) + U_sz;
810         int nelems = jcp.oc * jcp.ic * jcp.kh * jcp.kw;
811         float *input_ptrs[max_threads_number];
812         for (int i = 0; i < nthreads; ++i) {
813             input_ptrs[i] = input_base + nelems * i;
814         }
815         array_sum(nthreads, output, nelems, input_ptrs, false);
816
817         if (jcp.with_bias) {
818             output = (float *)(this->memory(1));
819             input_base = scratchpad.get<float>(key_conv_bia_reduction);
820             for (int i = 0; i < nthreads; ++i) {
821                 input_ptrs[i] = input_base + jcp.oc * i;
822             }
823             array_sum(nthreads, output, jcp.oc_without_padding, input_ptrs,
824                     false);
825         }
826     }
827 }
828
829 void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t::
830 _execute_backward_weights_S_D_Giot_W(
831         const memory_tracking::grantor_t &scratchpad) const {
832     const auto &jcp = kernel_->jcp;
833     const int nthreads = jcp.nthr;
834
835     array_offset_calculator<float, 5> src((float *)this->input_memory(0),
836             jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w);
837     array_offset_calculator<float, 5> diff_dst((float *)this->input_memory(1),
838             jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w);
839     array_offset_calculator<float, 6> diff_weights((float *)this->memory(0),
840             jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
841     array_offset_calculator<float, 1> diff_bias((float *)this->memory(1), jcp.oc);
842
843     array_offset_calculator<float, 9> U(scratchpad.get<float>(key_wino_U),
844             jcp.nb_ic, jcp.nb_oc,
845             alpha, alpha,
846             jcp.oc_block, jcp.ic_block,
847             jcp.ic_simd_block,
848             jcp.oc_reg_block,
849             jcp.oc_simd_block);
850
851     const int U_size = jcp.oc * jcp.ic * alpha * alpha;
852     array_offset_calculator<float, 10> Us(
853             scratchpad.get<float>(key_wino_U) + U_size,
854             0, jcp.nb_ic, jcp.nb_oc,
855             alpha, alpha,
856             jcp.oc_block, jcp.ic_block,
857             jcp.ic_simd_block,
858             jcp.oc_reg_block,
859             jcp.oc_simd_block);
860
861     array_offset_calculator<float, 9> M(scratchpad.get<float>(key_wino_M),
862             jcp.nb_oc,
863             jcp.tile_block,
864             alpha, alpha,
865             jcp.oc_block,
866             jcp.nb_tile_block_ur,
867             jcp.tile_block_ur ,
868             jcp.oc_reg_block,
869             jcp.oc_simd_block);
870
871     array_offset_calculator<float, 8> V(scratchpad.get<float>(key_wino_V),
872             jcp.nb_ic,
873             jcp.tile_block,
874             alpha, alpha,
875             jcp.ic_block,
876             jcp.nb_tile_block_ur, jcp.tile_block_ur,
877             jcp.ic_simd_block);
878
879     array_offset_calculator<float, 2> diff_bias_prv(
880             scratchpad.get<float>(key_conv_bia_reduction), nthreads, jcp.oc);
881
882     size_t input_starts[max_threads_number] = {0};
883     size_t input_ends[max_threads_number] = {0};
884     size_t first_tblk = 0;
885
886     auto trans_ker_p = jit_wino_transform_call_s();
887     float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
888                0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
889     float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f,
890         0.119514472455649f, 0.430252100840336f, 0.168067226890756f,
891         0.179271708683473f, 0.403361344537815f, 1.13777777777778f};
892     float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f};
893     float I[alpha][alpha][simd_w];
894     float T[alpha][alpha][simd_w];
895
896 PRAGMA_OMP(parallel firstprivate(first_tblk, trans_ker_p, I, T))
897 {
898     if (jcp.with_bias) {
899         parallel_nd_in_omp(nthreads, jcp.oc, [&](int ithr, int ofm) {
900             diff_bias_prv(ithr, ofm) = 0.0f;
901         });
902     }
903
904     trans_ker_p.G = G_I_3x3_4x4;
905     trans_ker_p.M = I;
906     trans_ker_p.T = T;
907
908     parallel_nd_in_omp(jcp.nb_ic, jcp.ic_block, jcp.mb,
909         [&](int ifm1, int ifm2, int img){
910          size_t ifm = ifm1 * jcp.ic_block + ifm2;
911          size_t tile_base_index = img * (jcp.itiles * jcp.jtiles);
912          size_t tblk3 = tile_base_index  % jcp.tile_block_ur;
913          size_t tblk2 = (tile_base_index / jcp.tile_block_ur)
914              % jcp.nb_tile_block_ur;
915          size_t tblk1 = (tile_base_index / jcp.tile_block_ur)
916              / jcp.nb_tile_block_ur;
917          trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3;
918          trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0));
919          trans_ker_p.dst = (float *)&(V(ifm1, tblk1, 0, 0, ifm2, 0, 0, 0));
920          kernel_->src_transform(&trans_ker_p);
921     });
922
923     int ithr = mkldnn_get_thread_num();
924     trans_ker_p.G = G_W_3x3_4x4;
925     parallel_nd_in_omp(jcp.nb_oc, jcp.oc_block, jcp.mb,
926         [&](int ofm1, int ofm2, int img){
927         int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block;
928         size_t tile_base_index = img * (jcp.itiles * jcp.jtiles);
929         size_t tblk3 = tile_base_index  % jcp.tile_block_ur;
930         size_t tblk2 = (tile_base_index / jcp.tile_block_ur)
931             % jcp.nb_tile_block_ur;
932         size_t tblk1 = (tile_base_index / jcp.tile_block_ur)
933             / jcp.nb_tile_block_ur;
934         trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3;
935         trans_ker_p.src = (float *)&(diff_dst(img, ofm, 0, 0, 0));
936         trans_ker_p.dst = (float *)&(M(ofm1, tblk1, 0, 0, ofm2, 0, 0, 0, 0));
937         if (jcp.with_bias) {
938             trans_ker_p.bias = (float *)&(diff_bias_prv(ithr, ofm * simd_w));
939             kernel_->diff_dst_transform_wbias(&trans_ker_p);
940         } else {
941             kernel_->diff_dst_transform(&trans_ker_p);
942         }
943     });
944
945     PRAGMA_OMP(barrier)
946
947     parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, alpha, alpha, jcp.tile_block,
948         [&](int ifm1, int ofm1, int oj, int oi, int tblk1){
949         if (first_tblk == 0) {
950             input_starts[ithr] =
951                 (float *)&(Us(ithr, ifm1, ofm1, oj, oi, 0, 0, 0,
952                             0, 0))
953                 - (float *)&(Us(ithr, 0, 0, 0, 0, 0, 0,
954                             0, 0, 0));
955             input_ends[ithr] = input_starts[ithr]
956                     + jcp.oc_block * jcp.ic_block
957                       * jcp.ic_simd_block * jcp.oc_reg_block
958                       * jcp.oc_simd_block;
959         }
960         else if (tblk1 == 0) {
961             input_ends[ithr] += jcp.oc_block * jcp.ic_block
962                 * jcp.ic_simd_block * jcp.oc_reg_block
963                 * jcp.oc_simd_block;
964         }
965
966         if (first_tblk == 0 || tblk1 == 0) {
967             kernel_->gemm_loop_ker_first_iter(
968                     &(Us(ithr, ifm1, ofm1, oj, oi,
969                             0, 0, 0, 0, 0)),
970                     &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)),
971                     &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0)));
972         } else {
973             kernel_->gemm_loop_ker(
974                     &(Us(ithr, ifm1, ofm1, oj, oi,
975                             0, 0, 0, 0, 0)),
976                     &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)),
977                     &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0)));
978         }
979         ++first_tblk;
980     });
981 }
982
983     // Reduce diff-weights
984     {
985         float *output = &(U(0, 0, 0, 0, 0, 0, 0, 0, 0));
986         size_t nelems = jcp.ic * jcp.oc * alpha * alpha;
987         float *input_ptrs[max_threads_number];
988         for (int i = 0; i < nthreads; ++i)
989             input_ptrs[i] = output + nelems * (i + 1);
990         subarray_sum(nthreads, output, nelems, input_ptrs,
991                 input_starts, input_ends);
992     }
993
994     trans_ker_p.G = G_O_3x3_4x4;
995 PRAGMA_OMP(parallel firstprivate(trans_ker_p))
996     {
997         parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block, jcp.oc_reg_block,
998             [&](int ifm1, int ofm1, int ofm2, int ifm2, int ofm3){
999             int ofm = (ofm1 * jcp.oc_block + ofm2)
1000                 * jcp.oc_reg_block + ofm3;
1001             int ifm = ifm1 * jcp.ic_block + ifm2;
1002             trans_ker_p.src = (float *)&(U(ifm1, ofm1, 0, 0,
1003                         ofm2, ifm2, 0, ofm3, 0));
1004             trans_ker_p.dst = (float *)&(diff_weights(ofm, ifm,
1005                         0, 0, 0, 0));
1006             kernel_->diff_weights_transform(&trans_ker_p);
1007         });
1008     }
1009
1010     if (jcp.with_bias) {
1011         parallel_nd(jcp.oc / simd_w, [&](int ofm1) {
1012             float* pbias = &(diff_bias(ofm1 * simd_w));
1013             float *pbias_prv = &(diff_bias_prv(0, ofm1 * simd_w));
1014
1015             const int blk_sz = ofm1 == jcp.oc / simd_w - 1
1016                 ? jcp.oc_without_padding - ofm1 * simd_w : simd_w;
1017
1018             PRAGMA_OMP_SIMD()
1019             for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) {
1020                 pbias[ofm2] = pbias_prv[ofm2];
1021             }
1022
1023             for (int ithr = 1; ithr < nthreads; ++ithr) {
1024                 pbias_prv = &(diff_bias_prv(ithr, ofm1 * simd_w));
1025                 PRAGMA_OMP_SIMD()
1026                 for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) {
1027                     pbias[ofm2] += pbias_prv[ofm2];
1028                 }
1029             }
1030         });
1031     }
1032 }
1033
1034 }
1035 }
1036 }
1037 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s