1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifdef __INTEL_COMPILER
18 #include <immintrin.h>
21 #include "mkldnn_types.h"
23 #include "c_types_map.hpp"
24 #include "mkldnn_thread.hpp"
25 #include "type_helpers.hpp"
28 #include "jit_avx512_common_convolution_winograd.hpp"
29 #include "jit_avx512_core_convolution_winograd.hpp"
32 #define pragma_unroll _Pragma("unroll")
42 using namespace mkldnn::impl::status;
43 using namespace mkldnn::impl::memory_format;
44 using namespace mkldnn::impl::utils;
46 template <bool is_fwd>
47 void _jit_avx512_core_convolution_winograd_t<is_fwd>
48 ::weight_transform_data(const jit_conv_winograd_conf_t &jcp,
49 float *wp, float *twp)
51 float G[] = {0.26890756302521f, 0.688403361344538f, 0.119514472455649f,
52 1.13777777777778f, 0.430252100840336f, 0.179271708683473f};
55 float Fw[alpha][alpha][simd_w][simd_w];
56 float F[kh][kw][simd_w][simd_w];
57 float T[alpha][3][simd_w];
58 auto p = jit_wino_transform_call_s();
67 kernel_->weights_transform_data_ker(&p);
71 void _jit_avx512_core_convolution_winograd_t<is_fwd>::output_transform_data
72 (int image, const jit_conv_winograd_conf_t &jcp,
73 const post_ops_t &p_ops, float *toutp, float *pout_b, float *bias) {
75 float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f};
76 float Ow[alpha][alpha][simd_w];
77 float O[tile_size][tile_size][simd_w];
78 float T[tile_size][alpha][simd_w];
80 auto p = jit_wino_transform_call_s();
89 int tile_base_index = image * jcp.itiles * jcp.jtiles;
90 int tile_block_ur = tile_base_index % jcp.tile_block_ur;
91 int nb_tile_block_ur =
92 (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur;
94 (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur;
96 for (int tj = 0; tj < jcp.jtiles; tj++) {
97 for (int ti = 0; ti < jcp.itiles; ti++) {
99 p.tile_block_ur = tile_block_ur;
100 p.nb_tile_block_ur = nb_tile_block_ur;
101 p.tile_block = tile_block;
105 kernel_->output_transform_data_ker(&p);
108 if (tile_block_ur >= jcp.tile_block_ur) {
112 if (nb_tile_block_ur >= jcp.nb_tile_block_ur) {
113 nb_tile_block_ur = 0;
120 template<bool is_fwd>
121 void _jit_avx512_core_convolution_winograd_t<is_fwd>
122 ::output_transform_tileblock_data(int tile_block,
123 const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops,
124 float *toutp, float *outp, float *bias) {
126 float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f};
127 float Ow[alpha][alpha][simd_w];
128 float O[tile_size][tile_size][simd_w];
129 float T[tile_size][alpha][simd_w];
131 auto p = jit_wino_transform_call_s();
140 int outw = is_fwd ? jcp.ow : jcp.iw;
141 int outh = is_fwd ? jcp.oh : jcp.ih;
143 int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur;
145 for (int nb_tile_block_ur = 0;
146 nb_tile_block_ur < jcp.nb_tile_block_ur;
147 nb_tile_block_ur++) {
149 for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur;
151 int img = tile_index / (jcp.jtiles * jcp.itiles);
152 int ti = tile_index % jcp.itiles;
153 int tj = (tile_index / jcp.itiles) % jcp.jtiles;
155 p.tile_block_ur = tile_block_ur;
156 p.nb_tile_block_ur = nb_tile_block_ur;
157 p.tile_block = tile_block;
160 p.dst = outp + img * (jcp.dimM / jcp.dimM_simd_block)
161 * outh * outw * jcp.dimM_simd_block;
163 kernel_->output_transform_data_ker(&p);
171 template<bool is_fwd>
172 void _jit_avx512_core_convolution_winograd_t<is_fwd>
173 ::input_transform_data(int image, const jit_conv_winograd_conf_t &jcp,
174 float *inp, float *tinp)
176 float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
177 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
179 float Iw[alpha][alpha][simd_w];
180 float I[alpha][alpha][simd_w];
181 float T[alpha][alpha][simd_w];
183 auto p = jit_wino_transform_call_s();
192 int tile_base_index = image * jcp.itiles * jcp.jtiles;
193 int tile_block_ur = tile_base_index % jcp.tile_block_ur;
194 int nb_tile_block_ur =
195 (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur;
197 (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur;
199 for (int tj = 0; tj < jcp.jtiles; tj++) {
200 for (int ti = 0; ti < jcp.itiles; ti++) {
202 p.tile_block_ur = tile_block_ur;
203 p.nb_tile_block_ur = nb_tile_block_ur;
204 p.tile_block = tile_block;
208 kernel_->input_transform_data_ker(&p);
211 if (tile_block_ur >= jcp.tile_block_ur) {
215 if (nb_tile_block_ur >= jcp.nb_tile_block_ur) {
216 nb_tile_block_ur = 0;
223 template <bool is_fwd>
224 void _jit_avx512_core_convolution_winograd_t<is_fwd>
225 ::input_transform_tileblock_data(int tile_block,
226 const jit_conv_winograd_conf_t &jcp,
227 float *inp, float *tinp)
229 float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
230 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
231 float Iw[alpha][alpha][simd_w];
232 float I[alpha][alpha][simd_w];
233 float T[alpha][alpha][simd_w];
235 const int inph = is_fwd ? jcp.ih : jcp.oh;
236 const int inpw = is_fwd ? jcp.iw : jcp.ow;
238 array_offset_calculator<float, 5> input(inp,
239 jcp.mb, jcp.dimK / simd_w, inph, inpw, simd_w);
240 array_offset_calculator<float, 7> output(tinp,
242 jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block,
243 jcp.dimN_reg_block, jcp.dimK_reg_block);
245 auto p = jit_wino_transform_call_s();
254 int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur;
256 for (int nb_tile_block_ur = 0;
257 nb_tile_block_ur < jcp.nb_tile_block_ur;
258 nb_tile_block_ur++) {
260 for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur;
263 int img = tile_index / (jcp.jtiles * jcp.itiles);
264 int ti = tile_index % jcp.itiles;
265 int tj = (tile_index / jcp.itiles) % jcp.jtiles;
266 float *pinp_b = &(input(img, 0, 0, 0, 0));
269 p.tile_block_ur = tile_block_ur;
270 p.nb_tile_block_ur = nb_tile_block_ur;
274 kernel_->input_transform_data_ker(&p);
281 template <bool is_fwd>
282 void _jit_avx512_core_convolution_winograd_t<is_fwd>::_execute_data_W_S_G_D(
283 const int MB, float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr) {
284 const auto &jcp = kernel_->jcp;
285 const auto &p_ops = attr_->post_ops_;
287 const int inph = is_fwd ? jcp.ih : jcp.oh;
288 const int inpw = is_fwd ? jcp.iw : jcp.ow;
289 const int outh = is_fwd ? jcp.oh : jcp.ih;
290 const int outw = is_fwd ? jcp.ow : jcp.iw;
293 FWD: dimM:oc, dimN:ntiles, dimK:ic,
294 BWD: dimM:ic, dimN:ntiles, dimK:oc,
295 FWD/BWD: V: src/diff_dst transform, U:weight transform,
296 M:dst/diff_src transform */
297 array_offset_calculator<float, 5> input(inp_ptr,
298 MB, jcp.dimK/jcp.dimK_reg_block, inph, inpw,
300 array_offset_calculator<float, 5> output(out_ptr,
301 MB, jcp.dimM/jcp.dimM_simd_block, outh, outw,
302 jcp.dimM_simd_block);
303 array_offset_calculator<float, 6> weights(wei_ptr,
304 jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw,
305 jcp.ic_simd_block, jcp.oc_simd_block);
306 array_offset_calculator<float, 2> bias(bias_ptr,
307 jcp.dimM/jcp.dimM_simd_block, jcp.dimM_simd_block);
309 array_offset_calculator<float, 8> M(
311 ? (this->scratchpad_)->M_ptr()
312 : (this->scratchpad_)->V_ptr())),
313 jcp.dimN_nb_block, jcp.dimM_nb_block,
315 jcp.dimN_block, jcp.dimM_block * jcp.dimM_reg_block,
316 jcp.dimN_reg_block, jcp.dimM_simd_block);
317 array_offset_calculator<float, 8> U((float *)((this->scratchpad_)->U_ptr()),
321 jcp.dimM_block * jcp.dimM_reg_block, jcp.dimK_block,
322 jcp.dimK_reg_block, jcp.dimM_simd_block);
323 array_offset_calculator<float, 8> V(
325 ? (this->scratchpad_)->V_ptr()
326 : (this->scratchpad_)->M_ptr())),
327 jcp.dimN_nb_block, alpha, alpha,
328 jcp.dimN_block, jcp.dimK_nb_block,
329 jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block);
331 const bool want_padded_bias = jcp.with_bias
332 && jcp.oc_without_padding != jcp.oc;
333 float last_slice_bias[simd_w] = {0};
334 if (want_padded_bias) {
335 for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
336 last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
341 #pragma omp for nowait collapse(3)
342 for (int img = 0; img < MB; img++){
343 for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++){
344 for (int K_blk2 = 0; K_blk2 < jcp.dimK_block; K_blk2++){
346 input_transform_data(img, jcp,
347 &(input(img, K_blk1 * jcp.dimK_block + K_blk2,
349 &(V(0, 0, 0, 0, K_blk1, K_blk2, 0, 0)));
355 #pragma omp for nowait collapse(4) schedule(static)
356 for (int ofm1 = 0; ofm1 < jcp.nb_oc; ofm1++){
357 for (int ifm1 = 0; ifm1 < jcp.nb_ic; ifm1++){
358 for (int ofm2 = 0; ofm2 < jcp.oc_block * jcp.oc_reg_block;
360 for (int ifm2 = 0; ifm2 < jcp.ic_block * jcp.ic_reg_block;
362 float *U_base_ptr = is_fwd
363 ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0))
364 : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0));
365 weight_transform_data(jcp,
367 ofm1 * jcp.oc_block * jcp.oc_reg_block + ofm2,
368 ifm1 * jcp.ic_block * jcp.ic_reg_block + ifm2,
378 #pragma omp for collapse(4) nowait schedule(static)
379 for (int N_blk1 = 0; N_blk1 < jcp.dimN_nb_block; N_blk1++){
380 for (int oj = 0; oj < alpha; oj++){
381 for (int oi = 0; oi < alpha; oi++){
382 for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++){
383 for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block;
385 for (int N_blk2 = 0; N_blk2 < jcp.dimN_block; N_blk2++)
386 kernel_->gemm_loop_ker(
387 (float *)&(M(N_blk1, M_blk1, oj, oi,
389 (const float *)&(U(M_blk1, oj, oi,
390 K_blk1, 0, 0, 0, 0)),
391 (const float *)&(V(N_blk1, oj, oi,
392 N_blk2, K_blk1, 0, 0, 0)), K_blk1);
400 #pragma omp for collapse(3)
401 for (int img = 0; img < MB; img++){
402 for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++){
404 M_blk2 < jcp.dimM_block * jcp.dimM_reg_block; M_blk2++)
407 M_blk1 * jcp.dimM_block * jcp.dimM_reg_block + M_blk2;
409 float *bias_ptr = want_padded_bias
410 && M_blk == jcp.dimM / jcp.dimM_simd_block - 1
411 ? last_slice_bias : &bias(M_blk, 0);
413 output_transform_data(img, jcp, p_ops,
414 &(M(0, M_blk1, 0, 0, 0, M_blk2, 0, 0)),
415 &(output(img, M_blk, 0, 0, 0)), bias_ptr);
423 _jit_avx512_core_convolution_winograd_t<true>::_execute_data_W_S_G_D(
424 const int, float *, float *, float *, float *);
426 _jit_avx512_core_convolution_winograd_t<false>::_execute_data_W_S_G_D(
427 const int, float *, float *, float *, float *);
429 template <bool is_fwd>
430 void _jit_avx512_core_convolution_winograd_t<is_fwd>::_execute_data_W_SGD(
431 const int MB, float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr) {
432 const auto &jcp = kernel_->jcp;
433 const auto &p_ops = attr_->post_ops_;
435 const int inph = is_fwd ? jcp.ih : jcp.oh;
436 const int inpw = is_fwd ? jcp.iw : jcp.ow;
437 const int outh = is_fwd ? jcp.oh : jcp.ih;
438 const int outw = is_fwd ? jcp.ow : jcp.iw;
440 array_offset_calculator<float, 5> input(inp_ptr,
441 MB, jcp.dimK/jcp.dimK_reg_block, inph, inpw, jcp.dimK_reg_block);
442 array_offset_calculator<float, 5> output(out_ptr,
443 MB, jcp.dimM/jcp.dimM_simd_block, outh, outw, jcp.dimM_simd_block);
444 array_offset_calculator<float, 6> weights(wei_ptr,
445 jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw,
446 jcp.ic_simd_block, jcp.oc_simd_block);
447 array_offset_calculator<float, 2> bias(bias_ptr,
448 jcp.oc/jcp.oc_simd_block, jcp.oc_simd_block);
450 array_offset_calculator<float, 8> U((float *)((this->scratchpad_)->U_ptr()),
454 jcp.dimM_block * jcp.dimM_reg_block, jcp.dimK_block,
455 jcp.dimK_reg_block, jcp.dimM_simd_block);
457 array_offset_calculator<float, 8> M(
459 ? (this->scratchpad_)->M_ptr()
460 : (this->scratchpad_)->V_ptr())),
461 0, jcp.dimM_nb_block, alpha, alpha,
462 jcp.dimN_block, jcp.dimM_block * jcp.dimM_reg_block,
463 jcp.dimN_reg_block, jcp.dimM_simd_block);
464 array_offset_calculator<float, 8> V(
466 ? (this->scratchpad_)->V_ptr()
467 : (this->scratchpad_)->M_ptr())),
468 0, alpha, alpha, jcp.dimN_block,
469 jcp.dimK_nb_block, jcp.dimK_block,
470 jcp.dimN_reg_block, jcp.dimK_reg_block);
472 const bool want_padded_bias = jcp.with_bias
473 && jcp.oc_without_padding != jcp.oc;
474 float last_slice_bias[simd_w] = {0};
475 if (want_padded_bias) {
476 for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
477 last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
482 #pragma omp for collapse(4) schedule(static)
483 for (int ofm1 = 0; ofm1 < jcp.nb_oc; ofm1++) {
484 for (int ifm1 = 0; ifm1 < jcp.nb_ic; ifm1++) {
485 for (int ofm2 = 0; ofm2 < jcp.oc_block * jcp.oc_reg_block; ofm2++) {
486 for (int ifm2 = 0; ifm2 < jcp.ic_block * jcp.ic_reg_block;
488 float *U_base_ptr = is_fwd
489 ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0))
490 : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0));
491 weight_transform_data(jcp,
493 ofm1 * jcp.oc_block * jcp.oc_reg_block + ofm2,
494 ifm1 * jcp.ic_block * jcp.ic_reg_block + ifm2,
502 int ithr = omp_get_thread_num();
504 #pragma omp for schedule(static)
505 for (int tile_block = 0; tile_block < jcp.tile_block; tile_block++) {
506 for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++) {
507 for (int K_blk2 = 0; K_blk2 < jcp.dimK_block; K_blk2++) {
509 input_transform_tileblock_data(
511 &(input(0, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, 0)),
512 &(V(ithr, 0, 0, 0, K_blk1, K_blk2, 0, 0)));
516 for (int oj = 0; oj < alpha; oj++) {
517 for (int oi = 0; oi < alpha; oi++) {
518 for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++)
519 for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++)
520 for (int N_blk = 0; N_blk < jcp.dimN_block; N_blk++)
521 kernel_->gemm_loop_ker(
522 (float *)&(M(ithr, M_blk1, oj, oi,
524 (const float *)&(U(M_blk1, oj, oi, K_blk1,
526 (const float *)&(V(ithr, oj, oi,
527 N_blk, K_blk1, 0, 0, 0)), K_blk1);
531 for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++) {
532 for (int M_blk2 = 0; M_blk2 < jcp.dimM_block * jcp.dimM_reg_block;
535 M_blk1 * jcp.dimM_block * jcp.dimM_reg_block + M_blk2;
537 float *bias_ptr = want_padded_bias
538 && M_blk == jcp.dimM / jcp.dimM_simd_block - 1
539 ? last_slice_bias : &bias(M_blk, 0);
541 output_transform_tileblock_data(tile_block, jcp, p_ops,
542 &(M(ithr, M_blk1, 0, 0, 0, M_blk2, 0, 0)),
543 &(output(0, M_blk, 0, 0, 0)), bias_ptr);
551 _jit_avx512_core_convolution_winograd_t<true>::_execute_data_W_SGD(
552 const int, float *, float *, float *, float *);
554 _jit_avx512_core_convolution_winograd_t<false>::_execute_data_W_SGD(
555 const int, float *, float *, float *, float *);
559 void subarray_sum(size_t num_arrs, float *output, size_t nelems,
560 float *input_ptrs[], size_t input_starts[], size_t input_ends[]) {
561 using namespace nstl;
562 const size_t block_size = 16 * 1024 / sizeof(float);
563 const size_t blocks_number = nelems / block_size;
564 const size_t tail = nelems % block_size;
568 const int ithr = omp_get_thread_num();
569 const int nthr = omp_get_num_threads();
570 size_t start{ 0 }, end{ 0 };
571 balance211(blocks_number, nthr, ithr, start, end);
573 for (size_t nb = start; nb < end; ++nb) {
574 size_t start_e = nb * block_size;
575 size_t end_e = start_e + block_size;
576 size_t input_start = max(start_e, min(input_starts[0], end_e));
577 size_t input_end = max(start_e, min(input_ends[0], end_e));
580 for (size_t e = start_e; e < input_start; e++) {
585 for (size_t e = input_start; e < input_end; e++) {
586 output[e] = input_ptrs[0][e];
590 for (size_t e = input_end; e < end_e; e++) {
594 for (size_t a = 1; a < num_arrs; a++) {
595 input_start = max(start_e, input_starts[a]);
596 input_end = min(input_ends[a], end_e);
599 for (size_t e = input_start; e < input_end; e++) {
600 output[e] += input_ptrs[a][e];
605 if (tail != 0 && ithr == nthr - 1) {
606 size_t start_e = nelems - tail;
607 size_t end_e = nelems;
608 size_t input_start = max(start_e, min(input_starts[0], end_e));
609 size_t input_end = max(start_e, min(input_ends[0], end_e));
612 for (size_t e = start_e; e < input_start; e++) {
617 for (size_t e = input_start; e < input_end; e++) {
618 output[e] = input_ptrs[0][e];
622 for (size_t e = input_end; e < end_e; e++) {
626 for (size_t a = 1; a < num_arrs; a++) {
627 input_start = max(start_e, input_starts[a]);
628 input_end = min(input_ends[a], end_e);
631 for (size_t e = input_start; e < input_end; e++) {
632 output[e] += input_ptrs[a][e];
639 const int max_threads_number = 1024;
641 // Sum to the first buffer array
642 void array_sum(size_t num_arrs, float *output,
643 size_t nelems, float *input_ptrs[], bool reduce_to_first = true) {
644 const size_t block_size = 16 * 1024 / sizeof(float);
645 const size_t blocks_number = nelems / block_size;
646 const size_t tail = nelems % block_size;
650 const size_t ithr = omp_get_thread_num();
651 const size_t nthr = omp_get_num_threads();
652 size_t start{ 0 }, end{ 0 };
653 balance211(blocks_number, nthr, ithr, start, end);
655 for (size_t nb = start; nb < end; ++nb) {
656 size_t start_e = nb * block_size;
657 size_t end_e = start_e + block_size;
658 if (!reduce_to_first) {
660 for (size_t e = start_e; e < end_e; e++) {
661 output[e] = input_ptrs[0][e];
664 for (size_t a = 1; a < num_arrs; a++) {
666 for (size_t e = start_e; e < end_e; e++) {
667 output[e] += input_ptrs[a][e];
672 if (tail != 0 && ithr == nthr - 1) {
673 size_t start_e = nelems - tail;
674 size_t end_e = nelems;
675 if (!reduce_to_first) {
677 for (size_t e = start_e; e < end_e; e++) {
678 output[e] = input_ptrs[0][e];
681 for (size_t a = 1; a < num_arrs; a++) {
683 for (size_t e = start_e; e < end_e; e++) {
684 output[e] += input_ptrs[a][e];
692 void jit_avx512_core_convolution_winograd_bwd_weights_t::
693 _execute_backward_weights_SDGtWo() {
694 const auto &jcp = kernel_->jcp;
695 const int nthreads = scratchpad_->num_threads();
697 array_offset_calculator<float, 5> src((float *)this->input_memory(0),
698 jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w);
699 array_offset_calculator<float, 5> diff_dst((float *)this->input_memory(1),
700 jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w);
701 array_offset_calculator<float, 6> diff_weights((float *)this->memory(0),
702 jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
704 array_offset_calculator<float, 8> Us((float *)(scratchpad_->U_ptr()),
706 jcp.oc_block, jcp.ic_block,
711 int U_sz = nthreads * alpha * alpha * jcp.oc / jcp.nb_oc
712 * jcp.ic / jcp.nb_ic * sizeof(float);
713 array_offset_calculator<float, 7>diff_weights_prv(
714 (float *)(scratchpad_->U_ptr() + U_sz),
715 0, jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
717 array_offset_calculator<float, 8> M((float *)(scratchpad_->M_ptr()),
720 jcp.nb_tile_block_ur,
725 array_offset_calculator<float, 7> V((float *)(scratchpad_->V_ptr()),
728 jcp.nb_tile_block_ur,
732 array_offset_calculator<float, 2> diff_bias_prv(
733 (float *)(scratchpad_->bias_ptr()), nthreads, jcp.oc);
735 auto trans_ker_p = jit_wino_transform_call_s();
736 float I[alpha][alpha][simd_w];
737 float T[alpha][alpha][simd_w];
738 float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
739 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
740 float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f, 0.119514472455649f,
741 0.430252100840336f, 0.168067226890756f, 0.179271708683473f, 0.403361344537815f,
743 float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f};
745 #pragma omp parallel firstprivate(trans_ker_p, I, T)
748 #pragma omp for nowait collapse(2)
749 for (int ithr = 0; ithr < nthreads; ithr++) {
750 for (int ofm = 0; ofm < jcp.oc / simd_w; ofm++) {
751 float *pdbias = &(diff_bias_prv(ithr, ofm * simd_w));
753 for (int v = 0; v < simd_w; v++) {
760 int ithr = omp_get_thread_num();
761 for (int ifm1 = 0; ifm1 < jcp.nb_ic; ++ifm1) {
764 for (int tblk1 = 0; tblk1 < jcp.tile_block; ++tblk1) {
765 int tile_index = tblk1 * jcp.nb_tile_block_ur * jcp.tile_block_ur;
766 int img = tile_index / (jcp.itiles * jcp.jtiles);
767 trans_ker_p.ti = tile_index % jcp.itiles;
768 trans_ker_p.tj = (tile_index / jcp.itiles) % jcp.jtiles;
771 trans_ker_p.G = G_I_3x3_4x4;
772 for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) {
773 int ifm = ifm1 * jcp.ic_block + ifm2;
774 trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0));
775 trans_ker_p.dst = (float *)&(V(ithr, 0, 0, ifm2, 0, 0, 0));
776 kernel_->src_transform(&trans_ker_p);
779 for (int ofm1 = 0; ofm1 < jcp.nb_oc; ++ofm1) {
780 trans_ker_p.G = G_W_3x3_4x4;
781 for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) {
782 int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block;
783 trans_ker_p.src = (float *)&(diff_dst(img, ofm, 0, 0, 0));
784 trans_ker_p.dst = (float *)&(M(ithr, 0, 0, ofm2, 0, 0, 0, 0));
785 if (jcp.with_bias && ifm1 == 0) {
786 trans_ker_p.bias = (float *)&(diff_bias_prv(ithr, ofm * simd_w));
787 kernel_->diff_dst_transform_wbias(&trans_ker_p);
789 kernel_->diff_dst_transform(&trans_ker_p);
793 for (int oj = 0; oj < alpha; ++oj) {
794 for (int oi = 0; oi < alpha; ++oi) {
795 kernel_->gemm_loop_ker_first_iter(
796 &(Us(ithr, oj, oi, 0, 0, 0, 0, 0)),
797 &(M(ithr, oj, oi, 0, 0, 0, 0, 0)),
798 &(V(ithr, oj, oi, 0, 0, 0, 0)));
801 trans_ker_p.G = G_O_3x3_4x4;
802 for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) {
803 for (int ofm3 = 0; ofm3 < jcp.oc_reg_block; ++ofm3) {
804 int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block
806 for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) {
807 int ifm = ifm1 * jcp.ic_block + ifm2;
808 trans_ker_p.src = (float *)&(Us(ithr, 0, 0,
809 ofm2, ifm2, 0, ofm3, 0));
810 trans_ker_p.dst = (float *)&(diff_weights_prv(ithr,
811 ofm, ifm, 0, 0, 0, 0));
812 if (first_tblk == 0) {
813 kernel_->diff_weights_transform(&trans_ker_p);
815 kernel_->diff_weights_transform_accum(&trans_ker_p);
826 // Reduce diff-weights
828 float *output = (float *)(this->memory(0));
829 float *input_base = (float *)(scratchpad_->U_ptr() + U_sz);
830 int nelems = jcp.oc * jcp.ic * jcp.kh * jcp.kw;
831 float *input_ptrs[max_threads_number];
832 for (int i = 0; i < nthreads; ++i) {
833 input_ptrs[i] = input_base + nelems * i;
835 array_sum(nthreads, output, nelems, input_ptrs, false);
838 output = (float *)(this->memory(1));
839 input_base = (float *)(scratchpad_->bias_ptr());
840 for (int i = 0; i < nthreads; ++i) {
841 input_ptrs[i] = input_base + jcp.oc * i;
843 array_sum(nthreads, output, jcp.oc_without_padding, input_ptrs,
849 void jit_avx512_core_convolution_winograd_bwd_weights_t::
850 _execute_backward_weights_S_D_Giot_W() {
851 const auto &jcp = kernel_->jcp;
852 const int nthreads = scratchpad_->num_threads();
854 array_offset_calculator<float, 5> src((float *)this->input_memory(0),
855 jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w);
856 array_offset_calculator<float, 5> diff_dst((float *)this->input_memory(1),
857 jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w);
858 array_offset_calculator<float, 6> diff_weights((float *)this->memory(0),
859 jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
860 array_offset_calculator<float, 1> diff_bias((float *)this->memory(1), jcp.oc);
862 array_offset_calculator<float, 9> U((float *)(scratchpad_->U_ptr()),
863 jcp.nb_ic, jcp.nb_oc,
865 jcp.oc_block, jcp.ic_block,
870 int U_size = jcp.oc * jcp.ic * alpha * alpha * sizeof(float);
871 array_offset_calculator<float, 10> Us(
872 (float *)(scratchpad_->U_ptr() + U_size),
873 0, jcp.nb_ic, jcp.nb_oc,
875 jcp.oc_block, jcp.ic_block,
880 array_offset_calculator<float, 9> M((float *)(scratchpad_->M_ptr()),
885 jcp.nb_tile_block_ur,
890 array_offset_calculator<float, 8> V((float *)(scratchpad_->V_ptr()),
895 jcp.nb_tile_block_ur, jcp.tile_block_ur,
898 array_offset_calculator<float, 2> diff_bias_prv(
899 (float *)(scratchpad_->bias_ptr()), nthreads, jcp.oc);
901 size_t input_starts[max_threads_number];
902 size_t input_ends[max_threads_number];
903 size_t first_tblk = 0;
905 auto trans_ker_p = jit_wino_transform_call_s();
906 float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
907 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
908 float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f,
909 0.119514472455649f, 0.430252100840336f, 0.168067226890756f,
910 0.179271708683473f, 0.403361344537815f, 1.13777777777778f};
911 float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f};
912 float I[alpha][alpha][simd_w];
913 float T[alpha][alpha][simd_w];
915 #pragma omp parallel firstprivate(first_tblk, trans_ker_p, I, T)
918 #pragma omp for nowait collapse(2)
919 for (int ithr = 0; ithr < nthreads; ++ithr) {
920 for (int ofm = 0; ofm < jcp.oc; ++ofm) {
921 diff_bias_prv(ithr, ofm) = 0.0f;
926 trans_ker_p.G = G_I_3x3_4x4;
929 #pragma omp for collapse(3) nowait
930 for (int ifm1 = 0; ifm1 < jcp.nb_ic; ++ifm1) {
931 for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) {
932 for (int img = 0; img < jcp.mb; img++) {
933 size_t ifm = ifm1 * jcp.ic_block + ifm2;
934 size_t tile_base_index = img * (jcp.itiles * jcp.jtiles);
935 size_t tblk3 = tile_base_index % jcp.tile_block_ur;
936 size_t tblk2 = (tile_base_index / jcp.tile_block_ur)
937 % jcp.nb_tile_block_ur;
938 size_t tblk1 = (tile_base_index / jcp.tile_block_ur)
939 / jcp.nb_tile_block_ur;
940 trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3;
941 trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0));
942 trans_ker_p.dst = (float *)&(V(ifm1, tblk1, 0, 0, ifm2, 0, 0, 0));
943 kernel_->src_transform(&trans_ker_p);
948 int ithr = omp_get_thread_num();
949 trans_ker_p.G = G_W_3x3_4x4;
950 #pragma omp for collapse(3)
951 for (int ofm1 = 0; ofm1 < jcp.nb_oc; ++ofm1) {
952 for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) {
953 for (int img = 0; img < jcp.mb; ++img) {
954 int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block;
955 size_t tile_base_index = img * (jcp.itiles * jcp.jtiles);
956 size_t tblk3 = tile_base_index % jcp.tile_block_ur;
957 size_t tblk2 = (tile_base_index / jcp.tile_block_ur)
958 % jcp.nb_tile_block_ur;
959 size_t tblk1 = (tile_base_index / jcp.tile_block_ur)
960 / jcp.nb_tile_block_ur;
961 trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3;
962 trans_ker_p.src = (float *)&(diff_dst(img, ofm, 0, 0, 0));
963 trans_ker_p.dst = (float *)&(M(ofm1, tblk1, 0, 0, ofm2, 0, 0, 0, 0));
965 trans_ker_p.bias = (float *)&(diff_bias_prv(ithr, ofm * simd_w));
966 kernel_->diff_dst_transform_wbias(&trans_ker_p);
968 kernel_->diff_dst_transform(&trans_ker_p);
974 #pragma omp for collapse(5) nowait schedule(static)
975 for (int ifm1 = 0; ifm1 < jcp.nb_ic; ++ifm1) {
976 for (int ofm1 = 0; ofm1 < jcp.nb_oc; ++ofm1) {
977 for (int oj = 0; oj < alpha; ++oj) {
978 for (int oi = 0; oi < alpha; ++oi) {
979 for (int tblk1 = 0; tblk1 < jcp.tile_block; ++tblk1) {
980 if (first_tblk == 0) {
982 (float *)&(Us(ithr, ifm1, ofm1, oj, oi, 0, 0, 0,
984 - (float *)&(Us(ithr, 0, 0, 0, 0, 0, 0,
986 input_ends[ithr] = input_starts[ithr]
987 + jcp.oc_block * jcp.ic_block
988 * jcp.ic_simd_block * jcp.oc_reg_block
991 else if (tblk1 == 0) {
992 input_ends[ithr] += jcp.oc_block * jcp.ic_block
993 * jcp.ic_simd_block * jcp.oc_reg_block
997 if (first_tblk == 0 || tblk1 == 0) {
998 kernel_->gemm_loop_ker_first_iter(
999 &(Us(ithr, ifm1, ofm1, oj, oi,
1001 &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)),
1002 &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0)));
1004 kernel_->gemm_loop_ker(
1005 &(Us(ithr, ifm1, ofm1, oj, oi,
1007 &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)),
1008 &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0)));
1018 // Reduce diff-weights
1020 float *output = &(U(0, 0, 0, 0, 0, 0, 0, 0, 0));
1021 size_t nelems = jcp.ic * jcp.oc * alpha * alpha;
1022 float *input_ptrs[max_threads_number];
1023 for (int i = 0; i < nthreads; ++i)
1024 input_ptrs[i] = output + nelems * (i + 1);
1025 subarray_sum(nthreads, output, nelems, input_ptrs,
1026 input_starts, input_ends);
1029 trans_ker_p.G = G_O_3x3_4x4;
1030 #pragma omp parallel for collapse(5) firstprivate(trans_ker_p)
1031 for (int ifm1 = 0; ifm1 < jcp.nb_ic; ++ifm1) {
1032 for (int ofm1 = 0; ofm1 < jcp.nb_oc; ++ofm1) {
1033 for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) {
1034 for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) {
1035 for (int ofm3 = 0; ofm3 < jcp.oc_reg_block; ++ofm3) {
1036 int ofm = (ofm1 * jcp.oc_block + ofm2)
1037 * jcp.oc_reg_block + ofm3;
1038 int ifm = ifm1 * jcp.ic_block + ifm2;
1039 trans_ker_p.src = (float *)&(U(ifm1, ofm1, 0, 0,
1040 ofm2, ifm2, 0, ofm3, 0));
1041 trans_ker_p.dst = (float *)&(diff_weights(ofm, ifm,
1043 kernel_->diff_weights_transform(&trans_ker_p);
1050 if (jcp.with_bias) {
1051 #pragma omp parallel for
1052 for (int ofm1 = 0; ofm1 < jcp.oc / simd_w; ++ofm1) {
1053 float* pbias = &(diff_bias(ofm1 * simd_w));
1054 float *pbias_prv = &(diff_bias_prv(0, ofm1 * simd_w));
1056 const int blk_sz = ofm1 == jcp.oc / simd_w - 1
1057 ? jcp.oc_without_padding - ofm1 * simd_w : simd_w;
1060 for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) {
1061 pbias[ofm2] = pbias_prv[ofm2];
1064 for (int ithr = 1; ithr < nthreads; ++ithr) {
1065 pbias_prv = &(diff_bias_prv(ithr, ofm1 * simd_w));
1067 for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) {
1068 pbias[ofm2] += pbias_prv[ofm2];
1078 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s