1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifdef __INTEL_COMPILER
18 #include <immintrin.h>
21 #include "mkldnn_types.h"
23 #include "c_types_map.hpp"
24 #include "mkldnn_thread.hpp"
25 #include "type_helpers.hpp"
28 #include "jit_avx512_core_fp32_wino_conv_4x3.hpp"
31 #define pragma_unroll _Pragma("unroll")
41 using namespace mkldnn::impl::status;
42 using namespace mkldnn::impl::memory_format;
43 using namespace mkldnn::impl::memory_tracking::names;
44 using namespace mkldnn::impl::utils;
46 template <bool is_fwd>
47 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
48 ::weight_transform_data(const jit_conv_winograd_conf_t &jcp,
49 float *wp, float *twp) const
51 float G[] = {0.26890756302521f, 0.688403361344538f, 0.119514472455649f,
52 1.13777777777778f, 0.430252100840336f, 0.179271708683473f};
55 float Fw[alpha][alpha][simd_w][simd_w];
56 float F[kh][kw][simd_w][simd_w];
57 float T[alpha][3][simd_w];
58 auto p = jit_wino_transform_call_s();
67 kernel_->weights_transform_data_ker(&p);
71 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::output_transform_data
72 (int image, const jit_conv_winograd_conf_t &jcp,
73 const post_ops_t &p_ops, float *toutp, float *pout_b, float *bias) const {
75 float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f};
76 float Ow[alpha][alpha][simd_w];
77 float O[tile_size][tile_size][simd_w];
78 float T[tile_size][alpha][simd_w];
80 auto p = jit_wino_transform_call_s();
89 int tile_base_index = image * jcp.itiles * jcp.jtiles;
90 int tile_block_ur = tile_base_index % jcp.tile_block_ur;
91 int nb_tile_block_ur =
92 (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur;
94 (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur;
96 for (int tj = 0; tj < jcp.jtiles; tj++) {
97 for (int ti = 0; ti < jcp.itiles; ti++) {
99 p.tile_block_ur = tile_block_ur;
100 p.nb_tile_block_ur = nb_tile_block_ur;
101 p.tile_block = tile_block;
105 kernel_->output_transform_data_ker(&p);
108 if (tile_block_ur >= jcp.tile_block_ur) {
112 if (nb_tile_block_ur >= jcp.nb_tile_block_ur) {
113 nb_tile_block_ur = 0;
120 template<bool is_fwd>
121 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
122 ::output_transform_tileblock_data(int tile_block,
123 const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops,
124 float *toutp, float *outp, float *bias) const {
126 float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f};
127 float Ow[alpha][alpha][simd_w];
128 float O[tile_size][tile_size][simd_w];
129 float T[tile_size][alpha][simd_w];
131 auto p = jit_wino_transform_call_s();
140 int outw = is_fwd ? jcp.ow : jcp.iw;
141 int outh = is_fwd ? jcp.oh : jcp.ih;
143 int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur;
145 for (int nb_tile_block_ur = 0;
146 nb_tile_block_ur < jcp.nb_tile_block_ur;
147 nb_tile_block_ur++) {
149 for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur;
151 int img = tile_index / (jcp.jtiles * jcp.itiles);
152 int ti = tile_index % jcp.itiles;
153 int tj = (tile_index / jcp.itiles) % jcp.jtiles;
155 p.tile_block_ur = tile_block_ur;
156 p.nb_tile_block_ur = nb_tile_block_ur;
157 p.tile_block = tile_block;
160 p.dst = outp + img * (jcp.dimM / jcp.dimM_simd_block)
161 * outh * outw * jcp.dimM_simd_block;
163 kernel_->output_transform_data_ker(&p);
171 template<bool is_fwd>
172 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
173 ::input_transform_data(int image, const jit_conv_winograd_conf_t &jcp,
174 float *inp, float *tinp) const
176 float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
177 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
179 float Iw[alpha][alpha][simd_w];
180 float I[alpha][alpha][simd_w];
181 float T[alpha][alpha][simd_w];
183 auto p = jit_wino_transform_call_s();
192 int tile_base_index = image * jcp.itiles * jcp.jtiles;
193 int tile_block_ur = tile_base_index % jcp.tile_block_ur;
194 int nb_tile_block_ur =
195 (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur;
197 (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur;
199 for (int tj = 0; tj < jcp.jtiles; tj++) {
200 for (int ti = 0; ti < jcp.itiles; ti++) {
202 p.tile_block_ur = tile_block_ur;
203 p.nb_tile_block_ur = nb_tile_block_ur;
204 p.tile_block = tile_block;
208 kernel_->input_transform_data_ker(&p);
211 if (tile_block_ur >= jcp.tile_block_ur) {
215 if (nb_tile_block_ur >= jcp.nb_tile_block_ur) {
216 nb_tile_block_ur = 0;
223 template <bool is_fwd>
224 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
225 ::input_transform_tileblock_data(int tile_block,
226 const jit_conv_winograd_conf_t &jcp,
227 float *inp, float *tinp) const
229 float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
230 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
231 float Iw[alpha][alpha][simd_w];
232 float I[alpha][alpha][simd_w];
233 float T[alpha][alpha][simd_w];
235 const int inph = is_fwd ? jcp.ih : jcp.oh;
236 const int inpw = is_fwd ? jcp.iw : jcp.ow;
238 array_offset_calculator<float, 5> input(inp,
239 jcp.mb, jcp.dimK / simd_w, inph, inpw, simd_w);
240 array_offset_calculator<float, 7> output(tinp,
242 jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block,
243 jcp.dimN_reg_block, jcp.dimK_reg_block);
245 auto p = jit_wino_transform_call_s();
254 int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur;
256 for (int nb_tile_block_ur = 0;
257 nb_tile_block_ur < jcp.nb_tile_block_ur;
258 nb_tile_block_ur++) {
260 for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur;
263 int img = tile_index / (jcp.jtiles * jcp.itiles);
264 int ti = tile_index % jcp.itiles;
265 int tj = (tile_index / jcp.itiles) % jcp.jtiles;
266 float *pinp_b = &(input(img, 0, 0, 0, 0));
269 p.tile_block_ur = tile_block_ur;
270 p.nb_tile_block_ur = nb_tile_block_ur;
274 kernel_->input_transform_data_ker(&p);
281 template <bool is_fwd>
282 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::_execute_data_W_S_G_D(
283 const int MB, float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr,
284 const memory_tracking::grantor_t &scratchpad) const {
285 const auto &jcp = kernel_->jcp;
286 const auto &p_ops = attr_->post_ops_;
288 const int inph = is_fwd ? jcp.ih : jcp.oh;
289 const int inpw = is_fwd ? jcp.iw : jcp.ow;
290 const int outh = is_fwd ? jcp.oh : jcp.ih;
291 const int outw = is_fwd ? jcp.ow : jcp.iw;
294 FWD: dimM:oc, dimN:ntiles, dimK:ic,
295 BWD: dimM:ic, dimN:ntiles, dimK:oc,
296 FWD/BWD: V: src/diff_dst transform, U:weight transform,
297 M:dst/diff_src transform */
298 array_offset_calculator<float, 5> input(inp_ptr,
299 MB, jcp.dimK/jcp.dimK_reg_block, inph, inpw,
301 array_offset_calculator<float, 5> output(out_ptr,
302 MB, jcp.dimM/jcp.dimM_simd_block, outh, outw,
303 jcp.dimM_simd_block);
304 array_offset_calculator<float, 6> weights(wei_ptr,
305 jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw,
306 jcp.ic_simd_block, jcp.oc_simd_block);
307 array_offset_calculator<float, 2> bias(bias_ptr,
308 jcp.dimM/jcp.dimM_simd_block, jcp.dimM_simd_block);
310 array_offset_calculator<float, 8> M(is_fwd
311 ? scratchpad.template get<float>(key_wino_M)
312 : scratchpad.template get<float>(key_wino_V),
313 jcp.dimN_nb_block, jcp.dimM_nb_block,
315 jcp.dimN_block, jcp.dimM_block * jcp.dimM_reg_block,
316 jcp.dimN_reg_block, jcp.dimM_simd_block);
318 auto wino_wei = (jcp.prop_kind == prop_kind::forward_inference)
320 : scratchpad.template get<float>(key_wino_U);
322 array_offset_calculator<float, 8> U(wino_wei,
326 jcp.dimM_block * jcp.dimM_reg_block, jcp.dimK_block,
327 jcp.dimK_reg_block, jcp.dimM_simd_block);
328 array_offset_calculator<float, 8> V(is_fwd
329 ? scratchpad.template get<float>(key_wino_V)
330 : scratchpad.template get<float>(key_wino_M),
331 jcp.dimN_nb_block, alpha, alpha,
332 jcp.dimN_block, jcp.dimK_nb_block,
333 jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block);
335 const bool wants_padded_bias = jcp.with_bias
336 && jcp.oc_without_padding != jcp.oc;
337 float last_slice_bias[simd_w] = {0};
338 if (wants_padded_bias) {
339 for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
340 last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
345 parallel_nd_in_omp(MB, jcp.dimK_nb_block, jcp.dimK_block,
346 [&](int img, int K_blk1, int K_blk2) {
347 input_transform_data(img, jcp,
348 &(input(img, K_blk1 * jcp.dimK_block + K_blk2,
350 &(V(0, 0, 0, 0, K_blk1, K_blk2, 0, 0)));
353 if (jcp.prop_kind != prop_kind::forward_inference) {
354 parallel_nd_in_omp(jcp.nb_oc, jcp.nb_ic, (jcp.oc_block * jcp.oc_reg_block),
355 (jcp.ic_block * jcp.ic_reg_block),
356 [&](int ofm1, int ifm1, int ofm2, int ifm2) {
357 float *U_base_ptr = is_fwd
358 ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0))
359 : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0));
360 weight_transform_data(jcp,
362 ofm1 * jcp.oc_block * jcp.oc_reg_block + ofm2,
363 ifm1 * jcp.ic_block * jcp.ic_reg_block + ifm2,
371 parallel_nd_in_omp(jcp.dimN_nb_block, alpha, alpha, jcp.dimM_nb_block,
372 [&](int N_blk1, int oj, int oi, int M_blk1) {
373 for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block;
375 for (int N_blk2 = 0; N_blk2 < jcp.dimN_block; N_blk2++)
376 kernel_->gemm_loop_ker(
377 (float *)&(M(N_blk1, M_blk1, oj, oi,
379 (const float *)&(U(M_blk1, oj, oi,
380 K_blk1, 0, 0, 0, 0)),
381 (const float *)&(V(N_blk1, oj, oi,
382 N_blk2, K_blk1, 0, 0, 0)), K_blk1);
387 parallel_nd_in_omp(MB, jcp.dimM_nb_block, (jcp.dimM_block * jcp.dimM_reg_block),
388 [&](int img, int M_blk1, int M_blk2) {
390 M_blk1 * jcp.dimM_block * jcp.dimM_reg_block + M_blk2;
392 float *bias_ptr = wants_padded_bias
393 && M_blk == jcp.dimM / jcp.dimM_simd_block - 1
394 ? last_slice_bias : &bias(M_blk, 0);
395 output_transform_data(img, jcp, p_ops,
396 &(M(0, M_blk1, 0, 0, 0, M_blk2, 0, 0)),
397 &(output(img, M_blk, 0, 0, 0)), bias_ptr);
402 template <bool is_fwd>
403 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::_execute_data_W_SGD(const int MB,
404 float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr,
405 const memory_tracking::grantor_t &scratchpad) const {
407 const auto &jcp = kernel_->jcp;
408 const auto &p_ops = attr_->post_ops_;
410 const int inph = is_fwd ? jcp.ih : jcp.oh;
411 const int inpw = is_fwd ? jcp.iw : jcp.ow;
412 const int outh = is_fwd ? jcp.oh : jcp.ih;
413 const int outw = is_fwd ? jcp.ow : jcp.iw;
415 array_offset_calculator<float, 5> input(inp_ptr,
416 MB, jcp.dimK/jcp.dimK_reg_block, inph, inpw, jcp.dimK_reg_block);
417 array_offset_calculator<float, 5> output(out_ptr,
418 MB, jcp.dimM/jcp.dimM_simd_block, outh, outw, jcp.dimM_simd_block);
419 array_offset_calculator<float, 6> weights(wei_ptr,
420 jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw,
421 jcp.ic_simd_block, jcp.oc_simd_block);
422 array_offset_calculator<float, 2> bias(bias_ptr,
423 jcp.oc/jcp.oc_simd_block, jcp.oc_simd_block);
425 auto wino_wei = (jcp.prop_kind == prop_kind::forward_inference)
427 : scratchpad.template get<float>(key_wino_U);
429 array_offset_calculator<float, 8> U(wino_wei,
433 jcp.dimM_block * jcp.dimM_reg_block, jcp.dimK_block,
434 jcp.dimK_reg_block, jcp.dimM_simd_block);
436 array_offset_calculator<float, 8> M(is_fwd
437 ? scratchpad.template get<float>(key_wino_M)
438 : scratchpad.template get<float>(key_wino_V),
439 0, jcp.dimM_nb_block, alpha, alpha,
440 jcp.dimN_block, jcp.dimM_block * jcp.dimM_reg_block,
441 jcp.dimN_reg_block, jcp.dimM_simd_block);
442 array_offset_calculator<float, 8> V(is_fwd
443 ? scratchpad.template get<float>(key_wino_V)
444 : scratchpad.template get<float>(key_wino_M),
445 0, alpha, alpha, jcp.dimN_block,
446 jcp.dimK_nb_block, jcp.dimK_block,
447 jcp.dimN_reg_block, jcp.dimK_reg_block);
449 const bool wants_padded_bias = jcp.with_bias
450 && jcp.oc_without_padding != jcp.oc;
451 float last_slice_bias[simd_w] = {0};
452 if (wants_padded_bias) {
453 for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
454 last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
457 if (jcp.prop_kind != prop_kind::forward_inference) {
459 parallel_nd(jcp.nb_oc, jcp.nb_ic, (jcp.oc_block * jcp.oc_reg_block), (jcp.ic_block * jcp.ic_reg_block),
460 [&](int ofm1, int ifm1, int ofm2, int ifm2) {
461 float *U_base_ptr = is_fwd
462 ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0))
463 : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0));
464 weight_transform_data(jcp,
466 ofm1 * jcp.oc_block * jcp.oc_reg_block + ofm2,
467 ifm1 * jcp.ic_block * jcp.ic_reg_block + ifm2,
476 int ithr = mkldnn_get_thread_num();
478 PRAGMA_OMP(for schedule(static))
479 for (int tile_block = 0; tile_block < jcp.tile_block; tile_block++) {
480 for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++) {
481 for (int K_blk2 = 0; K_blk2 < jcp.dimK_block; K_blk2++) {
483 input_transform_tileblock_data(
485 &(input(0, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, 0)),
486 &(V(ithr, 0, 0, 0, K_blk1, K_blk2, 0, 0)));
490 for (int oj = 0; oj < alpha; oj++) {
491 for (int oi = 0; oi < alpha; oi++) {
492 for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++)
493 for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++)
494 for (int N_blk = 0; N_blk < jcp.dimN_block; N_blk++)
495 kernel_->gemm_loop_ker(
496 (float *)&(M(ithr, M_blk1, oj, oi,
498 (const float *)&(U(M_blk1, oj, oi, K_blk1,
500 (const float *)&(V(ithr, oj, oi,
501 N_blk, K_blk1, 0, 0, 0)), K_blk1);
505 for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++) {
506 for (int M_blk2 = 0; M_blk2 < jcp.dimM_block * jcp.dimM_reg_block;
509 M_blk1 * jcp.dimM_block * jcp.dimM_reg_block + M_blk2;
511 float *bias_ptr = wants_padded_bias
512 && M_blk == jcp.dimM / jcp.dimM_simd_block - 1
513 ? last_slice_bias : &bias(M_blk, 0);
515 output_transform_tileblock_data(tile_block, jcp, p_ops,
516 &(M(ithr, M_blk1, 0, 0, 0, M_blk2, 0, 0)),
517 &(output(0, M_blk, 0, 0, 0)), bias_ptr);
524 template struct _jit_avx512_core_fp32_wino_conv_4x3_t<true>;
525 template struct _jit_avx512_core_fp32_wino_conv_4x3_t<false>;
529 void subarray_sum(size_t num_arrs, float *output, size_t nelems,
530 float *input_ptrs[], size_t input_starts[], size_t input_ends[]) {
531 using namespace nstl;
532 const size_t block_size = 16 * 1024 / sizeof(float);
533 const size_t blocks_number = nelems / block_size;
534 const size_t tail = nelems % block_size;
538 const int ithr = mkldnn_get_thread_num();
539 const int nthr = mkldnn_get_num_threads();
540 size_t start{ 0 }, end{ 0 };
541 balance211(blocks_number, nthr, ithr, start, end);
543 for (size_t nb = start; nb < end; ++nb) {
544 size_t start_e = nb * block_size;
545 size_t end_e = start_e + block_size;
546 size_t input_start = max(start_e, min(input_starts[0], end_e));
547 size_t input_end = max(start_e, min(input_ends[0], end_e));
550 for (size_t e = start_e; e < input_start; e++) {
555 for (size_t e = input_start; e < input_end; e++) {
556 output[e] = input_ptrs[0][e];
560 for (size_t e = input_end; e < end_e; e++) {
564 for (size_t a = 1; a < num_arrs; a++) {
565 input_start = max(start_e, input_starts[a]);
566 input_end = min(input_ends[a], end_e);
569 for (size_t e = input_start; e < input_end; e++) {
570 output[e] += input_ptrs[a][e];
575 if (tail != 0 && ithr == nthr - 1) {
576 size_t start_e = nelems - tail;
577 size_t end_e = nelems;
578 size_t input_start = max(start_e, min(input_starts[0], end_e));
579 size_t input_end = max(start_e, min(input_ends[0], end_e));
582 for (size_t e = start_e; e < input_start; e++) {
587 for (size_t e = input_start; e < input_end; e++) {
588 output[e] = input_ptrs[0][e];
592 for (size_t e = input_end; e < end_e; e++) {
596 for (size_t a = 1; a < num_arrs; a++) {
597 input_start = max(start_e, input_starts[a]);
598 input_end = min(input_ends[a], end_e);
601 for (size_t e = input_start; e < input_end; e++) {
602 output[e] += input_ptrs[a][e];
609 const int max_threads_number = 1024;
611 // Sum to the first buffer array
612 void array_sum(size_t num_arrs, float *output,
613 size_t nelems, float *input_ptrs[], bool reduce_to_first = true) {
614 const size_t block_size = 16 * 1024 / sizeof(float);
615 const size_t blocks_number = nelems / block_size;
616 const size_t tail = nelems % block_size;
620 const size_t ithr = mkldnn_get_thread_num();
621 const size_t nthr = mkldnn_get_num_threads();
622 size_t start{ 0 }, end{ 0 };
623 balance211(blocks_number, nthr, ithr, start, end);
625 for (size_t nb = start; nb < end; ++nb) {
626 size_t start_e = nb * block_size;
627 size_t end_e = start_e + block_size;
628 if (!reduce_to_first) {
630 for (size_t e = start_e; e < end_e; e++) {
631 output[e] = input_ptrs[0][e];
634 for (size_t a = 1; a < num_arrs; a++) {
636 for (size_t e = start_e; e < end_e; e++) {
637 output[e] += input_ptrs[a][e];
642 if (tail != 0 && ithr == nthr - 1) {
643 size_t start_e = nelems - tail;
644 size_t end_e = nelems;
645 if (!reduce_to_first) {
647 for (size_t e = start_e; e < end_e; e++) {
648 output[e] = input_ptrs[0][e];
651 for (size_t a = 1; a < num_arrs; a++) {
653 for (size_t e = start_e; e < end_e; e++) {
654 output[e] += input_ptrs[a][e];
662 void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t::
663 _execute_backward_weights_SDGtWo(
664 const memory_tracking::grantor_t &scratchpad) const {
665 const auto &jcp = kernel_->jcp;
666 const int nthreads = jcp.nthr;
668 array_offset_calculator<float, 5> src((float *)this->input_memory(0),
669 jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w);
670 array_offset_calculator<float, 5> diff_dst((float *)this->input_memory(1),
671 jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w);
672 array_offset_calculator<float, 6> diff_weights((float *)this->memory(0),
673 jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
675 array_offset_calculator<float, 8> Us(scratchpad.get<float>(key_wino_U),
677 jcp.oc_block, jcp.ic_block,
682 const int U_sz = nthreads * alpha * alpha * jcp.oc / jcp.nb_oc
683 * jcp.ic / jcp.nb_ic;
684 array_offset_calculator<float, 7>diff_weights_prv(
685 scratchpad.get<float>(key_wino_U) + U_sz,
686 0, jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
688 array_offset_calculator<float, 8> M(scratchpad.get<float>(key_wino_M),
691 jcp.nb_tile_block_ur,
696 array_offset_calculator<float, 7> V(scratchpad.get<float>(key_wino_V),
699 jcp.nb_tile_block_ur,
703 array_offset_calculator<float, 2> diff_bias_prv(
704 scratchpad.get<float>(key_conv_bia_reduction), nthreads, jcp.oc);
706 auto trans_ker_p = jit_wino_transform_call_s();
707 float I[alpha][alpha][simd_w];
708 float T[alpha][alpha][simd_w];
709 float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
710 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
711 float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f, 0.119514472455649f,
712 0.430252100840336f, 0.168067226890756f, 0.179271708683473f, 0.403361344537815f,
714 float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f};
716 PRAGMA_OMP(parallel num_threads(nthreads) firstprivate(trans_ker_p, I, T))
719 parallel_nd_in_omp(nthreads, jcp.oc / simd_w,
720 [&](int ithr, int ofm){
721 float *pdbias = &(diff_bias_prv(ithr, ofm * simd_w));
723 for (int v = 0; v < simd_w; v++) {
729 int ithr = mkldnn_get_thread_num();
730 for (int ifm1 = 0; ifm1 < jcp.nb_ic; ++ifm1) {
733 for (int tblk1 = 0; tblk1 < jcp.tile_block; ++tblk1) {
734 int tile_index = tblk1 * jcp.nb_tile_block_ur * jcp.tile_block_ur;
735 int img = tile_index / (jcp.itiles * jcp.jtiles);
736 trans_ker_p.ti = tile_index % jcp.itiles;
737 trans_ker_p.tj = (tile_index / jcp.itiles) % jcp.jtiles;
740 trans_ker_p.G = G_I_3x3_4x4;
741 for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) {
742 int ifm = ifm1 * jcp.ic_block + ifm2;
743 trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0));
744 trans_ker_p.dst = (float *)&(V(ithr, 0, 0, ifm2, 0, 0, 0));
745 kernel_->src_transform(&trans_ker_p);
748 for (int ofm1 = 0; ofm1 < jcp.nb_oc; ++ofm1) {
749 trans_ker_p.G = G_W_3x3_4x4;
750 for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) {
751 int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block;
752 trans_ker_p.src = (float *)&(diff_dst(img, ofm, 0, 0, 0));
753 trans_ker_p.dst = (float *)&(M(ithr, 0, 0, ofm2, 0, 0, 0, 0));
754 if (jcp.with_bias && ifm1 == 0) {
755 trans_ker_p.bias = (float *)&(diff_bias_prv(ithr, ofm * simd_w));
756 kernel_->diff_dst_transform_wbias(&trans_ker_p);
758 kernel_->diff_dst_transform(&trans_ker_p);
762 for (int oj = 0; oj < alpha; ++oj) {
763 for (int oi = 0; oi < alpha; ++oi) {
764 kernel_->gemm_loop_ker_first_iter(
765 &(Us(ithr, oj, oi, 0, 0, 0, 0, 0)),
766 &(M(ithr, oj, oi, 0, 0, 0, 0, 0)),
767 &(V(ithr, oj, oi, 0, 0, 0, 0)));
770 trans_ker_p.G = G_O_3x3_4x4;
771 for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) {
772 for (int ofm3 = 0; ofm3 < jcp.oc_reg_block; ++ofm3) {
773 int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block
775 for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) {
776 int ifm = ifm1 * jcp.ic_block + ifm2;
777 trans_ker_p.src = (float *)&(Us(ithr, 0, 0,
778 ofm2, ifm2, 0, ofm3, 0));
779 trans_ker_p.dst = (float *)&(diff_weights_prv(ithr,
780 ofm, ifm, 0, 0, 0, 0));
781 if (first_tblk == 0) {
782 kernel_->diff_weights_transform(&trans_ker_p);
784 kernel_->diff_weights_transform_accum(&trans_ker_p);
795 // Reduce diff-weights
797 float *output = (float *)(this->memory(0));
798 float *input_base = scratchpad.get<float>(key_wino_U) + U_sz;
799 int nelems = jcp.oc * jcp.ic * jcp.kh * jcp.kw;
800 float *input_ptrs[max_threads_number];
801 for (int i = 0; i < nthreads; ++i) {
802 input_ptrs[i] = input_base + nelems * i;
804 array_sum(nthreads, output, nelems, input_ptrs, false);
807 output = (float *)(this->memory(1));
808 input_base = scratchpad.get<float>(key_conv_bia_reduction);
809 for (int i = 0; i < nthreads; ++i) {
810 input_ptrs[i] = input_base + jcp.oc * i;
812 array_sum(nthreads, output, jcp.oc_without_padding, input_ptrs,
818 void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t::
819 _execute_backward_weights_S_D_Giot_W(
820 const memory_tracking::grantor_t &scratchpad) const {
821 const auto &jcp = kernel_->jcp;
822 const int nthreads = jcp.nthr;
824 array_offset_calculator<float, 5> src((float *)this->input_memory(0),
825 jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w);
826 array_offset_calculator<float, 5> diff_dst((float *)this->input_memory(1),
827 jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w);
828 array_offset_calculator<float, 6> diff_weights((float *)this->memory(0),
829 jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
830 array_offset_calculator<float, 1> diff_bias((float *)this->memory(1), jcp.oc);
832 array_offset_calculator<float, 9> U(scratchpad.get<float>(key_wino_U),
833 jcp.nb_ic, jcp.nb_oc,
835 jcp.oc_block, jcp.ic_block,
840 const int U_size = jcp.oc * jcp.ic * alpha * alpha;
841 array_offset_calculator<float, 10> Us(
842 scratchpad.get<float>(key_wino_U) + U_size,
843 0, jcp.nb_ic, jcp.nb_oc,
845 jcp.oc_block, jcp.ic_block,
850 array_offset_calculator<float, 9> M(scratchpad.get<float>(key_wino_M),
855 jcp.nb_tile_block_ur,
860 array_offset_calculator<float, 8> V(scratchpad.get<float>(key_wino_V),
865 jcp.nb_tile_block_ur, jcp.tile_block_ur,
868 array_offset_calculator<float, 2> diff_bias_prv(
869 scratchpad.get<float>(key_conv_bia_reduction), nthreads, jcp.oc);
871 size_t input_starts[max_threads_number] = {0};
872 size_t input_ends[max_threads_number] = {0};
873 size_t first_tblk = 0;
875 auto trans_ker_p = jit_wino_transform_call_s();
876 float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
877 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
878 float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f,
879 0.119514472455649f, 0.430252100840336f, 0.168067226890756f,
880 0.179271708683473f, 0.403361344537815f, 1.13777777777778f};
881 float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f};
882 float I[alpha][alpha][simd_w];
883 float T[alpha][alpha][simd_w];
885 PRAGMA_OMP(parallel firstprivate(first_tblk, trans_ker_p, I, T))
888 parallel_nd_in_omp(nthreads, jcp.oc, [&](int ithr, int ofm) {
889 diff_bias_prv(ithr, ofm) = 0.0f;
893 trans_ker_p.G = G_I_3x3_4x4;
897 parallel_nd_in_omp(jcp.nb_ic, jcp.ic_block, jcp.mb,
898 [&](int ifm1, int ifm2, int img){
899 size_t ifm = ifm1 * jcp.ic_block + ifm2;
900 size_t tile_base_index = img * (jcp.itiles * jcp.jtiles);
901 size_t tblk3 = tile_base_index % jcp.tile_block_ur;
902 size_t tblk2 = (tile_base_index / jcp.tile_block_ur)
903 % jcp.nb_tile_block_ur;
904 size_t tblk1 = (tile_base_index / jcp.tile_block_ur)
905 / jcp.nb_tile_block_ur;
906 trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3;
907 trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0));
908 trans_ker_p.dst = (float *)&(V(ifm1, tblk1, 0, 0, ifm2, 0, 0, 0));
909 kernel_->src_transform(&trans_ker_p);
912 int ithr = mkldnn_get_thread_num();
913 trans_ker_p.G = G_W_3x3_4x4;
914 parallel_nd_in_omp(jcp.nb_oc, jcp.oc_block, jcp.mb,
915 [&](int ofm1, int ofm2, int img){
916 int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block;
917 size_t tile_base_index = img * (jcp.itiles * jcp.jtiles);
918 size_t tblk3 = tile_base_index % jcp.tile_block_ur;
919 size_t tblk2 = (tile_base_index / jcp.tile_block_ur)
920 % jcp.nb_tile_block_ur;
921 size_t tblk1 = (tile_base_index / jcp.tile_block_ur)
922 / jcp.nb_tile_block_ur;
923 trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3;
924 trans_ker_p.src = (float *)&(diff_dst(img, ofm, 0, 0, 0));
925 trans_ker_p.dst = (float *)&(M(ofm1, tblk1, 0, 0, ofm2, 0, 0, 0, 0));
927 trans_ker_p.bias = (float *)&(diff_bias_prv(ithr, ofm * simd_w));
928 kernel_->diff_dst_transform_wbias(&trans_ker_p);
930 kernel_->diff_dst_transform(&trans_ker_p);
936 parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, alpha, alpha, jcp.tile_block,
937 [&](int ifm1, int ofm1, int oj, int oi, int tblk1){
938 if (first_tblk == 0) {
940 (float *)&(Us(ithr, ifm1, ofm1, oj, oi, 0, 0, 0,
942 - (float *)&(Us(ithr, 0, 0, 0, 0, 0, 0,
944 input_ends[ithr] = input_starts[ithr]
945 + jcp.oc_block * jcp.ic_block
946 * jcp.ic_simd_block * jcp.oc_reg_block
949 else if (tblk1 == 0) {
950 input_ends[ithr] += jcp.oc_block * jcp.ic_block
951 * jcp.ic_simd_block * jcp.oc_reg_block
955 if (first_tblk == 0 || tblk1 == 0) {
956 kernel_->gemm_loop_ker_first_iter(
957 &(Us(ithr, ifm1, ofm1, oj, oi,
959 &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)),
960 &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0)));
962 kernel_->gemm_loop_ker(
963 &(Us(ithr, ifm1, ofm1, oj, oi,
965 &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)),
966 &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0)));
972 // Reduce diff-weights
974 float *output = &(U(0, 0, 0, 0, 0, 0, 0, 0, 0));
975 size_t nelems = jcp.ic * jcp.oc * alpha * alpha;
976 float *input_ptrs[max_threads_number];
977 for (int i = 0; i < nthreads; ++i)
978 input_ptrs[i] = output + nelems * (i + 1);
979 subarray_sum(nthreads, output, nelems, input_ptrs,
980 input_starts, input_ends);
983 trans_ker_p.G = G_O_3x3_4x4;
984 PRAGMA_OMP(parallel firstprivate(trans_ker_p))
986 parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block, jcp.oc_reg_block,
987 [&](int ifm1, int ofm1, int ofm2, int ifm2, int ofm3){
988 int ofm = (ofm1 * jcp.oc_block + ofm2)
989 * jcp.oc_reg_block + ofm3;
990 int ifm = ifm1 * jcp.ic_block + ifm2;
991 trans_ker_p.src = (float *)&(U(ifm1, ofm1, 0, 0,
992 ofm2, ifm2, 0, ofm3, 0));
993 trans_ker_p.dst = (float *)&(diff_weights(ofm, ifm,
995 kernel_->diff_weights_transform(&trans_ker_p);
1000 parallel_nd(jcp.oc / simd_w, [&](int ofm1) {
1001 float* pbias = &(diff_bias(ofm1 * simd_w));
1002 float *pbias_prv = &(diff_bias_prv(0, ofm1 * simd_w));
1004 const int blk_sz = ofm1 == jcp.oc / simd_w - 1
1005 ? jcp.oc_without_padding - ofm1 * simd_w : simd_w;
1008 for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) {
1009 pbias[ofm2] = pbias_prv[ofm2];
1012 for (int ithr = 1; ithr < nthreads; ++ithr) {
1013 pbias_prv = &(diff_bias_prv(ithr, ofm1 * simd_w));
1015 for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) {
1016 pbias[ofm2] += pbias_prv[ofm2];
1026 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s