Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_convolution_winograd.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifdef __INTEL_COMPILER
18 #include <immintrin.h>
19 #endif
20
21 #include "mkldnn_types.h"
22
23 #include "c_types_map.hpp"
24 #include "mkldnn_thread.hpp"
25 #include "type_helpers.hpp"
26 #include "utils.hpp"
27
28 #include "jit_avx512_common_convolution_winograd.hpp"
29 #include "jit_avx512_core_convolution_winograd.hpp"
30
31 #ifndef _MSC_VER
32 #define pragma_unroll _Pragma("unroll")
33 #else
34 #define pragma_unroll
35 #endif
36
37
38 namespace mkldnn {
39 namespace impl {
40 namespace cpu {
41
42 using namespace mkldnn::impl::status;
43 using namespace mkldnn::impl::memory_format;
44 using namespace mkldnn::impl::utils;
45
46 template <bool is_fwd>
47 void _jit_avx512_core_convolution_winograd_t<is_fwd>
48 ::weight_transform_data(const jit_conv_winograd_conf_t &jcp,
49         float *wp, float *twp)
50 {
51     float G[] = {0.26890756302521f, 0.688403361344538f, 0.119514472455649f,
52                  1.13777777777778f, 0.430252100840336f, 0.179271708683473f};
53     const int kh = 3;
54     const int kw = 3;
55     float Fw[alpha][alpha][simd_w][simd_w];
56     float F[kh][kw][simd_w][simd_w];
57     float T[alpha][3][simd_w];
58     auto p = jit_wino_transform_call_s();
59
60     p.src = wp;
61     p.dst = twp;
62     p.G = G;
63     p.M = F;
64     p.Mw = Fw;
65     p.T = T;
66
67     kernel_->weights_transform_data_ker(&p);
68 }
69
70 template<bool is_fwd>
71 void _jit_avx512_core_convolution_winograd_t<is_fwd>::output_transform_data
72 (int image, const jit_conv_winograd_conf_t &jcp,
73     const post_ops_t &p_ops, float *toutp, float *pout_b, float *bias) {
74
75     float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f};
76     float Ow[alpha][alpha][simd_w];
77     float O[tile_size][tile_size][simd_w];
78     float T[tile_size][alpha][simd_w];
79
80     auto p = jit_wino_transform_call_s();
81     p.src = toutp;
82     p.dst = pout_b;
83     p.G = G;
84     p.M = O;
85     p.Mw = Ow;
86     p.T = T;
87     p.bias = bias;
88
89     int tile_base_index = image * jcp.itiles * jcp.jtiles;
90     int tile_block_ur = tile_base_index % jcp.tile_block_ur;
91     int nb_tile_block_ur =
92         (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur;
93     int tile_block =
94         (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur;
95
96     for (int tj = 0; tj < jcp.jtiles; tj++) {
97         for (int ti = 0; ti < jcp.itiles; ti++) {
98
99             p.tile_block_ur = tile_block_ur;
100             p.nb_tile_block_ur = nb_tile_block_ur;
101             p.tile_block = tile_block;
102             p.tj = tj;
103             p.ti = ti;
104
105             kernel_->output_transform_data_ker(&p);
106
107             tile_block_ur++;
108             if (tile_block_ur >= jcp.tile_block_ur) {
109                 tile_block_ur = 0;
110                 nb_tile_block_ur++;
111             }
112             if (nb_tile_block_ur >= jcp.nb_tile_block_ur) {
113                 nb_tile_block_ur = 0;
114                 tile_block++;
115             }
116         }
117     }
118 }
119
120 template<bool is_fwd>
121 void _jit_avx512_core_convolution_winograd_t<is_fwd>
122 ::output_transform_tileblock_data(int tile_block,
123     const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops,
124     float *toutp, float *outp, float *bias) {
125
126     float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f};
127     float Ow[alpha][alpha][simd_w];
128     float O[tile_size][tile_size][simd_w];
129     float T[tile_size][alpha][simd_w];
130
131     auto p = jit_wino_transform_call_s();
132     p.src = toutp;
133     p.dst = outp;
134     p.G = G;
135     p.M = O;
136     p.Mw = Ow;
137     p.T = T;
138     p.bias = bias;
139
140     int outw = is_fwd ? jcp.ow : jcp.iw;
141     int outh = is_fwd ? jcp.oh : jcp.ih;
142
143     int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur;
144
145     for (int nb_tile_block_ur = 0;
146         nb_tile_block_ur < jcp.nb_tile_block_ur;
147         nb_tile_block_ur++) {
148
149         for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur;
150             tile_block_ur++) {
151             int img = tile_index / (jcp.jtiles * jcp.itiles);
152             int ti = tile_index % jcp.itiles;
153             int tj = (tile_index / jcp.itiles) % jcp.jtiles;
154
155             p.tile_block_ur = tile_block_ur;
156             p.nb_tile_block_ur = nb_tile_block_ur;
157             p.tile_block = tile_block;
158             p.tj = tj;
159             p.ti = ti;
160             p.dst = outp + img * (jcp.dimM / jcp.dimM_simd_block)
161                                * outh * outw * jcp.dimM_simd_block;
162
163             kernel_->output_transform_data_ker(&p);
164
165             tile_index++;
166         }
167     }
168 }
169
170
171 template<bool is_fwd>
172 void _jit_avx512_core_convolution_winograd_t<is_fwd>
173     ::input_transform_data(int image, const jit_conv_winograd_conf_t &jcp,
174         float *inp, float *tinp)
175 {
176     float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
177                  0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
178
179     float Iw[alpha][alpha][simd_w];
180     float I[alpha][alpha][simd_w];
181     float T[alpha][alpha][simd_w];
182
183     auto p = jit_wino_transform_call_s();
184
185     p.src = inp;
186     p.dst = tinp;
187     p.G = G;
188     p.M = I;
189     p.Mw = Iw;
190     p.T = T;
191
192     int tile_base_index = image * jcp.itiles * jcp.jtiles;
193     int tile_block_ur = tile_base_index % jcp.tile_block_ur;
194     int nb_tile_block_ur =
195         (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur;
196     int tile_block =
197         (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur;
198
199     for (int tj = 0; tj < jcp.jtiles; tj++) {
200         for (int ti = 0; ti < jcp.itiles; ti++) {
201
202             p.tile_block_ur = tile_block_ur;
203             p.nb_tile_block_ur = nb_tile_block_ur;
204             p.tile_block = tile_block;
205             p.tj = tj;
206             p.ti = ti;
207
208             kernel_->input_transform_data_ker(&p);
209
210             tile_block_ur++;
211             if (tile_block_ur >= jcp.tile_block_ur) {
212                 tile_block_ur = 0;
213                 nb_tile_block_ur++;
214             }
215             if (nb_tile_block_ur >= jcp.nb_tile_block_ur) {
216                 nb_tile_block_ur = 0;
217                 tile_block++;
218             }
219         }
220     }
221 }
222
223 template <bool is_fwd>
224 void _jit_avx512_core_convolution_winograd_t<is_fwd>
225     ::input_transform_tileblock_data(int tile_block,
226         const jit_conv_winograd_conf_t &jcp,
227         float *inp, float *tinp)
228 {
229     float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
230                0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
231     float Iw[alpha][alpha][simd_w];
232     float I[alpha][alpha][simd_w];
233     float T[alpha][alpha][simd_w];
234
235     const int inph = is_fwd ? jcp.ih : jcp.oh;
236     const int inpw = is_fwd ? jcp.iw : jcp.ow;
237
238     array_offset_calculator<float, 5> input(inp,
239         jcp.mb, jcp.dimK / simd_w, inph, inpw, simd_w);
240     array_offset_calculator<float, 7> output(tinp,
241         alpha, alpha,
242         jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block,
243         jcp.dimN_reg_block, jcp.dimK_reg_block);
244
245     auto p = jit_wino_transform_call_s();
246
247     p.dst = tinp;
248     p.G = G;
249     p.M = I;
250     p.Mw = Iw;
251     p.T = T;
252
253
254     int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur;
255
256     for (int nb_tile_block_ur = 0;
257             nb_tile_block_ur < jcp.nb_tile_block_ur;
258             nb_tile_block_ur++) {
259
260         for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur;
261                 tile_block_ur++) {
262
263             int img = tile_index / (jcp.jtiles * jcp.itiles);
264             int ti = tile_index % jcp.itiles;
265             int tj = (tile_index / jcp.itiles) % jcp.jtiles;
266             float *pinp_b = &(input(img, 0, 0, 0, 0));
267
268             p.src = pinp_b;
269             p.tile_block_ur = tile_block_ur;
270             p.nb_tile_block_ur = nb_tile_block_ur;
271             p.tj = tj;
272             p.ti = ti;
273
274             kernel_->input_transform_data_ker(&p);
275
276             tile_index++;
277         }
278     }
279 }
280
281 template <bool is_fwd>
282 void _jit_avx512_core_convolution_winograd_t<is_fwd>::_execute_data_W_S_G_D(
283         const int MB, float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr) {
284     const auto &jcp = kernel_->jcp;
285     const auto &p_ops = attr_->post_ops_;
286
287     const int inph = is_fwd ? jcp.ih : jcp.oh;
288     const int inpw = is_fwd ? jcp.iw : jcp.ow;
289     const int outh = is_fwd ? jcp.oh : jcp.ih;
290     const int outw = is_fwd ? jcp.ow : jcp.iw;
291
292     /* Notation:
293        FWD: dimM:oc, dimN:ntiles, dimK:ic,
294        BWD: dimM:ic, dimN:ntiles, dimK:oc,
295        FWD/BWD: V: src/diff_dst transform, U:weight transform,
296                 M:dst/diff_src transform  */
297     array_offset_calculator<float, 5> input(inp_ptr,
298             MB, jcp.dimK/jcp.dimK_reg_block, inph, inpw,
299             jcp.dimK_reg_block);
300     array_offset_calculator<float, 5> output(out_ptr,
301             MB, jcp.dimM/jcp.dimM_simd_block, outh, outw,
302             jcp.dimM_simd_block);
303     array_offset_calculator<float, 6> weights(wei_ptr,
304             jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw,
305             jcp.ic_simd_block, jcp.oc_simd_block);
306     array_offset_calculator<float, 2> bias(bias_ptr,
307             jcp.dimM/jcp.dimM_simd_block, jcp.dimM_simd_block);
308
309     array_offset_calculator<float, 8> M(
310             (float *)((is_fwd
311                     ? (this->scratchpad_)->M_ptr()
312                     : (this->scratchpad_)->V_ptr())),
313             jcp.dimN_nb_block, jcp.dimM_nb_block,
314             alpha, alpha,
315             jcp.dimN_block, jcp.dimM_block * jcp.dimM_reg_block,
316             jcp.dimN_reg_block, jcp.dimM_simd_block);
317     array_offset_calculator<float, 8> U((float *)((this->scratchpad_)->U_ptr()),
318             jcp.dimM_nb_block,
319             alpha, alpha,
320             jcp.dimK_nb_block,
321             jcp.dimM_block * jcp.dimM_reg_block, jcp.dimK_block,
322             jcp.dimK_reg_block, jcp.dimM_simd_block);
323     array_offset_calculator<float, 8> V(
324             (float *)((is_fwd
325                     ? (this->scratchpad_)->V_ptr()
326                     : (this->scratchpad_)->M_ptr())),
327             jcp.dimN_nb_block, alpha, alpha,
328             jcp.dimN_block, jcp.dimK_nb_block,
329             jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block);
330
331     const bool want_padded_bias = jcp.with_bias
332         && jcp.oc_without_padding != jcp.oc;
333     float last_slice_bias[simd_w] = {0};
334     if (want_padded_bias) {
335         for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
336             last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
337     }
338
339 #pragma omp parallel
340     {
341 #pragma omp for nowait collapse(3)
342         for (int img = 0; img < MB; img++){
343             for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++){
344                 for (int K_blk2 = 0; K_blk2 < jcp.dimK_block; K_blk2++){
345
346                     input_transform_data(img, jcp,
347                         &(input(img, K_blk1 * jcp.dimK_block + K_blk2,
348                                 0, 0, 0)),
349                         &(V(0, 0, 0, 0, K_blk1, K_blk2, 0, 0)));
350
351                 }
352             }
353         }
354
355 #pragma omp for nowait collapse(4) schedule(static)
356         for (int ofm1 = 0; ofm1 < jcp.nb_oc; ofm1++){
357             for (int ifm1 = 0; ifm1 < jcp.nb_ic; ifm1++){
358                 for (int ofm2 = 0; ofm2 < jcp.oc_block * jcp.oc_reg_block;
359                      ofm2++){
360                     for (int ifm2 = 0; ifm2 < jcp.ic_block * jcp.ic_reg_block;
361                          ifm2++){
362                         float *U_base_ptr = is_fwd
363                         ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0))
364                         : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0));
365                         weight_transform_data(jcp,
366                             &(weights(
367                                 ofm1 * jcp.oc_block * jcp.oc_reg_block + ofm2,
368                                 ifm1 * jcp.ic_block * jcp.ic_reg_block + ifm2,
369                                 0, 0, 0, 0)),
370                             U_base_ptr);
371                     }
372                 }
373             }
374         }
375
376 #pragma omp barrier
377
378 #pragma omp for collapse(4) nowait schedule(static)
379         for (int N_blk1 = 0; N_blk1 < jcp.dimN_nb_block; N_blk1++){
380             for (int oj = 0; oj < alpha; oj++){
381                 for (int oi = 0; oi < alpha; oi++){
382                     for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++){
383                         for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block;
384                              K_blk1++)
385                         for (int N_blk2 = 0; N_blk2 < jcp.dimN_block; N_blk2++)
386                             kernel_->gemm_loop_ker(
387                                     (float *)&(M(N_blk1, M_blk1, oj, oi,
388                                         N_blk2, 0, 0, 0)),
389                                     (const float *)&(U(M_blk1, oj, oi,
390                                         K_blk1, 0, 0, 0, 0)),
391                                     (const float *)&(V(N_blk1, oj, oi,
392                                         N_blk2, K_blk1, 0, 0, 0)), K_blk1);
393                     }
394                 }
395             }
396         }
397
398 #pragma omp barrier
399
400 #pragma omp for collapse(3)
401         for (int img = 0; img < MB; img++){
402             for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++){
403                 for (int M_blk2 = 0;
404                         M_blk2 < jcp.dimM_block * jcp.dimM_reg_block; M_blk2++)
405                 {
406                     const int M_blk =
407                         M_blk1 * jcp.dimM_block  * jcp.dimM_reg_block + M_blk2;
408
409                     float *bias_ptr = want_padded_bias
410                         && M_blk == jcp.dimM / jcp.dimM_simd_block - 1
411                         ? last_slice_bias : &bias(M_blk, 0);
412
413                     output_transform_data(img, jcp, p_ops,
414                             &(M(0, M_blk1, 0, 0, 0, M_blk2, 0, 0)),
415                             &(output(img, M_blk, 0, 0, 0)), bias_ptr);
416                 }
417             }
418         }
419     }
420 }
421
422 template void
423 _jit_avx512_core_convolution_winograd_t<true>::_execute_data_W_S_G_D(
424         const int, float *, float *, float *, float *);
425 template void
426 _jit_avx512_core_convolution_winograd_t<false>::_execute_data_W_S_G_D(
427         const int, float *, float *, float *, float *);
428
429 template <bool is_fwd>
430 void _jit_avx512_core_convolution_winograd_t<is_fwd>::_execute_data_W_SGD(
431         const int MB, float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr) {
432     const auto &jcp = kernel_->jcp;
433     const auto &p_ops = attr_->post_ops_;
434
435     const int inph = is_fwd ? jcp.ih : jcp.oh;
436     const int inpw = is_fwd ? jcp.iw : jcp.ow;
437     const int outh = is_fwd ? jcp.oh : jcp.ih;
438     const int outw = is_fwd ? jcp.ow : jcp.iw;
439
440     array_offset_calculator<float, 5> input(inp_ptr,
441         MB, jcp.dimK/jcp.dimK_reg_block, inph, inpw, jcp.dimK_reg_block);
442     array_offset_calculator<float, 5> output(out_ptr,
443         MB, jcp.dimM/jcp.dimM_simd_block, outh, outw, jcp.dimM_simd_block);
444     array_offset_calculator<float, 6> weights(wei_ptr,
445         jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw,
446         jcp.ic_simd_block, jcp.oc_simd_block);
447     array_offset_calculator<float, 2> bias(bias_ptr,
448         jcp.oc/jcp.oc_simd_block, jcp.oc_simd_block);
449
450     array_offset_calculator<float, 8> U((float *)((this->scratchpad_)->U_ptr()),
451             jcp.dimM_nb_block,
452             alpha, alpha,
453             jcp.dimK_nb_block,
454             jcp.dimM_block  * jcp.dimM_reg_block, jcp.dimK_block,
455             jcp.dimK_reg_block, jcp.dimM_simd_block);
456
457     array_offset_calculator<float, 8> M(
458             (float *)((is_fwd
459                     ? (this->scratchpad_)->M_ptr()
460                     : (this->scratchpad_)->V_ptr())),
461             0, jcp.dimM_nb_block, alpha, alpha,
462             jcp.dimN_block, jcp.dimM_block * jcp.dimM_reg_block,
463             jcp.dimN_reg_block, jcp.dimM_simd_block);
464     array_offset_calculator<float, 8> V(
465             (float *)((is_fwd
466                     ? (this->scratchpad_)->V_ptr()
467                     : (this->scratchpad_)->M_ptr())),
468             0, alpha, alpha, jcp.dimN_block,
469             jcp.dimK_nb_block, jcp.dimK_block,
470             jcp.dimN_reg_block, jcp.dimK_reg_block);
471
472     const bool want_padded_bias = jcp.with_bias
473         && jcp.oc_without_padding != jcp.oc;
474     float last_slice_bias[simd_w] = {0};
475     if (want_padded_bias) {
476         for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
477             last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
478     }
479
480 #pragma omp parallel
481     {
482 #pragma omp for collapse(4) schedule(static)
483     for (int ofm1 = 0; ofm1 < jcp.nb_oc; ofm1++) {
484         for (int ifm1 = 0; ifm1 < jcp.nb_ic; ifm1++) {
485             for (int ofm2 = 0; ofm2 < jcp.oc_block * jcp.oc_reg_block; ofm2++) {
486                 for (int ifm2 = 0; ifm2 < jcp.ic_block * jcp.ic_reg_block;
487                       ifm2++) {
488                     float *U_base_ptr = is_fwd
489                                       ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0))
490                                       : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0));
491                     weight_transform_data(jcp,
492                             &(weights(
493                                 ofm1 * jcp.oc_block * jcp.oc_reg_block + ofm2,
494                                 ifm1 * jcp.ic_block * jcp.ic_reg_block + ifm2,
495                                 0, 0, 0, 0)),
496                             U_base_ptr);
497                 }
498             }
499         }
500     }
501
502     int ithr = omp_get_thread_num();
503
504     #pragma omp for schedule(static)
505     for (int tile_block = 0; tile_block < jcp.tile_block; tile_block++) {
506         for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++) {
507             for (int K_blk2 = 0; K_blk2 < jcp.dimK_block; K_blk2++) {
508
509                 input_transform_tileblock_data(
510                         tile_block, jcp,
511                         &(input(0, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, 0)),
512                         &(V(ithr, 0, 0, 0, K_blk1, K_blk2, 0, 0)));
513             }
514         }
515
516         for (int oj = 0; oj < alpha; oj++) {
517             for (int oi = 0; oi < alpha; oi++) {
518                 for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++)
519                 for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++)
520                 for (int N_blk = 0; N_blk < jcp.dimN_block; N_blk++)
521                     kernel_->gemm_loop_ker(
522                             (float *)&(M(ithr, M_blk1, oj, oi,
523                                     N_blk, 0, 0, 0)),
524                             (const float *)&(U(M_blk1, oj, oi, K_blk1,
525                                     0, 0, 0, 0)),
526                             (const float *)&(V(ithr, oj, oi,
527                                     N_blk, K_blk1, 0, 0, 0)), K_blk1);
528             }
529         }
530
531         for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++) {
532             for (int M_blk2 = 0; M_blk2 < jcp.dimM_block * jcp.dimM_reg_block;
533                   M_blk2++) {
534                 const int M_blk =
535                     M_blk1 * jcp.dimM_block  * jcp.dimM_reg_block + M_blk2;
536
537                 float *bias_ptr = want_padded_bias
538                     && M_blk == jcp.dimM / jcp.dimM_simd_block - 1
539                     ? last_slice_bias : &bias(M_blk, 0);
540
541                 output_transform_tileblock_data(tile_block, jcp, p_ops,
542                         &(M(ithr, M_blk1, 0, 0, 0, M_blk2, 0, 0)),
543                         &(output(0, M_blk, 0, 0, 0)), bias_ptr);
544             }
545         }
546     }
547     }
548 }
549
550 template void
551 _jit_avx512_core_convolution_winograd_t<true>::_execute_data_W_SGD(
552         const int, float *, float *, float *, float *);
553 template void
554 _jit_avx512_core_convolution_winograd_t<false>::_execute_data_W_SGD(
555         const int, float *, float *, float *, float *);
556
557 namespace {
558
559 void subarray_sum(size_t num_arrs, float *output, size_t nelems,
560         float *input_ptrs[], size_t input_starts[], size_t input_ends[]) {
561     using namespace nstl;
562     const size_t block_size = 16 * 1024 / sizeof(float);
563     const size_t blocks_number = nelems / block_size;
564     const size_t tail = nelems % block_size;
565
566 #pragma omp parallel
567     {
568         const int ithr = omp_get_thread_num();
569         const int nthr = omp_get_num_threads();
570         size_t start{ 0 }, end{ 0 };
571         balance211(blocks_number, nthr, ithr, start, end);
572
573         for (size_t nb = start; nb < end; ++nb) {
574             size_t start_e = nb * block_size;
575             size_t end_e = start_e + block_size;
576             size_t input_start = max(start_e, min(input_starts[0], end_e));
577             size_t input_end = max(start_e, min(input_ends[0], end_e));
578
579             PRAGMA_OMP_SIMD()
580             for (size_t e = start_e; e < input_start; e++) {
581                 output[e] = 0.f;
582             }
583
584             PRAGMA_OMP_SIMD()
585             for (size_t e = input_start; e < input_end; e++) {
586                 output[e] = input_ptrs[0][e];
587             }
588
589             PRAGMA_OMP_SIMD()
590             for (size_t e = input_end; e < end_e; e++) {
591                 output[e] = 0.f;
592             }
593
594             for (size_t a = 1; a < num_arrs; a++) {
595                 input_start = max(start_e, input_starts[a]);
596                 input_end = min(input_ends[a], end_e);
597
598                 PRAGMA_OMP_SIMD()
599                 for (size_t e = input_start; e < input_end; e++) {
600                     output[e] += input_ptrs[a][e];
601                 }
602             }
603         }
604
605         if (tail != 0 && ithr == nthr - 1) {
606             size_t start_e = nelems - tail;
607             size_t end_e = nelems;
608             size_t input_start = max(start_e, min(input_starts[0], end_e));
609             size_t input_end = max(start_e, min(input_ends[0], end_e));
610
611             PRAGMA_OMP_SIMD()
612             for (size_t e = start_e; e < input_start; e++) {
613                 output[e] = 0.f;
614             }
615
616             PRAGMA_OMP_SIMD()
617             for (size_t e = input_start; e < input_end; e++) {
618                 output[e] = input_ptrs[0][e];
619             }
620
621             PRAGMA_OMP_SIMD()
622             for (size_t e = input_end; e < end_e; e++) {
623                 output[e] = 0.f;
624             }
625
626             for (size_t a = 1; a < num_arrs; a++) {
627                 input_start = max(start_e, input_starts[a]);
628                 input_end = min(input_ends[a], end_e);
629
630                 PRAGMA_OMP_SIMD()
631                 for (size_t e = input_start; e < input_end; e++) {
632                     output[e] += input_ptrs[a][e];
633                 }
634             }
635         }
636     }
637 }
638
639 const int max_threads_number = 1024;
640
641 // Sum to the first buffer array
642 void array_sum(size_t num_arrs, float *output,
643     size_t nelems, float *input_ptrs[], bool reduce_to_first = true) {
644     const size_t block_size = 16 * 1024 / sizeof(float);
645     const size_t blocks_number = nelems / block_size;
646     const size_t tail = nelems % block_size;
647
648 #pragma omp parallel
649     {
650         const size_t ithr = omp_get_thread_num();
651         const size_t nthr = omp_get_num_threads();
652         size_t start{ 0 }, end{ 0 };
653         balance211(blocks_number, nthr, ithr, start, end);
654
655         for (size_t nb = start; nb < end; ++nb) {
656             size_t start_e = nb * block_size;
657             size_t end_e = start_e + block_size;
658             if (!reduce_to_first) {
659                 PRAGMA_OMP_SIMD()
660                 for (size_t e = start_e; e < end_e; e++) {
661                     output[e] = input_ptrs[0][e];
662                 }
663             }
664             for (size_t a = 1; a < num_arrs; a++) {
665                 PRAGMA_OMP_SIMD()
666                 for (size_t e = start_e; e < end_e; e++) {
667                     output[e] += input_ptrs[a][e];
668                 }
669             }
670         }
671
672         if (tail != 0 && ithr == nthr - 1) {
673             size_t start_e = nelems - tail;
674             size_t end_e = nelems;
675             if (!reduce_to_first) {
676                 PRAGMA_OMP_SIMD()
677                 for (size_t e = start_e; e < end_e; e++) {
678                     output[e] = input_ptrs[0][e];
679                 }
680             }
681             for (size_t a = 1; a < num_arrs; a++) {
682                 PRAGMA_OMP_SIMD()
683                 for (size_t e = start_e; e < end_e; e++) {
684                     output[e] += input_ptrs[a][e];
685                 }
686             }
687         }
688     }
689 }
690 } //bwdw namespace
691
692 void jit_avx512_core_convolution_winograd_bwd_weights_t::
693 _execute_backward_weights_SDGtWo() {
694     const auto &jcp = kernel_->jcp;
695     const int nthreads = scratchpad_->num_threads();
696
697     array_offset_calculator<float, 5> src((float *)this->input_memory(0),
698             jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w);
699     array_offset_calculator<float, 5> diff_dst((float *)this->input_memory(1),
700             jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w);
701     array_offset_calculator<float, 6> diff_weights((float *)this->memory(0),
702             jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
703
704     array_offset_calculator<float, 8> Us((float *)(scratchpad_->U_ptr()),
705             0, alpha, alpha,
706             jcp.oc_block, jcp.ic_block,
707             jcp.ic_simd_block,
708             jcp.oc_reg_block,
709             jcp.oc_simd_block);
710
711     int U_sz = nthreads * alpha * alpha * jcp.oc / jcp.nb_oc
712         * jcp.ic / jcp.nb_ic * sizeof(float);
713     array_offset_calculator<float, 7>diff_weights_prv(
714             (float *)(scratchpad_->U_ptr() + U_sz),
715             0, jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
716
717     array_offset_calculator<float, 8> M((float *)(scratchpad_->M_ptr()),
718             0, alpha, alpha,
719             jcp.oc_block,
720             jcp.nb_tile_block_ur,
721             jcp.tile_block_ur,
722             jcp.oc_reg_block,
723             jcp.oc_simd_block);
724
725     array_offset_calculator<float, 7> V((float *)(scratchpad_->V_ptr()),
726             0, alpha, alpha,
727             jcp.ic_block,
728             jcp.nb_tile_block_ur,
729             jcp.tile_block_ur,
730             jcp.ic_simd_block);
731
732     array_offset_calculator<float, 2> diff_bias_prv(
733             (float *)(scratchpad_->bias_ptr()), nthreads, jcp.oc);
734
735     auto trans_ker_p = jit_wino_transform_call_s();
736     float I[alpha][alpha][simd_w];
737     float T[alpha][alpha][simd_w];
738     float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
739                0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
740     float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f, 0.119514472455649f,
741        0.430252100840336f, 0.168067226890756f, 0.179271708683473f, 0.403361344537815f,
742        1.13777777777778f};
743     float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f};
744
745 #pragma omp parallel firstprivate(trans_ker_p, I, T)
746 {
747     if (jcp.with_bias) {
748 #pragma omp for nowait collapse(2)
749         for (int ithr = 0; ithr < nthreads; ithr++) {
750             for (int ofm = 0; ofm < jcp.oc / simd_w; ofm++) {
751                 float *pdbias = &(diff_bias_prv(ithr, ofm * simd_w));
752                 PRAGMA_OMP_SIMD()
753                 for (int v = 0; v < simd_w; v++) {
754                     pdbias[v] = 0.0f;
755                 }
756             }
757         }
758     }
759
760     int ithr = omp_get_thread_num();
761     for (int ifm1 = 0; ifm1 < jcp.nb_ic; ++ifm1) {
762         int first_tblk = 0;
763 #pragma omp for
764         for (int tblk1 = 0; tblk1 < jcp.tile_block; ++tblk1) {
765             int tile_index = tblk1 * jcp.nb_tile_block_ur * jcp.tile_block_ur;
766             int img = tile_index / (jcp.itiles * jcp.jtiles);
767             trans_ker_p.ti = tile_index % jcp.itiles;
768             trans_ker_p.tj = (tile_index / jcp.itiles) % jcp.jtiles;
769             trans_ker_p.M = I;
770             trans_ker_p.T = T;
771             trans_ker_p.G = G_I_3x3_4x4;
772             for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) {
773                 int ifm = ifm1 * jcp.ic_block + ifm2;
774                 trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0));
775                 trans_ker_p.dst = (float *)&(V(ithr, 0, 0, ifm2, 0, 0, 0));
776                 kernel_->src_transform(&trans_ker_p);
777             }
778
779             for (int ofm1 = 0; ofm1 < jcp.nb_oc; ++ofm1) {
780                 trans_ker_p.G = G_W_3x3_4x4;
781                 for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) {
782                     int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block;
783                     trans_ker_p.src = (float *)&(diff_dst(img, ofm, 0, 0, 0));
784                     trans_ker_p.dst = (float *)&(M(ithr, 0, 0, ofm2, 0, 0, 0, 0));
785                     if (jcp.with_bias && ifm1 == 0) {
786                         trans_ker_p.bias = (float *)&(diff_bias_prv(ithr, ofm * simd_w));
787                         kernel_->diff_dst_transform_wbias(&trans_ker_p);
788                     } else {
789                         kernel_->diff_dst_transform(&trans_ker_p);
790                     }
791                 }
792
793                 for (int oj = 0; oj < alpha; ++oj) {
794                     for (int oi = 0; oi < alpha; ++oi) {
795                         kernel_->gemm_loop_ker_first_iter(
796                                 &(Us(ithr, oj, oi, 0, 0, 0, 0, 0)),
797                                 &(M(ithr, oj, oi, 0, 0, 0, 0, 0)),
798                                 &(V(ithr, oj, oi, 0, 0, 0, 0)));
799                     }
800                 }
801                 trans_ker_p.G = G_O_3x3_4x4;
802                 for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) {
803                     for (int ofm3 = 0; ofm3 < jcp.oc_reg_block; ++ofm3) {
804                         int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block
805                                 + ofm3;
806                         for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) {
807                             int ifm = ifm1 * jcp.ic_block + ifm2;
808                             trans_ker_p.src = (float *)&(Us(ithr, 0, 0,
809                                         ofm2, ifm2, 0, ofm3, 0));
810                             trans_ker_p.dst = (float *)&(diff_weights_prv(ithr,
811                                         ofm, ifm, 0, 0, 0, 0));
812                             if (first_tblk == 0) {
813                                 kernel_->diff_weights_transform(&trans_ker_p);
814                             } else {
815                                 kernel_->diff_weights_transform_accum(&trans_ker_p);
816                             }
817                         }
818                     }
819                 }
820             }
821             ++first_tblk;
822         }
823     }
824 }
825
826     // Reduce diff-weights
827     {
828         float *output = (float *)(this->memory(0));
829         float *input_base = (float *)(scratchpad_->U_ptr() + U_sz);
830         int nelems = jcp.oc * jcp.ic * jcp.kh * jcp.kw;
831         float *input_ptrs[max_threads_number];
832         for (int i = 0; i < nthreads; ++i) {
833             input_ptrs[i] = input_base + nelems * i;
834         }
835         array_sum(nthreads, output, nelems, input_ptrs, false);
836
837         if (jcp.with_bias) {
838             output = (float *)(this->memory(1));
839             input_base = (float *)(scratchpad_->bias_ptr());
840             for (int i = 0; i < nthreads; ++i) {
841                 input_ptrs[i] = input_base + jcp.oc * i;
842             }
843             array_sum(nthreads, output, jcp.oc_without_padding, input_ptrs,
844                     false);
845         }
846     }
847 }
848
849 void jit_avx512_core_convolution_winograd_bwd_weights_t::
850 _execute_backward_weights_S_D_Giot_W() {
851     const auto &jcp = kernel_->jcp;
852     const int nthreads = scratchpad_->num_threads();
853
854     array_offset_calculator<float, 5> src((float *)this->input_memory(0),
855             jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w);
856     array_offset_calculator<float, 5> diff_dst((float *)this->input_memory(1),
857             jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w);
858     array_offset_calculator<float, 6> diff_weights((float *)this->memory(0),
859             jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
860     array_offset_calculator<float, 1> diff_bias((float *)this->memory(1), jcp.oc);
861
862     array_offset_calculator<float, 9> U((float *)(scratchpad_->U_ptr()),
863             jcp.nb_ic, jcp.nb_oc,
864             alpha, alpha,
865             jcp.oc_block, jcp.ic_block,
866             jcp.ic_simd_block,
867             jcp.oc_reg_block,
868             jcp.oc_simd_block);
869
870     int U_size = jcp.oc * jcp.ic * alpha * alpha * sizeof(float);
871     array_offset_calculator<float, 10> Us(
872             (float *)(scratchpad_->U_ptr() + U_size),
873             0, jcp.nb_ic, jcp.nb_oc,
874             alpha, alpha,
875             jcp.oc_block, jcp.ic_block,
876             jcp.ic_simd_block,
877             jcp.oc_reg_block,
878             jcp.oc_simd_block);
879
880     array_offset_calculator<float, 9> M((float *)(scratchpad_->M_ptr()),
881             jcp.nb_oc,
882             jcp.tile_block,
883             alpha, alpha,
884             jcp.oc_block,
885             jcp.nb_tile_block_ur,
886             jcp.tile_block_ur ,
887             jcp.oc_reg_block,
888             jcp.oc_simd_block);
889
890     array_offset_calculator<float, 8> V((float *)(scratchpad_->V_ptr()),
891             jcp.nb_ic,
892             jcp.tile_block,
893             alpha, alpha,
894             jcp.ic_block,
895             jcp.nb_tile_block_ur, jcp.tile_block_ur,
896             jcp.ic_simd_block);
897
898     array_offset_calculator<float, 2> diff_bias_prv(
899             (float *)(scratchpad_->bias_ptr()), nthreads, jcp.oc);
900
901     size_t input_starts[max_threads_number];
902     size_t input_ends[max_threads_number];
903     size_t first_tblk = 0;
904
905     auto trans_ker_p = jit_wino_transform_call_s();
906     float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
907                0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
908     float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f,
909         0.119514472455649f, 0.430252100840336f, 0.168067226890756f,
910         0.179271708683473f, 0.403361344537815f, 1.13777777777778f};
911     float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f};
912     float I[alpha][alpha][simd_w];
913     float T[alpha][alpha][simd_w];
914
915 #pragma omp parallel firstprivate(first_tblk, trans_ker_p, I, T)
916 {
917     if (jcp.with_bias) {
918 #pragma omp for nowait collapse(2)
919         for (int ithr = 0; ithr < nthreads; ++ithr) {
920             for (int ofm = 0; ofm < jcp.oc; ++ofm) {
921                 diff_bias_prv(ithr, ofm) = 0.0f;
922             }
923         }
924     }
925
926     trans_ker_p.G = G_I_3x3_4x4;
927     trans_ker_p.M = I;
928     trans_ker_p.T = T;
929 #pragma omp for collapse(3) nowait
930     for (int ifm1 = 0; ifm1 < jcp.nb_ic; ++ifm1) {
931          for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) {
932              for (int img = 0; img < jcp.mb; img++) {
933                  size_t ifm = ifm1 * jcp.ic_block + ifm2;
934                  size_t tile_base_index = img * (jcp.itiles * jcp.jtiles);
935                  size_t tblk3 = tile_base_index  % jcp.tile_block_ur;
936                  size_t tblk2 = (tile_base_index / jcp.tile_block_ur)
937                      % jcp.nb_tile_block_ur;
938                  size_t tblk1 = (tile_base_index / jcp.tile_block_ur)
939                      / jcp.nb_tile_block_ur;
940                  trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3;
941                  trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0));
942                  trans_ker_p.dst = (float *)&(V(ifm1, tblk1, 0, 0, ifm2, 0, 0, 0));
943                  kernel_->src_transform(&trans_ker_p);
944              }
945          }
946     }
947
948     int ithr = omp_get_thread_num();
949     trans_ker_p.G = G_W_3x3_4x4;
950 #pragma omp for collapse(3)
951     for (int ofm1 = 0; ofm1 < jcp.nb_oc; ++ofm1) {
952         for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) {
953             for (int img = 0; img < jcp.mb; ++img) {
954                 int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block;
955                 size_t tile_base_index = img * (jcp.itiles * jcp.jtiles);
956                 size_t tblk3 = tile_base_index  % jcp.tile_block_ur;
957                 size_t tblk2 = (tile_base_index / jcp.tile_block_ur)
958                     % jcp.nb_tile_block_ur;
959                 size_t tblk1 = (tile_base_index / jcp.tile_block_ur)
960                     / jcp.nb_tile_block_ur;
961                 trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3;
962                 trans_ker_p.src = (float *)&(diff_dst(img, ofm, 0, 0, 0));
963                 trans_ker_p.dst = (float *)&(M(ofm1, tblk1, 0, 0, ofm2, 0, 0, 0, 0));
964                 if (jcp.with_bias) {
965                     trans_ker_p.bias = (float *)&(diff_bias_prv(ithr, ofm * simd_w));
966                     kernel_->diff_dst_transform_wbias(&trans_ker_p);
967                 } else {
968                     kernel_->diff_dst_transform(&trans_ker_p);
969                 }
970             }
971         }
972     }
973
974 #pragma omp for collapse(5) nowait schedule(static)
975     for (int ifm1 = 0; ifm1 < jcp.nb_ic; ++ifm1) {
976         for (int ofm1 = 0; ofm1 < jcp.nb_oc; ++ofm1) {
977             for (int oj = 0; oj < alpha; ++oj) {
978                 for (int oi = 0; oi < alpha; ++oi) {
979                     for (int tblk1 = 0; tblk1 < jcp.tile_block; ++tblk1) {
980                         if (first_tblk == 0) {
981                             input_starts[ithr] =
982                                 (float *)&(Us(ithr, ifm1, ofm1, oj, oi, 0, 0, 0,
983                                             0, 0))
984                                 - (float *)&(Us(ithr, 0, 0, 0, 0, 0, 0,
985                                             0, 0, 0));
986                             input_ends[ithr] = input_starts[ithr]
987                                     + jcp.oc_block * jcp.ic_block
988                                       * jcp.ic_simd_block * jcp.oc_reg_block
989                                       * jcp.oc_simd_block;
990                         }
991                         else if (tblk1 == 0) {
992                             input_ends[ithr] += jcp.oc_block * jcp.ic_block
993                                 * jcp.ic_simd_block * jcp.oc_reg_block
994                                 * jcp.oc_simd_block;
995                         }
996
997                         if (first_tblk == 0 || tblk1 == 0) {
998                             kernel_->gemm_loop_ker_first_iter(
999                                     &(Us(ithr, ifm1, ofm1, oj, oi,
1000                                             0, 0, 0, 0, 0)),
1001                                     &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)),
1002                                     &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0)));
1003                         } else {
1004                             kernel_->gemm_loop_ker(
1005                                     &(Us(ithr, ifm1, ofm1, oj, oi,
1006                                             0, 0, 0, 0, 0)),
1007                                     &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)),
1008                                     &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0)));
1009                         }
1010                         ++first_tblk;
1011                     }
1012                 }
1013             }
1014         }
1015     }
1016 }
1017
1018     // Reduce diff-weights
1019     {
1020         float *output = &(U(0, 0, 0, 0, 0, 0, 0, 0, 0));
1021         size_t nelems = jcp.ic * jcp.oc * alpha * alpha;
1022         float *input_ptrs[max_threads_number];
1023         for (int i = 0; i < nthreads; ++i)
1024             input_ptrs[i] = output + nelems * (i + 1);
1025         subarray_sum(nthreads, output, nelems, input_ptrs,
1026                 input_starts, input_ends);
1027     }
1028
1029     trans_ker_p.G = G_O_3x3_4x4;
1030 #pragma omp parallel for collapse(5) firstprivate(trans_ker_p)
1031     for (int ifm1 = 0; ifm1 < jcp.nb_ic; ++ifm1) {
1032         for (int ofm1 = 0; ofm1 < jcp.nb_oc; ++ofm1) {
1033             for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) {
1034                 for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) {
1035                     for (int ofm3 = 0;  ofm3 < jcp.oc_reg_block; ++ofm3) {
1036                         int ofm = (ofm1 * jcp.oc_block + ofm2)
1037                             * jcp.oc_reg_block + ofm3;
1038                         int ifm = ifm1 * jcp.ic_block + ifm2;
1039                         trans_ker_p.src = (float *)&(U(ifm1, ofm1, 0, 0,
1040                                     ofm2, ifm2, 0, ofm3, 0));
1041                         trans_ker_p.dst = (float *)&(diff_weights(ofm, ifm,
1042                                     0, 0, 0, 0));
1043                         kernel_->diff_weights_transform(&trans_ker_p);
1044                     }
1045                 }
1046             }
1047         }
1048     }
1049
1050     if (jcp.with_bias) {
1051 #pragma omp parallel for
1052         for (int ofm1 = 0; ofm1 < jcp.oc / simd_w; ++ofm1) {
1053             float* pbias = &(diff_bias(ofm1 * simd_w));
1054             float *pbias_prv = &(diff_bias_prv(0, ofm1 * simd_w));
1055
1056             const int blk_sz = ofm1 == jcp.oc / simd_w - 1
1057                 ? jcp.oc_without_padding - ofm1 * simd_w : simd_w;
1058
1059             PRAGMA_OMP_SIMD()
1060             for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) {
1061                 pbias[ofm2] = pbias_prv[ofm2];
1062             }
1063
1064             for (int ithr = 1; ithr < nthreads; ++ithr) {
1065                 pbias_prv = &(diff_bias_prv(ithr, ofm1 * simd_w));
1066                 PRAGMA_OMP_SIMD()
1067                 for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) {
1068                     pbias[ofm2] += pbias_prv[ofm2];
1069                 }
1070             }
1071         }
1072     }
1073 }
1074
1075 }
1076 }
1077 }
1078 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s