1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifdef __INTEL_COMPILER
18 #include <immintrin.h>
21 #include "mkldnn_types.h"
23 #include "c_types_map.hpp"
24 #include "mkldnn_thread.hpp"
25 #include "type_helpers.hpp"
28 #include "jit_avx512_core_fp32_wino_conv_4x3.hpp"
31 #define pragma_unroll _Pragma("unroll")
41 using namespace mkldnn::impl::status;
42 using namespace mkldnn::impl::memory_format;
43 using namespace mkldnn::impl::memory_tracking::names;
44 using namespace mkldnn::impl::utils;
46 template <bool is_fwd>
47 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
48 ::weight_transform_data(const jit_conv_winograd_conf_t &jcp,
49 float *wp, float *twp) const
51 float G[] = {0.26890756302521f, 0.688403361344538f, 0.119514472455649f,
52 1.13777777777778f, 0.430252100840336f, 0.179271708683473f};
55 float Fw[alpha][alpha][simd_w][simd_w];
56 float F[kh][kw][simd_w][simd_w];
57 float T[alpha][3][simd_w];
58 auto p = jit_wino_transform_call_s();
67 kernel_->weights_transform_data_ker(&p);
71 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::output_transform_data
72 (int image, const jit_conv_winograd_conf_t &jcp,
73 const post_ops_t &p_ops, float *toutp, float *pout_b, float *bias) const {
75 float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f};
76 float Ow[alpha][alpha][simd_w];
77 float O[tile_size][tile_size][simd_w];
78 float T[tile_size][alpha][simd_w];
80 auto p = jit_wino_transform_call_s();
89 int tile_base_index = image * jcp.itiles * jcp.jtiles;
90 int tile_block_ur = tile_base_index % jcp.tile_block_ur;
91 int nb_tile_block_ur =
92 (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur;
94 (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur;
96 for (int tj = 0; tj < jcp.jtiles; tj++) {
97 for (int ti = 0; ti < jcp.itiles; ti++) {
99 p.tile_block_ur = tile_block_ur;
100 p.nb_tile_block_ur = nb_tile_block_ur;
101 p.tile_block = tile_block;
105 kernel_->output_transform_data_ker(&p);
108 if (tile_block_ur >= jcp.tile_block_ur) {
112 if (nb_tile_block_ur >= jcp.nb_tile_block_ur) {
113 nb_tile_block_ur = 0;
120 template<bool is_fwd>
121 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
122 ::output_transform_tileblock_data(int tile_block,
123 const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops,
124 float *toutp, float *outp, float *bias) const {
126 float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f};
127 float Ow[alpha][alpha][simd_w];
128 float O[tile_size][tile_size][simd_w];
129 float T[tile_size][alpha][simd_w];
131 auto p = jit_wino_transform_call_s();
140 int outw = is_fwd ? jcp.ow : jcp.iw;
141 int outh = is_fwd ? jcp.oh : jcp.ih;
143 int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur;
145 for (int nb_tile_block_ur = 0;
146 nb_tile_block_ur < jcp.nb_tile_block_ur;
147 nb_tile_block_ur++) {
149 for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur;
151 int img = tile_index / (jcp.jtiles * jcp.itiles);
152 int ti = tile_index % jcp.itiles;
153 int tj = (tile_index / jcp.itiles) % jcp.jtiles;
155 p.tile_block_ur = tile_block_ur;
156 p.nb_tile_block_ur = nb_tile_block_ur;
157 p.tile_block = tile_block;
160 p.dst = outp + img * (jcp.dimM / jcp.dimM_simd_block)
161 * outh * outw * jcp.dimM_simd_block;
163 kernel_->output_transform_data_ker(&p);
171 template<bool is_fwd>
172 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
173 ::input_transform_data(int image, const jit_conv_winograd_conf_t &jcp,
174 float *inp, float *tinp) const
176 float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
177 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
179 float Iw[alpha][alpha][simd_w];
180 float I[alpha][alpha][simd_w];
181 float T[alpha][alpha][simd_w];
183 auto p = jit_wino_transform_call_s();
192 int tile_base_index = image * jcp.itiles * jcp.jtiles;
193 int tile_block_ur = tile_base_index % jcp.tile_block_ur;
194 int nb_tile_block_ur =
195 (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur;
197 (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur;
199 for (int tj = 0; tj < jcp.jtiles; tj++) {
200 for (int ti = 0; ti < jcp.itiles; ti++) {
202 p.tile_block_ur = tile_block_ur;
203 p.nb_tile_block_ur = nb_tile_block_ur;
204 p.tile_block = tile_block;
208 kernel_->input_transform_data_ker(&p);
211 if (tile_block_ur >= jcp.tile_block_ur) {
215 if (nb_tile_block_ur >= jcp.nb_tile_block_ur) {
216 nb_tile_block_ur = 0;
223 template <bool is_fwd>
224 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
225 ::input_transform_tileblock_data(int tile_block,
226 const jit_conv_winograd_conf_t &jcp,
227 float *inp, float *tinp) const
229 float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
230 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
231 float Iw[alpha][alpha][simd_w];
232 float I[alpha][alpha][simd_w];
233 float T[alpha][alpha][simd_w];
235 const int inph = is_fwd ? jcp.ih : jcp.oh;
236 const int inpw = is_fwd ? jcp.iw : jcp.ow;
238 array_offset_calculator<float, 5> input(inp,
239 jcp.mb, jcp.dimK / simd_w, inph, inpw, simd_w);
240 array_offset_calculator<float, 7> output(tinp,
242 jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block,
243 jcp.dimN_reg_block, jcp.dimK_reg_block);
245 auto p = jit_wino_transform_call_s();
254 int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur;
256 for (int nb_tile_block_ur = 0;
257 nb_tile_block_ur < jcp.nb_tile_block_ur;
258 nb_tile_block_ur++) {
260 for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur;
263 int img = tile_index / (jcp.jtiles * jcp.itiles);
264 int ti = tile_index % jcp.itiles;
265 int tj = (tile_index / jcp.itiles) % jcp.jtiles;
266 float *pinp_b = &(input(img, 0, 0, 0, 0));
269 p.tile_block_ur = tile_block_ur;
270 p.nb_tile_block_ur = nb_tile_block_ur;
274 kernel_->input_transform_data_ker(&p);
281 template <bool is_fwd>
282 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::_execute_data_W_S_G_D(
283 const int MB, float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr,
284 const memory_tracking::grantor_t &scratchpad) const {
285 const auto &jcp = kernel_->jcp;
286 const auto &p_ops = attr_->post_ops_;
288 const int inph = is_fwd ? jcp.ih : jcp.oh;
289 const int inpw = is_fwd ? jcp.iw : jcp.ow;
290 const int outh = is_fwd ? jcp.oh : jcp.ih;
291 const int outw = is_fwd ? jcp.ow : jcp.iw;
294 FWD: dimM:oc, dimN:ntiles, dimK:ic,
295 BWD: dimM:ic, dimN:ntiles, dimK:oc,
296 FWD/BWD: V: src/diff_dst transform, U:weight transform,
297 M:dst/diff_src transform */
298 array_offset_calculator<float, 5> input(inp_ptr,
299 MB, jcp.dimK/jcp.dimK_reg_block, inph, inpw,
301 array_offset_calculator<float, 5> output(out_ptr,
302 MB, jcp.dimM/jcp.dimM_simd_block, outh, outw,
303 jcp.dimM_simd_block);
304 array_offset_calculator<float, 6> weights(wei_ptr,
305 jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw,
306 jcp.ic_simd_block, jcp.oc_simd_block);
307 array_offset_calculator<float, 2> bias(bias_ptr,
308 jcp.dimM/jcp.dimM_simd_block, jcp.dimM_simd_block);
310 array_offset_calculator<float, 8> M(is_fwd
311 ? scratchpad.template get<float>(key_wino_M)
312 : scratchpad.template get<float>(key_wino_V),
313 jcp.dimN_nb_block, jcp.dimM_nb_block,
315 jcp.dimN_block, jcp.dimM_block * jcp.dimM_reg_block,
316 jcp.dimN_reg_block, jcp.dimM_simd_block);
318 auto wino_wei = (jcp.prop_kind == prop_kind::forward_inference)
320 : scratchpad.template get<float>(key_wino_U);
322 array_offset_calculator<float, 8> U(wino_wei,
326 jcp.dimM_block * jcp.dimM_reg_block, jcp.dimK_block,
327 jcp.dimK_reg_block, jcp.dimM_simd_block);
328 array_offset_calculator<float, 8> V(is_fwd
329 ? scratchpad.template get<float>(key_wino_V)
330 : scratchpad.template get<float>(key_wino_M),
331 jcp.dimN_nb_block, alpha, alpha,
332 jcp.dimN_block, jcp.dimK_nb_block,
333 jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block);
335 const bool wants_padded_bias = jcp.with_bias
336 && jcp.oc_without_padding != jcp.oc;
337 float last_slice_bias[simd_w] = {0};
338 if (wants_padded_bias) {
339 for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
340 last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
343 #if MKLDNN_THR == MKLDNN_THR_OMP
344 #define PARALLEL_ND parallel_nd_in_omp
346 #define PARALLEL_ND parallel_nd
349 #if MKLDNN_THR == MKLDNN_THR_OMP
353 PARALLEL_ND(MB, jcp.dimK_nb_block, jcp.dimK_block,
354 [&](int img, int K_blk1, int K_blk2) {
355 input_transform_data(img, jcp,
356 &(input(img, K_blk1 * jcp.dimK_block + K_blk2,
358 &(V(0, 0, 0, 0, K_blk1, K_blk2, 0, 0)));
361 if (jcp.prop_kind != prop_kind::forward_inference) {
362 PARALLEL_ND(jcp.nb_oc, jcp.nb_ic, (jcp.oc_block * jcp.oc_reg_block),
363 (jcp.ic_block * jcp.ic_reg_block),
364 [&](int ofm1, int ifm1, int ofm2, int ifm2) {
365 float *U_base_ptr = is_fwd
366 ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0))
367 : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0));
368 weight_transform_data(jcp,
370 ofm1 * jcp.oc_block * jcp.oc_reg_block + ofm2,
371 ifm1 * jcp.ic_block * jcp.ic_reg_block + ifm2,
380 PARALLEL_ND(jcp.dimN_nb_block, alpha, alpha, jcp.dimM_nb_block,
381 [&](int N_blk1, int oj, int oi, int M_blk1) {
382 for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block;
384 for (int N_blk2 = 0; N_blk2 < jcp.dimN_block; N_blk2++)
385 kernel_->gemm_loop_ker(
386 (float *)&(M(N_blk1, M_blk1, oj, oi,
388 (const float *)&(U(M_blk1, oj, oi,
389 K_blk1, 0, 0, 0, 0)),
390 (const float *)&(V(N_blk1, oj, oi,
391 N_blk2, K_blk1, 0, 0, 0)), K_blk1);
396 PARALLEL_ND(MB, jcp.dimM_nb_block, (jcp.dimM_block * jcp.dimM_reg_block),
397 [&](int img, int M_blk1, int M_blk2) {
399 M_blk1 * jcp.dimM_block * jcp.dimM_reg_block + M_blk2;
401 float *bias_ptr = wants_padded_bias
402 && M_blk == jcp.dimM / jcp.dimM_simd_block - 1
403 ? last_slice_bias : &bias(M_blk, 0);
404 output_transform_data(img, jcp, p_ops,
405 &(M(0, M_blk1, 0, 0, 0, M_blk2, 0, 0)),
406 &(output(img, M_blk, 0, 0, 0)), bias_ptr);
413 template <bool is_fwd>
414 void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::_execute_data_W_SGD(const int MB,
415 float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr,
416 const memory_tracking::grantor_t &scratchpad) const {
418 const auto &jcp = kernel_->jcp;
419 const auto &p_ops = attr_->post_ops_;
421 const int inph = is_fwd ? jcp.ih : jcp.oh;
422 const int inpw = is_fwd ? jcp.iw : jcp.ow;
423 const int outh = is_fwd ? jcp.oh : jcp.ih;
424 const int outw = is_fwd ? jcp.ow : jcp.iw;
426 array_offset_calculator<float, 5> input(inp_ptr,
427 MB, jcp.dimK/jcp.dimK_reg_block, inph, inpw, jcp.dimK_reg_block);
428 array_offset_calculator<float, 5> output(out_ptr,
429 MB, jcp.dimM/jcp.dimM_simd_block, outh, outw, jcp.dimM_simd_block);
430 array_offset_calculator<float, 6> weights(wei_ptr,
431 jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw,
432 jcp.ic_simd_block, jcp.oc_simd_block);
433 array_offset_calculator<float, 2> bias(bias_ptr,
434 jcp.oc/jcp.oc_simd_block, jcp.oc_simd_block);
436 auto wino_wei = (jcp.prop_kind == prop_kind::forward_inference)
438 : scratchpad.template get<float>(key_wino_U);
440 array_offset_calculator<float, 8> U(wino_wei,
444 jcp.dimM_block * jcp.dimM_reg_block, jcp.dimK_block,
445 jcp.dimK_reg_block, jcp.dimM_simd_block);
447 array_offset_calculator<float, 8> M(is_fwd
448 ? scratchpad.template get<float>(key_wino_M)
449 : scratchpad.template get<float>(key_wino_V),
450 0, jcp.dimM_nb_block, alpha, alpha,
451 jcp.dimN_block, jcp.dimM_block * jcp.dimM_reg_block,
452 jcp.dimN_reg_block, jcp.dimM_simd_block);
453 array_offset_calculator<float, 8> V(is_fwd
454 ? scratchpad.template get<float>(key_wino_V)
455 : scratchpad.template get<float>(key_wino_M),
456 0, alpha, alpha, jcp.dimN_block,
457 jcp.dimK_nb_block, jcp.dimK_block,
458 jcp.dimN_reg_block, jcp.dimK_reg_block);
460 const bool wants_padded_bias = jcp.with_bias
461 && jcp.oc_without_padding != jcp.oc;
462 float last_slice_bias[simd_w] = {0};
463 if (wants_padded_bias) {
464 for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
465 last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
468 if (jcp.prop_kind != prop_kind::forward_inference) {
470 parallel_nd(jcp.nb_oc, jcp.nb_ic, (jcp.oc_block * jcp.oc_reg_block), (jcp.ic_block * jcp.ic_reg_block),
471 [&](int ofm1, int ifm1, int ofm2, int ifm2) {
472 float *U_base_ptr = is_fwd
473 ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0))
474 : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0));
475 weight_transform_data(jcp,
477 ofm1 * jcp.oc_block * jcp.oc_reg_block + ofm2,
478 ifm1 * jcp.ic_block * jcp.ic_reg_block + ifm2,
487 int ithr = mkldnn_get_thread_num();
489 PRAGMA_OMP(for schedule(static))
490 for (int tile_block = 0; tile_block < jcp.tile_block; tile_block++) {
491 for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++) {
492 for (int K_blk2 = 0; K_blk2 < jcp.dimK_block; K_blk2++) {
494 input_transform_tileblock_data(
496 &(input(0, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, 0)),
497 &(V(ithr, 0, 0, 0, K_blk1, K_blk2, 0, 0)));
501 for (int oj = 0; oj < alpha; oj++) {
502 for (int oi = 0; oi < alpha; oi++) {
503 for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++)
504 for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++)
505 for (int N_blk = 0; N_blk < jcp.dimN_block; N_blk++)
506 kernel_->gemm_loop_ker(
507 (float *)&(M(ithr, M_blk1, oj, oi,
509 (const float *)&(U(M_blk1, oj, oi, K_blk1,
511 (const float *)&(V(ithr, oj, oi,
512 N_blk, K_blk1, 0, 0, 0)), K_blk1);
516 for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++) {
517 for (int M_blk2 = 0; M_blk2 < jcp.dimM_block * jcp.dimM_reg_block;
520 M_blk1 * jcp.dimM_block * jcp.dimM_reg_block + M_blk2;
522 float *bias_ptr = wants_padded_bias
523 && M_blk == jcp.dimM / jcp.dimM_simd_block - 1
524 ? last_slice_bias : &bias(M_blk, 0);
526 output_transform_tileblock_data(tile_block, jcp, p_ops,
527 &(M(ithr, M_blk1, 0, 0, 0, M_blk2, 0, 0)),
528 &(output(0, M_blk, 0, 0, 0)), bias_ptr);
535 template struct _jit_avx512_core_fp32_wino_conv_4x3_t<true>;
536 template struct _jit_avx512_core_fp32_wino_conv_4x3_t<false>;
540 void subarray_sum(size_t num_arrs, float *output, size_t nelems,
541 float *input_ptrs[], size_t input_starts[], size_t input_ends[]) {
542 using namespace nstl;
543 const size_t block_size = 16 * 1024 / sizeof(float);
544 const size_t blocks_number = nelems / block_size;
545 const size_t tail = nelems % block_size;
549 const int ithr = mkldnn_get_thread_num();
550 const int nthr = mkldnn_get_num_threads();
551 size_t start{ 0 }, end{ 0 };
552 balance211(blocks_number, nthr, ithr, start, end);
554 for (size_t nb = start; nb < end; ++nb) {
555 size_t start_e = nb * block_size;
556 size_t end_e = start_e + block_size;
557 size_t input_start = max(start_e, min(input_starts[0], end_e));
558 size_t input_end = max(start_e, min(input_ends[0], end_e));
561 for (size_t e = start_e; e < input_start; e++) {
566 for (size_t e = input_start; e < input_end; e++) {
567 output[e] = input_ptrs[0][e];
571 for (size_t e = input_end; e < end_e; e++) {
575 for (size_t a = 1; a < num_arrs; a++) {
576 input_start = max(start_e, input_starts[a]);
577 input_end = min(input_ends[a], end_e);
580 for (size_t e = input_start; e < input_end; e++) {
581 output[e] += input_ptrs[a][e];
586 if (tail != 0 && ithr == nthr - 1) {
587 size_t start_e = nelems - tail;
588 size_t end_e = nelems;
589 size_t input_start = max(start_e, min(input_starts[0], end_e));
590 size_t input_end = max(start_e, min(input_ends[0], end_e));
593 for (size_t e = start_e; e < input_start; e++) {
598 for (size_t e = input_start; e < input_end; e++) {
599 output[e] = input_ptrs[0][e];
603 for (size_t e = input_end; e < end_e; e++) {
607 for (size_t a = 1; a < num_arrs; a++) {
608 input_start = max(start_e, input_starts[a]);
609 input_end = min(input_ends[a], end_e);
612 for (size_t e = input_start; e < input_end; e++) {
613 output[e] += input_ptrs[a][e];
620 const int max_threads_number = 1024;
622 // Sum to the first buffer array
623 void array_sum(size_t num_arrs, float *output,
624 size_t nelems, float *input_ptrs[], bool reduce_to_first = true) {
625 const size_t block_size = 16 * 1024 / sizeof(float);
626 const size_t blocks_number = nelems / block_size;
627 const size_t tail = nelems % block_size;
631 const size_t ithr = mkldnn_get_thread_num();
632 const size_t nthr = mkldnn_get_num_threads();
633 size_t start{ 0 }, end{ 0 };
634 balance211(blocks_number, nthr, ithr, start, end);
636 for (size_t nb = start; nb < end; ++nb) {
637 size_t start_e = nb * block_size;
638 size_t end_e = start_e + block_size;
639 if (!reduce_to_first) {
641 for (size_t e = start_e; e < end_e; e++) {
642 output[e] = input_ptrs[0][e];
645 for (size_t a = 1; a < num_arrs; a++) {
647 for (size_t e = start_e; e < end_e; e++) {
648 output[e] += input_ptrs[a][e];
653 if (tail != 0 && ithr == nthr - 1) {
654 size_t start_e = nelems - tail;
655 size_t end_e = nelems;
656 if (!reduce_to_first) {
658 for (size_t e = start_e; e < end_e; e++) {
659 output[e] = input_ptrs[0][e];
662 for (size_t a = 1; a < num_arrs; a++) {
664 for (size_t e = start_e; e < end_e; e++) {
665 output[e] += input_ptrs[a][e];
673 void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t::
674 _execute_backward_weights_SDGtWo(
675 const memory_tracking::grantor_t &scratchpad) const {
676 const auto &jcp = kernel_->jcp;
677 const int nthreads = jcp.nthr;
679 array_offset_calculator<float, 5> src((float *)this->input_memory(0),
680 jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w);
681 array_offset_calculator<float, 5> diff_dst((float *)this->input_memory(1),
682 jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w);
683 array_offset_calculator<float, 6> diff_weights((float *)this->memory(0),
684 jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
686 array_offset_calculator<float, 8> Us(scratchpad.get<float>(key_wino_U),
688 jcp.oc_block, jcp.ic_block,
693 const int U_sz = nthreads * alpha * alpha * jcp.oc / jcp.nb_oc
694 * jcp.ic / jcp.nb_ic;
695 array_offset_calculator<float, 7>diff_weights_prv(
696 scratchpad.get<float>(key_wino_U) + U_sz,
697 0, jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
699 array_offset_calculator<float, 8> M(scratchpad.get<float>(key_wino_M),
702 jcp.nb_tile_block_ur,
707 array_offset_calculator<float, 7> V(scratchpad.get<float>(key_wino_V),
710 jcp.nb_tile_block_ur,
714 array_offset_calculator<float, 2> diff_bias_prv(
715 scratchpad.get<float>(key_conv_bia_reduction), nthreads, jcp.oc);
717 auto trans_ker_p = jit_wino_transform_call_s();
718 float I[alpha][alpha][simd_w];
719 float T[alpha][alpha][simd_w];
720 float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
721 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
722 float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f, 0.119514472455649f,
723 0.430252100840336f, 0.168067226890756f, 0.179271708683473f, 0.403361344537815f,
725 float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f};
727 PRAGMA_OMP(parallel num_threads(nthreads) firstprivate(trans_ker_p, I, T))
730 parallel_nd_in_omp(nthreads, jcp.oc / simd_w,
731 [&](int ithr, int ofm){
732 float *pdbias = &(diff_bias_prv(ithr, ofm * simd_w));
734 for (int v = 0; v < simd_w; v++) {
740 int ithr = mkldnn_get_thread_num();
741 for (int ifm1 = 0; ifm1 < jcp.nb_ic; ++ifm1) {
744 for (int tblk1 = 0; tblk1 < jcp.tile_block; ++tblk1) {
745 int tile_index = tblk1 * jcp.nb_tile_block_ur * jcp.tile_block_ur;
746 int img = tile_index / (jcp.itiles * jcp.jtiles);
747 trans_ker_p.ti = tile_index % jcp.itiles;
748 trans_ker_p.tj = (tile_index / jcp.itiles) % jcp.jtiles;
751 trans_ker_p.G = G_I_3x3_4x4;
752 for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) {
753 int ifm = ifm1 * jcp.ic_block + ifm2;
754 trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0));
755 trans_ker_p.dst = (float *)&(V(ithr, 0, 0, ifm2, 0, 0, 0));
756 kernel_->src_transform(&trans_ker_p);
759 for (int ofm1 = 0; ofm1 < jcp.nb_oc; ++ofm1) {
760 trans_ker_p.G = G_W_3x3_4x4;
761 for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) {
762 int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block;
763 trans_ker_p.src = (float *)&(diff_dst(img, ofm, 0, 0, 0));
764 trans_ker_p.dst = (float *)&(M(ithr, 0, 0, ofm2, 0, 0, 0, 0));
765 if (jcp.with_bias && ifm1 == 0) {
766 trans_ker_p.bias = (float *)&(diff_bias_prv(ithr, ofm * simd_w));
767 kernel_->diff_dst_transform_wbias(&trans_ker_p);
769 kernel_->diff_dst_transform(&trans_ker_p);
773 for (int oj = 0; oj < alpha; ++oj) {
774 for (int oi = 0; oi < alpha; ++oi) {
775 kernel_->gemm_loop_ker_first_iter(
776 &(Us(ithr, oj, oi, 0, 0, 0, 0, 0)),
777 &(M(ithr, oj, oi, 0, 0, 0, 0, 0)),
778 &(V(ithr, oj, oi, 0, 0, 0, 0)));
781 trans_ker_p.G = G_O_3x3_4x4;
782 for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) {
783 for (int ofm3 = 0; ofm3 < jcp.oc_reg_block; ++ofm3) {
784 int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block
786 for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) {
787 int ifm = ifm1 * jcp.ic_block + ifm2;
788 trans_ker_p.src = (float *)&(Us(ithr, 0, 0,
789 ofm2, ifm2, 0, ofm3, 0));
790 trans_ker_p.dst = (float *)&(diff_weights_prv(ithr,
791 ofm, ifm, 0, 0, 0, 0));
792 if (first_tblk == 0) {
793 kernel_->diff_weights_transform(&trans_ker_p);
795 kernel_->diff_weights_transform_accum(&trans_ker_p);
806 // Reduce diff-weights
808 float *output = (float *)(this->memory(0));
809 float *input_base = scratchpad.get<float>(key_wino_U) + U_sz;
810 int nelems = jcp.oc * jcp.ic * jcp.kh * jcp.kw;
811 float *input_ptrs[max_threads_number];
812 for (int i = 0; i < nthreads; ++i) {
813 input_ptrs[i] = input_base + nelems * i;
815 array_sum(nthreads, output, nelems, input_ptrs, false);
818 output = (float *)(this->memory(1));
819 input_base = scratchpad.get<float>(key_conv_bia_reduction);
820 for (int i = 0; i < nthreads; ++i) {
821 input_ptrs[i] = input_base + jcp.oc * i;
823 array_sum(nthreads, output, jcp.oc_without_padding, input_ptrs,
829 void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t::
830 _execute_backward_weights_S_D_Giot_W(
831 const memory_tracking::grantor_t &scratchpad) const {
832 const auto &jcp = kernel_->jcp;
833 const int nthreads = jcp.nthr;
835 array_offset_calculator<float, 5> src((float *)this->input_memory(0),
836 jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w);
837 array_offset_calculator<float, 5> diff_dst((float *)this->input_memory(1),
838 jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w);
839 array_offset_calculator<float, 6> diff_weights((float *)this->memory(0),
840 jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
841 array_offset_calculator<float, 1> diff_bias((float *)this->memory(1), jcp.oc);
843 array_offset_calculator<float, 9> U(scratchpad.get<float>(key_wino_U),
844 jcp.nb_ic, jcp.nb_oc,
846 jcp.oc_block, jcp.ic_block,
851 const int U_size = jcp.oc * jcp.ic * alpha * alpha;
852 array_offset_calculator<float, 10> Us(
853 scratchpad.get<float>(key_wino_U) + U_size,
854 0, jcp.nb_ic, jcp.nb_oc,
856 jcp.oc_block, jcp.ic_block,
861 array_offset_calculator<float, 9> M(scratchpad.get<float>(key_wino_M),
866 jcp.nb_tile_block_ur,
871 array_offset_calculator<float, 8> V(scratchpad.get<float>(key_wino_V),
876 jcp.nb_tile_block_ur, jcp.tile_block_ur,
879 array_offset_calculator<float, 2> diff_bias_prv(
880 scratchpad.get<float>(key_conv_bia_reduction), nthreads, jcp.oc);
882 size_t input_starts[max_threads_number] = {0};
883 size_t input_ends[max_threads_number] = {0};
884 size_t first_tblk = 0;
886 auto trans_ker_p = jit_wino_transform_call_s();
887 float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
888 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
889 float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f,
890 0.119514472455649f, 0.430252100840336f, 0.168067226890756f,
891 0.179271708683473f, 0.403361344537815f, 1.13777777777778f};
892 float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f};
893 float I[alpha][alpha][simd_w];
894 float T[alpha][alpha][simd_w];
896 PRAGMA_OMP(parallel firstprivate(first_tblk, trans_ker_p, I, T))
899 parallel_nd_in_omp(nthreads, jcp.oc, [&](int ithr, int ofm) {
900 diff_bias_prv(ithr, ofm) = 0.0f;
904 trans_ker_p.G = G_I_3x3_4x4;
908 parallel_nd_in_omp(jcp.nb_ic, jcp.ic_block, jcp.mb,
909 [&](int ifm1, int ifm2, int img){
910 size_t ifm = ifm1 * jcp.ic_block + ifm2;
911 size_t tile_base_index = img * (jcp.itiles * jcp.jtiles);
912 size_t tblk3 = tile_base_index % jcp.tile_block_ur;
913 size_t tblk2 = (tile_base_index / jcp.tile_block_ur)
914 % jcp.nb_tile_block_ur;
915 size_t tblk1 = (tile_base_index / jcp.tile_block_ur)
916 / jcp.nb_tile_block_ur;
917 trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3;
918 trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0));
919 trans_ker_p.dst = (float *)&(V(ifm1, tblk1, 0, 0, ifm2, 0, 0, 0));
920 kernel_->src_transform(&trans_ker_p);
923 int ithr = mkldnn_get_thread_num();
924 trans_ker_p.G = G_W_3x3_4x4;
925 parallel_nd_in_omp(jcp.nb_oc, jcp.oc_block, jcp.mb,
926 [&](int ofm1, int ofm2, int img){
927 int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block;
928 size_t tile_base_index = img * (jcp.itiles * jcp.jtiles);
929 size_t tblk3 = tile_base_index % jcp.tile_block_ur;
930 size_t tblk2 = (tile_base_index / jcp.tile_block_ur)
931 % jcp.nb_tile_block_ur;
932 size_t tblk1 = (tile_base_index / jcp.tile_block_ur)
933 / jcp.nb_tile_block_ur;
934 trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3;
935 trans_ker_p.src = (float *)&(diff_dst(img, ofm, 0, 0, 0));
936 trans_ker_p.dst = (float *)&(M(ofm1, tblk1, 0, 0, ofm2, 0, 0, 0, 0));
938 trans_ker_p.bias = (float *)&(diff_bias_prv(ithr, ofm * simd_w));
939 kernel_->diff_dst_transform_wbias(&trans_ker_p);
941 kernel_->diff_dst_transform(&trans_ker_p);
947 parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, alpha, alpha, jcp.tile_block,
948 [&](int ifm1, int ofm1, int oj, int oi, int tblk1){
949 if (first_tblk == 0) {
951 (float *)&(Us(ithr, ifm1, ofm1, oj, oi, 0, 0, 0,
953 - (float *)&(Us(ithr, 0, 0, 0, 0, 0, 0,
955 input_ends[ithr] = input_starts[ithr]
956 + jcp.oc_block * jcp.ic_block
957 * jcp.ic_simd_block * jcp.oc_reg_block
960 else if (tblk1 == 0) {
961 input_ends[ithr] += jcp.oc_block * jcp.ic_block
962 * jcp.ic_simd_block * jcp.oc_reg_block
966 if (first_tblk == 0 || tblk1 == 0) {
967 kernel_->gemm_loop_ker_first_iter(
968 &(Us(ithr, ifm1, ofm1, oj, oi,
970 &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)),
971 &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0)));
973 kernel_->gemm_loop_ker(
974 &(Us(ithr, ifm1, ofm1, oj, oi,
976 &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)),
977 &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0)));
983 // Reduce diff-weights
985 float *output = &(U(0, 0, 0, 0, 0, 0, 0, 0, 0));
986 size_t nelems = jcp.ic * jcp.oc * alpha * alpha;
987 float *input_ptrs[max_threads_number];
988 for (int i = 0; i < nthreads; ++i)
989 input_ptrs[i] = output + nelems * (i + 1);
990 subarray_sum(nthreads, output, nelems, input_ptrs,
991 input_starts, input_ends);
994 trans_ker_p.G = G_O_3x3_4x4;
995 PRAGMA_OMP(parallel firstprivate(trans_ker_p))
997 parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block, jcp.oc_reg_block,
998 [&](int ifm1, int ofm1, int ofm2, int ifm2, int ofm3){
999 int ofm = (ofm1 * jcp.oc_block + ofm2)
1000 * jcp.oc_reg_block + ofm3;
1001 int ifm = ifm1 * jcp.ic_block + ifm2;
1002 trans_ker_p.src = (float *)&(U(ifm1, ofm1, 0, 0,
1003 ofm2, ifm2, 0, ofm3, 0));
1004 trans_ker_p.dst = (float *)&(diff_weights(ofm, ifm,
1006 kernel_->diff_weights_transform(&trans_ker_p);
1010 if (jcp.with_bias) {
1011 parallel_nd(jcp.oc / simd_w, [&](int ofm1) {
1012 float* pbias = &(diff_bias(ofm1 * simd_w));
1013 float *pbias_prv = &(diff_bias_prv(0, ofm1 * simd_w));
1015 const int blk_sz = ofm1 == jcp.oc / simd_w - 1
1016 ? jcp.oc_without_padding - ofm1 * simd_w : simd_w;
1019 for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) {
1020 pbias[ofm2] = pbias_prv[ofm2];
1023 for (int ithr = 1; ithr < nthreads; ++ithr) {
1024 pbias_prv = &(diff_bias_prv(ithr, ofm1 * simd_w));
1026 for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) {
1027 pbias[ofm2] += pbias_prv[ofm2];
1037 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s