1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifdef __INTEL_COMPILER
18 #include <immintrin.h>
21 #include "mkldnn_types.h"
23 #include "c_types_map.hpp"
24 #include "mkldnn_thread.hpp"
25 #include "type_helpers.hpp"
28 #include "jit_avx512_common_convolution_winograd.hpp"
31 #define pragma_unroll _Pragma("unroll")
40 using namespace memory_tracking::names;
44 unsigned int LLC_cache_size = get_cache_size(3, false);
46 void inline load_ps(float *dest, const float *src_mem) {
47 #ifdef __INTEL_COMPILER
48 __m512 *Iv512 = (__m512 *)dest;
49 Iv512[0] = _mm512_load_ps(src_mem);
52 for (int v = 0; v < simd_w; v++) dest[v] = src_mem[v];
56 void inline store_output(float *dest, const float *data, bool streamout) {
57 #ifdef __INTEL_COMPILER
59 _mm512_stream_ps(dest, *((__m512 *)data));
61 _mm512_store_ps(dest, *((__m512 *)data));
64 for (int v = 0; v < simd_w; v++)
69 void inline accum_output(
70 float *dest, float *data, bool streamout, bool with_relu_postsum) {
71 #ifdef __INTEL_COMPILER
72 __m512 _data = _mm512_loadu_ps(data);
73 __m512 _dest = _mm512_loadu_ps(dest);
74 _data = _mm512_add_ps(_data, _dest);
75 if (with_relu_postsum)
76 _data = _mm512_max_ps(_data, _mm512_setzero_ps());
78 _mm512_stream_ps(dest, _data);
80 _mm512_store_ps(dest, _data);
83 for (int v = 0; v < simd_w; v++)
86 if (with_relu_postsum) {
88 for (int v = 0; v < simd_w; v++)
94 for (int v = 0; v < simd_w; v++)
100 using namespace mkldnn::impl::status;
101 using namespace mkldnn::impl::memory_format;
102 using namespace mkldnn::impl::utils;
104 void trans_W_4x4_3x3(float Fw_[6][6][16][16], float F[3][3][16][16]) {
111 for (int j = 0; j < 16; j++) {
113 for (int i = 0; i < 3; i++) {
115 for (int k = 0; k < 16; k++) {
116 t0[k] = 0.26890756302521f * F[2][i][j][k];
117 t1[k] = -t0[k] - 0.688403361344538f * F[0][i][j][k];
118 t2[k] = t0[k] + 0.119514472455649f * F[0][i][j][k];
120 T[0][i][k] = 1.13777777777778f * F[0][i][j][k];
121 T[1][i][k] = t1[k] - 0.430252100840336f * F[1][i][j][k];
122 T[2][i][k] = t1[k] + 0.430252100840336f * F[1][i][j][k];
123 T[3][i][k] = t2[k] + 0.179271708683473f * F[1][i][j][k];
124 T[4][i][k] = t2[k] - 0.179271708683473f * F[1][i][j][k];
125 T[5][i][k] = F[2][i][j][k];
129 for (int i = 0; i < 6; i++) {
131 for (int k = 0; k < 16; k++) {
132 t0[k] = 0.26890756302521f * T[i][2][k];
133 t1[k] = -t0[k] - 0.688403361344538f * T[i][0][k];
134 t2[k] = t0[k] + 0.119514472455649f * T[i][0][k];
136 Fw[0][k] = 1.13777777777778f * T[i][0][k];
137 Fw[1][k] = t1[k] - 0.430252100840336f * T[i][1][k];
138 Fw[2][k] = t1[k] + 0.430252100840336f * T[i][1][k];
139 Fw[3][k] = t2[k] + 0.179271708683473f * T[i][1][k];
140 Fw[4][k] = t2[k] - 0.179271708683473f * T[i][1][k];
141 Fw[5][k] = T[i][2][k];
143 for (int l = 0; l < 6; l++) {
144 Fw_[i][l][j][k] = Fw[l][k];
151 void trans_O_4x4_3x3(float Mw[6][6][16], float O[4][4][16]) {
159 for (int i = 0; i < 6; i++) {
161 for (int v = 0; v < 16; v++) {
162 t0[v] = Mw[1][i][v] + Mw[2][i][v];
163 t1[v] = Mw[3][i][v] + Mw[4][i][v];
164 t2[v] = Mw[1][i][v] - Mw[2][i][v];
165 t3[v] = Mw[3][i][v] - Mw[4][i][v];
167 T[0][i][v] = t0[v] + t1[v] + Mw[0][i][v];
168 T[1][i][v] = t2[v] * 0.625f + t3[v] * 1.5f;
169 T[2][i][v] = t0[v] * 0.390625f + t1[v] * 2.25f;
170 T[3][i][v] = t2[v] * 0.244140625f + t3[v] * 3.375f + Mw[5][i][v];
174 for (int i = 0; i < 4; i++) {
176 for (int v = 0; v < 16; v++) {
177 t0[v] = T[i][1][v] + T[i][2][v];
178 t1[v] = T[i][3][v] + T[i][4][v];
179 t2[v] = T[i][1][v] - T[i][2][v];
180 t3[v] = T[i][3][v] - T[i][4][v];
182 O[i][0][v] = t0[v] + t1[v] + T[i][0][v];
183 O[i][1][v] = t2[v] * 0.625f + t3[v] * 1.5f;
184 O[i][2][v] = t0[v] * 0.390625f + t1[v] * 2.25f;
185 O[i][3][v] = t2[v] * 0.244140625f + t3[v] * 3.375f + T[i][5][v];
191 void trans_W_3x3_4x4(float Fw[6][6][16], float F[4][6][16])
193 const float rcp3 = 1.0f / 3.0f;
194 const float rcp4 = 1.0f / 4.0f;
195 const float rcp6 = 1.0f / 6.0f;
196 const float rcp12 = 1.0f / 12.0f;
197 const float rcp24 = 1.0f / 24.0f;
206 for (int i = 0; i < 4; i++) {
208 for (int j = 0; j < 16; j++) {
209 t0[j] = F[2][i][j] * rcp6;
210 t1[j] = F[0][i][j] * -rcp6 - t0[j];
211 t2[j] = F[0][i][j] * rcp24 + t0[j];
212 t3[j] = (F[1][i][j] + F[3][i][j]) * rcp6;
213 t4[j] = F[1][i][j] * rcp12 + F[3][i][j] * rcp3;
215 T[0][i][j] = F[0][i][j] * rcp4;
216 T[1][i][j] = t1[j] - t3[j];
217 T[2][i][j] = t1[j] + t3[j];
218 T[3][i][j] = t2[j] + t4[j];
219 T[4][i][j] = t2[j] - t4[j];
220 T[5][i][j] = F[3][i][j];
224 for (int i = 0; i < 6; i++) {
226 for (int j = 0; j < 16; j++) {
227 t0[j] = T[i][2][j] * rcp6;
228 t1[j] = T[i][0][j] * -rcp6 - t0[j];
229 t2[j] = T[i][0][j] * rcp24 + t0[j];
230 t3[j] = (T[i][1][j] + T[i][3][j]) * rcp6;
231 t4[j] = T[i][1][j] * rcp12 + T[i][3][j] * rcp3;
233 Fw[i][0][j] = T[i][0][j] * rcp4;
234 Fw[i][1][j] = t1[j] - t3[j];
235 Fw[i][2][j] = t1[j] + t3[j];
236 Fw[i][3][j] = t2[j] + t4[j];
237 Fw[i][4][j] = t2[j] - t4[j];
238 Fw[i][5][j] = T[i][3][j];
243 void trans_O_3x3_4x4(float Mw[6][6][16][16], float M[3][3][16][16])
251 for (int j = 0; j < 16; j++) {
253 for (int i = 0; i < 6; i++) {
255 for (int l = 0; l < 16; l++) {
256 t0[l] = Mw[1][i][j][l] + Mw[2][i][j][l];
257 t1[l] = Mw[3][i][j][l] + Mw[4][i][j][l];
258 t2[l] = t1[l] * 4.0f + Mw[5][i][j][l];
260 T[0][i][l] = Mw[0][i][j][l] + t0[l] + t1[l];
261 T[1][i][l] = (Mw[1][i][j][l] - Mw[2][i][j][l]) +
262 2.0f * (Mw[3][i][j][l] - Mw[4][i][j][l]);
263 T[2][i][l] = t0[l] + t2[l];
267 for (int i = 0; i < 3; i++) {
269 for (int l = 0; l < 16; l++) {
270 t0[l] = T[i][1][l] + T[i][2][l];
271 t1[l] = T[i][3][l] + T[i][4][l];
272 t2[l] = t1[l] * 4.0f + T[i][5][l];
274 M_[0][l] = T[i][0][l] + t0[l] + t1[l];
275 M_[1][l] = (T[i][1][l] - T[i][2][l]) +
276 2.0f * (T[i][3][l] - T[i][4][l]);
277 M_[2][l] = t0[l] + t2[l];
279 for (int k = 0; k < 3; k++) {
280 M[i][k][j][l] = M_[k][l];
287 void trans_I_4x4_3x3(float Iw[6][6][16], float I[6][6][16])
298 for (int i = 0; i < 6; i++) {
300 for (int v = 0; v < 16; v++) {
301 t0[v] = I[2][i][v] * -2.25f + I[4][i][v];
302 t1[v] = I[1][i][v] * -2.25f + I[3][i][v];
303 t2[v] = I[2][i][v] * -0.390625f + I[4][i][v];
304 t3[v] = I[1][i][v] * -0.390625f + I[3][i][v];
305 t4[v] = I[0][i][v] * 0.87890625f + I[4][i][v];
306 t5[v] = I[1][i][v] * 0.87890625f + I[5][i][v];
308 T[0][i][v] = I[2][i][v] * -2.640625f + t4[v];
309 T[1][i][v] = t1[v] * 0.625f + t0[v];
310 T[2][i][v] = t1[v] * -0.625f + t0[v];
311 T[3][i][v] = t3[v] * 1.5f + t2[v];
312 T[4][i][v] = t3[v] * -1.5f + t2[v];
313 T[5][i][v] = I[3][i][v] * -2.640625f + t5[v];
318 for (int i = 0; i < 6; i++) {
320 for (int v = 0; v < 16; v++) {
321 t0[v] = T[i][2][v] * -2.25f + T[i][4][v];
322 t1[v] = T[i][1][v] * -2.25f + T[i][3][v];
323 t2[v] = T[i][2][v] * -0.390625f + T[i][4][v];
324 t3[v] = T[i][1][v] * -0.390625f + T[i][3][v];
325 t4[v] = T[i][0][v] * 0.87890625f + T[i][4][v];
326 t5[v] = T[i][1][v] * 0.87890625f + T[i][5][v];
328 Iw[i][0][v] = T[i][2][v] * -2.640625f + t4[v];
329 Iw[i][1][v] = t1[v] * 0.625f + t0[v];
330 Iw[i][2][v] = t1[v] * -0.625f + t0[v];
331 Iw[i][3][v] = t3[v] * 1.5f + t2[v];
332 Iw[i][4][v] = t3[v] * -1.5f + t2[v];
333 Iw[i][5][v] = T[i][3][v] * -2.640625f + t5[v];
338 void trans_W_3x3_4x4_wu(float Fw[6][6][16], float F[4][6][16])
348 for (int i = 0; i < 4; i++) {
350 for (int v = 0; v < 16; v++) {
351 t0[v] = F[2][i][v] * 0.26890756302521f;
352 t1[v] = F[0][i][v] * -0.688403361344538f - t0[v];
353 t2[v] = F[0][i][v] * 0.119514472455649f + t0[v];
354 t3[v] = F[1][i][v] * 0.430252100840336f +
355 F[3][i][v] * 0.168067226890756f;
356 t4[v] = F[1][i][v] * 0.179271708683473f +
357 F[3][i][v] * 0.403361344537815f;
359 T[0][i][v] = F[0][i][v] * 1.13777777777778f;
360 T[1][i][v] = t1[v] - t3[v];
361 T[2][i][v] = t1[v] + t3[v];
362 T[3][i][v] = t2[v] + t4[v];
363 T[4][i][v] = t2[v] - t4[v];
364 T[5][i][v] = F[3][i][v];
368 for (int i = 0; i < 6; i++) {
369 for (int v = 0; v < 16; v++) {
370 t0[v] = T[i][2][v] * 0.26890756302521f;
371 t1[v] = T[i][0][v] * -0.688403361344538f - t0[v];
372 t2[v] = T[i][0][v] * 0.119514472455649f + t0[v];
373 t3[v] = T[i][1][v] * 0.430252100840336f +
374 T[i][3][v] * 0.168067226890756f;
375 t4[v] = T[i][1][v] * 0.179271708683473f +
376 T[i][3][v] * 0.403361344537815f;
378 Fw[i][0][v] = T[i][0][v] * 1.13777777777778f;
379 Fw[i][1][v] = t1[v] - t3[v];
380 Fw[i][2][v] = t1[v] + t3[v];
381 Fw[i][3][v] = t2[v] + t4[v];
382 Fw[i][4][v] = t2[v] - t4[v];
383 Fw[i][5][v] = T[i][3][v];
388 void trans_O_3x3_4x4_wu(float Mw[6][6][16][16], float M[3][3][16][16])
396 for (int j = 0; j < 16; j++) {
398 for (int i = 0; i < 6; i++) {
400 for (int v = 0; v < 16; v++) {
401 t0[v] = Mw[1][i][j][v] + Mw[2][i][j][v];
402 t1[v] = Mw[3][i][j][v] + Mw[4][i][j][v];
403 t2[v] = t1[v] * 2.25f + Mw[5][i][j][v];
405 T[0][i][v] = Mw[0][i][j][v] + t0[v] + t1[v];
406 T[1][i][v] = 0.625f * (Mw[1][i][j][v] - Mw[2][i][j][v]) +
407 1.5f * (Mw[3][i][j][v] - Mw[4][i][j][v]);
408 T[2][i][v] = t0[v] * 0.390625f + t2[v];
412 for (int i = 0; i < 3; i++) {
414 for (int v = 0; v < 16; v++) {
415 t0[v] = T[i][1][v] + T[i][2][v];
416 t1[v] = T[i][3][v] + T[i][4][v];
417 t2[v] = t1[v] * 2.25f + T[i][5][v];
419 M_[0][v] = T[i][0][v] + t0[v] + t1[v];
420 M_[1][v] = 0.625f * (T[i][1][v] - T[i][2][v]) +
421 1.5f * (T[i][3][v] - T[i][4][v]);
422 M_[2][v] = t0[v] * 0.390625f + t2[v];
426 for (int k = 0; k < 3; k++) {
428 for (int v = 0; v < 16; v++) {
429 M[i][k][j][v] = M_[k][v];
436 template <bool is_fwd>
437 void input_transform_data(int image, const jit_conv_winograd_conf_t &jcp,
438 float *inp, float *tinp, bool streamout = true)
440 const int inpw = is_fwd ? jcp.iw : jcp.ow;
441 const int inph = is_fwd ? jcp.ih : jcp.oh;
442 const int l_pad = is_fwd ? jcp.l_pad : jcp.iw + jcp.r_pad - jcp.ow;
443 const int t_pad = is_fwd ? jcp.t_pad : jcp.ih + jcp.t_pad - jcp.oh;
444 const int wp_max = inpw + l_pad;
445 const int hp_max = inph + t_pad;
446 float Iw[alpha][alpha][simd_w];
447 float I[alpha][alpha][simd_w];
449 array_offset_calculator<float, 5> input(inp,
450 jcp.mb, jcp.dimK/simd_w, inph, inpw,
452 array_offset_calculator<float, 8> output(tinp,
453 jcp.dimN_nb_block, alpha, alpha,
454 jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block,
455 jcp.dimN_reg_block, jcp.dimK_reg_block);
457 int tile_base_index = image * jcp.itiles * jcp.jtiles;
458 int tile_block_ur = tile_base_index % jcp.tile_block_ur;
459 int nb_tile_block_ur =
460 (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur;
462 (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur;
464 for (int tj = 0; tj < jcp.jtiles; tj++) {
465 for (int ti = 0; ti < jcp.itiles; ti++) {
466 for (int j = 0; j < alpha; j++) {
467 int ydim = tj * tile_size + j;
468 if ((t_pad <= ydim) && (ydim < hp_max)) {
469 float *pinp_j = inp + (ydim - t_pad) * inpw * 16 ;
470 for (int i = 0; i < alpha; i++) {
471 int xdim = ti * tile_size + i;
472 if ((l_pad <= xdim) && (xdim < wp_max)) {
473 float *pinp_i = pinp_j + (xdim - l_pad) * 16;
474 load_ps(I[j][i], pinp_i);
477 for (int v = 0; v < simd_w; v++) {
483 for (int i = 0; i < alpha; i++) {
485 for (int v = 0; v < simd_w; v++) {
492 trans_I_4x4_3x3(Iw, I);
494 for (int j = 0; j < alpha; j++) {
495 for (int i = 0; i < alpha; i++) {
496 store_output(&(output(tile_block, j, i,
497 nb_tile_block_ur, 0, 0,
499 Iw[j][i], streamout);
503 if (tile_block_ur >= jcp.tile_block_ur) {
507 if (nb_tile_block_ur >= jcp.nb_tile_block_ur) {
508 nb_tile_block_ur = 0;
515 template <bool is_fwd>
516 void weight_transform_data(const jit_conv_winograd_conf_t &jcp,
517 float *wp, float *twp)
521 array_offset_calculator<float, 6> input(wp,
522 jcp.oc/jcp.oc_simd_block,
523 jcp.ic/jcp.ic_simd_block,
526 array_offset_calculator<float, 8> output(twp,
530 jcp.dimM_block, jcp.dimK_block,
532 float Fw[alpha][alpha][simd_w][simd_w];
533 float F[kh][kw][simd_w][simd_w];
535 for (int j = 0; j < kh; j++) {
536 for (int i = 0; i < kw; i++) {
537 for (int v1 = 0; v1 < simd_w; v1++) {
538 float *base_inp = is_fwd
539 ? &(input(0, 0, j, i, v1, 0))
540 : &(input(0, 0, 2 - j, 2 - i, v1, 0));
542 for (int v2 = 0; v2 < simd_w; v2++) {
544 F[j][i][v1][v2] = *(base_inp + v2);
546 F[j][i][v2][v1] = *(base_inp + v2);
552 trans_W_4x4_3x3(Fw, F);
554 for (int j = 0; j < alpha; j++) {
555 for (int i = 0; i < alpha; i++) {
556 for (int v1 = 0; v1 < simd_w; v1++) {
558 for (int v2 = 0; v2 < simd_w; v2++) {
559 output(0, j, i, 0, 0, 0, v1, v2) = Fw[j][i][v1][v2];
566 template <bool is_fwd, bool with_bias, bool with_relu_presum, bool with_sum>
567 void output_transform_data(int image, const jit_conv_winograd_conf_t &jcp,
568 const post_ops_t &p_ops, float *toutp, float *pout_b, float *bias,
569 bool streamout = true) {
570 float Ow[alpha][alpha][simd_w];
571 float O[tile_size][tile_size][simd_w];
572 int outw = is_fwd ? jcp.ow : jcp.iw;
573 int outh = is_fwd ? jcp.oh : jcp.ih;
575 /* Prepare for PostOps */
576 bool with_relu_postsum = p_ops.find(primitive_kind::eltwise, 1) != -1;
578 array_offset_calculator<float, 8> input(toutp,
579 jcp.dimN_nb_block, jcp.dimM_nb_block,
581 jcp.dimN_block, jcp.dimM_block,
582 jcp.dimN_reg_block, jcp.dimM_simd_block);
584 int tile_base_index = image * jcp.itiles * jcp.jtiles;
585 int tile_block_ur = tile_base_index % jcp.tile_block_ur;
586 int nb_tile_block_ur =
587 (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur;
589 (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur;
591 for (int tj = 0; tj < jcp.jtiles; tj++) {
592 for (int ti = 0; ti < jcp.itiles; ti++) {
593 for (int j = 0; j < alpha; j++) {
594 for (int i = 0; i < alpha; i++) {
596 for (int v = 0; v < simd_w; v++) {
597 Ow[j][i][v] = input(tile_block, 0,
605 trans_O_4x4_3x3(Ow, O);
607 for (int j = 0; j < tile_size; j++) {
608 int ydim = tj * tile_size + j;
610 float *pout_j = pout_b + ydim * outw * simd_w;
611 for (int i = 0; i < tile_size; i++) {
612 int xdim = ti * tile_size + i;
614 float *pout_i = pout_j + xdim * simd_w;
617 for (int v = 0; v < simd_w; v++) {
618 O[j][i][v] += with_bias ? bias[v] : 0.f;
620 && with_relu_presum && O[j][i][v] < 0.f
627 accum_output(pout_i, O[j][i], streamout,
630 store_output(pout_i, O[j][i], streamout);
636 if (tile_block_ur >= jcp.tile_block_ur) {
640 if (nb_tile_block_ur >= jcp.nb_tile_block_ur) {
641 nb_tile_block_ur = 0;
648 template <bool ver_4fma>
649 void diff_src_transform_bwd_weights(int image, jit_conv_winograd_conf_t conv,
650 float *inp, float *tinp, float *Iw_temp,
651 void (*transpose_4fma_ker)(float *, float *))
654 const int ifwp = conv.iw + conv.l_pad;
655 const int ifhp = conv.ih + conv.t_pad;
656 float I[alpha][alpha][simd_w];
657 float Iw[alpha][alpha][simd_w];
659 array_offset_calculator<float, 4> Iw_trans_temp(Iw_temp,
660 alpha, alpha, conv.tile_4fma, simd_w);
661 array_offset_calculator<float, 5> input(inp,
662 conv.mb, conv.ic/simd_w, conv.ih, conv.iw, simd_w);
663 array_offset_calculator<float, 8> output(tinp,
664 conv.nb_ic, alpha, alpha,
665 conv.tile_block, conv.ic_block,
666 conv.nb_tile_block_ur, conv.tile_block_ur,
667 conv.ic_simd_block * conv.tile_4fma);
669 int tile_base_index =
670 image * (conv.itiles * conv.jtiles + conv.tile_4fma_padding);
672 int tile_block_ur = (tile_base_index / conv.tile_4fma) % conv.tile_block_ur;
673 int nb_tile_block_ur =
674 (tile_base_index / conv.tile_4fma / conv.tile_block_ur)
675 % conv.nb_tile_block_ur;
676 int tile_block = (tile_base_index / conv.tile_4fma / conv.tile_block_ur)
677 / conv.nb_tile_block_ur;
679 for (int tj = 0; tj < conv.jtiles; tj++) {
680 for (int ti = 0; ti < conv.itiles; ti++) {
681 for (int j = 0; j < alpha; j++) {
682 int ydim = tj * tile_size + j;
683 if ((conv.t_pad <= ydim) && ydim < ifhp) {
684 for (int i = 0; i < alpha; i++) {
685 int xdim = ti * tile_size + i;
686 if ((conv.l_pad <= xdim) && xdim < ifwp) {
688 for (int v = 0; v < simd_w; v++) {
689 I[j][i][v] = input(0, 0,
691 xdim - conv.l_pad, v);
695 for (int v = 0; v < simd_w; v++) {
701 for (int i = 0; i < alpha; i++) {
703 for (int v = 0; v < simd_w; v++) {
709 trans_I_4x4_3x3(Iw, I);
712 for (int j = 0; j < alpha; j++) {
713 for (int i = 0; i < alpha; i++) {
714 float *Iw_temp_base = &(Iw_trans_temp(j, i,
717 for (int v = 0; v < simd_w; v++) {
718 Iw_temp_base[v] = Iw[j][i][v];
723 if (tile_4fma == conv.tile_4fma) {
724 float *outp = &(output(0, 0, 0,
726 nb_tile_block_ur, tile_block_ur, 0));
727 transpose_4fma_ker(outp, (float *)Iw_temp);
732 for (int j = 0; j < alpha; j++) {
733 for (int i = 0; i < alpha; i++) {
734 store_output(&(output(0, j, i,
736 nb_tile_block_ur, tile_block_ur, 0)),
743 if (tile_block_ur == conv.tile_block_ur) {
747 if (nb_tile_block_ur == conv.nb_tile_block_ur) {
748 nb_tile_block_ur = 0;
754 if (ver_4fma && tile_4fma < conv.tile_4fma && conv.tile_4fma_padding != 0) {
756 for (int j = 0; j < alpha; j++) {
757 for (int i = 0; i < alpha; i++) {
758 for (int tb = tile_4fma; tb < conv.tile_4fma; tb++) {
759 float *Iw_temp_base = &(Iw_trans_temp(j, i, tb, 0));
761 for (int v = 0; v < simd_w; v++) {
767 float *outp = &(output(0, 0, 0,
769 nb_tile_block_ur, tile_block_ur, 0));
770 transpose_4fma_ker(outp, (float *)Iw_temp);
774 template <bool with_bias>
775 void diff_dst_transform_bwd_weights(int image, jit_conv_winograd_conf_t conv,
776 float *inp, float *tinp, float *dbias)
779 const int total_tiles = conv.itiles * conv.jtiles + conv.tile_4fma_padding;
780 float I[alpha][alpha][simd_w];
781 float Iw[alpha][alpha][simd_w];
783 array_offset_calculator<float, 5> input(inp,
784 conv.mb, conv.oc/simd_w, conv.oh, conv.ow, conv.oc_simd_block);
785 array_offset_calculator<float, 8> output(tinp,
786 conv.nb_oc, alpha, alpha,
787 conv.tile_block, conv.oc_block,
788 conv.nb_tile_block_ur,
789 conv.tile_block_ur * conv.tile_4fma, conv.oc_simd_block);
791 int tile_base_index = image * total_tiles;
792 int tile_block_ur = tile_base_index % (conv.tile_block_ur * conv.tile_4fma);
793 int nb_tile_block_ur =
794 (tile_base_index / conv.tile_block_ur / conv.tile_4fma)
795 % conv.nb_tile_block_ur;
796 int tile_block = (tile_base_index / conv.tile_block_ur / conv.tile_4fma)
797 / conv.nb_tile_block_ur;
799 for (int tj = 0; tj < conv.jtiles; tj++) {
800 for (int ti = 0; ti < conv.itiles; ti++) {
801 for (int j = 0; j < alpha; j++) {
802 int ydim = tj * tile_size + j;
803 if (ydim < conv.oh) {
804 for (int i = 0; i < alpha; i++) {
805 int xdim = ti * tile_size + i;
806 if (xdim < conv.ow) {
807 float *input_base = &(input(0, 0, ydim, xdim, 0));
810 for (int v = 0; v < simd_w; v++) {
811 I[j][i][v] = input_base[v];
813 if (with_bias && j < tile_size && i < tile_size) {
815 for (int v = 0; v < simd_w; v++) {
816 dbias[v] += input_base[v];
821 for (int v = 0; v < simd_w; v++) {
827 for (int i = 0; i < alpha; i++) {
829 for (int v = 0; v < simd_w; v++) {
836 trans_W_3x3_4x4_wu(Iw, I);
838 for (int j = 0; j < alpha; j++) {
839 for (int i = 0; i < alpha; i++) {
840 store_output(&(output(0, j, i,
848 if (tile_block_ur >= conv.tile_block_ur * conv.tile_4fma) {
852 if (nb_tile_block_ur >= conv.nb_tile_block_ur) {
853 nb_tile_block_ur = 0;
860 void diff_weights_transform_bwd_weights(jit_conv_winograd_conf_t conv,
861 float *wp, float *twp)
865 float Fw[alpha][alpha][simd_w][simd_w];
866 float F[kh][kw][simd_w][simd_w];
868 array_offset_calculator<float, 8> input(twp,
869 conv.nb_ic, conv.nb_oc,
871 conv.oc_block, conv.ic_block,
872 conv.ic_simd_block, conv.oc_simd_block);
873 array_offset_calculator<float, 6> output(wp,
874 conv.oc/simd_w, conv.ic/simd_w,
876 conv.ic_simd_block, conv.oc_simd_block);
878 for (int j = 0; j < alpha; j++) {
879 for (int i = 0; i < alpha; i++) {
880 for (int v = 0; v < conv.ic_simd_block; v++) {
882 for (int k = 0; k < conv.oc_simd_block; k++) {
883 Fw[j][i][v][k] = input(0, 0, j, i, 0, 0, v, k);
889 trans_O_3x3_4x4_wu(Fw, F);
891 for (int j = 0; j < kh; j++) {
892 for (int i = 0; i < kw; i++) {
893 for (int v = 0; v < conv.ic_simd_block; v++) {
894 store_output(&(output(0, 0, j, i, v, 0)),
901 template <bool is_fwd>
902 void _jit_avx512_common_convolution_winograd_t<is_fwd>::_execute_data_W_S_G_D(
903 const int MB, float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr,
904 const memory_tracking::grantor_t &scratchpad) const{
905 const auto &jcp = kernel_->jcp;
906 const auto &p_ops = attr_->post_ops_;
908 const int inph = is_fwd ? jcp.ih : jcp.oh;
909 const int inpw = is_fwd ? jcp.iw : jcp.ow;
910 const int outh = is_fwd ? jcp.oh : jcp.ih;
911 const int outw = is_fwd ? jcp.ow : jcp.iw;
913 /* Note that jcp.with_eltwise is true for both fused conv+relu primitive
914 * and conv primitive with PostOps with relu before sum
915 * (PostOps relu after sum is handled later) */
916 auto output_transform = jcp.with_bias
919 ? output_transform_data<is_fwd, true, true, true>
920 : output_transform_data<is_fwd, true, true, false>)
922 ? output_transform_data<is_fwd, true, false, true>
923 : output_transform_data<is_fwd, true, false, false>))
926 ? output_transform_data<is_fwd, false, true, true>
927 : output_transform_data<is_fwd, false, true, false>)
929 ? output_transform_data<is_fwd, false, false, true>
930 : output_transform_data<is_fwd, false, false, false>));
933 FWD: dimM:oc, dimN:ntiles, dimK:ic,
934 BWD: dimM:ic, dimN:ntiles, dimK:oc,
935 FWD/BWD: V: src/diff_dst transform, U:weight transform,
936 M:dst/diff_src transform */
937 array_offset_calculator<float, 5> input(inp_ptr,
938 MB, jcp.dimK/jcp.dimK_reg_block, inph, inpw,
940 array_offset_calculator<float, 5> output(out_ptr,
941 MB, jcp.dimM/jcp.dimM_simd_block, outh, outw,
942 jcp.dimM_simd_block);
943 array_offset_calculator<float, 6> weights(wei_ptr,
944 jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw,
945 jcp.ic_simd_block, jcp.oc_simd_block);
946 array_offset_calculator<float, 2> bias(bias_ptr,
947 jcp.dimM/jcp.dimM_simd_block, jcp.dimM_simd_block);
949 array_offset_calculator<float, 8> M(is_fwd
950 ? scratchpad.template get<float>(key_wino_M)
951 : scratchpad.template get<float>(key_wino_V),
952 jcp.dimN_nb_block, jcp.dimM_nb_block,
954 jcp.dimN_block, jcp.dimM_block,
955 jcp.dimN_reg_block, jcp.dimM_simd_block);
956 array_offset_calculator<float, 8> U(
957 scratchpad.template get<float>(key_wino_U),
961 jcp.dimM_block, jcp.dimK_block,
962 jcp.dimK_reg_block, jcp.dimM_simd_block);
963 array_offset_calculator<float, 8> V(is_fwd
964 ? scratchpad.template get<float>(key_wino_V)
965 : scratchpad.template get<float>(key_wino_M),
966 jcp.dimN_nb_block, alpha, alpha,
967 jcp.dimN_block, jcp.dimK_nb_block,
968 jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block);
970 bool V_streamout = jcp.dimN * jcp.dimK * alpha * alpha * sizeof(float)
971 > 2 * LLC_cache_size ? true : false;
973 const bool output_is_aligned = ((size_t)out_ptr & (64 - 1)) == 0;
975 const bool wants_padded_bias = jcp.with_bias
976 && jcp.oc_without_padding != jcp.oc;
977 float last_slice_bias[simd_w] = {0};
978 if (wants_padded_bias) {
979 for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
980 last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
985 parallel_nd_in_omp(MB, jcp.dimK_nb_block, jcp.dimK_block,
986 [&](int img, int K_blk1, int K_blk2) {
987 input_transform_data<is_fwd>(img, jcp,
988 &(input(img, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, 0)),
989 &(V(0, 0, 0, 0, K_blk1, K_blk2, 0, 0)), V_streamout);
992 parallel_nd_in_omp(jcp.nb_oc, jcp.nb_ic, jcp.oc_block, jcp.ic_block,
993 [&](int ofm1, int ifm1, int ofm2, int ifm2) {
994 float *U_base_ptr = is_fwd
995 ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0))
996 : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0));
997 weight_transform_data<is_fwd>(jcp,
998 &(weights(ofm1 * jcp.oc_block + ofm2,
999 ifm1 * jcp.ic_block + ifm2, 0, 0, 0, 0)), U_base_ptr);
1004 parallel_nd_in_omp(jcp.dimN_nb_block, alpha, alpha, jcp.dimM_nb_block, jcp.dimN_block,
1005 [&](int N_blk1, int oj, int oi, int M_blk1, int N_blk2) {
1007 kernel_->gemm_loop_ker_first_iter(
1008 (float *)&(M(N_blk1, M_blk1, oj, oi,
1010 (const float *)&(U(M_blk1, oj, oi,
1012 (const float *)&(V(N_blk1, oj, oi,
1013 N_blk2, 0, 0, 0, 0)));
1014 for (int K_blk1 = 1; K_blk1 < jcp.dimK_nb_block; K_blk1++) {
1015 kernel_->gemm_loop_ker(
1016 (float *)&(M(N_blk1, M_blk1, oj, oi,
1018 (const float *)&(U(M_blk1, oj, oi,
1019 K_blk1, 0, 0, 0, 0)),
1020 (const float *)&(V(N_blk1, oj, oi,
1030 parallel_nd_in_omp(MB, jcp.dimM_nb_block, jcp.dimM_block,
1031 [&](int img, int M_blk1, int M_blk2) {
1033 const int M_blk = M_blk1 * jcp.dimM_block + M_blk2;
1035 float *bias_ptr = wants_padded_bias
1036 && M_blk == jcp.dimM / jcp.dimM_simd_block - 1
1037 ? last_slice_bias : &bias(M_blk, 0);
1039 output_transform(img, jcp, p_ops,
1040 &(M(0, M_blk1, 0, 0, 0, M_blk2, 0, 0)),
1041 &(output(img, M_blk, 0, 0, 0)),
1042 bias_ptr, output_is_aligned);
1048 template struct _jit_avx512_common_convolution_winograd_t<true>;
1049 template struct _jit_avx512_common_convolution_winograd_t<false>;
1051 void jit_avx512_common_convolution_winograd_bwd_weights_t::
1052 _maybe_execute_diff_bias_copy(
1053 const memory_tracking::grantor_t &scratchpad) const {
1054 if (pd()->wants_padded_bias()) {
1055 auto padded_bias = scratchpad.get<float>(key_conv_padded_bias);
1056 float *diff_bias = (float *)this->memory(1);
1057 for (int oc = 0; oc < pd()->jcp_.oc_without_padding; ++oc)
1058 diff_bias[oc] = padded_bias[oc];
1062 void jit_avx512_common_convolution_winograd_bwd_weights_t::
1063 _execute_backward_weights_S_D_G_W(
1064 const memory_tracking::grantor_t &scratchpad) const {
1065 const auto &jcp = kernel_->jcp;
1066 const int nthreads = jcp.nthr;
1068 auto diff_src_transform_bwd_weights_ver = jcp.ver == ver_4fma ?
1069 diff_src_transform_bwd_weights<true> :
1070 diff_src_transform_bwd_weights<false>;
1071 auto diff_dst_transform_bwd_weights_ver = jcp.with_bias
1072 ? diff_dst_transform_bwd_weights<true>
1073 : diff_dst_transform_bwd_weights<false>;
1075 array_offset_calculator<float, 5> diff_src((float *)this->input_memory(0),
1076 jcp.mb, jcp.ic/simd_w, jcp.ih, jcp.iw, simd_w);
1077 array_offset_calculator<float, 5> diff_dst((float *)this->input_memory(1),
1078 jcp.mb, jcp.oc/simd_w, jcp.oh, jcp.ow, simd_w);
1079 array_offset_calculator<float, 6> diff_weights((float *)this->memory(0),
1080 jcp.oc/simd_w, jcp.ic/simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
1081 array_offset_calculator<float, 2> diff_bias(pd()->wants_padded_bias()
1082 ? scratchpad.get<float>(key_conv_padded_bias)
1083 : (float *)this->memory(1), jcp.oc/simd_w, simd_w);
1085 array_offset_calculator<float, 8> U(
1086 scratchpad.get<float>(key_wino_U),
1087 jcp.nb_ic, jcp.nb_oc,
1089 jcp.oc_block, jcp.ic_block,
1090 jcp.ic_simd_block, jcp.oc_simd_block);
1092 array_offset_calculator<float, 8> M(
1093 scratchpad.get<float>(key_wino_M),
1094 jcp.nb_oc, alpha, alpha,
1095 jcp.tile_block, jcp.oc_block,
1096 jcp.nb_tile_block_ur, jcp.tile_block_ur * jcp.tile_4fma,
1098 array_offset_calculator<float, 8> V(
1099 scratchpad.get<float>(key_wino_V),
1100 jcp.nb_ic, alpha, alpha,
1101 jcp.tile_block, jcp.ic_block,
1102 jcp.nb_tile_block_ur, jcp.tile_block_ur,
1103 jcp.ic_simd_block * jcp.tile_4fma);
1105 const int trans_buffer_size = alpha * alpha * jcp.tile_4fma
1106 * jcp.ic_simd_block;
1107 array_offset_calculator<float, 2> trans_buffer(
1108 scratchpad.get<float>(key_conv_tr_src),
1112 array_offset_calculator<float, 2> diff_bias_prv(
1113 scratchpad.get<float>(key_conv_bia_reduction),
1117 PRAGMA_OMP(parallel num_threads(nthreads))
1119 if (jcp.with_bias) {
1120 parallel_nd_in_omp(nthreads, jcp.oc, [&](int ithr, int ofm) {
1121 diff_bias_prv(ithr, ofm) = 0.0f;
1124 PRAGMA_OMP(for nowait)
1125 for (int bofm = 0; bofm < jcp.oc / simd_w; bofm++) {
1127 for (int v = 0; v < simd_w; v++)
1128 diff_bias(bofm, v) = 0.0f;
1132 const int ithread = mkldnn_get_thread_num();
1134 parallel_nd_in_omp(jcp.mb, jcp.nb_ic, jcp.ic_block,
1135 [&](int img, int ifm1, int ifm2) {
1136 float *transb = jcp.ver == ver_4fma
1137 ? &(trans_buffer(ithread, 0))
1139 diff_src_transform_bwd_weights_ver(img, jcp,
1140 &(diff_src(img, ifm1 * jcp.ic_block + ifm2,
1142 &(V(ifm1, 0, 0, 0, ifm2, 0, 0, 0)),
1144 kernel_->transpose_4fma_ker);
1147 parallel_nd_in_omp(jcp.mb, jcp.nb_oc, jcp.oc_block,
1148 [&](int img, int ofm1, int ofm2) {
1149 float *dbias = jcp.with_bias
1150 ? &(diff_bias_prv(ithread,
1151 simd_w * (ofm1 * jcp.oc_block + ofm2)))
1153 diff_dst_transform_bwd_weights_ver(img, jcp,
1154 &(diff_dst(img, ofm1 * jcp.oc_block + ofm2,
1156 &(M(ofm1, 0, 0, 0, ofm2, 0, 0, 0)),
1162 for (int ifm1 = 0; ifm1 < jcp.nb_ic; ifm1++) {
1163 parallel_nd_in_omp(alpha, alpha, jcp.nb_oc,
1164 [&](int oj, int oi, int ofm1) {
1165 kernel_->gemm_loop_ker_first_iter(
1166 (float *)&(U(ifm1, ofm1, oj, oi,
1168 (const float *)&(M(ofm1, oj, oi,
1170 (const float *)&(V(ifm1, oj, oi,
1172 for (int tile_block = 1; tile_block < jcp.tile_block;
1174 kernel_->gemm_loop_ker((float *)&(U(ifm1, ofm1,
1177 (const float *)&(M(ofm1, oj, oi, tile_block,
1179 (const float *)&(V(ifm1, oj, oi, tile_block,
1187 parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block,
1188 [&](int ifm1, int ofm1, int ofm2, int ifm2) {
1189 diff_weights_transform_bwd_weights(jcp,
1190 &(diff_weights(ofm1 * jcp.oc_block + ofm2,
1191 ifm1 * jcp.ic_block + ifm2, 0, 0, 0, 0)),
1192 &(U(ifm1, ofm1, 0, 0, ofm2, ifm2, 0, 0)));
1195 if (jcp.with_bias) {
1197 for (int ofm1 = 0; ofm1 < jcp.oc / simd_w; ofm1++) {
1198 for (int ithr = 0; ithr < nthreads; ithr++) {
1199 float* base_bias_ptr = &(diff_bias(ofm1, 0));
1200 float* base_bias_prv_ptr = &(diff_bias_prv(
1201 ithr * jcp.oc + ofm1 * simd_w));
1203 for (int ofm2 = 0; ofm2 < simd_w; ofm2++) {
1204 base_bias_ptr[ofm2] += base_bias_prv_ptr[ofm2];
1211 _maybe_execute_diff_bias_copy(scratchpad);
1217 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s