Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_common_conv_kernel.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "nstl.hpp"
19 #include "type_helpers.hpp"
20 #include "utils.hpp"
21
22 #include "cpu_barrier.hpp"
23 #include "cpu_memory.hpp"
24
25 #include "jit_avx512_common_conv_kernel.hpp"
26
27 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
28 #define KNx_L2_EFFECTIVE_CAPACITY ((512-64)*1024)
29
30 namespace mkldnn {
31 namespace impl {
32 namespace cpu {
33
34 using namespace mkldnn::impl::memory_format;
35 using namespace mkldnn::impl::memory_tracking::names;
36 using namespace mkldnn::impl::utils;
37 using namespace Xbyak;
38
39 namespace {
40
41 constexpr auto small_spatial = 14;
42 unsigned int L1_cache_size = get_cache_size(1, true);
43
44 inline void pick_loop_order(jit_conv_conf_t &jcp) {
45     using namespace prop_kind;
46     assert(one_of(jcp.prop_kind,
47                 forward_training, forward_inference, backward_data));
48     auto w = (jcp.prop_kind == backward_data) ? jcp.iw : jcp.ow;
49     auto h = (jcp.prop_kind == backward_data) ? jcp.ih : jcp.oh;
50
51     // ow-threading is currently implemented for forward only
52     // TODO: single code for fwd and bwd after ow-thr for bwd
53     // meaningless switch was removed
54     if (jcp.prop_kind == backward_data) {
55         jcp.loop_order = (w <= small_spatial && h <= small_spatial)
56             ? loop_cgn : loop_gnc;
57     } else {
58         jcp.loop_order = (w <= small_spatial && h <= small_spatial)
59             ? loop_cwgn : loop_gncw;
60     }
61 }
62
63 inline bool is_1stconv(const jit_conv_conf_t &jcp) {
64     if (mayiuse(avx512_core) && !mayiuse(avx512_core_vnni))
65         return (jcp.ic < 16 && jcp.ngroups == 1);
66     else
67         return one_of(jcp.ic, 1, 3);
68 }
69
70 inline bool is_ow_threading_on(const jit_conv_conf_t &jcp) {
71     return (jcp.nb_ow > 1);
72 }
73
74 inline bool is_owb_prefetching(const jit_conv_conf_t &jcp) {
75     return (jcp.ver == ver_4fma && is_ow_threading_on(jcp));
76 }
77
78 }
79
80 template<typename Vmm>
81 void _jit_avx512_common_conv_fwd_kernel<Vmm>::prepare_output(int ur_w)
82 {
83     for (int k = 0; k < jcp.nb_oc_blocking; k++)
84         for (int j = 0; j < ur_w; j++) {
85             Vmm vmm = vmm_out(j, k);
86             vpxord(vmm, vmm, vmm);
87             if (!is_owb_prefetching(jcp)) {
88                 size_t aux_output_offset = get_output_offset(j, k);
89                 mic_prefetcht1(EVEX_compress_addr_safe(reg_out_prf,
90                             aux_output_offset, reg_out_long_offt));
91             }
92         }
93 }
94
95 template<typename Vmm>
96 void _jit_avx512_common_conv_fwd_kernel<Vmm>::store_output(int ur_w)
97 {
98     Label no_update_label, store_label, postproc_label;
99
100     mov(reg_channel, ptr[param1 + GET_OFF(channel)]);
101     if (jcp.with_bias) {
102         mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
103     }
104
105     if (!jcp.with_sum) {
106         cmp(reg_channel, 0);
107         je(no_update_label, T_NEAR);
108     }
109
110     for (int k = 0; k < jcp.nb_oc_blocking; k++)
111         for (int j = 0; j < ur_w; j++) {
112             Vmm vmm = vmm_out(j, k);
113             size_t aux_output_offset = get_output_offset(j, k);
114             vadd(vmm,
115                 make_safe_addr(reg_out, aux_output_offset, reg_out_long_offt));
116         }
117
118     if (!jcp.with_sum) {
119         jmp(postproc_label, T_NEAR);
120     } else {
121         cmp(reg_channel, 0);
122         jne(postproc_label, T_NEAR);
123     }
124
125     L(no_update_label);
126     if (jcp.with_bias) {
127         for (int k = 0; k < jcp.nb_oc_blocking; k++) {
128             int bias_offset = jcp.typesize_out * k * jcp.oc_block;
129             for (int j = 0; j < ur_w; j++) {
130                 Vmm vmm = vmm_out(j, k);
131                 vadd(vmm, EVEX_compress_addr(reg_bias, bias_offset));
132             }
133             mic_prefetcht1(EVEX_compress_addr(reg_bias, bias_offset + 64));
134         }
135     }
136
137     L(postproc_label);
138
139     cmp(reg_channel, jcp.nb_ic - 1);
140     jl(store_label, T_NEAR);
141
142     int eltwise_inj_idx = 0;
143     int depthwise_inj_idx = 0;
144     const auto &p = attr_.post_ops_;
145
146     for (int i = 0; i < p.len_; i++) {
147         auto& post_op = p.entry_[i];
148         if (post_op.is_eltwise()) {
149             if (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) {
150                 Vmm vmm_zero = vmm_wei;
151                 vpxord(vmm_zero, vmm_zero, vmm_zero);
152
153                 for (int k = 0; k < jcp.nb_oc_blocking; k++)
154                     for (int j = 0; j < ur_w; j++) {
155                         Vmm vmm = vmm_out(j, k);
156                         vpcmpd(k1, vmm, vmm_zero, _cmp_lt_os);
157                         vpmulld(vmm | k1, vmm, vmm_zero);
158                     }
159             } else {
160                 if (ur_w == jcp.ur_w) {
161                     eltwise_injectors[eltwise_inj_idx]->compute_vector_range(0,
162                                                             jcp.nb_oc_blocking * jcp.ur_w);
163                 } else {
164                     for (int k = 0; k < jcp.nb_oc_blocking; k++)
165                         eltwise_injectors[eltwise_inj_idx]->compute_vector_range(k * jcp.ur_w,
166                                                                                  k * jcp.ur_w + ur_w);
167                 }
168             }
169
170             eltwise_inj_idx++;
171         } else if (post_op.is_depthwise()) {
172             mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
173             mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
174
175             add(reg_d_weights, ptr[this->param1 + GET_OFF(oc_off)]);
176             add(reg_d_bias, ptr[this->param1 + GET_OFF(oc_off)]);
177
178             for (int k = 0; k < jcp.nb_oc_blocking; k++) {
179                 depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
180                         k*jcp.ur_w, k*jcp.ur_w + ur_w, reg_d_weights, reg_d_bias);
181
182                 add(reg_d_weights, jcp.oc_block * sizeof(float));
183                 add(reg_d_bias, jcp.oc_block * sizeof(float));
184             }
185
186             depthwise_inj_idx++;
187         }
188     }
189
190     L(store_label);
191     for (int k = 0; k < jcp.nb_oc_blocking; k++)
192         for (int j = 0; j < ur_w; j++) {
193             Vmm vmm = vmm_out(j, k);
194             size_t aux_output_offset = (size_t)typesize *
195                 ((size_t)k * jcp.od * jcp.oh * jcp.ow + j) * jcp.oc_block;
196             vmovups(EVEX_compress_addr_safe(reg_out, aux_output_offset,
197                         reg_out_long_offt), vmm);
198             if (!is_owb_prefetching(jcp))
199                 mic_prefetcht0(EVEX_compress_addr_safe(reg_out_prf,
200                             aux_output_offset, reg_out_long_offt));
201         }
202 }
203
204 template<typename Vmm>
205 void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_4fma_1st(int ur_w,
206     int pad_l, int pad_r)
207 {
208 }
209
210 template<>
211 void _jit_avx512_common_conv_fwd_kernel<Zmm>::compute_loop_4fma_1st(int ur_w,
212         int pad_l, int pad_r)
213 {
214     assert(jcp.dilate_d == 0 && jcp.dilate_h == 0 && jcp.dilate_w == 0);
215
216     int iw = jcp.iw;
217     int ih = jcp.ih;
218     int kw = jcp.kw;
219     int stride_w = jcp.stride_w;
220     int ic_block = jcp.ic_block;
221     int oc_block = jcp.oc_block;
222
223     Label kh_label, kd_label;
224
225     if (one_of(jcp.ndims, 3, 4)) {
226         mov(aux_reg_inp, reg_inp);
227         mov(aux_reg_ker, reg_ker);
228         mov(aux_reg_inp_prf, reg_inp_prf);
229     }
230
231     size_t max_input_offset = (size_t)jcp.typesize_in
232         * ((size_t)(kw + ur_w * stride_w - pad_l)
233                 + (size_t)ic_block * iw * ih * jcp.id);
234     assert(reg_inp_prf == reg_long_offt);
235     if (max_input_offset > INT_MAX) push(reg_inp_prf);
236
237     if (jcp.ndims == 5) {
238         push(reg_out_prf);
239         push(reg_out);
240
241         mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
242         mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
243         mov(aux_reg_inp_d, reg_inp);
244         mov(aux_reg_inp_d_prf, reg_inp_prf);
245
246         L(kd_label);
247     }
248     mov(reg_kj, reg_kh);
249     if (jcp.ndims == 5) {
250         mov(aux_reg_inp, aux_reg_inp_d);
251         mov(aux_reg_ker, aux_reg_ker_d);
252         mov(aux_reg_inp_prf, aux_reg_inp_d_prf);
253     }
254
255     L(kh_label);
256     for (int ki = 0; ki < kw; ki += 4) {
257         for (int ic = 0; ic < ic_block; ic++) {
258             for (int i = 0; i < 4; i++) {
259                 int aux_ker_offset
260                         = jcp.typesize_in
261                         * ((ki + i) * oc_block
262                                   + ic * kw * jcp.kh * jcp.kd * oc_block);
263                 if (ki + i < kw)
264                     vmovups(vmm_ker(i),
265                         EVEX_compress_addr(aux_reg_ker, aux_ker_offset));
266                 else
267                     vpxord(vmm_ker(i), vmm_ker(i), vmm_ker(i));
268             }
269
270             int j_start = get_ow_start(ki, pad_l);
271             int j_end = get_ow_end(ur_w, ki, pad_r);
272
273             for (int j = j_start, prf_count=0; j < j_end; j++) {
274                 size_t aux_input_offset = (size_t)jcp.typesize_in
275                         * ((size_t)(ki + j * stride_w
276                             - pad_l) + (size_t)ic * iw * ih * jcp.id);
277                 v4fmaddps(vmm_out(j, 0), vmm_ker(0),
278                         EVEX_compress_addr_safe(aux_reg_inp, aux_input_offset,
279                         reg_long_offt));
280                 if (ki + prf_count < kw && prf_count < 4
281                     && ((ki < 2 && j % 4) || j % 2)) {
282                     int aux_ker_offset = jcp.typesize_in
283                         * ((ki + prf_count) * oc_block
284                         + ic * kw * jcp.kh * jcp.kd * oc_block + kw * oc_block);
285                     mic_prefetcht0(EVEX_compress_addr(aux_reg_ker,
286                         aux_ker_offset));
287                     prf_count++;
288                 }
289                 if (ki == 0
290                     && j % (64 / (stride_w * jcp.typesize_in)) == 0) {
291                     mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp_prf,
292                         aux_input_offset, reg_long_offt));
293                 }
294                 if (ki == 1
295                     && j % (64 / (stride_w * jcp.typesize_in)) == 0) {
296                     mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp,
297                         aux_input_offset+jcp.typesize_in * iw, reg_long_offt));
298                 }
299             }
300         }
301     }
302     add(aux_reg_ker, jcp.typesize_in * kw * oc_block);
303     add(aux_reg_inp, jcp.typesize_in * iw);
304     add(aux_reg_inp_prf, jcp.typesize_in * iw);
305
306     dec(reg_kj);
307     cmp(reg_kj, 0);
308     jg(kh_label, T_NEAR);
309
310     if (jcp.ndims == 5) {
311         add(aux_reg_inp_d, typesize * jcp.ih * jcp.iw);
312         add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block);
313         add(aux_reg_inp_d_prf, typesize * jcp.ih * jcp.iw);
314
315         dec(reg_ki);
316         cmp(reg_ki, 0);
317         jg(kd_label, T_NEAR);
318
319         pop(reg_out);
320         pop(reg_out_prf);
321     }
322
323     if (max_input_offset > INT_MAX) pop(reg_inp_prf);
324 }
325
326 template<typename Vmm>
327 void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_4fma(int ur_w,
328     int pad_l, int pad_r)
329 {
330 }
331
332 template<>
333 void _jit_avx512_common_conv_fwd_kernel<Zmm>::compute_loop_4fma(int ur_w,
334         int pad_l, int pad_r)
335 {
336     int stride_w = jcp.stride_w;
337     int ic_block = jcp.ic_block;
338     int oc_block = jcp.oc_block;
339     Label kh_label, last_iter_label, loop_end_label, kd_label;
340     int ker_load_number = 4;
341     int shift_kernel_ptr = typesize * jcp.kw * jcp.oc_block * jcp.ic_block;
342     int shift_input_ptr = typesize * (jcp.dilate_h + 1) * jcp.iw * jcp.ic_block;
343
344     bool check_last_kh = (jcp.kh > 3);
345     bool pref_current_inp = (jcp.iw < 14 || jcp.iw > 28);
346
347     int oi_ipref_t0 = get_ow_start(0, pad_l);
348     int ow_end_ipref = get_ow_end(ur_w, 0, pad_r);
349
350     assert(jcp.oc % jcp.nb_oc_blocking == 0);
351
352     auto kernel_offset = [=](int ocb, int ic, int ki) {
353         int blk_idx = ocb * jcp.nb_ic * jcp.kh * jcp.kw * jcp.kd + ki;
354         int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block;
355         int ic_offset = ic * jcp.oc_block;
356         return typesize * (blk_offset + ic_offset);
357     };
358     auto kernel_loads = [=](int ki, int ic, int kk) {
359         for (int ii = 0; ii < ker_load_number; ii++) {
360             int aux_kernel_offset = kernel_offset(kk, ic + ii, ki);
361             vmovups(vmm_ker(ii),
362                 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
363         }
364     };
365     auto prefetch_inp_next_kh = [&](int ki, int ki_start, int cnt0, int cnt1) {
366         if (cnt1 >= ker_load_number && cnt0 >= ker_load_number
367             && ki >= ki_start && oi_ipref_t0 < ow_end_ipref) {
368             int aux_inp_offset
369                     = typesize
370                     * ((oi_ipref_t0 * stride_w - pad_l) * ic_block
371                               + (jcp.dilate_h + 1) * jcp.iw * ic_block);
372             prefetcht0(EVEX_compress_addr(aux_reg_inp,
373                     aux_inp_offset));
374             oi_ipref_t0++;
375         }
376     };
377
378     if (one_of(jcp.ndims, 3, 4)) {
379         mov(aux_reg_inp, reg_inp);
380         mov(aux_reg_ker, reg_ker);
381         mov(aux_reg_ker_prf, reg_ker_prf);
382         mov(aux_reg_inp_prf, reg_inp_prf);
383     }
384
385     if (jcp.ndims == 5) {
386         push(reg_out_prf);
387         push(reg_out);
388
389         mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
390         mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
391         mov(aux_reg_inp_d, reg_inp);
392         mov(aux_reg_inp_d_prf, reg_inp_prf);
393         mov(aux_reg_ker_d_prf, reg_ker_prf);
394         L(kd_label);
395         mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
396     } else {
397         mov(reg_kj, reg_kh);
398     }
399     if (jcp.ndims == 5) {
400         mov(aux_reg_inp, aux_reg_inp_d);
401         mov(aux_reg_ker, aux_reg_ker_d);
402         mov(aux_reg_ker_prf, aux_reg_ker_d_prf);
403         mov(aux_reg_inp_prf, aux_reg_inp_d_prf);
404     }
405
406     align(16);
407     L(kh_label);
408     int kw = jcp.kw;
409     if (check_last_kh) {
410         for (int ki = 0; ki < kw; ki++)
411             for (int ic = 0; ic < ic_block; ic += 4)
412                 for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) {
413                     bool last_kernel_loads = (kk == jcp.nb_oc_blocking - 1
414                         && ki == kw - 1 && (ic + 4) == ic_block);
415
416                     if (last_kernel_loads) {
417                         cmp(reg_kj, 1);
418                         je(last_iter_label, T_NEAR);
419                     }
420
421                     kernel_loads(ki, ic, kk);
422                     for (int oi = get_ow_start(ki, pad_l), prf_count_t1 = 0,
423                              prf_count_t0 = 0;
424                             oi < get_ow_end(ur_w, ki, pad_r); oi++) {
425                         int aux_input_offset = typesize
426                                 * ((ki * (jcp.dilate_w + 1) + oi * stride_w
427                                            - pad_l) * ic_block
428                                                        + ic);
429                         v4fmaddps(vmm_out(oi, kk), vmm_ker(0),
430                             EVEX_compress_addr(aux_reg_inp, aux_input_offset));
431
432                         if (oi % 2) {
433                             if (prf_count_t0 < 4) {
434                                 int aux_kernel_prf;
435                                 if (last_kernel_loads)
436                                     aux_kernel_prf= kernel_offset(0,
437                                         prf_count_t0 + ic + 4
438                                         - ic_block, 0) + typesize * kw
439                                         * oc_block * ic_block;
440                                 else
441                                     aux_kernel_prf = kernel_offset(kk, ic + 4
442                                         + prf_count_t0, ki);
443                                 mic_prefetcht0(EVEX_compress_addr(aux_reg_ker,
444                                     aux_kernel_prf));
445                                 prf_count_t0++;
446                             } else if (prf_count_t1 < 4) {
447                                 mic_prefetcht1(EVEX_compress_addr(
448                                     aux_reg_ker_prf, kernel_offset(kk, ic
449                                     + prf_count_t1, ki)));
450                                 prf_count_t1++;
451                             }
452                         } else
453                            prefetch_inp_next_kh(ki, 2, prf_count_t0,
454                                prf_count_t1);
455                     }
456
457                     if (last_kernel_loads) {
458                         jmp(loop_end_label, T_NEAR);
459
460                         L(last_iter_label);
461
462                         kernel_loads(ki, ic, kk);
463                         for (int oi = get_ow_start(ki, pad_l), prf_count_t1 = 0,
464                                  prf_count_t0 = 0;
465                                 oi < get_ow_end(ur_w, ki, pad_r); oi++) {
466                             int aux_input_offset = typesize
467                                     * ((ki * (jcp.dilate_w + 1) + oi * stride_w
468                                                - pad_l) * ic_block
469                                                            + ic);
470                             v4fmaddps(vmm_out(oi, kk), vmm_ker(0),
471                                 EVEX_compress_addr(aux_reg_inp,
472                                     aux_input_offset));
473                             if (oi % 2) {
474                                 if (prf_count_t0 < 4) {
475                                     mic_prefetcht0(EVEX_compress_addr(
476                                         aux_reg_ker_prf, kernel_offset(0,
477                                         prf_count_t0, 0)));
478                                     prf_count_t0++;
479                                 } else if (prf_count_t1 < 4) {
480                                     mic_prefetcht1(EVEX_compress_addr(
481                                         aux_reg_ker_prf, kernel_offset(kk,
482                                         ic + prf_count_t1, ki)));
483                                     prf_count_t1++;
484                                 }
485                             }
486                         }
487                         L(loop_end_label);
488                     }
489                 }
490     } else {
491         for (int ki = 0; ki < kw; ki++)
492             for (int ic = 0; ic < ic_block; ic += 4)
493                 for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) {
494                     kernel_loads(ki, ic, kk);
495                     for (int oi = get_ow_start(ki, pad_l),
496                             prf_count_t1 = 0, prf_count_t0 = 0;
497                             oi < get_ow_end(ur_w, ki, pad_r); oi++) {
498                         int aux_input_offset = typesize
499                                 * ((ki * (jcp.dilate_w + 1) + oi * stride_w
500                                 - pad_l) * ic_block + ic);
501                         v4fmaddps(vmm_out(oi, kk), vmm_ker(0),
502                             EVEX_compress_addr(aux_reg_inp,
503                                 aux_input_offset));
504
505                         if (!is_owb_prefetching(jcp)) {
506                             if ((oi % 2) && (prf_count_t1 < 4)) {
507                                 mic_prefetcht1(EVEX_compress_addr(
508                                     aux_reg_ker_prf, kernel_offset(kk,
509                                     ic + prf_count_t1, ki)));
510                                 prf_count_t1++;
511                             }
512                         } else {
513                             if (!(ki == 0 && ic == 0)
514                                 && !(ki == kw-1 && ic == 0) &&
515                                 (oi % 2) && (prf_count_t1 < 4)
516                                 ) {
517                                 mic_prefetcht0(EVEX_compress_addr(
518                                     aux_reg_ker, kernel_offset(kk,
519                                     ic + 4 + prf_count_t0, ki)));
520                                 prf_count_t0++;
521                             }
522                         }
523                         if (!is_owb_prefetching(jcp)) {
524                             if (pref_current_inp) {
525                                 if (ki == 0 && ic == 0 && kk == 0)
526                                     mic_prefetcht0(EVEX_compress_addr(
527                                         aux_reg_inp,
528                                         aux_input_offset + shift_input_ptr));
529                             } else {
530                                 if (ki == 1 && ic == 0 && kk == 0)
531                                     mic_prefetcht1(EVEX_compress_addr(
532                                         aux_reg_inp_prf, aux_input_offset));
533                             }
534                         } else {
535                             int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
536                             int inp_shift
537                                 = jcp.typesize_in * ur_w * stride_w * inp_mult;
538                             bool kk_pref_slot = kk ? oi % 2 : !(oi % 2);
539                             if (ki == 0 && ic == 0 && kk_pref_slot)
540                                     mic_prefetcht1(EVEX_compress_addr(
541                                         aux_reg_inp,
542                                         aux_input_offset + inp_shift));
543
544                             if (ki == kw - 1 && ic == 0 && kk_pref_slot)
545                                     mic_prefetcht0(EVEX_compress_addr(
546                                         aux_reg_inp,
547                                         aux_input_offset + inp_shift));
548                         }
549                     }
550                 }
551     }
552
553     add(aux_reg_ker, shift_kernel_ptr);
554     add(aux_reg_inp, shift_input_ptr);
555     add(aux_reg_ker_prf, shift_kernel_ptr);
556     add(aux_reg_inp_prf, shift_input_ptr);
557
558     dec(reg_kj);
559     cmp(reg_kj, 0);
560     jg(kh_label, T_NEAR);
561
562     if (jcp.ndims == 5) {
563         add(aux_reg_inp_d,
564                 typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * jcp.ic_block);
565         add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block
566                 * jcp.ic_block);
567         add(aux_reg_inp_d_prf,
568                 typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * jcp.ic_block);
569         add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh * jcp.oc_block
570                 * jcp.ic_block);
571
572         dec(reg_ki);
573         cmp(reg_ki, 0);
574         jg(kd_label, T_NEAR);
575
576         pop(reg_out);
577         pop(reg_out_prf);
578     }
579 }
580
581 template<typename Vmm>
582 void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_fma(int ur_w,
583         int pad_l, int pad_r)
584 {
585     bool prf_ker = true;
586     bool prf_inp = true;
587     int ih = jcp.ih;
588     int stride_w = jcp.stride_w;
589     int id = jcp.id;
590     int iw = jcp.iw;
591     int kw = jcp.kw;
592     int ic_block = jcp.ic_block;
593     int oc_block = jcp.oc_block;
594     int nb_oc_block = jcp.nb_oc_blocking;
595     Label kh_label, kd_label;
596
597     int ker_pipeline_depth = 4;
598     assert(ker_reg_base_idx + ker_pipeline_depth <= 32);
599     assert(oc_block >= ker_pipeline_depth);
600
601     int num_ker_loads = ic_block * nb_oc_block * kw;
602     int num_ker_prfs = prf_ker ? num_ker_loads : 0;
603     int num_inp_prfs = prf_inp ?
604             ur_w * nstl::min(kw, stride_w) + nstl::max(0, kw - stride_w) :
605             0;
606     if (jcp.is_1stconv && prf_inp) {
607         num_inp_prfs = div_up(num_inp_prfs, jcp.simd_w) * ic_block;
608     }
609     int num_prfs = num_ker_prfs + num_inp_prfs;
610     int num_fmas = num_ker_loads * ur_w;
611     int prf_inst_spacing
612             = (prf_ker || prf_inp) ? nstl::max(1, num_fmas / num_prfs) : 1;
613     int prf_inst_trigger = (num_fmas % prf_inst_spacing) / 2;
614     int inp_mul = !jcp.is_1stconv ? ic_block : 1;
615
616     if (one_of(jcp.ndims, 3, 4)) {
617         mov(aux_reg_inp, reg_inp);
618         mov(aux_reg_ker, reg_ker);
619         mov(aux_reg_inp_prf, reg_inp_prf);
620         mov(aux_reg_ker_prf, reg_ker_prf);
621     }
622
623     size_t max_input_offset = (size_t)jcp.typesize_in * ic_block * iw * ih * id;
624     assert(reg_inp_prf == reg_long_offt);
625     if (max_input_offset > INT_MAX) push(reg_inp_prf);
626
627
628     if (jcp.ndims == 5) {
629         push(reg_out_prf);
630         push(reg_out);
631
632         mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
633         mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
634         mov(aux_reg_inp_d, reg_inp);
635         mov(aux_reg_inp_d_prf, reg_inp_prf);
636         mov(aux_reg_ker_d_prf, reg_ker_prf);
637
638         L(kd_label);
639         mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
640     } else {
641         mov(reg_kj, reg_kh);
642     }
643
644     if (jcp.ndims == 5) {
645         mov(aux_reg_inp, aux_reg_inp_d);
646         mov(aux_reg_ker, aux_reg_ker_d);
647         mov(aux_reg_ker_prf, aux_reg_ker_d_prf);
648         mov(aux_reg_inp_prf, aux_reg_inp_d_prf);
649     }
650
651     align(16);
652     L(kh_label);
653     {
654         int step = 0;
655         int ker_prfs = 0;
656         for (int ki = 0; ki < kw; ki++) {
657             for (int ic = 0; ic < ic_block; ic++) {
658                 int aux_kernel_offset = 0;
659                 if (step == 0) {
660                     for (int i = 0; i < ker_pipeline_depth; i++) {
661                         aux_kernel_offset = get_kernel_offset(ki, ic, 0, i);
662                         vmovups(vmm_ker(i), EVEX_compress_addr(
663                                         aux_reg_ker, aux_kernel_offset));
664                     }
665                 } else if (step < num_ker_loads - ker_pipeline_depth + 1) {
666                     int load_offset = ker_pipeline_depth - 1;
667                     int ker_load_reg_idx
668                         = (step + load_offset) % ker_pipeline_depth;
669                     aux_kernel_offset
670                             = get_kernel_offset(ki, ic, 0, load_offset);
671                     vmovups(vmm_ker(ker_load_reg_idx),
672                             EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
673                 }
674
675                 bool ker_prf_inserted = false;
676                 Vmm vmm_kernel = vmm_ker(step % ker_pipeline_depth);
677                 int j_start = get_ow_start(ki, pad_l);
678                 int j_end = get_ow_end(ur_w, ki, pad_r);
679                 for (int j = j_start; j < j_end; j++) {
680                     size_t aux_input_offset = get_input_offset(ki, ic, j, pad_l);
681                     auto addr = EVEX_compress_addr_safe(aux_reg_inp,
682                             aux_input_offset, reg_long_offt, true);
683                     vfmadd231ps(vmm_out(j, 0), vmm_kernel, addr);
684                     int fma_idx = step * ur_w + j;
685                     int prf_slot_idx = fma_idx / prf_inst_spacing;
686                     if (fma_idx % prf_inst_spacing == prf_inst_trigger) {
687                         if (prf_ker && !ker_prf_inserted
688                                 && ker_prfs < num_ker_prfs) {
689                             int ker_prf_offset
690                                     = jcp.typesize_in * ker_prfs * jcp.oc_block;
691                             mic_prefetcht2(EVEX_compress_addr(
692                                     aux_reg_ker_prf, ker_prf_offset));
693                             ker_prf_inserted = true;
694                             ker_prfs++;
695                         } else if (prf_inp) {
696                             int inp_prf_idx = prf_slot_idx - ker_prfs;
697                             if (inp_prf_idx < num_inp_prfs) {
698                                 size_t inp_prf_stride = nstl::max(kw, stride_w);
699                                 size_t inp_prf_offset;
700                                 if (!jcp.is_1stconv) {
701                                     inp_prf_offset
702                                             = ic_block * jcp.typesize_in
703                                             * ((inp_prf_idx / kw)
704                                             * inp_prf_stride
705                                             + (inp_prf_idx % kw));
706                                 } else {
707                                     size_t ic_prf_stride =
708                                         (size_t)jcp.typesize_in * iw * ih * id;
709                                     size_t iw_prf_stride
710                                             = jcp.typesize_in * jcp.simd_w;
711                                     inp_prf_offset = ((inp_prf_idx / ic_block)
712                                             * iw_prf_stride
713                                             + (inp_prf_idx % ic_block)
714                                             * ic_prf_stride);
715                                 }
716                                 mic_prefetcht0(EVEX_compress_addr_safe(
717                                         aux_reg_inp_prf, inp_prf_offset,
718                                         reg_long_offt));
719                             }
720                         }
721                     }
722                 }
723                 step++;
724             }
725         }
726         add(aux_reg_ker, jcp.typesize_in * kw * oc_block * ic_block);
727         if (prf_ker)
728             add(aux_reg_ker_prf, jcp.typesize_in * kw * oc_block * ic_block);
729         add(aux_reg_inp, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul);
730         if (prf_inp)
731             add(aux_reg_inp_prf,
732                     jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul);
733         dec(reg_kj);
734         cmp(reg_kj, 0);
735         jg(kh_label, T_NEAR);
736     }
737
738
739     if (jcp.ndims == 5) {
740         add(aux_reg_inp_d,
741                 typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul);
742         add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block
743                 * jcp.ic_block);
744         add(aux_reg_inp_d_prf,
745                 typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul);
746         add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh * jcp.oc_block
747                 * jcp.ic_block);
748
749         dec(reg_ki);
750         cmp(reg_ki, 0);
751         jg(kd_label, T_NEAR);
752
753         pop(reg_out);
754         pop(reg_out_prf);
755     }
756     if (max_input_offset > INT_MAX) pop(reg_inp_prf);
757 }
758
759 template<typename Vmm>
760 void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_fma_core(int ur_w,
761     int pad_l, int pad_r)
762 {
763     int kw = jcp.kw;
764     int stride_w = jcp.stride_w;
765     int ic_block = jcp.ic_block;
766     int oc_block = jcp.oc_block;
767     int nb_oc_block = jcp.nb_oc_blocking;
768     Label kh_label, kd_label;
769     int shift_kernel_ptr = jcp.typesize_in * jcp.kw * jcp.oc_block
770         * jcp.ic_block;
771     int inp_mul = !jcp.is_1stconv ? ic_block : 1;
772     int shift_input_ptr = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw
773         * inp_mul;
774
775
776     auto input_offset = [=](int oi, int ic, int ki) {
777         return (size_t)jcp.typesize_in
778                 * ((size_t)(ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l)
779                 * inp_mul + (size_t)ic
780                 * (!jcp.is_1stconv ? 1 : (size_t)jcp.iw * jcp.ih * jcp.id));
781     };
782
783     if (one_of(jcp.ndims, 3, 4)) {
784         mov(aux_reg_inp, reg_inp);
785         mov(aux_reg_ker, reg_ker);
786     }
787
788     if (jcp.ndims == 5) {
789         push(reg_out);
790
791         mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
792         mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
793         mov(aux_reg_inp_d, reg_inp);
794
795         L(kd_label);
796         mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
797     } else {
798         mov(reg_kj, reg_kh);
799     }
800
801     if (jcp.ndims == 5) {
802         mov(aux_reg_inp, aux_reg_inp_d);
803         mov(aux_reg_ker, aux_reg_ker_d);
804     }
805
806     L(kh_label);
807     {
808         for (int ki = 0; ki < kw; ki++) {
809             int jj_start = get_ow_start(ki, pad_l);
810             int jj_end = get_ow_end(ur_w, ki, pad_r);
811             for (int ic = 0; ic < ic_block; ic++) {
812                 if (jcp.kernel_kind == expl_bcast) {
813                     for (int jj = jj_start; jj < jj_end; jj++) {
814                         size_t aux_input_offset = input_offset(jj, ic, ki);
815                         vbroadcastss(vmm_inp(jj, nb_oc_block),
816                             EVEX_compress_addr_safe(aux_reg_inp,
817                             aux_input_offset, reg_long_offt));
818                     }
819                 }
820                 for (int ii = 0; ii < nb_oc_block; ii++) {
821                     int aux_kernel_offset = jcp.typesize_in
822                         * (ii * jcp.nb_ic * jcp.kh * jcp.kw * jcp.kd * ic_block
823                         * oc_block + ki * ic_block * oc_block + ic * oc_block);
824                     if (jj_end - jj_start > 0)
825                         vmovups(vmm_wei, EVEX_compress_addr(aux_reg_ker,
826                             aux_kernel_offset));
827                     for (int jj = jj_start; jj < jj_end; jj++)
828                         if (jcp.kernel_kind == expl_bcast)
829                             vfmadd231ps(vmm_out(jj, ii),
830                                 vmm_inp(jj, nb_oc_block), vmm_wei);
831                         else {
832                             size_t aux_input_offset = input_offset(jj, ic, ki);
833                             vfmadd231ps(vmm_out(jj, ii), vmm_wei,
834                                 EVEX_compress_addr_safe(aux_reg_inp,
835                                 aux_input_offset, reg_long_offt, true));
836                         }
837                 }
838             }
839         }
840         add(aux_reg_ker, shift_kernel_ptr);
841         add(aux_reg_inp, shift_input_ptr);
842         dec(reg_kj);
843         cmp(reg_kj, 0);
844         jg(kh_label, T_NEAR);
845     }
846
847     if (jcp.ndims == 5) {
848         add(aux_reg_inp_d,
849                 typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul);
850         add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block
851                 * jcp.ic_block);
852
853         dec(reg_ki);
854         cmp(reg_ki, 0);
855         jg(kd_label, T_NEAR);
856
857         pop(reg_out);
858     }
859 }
860
861 template<typename Vmm>
862 void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_vnni(
863     int ur_w, int pad_l, int pad_r)
864 {
865 }
866
867 template<>
868 void _jit_avx512_common_conv_fwd_kernel<Zmm>::compute_loop_vnni(
869         int ur_w, int pad_l, int pad_r)
870 {
871     Label kh_label, kd_label;
872     const int ker_reg_base_idx = 28;
873     const int channel_inc = jcp.ver == ver_4vnni ? 4 : 1;
874     const int ker_load_number = jcp.ver == ver_4vnni ? 4 : 1;
875     const int shift_kernel_ptr = jcp.typesize_in * jcp.kw
876                                * jcp.oc_block * jcp.ic_block;
877     const int shift_input_ptr
878             = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw * jcp.ic_block;
879
880     size_t max_input_offset = (size_t)jcp.typesize_in
881                 * jcp.ic_block * jcp.iw * jcp.ih * jcp.id;
882     assert(reg_inp_prf == reg_long_offt);
883     if (max_input_offset > INT_MAX) push(reg_inp_prf);
884
885
886     if (one_of(jcp.ndims, 3, 4)) {
887         mov(aux_reg_inp, reg_inp);
888         mov(aux_reg_ker, reg_ker);
889         mov(aux_reg_ker_prf, reg_ker_prf);
890         mov(aux_reg_inp_prf, reg_inp_prf);
891     }
892
893     if (jcp.ndims == 5) {
894         push(reg_out_prf);
895         push(reg_out);
896
897         mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
898         mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
899         mov(aux_reg_inp_d, reg_inp);
900         mov(aux_reg_inp_d_prf, reg_inp_prf);
901         mov(aux_reg_ker_d_prf, reg_ker_prf);
902
903         L(kd_label);
904         mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
905     } else {
906         mov(reg_kj, reg_kh);
907     }
908     if (jcp.ndims == 5) {
909         mov(aux_reg_inp, aux_reg_inp_d);
910         mov(aux_reg_ker, aux_reg_ker_d);
911         mov(aux_reg_ker_prf, aux_reg_ker_d_prf);
912         mov(aux_reg_inp_prf, aux_reg_inp_d_prf);
913     }
914
915     L(kh_label); {
916         for (int ki = 0; ki < jcp.kw; ki++) {
917             int ow_start = get_ow_start(ki, pad_l);
918             int ow_end = get_ow_end(ur_w, ki, pad_r);
919             for (int ic = 0; ic < jcp.ic_block / 2; ic += channel_inc) {
920                 if (jcp.kernel_kind == expl_bcast) {
921                     for (int oi = ow_start; oi < ow_end; oi++) {
922                         size_t input_offset = get_input_offset(ki, ic, oi, pad_l);
923                         vpbroadcastd(vmm_inp(oi, jcp.nb_oc_blocking),
924                             EVEX_compress_addr_safe(aux_reg_inp, input_offset,
925                             reg_long_offt));
926                     }
927                 }
928                 for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) {
929                     if (jcp.kernel_kind == expl_bcast) {
930                         int kernel_offset = get_kernel_offset(ki, ic, kk, 0);
931                         vmovups(vmm_wei,
932                             EVEX_compress_addr(aux_reg_ker, kernel_offset));
933                     } else {
934                         for (int ii = 0; ii < ker_load_number; ii++) {
935                             int kernel_offset
936                                 = get_kernel_offset(ki, ic, kk, ii);
937                             vmovups(Zmm(ker_reg_base_idx + ii),
938                                     EVEX_compress_addr(
939                                             aux_reg_ker, kernel_offset));
940                         }
941                     }
942                     for (int oi = ow_start, prf_count = 0; oi < ow_end; oi++) {
943                         size_t input_offset = get_input_offset(ki, ic, oi, pad_l);
944                         if (jcp.kernel_kind == expl_bcast) {
945                             vpdpwssd(vmm_out(oi, kk), vmm_wei,
946                                 vmm_inp(oi, jcp.nb_oc_blocking));
947                         } else {
948                             if (jcp.ver == ver_4vnni)
949                                 vp4dpwssd(vmm_out(oi, kk), Zmm(ker_reg_base_idx),
950                                 EVEX_compress_addr_safe(aux_reg_inp,
951                                     input_offset, reg_long_offt, false));
952                             else
953                                 vpdpwssd(vmm_out(oi, kk), Zmm(ker_reg_base_idx),
954                                 EVEX_compress_addr_safe(aux_reg_inp,
955                                     input_offset, reg_long_offt, true));
956                         }
957                         if ((oi % 2) && (prf_count < ker_load_number)) {
958                             int kernel_offset = get_kernel_offset(
959                                 ki, ic, kk, prf_count++);
960                             mic_prefetcht0(EVEX_compress_addr(aux_reg_ker_prf,
961                                 kernel_offset));
962                         }
963                         if (!(oi % 2) && ki == 0 && ic == 0 && kk == 0) {
964                             mic_prefetcht1(EVEX_compress_addr_safe(
965                                 aux_reg_inp_prf, input_offset, reg_long_offt));
966                         }
967                         if (!(oi % 2) && ki == 1 && ic == 0 && kk == 0) {
968                             mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp,
969                                 input_offset + shift_input_ptr, reg_long_offt));
970                         }
971                     }
972                 }
973             }
974         }
975         add(aux_reg_ker_prf, shift_kernel_ptr);
976         add(aux_reg_inp_prf, shift_input_ptr);
977         add(aux_reg_ker, shift_kernel_ptr);
978         add(aux_reg_inp, shift_input_ptr);
979
980         dec(reg_kj);
981         cmp(reg_kj, 0);
982         jg(kh_label, T_NEAR);
983     }
984
985     if (jcp.ndims == 5) {
986         add(aux_reg_inp_d, jcp.typesize_in * jcp.ih * jcp.iw * jcp.ic_block);
987         add(aux_reg_ker_d, jcp.typesize_in * jcp.kw * jcp.kh * jcp.oc_block
988                 * jcp.ic_block);
989         add(aux_reg_inp_d_prf, jcp.typesize_in * jcp.ih * jcp.iw * jcp.ic_block);
990         add(aux_reg_ker_d_prf, jcp.typesize_in * jcp.kw * jcp.kh * jcp.oc_block
991                 * jcp.ic_block);
992
993         dec(reg_ki);
994         cmp(reg_ki, 0);
995         jg(kd_label, T_NEAR);
996
997         pop(reg_out);
998         pop(reg_out_prf);
999     }
1000     if (max_input_offset > INT_MAX) pop(reg_inp_prf);
1001 }
1002
1003 template<typename Vmm>
1004 void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop(int ur_w,
1005         int pad_l, int pad_r)
1006 {
1007     if (jcp.ndims == 5) push(reg_oi);
1008
1009     prepare_output(ur_w);
1010
1011     Label skip_compute_loop;
1012     if (jcp.ndims == 5) {
1013         if ((jcp.dilate_d >= jcp.id)
1014                 || (jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) {
1015             mov(reg_kj, ptr[param1 + GET_OFF(kd_padding)]);
1016             cmp(reg_kj, 0);
1017             je(skip_compute_loop, T_NEAR);
1018         }
1019     }
1020     if ((jcp.dilate_h >= jcp.ih)
1021             || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
1022         mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
1023         cmp(reg_kj, 0);
1024         je(skip_compute_loop, T_NEAR);
1025     }
1026
1027     if (jcp.ver == ver_4vnni || jcp.ver == ver_vnni)
1028         compute_loop_vnni(ur_w, pad_l, pad_r);
1029     else if (jcp.ver == ver_4fma)
1030         if(jcp.is_1stconv)
1031             compute_loop_4fma_1st(ur_w, pad_l, pad_r);
1032         else
1033             compute_loop_4fma(ur_w, pad_l, pad_r);
1034     else if (jcp.ver == ver_fma)
1035         if ((jcp.is_1stconv && jcp.kernel_kind != expl_bcast)
1036                 || mayiuse(avx512_mic))
1037             compute_loop_fma(ur_w, pad_l, pad_r);
1038         else
1039             if (jcp.kernel_kind == embd_bcast && jcp.nb_oc_blocking == 1)
1040                 compute_loop_fma(ur_w, pad_l, pad_r);
1041             else
1042                 compute_loop_fma_core(ur_w, pad_l, pad_r);
1043     else
1044         assert(!"unknown convolution version");
1045
1046     L(skip_compute_loop);
1047     store_output(ur_w);
1048     if (jcp.ndims == 5) pop(reg_oi);
1049 }
1050
1051 template<typename Vmm>
1052 void _jit_avx512_common_conv_fwd_kernel<Vmm>::generate()
1053 {
1054     const auto &p = attr_.post_ops_;
1055     for (int i = 0; i < p.len_; i++) {
1056         auto &post_op = p.entry_[i];
1057         if (post_op.is_eltwise()) {
1058             eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<avx512_common>(
1059                     this,
1060                     post_op.eltwise.alg,
1061                     post_op.eltwise.alpha,
1062                     post_op.eltwise.beta
1063             ));
1064         } else if (post_op.is_depthwise()) {
1065             depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<avx512_common>(
1066                     this,
1067                     post_op.depthwise.alg
1068             ));
1069         }
1070     }
1071
1072     int iw = jcp.iw;
1073     int ow = jcp.ow;
1074     int ow_block = jcp.ow_block;
1075     int nb_ow = jcp.nb_ow;
1076     int kw = jcp.kw;
1077     int l_pad = jcp.l_pad;
1078     int ur_w = jcp.ur_w;
1079     int ur_w_tail = jcp.ur_w_tail;
1080     int dilate_w = jcp.dilate_w + 1;
1081     int stride_w = jcp.stride_w;
1082
1083     int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
1084     int inp_shift_pad = jcp.typesize_in * (ur_w * stride_w - l_pad) * inp_mult;
1085     int inp_shift = jcp.typesize_in * ur_w * stride_w * inp_mult;
1086     int inp_shift_pad_second_block = -1 * jcp.typesize_in * l_pad * inp_mult;
1087     int out_shift = jcp.typesize_out * ur_w * jcp.oc_block;
1088
1089     preamble();
1090     mov(reg_inp, ptr[param1 + GET_OFF(src)]);
1091     mov(reg_out, ptr[param1 + GET_OFF(dst)]);
1092     mov(reg_ker, ptr[param1 + GET_OFF(filt)]);
1093     mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]);
1094     mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]);
1095
1096     int r_pad = nstl::max(
1097             0, (ow - 1) * stride_w + (kw - 1) * dilate_w - (iw + l_pad - 1));
1098     int n_oi = ow / ur_w;
1099     int r_pad1 = (ur_w * n_oi - 1) * stride_w + (kw - 1) * dilate_w
1100             - (iw + l_pad - 1);
1101
1102     if (!is_ow_threading_on(jcp)) {
1103         // ow is being processed as a whole - with left and right paddings
1104         if (r_pad1 > 0) n_oi--;
1105
1106         if (ow == ur_w) {
1107             mov(reg_inp_prf, ptr[param1 + GET_OFF(src_prf)]);
1108             mov(reg_out_prf, ptr[param1 + GET_OFF(dst_prf)]);
1109             compute_loop(ur_w, l_pad, r_pad);
1110         } else {
1111             mov(reg_inp_prf, reg_inp);
1112             mov(reg_out_prf, reg_out);
1113             if (n_oi == 0) {
1114                 add(reg_inp_prf, inp_shift_pad);
1115                 add(reg_out_prf, out_shift);
1116                 compute_loop(ur_w, l_pad, r_pad1);
1117                 add(reg_inp, inp_shift_pad);
1118                 add(reg_out, out_shift);
1119                 if (ur_w_tail != 0) {
1120                     add(reg_inp_prf, inp_shift);
1121                     add(reg_out_prf, out_shift);
1122                     compute_loop(ur_w_tail, 0, r_pad);
1123                 }
1124             } else {
1125                 if (l_pad > 0) {
1126                     n_oi--;
1127                     add(reg_inp_prf, inp_shift_pad);
1128                     add(reg_out_prf, out_shift);
1129                     compute_loop(ur_w, l_pad, 0);
1130                     add(reg_inp, inp_shift_pad);
1131                     add(reg_out, out_shift);
1132                 }
1133                 if (n_oi > 0) {
1134                     xor_(reg_oi, reg_oi);
1135                     Label ow_loop_label;
1136                     L(ow_loop_label);
1137                     {
1138                         add(reg_inp_prf, inp_shift);
1139                         add(reg_out_prf, out_shift);
1140                         compute_loop(ur_w, 0, 0);
1141                         add(reg_inp, inp_shift);
1142                         add(reg_out, out_shift);
1143                         inc(reg_oi);
1144                         cmp(reg_oi, n_oi);
1145                         jl(ow_loop_label, T_NEAR);
1146                     }
1147                 }
1148                 if (r_pad1 > 0) {
1149                     add(reg_inp_prf, inp_shift);
1150                     add(reg_out_prf, out_shift);
1151                     compute_loop(ur_w, 0, r_pad1);
1152                     add(reg_inp, inp_shift);
1153                     add(reg_out, out_shift);
1154                 }
1155                 if (ur_w_tail != 0) {
1156                     add(reg_inp_prf, inp_shift);
1157                     add(reg_out_prf, out_shift);
1158                     compute_loop(ur_w_tail, 0, r_pad);
1159                 }
1160             }
1161         }
1162     } else {
1163         // ow block is only processed.
1164         // Number of block is passed as parameter owb,
1165         // and padding processing depends on this number.
1166
1167         Label end_label, last_oi_label, middle_ow_blocks_label, tail_label;
1168         Label oi_loop_label, oi_loop_start_label, oi_loop_end_label;
1169
1170         assert(ow_block % ur_w == 0);
1171         int n_oi_not_last_ow_block = ow_block / ur_w;
1172         // to simplify code (and general regs usage),
1173         // size of ow block must be >= 2 * ur_w
1174         assert(n_oi_not_last_ow_block > 1);
1175         int n_oi_next_last_ow_block = n_oi_not_last_ow_block;
1176         int n_oi_first_ow_block = n_oi_not_last_ow_block;
1177
1178         int n_oi_last_ow_block = (ow - ow_block * (nb_ow-1)) / ur_w;
1179
1180         // prepare right padding
1181         bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0;
1182         bool first_ow_block_padded = next_last_ow_block_padded && jcp.nb_ow == 2;
1183         bool last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block > 0;
1184
1185         if (last_ow_block_padded) n_oi_last_ow_block--;
1186         else if (first_ow_block_padded) n_oi_first_ow_block--;
1187         else if (next_last_ow_block_padded) n_oi_next_last_ow_block--;
1188
1189         mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
1190         cmp(reg_owb, 0); // is that the first ow-block ?
1191         jg(middle_ow_blocks_label, T_NEAR);
1192
1193         // the first ow block, compute left padding
1194
1195         mov(reg_oi, n_oi_first_ow_block);
1196         mov(reg_inp_prf, reg_inp);
1197         mov(reg_out_prf, reg_out);
1198
1199         if (l_pad > 0) {
1200             mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]);
1201             add(reg_inp_prf, inp_shift_pad);
1202             add(reg_out_prf, out_shift);
1203             compute_loop(ur_w, l_pad, 0);
1204             add(reg_inp, inp_shift_pad);
1205             add(reg_out, out_shift);
1206             dec(reg_oi);
1207         }
1208         jmp(oi_loop_label, T_NEAR);
1209
1210         // middle or last ow block entry
1211
1212         L(middle_ow_blocks_label);
1213
1214         if (l_pad > 0) {
1215             // just to consider left padding, not compute
1216             add(reg_inp, inp_shift_pad_second_block);
1217             add(reg_inp_prf, inp_shift_pad_second_block);
1218         }
1219
1220         // set number of iteration for oi-loop
1221         cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ?
1222         mov(reg_oi, n_oi_last_ow_block);
1223         je(oi_loop_label, T_NEAR);
1224         cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
1225         mov(reg_oi, n_oi_next_last_ow_block);
1226         je(oi_loop_label, T_NEAR);
1227         mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks
1228
1229         // oi loop w/o padding
1230         L(oi_loop_label);
1231         mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]);
1232         L(oi_loop_start_label);
1233             cmp(reg_oi, 0);
1234             jle(oi_loop_end_label, T_NEAR);
1235
1236             add(reg_inp_prf, inp_shift);
1237             add(reg_out_prf, out_shift);
1238             compute_loop(ur_w, 0, 0);
1239             add(reg_inp, inp_shift);
1240             add(reg_out, out_shift);
1241             dec(reg_oi);
1242             jmp(oi_loop_start_label, T_NEAR);
1243         L(oi_loop_end_label);
1244
1245         mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
1246
1247         cmp(reg_owb, 0); // first ow-block ?
1248         if (first_ow_block_padded) {
1249             je(last_oi_label, T_NEAR);
1250         } else {
1251             je(end_label, T_NEAR);
1252         }
1253         cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
1254         jl(end_label, T_NEAR);
1255         if (next_last_ow_block_padded) {
1256             je(last_oi_label, T_NEAR);
1257         } else {
1258             je(end_label, T_NEAR);
1259         }
1260         // that is last block
1261         if (!last_ow_block_padded) {
1262             jmp(tail_label, T_NEAR);
1263         }
1264
1265         // last oi block with right padding
1266         L(last_oi_label);
1267         mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]);
1268         add(reg_inp_prf, inp_shift);
1269         add(reg_out_prf, out_shift);
1270         compute_loop(ur_w, 0, r_pad1);
1271         add(reg_inp, inp_shift);
1272         add(reg_out, out_shift);
1273
1274         mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
1275         cmp(reg_owb, jcp.nb_ow - 1); // last ow_block?
1276         jl(end_label, T_NEAR);
1277
1278         L(tail_label);
1279         mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]);
1280         if (ur_w_tail != 0) {
1281             add(reg_inp_prf, inp_shift);
1282             add(reg_out_prf, out_shift);
1283             compute_loop(ur_w_tail, 0, r_pad);
1284         }
1285         L(end_label);
1286     }
1287     postamble();
1288
1289     for (auto& inj : eltwise_injectors)
1290         inj->prepare_table();
1291 }
1292
1293 bool jit_avx512_common_conv_fwd_kernel::post_ops_ok(
1294         jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
1295     const auto &p = attr.post_ops_;
1296
1297     auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
1298     auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
1299     auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
1300     auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
1301
1302     switch (p.len_) {
1303     case 0: return true;
1304     case 1: return is_simple(0) || is_sum(0);
1305     case 2: return (is_sum(0) && is_simple(1)) || (is_simple(0) && is_simple(1));
1306     case 3: return is_sum(0) && is_simple(1) && is_simple(2);
1307     default: return false;
1308     }
1309
1310     return false;
1311 }
1312
1313 status_t jit_avx512_common_conv_fwd_kernel::init_conf(
1314             jit_conv_conf_t &jcp, const convolution_desc_t &cd,
1315             cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &weights_pd,
1316             cpu_memory_t::pd_t &dst_pd, cpu_memory_t::pd_t &bias_pd,
1317             const primitive_attr_t &attr, int nthreads)
1318 {
1319     using namespace prop_kind;
1320
1321     if (!mayiuse(avx512_common))
1322         return status::unimplemented;
1323
1324     const memory_desc_wrapper src_d(&src_pd);
1325     const memory_desc_wrapper weights_d(&weights_pd);
1326     const memory_desc_wrapper dst_d(&dst_pd);
1327     const memory_desc_wrapper bias_d(&bias_pd);
1328
1329     const int regs = 28;
1330     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
1331     int ndims = src_d.ndims();
1332
1333     jcp = zero<decltype(jcp)>();
1334     jcp.ndims = ndims;
1335     jcp.prop_kind = cd.prop_kind;
1336     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
1337     jcp.mb = src_d.dims()[0];
1338     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
1339     jcp.oc_without_padding = jcp.oc;
1340     jcp.ic = src_d.dims()[1] / jcp.ngroups;
1341     jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
1342     jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2];
1343     jcp.iw = src_d.dims()[ndims-1];
1344     jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
1345     jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims-2];
1346     jcp.ow = dst_d.dims()[ndims-1];
1347     jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
1348     jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims-2];
1349     jcp.kw = weights_d.dims()[with_groups + ndims-1];
1350     jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
1351     jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
1352     jcp.l_pad = cd.padding[0][ndims-3];
1353     jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
1354     jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
1355     jcp.stride_w = cd.strides[ndims-3];
1356     jcp.src_fmt = src_d.format();
1357
1358     jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
1359     jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
1360     jcp.dilate_w = cd.dilates[ndims-3];
1361
1362     jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
1363             - (jcp.ih + jcp.t_pad - 1);
1364     jcp.back_pad = (jcp.od - 1) * jcp.stride_d
1365             + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1);
1366
1367     jcp.is_1stconv = is_1stconv(jcp);
1368
1369     bool ok_to_pad_channels = true
1370         && jcp.ngroups == 1
1371         && src_d.data_type() == data_type::f32;
1372
1373     const int full_simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
1374     jcp.simd_w = full_simd_w;
1375     bool ok_to_try_xmm = true
1376         && mayiuse(avx512_core)
1377         && src_d.data_type() == data_type::f32
1378         && !jcp.is_1stconv
1379         && !ok_to_pad_channels
1380         && (jcp.ic % jcp.simd_w != 0 || jcp.oc % jcp.simd_w != 0)
1381         && (jcp.ic % 8 != 0 || jcp.oc % 8 != 0);
1382     if (ok_to_try_xmm)
1383         jcp.simd_w = 4;
1384
1385     jcp.oc_block = jcp.simd_w;
1386     jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w;
1387     jcp.aligned_threads = 0;
1388
1389     if (ok_to_pad_channels) {
1390         jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
1391         jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
1392     }
1393     bool args_ok = true
1394         && jcp.oc % jcp.oc_block == 0
1395         && jcp.ic % jcp.ic_block == 0;
1396     if (!args_ok)
1397         return status::unimplemented;
1398
1399     if (!post_ops_ok(jcp, attr))
1400         return status::unimplemented;
1401
1402     const auto &p = attr.post_ops_;
1403     jcp.with_sum = p.find(primitive_kind::sum) != -1;
1404     const int eltwise_ind = p.find(primitive_kind::eltwise);
1405     jcp.with_eltwise = eltwise_ind != -1;
1406     if (jcp.with_eltwise) {
1407         jcp.eltwise = p.entry_[eltwise_ind].eltwise;
1408         if (dst_d.data_type() == data_type::s32) return status::unimplemented;
1409     }
1410
1411     auto src_format = jcp.is_1stconv
1412         ? pick(ndims - 3, ncw, nchw, ncdhw)
1413         : ((jcp.simd_w == 4)
1414             ? pick(ndims - 3, nCw4c, nChw4c, nCdhw4c)
1415             : pick(ndims - 3, nCw16c, nChw16c, nCdhw16c));
1416     auto dst_format = (jcp.simd_w == 4)
1417         ? pick(ndims - 3, nCw4c, nChw4c, nCdhw4c)
1418         : pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
1419     auto wei_format = with_groups
1420         ? ((jcp.simd_w == 4)
1421             ? pick(ndims - 3, gOIw4i4o, gOIhw4i4o, gOIdhw4i4o)
1422             : pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o))
1423         : ((jcp.simd_w == 4)
1424             ? pick(ndims - 3, OIw4i4o, OIhw4i4o, OIdhw4i4o)
1425             : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o));
1426
1427     if (src_d.format() == any)
1428         CHECK(src_pd.set_format(src_format));
1429     if (src_d.format() != src_format)
1430         return status::unimplemented;
1431
1432     if (dst_d.format() == any)
1433         CHECK(dst_pd.set_format(dst_format));
1434     if (dst_d.format() != dst_format)
1435         return status::unimplemented;
1436
1437     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
1438     if (jcp.with_bias) {
1439         if (bias_d.format() == any)
1440             CHECK(bias_pd.set_format(x));
1441         if (bias_d.format() != x)
1442             return status::unimplemented;
1443     }
1444
1445     if ((mayiuse(avx512_mic_4ops) || mayiuse(avx512_core_vnni))
1446          && src_d.data_type() == data_type::s16
1447          && weights_d.data_type() == data_type::s16
1448          && dst_d.data_type() == data_type::s32)
1449     {
1450         if (jcp.is_1stconv)
1451             return status::unimplemented;
1452
1453         if (mayiuse(avx512_mic_4ops)) {
1454             jcp.ver = ver_4vnni;
1455         } else {
1456             jcp.ver = ver_vnni;
1457         }
1458         jcp.typesize_in = sizeof(int16_t);
1459         jcp.typesize_out = sizeof(int32_t);
1460
1461         const auto w_format = with_groups
1462             ? pick(ndims - 3, gOIw8i16o2i, gOIhw8i16o2i, gOIdhw8i16o2i)
1463             : pick(ndims - 3, OIw8i16o2i, OIhw8i16o2i, OIdhw8i16o2i);
1464         if (weights_d.format() == any)
1465             CHECK(weights_pd.set_format(w_format));
1466         if (weights_d.format() != w_format)
1467             return status::unimplemented;
1468     } else if (mayiuse(avx512_common) &&
1469             src_d.data_type() == data_type::f32
1470          && weights_d.data_type() == data_type::f32
1471          && dst_d.data_type() == data_type::f32) {
1472         jcp.ver = ver_fma;
1473         jcp.typesize_in = sizeof(float);
1474         jcp.typesize_out = sizeof(float);
1475         if (mayiuse(avx512_mic_4ops))
1476            jcp.ver = ver_4fma;
1477
1478         if (jcp.is_1stconv) {
1479             // TODO: fix & remove constraints below
1480             bool not_for_4fma
1481                     = IMPLICATION(everyone_is(0, jcp.l_pad, jcp.t_pad),
1482                             nstl::max(jcp.kw, jcp.kh) < 7);
1483             bool is_dilated
1484                     = !everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w);
1485             if (one_of(true, not_for_4fma, is_dilated))
1486                 jcp.ver = ver_fma;
1487             if (jcp.ver == ver_4fma) {
1488                 const auto w_format = with_groups
1489                     ? ((jcp.simd_w == 4)
1490                         ? pick(ndims - 3, gOiw4o, gOihw4o, gOidhw4o)
1491                         : pick(ndims - 3, gOiw16o, gOihw16o, gOidhw16o))
1492                     : ((jcp.simd_w == 4)
1493                         ? pick(ndims - 3, Oiw4o, Oihw4o, Oidhw4o)
1494                         : pick(ndims - 3, Oiw16o, Oihw16o, Oidhw16o));
1495                 if (weights_d.format() == any)
1496                     CHECK(weights_pd.set_format(w_format));
1497                 if (weights_d.format() != w_format)
1498                     return status::unimplemented;
1499             } else {
1500                 const auto w_format = with_groups
1501                     ? ((jcp.simd_w == 4)
1502                         ? pick(ndims - 3, gOwi4o, gOhwi4o, gOdhwi4o)
1503                         : pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o))
1504                     : ((jcp.simd_w == 4)
1505                         ? pick(ndims - 3, Owi4o, Ohwi4o, Odhwi4o)
1506                         : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o));
1507                 if (weights_d.format() == any)
1508                     CHECK(weights_pd.set_format(w_format));
1509                 if (weights_d.format() != w_format)
1510                     return status::unimplemented;
1511             }
1512         } else {
1513             if (weights_d.format() == any)
1514                 CHECK(weights_pd.set_format(wei_format));
1515             if (weights_d.format() != wei_format)
1516                 return status::unimplemented;
1517         }
1518     } else {
1519         return status::unimplemented;
1520     }
1521
1522     if (jcp.is_1stconv) {
1523         jcp.ur_w = nstl::min(jcp.ow, regs);
1524     } else {
1525         // avx512_core guard - just to avoid possible regression for other archs
1526         if (jcp.ver == ver_fma && mayiuse(avx512_core)) {
1527             jcp.ur_w = nstl::min(jcp.ow, regs);
1528         } else {
1529             for (int ur_w = regs; ur_w > 0; --ur_w) {
1530                 if (jcp.ow % ur_w == 0) {
1531                     jcp.ur_w = ur_w;
1532                     break;
1533                 }
1534             }
1535         }
1536         if ((ndims == 5 && jcp.ur_w <= 8) || (jcp.ur_w <= 1)) {
1537             jcp.ur_w = nstl::min(jcp.ow, regs);
1538         }
1539     }
1540     // TODO (Tanya): currently applied to Segnet convolutions only.
1541     // Need to try for other topologies
1542     if (jcp.ow > 150 && jcp.ur_w < regs/2)
1543         jcp.ur_w = regs;
1544
1545     int n_oi = (jcp.ow / jcp.ur_w);
1546     int r_pad = (jcp.ur_w * n_oi - 1) * jcp.stride_w
1547             + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1);
1548     if (jcp.l_pad > 0 && r_pad > 0)
1549         n_oi--;
1550
1551     bool large_code_size = jcp.ur_w != jcp.ow && jcp.l_pad > 0 && r_pad > 0
1552             && ((jcp.l_pad <= 0 && n_oi > 0) || (jcp.l_pad > 0 && n_oi > 1));
1553     if (large_code_size) {
1554         const int max_code_size = 24 * 1024;
1555         const int num_ops_per_reg = 6 + jcp.ic_block * jcp.kw;
1556         int mult = 1;
1557         if (jcp.l_pad > 0) mult += 1;
1558         if (r_pad > 0) mult += 1;
1559         for (int ur_w = jcp.ur_w; ur_w > regs/2; --ur_w) {
1560             if (ur_w * mult * num_ops_per_reg * 9.0 < max_code_size) {
1561                 jcp.ur_w = ur_w;
1562                 break;
1563             }
1564         }
1565     }
1566
1567     /* Grouped channel offset to support 'non-blocked data' format for
1568      * convolution sizes with '(input_channel / ngroups) < simd' */
1569     jcp.nonblk_group_off
1570             = (jcp.ngroups > 1 && one_of(src_d.format(), ncw, nchw, ncdhw)) ?
1571             jcp.ic :
1572             1;
1573
1574     jcp.nb_ic = jcp.ic / jcp.ic_block;
1575     jcp.nb_oc = jcp.oc / jcp.oc_block;
1576     jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
1577
1578     auto is_ow_threading_applicable = [=]() {
1579         return (true && !jcp.is_1stconv && one_of(jcp.ndims, 3, 4)
1580                 && IMPLICATION(mayiuse(avx512_mic),
1581                            jcp.ver == ver_4fma
1582                                    && IMPLICATION(jcp.mb != 1,
1583                                               jcp.ih == 1 && jcp.kh == 1)));
1584     };
1585
1586     if (jcp.ver == ver_4vnni) {
1587         jcp.kernel_kind = embd_bcast;
1588     }
1589     if (jcp.ver == ver_vnni) {
1590         // TODO: kernel_kind and nb_oc_blocking selection
1591         //       should be tuned on real HW
1592         if (jcp.ow <= 8 && jcp.oh <= 8 && jcp.od <= 8) {
1593             jcp.kernel_kind = expl_bcast;
1594             jcp.nb_oc_blocking = 2;
1595         } else {
1596             jcp.kernel_kind = embd_bcast;
1597             jcp.nb_oc_blocking = 2;
1598         }
1599         if (jcp.nb_oc_blocking > 1) {
1600             if (jcp.nb_oc < jcp.nb_oc_blocking) jcp.nb_oc_blocking = jcp.nb_oc;
1601             if (jcp.nb_oc % jcp.nb_oc_blocking != 0)
1602                 for (int i = jcp.nb_oc_blocking; i > 0; i--)
1603                     if (jcp.nb_oc % i == 0) {
1604                         jcp.nb_oc_blocking = i;
1605                         break;
1606                     }
1607             jcp.ur_w = 31 / (jcp.nb_oc_blocking + 1);
1608             if (jcp.ow < jcp.ur_w)
1609                 jcp.ur_w = jcp.ow;
1610         }
1611     }
1612
1613     if (one_of(jcp.ver, ver_4vnni, ver_4fma) && !jcp.is_1stconv) {
1614         if ((jcp.kw <= 5 && jcp.kh <= 5 && jcp.kw == jcp.kh && jcp.ow <= 8
1615                     && jcp.oh <= 8 && jcp.ow == jcp.oh)
1616                 || (jcp.stride_h != 1 && jcp.ur_w < jcp.ow)) {
1617             if (jcp.nb_oc % 2 == 0) {
1618                 jcp.nb_oc_blocking = 2;
1619                 jcp.ur_w = nstl::min(jcp.ow, regs / jcp.nb_oc_blocking);
1620             }
1621         } else {
1622             for (int i = jcp.nb_oc; i > 0; i--)
1623                 if (i * jcp.ur_w <= regs && jcp.nb_oc % i == 0) {
1624                     jcp.nb_oc_blocking = i;
1625                     break;
1626                 }
1627         }
1628         if (jcp.ver == ver_4fma && is_ow_threading_applicable()) {
1629             if (jcp.nb_oc % 2 == 0 && jcp.ur_w < jcp.ow
1630                     && jcp.ow != 2 * jcp.ur_w) {
1631                 jcp.nb_oc_blocking = 2;
1632                 jcp.ur_w = nstl::min(jcp.ow, regs / jcp.nb_oc_blocking);
1633             }
1634         }
1635     }
1636
1637     jcp.ow_block = jcp.ow;
1638
1639     auto get_thr_eff = [=](int nb_oc_blocking, int ow_block) {
1640         int nb_ow = div_up(jcp.ow, ow_block);
1641         int nb_oc_chunks = div_up(jcp.nb_oc, nb_oc_blocking);
1642         int work_amount = jcp.mb * jcp.oh * nb_oc_chunks * nb_ow;
1643         float disbalance = (float)jcp.ow / rnd_up(jcp.ow, ow_block);
1644         float thr_eff = disbalance * (float)work_amount
1645             / rnd_up(work_amount, nthreads);
1646         return thr_eff;
1647     };
1648
1649     auto get_ow_block = [=](int nb_oc_blocking, int ur_w, float &eff) {
1650         int res_ow_block = jcp.ow;
1651         eff = get_thr_eff(nb_oc_blocking, res_ow_block);
1652         if (!is_ow_threading_applicable())
1653             return res_ow_block;
1654
1655         int L2_part = (get_cache_size(2) * 7 / 8) / typesize;
1656         if (jcp.ver == ver_4fma)
1657             L2_part /= 2;
1658         int size_src_chunk = jcp.ic_block * ur_w * jcp.kh;
1659         int size_dst_chunk = jcp.oc_block * nb_oc_blocking * ur_w;
1660         int size_wei_chunk = jcp.oc_block * nb_oc_blocking * jcp.ic_block
1661             * jcp.kw * jcp.kh;
1662         int nurw_cache = (L2_part - 2 * size_wei_chunk)
1663             / (2 * size_dst_chunk + 2 * size_src_chunk);
1664         // current design of generate() requires ow_block >= 2 * ur_w
1665         int ow_block_cache = ur_w * nstl::max(2, nurw_cache);
1666
1667         int ow_block_thr = ow_block_cache;
1668         eff = get_thr_eff(nb_oc_blocking, ow_block_thr);
1669
1670         int max_nb_ow = div_up(jcp.ow, 2 * ur_w);
1671         int start_nb_ow = div_up(jcp.ow, ow_block_thr);
1672         for (int nb_ow = start_nb_ow; nb_ow <= max_nb_ow; nb_ow++) {
1673             int ow_block
1674                 = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), ur_w), jcp.ow);
1675             float eff_threshold = (jcp.ver == ver_4fma) ? 0.8f : 0.9f;
1676             if (ow_block < nb_oc_blocking * jcp.oc_block && eff > eff_threshold)
1677                 break;
1678             if (div_up(jcp.ow, ow_block) != nb_ow)
1679                 continue;
1680             float thr_eff = get_thr_eff(nb_oc_blocking, ow_block);
1681             float eff_step = (jcp.ver == ver_4fma) ? 1.1f : 1.f;
1682             if (ow_block >= 2 * ur_w && thr_eff > eff_step * eff) {
1683                 ow_block_thr = ow_block;
1684                 eff = thr_eff;
1685             }
1686             eff_threshold = (jcp.ver == ver_4fma) ? 0.9f : 0.98f;
1687             if (eff > eff_threshold)
1688                 break;
1689         }
1690         res_ow_block = nstl::min(jcp.ow, nstl::max(2 * ur_w, ow_block_thr));
1691         eff = get_thr_eff(nb_oc_blocking, res_ow_block);
1692         return res_ow_block;
1693     };
1694
1695
1696     if (jcp.ver == ver_fma && mayiuse(avx512_core)) {
1697         int try_nb_oc_blocking = 2;
1698         unsigned int ker_inp_size = typesize * div_up(jcp.iw, jcp.stride_w)
1699             * jcp.ic_block * jcp.kh * jcp.kd;
1700         unsigned int ker_out_size = typesize * jcp.ow * jcp.oc_block
1701             * try_nb_oc_blocking;
1702         unsigned int ker_wei_size = typesize * jcp.kh * jcp.kw * jcp.ic_block
1703             * jcp.oc_block * try_nb_oc_blocking * jcp.kd;
1704         unsigned int ker_total_size = ker_inp_size + ker_out_size
1705             + ker_wei_size;
1706
1707         bool embd_bcast_condition = true
1708             && (jcp.kw == 3 && jcp.ow <= 28 && ker_total_size < L1_cache_size)
1709             && !(jcp.kw == 3 && jcp.ow == 13 && jcp.ic >= 192)
1710             && !(jcp.kw == 3 && jcp.ow == 28 && jcp.ic >= 512);
1711
1712         if (jcp.mb == 1) {
1713             unsigned int inp_size = jcp.mb * div_up(jcp.ih, jcp.stride_h)
1714                     * div_up(jcp.iw, jcp.stride_w) * jcp.ic;
1715             unsigned int wei_size = jcp.ic * jcp.oc * jcp.kh * jcp.kw;
1716
1717             // Estimate whether we need to limit the number of threads
1718             // and calculate this number. Includes some heuristic.
1719             int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
1720             int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.oh;
1721             int job_size_min = work_amount / nthreads;
1722             int job_size_max = div_up(work_amount, nthreads);
1723             int ch_max = rnd_up(jcp.oh, job_size_max);
1724             int ch_min = (job_size_min == 0)
1725                 ? jcp.oh
1726                 : rnd_up(jcp.oh, job_size_min);
1727             bool not_aligned_max = ch_max % jcp.oh != 0 && ch_max / jcp.oh < 2
1728                     && (jcp.oh != 8 || ch_max / jcp.oh > 1);
1729             bool not_aligned_min = ch_min % jcp.oh != 0 && ch_min / jcp.oh < 2
1730                     && (jcp.oh != 8 || ch_min / jcp.oh > 1);
1731             bool eligible_case = (jcp.stride_h == 1 && jcp.stride_w == 1)
1732                     || nthreads > oc_chunks;
1733             if (jcp.loop_order == loop_cgn && oc_chunks > 1 && nthreads > 1
1734                 && wei_size / inp_size > 24
1735                 && (not_aligned_max || not_aligned_min)
1736                 && eligible_case) {
1737                 jcp.aligned_threads = nthreads;
1738                 for (int i = nthreads; i > 0; i--) {
1739                     if (oc_chunks % i == 0 || i % oc_chunks == 0) {
1740                         jcp.aligned_threads = i;
1741                         break;
1742                     }
1743                 }
1744             }
1745         }
1746
1747         if (jcp.kw > 3
1748                 || (jcp.stride_w == 1 && jcp.stride_h == 1
1749                            && embd_bcast_condition)
1750                 || ((jcp.stride_w != 1 || jcp.stride_h != 1)
1751                            && ((jcp.mb <= 16 && (jcp.oc <= 192 || jcp.oh <= 10)
1752                                       && embd_bcast_condition)))
1753                 || (jcp.mb == 1
1754                            && (jcp.ur_w >= jcp.ow || jcp.is_1stconv
1755                                       || (jcp.ow <= 147 && jcp.oc <= 96)))) {
1756             jcp.kernel_kind = embd_bcast;
1757             jcp.ur_w = nstl::min(jcp.ow, regs);
1758             jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
1759             if (ker_total_size < L1_cache_size && jcp.ow <= 8 && jcp.kh <= 3
1760                     && jcp.kw <= 3 && jcp.nb_oc % try_nb_oc_blocking == 0
1761                     && IMPLICATION(jcp.is_1stconv, jcp.mb == 1)
1762                     && IMPLICATION(jcp.mb == 1, jcp.ur_w < jcp.ow)) {
1763                 jcp.nb_oc_blocking = try_nb_oc_blocking;
1764                 jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1));
1765             }
1766         } else {
1767             jcp.kernel_kind = expl_bcast;
1768             jcp.nb_ic_blocking = 1;
1769             if (IMPLICATION(jcp.is_1stconv, jcp.mb > 1)) {
1770                 float best_thr_eff = 0.f;
1771                 int best_nb_oc_blocking = 1;
1772                 for (int i = nstl::min(jcp.nb_oc, 5); i > 0; i--) {
1773                     if (jcp.nb_oc % i == 0) {
1774                         float thr_eff;
1775                         int ur_w = nstl::min(jcp.ow, 31 / (i + 1));
1776                         get_ow_block(i, ur_w, thr_eff);
1777                         if (thr_eff > 1.05f * best_thr_eff) {
1778                             best_nb_oc_blocking = i;
1779                             best_thr_eff = thr_eff;
1780                         }
1781                     }
1782                 }
1783                 jcp.nb_oc_blocking = best_nb_oc_blocking;
1784                 jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1));
1785             }
1786         }
1787     }
1788
1789     jcp.ur_w_tail = jcp.ow % jcp.ur_w;
1790
1791     args_ok = true
1792         && jcp.l_pad <= jcp.ur_w
1793         && jcp.ic <= src_d.blocking_desc().padding_dims[1]
1794         && jcp.oc <= dst_d.blocking_desc().padding_dims[1]
1795         && jcp.ic <= weights_d.blocking_desc().padding_dims[with_groups + 1]
1796         && jcp.oc <= weights_d.blocking_desc().padding_dims[with_groups + 0];
1797     if (!args_ok)
1798         return status::unimplemented;
1799
1800     int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
1801                     + (jcp.kw - 1) * (jcp.dilate_w + 1)
1802                     - (jcp.iw + jcp.l_pad - 1));
1803     if (r_pad_no_tail > jcp.ur_w)
1804         return status::unimplemented;
1805
1806     pick_loop_order(jcp);
1807
1808     jcp.nb_ic_L2 = jcp.nb_ic;
1809
1810     float thr_eff;
1811     jcp.ow_block = get_ow_block(jcp.nb_oc_blocking, jcp.ur_w, thr_eff);
1812     jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
1813
1814     const int L2_size = get_cache_size(2, true) / sizeof(float);
1815     // Source and output data needs to fit in L2,
1816     // leaving some space for weights and prefetching.
1817     int h_L2 = int(((0.6f * L2_size) / jcp.simd_w
1818                            - nstl::min(0, jcp.kh - jcp.stride_h) * jcp.iw)
1819             / (jcp.stride_h * jcp.iw + jcp.ow));
1820     jcp.h_blocking = nstl::max(1, nstl::min(jcp.oh, h_L2));
1821
1822     // TODO check for 4vnni
1823     if (jcp.ver == ver_4fma) {
1824         if (!is_ow_threading_on(jcp)) {
1825             for (int divf = 2, temp_nb = jcp.nb_ic_L2; divf <= jcp.nb_ic;
1826                   divf++) {
1827                 size_t l2_src
1828                     = (size_t)jcp.iw * jcp.ic_block * jcp.ih * temp_nb * jcp.id;
1829                 size_t l2_dst = (size_t)jcp.ow * jcp.oc_block * jcp.nb_oc_blocking
1830                     * jcp.oh * jcp.od;
1831                 size_t l2_filt = (size_t)jcp.kw * jcp.oc_block * jcp.ic_block
1832                     * jcp.kh * jcp.nb_oc_blocking * temp_nb * jcp.kd;
1833                 if (4 * (l2_src + l2_dst + l2_filt) > KNx_L2_EFFECTIVE_CAPACITY) {
1834                     if (jcp.kh == 3 && jcp.oh == 7) {
1835                         jcp.nb_ic_L2 = 1;
1836                         break;
1837                     }
1838                     temp_nb = (jcp.nb_ic_L2 % divf == 0 ? jcp.nb_ic_L2 / divf
1839                                     : jcp.nb_ic_L2);
1840                 } else {
1841                     jcp.nb_ic_L2 = temp_nb;
1842                     break;
1843                 }
1844             }
1845         } else if (jcp.ic > 64) {
1846             jcp.nb_ic_L2 = 2; /* according to performance data*/
1847         }
1848     }
1849
1850     return status::success;
1851 }
1852
1853 void jit_avx512_common_conv_fwd_kernel::init_scratchpad(
1854         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
1855     if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
1856         scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc);
1857 }
1858
1859 void jit_avx512_common_conv_bwd_data_kernel_f32::prepare_output(int ur_w)
1860 {
1861     for (int k = 0; k < jcp.nb_ic_blocking; k++) {
1862         for (int j = 0; j < ur_w; j++) {
1863             Zmm zmm = zmm_out(j, k);
1864             vpxord(zmm, zmm, zmm);
1865             size_t aux_src_offset
1866                 = (size_t)typesize * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j)
1867                 * jcp.ic_block;
1868             mic_prefetcht1(EVEX_compress_addr_safe(reg_src_prf, aux_src_offset,
1869                         reg_long_offt));
1870         }
1871     }
1872 }
1873
1874 void jit_avx512_common_conv_bwd_data_kernel_f32::store_output(int ur_w)
1875 {
1876     Label no_update_label;
1877
1878     mov(reg_channel, ptr[param + GET_OFF(channel)]);
1879     cmp(reg_channel, 0);
1880     je(no_update_label, T_NEAR);
1881     for (int k = 0; k < jcp.nb_ic_blocking; k++) {
1882         for (int j = 0; j < ur_w; j++) {
1883             Zmm zmm = zmm_out(j, k);
1884             size_t aux_src_offset = (size_t)typesize
1885                 * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block;
1886             vadd(zmm, EVEX_compress_addr_safe(reg_src, aux_src_offset,
1887                         reg_long_offt));
1888         }
1889     }
1890
1891     L(no_update_label);
1892     for (int k = 0; k < jcp.nb_ic_blocking; k++) {
1893         for (int j = 0; j < ur_w; j++) {
1894             Zmm zmm = zmm_out(j, k);
1895             size_t aux_src_offset = (size_t)typesize
1896                 * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block;
1897             vmovups(EVEX_compress_addr_safe(reg_src, aux_src_offset,
1898                         reg_long_offt), zmm);
1899             mic_prefetcht0(EVEX_compress_addr_safe(reg_src_prf, aux_src_offset,
1900                         reg_long_offt));
1901         }
1902     }
1903 }
1904
1905 void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_4fma(
1906         int ur_w, int l_overflow, int r_overflow)
1907 {
1908     int ow = jcp.ow;
1909     int kw = jcp.kw;
1910     int ic_block = jcp.ic_block;
1911     int oc_block = jcp.oc_block;
1912     Label kh_label, last_iter_label, loop_end_label, kd_label;
1913     int ker_load_number = 4;
1914     int shift_ker_ptr = typesize * kw * oc_block * ic_block;
1915     int shift_dst_ptr = typesize * ow * oc_block;
1916     int ii_dpref_t0 = get_iw_start(0, l_overflow);
1917     int iw_end_ipref = get_iw_end(ur_w, 0, r_overflow);
1918
1919     bool check_last_kh = (jcp.kh > 3);
1920     auto kernel_offset = [=](int icb, int oc, int ki) {
1921         int blk_idx = icb * jcp.kh * jcp.kw * jcp.kd + ki;
1922         int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block;
1923         int oc_offset = oc * jcp.oc_block;
1924         return typesize * (blk_offset + oc_offset);
1925     };
1926     auto kernel_loads = [=](int ki, int oc, int kk) {
1927         for (int ii = 0; ii < ker_load_number; ii++) {
1928             int aux_kernel_offset = kernel_offset(kk, oc + ii, ki);
1929             vmovups(zmm_ker(ii),
1930                 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
1931         }
1932     };
1933     auto prefetch_dst_next_kh = [&](int ki, int ki_start, int cnt0, int cnt1) {
1934         if (cnt1 >= ker_load_number && cnt0 >= ker_load_number
1935             && ki >= ki_start && ii_dpref_t0 < iw_end_ipref) {
1936             int aux_dst_offset = typesize * ((ii_dpref_t0
1937                 + jcp.l_pad) * oc_block + jcp.ow * oc_block);
1938             prefetcht0(EVEX_compress_addr(aux_reg_dst, aux_dst_offset));
1939             ii_dpref_t0++;
1940         }
1941     };
1942
1943     if (one_of(jcp.ndims, 3, 4)) {
1944         mov(aux_reg_dst, reg_dst);
1945         mov(aux_reg_ker, reg_ker);
1946         mov(aux_reg_dst_prf, reg_dst_prf);
1947         mov(aux_reg_ker_prf, reg_ker_prf);
1948     }
1949
1950     if (jcp.ndims == 5) {
1951         push(reg_src_prf);
1952         push(reg_src);
1953
1954         mov(reg_ki, ptr[param + GET_OFF(kd_padding)]);
1955         mov(aux_reg_dst_d, reg_dst);
1956         mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]);
1957         mov(aux_reg_dst_d_prf, reg_dst_prf);
1958         mov(aux_reg_ker_d_prf, reg_ker_prf);
1959
1960         L(kd_label);
1961         mov(reg_kj, ptr[param + GET_OFF(kh_padding)]);
1962     } else {
1963         mov(reg_kj, reg_kh);
1964     }
1965
1966     if (jcp.ndims == 5) {
1967         mov(aux_reg_dst, aux_reg_dst_d);
1968         mov(aux_reg_ker, aux_reg_ker_d);
1969         mov(aux_reg_dst_prf, aux_reg_dst_d_prf);
1970         mov(aux_reg_ker_prf, aux_reg_ker_d_prf);
1971     }
1972
1973     align(16);
1974     L(kh_label);
1975     if (check_last_kh) {
1976         for (int ki = 0; ki < kw; ki++)
1977         for (int oc = 0; oc < oc_block; oc += 4)
1978         for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) {
1979             bool last_kernel_loads = (kk == jcp.nb_ic_blocking - 1
1980                 && ki == kw - 1 && (oc + 4) == oc_block);
1981
1982             if (last_kernel_loads) {
1983                 cmp(reg_kj, 1);
1984                 je(last_iter_label, T_NEAR);
1985             }
1986
1987             kernel_loads(ki, oc, kk);
1988             for (int ii = get_iw_start(ki, l_overflow),
1989                     prf_count_t0 = 0, prf_count_t1 = 0;
1990                     ii < get_iw_end(ur_w, ki, r_overflow); ii++) {
1991                 int aux_dst_offset = typesize
1992                     * ((ii + jcp.l_pad - ki) * oc_block + oc);
1993                 v4fmaddps(zmm_out(ii, kk), zmm_ker(0),
1994                     EVEX_compress_addr(aux_reg_dst, aux_dst_offset));
1995
1996                 if (ii % 2) {
1997                     if (prf_count_t0 < 4) {
1998                         int aux_kernel_prf;
1999                         if (last_kernel_loads)
2000                             aux_kernel_prf= kernel_offset(0, prf_count_t0
2001                                 + oc + 4 - oc_block, 0) + typesize * kw
2002                                 * oc_block * ic_block;
2003                         else
2004                             aux_kernel_prf = kernel_offset(kk, oc + 4
2005                                 + prf_count_t0, ki);
2006                         mic_prefetcht0(EVEX_compress_addr(aux_reg_ker,
2007                             aux_kernel_prf));
2008                         prf_count_t0++;
2009                     } else if (prf_count_t1 < 4) {
2010                         mic_prefetcht1(EVEX_compress_addr(aux_reg_ker_prf,
2011                             kernel_offset(kk, oc + prf_count_t1, ki)));
2012                         prf_count_t1++;
2013                     }
2014                 } else
2015                     prefetch_dst_next_kh(ki, 2, prf_count_t0, prf_count_t1);
2016             }
2017             if (last_kernel_loads) {
2018                 jmp(loop_end_label, T_NEAR);
2019
2020                 L(last_iter_label);
2021
2022                 kernel_loads(ki, oc, kk);
2023                 for (int ii = get_iw_start(ki, l_overflow),
2024                         prf_count_t0 = 0, prf_count_t1 = 0;
2025                         ii < get_iw_end(ur_w, ki, r_overflow); ii++) {
2026                     int aux_dst_offset = typesize
2027                         * ((ii + jcp.l_pad - ki) * oc_block + oc);
2028                     v4fmaddps(zmm_out(ii, kk), zmm_ker(0),
2029                             EVEX_compress_addr(aux_reg_dst, aux_dst_offset));
2030                     if (ii % 2) {
2031                         if (prf_count_t0 < 4) {
2032                             mic_prefetcht0(EVEX_compress_addr(aux_reg_ker_prf,
2033                                 kernel_offset(0, prf_count_t0, 0)));
2034                             prf_count_t0++;
2035                         } else if (prf_count_t1 < 4) {
2036                             mic_prefetcht1(EVEX_compress_addr(aux_reg_ker_prf,
2037                                 kernel_offset(kk, oc + prf_count_t1, ki)));
2038                             prf_count_t1++;
2039                         }
2040                     }
2041                 }
2042                 L(loop_end_label);
2043             }
2044         }
2045     } else {
2046         for (int ki = 0; ki < kw; ki++)
2047         for (int oc = 0; oc < oc_block; oc += 4)
2048         for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) {
2049             kernel_loads(ki, oc, kk);
2050
2051             for (int ii = get_iw_start(ki, l_overflow), prf_count_t1 = 0;
2052                     ii < get_iw_end(ur_w, ki, r_overflow); ii++) {
2053                 int aux_dst_offset = typesize
2054                     * ((ii + jcp.l_pad - ki) * oc_block + oc);
2055                 v4fmaddps(zmm_out(ii, kk), zmm_ker(0),
2056                     EVEX_compress_addr(aux_reg_dst, aux_dst_offset));
2057                 if ((ii % 2) && (prf_count_t1 < 4)) {
2058                     mic_prefetcht1(EVEX_compress_addr(
2059                         aux_reg_ker_prf, kernel_offset(kk,
2060                         oc + prf_count_t1, ki)));
2061                     prf_count_t1++;
2062                 }
2063                 if ( ki == 1 && oc == 0 && kk == 0)
2064                     mic_prefetcht1(EVEX_compress_addr(
2065                         aux_reg_dst_prf, aux_dst_offset));
2066             }
2067         }
2068     }
2069
2070     add(aux_reg_ker, shift_ker_ptr);
2071     sub(aux_reg_dst, shift_dst_ptr);
2072     add(aux_reg_ker_prf, shift_ker_ptr);
2073     sub(aux_reg_dst_prf, shift_dst_ptr);
2074
2075     dec(reg_kj);
2076     cmp(reg_kj, 0);
2077     jg(kh_label, T_NEAR);
2078
2079     if (jcp.ndims == 5) {
2080         sub(aux_reg_dst_d, typesize * (jcp.oh * ow) * ic_block);
2081         add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block * ic_block);
2082         sub(aux_reg_dst_d_prf, typesize * (jcp.oh * ow) * ic_block);
2083         add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh *oc_block * ic_block);
2084
2085         dec(reg_ki);
2086         cmp(reg_ki, 0);
2087         jg(kd_label, T_NEAR);
2088
2089         pop(reg_src);
2090         pop(reg_src_prf);
2091     }
2092 }
2093
2094 void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_vnni(
2095         int ur_w, int l_overflow, int r_overflow)
2096 {
2097     int ow = jcp.ow;
2098     int kw = jcp.kw;
2099     int ic_block = jcp.ic_block;
2100     int oc_block = jcp.oc_block;
2101     const int channel_inc = jcp.ver == ver_4vnni ? 4 : 1;
2102     const int ker_load_number = jcp.ver == ver_4vnni ? 4 : 1;
2103     Label kh_label;
2104
2105     auto kernel_offset = [=](int icb, int oc, int ki) {
2106         int blk_idx = icb * jcp.kh * jcp.kw + ki;
2107         int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block;
2108         int oc_offset = oc * jcp.oc_block;
2109         return jcp.typesize_in * (blk_offset + oc_offset);
2110     };
2111
2112     mov(aux_reg_dst, reg_dst);
2113     mov(aux_reg_ker, reg_ker);
2114     mov(aux_reg_dst_prf, reg_dst_prf);
2115     mov(aux_reg_ker_prf, reg_ker_prf);
2116
2117     mov(reg_kj, reg_kh);
2118     L(kh_label); {
2119         for (int ki = 0; ki < kw; ki++) {
2120             int jj_start = get_iw_start(ki, l_overflow);
2121             int jj_end = get_iw_end(ur_w, ki, r_overflow);
2122             for (int oc = 0; oc < oc_block / 2; oc += channel_inc) {
2123                 if (jcp.kernel_kind == expl_bcast) {
2124                     for (int jj = jj_start; jj < jj_end; jj++) {
2125                         int aux_dst_offset = jcp.typesize_in
2126                             * ((jj + jcp.l_pad - ki) * oc_block + 2 * oc);
2127                         vpbroadcastd(zmm_inp(jj, jcp.nb_ic_blocking),
2128                             ptr[aux_reg_dst + aux_dst_offset]);
2129                     }
2130                 }
2131                 for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) {
2132                     if (jcp.kernel_kind == expl_bcast) {
2133                         int aux_kernel_offset = kernel_offset(kk, 2 * oc, ki);
2134                         vmovups(zmm_wei,
2135                             EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
2136                     } else {
2137                         for (int ii = 0; ii < ker_load_number; ii++) {
2138                             int aux_kernel_offset
2139                                 = kernel_offset(kk, 2 * (oc + ii), ki);
2140                             vmovups(zmm_ker(ii),
2141                                 EVEX_compress_addr(aux_reg_ker,
2142                                 aux_kernel_offset));
2143                         }
2144                     }
2145
2146                     for (int jj = jj_start, prf_count = 0; jj < jj_end; jj++) {
2147                         int aux_dst_offset = jcp.typesize_in
2148                             * ((jj + jcp.l_pad - ki) * oc_block + 2 * oc);
2149                         if (jcp.kernel_kind == expl_bcast) {
2150                             vpdpwssd(zmm_out(jj, kk), zmm_wei,
2151                                 zmm_inp(jj, jcp.nb_ic_blocking));
2152                         } else {
2153                             vpXdpwssd(zmm_out(jj, kk), zmm_ker(0),
2154                                 aux_reg_dst, aux_dst_offset);
2155                         }
2156
2157                         if ((jj % 2) && (prf_count < 4)) {
2158                             int aux_kernel_prf
2159                                 = kernel_offset(kk, oc + prf_count, ki);
2160                             mic_prefetcht1(EVEX_compress_addr(
2161                                     aux_reg_ker_prf, aux_kernel_prf));
2162                             prf_count++;
2163                         }
2164                         if (!(jj % 2) && ki == 0 && oc == 0 && kk == 0) {
2165                             mic_prefetcht1(EVEX_compress_addr(aux_reg_dst_prf,
2166                                     aux_dst_offset));
2167                         }
2168                         if (!(jj % 2) && ki == 1 && oc == 0 && kk == 0) {
2169                             mic_prefetcht0(EVEX_compress_addr(aux_reg_dst,
2170                                     aux_dst_offset + jcp.typesize_in
2171                                     * ow * oc_block));
2172                         }
2173                     }
2174                 }
2175             }
2176         }
2177
2178         add(aux_reg_ker, jcp.typesize_in * kw * oc_block * ic_block);
2179         sub(aux_reg_dst, jcp.typesize_in * ow * oc_block);
2180         add(aux_reg_ker_prf, jcp.typesize_in * kw * oc_block * ic_block);
2181         sub(aux_reg_dst_prf, jcp.typesize_in * ow * oc_block);
2182
2183         dec(reg_kj);
2184         cmp(reg_kj, 0);
2185         jg(kh_label, T_NEAR);
2186     }
2187 }
2188
2189 void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma(
2190         int ur_w, int l_overflow, int r_overflow)
2191 {
2192     Label kh_label, kd_label;
2193     int kw = jcp.kw;
2194     int ow = jcp.ow;
2195
2196     int ic_block = jcp.ic_block;
2197     int oc_block = jcp.oc_block;
2198     int l_pad = jcp.l_pad;
2199     int dilate_w = jcp.dilate_w + 1;
2200     int stride_w = jcp.stride_w;
2201     int stride_h = jcp.stride_h;
2202
2203     int ker_pipeline_depth = 4;
2204     assert(ker_reg_base_idx + ker_pipeline_depth <= 32);
2205     assert(oc_block >= ker_pipeline_depth);
2206
2207     int num_ker_loads = oc_block * kw;
2208     int num_inp_prfs = ur_w * nstl::min(kw, stride_w)
2209                        + nstl::max(0, kw - stride_w);
2210     int num_prfs = num_ker_loads + num_inp_prfs;
2211     int num_fmas = num_ker_loads * ur_w / stride_w;
2212     int prf_inst_spacing = nstl::max(1, num_fmas / num_prfs);
2213     int prf_inst_trigger = (num_fmas % prf_inst_spacing) / 2;
2214
2215     if (one_of(jcp.ndims, 3, 4)) {
2216         mov(aux_reg_dst, reg_dst);
2217         mov(aux_reg_ker, reg_ker);
2218
2219         mov(aux_reg_dst_prf, reg_dst_prf);
2220         mov(aux_reg_ker_prf, reg_ker_prf);
2221     }
2222
2223     if (jcp.ndims == 5) {
2224         push(reg_src_prf);
2225         push(reg_src);
2226
2227         mov(reg_ki, ptr[param + GET_OFF(kd_padding)]);
2228         mov(aux_reg_dst_d, reg_dst);
2229         mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]);
2230         mov(aux_reg_dst_d_prf, reg_dst_prf);
2231         mov(aux_reg_ker_d_prf, reg_ker_prf);
2232
2233         L(kd_label);
2234         mov(reg_kj, ptr[param + GET_OFF(kh_padding)]);
2235     } else {
2236         mov(reg_kj, reg_kh);
2237     }
2238
2239     if (jcp.ndims == 5) {
2240         mov(aux_reg_dst, aux_reg_dst_d);
2241         mov(aux_reg_ker, aux_reg_ker_d);
2242         mov(aux_reg_dst_prf, aux_reg_dst_d_prf);
2243         mov(aux_reg_ker_prf, aux_reg_ker_d_prf);
2244     }
2245
2246     L(kh_label); {
2247         int step = 0;
2248         int ker_prfs = 0;
2249         for (int ki = 0; ki < kw; ki++) {
2250             for (int oc = 0; oc < oc_block; oc++) {
2251                 if (step == 0) {
2252                     for (int i = 0; i < ker_pipeline_depth; i++) {
2253                         int aux_kernel_offset = typesize * ((oc + i) * oc_block
2254                                 + ki * ic_block * oc_block);
2255                         vmovups(zmm_ker(i), EVEX_compress_addr(
2256                                     aux_reg_ker, aux_kernel_offset));
2257                     }
2258                 } else if (step < num_ker_loads - ker_pipeline_depth + 1) {
2259                     int load_offset = ker_pipeline_depth - 1;
2260                     int ker_load_reg_idx
2261                         = (step + load_offset) % ker_pipeline_depth;
2262                     int aux_kernel_offset = typesize * ((oc + load_offset)
2263                             * oc_block + ki * ic_block * oc_block);
2264                     vmovups(zmm_ker(ker_load_reg_idx),
2265                             EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
2266                 }
2267
2268                 bool ker_prf_inserted = false;
2269                 auto zmm_kernel = zmm_ker(step % ker_pipeline_depth);
2270
2271                 int jj_start = get_iw_start(ki, l_overflow);
2272                 int jj_end = get_iw_end(ur_w, ki, r_overflow);
2273                 assert(stride_w != 1
2274                         || jj_start == nstl::max(0,
2275                             l_overflow - (kw - 1 - ki) * dilate_w));
2276                 assert(stride_w != 1
2277                         || jj_end == ur_w - nstl::max(0,
2278                             r_overflow - ki * dilate_w));
2279
2280                 for (int jj = jj_start; jj < jj_end; jj += stride_w) {
2281                     assert((jj + l_pad - ki * dilate_w) % stride_w == 0);
2282                     int aux_dst_offset = typesize *
2283                         (((jj + l_pad - ki * dilate_w)
2284                                 / stride_w) * jcp.oc_block + oc);
2285                     vfmadd231ps(zmm_out(jj, 0), zmm_kernel,
2286                         EVEX_compress_addr(aux_reg_dst, aux_dst_offset, true));
2287
2288                     int fma_idx = (step * ur_w + jj) / stride_w;
2289                     int prf_slot_idx = fma_idx / prf_inst_spacing;
2290                     if (fma_idx % prf_inst_spacing == prf_inst_trigger) {
2291                         if (!ker_prf_inserted && ker_prfs < num_ker_loads) {
2292                             int ker_prf_offset = typesize
2293                                 * ker_prfs * jcp.oc_block;
2294                             mic_prefetcht1(EVEX_compress_addr(
2295                                         aux_reg_ker_prf, ker_prf_offset));
2296                             ker_prf_inserted = true;
2297                             ker_prfs++;
2298                         } else {
2299                             int inp_prf_idx = prf_slot_idx - ker_prfs;
2300                             if (inp_prf_idx < num_inp_prfs) {
2301                                 int inp_prf_offset
2302                                     = ic_block * typesize
2303                                     * ((inp_prf_idx / kw) * kw
2304                                             + (inp_prf_idx % kw));
2305                                 mic_prefetcht0(EVEX_compress_addr(
2306                                             aux_reg_dst_prf, inp_prf_offset));
2307                             }
2308                         }
2309                     }
2310                 }
2311                 step++;
2312             }
2313         }
2314
2315         add(aux_reg_ker, typesize * stride_h * kw * oc_block * ic_block);
2316         sub(aux_reg_dst, typesize * (jcp.dilate_h + 1) * ow * oc_block);
2317         add(aux_reg_ker_prf, typesize * stride_h * kw * oc_block * ic_block);
2318         sub(aux_reg_dst_prf, typesize * (jcp.dilate_h + 1) * ow * oc_block);
2319
2320         dec(reg_kj);
2321         cmp(reg_kj, 0);
2322         jg(kh_label, T_NEAR);
2323     }
2324     if (jcp.ndims == 5) {
2325         sub(aux_reg_dst_d,
2326                 typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block);
2327         add(aux_reg_ker_d, typesize * jcp.stride_d * jcp.kw * jcp.kh
2328                 * oc_block * ic_block);
2329         sub(aux_reg_dst_d_prf,
2330                 typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block);
2331         add(aux_reg_ker_d_prf, typesize * jcp.stride_d * jcp.kw * jcp.kh
2332                 * oc_block * ic_block);
2333
2334         dec(reg_ki);
2335         cmp(reg_ki, 0);
2336         jg(kd_label, T_NEAR);
2337     }
2338
2339     if (jcp.ndims == 5)
2340     {
2341         pop(reg_src);
2342         pop(reg_src_prf);
2343     }
2344 }
2345
2346 void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma_core(
2347         int ur_w, int l_overflow, int r_overflow)
2348 {
2349     int kw = jcp.kw;
2350     int ow = jcp.ow;
2351     int dilate_w = jcp.dilate_w + 1;
2352     int stride_w = jcp.stride_w;
2353     int ic_block = jcp.ic_block;
2354     int oc_block = jcp.oc_block;
2355     int nb_ic_block = jcp.nb_ic_blocking;
2356     Label kh_label, kd_label;
2357
2358     int shift_ker_ptr = typesize * kw * oc_block * ic_block;
2359     int shift_dst_ptr = typesize * (jcp.dilate_h + 1) * ow * oc_block;
2360
2361     auto output_offset = [=](int oi, int oc, int ki) {
2362         return typesize *
2363             (((oi + jcp.l_pad - ki * dilate_w) / stride_w) * oc_block + oc);
2364     };
2365     auto kernel_offset = [=](int icb, int oc, int ki) {
2366         int blk_idx = icb * jcp.kh * jcp.kw * jcp.kd + ki;
2367         int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block;
2368         int oc_offset = oc * jcp.oc_block;
2369         return typesize * (blk_offset + oc_offset);
2370     };
2371
2372     if (one_of(jcp.ndims, 3, 4)) {
2373         mov(aux_reg_dst, reg_dst);
2374         mov(aux_reg_ker, reg_ker);
2375     }
2376
2377     if (jcp.ndims == 5) {
2378         push(reg_src_prf);
2379         push(reg_src);
2380
2381         mov(reg_ki, ptr[param + GET_OFF(kd_padding)]);
2382         mov(aux_reg_dst_d, reg_dst);
2383         mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]);
2384
2385         L(kd_label);
2386         mov(reg_kj, ptr[param + GET_OFF(kh_padding)]);
2387     } else {
2388         mov(reg_kj, reg_kh);
2389     }
2390
2391     if (jcp.ndims == 5) {
2392         mov(aux_reg_dst, aux_reg_dst_d);
2393         mov(aux_reg_ker, aux_reg_ker_d);
2394     }
2395
2396     L(kh_label);
2397     {
2398         for (int ki = 0; ki < kw; ki++) {
2399             int jj_start = get_iw_start(ki, l_overflow);
2400             int jj_end = get_iw_end(ur_w, ki, r_overflow);
2401             for (int oc = 0; oc < oc_block; oc++) {
2402                 if (jcp.kernel_kind == expl_bcast) {
2403                     for (int jj = jj_start; jj < jj_end; jj++) {
2404                         int aux_output_offset = output_offset(jj, oc, ki);
2405                         vbroadcastss(zmm_inp(jj, nb_ic_block),
2406                             ptr[aux_reg_dst + aux_output_offset]);
2407                     }
2408                 }
2409                 for (int ii = 0; ii < nb_ic_block; ii++) {
2410                     int aux_kernel_offset = kernel_offset(ii, oc, ki);
2411                     if (jj_end - jj_start > 0)
2412                         vmovups(zmm_wei, EVEX_compress_addr(aux_reg_ker,
2413                             aux_kernel_offset));
2414                     for (int jj = jj_start; jj < jj_end; jj += stride_w)
2415                         if (jcp.kernel_kind == expl_bcast)
2416                             vfmadd231ps(zmm_out(jj, ii),
2417                                 zmm_inp(jj, nb_ic_block), zmm_wei);
2418                         else
2419                             vfmadd231ps(zmm_out(jj, ii), zmm_wei,
2420                                 EVEX_compress_addr(aux_reg_dst,
2421                                 output_offset(jj, oc, ki), true));
2422                 }
2423             }
2424         }
2425         add(aux_reg_ker, shift_ker_ptr);
2426         sub(aux_reg_dst, shift_dst_ptr);
2427         dec(reg_kj);
2428         cmp(reg_kj, 0);
2429         jg(kh_label, T_NEAR);
2430     }
2431
2432     if (jcp.ndims == 5) {
2433         sub(aux_reg_dst_d,
2434                 typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block);
2435         add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block * ic_block);
2436
2437         dec(reg_ki);
2438         cmp(reg_ki, 0);
2439         jg(kd_label, T_NEAR);
2440
2441         pop(reg_src);
2442         pop(reg_src_prf);
2443     }
2444 }
2445
2446 inline void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop(
2447         int ur_w, int l_overflow, int r_overflow)
2448 {
2449     if (jcp.ndims == 5) push(reg_oi);
2450
2451     prepare_output(ur_w);
2452
2453     Label skip_compute_loop;
2454     if (jcp.ndims == 5) {
2455         mov(reg_kj, ptr[param + GET_OFF(kd_padding)]);
2456         cmp(reg_kj, 0);
2457         je(skip_compute_loop, T_NEAR);
2458     }
2459     mov(reg_kj, ptr[param + GET_OFF(kh_padding)]);
2460     cmp(reg_kj, 0);
2461     je(skip_compute_loop, T_NEAR);
2462
2463     if (jcp.ver == ver_4vnni || jcp.ver == ver_vnni)
2464         compute_loop_vnni(ur_w, l_overflow, r_overflow);
2465     else if (jcp.ver == ver_4fma)
2466         compute_loop_4fma(ur_w, l_overflow, r_overflow);
2467     else if (jcp.ver == ver_fma)
2468         if (mayiuse(avx512_mic))
2469             compute_loop_fma(ur_w, l_overflow, r_overflow);
2470         else
2471           if (jcp.kernel_kind == embd_bcast && jcp.nb_ic_blocking == 1)
2472               compute_loop_fma(ur_w, l_overflow, r_overflow);
2473           else
2474               compute_loop_fma_core(ur_w, l_overflow, r_overflow);
2475     else
2476         assert("!unknown convolution version");
2477
2478     L(skip_compute_loop);
2479     store_output(ur_w);
2480     if (jcp.ndims == 5) pop(reg_oi);
2481 }
2482
2483 void jit_avx512_common_conv_bwd_data_kernel_f32::generate()
2484 {
2485     int iw = jcp.iw;
2486     int kw = jcp.kw;
2487     int ur_w = jcp.ur_w;
2488     int ic_block = jcp.ic_block;
2489     int oc_block = jcp.oc_block;
2490     int ur_w_tail = jcp.ur_w_tail;
2491     int dilate_w = jcp.dilate_w + 1;
2492     int stride_w = jcp.stride_w;
2493
2494     int dst_shift = jcp.typesize_in * (ur_w / stride_w) * ic_block;
2495     int src_shift = jcp.typesize_out * ur_w * oc_block;
2496
2497     preamble();
2498
2499     mov(reg_src, ptr[param + GET_OFF(src)]);
2500     mov(reg_dst, ptr[param + GET_OFF(dst)]);
2501     mov(reg_ker, ptr[param + GET_OFF(filt)]);
2502
2503     mov(reg_kh, ptr[param + GET_OFF(kh_padding)]);
2504     mov(reg_src_prf, ptr[param + GET_OFF(src_prf)]);
2505     mov(reg_dst_prf, ptr[param + GET_OFF(dst_prf)]);
2506     mov(reg_ker_prf, ptr[param + GET_OFF(filt_prf)]);
2507
2508     int l_overflow = nstl::max(0, ((kw - 1) * dilate_w - jcp.l_pad) / stride_w);
2509     int r_overflow = nstl::max(0, ((kw - 1) * dilate_w
2510                     - nstl::max(0, jcp.r_pad)) / stride_w);
2511     int r_overflow1 = nstl::max(0, ((kw - 1) * dilate_w
2512                     - nstl::max(0, jcp.r_pad) - ur_w_tail) / stride_w);
2513
2514     int n_oi = iw / ur_w;
2515     if (r_overflow1 > 0) n_oi--;
2516
2517     if (ur_w == iw) {
2518         compute_loop(ur_w, l_overflow, r_overflow);
2519     } else if (n_oi == 0) {
2520         compute_loop(ur_w, l_overflow, r_overflow1);
2521         add(reg_src, src_shift);
2522         add(reg_dst, dst_shift);
2523         add(reg_src_prf, src_shift);
2524         add(reg_dst_prf, dst_shift);
2525         if (ur_w_tail != 0)
2526             compute_loop(ur_w_tail, 0, r_overflow);
2527     } else {
2528         xor_(reg_oi, reg_oi);
2529         if (l_overflow > 0) {
2530             compute_loop(ur_w, l_overflow, 0);
2531             add(reg_src, src_shift);
2532             add(reg_dst, dst_shift);
2533             add(reg_src_prf, src_shift);
2534             add(reg_dst_prf, dst_shift);
2535
2536             inc(reg_oi);
2537         }
2538         if ((l_overflow <= 0 && n_oi > 0)
2539             || (l_overflow > 0 && n_oi > 1)) {
2540             Label ow_loop_label;
2541             L(ow_loop_label); {
2542                 compute_loop(ur_w, 0, 0);
2543                 add(reg_src, src_shift);
2544                 add(reg_dst, dst_shift);
2545                 add(reg_src_prf, src_shift);
2546                 add(reg_dst_prf, dst_shift);
2547
2548                 inc(reg_oi);
2549                 cmp(reg_oi, n_oi);
2550                 jl(ow_loop_label, T_NEAR);
2551             }
2552         }
2553         if (r_overflow1 > 0) {
2554             compute_loop(ur_w, 0, r_overflow1);
2555             add(reg_src, src_shift);
2556             add(reg_dst, dst_shift);
2557             add(reg_src_prf, src_shift);
2558             add(reg_dst_prf, dst_shift);
2559         }
2560         if (ur_w_tail != 0) {
2561             compute_loop(ur_w_tail, 0, r_overflow);
2562         }
2563     }
2564
2565     postamble();
2566 }
2567
2568 status_t jit_avx512_common_conv_bwd_data_kernel_f32::init_conf(
2569         jit_conv_conf_t &jcp,
2570         const convolution_desc_t &cd,
2571         const memory_desc_wrapper &diff_src_d,
2572         const memory_desc_wrapper &weights_d,
2573         const memory_desc_wrapper &diff_dst_d)
2574 {
2575     if (!mayiuse(avx512_common)) return status::unimplemented;
2576
2577     jcp = zero<decltype(jcp)>();
2578
2579     jcp.simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
2580     const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
2581     int ndims = diff_src_d.ndims();
2582
2583     jcp.ndims = ndims;
2584     jcp.prop_kind = cd.prop_kind;
2585
2586     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
2587     jcp.mb = diff_src_d.dims()[0];
2588
2589     jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
2590     jcp.oc_without_padding = jcp.oc;
2591     jcp.ic = diff_src_d.dims()[1] / jcp.ngroups;
2592
2593     jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1;
2594     jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims-2];
2595     jcp.iw = diff_src_d.dims()[ndims-1];
2596     jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
2597     jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2];
2598     jcp.ow = diff_dst_d.dims()[ndims-1];
2599
2600     jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
2601     jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
2602     jcp.kw = weights_d.dims()[with_groups + ndims - 1];
2603
2604     jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
2605     jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
2606     jcp.l_pad = cd.padding[0][ndims-3];
2607
2608     jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
2609     jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
2610     jcp.stride_w = cd.strides[ndims-3];
2611
2612     jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
2613     jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
2614     jcp.dilate_w = cd.dilates[ndims-3];
2615     if ((jcp.dilate_w != 0 && jcp.stride_w != 1)
2616             || (jcp.dilate_d != 0 && jcp.stride_d != 1)
2617             || (jcp.dilate_h != 0 && jcp.stride_h != 1))
2618         return status::unimplemented;
2619
2620     jcp.r_pad = (jcp.ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1)
2621             - (jcp.iw + jcp.l_pad - 1);
2622     jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
2623             - (jcp.ih + jcp.t_pad - 1);
2624     jcp.back_pad = (jcp.od - 1) * jcp.stride_d
2625             + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1);
2626
2627     jcp.aligned_threads = 0;
2628
2629     jcp.is_1stconv = false;
2630
2631     jcp.oc_block = jcp.simd_w;
2632     jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w;
2633
2634     bool ok_to_pad_channels = true
2635         && jcp.ngroups == 1
2636         && diff_src_d.data_type() == data_type::f32;
2637
2638     if (ok_to_pad_channels) {
2639         jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
2640         jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
2641     }
2642
2643     auto src_format = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
2644     auto wei_format = with_groups
2645         ? pick(ndims - 3, gOIw16o16i, gOIhw16o16i, gOIdhw16o16i)
2646         : pick(ndims - 3, OIw16o16i, OIhw16o16i, OIdhw16o16i);
2647     bool args_ok = true
2648         && jcp.oc % jcp.oc_block == 0
2649         && jcp.ic % jcp.ic_block == 0
2650         && diff_src_d.format() == src_format
2651         && diff_dst_d.format() == src_format;
2652     if (!args_ok)
2653         return status::unimplemented;
2654
2655     jcp.nb_ic = jcp.ic / jcp.ic_block;
2656     jcp.nb_oc = jcp.oc / jcp.oc_block;
2657
2658     jcp.ur_w = jcp.stride_w;
2659
2660     int regs = 28;
2661     if (jcp.iw <= regs)
2662         jcp.ur_w = jcp.iw;
2663     else {
2664         for (int ur_w = regs; ur_w > 0; --ur_w)
2665             if (ur_w % jcp.stride_w == 0) {
2666                 jcp.ur_w = ur_w;
2667                 break;
2668             }
2669     }
2670     int l_overflow = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
2671                     - jcp.l_pad) / jcp.stride_w);
2672     int r_overflow1 = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
2673                     - nstl::max(0, jcp.r_pad) - jcp.iw % jcp.ur_w) / jcp.stride_w);
2674     int n_oi = jcp.iw / jcp.ur_w;
2675     if (r_overflow1 > 0) n_oi--;
2676
2677     if ((mayiuse(avx512_mic_4ops) || mayiuse(avx512_core_vnni))
2678            && jcp.stride_w == 1 && jcp.stride_h == 1
2679            && diff_dst_d.data_type() == data_type::s16
2680            && weights_d.data_type() == data_type::s16
2681            && diff_src_d.data_type() == data_type::s32) {
2682         if (weights_d.format() != (with_groups ? gOIhw8o16i2o : OIhw8o16i2o))
2683             return status::unimplemented;
2684         if (mayiuse(avx512_mic_4ops)) {
2685             jcp.ver = ver_4vnni;
2686         } else {
2687             jcp.ver = ver_vnni;
2688         }
2689         jcp.typesize_in = sizeof(int16_t);
2690         jcp.typesize_out = sizeof(int32_t);
2691     } else if (mayiuse(avx512_common)
2692          && diff_dst_d.data_type() == data_type::f32
2693          && weights_d.data_type() == data_type::f32
2694          && diff_src_d.data_type() == data_type::f32) {
2695         if (weights_d.format() != wei_format)
2696             return status::unimplemented;
2697         jcp.ver = ver_fma;
2698         jcp.typesize_in = sizeof(float);
2699         jcp.typesize_out = sizeof(float);
2700         if (mayiuse(avx512_mic_4ops)
2701             && jcp.stride_w == 1 && jcp.stride_h == 1 && jcp.stride_d == 1) {
2702                 jcp.ver = ver_4fma;
2703             }
2704     } else {
2705         return status::unimplemented;
2706     }
2707     if (!utils::everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w)
2708             && jcp.ver != ver_fma)
2709         return status::unimplemented;
2710
2711     jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
2712     if (jcp.ver == ver_4vnni) {
2713         jcp.kernel_kind = embd_bcast;
2714     }
2715     if (jcp.ver == ver_vnni) {
2716         // TODO: kernel_kind and nb_oc_blocking selection
2717         //       should be tuned on real HW
2718         if ((jcp.iw <= 56 && jcp.ih <= 56 && jcp.kh < 5)
2719             || (jcp.iw <= 17 && jcp.ih <= 17 && jcp.kh >= 5) ) {
2720             jcp.kernel_kind = expl_bcast;
2721             jcp.nb_ic_blocking = 4;
2722         } else {
2723             jcp.kernel_kind = embd_bcast;
2724             jcp.nb_ic_blocking = 2;
2725         }
2726         if (jcp.nb_ic_blocking > 1) {
2727             if (jcp.nb_ic < jcp.nb_ic_blocking) jcp.nb_ic_blocking = jcp.nb_ic;
2728             if (jcp.nb_ic % jcp.nb_ic_blocking != 0)
2729                 for (int i = jcp.nb_ic_blocking; i > 0; i--)
2730                     if (jcp.nb_ic % i == 0) {
2731                         jcp.nb_ic_blocking = i;
2732                         break;
2733                     }
2734             jcp.ur_w = 31 / (jcp.nb_ic_blocking + 1);
2735             if (jcp.iw < jcp.ur_w) jcp.ur_w = jcp.iw;
2736         }
2737     }
2738     if (jcp.ver == ver_4fma) {
2739         if (jcp.kw == 3 && jcp.kh == 3 && jcp.iw == 7 && jcp.ih == 7) {
2740             jcp.nb_ic_blocking = 2;
2741         } else {
2742             for (int i = jcp.nb_ic; i > 0; i--)
2743                 if (i * jcp.ur_w <= regs && jcp.nb_ic % i == 0) {
2744                     jcp.nb_ic_blocking = i;
2745                     break;
2746                 }
2747         }
2748     }
2749
2750     jcp.loop_order = loop_gnc;
2751
2752     bool large_code_size = (jcp.ur_w != jcp.ow)
2753          && ((l_overflow <= 0 && n_oi > 0) ||(l_overflow > 0 && n_oi > 1))
2754          && (r_overflow1 > 0) && (l_overflow > 0);
2755     if (large_code_size) {
2756         const int max_code_size = 24 * 1024;
2757         const int num_ops_per_reg = 6 + jcp.oc_block * jcp.kw;
2758         int mult = 1;
2759         if (l_overflow > 0) mult += 1;
2760         if (r_overflow1 > 0) mult += 1;
2761         for (int ur_w = jcp.ur_w; ur_w > regs/2; --ur_w) {
2762             if ((ur_w / jcp.stride_w) * mult * num_ops_per_reg * 9.2
2763                     < max_code_size) {
2764                 if (ur_w % jcp.stride_w == 0) {
2765                     jcp.ur_w = ur_w;
2766                     break;
2767                 }
2768             }
2769         }
2770     }
2771
2772     if (jcp.ver == ver_fma && mayiuse(avx512_core)) {
2773         int try_nb_ic_blocking = 2;
2774         unsigned int ker_inp_size = typesize * jcp.iw * jcp.ic_block
2775             * try_nb_ic_blocking * jcp.kh;
2776         unsigned int ker_out_size = typesize * jcp.ow * jcp.oc_block;
2777         unsigned int ker_wei_size = typesize * jcp.kh * jcp.kw * jcp.ic_block
2778             * jcp.oc_block * try_nb_ic_blocking;
2779         unsigned int ker_total_size = ker_inp_size + ker_out_size
2780             + ker_wei_size;
2781         if (!(jcp.kw == 1 || (jcp.kw == 5 && jcp.iw < 8)
2782             || (jcp.kw < 5 && ((jcp.iw <= 5 || (jcp.iw > 8 && jcp.iw <= 13))
2783             || ker_total_size > L1_cache_size )))
2784                 || jcp.stride_h > 1 || jcp.stride_d > 1) {
2785             jcp.kernel_kind = embd_bcast;
2786             jcp.ur_w = nstl::min(jcp.iw, regs);
2787             jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
2788             if (!(jcp.kw > 3 || (jcp.kw == 3 && ker_total_size < L1_cache_size
2789                 && jcp.ow > 8)) && jcp.stride_h == 1)
2790                 if (jcp.nb_ic % try_nb_ic_blocking == 0) {
2791                     jcp.nb_ic_blocking = try_nb_ic_blocking;
2792                     jcp.ur_w = 31 / (jcp.nb_ic_blocking + 1);
2793                     if (jcp.iw < jcp.ur_w) jcp.ur_w = jcp.iw;
2794                 }
2795          } else {
2796             jcp.kernel_kind = expl_bcast;
2797             jcp.nb_oc_blocking = 1;
2798             jcp.nb_ic_blocking = 4;
2799             if (jcp.nb_ic < jcp.nb_ic_blocking) jcp.nb_ic_blocking = jcp.nb_ic;
2800             if (jcp.nb_ic % jcp.nb_ic_blocking != 0)
2801                 for (int i = jcp.nb_ic_blocking; i > 0; i--)
2802                     if (jcp.nb_ic % i == 0) {
2803                         jcp.nb_ic_blocking = i;
2804                         break;
2805                     }
2806             jcp.ur_w = 31 / (jcp.nb_ic_blocking + 1);
2807             if (jcp.iw < jcp.ur_w) jcp.ur_w = jcp.iw;
2808         }
2809     }
2810     jcp.ur_w_tail = jcp.iw % jcp.ur_w;
2811
2812     if (l_overflow * jcp.stride_w > jcp.ur_w)
2813         return status::unimplemented;
2814     int r_overflow_no_tail = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
2815                     - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w);
2816     if (r_overflow_no_tail * jcp.stride_w > jcp.ur_w)
2817         return status::unimplemented;
2818     if ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0))
2819         return status::unimplemented;
2820
2821     pick_loop_order(jcp);
2822
2823     jcp.nb_oc_L2 = jcp.nb_oc;
2824     // TODO check for 4vnni
2825     if (jcp.ver == ver_4fma && (jcp.kh < 5 && jcp.kw < 5)) {
2826         for (int divf = 2, temp_nb = jcp.nb_oc_L2; divf <= jcp.nb_oc;
2827               divf++) {
2828             size_t l2_src = jcp.iw * jcp.ic_block * jcp.nb_ic_blocking * jcp.ih
2829                 * jcp.id;
2830             size_t l2_dst = jcp.ow * jcp.oc_block * temp_nb * jcp.oh * jcp.od;
2831             size_t l2_filt = jcp.kw * jcp.oc_block * jcp.ic_block * jcp.kh
2832                 * jcp.kd * jcp.nb_ic_blocking * temp_nb;
2833             if (4 * (l2_src + l2_dst + l2_filt) > KNx_L2_EFFECTIVE_CAPACITY) {
2834                 if (jcp.kh == 3 && jcp.ih == 7) {
2835                     jcp.nb_oc_L2 = 1;
2836                     break;
2837                 }
2838                 temp_nb = (jcp.nb_oc_L2 % divf == 0 ? jcp.nb_oc_L2 / divf
2839                                 : jcp.nb_oc_L2);
2840             } else {
2841                 jcp.nb_oc_L2 = temp_nb;
2842                 break;
2843             }
2844         }
2845     }
2846
2847     args_ok = true
2848         && jcp.ic <= diff_src_d.blocking_desc().padding_dims[1]
2849         && jcp.oc <= diff_dst_d.blocking_desc().padding_dims[1]
2850         && jcp.ic <= weights_d.blocking_desc().padding_dims[with_groups + 1]
2851         && jcp.oc <= weights_d.blocking_desc().padding_dims[with_groups + 0];
2852     if (!args_ok) return status::unimplemented;
2853
2854     return status::success;
2855 }
2856
2857 void jit_avx512_common_conv_bwd_data_kernel_f32::init_scratchpad(
2858         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
2859     UNUSED(scratchpad);
2860     UNUSED(jcp);
2861 }
2862
2863 const int jit_avx512_common_conv_bwd_weights_kernel_f32::max_ur_w = 28;
2864
2865 void jit_avx512_common_conv_bwd_weights_kernel_f32::od_step_comeback_pointers()
2866 {
2867     Label kd_comeback_label;
2868
2869     /* 'depth' loop count bound by 'kd_work_size' */
2870     mov(kj, ptr[param + GET_OFF(kd_padding)]);
2871     L(kd_comeback_label); {
2872         int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
2873         int iw = (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni))
2874             ? jcp.tr_iw : jcp.iw;
2875         sub(aux_reg_input,
2876                 jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * iw * inp_mult);
2877         sub(aux_reg_kernel,
2878             jcp.typesize_out * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block);
2879         dec(kj);
2880         cmp(kj, 0);
2881         jg(kd_comeback_label, T_NEAR);
2882     }
2883 }
2884
2885 void jit_avx512_common_conv_bwd_weights_kernel_f32::oh_step_comeback_pointers()
2886 {
2887     Label kh_comeback_label, kd_comeback_label;
2888     mov(kj, reg_kh);
2889     L(kh_comeback_label); {
2890         int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
2891         int iw = (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni))
2892             ? jcp.tr_iw : jcp.iw;
2893         sub(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mult);
2894         sub(reg_kernel,
2895             jcp.typesize_out * jcp.kw * jcp.ic_block * jcp.oc_block);
2896         dec(kj);
2897         cmp(kj, 0);
2898         jg(kh_comeback_label, T_NEAR);
2899     }
2900 }
2901
2902 void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step_fma(
2903     int ur_w, int pad_l, int pad_r,
2904     int ic_block_step, int input_offset, int kernel_offset,
2905     int output_offset, bool input_wraparound)
2906 {
2907
2908     int kw = jcp.kw;
2909     int ic_block = jcp.ic_block;
2910     int oc_block = jcp.oc_block;
2911     for (int i_kw = 0; i_kw < kw; i_kw++)
2912         for (int i_ic = 0; i_ic < ic_block_step; i_ic++)
2913             vmovups(Zmm(i_kw * ic_block_step + i_ic),
2914                 EVEX_compress_addr(reg_kernel, typesize * (i_kw * ic_block
2915                 + i_ic) * jcp.oc_block + kernel_offset));
2916
2917     for (int i_ur = 0; i_ur < ur_w; i_ur++) {
2918         if (i_ur == 0) {
2919             vmovups(Zmm(kw * ic_block_step + (i_ur + 0) % 4),
2920                 EVEX_compress_addr(reg_output, typesize * (i_ur + 0)
2921                 * oc_block + output_offset));
2922             if (ur_w > 1) vmovups(Zmm(kw * ic_block_step + (i_ur + 1) % 4),
2923                 EVEX_compress_addr(reg_output, typesize * (i_ur + 1) * oc_block
2924                 + output_offset));
2925             if (ur_w > 2) vmovups(Zmm(kw * ic_block_step + (i_ur + 2) % 4),
2926                 EVEX_compress_addr(reg_output, typesize * (i_ur + 2) * oc_block
2927                 + output_offset));
2928             if (ur_w > 3) vmovups(Zmm(kw * ic_block_step + (i_ur + 3) % 4),
2929                 EVEX_compress_addr(reg_output, typesize * (i_ur + 3) * oc_block
2930                 + output_offset));
2931         } else if (i_ur + 3 < ur_w)
2932             vmovups(Zmm(kw * ic_block_step + (i_ur + 3) % 4),
2933                 EVEX_compress_addr(reg_output, typesize * (i_ur + 3) * oc_block
2934                 + output_offset));
2935
2936         for (int i_kw = 0; i_kw < kw; i_kw++) {
2937             int i_iw = i_ur * jcp.stride_w + i_kw * (jcp.dilate_w + 1);
2938             if (i_iw - pad_l < 0 || i_iw > (ur_w - 1) * jcp.stride_w +
2939                     (kw - 1) * (jcp.dilate_w + 1) - pad_r) continue;
2940             for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2941                 const size_t i_offset = (size_t)input_offset
2942                     + (size_t)typesize * (jcp.ver == ver_4fma
2943                             ? (i_iw - pad_l + i_ic * jcp.tr_iw)
2944                             : (jcp.is_1stconv
2945                                 ? (i_iw - pad_l) + (size_t)i_ic
2946                                     * ((size_t)jcp.ih*jcp.iw*jcp.id)
2947                                 : (i_iw - pad_l) * ic_block + i_ic));
2948                 vfmadd231ps(Zmm(i_kw * ic_block_step + i_ic),
2949                     Zmm(kw * ic_block_step + i_ur % 4),
2950                     EVEX_compress_addr_safe(reg_input, i_offset, reg_long_offt,
2951                         true));
2952             }
2953         }
2954     }
2955
2956     for (int i_kw = 0; i_kw < kw; i_kw++)
2957         for (int i_ic = 0; i_ic < ic_block_step; i_ic++)
2958             vmovups(EVEX_compress_addr(reg_kernel, typesize
2959                 * (i_kw * ic_block + i_ic) * jcp.oc_block + kernel_offset),
2960                 Zmm(i_kw * ic_block_step + i_ic));
2961 }
2962
2963 void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step_4fma(
2964     int ur_w, int pad_l, int pad_r,
2965     int ic_block_step, int input_offset, int kernel_offset,
2966     int output_offset, bool input_wraparound)
2967 {
2968     // TODO: add prefetches to fma version as well
2969
2970     assert(jcp.ver == ver_4fma);
2971
2972     int kw = jcp.kw;
2973     int ic_block = jcp.ic_block;
2974     int oc_block = jcp.oc_block;
2975
2976     auto zmm_ker = [=](int i_kw, int i_ic) {
2977         return Zmm(i_kw * ic_block_step + i_ic);
2978     };
2979
2980     auto ker_addr = [=](int i_kw, int i_ic) {
2981         size_t local_offset
2982             = jcp.typesize_out * (i_kw * ic_block + i_ic) * jcp.oc_block;
2983         return EVEX_compress_addr(reg_kernel, local_offset + kernel_offset);
2984     };
2985
2986     auto inp_addr = [=](int i_iw, int i_ic, ptrdiff_t extra_offset = 0) {
2987         int stride = jcp.tr_iw * (jcp.is_1stconv ? jcp.ih : 1);
2988         int local_offset = jcp.typesize_in * (i_iw + i_ic * stride);
2989         return EVEX_compress_addr(reg_input,
2990                 local_offset + input_offset + extra_offset);
2991     };
2992
2993     auto zmm_out = [=](int i_iw) {
2994         // TODO: move reg calc to global member funcs
2995         const int out_zmm_base_idx = 28;
2996         return Zmm(out_zmm_base_idx + i_iw % 4);
2997     };
2998
2999     auto out_addr = [=](int i_ur) {
3000         return EVEX_compress_addr(reg_output,
3001                 jcp.typesize_in * i_ur * oc_block + output_offset);
3002     };
3003
3004     auto pf_callback = [=](int i_ur, int i_kw, int i_ic) {
3005         assert(i_ur % 4 == 0);
3006         if (i_ur == 0)
3007             prefetcht1(ker_addr(i_kw, i_ic));
3008         if (i_ur + 4 >= ur_w)
3009             prefetcht0(ker_addr(i_kw, i_ic));
3010
3011         const ptrdiff_t next_input_block_offset
3012             = jcp.typesize_in * ic_block_step * jcp.tr_iw;
3013         if (i_ur % 16 == 4 && i_kw == 0) {
3014             if (i_ur + 16 < ur_w)
3015                 prefetcht0(inp_addr(i_ur + 16, i_ic));
3016             else
3017                 prefetcht0(inp_addr(0, i_ic, next_input_block_offset));
3018         }
3019         if (i_ur % 16 == 4 && i_kw == 1) {
3020             if (input_wraparound)
3021                 prefetcht1(inp_addr(i_ur, i_ic, -input_offset));
3022             else
3023                 prefetcht1(inp_addr(i_ur, i_ic, next_input_block_offset));
3024         }
3025     };
3026
3027     for (int i_kw = 0; i_kw < kw; i_kw++)
3028         for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
3029             auto zmm = zmm_ker(i_kw, i_ic);
3030             vpxord(zmm, zmm, zmm);
3031         }
3032
3033     for (int i_ur = 0; i_ur < ur_w; i_ur += 4) {
3034
3035         for (int i = 0; i < 4; i++) {
3036             auto zmm = zmm_out(i_ur + i);
3037             if (i_ur + i < ur_w)
3038                 vmovups(zmm, out_addr(i_ur + i));
3039             else
3040                 vpxord(zmm, zmm, zmm);
3041             prefetcht0(out_addr(i_ur + i + 4));
3042         }
3043
3044         for (int i_kw = 0; i_kw < kw; i_kw++)
3045             for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
3046                 int i_iw = i_ur + i_kw;
3047                 v4fmaddps(zmm_ker(i_kw, i_ic),
3048                         zmm_out(i_ur), inp_addr(i_iw, i_ic));
3049                 pf_callback(i_ur, i_kw, i_ic);
3050             }
3051     }
3052
3053     for (int i_kw = 0; i_kw < kw; i_kw++)
3054         for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
3055             auto addr = ker_addr(i_kw, i_ic);
3056             auto zmm = zmm_ker(i_kw, i_ic);
3057             vaddps(zmm, zmm, addr);
3058             vmovups(addr, zmm);
3059         }
3060 }
3061
3062 void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step_vnni(
3063     int ur_w, int pad_l, int pad_r,
3064     int ic_block_step, int input_offset, int kernel_offset,
3065     int output_offset, bool input_wraparound)
3066 {
3067     // TODO: add prefetches to fma version as well
3068     assert(jcp.ver == ver_4vnni || jcp.ver == ver_vnni);
3069
3070     int kw = jcp.kw;
3071     int ic_block = jcp.ic_block;
3072     int oc_block = jcp.oc_block;
3073
3074     auto zmm_ker = [=](int i_kw, int i_ic) {
3075         return Zmm(i_kw * ic_block_step + i_ic);
3076     };
3077
3078     auto ker_addr = [=](int i_kw, int i_ic) {
3079         size_t local_offset
3080             = jcp.typesize_out * (i_kw * ic_block + i_ic) * jcp.oc_block;
3081         return EVEX_compress_addr(reg_kernel, local_offset + kernel_offset);
3082     };
3083
3084     auto inp_addr = [=](int i_iw, int i_ic, ptrdiff_t extra_offset = 0,
3085                         bool vnni_bcast = false) {
3086         int stride = jcp.tr_iw * (jcp.is_1stconv ? jcp.ih : 1);
3087         int local_offset = jcp.typesize_in * (i_iw + i_ic * stride);
3088         if (vnni_bcast)
3089             return EVEX_compress_addr(reg_input,
3090                     local_offset + input_offset + extra_offset, true);
3091         else
3092             return EVEX_compress_addr(reg_input,
3093                     local_offset + input_offset + extra_offset);
3094     };
3095
3096     auto zmm_out = [=](int i_iw) {
3097         // TODO: move reg calc to global member funcs
3098         const int out_zmm_base_idx = 28;
3099         return Zmm(out_zmm_base_idx + i_iw % 4);
3100     };
3101
3102     auto out_addr = [=](int i_ur) {
3103         assert(utils::one_of(jcp.ver, ver_4vnni, ver_4fma, ver_vnni));
3104         auto ow_per_oc = (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) ? 2 : 1;
3105         return EVEX_compress_addr(reg_output,
3106                 jcp.typesize_in * i_ur * oc_block * ow_per_oc + output_offset);
3107     };
3108
3109     auto pf_callback = [=](int i_ur, int i_kw, int i_ic) {
3110         if (i_ur == 0)
3111             mic_prefetcht1(ker_addr(i_kw, i_ic));
3112         if (i_ur + 4 >= ur_w)
3113             mic_prefetcht0(ker_addr(i_kw, i_ic));
3114
3115         const ptrdiff_t next_input_block_offset
3116             = jcp.typesize_in * ic_block_step * jcp.tr_iw;
3117         if (i_ur % 16 == 4 && i_kw == 0) {
3118             if (i_ur + 16 < ur_w)
3119                 mic_prefetcht0(inp_addr(i_ur + 16, i_ic));
3120             else
3121                 mic_prefetcht0(inp_addr(0, i_ic, next_input_block_offset));
3122         }
3123         if (i_ur % 16 == 4 && i_kw == 1) {
3124             if (input_wraparound)
3125                 mic_prefetcht1(inp_addr(i_ur, i_ic, -input_offset));
3126             else
3127                 mic_prefetcht1(inp_addr(i_ur, i_ic, next_input_block_offset));
3128         }
3129     };
3130
3131     for (int i_kw = 0; i_kw < kw; i_kw++)
3132         for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
3133             auto zmm = zmm_ker(i_kw, i_ic);
3134             vpxord(zmm, zmm, zmm);
3135         }
3136     auto steps = ur_w / 2;
3137     auto numloads = (jcp.ver == ver_vnni) ? 1 : 4;
3138     for (int i_ur = 0; i_ur < steps; i_ur += numloads) {
3139
3140         for (int i = 0; i < numloads; i++) {
3141             int oi = i_ur + i;
3142             auto zmm = zmm_out(oi);
3143             if (oi < ur_w / 2)
3144                 vmovups(zmm, out_addr(oi));
3145             else
3146                 vpxord(zmm, zmm, zmm);
3147             mic_prefetcht0(out_addr(2 * i_ur + i + 4));
3148         }
3149
3150         for (int i_kw = 0; i_kw < kw; i_kw++)
3151             for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
3152                 int i_iw = 2 * i_ur + i_kw;
3153                 if (jcp.ver == ver_4vnni)
3154                     vp4dpwssd(zmm_ker(i_kw, i_ic), zmm_out(i_ur),
3155                         inp_addr(i_iw, i_ic));
3156                 else if (jcp.ver == ver_vnni)
3157                     vpdpwssd(zmm_ker(i_kw, i_ic), zmm_out(i_ur),
3158                         inp_addr(i_iw, i_ic, 0, true));
3159                 else
3160                     assert(!"unknown convolution version");
3161                 pf_callback(2 * i_ur, i_kw, i_ic);
3162             }
3163     }
3164
3165     for (int i_kw = 0; i_kw < kw; i_kw++) {
3166         for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
3167             auto addr = ker_addr(i_kw, i_ic);
3168             auto zmm = zmm_ker(i_kw, i_ic);
3169             vpaddd(zmm, zmm, addr);
3170             vmovups(addr, zmm);
3171         }
3172     }
3173 }
3174
3175 void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step(
3176     int ur_w, int pad_l, int pad_r,
3177     int ic_block_step, int input_offset, int kernel_offset,
3178     int output_offset, bool input_wraparound)
3179 {
3180     if (jcp.ver == ver_4vnni || jcp.ver == ver_vnni)
3181         compute_ic_block_step_vnni(ur_w, pad_l, pad_r,
3182                 ic_block_step, input_offset, kernel_offset, output_offset,
3183                 input_wraparound);
3184     else if (jcp.ver == ver_4fma)
3185         compute_ic_block_step_4fma(ur_w, pad_l, pad_r,
3186                 ic_block_step, input_offset, kernel_offset, output_offset,
3187                 input_wraparound);
3188     else if (jcp.ver == ver_fma)
3189         compute_ic_block_step_fma(ur_w, pad_l, pad_r,
3190                 ic_block_step, input_offset, kernel_offset, output_offset,
3191                 input_wraparound);
3192     else
3193         assert(!"unknown convolution version");
3194 }
3195
3196 void jit_avx512_common_conv_bwd_weights_kernel_f32
3197     ::compute_oh_step_unroll_ow_icblock(
3198     int ic_block_step, int max_ur_w)
3199 {
3200     UNUSED(max_ur_w);
3201
3202     Label kh_label, kd_label;
3203
3204     int ic_block = jcp.ic_block;
3205     int oc_block = jcp.oc_block;
3206     int inp_mul = !jcp.is_1stconv ? ic_block : 1;
3207     int iw = (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni))
3208         ? jcp.tr_iw : jcp.iw;
3209     int ow = (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) ? jcp.tr_ow : jcp.ow;
3210
3211     int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w
3212             + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
3213     int l_pad = jcp.l_pad;
3214
3215     if (jcp.ndims == 5) {
3216         mov(aux_reg_input, reg_input);
3217         mov(aux_reg_kernel, reg_kernel);
3218         mov(ki, ptr[param + GET_OFF(kd_padding)]);
3219         L(kd_label);
3220         mov(reg_input, aux_reg_input);
3221         mov(reg_kernel, aux_reg_kernel);
3222     }
3223
3224     mov(kj, reg_kh);
3225     L(kh_label);
3226     {
3227         for (int i_b_ic = 0; i_b_ic < jcp.ic_block; i_b_ic += ic_block_step) {
3228             const int input_offset = jcp.typesize_in
3229                 * (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)
3230                    ? i_b_ic * iw : i_b_ic);
3231             compute_ic_block_step(jcp.ur_w, l_pad, r_pad, ic_block_step,
3232                 input_offset, jcp.typesize_out * i_b_ic * jcp.oc_block, 0,
3233                 i_b_ic + ic_block_step >= jcp.ic_block);
3234         }
3235         add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul);
3236         add(reg_kernel, jcp.typesize_out * jcp.kw * ic_block * oc_block);
3237         dec(kj);
3238         cmp(kj, 0);
3239         jg(kh_label, T_NEAR);
3240     }
3241
3242     if (jcp.ndims == 5) {
3243         add(aux_reg_input,
3244                 jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * iw * inp_mul);
3245         add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block
3246             * oc_block);
3247         dec(ki);
3248         cmp(ki, 0);
3249         jg(kd_label, T_NEAR);
3250     }
3251 }
3252
3253 void jit_avx512_common_conv_bwd_weights_kernel_f32
3254     ::compute_oh_step_unroll_ow(
3255     int ic_block_step, int max_ur_w)
3256 {
3257     Label kh_label, ic_block_label, kd_label;
3258
3259     UNUSED(max_ur_w);
3260
3261     int ic_block = jcp.ic_block;
3262     int oc_block = jcp.oc_block;
3263
3264     int ow = (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) ? jcp.tr_ow : jcp.ow;
3265
3266     int r_pad = nstl::max(0,
3267         (ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1)
3268         - (jcp.iw + jcp.l_pad - 1));
3269     int l_pad = jcp.l_pad;
3270
3271     if (jcp.ndims == 5) {
3272         mov(aux_reg_input, reg_input);
3273         mov(aux_reg_kernel, reg_kernel);
3274         mov(ki, ptr[param + GET_OFF(kd_padding)]);
3275         L(kd_label);
3276         mov(reg_input, aux_reg_input);
3277         mov(reg_kernel, aux_reg_kernel);
3278     }
3279
3280     mov(kj, reg_kh);
3281     L(kh_label);
3282     {
3283         xor_(b_ic, b_ic);
3284         L(ic_block_label); {
3285             compute_ic_block_step(ow, l_pad, r_pad, ic_block_step,
3286                 0, 0, 0);
3287             size_t inp_icblk_stride = jcp.is_1stconv
3288                 ? (size_t)jcp.ih * jcp.iw * jcp.id
3289                 : (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)
3290                    ? jcp.tr_iw : 1);
3291             size_t input_offset
3292                 = inp_icblk_stride * jcp.typesize_in * ic_block_step;
3293             safe_add(reg_input, input_offset, reg_long_offt);
3294             add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block);
3295             add(b_ic, ic_block_step);
3296             cmp(b_ic, jcp.ic_block);
3297             jl(ic_block_label, T_NEAR);
3298         }
3299
3300         if (jcp.is_1stconv) {
3301             size_t input_offset
3302                 = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block;
3303             safe_sub(reg_input, input_offset, reg_long_offt);
3304             add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw);
3305         } else if (!utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)) {
3306             add(reg_input, jcp.typesize_in
3307                     * ((jcp.dilate_h + 1) * jcp.iw - 1) * ic_block);
3308         }
3309         add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block);
3310         dec(kj);
3311         cmp(kj, 0);
3312         jg(kh_label, T_NEAR);
3313     }
3314     if (jcp.ndims == 5) {
3315         add(aux_reg_input, jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih
3316                 * jcp.iw * (jcp.is_1stconv ? 1 : ic_block));
3317         add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block
3318             * oc_block);
3319         dec(ki);
3320         cmp(ki, 0);
3321         jg(kd_label, T_NEAR);
3322     }
3323 }
3324
3325 void jit_avx512_common_conv_bwd_weights_kernel_f32
3326     ::compute_oh_step_common(
3327     int ic_block_step, int max_ur_w)
3328 {
3329     Label kh_label, ic_block_label, ow_block_label, kd_label;
3330
3331     int ic_block = jcp.ic_block;
3332     int oc_block = jcp.oc_block;
3333
3334     int ow = (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) ? jcp.tr_ow : jcp.ow;
3335     int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w
3336             + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
3337     int l_pad = (jcp.ver == ver_4fma || jcp.ver == ver_4vnni
3338                  || jcp.ver == ver_vnni) ? 0 : jcp.l_pad;
3339
3340     int ur_w = nstl::min(ow, max_ur_w);
3341     int ur_w_trips = ow / ur_w;
3342     int ur_w_tail = ow % ur_w;
3343     if ((ur_w_tail == 0 && r_pad != 0)
3344         || r_pad >= ur_w_tail) {
3345         if (ur_w_trips > 1) {
3346             ur_w_tail += ur_w;
3347             ur_w_trips--;
3348         } else {
3349             ur_w_tail += (ur_w - ur_w / 2);
3350             ur_w = ur_w / 2;
3351         }
3352     }
3353
3354     int inp_mult = (jcp.is_1stconv ||
3355         utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)) ? 1 : ic_block;
3356     int input_comeback = (ur_w_trips * ur_w * jcp.stride_w - l_pad) * inp_mult;
3357     int output_comeback = ur_w_trips * ur_w * oc_block;
3358
3359     if (jcp.ndims == 5) {
3360         mov(aux_reg_input, reg_input);
3361         mov(aux_reg_kernel, reg_kernel);
3362         mov(ki, ptr[param + GET_OFF(kd_padding)]);
3363         L(kd_label);
3364         mov(reg_input, aux_reg_input);
3365         mov(reg_kernel, aux_reg_kernel);
3366     }
3367
3368     mov(kj, reg_kh);
3369     L(kh_label); {
3370         xor_(b_ic, b_ic);
3371         L(ic_block_label); {
3372             if (l_pad != 0) {
3373                 ur_w_trips--;
3374                 compute_ic_block_step(ur_w, l_pad, 0, ic_block_step, 0, 0, 0);
3375                 add(reg_input, jcp.typesize_in * (ur_w * jcp.stride_w - l_pad)
3376                     * inp_mult);
3377                 add(reg_output, jcp.typesize_in * ur_w * oc_block);
3378             }
3379
3380             if (ur_w_trips > 0) {
3381                 xor_(reg_ur_w_trips, reg_ur_w_trips);
3382                 L(ow_block_label); {
3383                     compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0);
3384                     add(reg_input, jcp.typesize_in * ur_w * jcp.stride_w
3385                         * inp_mult);
3386                     add(reg_output, jcp.typesize_in * ur_w * oc_block);
3387
3388                     inc(reg_ur_w_trips);
3389                     cmp(reg_ur_w_trips, ur_w_trips);
3390                     jl(ow_block_label, T_NEAR);
3391                 }
3392             }
3393
3394             if (ur_w_tail > 0) compute_ic_block_step(ur_w_tail, 0, r_pad,
3395                 ic_block_step, 0, 0, 0);
3396
3397             sub(reg_input, jcp.typesize_in * input_comeback);
3398             sub(reg_output, jcp.typesize_in * output_comeback);
3399             int inp_icblk_stride = jcp.is_1stconv
3400                 ? jcp.ih * jcp.iw * jcp.id
3401                 : (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)
3402                    ? jcp.tr_iw : 1);
3403             size_t input_offset
3404                 = inp_icblk_stride * jcp.typesize_in * ic_block_step;
3405             safe_add(reg_input, input_offset, reg_long_offt);
3406             add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block);
3407
3408             add(b_ic, ic_block_step);
3409             cmp(b_ic, jcp.ic_block);
3410             jl(ic_block_label, T_NEAR);
3411         }
3412         if (jcp.is_1stconv) {
3413             size_t input_offset
3414                 = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block;
3415             safe_sub(reg_input, input_offset, reg_long_offt);
3416             add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw);
3417         } else if (!utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)) {
3418             add(reg_input, jcp.typesize_in
3419                     * ((jcp.dilate_h + 1 ) * jcp.iw - 1) * ic_block);
3420         }
3421         add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block);
3422         dec(kj);
3423         cmp(kj, 0);
3424         jg(kh_label, T_NEAR);
3425     }
3426     if (jcp.ndims == 5) {
3427         add(aux_reg_input, jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih
3428                 * jcp.iw * (jcp.is_1stconv ? 1 : ic_block));
3429         add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block
3430             * oc_block);
3431         dec(ki);
3432         cmp(ki, 0);
3433         jg(kd_label, T_NEAR);
3434     }
3435 }
3436
3437 void jit_avx512_common_conv_bwd_weights_kernel_f32
3438     ::compute_oh_step_disp()
3439 {
3440     int ic_block_step = jcp.kw <= 3 ? 8 : (jcp.kw <= 7 ? 4 : 2);
3441     if (jcp.is_1stconv) {
3442         bool large_code = jcp.kw >= 7 && (jcp.l_pad > 0 || jcp.t_pad > 0);
3443         ic_block_step
3444             = (jcp.kw * jcp.ic_block <= 28 && !large_code) ? jcp.ic_block : 1;
3445     }
3446
3447     bool too_large_to_unroll
3448         = (jcp.kw > 1 || jcp.kh > 1 || jcp.kd > 1)
3449         && (jcp.stride_w > 1 || jcp.stride_h > 1 || jcp.stride_d > 1);
3450
3451     int ow = (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) ? jcp.tr_ow : jcp.ow;
3452     if (jcp.kw <= 3 && ow <= 16 && !too_large_to_unroll)
3453         compute_oh_step_unroll_ow_icblock(ic_block_step, max_ur_w);
3454     else if (ow <= max_ur_w)
3455         compute_oh_step_unroll_ow(ic_block_step, max_ur_w);
3456     else
3457         compute_oh_step_common(ic_block_step, max_ur_w);
3458
3459     if (jcp.ndims == 5) {
3460         od_step_comeback_pointers();
3461         mov(reg_input, aux_reg_input);
3462         mov(reg_kernel, aux_reg_kernel);
3463     } else {
3464         oh_step_comeback_pointers();
3465     }
3466 }
3467
3468 void jit_avx512_common_conv_bwd_weights_kernel_f32::maybe_zero_kernel()
3469 {
3470     Label skip_zeroing, zeroing_loop;
3471
3472     mov(reg_tmp, ptr[param + GET_OFF(channel)]);
3473     cmp(reg_tmp, 0);
3474     jz(skip_zeroing, T_NEAR);
3475
3476     Zmm zero = Zmm(0);
3477     vpxord(zero, zero, zero);
3478     xor_(reg_tmp, reg_tmp);
3479     L(zeroing_loop); {
3480         assert(jcp.oc_block * jcp.typesize_out
3481             == cpu_isa_traits<avx512_common>::vlen);
3482         for (int ic1 = 0; ic1 < jcp.ic_block; ic1++)
3483             vmovups(ptr[reg_kernel + reg_tmp + ic1 * jcp.oc_block
3484                 * jcp.typesize_out], zero);
3485         add(reg_tmp, jcp.ic_block * jcp.oc_block * jcp.typesize_out);
3486         cmp(reg_tmp, jcp.ic_block * jcp.oc_block * jcp.kw * jcp.kh * jcp.kd
3487             * jcp.typesize_out);
3488         jnz(zeroing_loop);
3489     }
3490
3491     L(skip_zeroing);
3492 }
3493
3494 void jit_avx512_common_conv_bwd_weights_kernel_f32::bias_kernel()
3495 {
3496     Label skip_bias, bias_loop, skip_load_bias;
3497
3498     mov(reg_tmp, ptr[param + GET_OFF(flags)]);
3499     test(reg_tmp,reg_tmp);
3500     jne(skip_bias, T_NEAR);
3501
3502     mov(reg_bias, ptr[param + GET_OFF(bias)]);
3503     mov(reg_output, ptr[param + GET_OFF(dst)]);
3504     vpxord(Zmm(1), Zmm(1), Zmm(1));
3505
3506     mov(reg_tmp, ptr[param + GET_OFF(channel)]);
3507     cmp(reg_tmp, 0);
3508     jne(skip_load_bias, T_NEAR);
3509     vmovups(Zmm(1), ptr[reg_bias]);
3510
3511     L(skip_load_bias);
3512
3513     mov(reg_oi, ptr[param + GET_OFF(d_worksize)]);
3514     sub(reg_oi, ptr[param + GET_OFF(d_index)]);
3515     mov(reg_tmp, jcp.oc_block * jcp.ow * jcp.oh * jcp.typesize_out);
3516     imul(reg_oi, reg_tmp);
3517
3518     xor_(reg_tmp, reg_tmp);
3519     L(bias_loop); {
3520         vmovups(Zmm(0), ptr[reg_output + reg_tmp]);
3521         vaddps(Zmm(1), Zmm(1), Zmm(0));
3522         add(reg_tmp, jcp.oc_block * jcp.typesize_out);
3523         cmp(reg_tmp, reg_oi);
3524         jl(bias_loop);
3525     }
3526     vmovups(EVEX_compress_addr(reg_bias,0), Zmm(1));
3527
3528     L(skip_bias);
3529 }
3530
3531 void jit_avx512_common_conv_bwd_weights_kernel_f32
3532     ::compute_oh_loop_common()
3533 {
3534     int ic_block = jcp.ic_block;
3535     int oc_block = jcp.oc_block;
3536     int back_pad = jcp.back_pad;
3537     int b_pad = jcp.b_pad;
3538     int t_pad = jcp.t_pad;
3539     bool is_dilated = jcp.dilate_h != 0;
3540     int dilate_h = jcp.dilate_h + 1;
3541     int stride_h = jcp.stride_h;
3542     const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
3543     int iw = utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni) ? jcp.tr_iw
3544         : jcp.iw;
3545     const size_t io_overlap = jcp.od - back_pad;
3546     Label oh_label, oh_label_end, oh_tpad_label, oh_tpad_tail_label,
3547             oh_bpad_label, oh_bpad_label_end, od_label, od_label_end,
3548             oh_dilate_label_shift, oh_dilate_label_noshift, oh_dilate_label_end,
3549             skip_neg_overlap_label, skip_fpad_label, skip_input_label;
3550
3551     maybe_zero_kernel();
3552     if (jcp.ndims == 5 && jcp.with_bias) bias_kernel();
3553
3554     /* initially offset 'kd' by f_pad */
3555     if (jcp.ndims == 5) add(reg_kernel, ptr[param + GET_OFF(kd_offset)]);
3556
3557     int ow = (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) ? jcp.tr_ow : jcp.ow;
3558
3559     if (jcp.ndims == 5) {
3560         mov(reg_input_d, ptr[param + GET_OFF(src)]);
3561         mov(reg_output_d, ptr[param + GET_OFF(dst)]);
3562         mov(reg_d_index, ptr[param + GET_OFF(d_index)]);
3563         L(od_label);
3564
3565         mov(reg_input, reg_input_d);
3566         mov(reg_output, reg_output_d);
3567         push(reg_input_d);
3568         push(reg_output_d);
3569         push(reg_d_index);
3570     }
3571
3572     mov(reg_kh, jcp.kh);
3573     xor_(reg_ih_count, reg_ih_count);
3574     xor_(reg_oj, reg_oj);
3575     /* Compute 'top' edge */
3576     if (t_pad > 0) {
3577         const int kh_range = 1 + (jcp.kh - 1) * dilate_h;
3578         const int overflow
3579             = nstl::max(0, jcp.kh - div_up(t_pad + jcp.ih, dilate_h));
3580         const int underflow = div_up(t_pad, dilate_h);
3581         const int initial_inp_ker_overlap = jcp.kh - overflow - underflow;
3582         mov(reg_kh, initial_inp_ker_overlap);
3583         add(reg_kernel, jcp.typesize_out * underflow * jcp.kw * jcp.ic_block
3584             * jcp.oc_block);
3585         // generate loop to process kernel while it remains within t_pad + ih
3586         if (kh_range < t_pad + jcp.ih) {
3587             if (is_dilated) {
3588                 const int tail = t_pad % dilate_h;
3589                 const int shift = tail == 0 ? 0 : dilate_h - tail;
3590                 mov(reg_tmp, shift);
3591                 if (tail != 0)
3592                     add(reg_input, jcp.typesize_in * shift * iw * inp_mult);
3593             }
3594             L(oh_tpad_label); {
3595                 compute_oh_step_disp();
3596                 add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
3597                 if (is_dilated) {
3598                     inc(reg_tmp);
3599                     cmp(reg_tmp, dilate_h);
3600                     jl(oh_dilate_label_shift, T_NEAR);
3601                     // unshift input as new kernel element enters
3602                     sub(reg_input, jcp.typesize_in * (dilate_h - 1) * iw * inp_mult);
3603                     xor_(reg_tmp, reg_tmp);
3604                 }
3605                 // kernel overlap only changes when (t_pad + oj) % dilate_h == 0
3606                 sub(reg_kernel, jcp.typesize_out * stride_h * jcp.kw
3607                                 * jcp.ic_block * jcp.oc_block);
3608                 add(reg_kh, stride_h);
3609                 if (is_dilated) {
3610                     jmp(oh_dilate_label_noshift, T_NEAR);
3611                     L(oh_dilate_label_shift);
3612                     // shift input as old kernel element progresses
3613                     add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult);
3614                     L(oh_dilate_label_noshift);
3615                 }
3616                 inc(reg_oj);
3617                 add(reg_ih_count, stride_h);
3618
3619                 // final number of kernel elements that overlap with input
3620                 const int final_inp_ker_overlap
3621                     = nstl::min(jcp.kh, div_up(jcp.ih, dilate_h));
3622                 cmp(reg_kh, final_inp_ker_overlap);
3623                 jl(oh_tpad_label, T_NEAR);
3624             }
3625         }
3626         // need second loop to process kernel if it is larger than the input
3627         // (does not apply to dilations as they must have unit stride)
3628         if (kh_range >= jcp.ih + (t_pad % stride_h == 0 ? stride_h :
3629                                                         t_pad % stride_h)) {
3630             assert(!is_dilated);
3631             mov(reg_kh, jcp.ih);
3632             L(oh_tpad_tail_label); {
3633                 compute_oh_step_disp();
3634                 add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
3635                 sub(reg_kernel, jcp.typesize_out * stride_h * jcp.kw
3636                                 * jcp.ic_block * jcp.oc_block);
3637
3638                 inc(reg_oj);
3639                 add(reg_ih_count, stride_h);
3640
3641                 cmp(reg_ih_count, nstl::min(t_pad, jcp.oh * stride_h));
3642                 jl(oh_tpad_tail_label, T_NEAR);
3643             }
3644         }
3645         // correct any excess shifts to kernel and input
3646         // (does not apply to dilations as they must have unit stride,
3647         //  kernel must fit inside input, and padding is smaller than input)
3648         if (t_pad <= jcp.oh * stride_h) {
3649             // kernel has moved beyond padding (adjust for stride effects)
3650             if (t_pad % stride_h != 0) {
3651                 assert(!is_dilated);
3652                 int inp_corr = stride_h - t_pad % stride_h;
3653                 add(reg_kernel, jcp.typesize_out * inp_corr * jcp.kw
3654                                 * jcp.ic_block * jcp.oc_block);
3655                 add(reg_input, jcp.typesize_in * inp_corr * iw * inp_mult);
3656             }
3657         } else {
3658             // kernel still overlaps padding (complete reset)
3659             assert(!is_dilated);
3660             sub(reg_kernel, jcp.typesize_out * (t_pad - jcp.oh * stride_h)
3661                             * jcp.kw * jcp.ic_block * jcp.oc_block);
3662         }
3663     }
3664
3665     cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h);
3666     jge(oh_label_end, T_NEAR);
3667     cmp(reg_oj, jcp.oh);
3668     jge(oh_label, T_NEAR);
3669
3670     /* Compute middle block(s) */
3671     mov(reg_kh, jcp.kh);
3672     L(oh_label); {
3673         compute_oh_step_disp();
3674         add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult);
3675         add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
3676
3677         inc(reg_oj);
3678         add(reg_ih_count, stride_h);
3679
3680         cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h);
3681         jge(oh_label_end, T_NEAR);
3682
3683         cmp(reg_oj, jcp.oh);
3684         jl(oh_label, T_NEAR);
3685     }
3686     L(oh_label_end);
3687
3688     /* Compute bottom edge */
3689     if (b_pad > 0) {
3690         cmp(reg_oj, jcp.oh);
3691         jge(oh_bpad_label_end, T_NEAR);
3692
3693         if (is_dilated) {
3694             mov(reg_kh, jcp.kh - 1); // assumes unit stride for dilations
3695             mov(reg_tmp, 0);
3696         } else {
3697             mov(reg_kh, jcp.ihp - b_pad);
3698             sub(reg_kh, reg_ih_count);
3699         }
3700         L(oh_bpad_label);
3701         {
3702             compute_oh_step_disp();
3703             add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult);
3704             add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
3705             if (is_dilated) {
3706                 inc(reg_tmp);
3707                 cmp(reg_tmp, dilate_h);
3708                 jl(oh_dilate_label_end, T_NEAR);
3709                 xor_(reg_tmp, reg_tmp);
3710             }
3711             sub(reg_kh, stride_h);
3712             cmp(reg_kh, 0);
3713             jle(oh_bpad_label_end, T_NEAR);
3714             if (is_dilated)
3715                 L(oh_dilate_label_end);
3716
3717             inc(reg_oj);
3718             cmp(reg_oj, jcp.oh);
3719             jl(oh_bpad_label, T_NEAR);
3720         }
3721         L(oh_bpad_label_end);
3722     }
3723
3724     if (jcp.ndims == 5) {
3725         pop(reg_d_index);
3726         pop(reg_output_d);
3727         pop(reg_input_d);
3728
3729         mov(reg_kd_count, ptr[param + GET_OFF(kd_padding)]);
3730
3731         /* 'outer-depth loop' offset into next 'depth' index */
3732         add(reg_output_d, jcp.typesize_in * jcp.oh * ow * jcp.oc_block);
3733
3734         /* only increase input address when convolution is not within the
3735          * 'f_pad' region */
3736         if (jcp.f_pad > 0) {
3737             cmp(reg_d_index, jcp.f_pad);
3738             jl(skip_input_label);
3739         }
3740         add(reg_input_d,
3741                 jcp.typesize_in * jcp.stride_d * jcp.ih * iw * inp_mult);
3742         L(skip_input_label);
3743
3744         inc(reg_d_index);
3745         cmp(reg_d_index, io_overlap);
3746         jl(skip_neg_overlap_label);
3747
3748         /* Reduce 'kd' count as convolution steps within 'back_pad' region */
3749         dec(reg_kd_count);
3750         jmp(skip_fpad_label);
3751
3752         L(skip_neg_overlap_label);
3753         cmp(reg_kd_count, jcp.kd);
3754         jge(skip_fpad_label);
3755
3756         /* increase 'kd' count as convolution steps out of 'f_pad' region */
3757         inc(reg_kd_count);
3758         sub(reg_kernel,
3759                 jcp.typesize_out * jcp.kh * jcp.kw * ic_block * oc_block);
3760
3761         L(skip_fpad_label);
3762         mov(ptr[param + GET_OFF(kd_padding)], reg_kd_count);
3763
3764         cmp(reg_d_index, ptr[param + GET_OFF(d_worksize)]);
3765         jl(od_label, T_NEAR);
3766
3767         L(od_label_end);
3768     }
3769 }
3770
3771 bool jit_avx512_common_conv_bwd_weights_kernel_f32
3772     ::compute_full_spat_loop()
3773 {
3774     // FIXME: use register mapping from the class declaration
3775     bool ok = one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)
3776         && (jcp.ver == ver_4fma || !one_of(1, jcp.kh, jcp.kw))
3777         && everyone_is(0, jcp.dilate_h, jcp.dilate_w)
3778         && everyone_is(1, jcp.stride_h, jcp.stride_w);
3779     if (!ok) return false;
3780     if (jcp.l_pad != jcp.kw / 2 || jcp.t_pad != jcp.kh / 2)
3781         return false;
3782
3783     // General code layout:
3784     //
3785     // Blocking over OH -- top level
3786     // (Reduces L2 pressure; not very useful right now)
3787     //  Loop over all KHxKW kernel -- emit_kh_kw_loop()
3788     //    Loop over OH block -- emit_h_loop()
3789     //      Loop over OW blocks -- emit_fma_block()
3790     //      (Supports both fully unrolled and partially unrolled versions to
3791     //      reduce code size)
3792     //          Loop over OW block -- emit_fma_step()
3793
3794     int max_working_set_size = 128 * 1024;
3795     int pad_ow = (jcp.ver == ver_4vnni || jcp.ver == ver_vnni)? jcp.tr_ow
3796         : jcp.ow;
3797
3798     int inp_row_size = jcp.ic_block * jcp.tr_iw * jcp.typesize_in;
3799     int out_row_size = jcp.oc_block * pad_ow * jcp.typesize_in;
3800     int row_size = inp_row_size + out_row_size;
3801
3802     int h_block_size = jcp.oh;
3803     int working_set_size = row_size * h_block_size;
3804
3805     if (working_set_size > max_working_set_size) {
3806         int opt_working_set_size = 48 * 1024;
3807         assert(opt_working_set_size < max_working_set_size);
3808
3809         while (working_set_size > opt_working_set_size) {
3810             for (int i = 2; i <= h_block_size; i++)
3811                 if (i == h_block_size)
3812                     h_block_size = h_block_size / 2;
3813                 else if (h_block_size % i == 0) {
3814                     h_block_size = h_block_size / i;
3815                     break;
3816                 }
3817             working_set_size = row_size * h_block_size;
3818
3819             if (h_block_size == 1 && working_set_size > opt_working_set_size)
3820                 return false;
3821         }
3822     }
3823
3824     // NB1: t_pad <= oh_block_size and b_pad <= last_oh_block_size (see below)
3825     if (h_block_size < nstl::max(1, jcp.t_pad)
3826             || jcp.b_pad > (jcp.oh % h_block_size == 0 ? h_block_size
3827                                                        : jcp.oh % h_block_size))
3828         return false;
3829
3830     // check that we can use simple arithmetic for prefetch address
3831     // calculations
3832     // TODO: we need some traits for this check (Roma)
3833     int cache_line_size = 64;
3834     assert(jcp.ic_block * typesize == 64);
3835     assert(jcp.oc_block * typesize == 64);
3836
3837     int num_inp_l2_pfs = jcp.tr_iw * h_block_size;
3838     int avg_h_loop_len = h_block_size;
3839     int num_inp_l2_pfs_per_fma_block
3840         = div_up(num_inp_l2_pfs, avg_h_loop_len * jcp.kw * jcp.kh);
3841     int num_out_l2_pfs = pad_ow * h_block_size;
3842     int num_out_l2_pfs_per_fma_block
3843         = div_up(num_out_l2_pfs, avg_h_loop_len * jcp.kw * jcp.kh);
3844
3845     Opmask reg_h_block = k1; // 32-bit only on Intel(R) Xeon Phi(TM) processors
3846     Reg64 reg_kh = rax;
3847     Reg64 reg_kw = rbx;
3848     Reg64 reg_tmp = abi_not_param1;
3849     Reg32 reg_tmp_w = reg_tmp.cvt32();
3850     Reg64 reg_ohs = rdx;
3851     Reg64 reg_ihs = rsi;
3852     Reg64 reg_h = r8;
3853     Reg64 reg_i = r9;
3854     Reg64 reg_j = r10;
3855
3856     Reg64 reg_inp = r13;
3857     Reg64 reg_out = r14;
3858     Reg64 reg_ker = r15;
3859
3860     Reg64 reg_inp_pf_l1 = rbp;
3861
3862     Reg64 reg_inp_pf_l2 = r11;
3863     Reg64 reg_out_pf_l2 = r12;
3864
3865     Xmm reg_inp_pf_save = xmm17;
3866     Xmm reg_out_pf_save = xmm18;
3867
3868     Reg64 reg_inp_save = abi_param1;
3869     Reg64 reg_out_save = reg_tmp;
3870
3871     auto zmm_out = [&](int oi) { return Zmm(24 + oi % 8); };
3872     auto zmm_ker = [&](int ic1) { return Zmm(ic1); };
3873     auto inp_addr = [&](int oi, int ic1, bool vnni_bcast = false) {
3874         if (vnni_bcast)
3875             return zword_b[reg_inp + (ic1 * jcp.tr_iw + oi) * jcp.typesize_in];
3876         else
3877             return ptr[reg_inp + (ic1 * jcp.tr_iw + oi) * jcp.typesize_in];
3878     };
3879     auto out_addr = [&](int oi, int oj = 0) {
3880         assert(utils::one_of(jcp.ver, ver_4vnni, ver_4fma, ver_vnni));
3881         auto ow_per_oc = (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) ? 2 : 1;
3882         auto pad_ow = (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) ? jcp.tr_ow
3883             : jcp.ow;
3884         return ptr[reg_out
3885             + ((oi + oj * pad_ow / ow_per_oc) * jcp.oc_block * ow_per_oc)
3886             * jcp.typesize_in];
3887     };
3888     auto ker_addr = [&](int ic1) {
3889         return ptr[reg_ker + ic1 * jcp.oc_block * jcp.typesize_out];
3890     };
3891
3892     auto emit_block = [&](int h_block_size,
3893             bool is_last_block, bool is_last_kh_kw_iter, bool is_last_row)
3894     {
3895         // TODO: add an fma version (Roma)
3896         auto pad_ow = (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) ? jcp.tr_ow
3897             : jcp.ow;
3898
3899         int ow_per_oc = (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) ? 2 : 1;
3900         int ow4u = rnd_up(pad_ow, 4 * ow_per_oc);
3901         int def_step_size = 16;
3902
3903         bool has_w_tail = (pad_ow % def_step_size != 0
3904                 || pad_ow % (4 * ow_per_oc) != 0);
3905         bool full_w_unroll = pad_ow / def_step_size < 2 + has_w_tail;
3906
3907         auto emit_step = [&](int ur_ow,
3908                 int num_inp_l1_pfs_per_fma_step,
3909                 int num_inp_l2_pfs_per_fma_step,
3910                 int num_out_l2_pfs_per_fma_step, bool is_w_tail)
3911         {
3912             bool block_wraparound = is_w_tail && is_last_row;
3913
3914             assert(ur_ow % 4 == 0);
3915             int tail_size = ow4u % ur_ow;
3916             int this_ur_ow
3917                 = (is_w_tail && tail_size) ? tail_size : ur_ow;
3918             int ow_last_chunk4 = pad_ow % (4 * ow_per_oc);
3919             int ow_zero_tail4 = ow_last_chunk4
3920                 ? (4 * ow_per_oc) - ow_last_chunk4 : 0;
3921
3922             auto emit_out_pf = [&](int oi) {
3923 #if 1
3924                 if (oi + def_step_size < ur_ow / ow_per_oc || !block_wraparound)
3925                     mic_prefetcht0(ptr[reg_out
3926                             + ((def_step_size + oi)
3927                                 * ow_per_oc * jcp.oc_block * jcp.typesize_in)]);
3928                 else {
3929                     assert(block_wraparound);
3930                     assert(oi + def_step_size >= ur_ow / ow_per_oc);
3931                     mic_prefetcht0(ptr[reg_out_save
3932                             + ((oi + def_step_size - ur_ow / ow_per_oc)
3933                                 * ow_per_oc * jcp.oc_block * jcp.typesize_in)]);
3934                 }
3935 #else
3936                 // XXX: This is an alternative prefetching strategy that
3937                 // always prefetches the next row. Keeping it here for
3938                 // future experiments (Roma)
3939                 if (!block_wraparound)
3940                     mic_prefetcht0(ptr[reg_out
3941                             + (jcp.ow + oi) * jcp.oc_block * jcp.typesize_in]);
3942                 else
3943                     mic_prefetcht0(ptr[reg_out + reg_ohs
3944                             - ((h_block_size - 1) * jcp.ow
3945                                 - oi) * jcp.oc_block * jcp.typesize_in]);
3946 #endif
3947                 if (oi < num_out_l2_pfs_per_fma_step)
3948                     mic_prefetcht1(ptr[reg_out_pf_l2
3949                             + oi * jcp.oc_block * jcp.typesize_in]);
3950             };
3951
3952             auto emit_inp_pf = [&](int oi4, int ic1) {
3953                 int pf_slot_idx = ic1 + oi4 / 4 * jcp.ic_block;
3954                 int num_pf_slots = jcp.ic_block * ur_ow / 4;
3955
3956                 int num_pfs = num_inp_l1_pfs_per_fma_step
3957                     + num_inp_l2_pfs_per_fma_step;
3958                 int pf_freq = nstl::max(1, num_pf_slots / num_pfs);
3959
3960                 if (pf_slot_idx % pf_freq)
3961                     return;
3962
3963                 int pf_idx = pf_slot_idx / pf_freq;
3964
3965                 if (pf_idx < num_inp_l2_pfs_per_fma_step)
3966                     mic_prefetcht1(ptr[reg_inp_pf_l2
3967                             + pf_idx * jcp.ic_block * jcp.typesize_in]);
3968                 else {
3969                     pf_idx -= num_inp_l2_pfs_per_fma_step;
3970                     // prefetch the 'tail' of the cache line because most of
3971                     // the accesses are not aligned
3972                     mic_prefetcht0(ptr[reg_inp_pf_l1
3973                             + pf_idx * jcp.ic_block * jcp.typesize_in
3974                             + cache_line_size - jcp.typesize_in]);
3975                 }
3976             };
3977
3978             auto numloads = (jcp.ver == ver_vnni) ? 1 : 4;
3979
3980             int steps = this_ur_ow / ow_per_oc;
3981             for (int oi4 = 0; oi4 < steps; oi4 += numloads) {
3982                 for (int oi1 = 0; oi1 < numloads; oi1++) {
3983                     int oi = oi4 + oi1;
3984                     if (!is_w_tail || oi < (this_ur_ow - ow_zero_tail4)/ow_per_oc) {
3985                         vmovups(zmm_out(oi), out_addr(oi));
3986                         emit_out_pf(oi);
3987                     } else {
3988                         auto zmm = zmm_out(oi);
3989                         vpxord(zmm, zmm, zmm);
3990                     }
3991                 }
3992
3993                 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
3994                     if (jcp.ver == ver_4fma) {
3995                         v4fmaddps(zmm_ker(ic1),
3996                                 zmm_out(oi4), inp_addr(oi4, ic1));
3997                     } else if (jcp.ver == ver_4vnni) {
3998                         vp4dpwssd(zmm_ker(ic1),
3999                                 zmm_out(oi4), inp_addr(ow_per_oc*oi4, ic1));
4000                     } else if (jcp.ver == ver_vnni) {
4001                         vpdpwssd(zmm_ker(ic1),
4002                             zmm_out(oi4), inp_addr(ow_per_oc*oi4, ic1, true));
4003                     } else {
4004                         assert(!"unknown convolution version");
4005                     }
4006                         emit_inp_pf(ow_per_oc * oi4, ic1);
4007                 }
4008             }
4009         };
4010
4011         // Input is transposed and padded but we only access about jcp.iw
4012         // elements so use that to compute the # of cache lines in each 'row'
4013         int num_inp_l1_pfs
4014             = div_up(jcp.iw * jcp.typesize_in, cache_line_size) * jcp.ic_block;
4015
4016         if (full_w_unroll) {
4017             emit_step(ow4u, num_inp_l1_pfs,
4018                     num_inp_l2_pfs_per_fma_block,
4019                     num_out_l2_pfs_per_fma_block, true);
4020             add(reg_inp_pf_l2, num_inp_l2_pfs_per_fma_block * cache_line_size);
4021             add(reg_out_pf_l2, num_out_l2_pfs_per_fma_block * cache_line_size);
4022         } else {
4023             Label w_loop;
4024             int num_w_iters = pad_ow / def_step_size;
4025             int num_w_iters_full = num_w_iters + has_w_tail;
4026             int num_inp_l1_pfs_per_fma_step
4027                 = div_up(num_inp_l1_pfs, num_w_iters_full);
4028             int num_inp_l2_pfs_per_fma_step
4029                 = div_up(num_inp_l2_pfs_per_fma_block, num_w_iters_full);
4030             int num_out_l2_pfs_per_fma_step
4031                 = div_up(num_out_l2_pfs_per_fma_block, num_w_iters_full);
4032             mov(reg_i, num_w_iters);
4033             L(w_loop); {
4034                 emit_step(def_step_size, num_inp_l1_pfs_per_fma_step,
4035                         num_inp_l2_pfs_per_fma_step,
4036                         num_out_l2_pfs_per_fma_step, false);
4037                 add(reg_inp, def_step_size * jcp.typesize_in);
4038                 add(reg_out, def_step_size * jcp.oc_block * jcp.typesize_in);
4039                 add(reg_inp_pf_l1,
4040                         num_inp_l1_pfs_per_fma_step * cache_line_size);
4041                 add(reg_inp_pf_l2,
4042                         num_inp_l2_pfs_per_fma_step * cache_line_size);
4043                 add(reg_out_pf_l2,
4044                         num_out_l2_pfs_per_fma_step * cache_line_size);
4045                 sub(reg_i, 1);
4046                 jnz(w_loop);
4047             }
4048             if (has_w_tail) {
4049                 emit_step(def_step_size, num_inp_l1_pfs_per_fma_step,
4050                         num_inp_l2_pfs_per_fma_step,
4051                         num_out_l2_pfs_per_fma_step, true);
4052                 add(reg_inp_pf_l2,
4053                         num_inp_l2_pfs_per_fma_step * cache_line_size);
4054                 add(reg_out_pf_l2,
4055                         num_out_l2_pfs_per_fma_step * cache_line_size);
4056             }
4057             // reset reg_inp and reg_out because emit_h_loop expects
4058             // unmodified pointers
4059             int w_offset = num_w_iters * def_step_size;
4060             sub(reg_inp, w_offset * jcp.typesize_in);
4061             sub(reg_out, w_offset * jcp.oc_block * jcp.typesize_in);
4062         }
4063     };
4064
4065     auto emit_h_loop = [&](int h_block_size,
4066             bool is_last_block, bool is_last_kh_kw_iter)
4067     {
4068         Label h_loop, skip_h_loop;
4069         mov(reg_j, 1);
4070         cmp(reg_j, reg_h);
4071         je(skip_h_loop, T_NEAR);
4072         L(h_loop); {
4073
4074             lea(reg_inp_pf_l1,
4075                     ptr[reg_inp + jcp.tr_iw * jcp.ic_block * jcp.typesize_in]);
4076             emit_block(h_block_size,
4077                     is_last_block, is_last_kh_kw_iter, false);
4078
4079             add(reg_inp, jcp.tr_iw * jcp.ic_block * jcp.typesize_in);
4080             add(reg_out, pad_ow * jcp.oc_block * jcp.typesize_in);
4081             add(reg_j, 1);
4082             cmp(reg_j, reg_h);
4083             jb(h_loop);
4084         }
4085
4086         L(skip_h_loop);
4087
4088         for (int ic1 = 0; ic1 < jcp.ic_block; ic1++)
4089             mic_prefetcht0(ker_addr(ic1));
4090
4091         lea(reg_inp_pf_l1, ptr[reg_inp_save + reg_kw * jcp.typesize_in]);
4092         emit_block(h_block_size, is_last_block, is_last_kh_kw_iter, true);
4093     };
4094
4095     auto emit_kh_kw_loop = [&](bool is_first_block, bool is_last_block,
4096             int h_block_size)
4097     {
4098         xor_(reg_kh, reg_kh);
4099         Label kh_loop, kh_loop_end;
4100
4101         int last_oh_block_size
4102             = jcp.oh - rnd_up(jcp.oh - h_block_size, h_block_size);
4103         int oh_block_size = (is_last_block) ? last_oh_block_size : h_block_size;
4104         // NB1: t_pad <= oh_block_size and b_pad <= last_oh_block_size
4105         int ih_block_size = oh_block_size - 1 + jcp.kh
4106                 - is_first_block * jcp.t_pad - is_last_block * jcp.b_pad;
4107
4108         L(kh_loop); {
4109             // determine starting indices for this block
4110             if (is_first_block) {
4111                 xor_(reg_tmp, reg_tmp);
4112                 mov(reg_ohs, jcp.t_pad);
4113                 sub(reg_ohs, reg_kh);
4114                 cmovb(reg_ohs, reg_tmp);
4115
4116                 mov(reg_ihs, reg_ohs);
4117                 sub(reg_ihs, jcp.t_pad);
4118                 add(reg_ihs, reg_kh);
4119             } else {
4120                 xor_(reg_ohs, reg_ohs);
4121                 mov(reg_ihs, reg_kh);
4122             }
4123
4124             // determine effective size of block based on padding
4125             mov(reg_tmp, oh_block_size);
4126             sub(reg_tmp, reg_ohs);
4127             mov(reg_h, ih_block_size);
4128             sub(reg_h, reg_ihs);
4129             cmp(reg_tmp, reg_h);
4130             cmovb(reg_h, reg_tmp);
4131
4132             Label kh_loop_work;
4133             cmp(reg_h, 0);
4134             jg(kh_loop_work, T_NEAR);
4135
4136             // empty h loop for this jcp.kh:
4137             // - set the output to 0 if necessary
4138             // - move ker pt
4139             // - jump to the end
4140             sub(reg_h, 1);
4141             Label skip_ker_zeroing;
4142
4143             // The reg_ker ptr has highest bit set if the output needs to be
4144             // zeroed. Those who have byte-aligned their data will suffer the
4145             // consiquences :(
4146             // TODO: move the flag to a mask register? (Roma)
4147             test(reg_ker, 1);
4148             jz(skip_ker_zeroing, T_NEAR);
4149
4150             Label zeroing_loop;
4151             vpxord(zmm0, zmm0, zmm0);
4152             and_(reg_ker, ~1); // temporarily clear the zeroing flag
4153             mov(reg_tmp, jcp.kw);
4154             L(zeroing_loop); {
4155                 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++)
4156                     vmovups(ker_addr(ic1), zmm0);
4157                 add(reg_ker, jcp.oc_block * jcp.ic_block * jcp.typesize_out);
4158                 sub(reg_tmp, 1);
4159                 jnz(zeroing_loop, T_NEAR);
4160             }
4161             // restore the zeroing flag (it will be cleared after the end of
4162             // emit_kh_kw_loop, but we may need it until then)
4163             or_(reg_ker, 1);
4164             jmp(kh_loop_end, T_NEAR);
4165
4166             L(skip_ker_zeroing);
4167             add(reg_ker, jcp.oc_block * jcp.ic_block * jcp.kw
4168                 * jcp.typesize_out);
4169             jmp(kh_loop_end, T_NEAR);
4170
4171             L(kh_loop_work);
4172
4173             mul_by_const(reg_ihs, reg_tmp,
4174                     jcp.tr_iw * jcp.ic_block * jcp.typesize_in);
4175             mul_by_const(reg_ohs, reg_tmp,
4176                     pad_ow * jcp.oc_block * jcp.typesize_in);
4177
4178             add(reg_inp, reg_ihs);
4179             add(reg_out, reg_ohs);
4180
4181             Label kw_loop;
4182             xor_(reg_kw, reg_kw);
4183             L(kw_loop); {
4184                 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
4185                     auto zmm = zmm_ker(ic1);
4186                     vpxord(zmm, zmm, zmm);
4187                     mic_prefetcht1(ker_addr(ic1));
4188                 }
4189
4190                 mov(reg_out_save, reg_out);
4191                 mov(reg_inp_save, reg_inp);
4192                 lea(reg_inp, ptr[reg_inp + reg_kw * jcp.typesize_in]);
4193
4194 #if 0
4195                 // XXX: Generate code with special prefetches when switching
4196                 // blocks or at the end of the last block. Disabled to reduce
4197                 // code size and because there's no performance benefit (Roma)
4198                 Label regular_h_loop, end_h_loop;
4199                 cmp(reg_kw, jcp.kw - 1);
4200                 jne(regular_h_loop, T_NEAR);
4201                 cmp(reg_kh, jcp.kh - 1);
4202                 jne(regular_h_loop, T_NEAR);
4203
4204                 emit_h_loop(oh_block_size, is_last_block, true);
4205                 jmp(end_h_loop, T_NEAR);
4206
4207                 L(regular_h_loop);
4208                 emit_h_loop(oh_block_size, is_last_block, false);
4209
4210                 L(end_h_loop);
4211 #else
4212                 emit_h_loop(oh_block_size, is_last_block, false);
4213 #endif
4214
4215                 mov(reg_out, reg_out_save);
4216                 mov(reg_inp, reg_inp_save);
4217
4218                 Label do_store;
4219                 // The reg_ker ptr has highest bit set if the output needs to
4220                 // be zeroed. Those who have byte-aligned their data will
4221                 // suffer the consiquences :(
4222                 mov(reg_tmp, reg_ker);
4223                 and_(reg_ker, ~1);
4224                 test(reg_tmp, 1);
4225                 jnz(do_store, T_NEAR);
4226
4227                 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
4228                     auto zmm = zmm_ker(ic1);
4229                     if (jcp.ver == ver_4fma) {
4230                         vaddps(zmm, ker_addr(ic1));
4231                     } else if (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) {
4232                         vpaddd(zmm, zmm, ker_addr(ic1));
4233                     } else {
4234                         assert(!"unknown convolution version");
4235                     }
4236                 }
4237
4238                 L(do_store);
4239                 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
4240                     auto zmm = zmm_ker(ic1);
4241                     vmovups(ker_addr(ic1), zmm);
4242                 }
4243
4244                 mov(reg_ker, reg_tmp);
4245                 add(reg_ker, jcp.ic_block * jcp.oc_block * jcp.typesize_out);
4246                 add(reg_kw, 1);
4247                 cmp(reg_kw, jcp.kw);
4248                 jl(kw_loop);
4249             }
4250
4251             sub(reg_inp, reg_ihs);
4252             sub(reg_out, reg_ohs);
4253
4254
4255             L(kh_loop_end);
4256             add(reg_kh, 1);
4257             cmp(reg_kh, jcp.kh);
4258             jl(kh_loop);
4259         }
4260     };
4261
4262     mov(reg_inp, ptr[param + GET_OFF(src)]);
4263     mov(reg_out, ptr[param + GET_OFF(dst)]);
4264     mov(reg_ker, ptr[param + GET_OFF(filt)]);
4265     mov(reg_inp_pf_l2, ptr[param + GET_OFF(src_prf)]);
4266     mov(reg_out_pf_l2, ptr[param + GET_OFF(dst_prf)]);
4267     mov(reg_tmp, ptr[param + GET_OFF(channel)]);
4268     or_(reg_ker, reg_tmp);
4269
4270     bool single_kh_kw_loop = (h_block_size == jcp.oh);
4271
4272     size_t inp_row_step = jcp.tr_iw * jcp.ic_block * jcp.typesize_in;
4273     size_t first_inp_block_step = inp_row_step * (h_block_size - jcp.t_pad);
4274     size_t inp_block_step = inp_row_step * h_block_size;
4275     size_t out_block_step = pad_ow * jcp.oc_block * jcp.typesize_in
4276         * h_block_size;
4277
4278     if (!single_kh_kw_loop) {
4279         // Save the original prefetch pointers from the OpenMP driver
4280         vmovq(reg_inp_pf_save, reg_inp_pf_l2);
4281         vmovq(reg_out_pf_save, reg_out_pf_l2);
4282         mov(reg_inp_pf_l2, reg_inp);
4283         add(reg_inp_pf_l2, first_inp_block_step);
4284         mov(reg_out_pf_l2, reg_out);
4285         add(reg_out_pf_l2, out_block_step);
4286     }
4287     emit_kh_kw_loop(true, single_kh_kw_loop, h_block_size);
4288
4289     if (!single_kh_kw_loop) {
4290         size_t ker_reset_offset
4291             = jcp.oc_block * jcp.ic_block * jcp.typesize_out * jcp.kw * jcp.kh;
4292         sub(reg_ker, ker_reset_offset);
4293         and_(reg_ker, ~1); // Clear the zeroing flag for subsequent updates
4294
4295         add(reg_inp, first_inp_block_step);
4296         add(reg_out, out_block_step);
4297         mov(reg_inp_pf_l2, reg_inp);
4298         add(reg_inp_pf_l2, inp_block_step);
4299         mov(reg_out_pf_l2, reg_out);
4300         add(reg_out_pf_l2, out_block_step);
4301
4302         int num_innermost_iters = div_up(jcp.oh, h_block_size) - 2;
4303         if (num_innermost_iters > 0) {
4304             Label h_block_loop;
4305
4306             mov(reg_tmp_w, num_innermost_iters);
4307             kmovw(reg_h_block, reg_tmp_w);
4308             L(h_block_loop); {
4309                 emit_kh_kw_loop(false, false, h_block_size);
4310                 sub(reg_ker, ker_reset_offset);
4311                 add(reg_inp, inp_row_step * h_block_size);
4312                 add(reg_out, out_block_step);
4313                 mov(reg_inp_pf_l2, reg_inp);
4314                 add(reg_inp_pf_l2, inp_block_step);
4315                 mov(reg_out_pf_l2, reg_out);
4316                 add(reg_out_pf_l2, out_block_step);
4317                 kmovw(reg_tmp_w, reg_h_block);
4318                 sub(reg_tmp_w, 1);
4319                 kmovw(reg_h_block, reg_tmp_w);
4320                 jnz(h_block_loop);
4321             }
4322         }
4323
4324         // Restore the original prefetch pointers that came from the OpenMP
4325         // driver
4326         vmovq(reg_inp_pf_l2, reg_inp_pf_save);
4327         vmovq(reg_out_pf_l2, reg_out_pf_save);
4328         emit_kh_kw_loop(false, true, h_block_size);
4329     }
4330
4331     return true;
4332 }
4333
4334 bool jit_avx512_common_conv_bwd_weights_kernel_f32
4335     ::flat_4ops_compute() {
4336     const auto &j = jcp;
4337     const bool ok = j.ver == ver_4fma && j.is_1stconv
4338         && everyone_is(0, j.dilate_h, j.dilate_w);
4339     if (!ok) return false;
4340
4341     Reg64 reg_ptr_tr_src = r8;
4342     Reg64 reg_ptr_dst = r9;
4343     Reg64 reg_ptr_wei = r10;
4344     Reg64 reg_ptr_bia = r11;
4345
4346     Reg64 reg_kh_step = rax;
4347     Reg64 reg_oh = abi_not_param1;
4348     Reg64 reg_kh = rdx;
4349
4350     Reg32 reg_flag_save = ebx;
4351     Reg32 reg_flag = esi;
4352
4353     Zmm vbia(31);
4354
4355     auto zmm_wei = [&](int kh, int kw) {
4356         return Zmm(8 + kh * j.kw + kw);
4357     };
4358     auto zmm_dst = [&](int ow) {
4359         return Zmm(ow % 8);
4360     };
4361
4362     auto addr_tr_src = [&](int kh, int iw) {
4363         return ptr[reg_ptr_tr_src
4364             + (kh * j.stride_w * j.tr_ld + iw) * jcp.typesize_in];
4365     };
4366     auto addr_dst = [&](int ow) {
4367         return ptr[reg_ptr_dst + ow * jcp.oc_block * jcp.typesize_in];
4368     };
4369     auto addr_wei = [&](int kh, int kw) {
4370         return ptr[reg_ptr_wei + (kh * j.kw + kw) * j.oc_block
4371             * jcp.typesize_out];
4372     };
4373
4374     auto emit_fma_block = [&](int kh_step) {
4375         for (int kh = 0; kh < kh_step; ++kh) {
4376             for (int kw = 0; kw < j.kw; ++kw) {
4377                 auto vwei = zmm_wei(kh, kw);
4378                 vpxord(vwei, vwei, vwei);
4379             }
4380         }
4381
4382         for (int ow = 0; ow < j.ow; ow += 4) {
4383             for (int _ow = ow; _ow < ow + 4; ++_ow) {
4384                 auto vdst = zmm_dst(_ow);
4385                 if (_ow < j.ow)
4386                     vmovups(vdst, addr_dst(_ow));
4387                 else
4388                     vpxord(vdst, vdst, vdst);
4389             }
4390
4391             for (int kh = 0; kh < kh_step; ++kh) {
4392                 for (int kw = 0; kw < j.kw; ++kw) {
4393                     const int iw = ow + (kw % j.stride_w) * j.tr_ld
4394                         + (kw / j.stride_w);
4395                     v4fmaddps(zmm_wei(kh, kw), zmm_dst(ow),
4396                             addr_tr_src(kh, iw));
4397                     if (1 && kh == 0 && kw < 4) {
4398                         prefetcht1(ptr[reg_ptr_dst
4399                             + (j.ow + ow + kw) * jcp.oc_block
4400                             * jcp.typesize_in]);
4401                     }
4402                     if (j.with_bias && kh_step == 1) { /* [bwd_w:b:r1] */
4403                         const int off = kw + 4 - j.kw;
4404                         if (off >= 0 && ow + off < j.ow)
4405                             vaddps(vbia, vbia, zmm_dst(ow + off));
4406                     }
4407                 }
4408             }
4409         }
4410
4411         Label l_store;
4412         test(reg_flag, FLAG_MB_FIRST);
4413         jnz(l_store, T_NEAR);
4414         for (int kh = 0; kh < kh_step; ++kh) {
4415             for (int kw = 0; kw < j.kw; ++kw)
4416                 vaddps(zmm_wei(kh, kw), addr_wei(kh, kw));
4417         }
4418         L(l_store);
4419         for (int kh = 0; kh < kh_step; ++kh) {
4420             for (int kw = 0; kw < j.kw; ++kw)
4421                 vmovups(addr_wei(kh, kw), zmm_wei(kh, kw));
4422         }
4423     };
4424
4425     auto emit_kh_loop = [&]() {
4426         const int kh_step_rem = j.kh % j.kh_step;
4427         xor_(reg_kh, reg_kh);
4428         mov(reg_kh_step, j.kh_step);
4429
4430         Label l_kh_loop;
4431         L(l_kh_loop); {
4432             Label l_done;
4433
4434             if (kh_step_rem != 0) {
4435                 Label l_keep_kh_step;
4436                 cmp(reg_kh, j.kh - j.kh_step);
4437                 jle(l_keep_kh_step, T_NEAR);
4438
4439                 mov(reg_kh_step, kh_step_rem);
4440                 emit_fma_block(kh_step_rem);
4441                 jmp(l_done, T_NEAR);
4442
4443                 L(l_keep_kh_step);
4444             }
4445
4446             emit_fma_block(j.kh_step);
4447
4448             L(l_done);
4449
4450             add(reg_ptr_tr_src, j.kh_step * j.stride_w * j.tr_ld
4451                 * jcp.typesize_in);
4452             add(reg_ptr_wei, j.kh_step * j.kw * j.oc_block * jcp.typesize_out);
4453             add(reg_kh, j.kh_step);
4454
4455             cmp(reg_kh, j.kh);
4456             jl(l_kh_loop, T_NEAR);
4457         }
4458
4459         const int kh_steps = rnd_up(j.kh, j.kh_step);
4460         sub(reg_ptr_tr_src, kh_steps * j.stride_w * j.tr_ld * jcp.typesize_in);
4461         sub(reg_ptr_wei, kh_steps * j.kw * j.oc_block * jcp.typesize_out);
4462     };
4463
4464     auto emit_oh_loop = [&]() {
4465         mov(reg_oh, j.oh);
4466
4467         Label l_oh_loop;
4468         L(l_oh_loop); {
4469             Label l_restore_mb_flag, l_jump;
4470
4471             cmp(reg_oh, j.oh);
4472             je(l_restore_mb_flag, T_NEAR);
4473
4474             and_(reg_flag, ~FLAG_MB_FIRST);
4475             jmp(l_jump, T_NEAR);
4476
4477             L(l_restore_mb_flag);
4478             mov(reg_flag, reg_flag_save);
4479
4480             L(l_jump);
4481
4482             emit_kh_loop();
4483
4484             add(reg_ptr_tr_src, j.stride_h * j.stride_w * j.tr_ld
4485                 * jcp.typesize_in);
4486             add(reg_ptr_dst, j.ow * j.oc_block * jcp.typesize_in);
4487
4488             dec(reg_oh);
4489             jnz(l_oh_loop, T_NEAR);
4490         }
4491     };
4492
4493     auto emit_bia_store = [&]() {
4494         if (!j.with_bias) return;
4495
4496         Label l_bia_store, l_bia_skip;
4497         test(reg_flag, FLAG_IC_FIRST);
4498         jz(l_bia_skip);
4499
4500         test(reg_flag, FLAG_MB_FIRST);
4501         jnz(l_bia_store, T_NEAR);
4502         vaddps(vbia, ptr[reg_ptr_bia]);
4503         L(l_bia_store);
4504         vmovups(ptr[reg_ptr_bia], vbia);
4505         L(l_bia_skip);
4506     };
4507
4508     mov(reg_ptr_tr_src, ptr[param + GET_OFF(src)]);
4509     mov(reg_ptr_dst, ptr[param + GET_OFF(dst)]);
4510     mov(reg_ptr_wei, ptr[param + GET_OFF(filt)]);
4511     mov(reg_ptr_bia, ptr[param + GET_OFF(bias)]);
4512     mov(reg_flag_save, ptr[param + GET_OFF(flags)]);
4513
4514     vpxord(vbia, vbia, vbia);
4515     emit_oh_loop();
4516     emit_bia_store();
4517
4518     return true;
4519 }
4520
4521 void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_loop()
4522 {
4523     if (flat_4ops_compute())
4524         return;
4525     if (compute_full_spat_loop())
4526         return;
4527     compute_oh_loop_common();
4528 }
4529
4530 void jit_avx512_common_conv_bwd_weights_kernel_f32::generate()
4531 {
4532     preamble();
4533
4534     mov(reg_input, ptr[param + GET_OFF(src)]);
4535     mov(reg_output, ptr[param + GET_OFF(dst)]);
4536     mov(reg_kernel, ptr[param + GET_OFF(filt)]);
4537
4538     compute_loop();
4539
4540     postamble();
4541 }
4542
4543 status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf(
4544     jit_conv_conf_t &jcp, const convolution_desc_t &cd,
4545     cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &diff_weights_pd,
4546     cpu_memory_t::pd_t &diff_bias_pd, cpu_memory_t::pd_t &diff_dst_pd) {
4547     if (!mayiuse(avx512_common))
4548         return status::unimplemented;
4549
4550     const memory_desc_wrapper src_d(&src_pd);
4551     const memory_desc_wrapper diff_weights_d(&diff_weights_pd);
4552     const memory_desc_wrapper diff_bias_d(&diff_bias_pd);
4553     const memory_desc_wrapper diff_dst_d(&diff_dst_pd);
4554
4555     const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
4556     int ndims = src_d.ndims();
4557
4558     jcp = zero<decltype(jcp)>();
4559
4560     jcp.simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
4561     jcp.ndims = ndims;
4562     jcp.prop_kind = cd.prop_kind;
4563
4564     jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
4565     jcp.mb = src_d.dims()[0];
4566
4567     jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
4568     jcp.oc_without_padding = jcp.oc;
4569     jcp.ic = src_d.dims()[1] / jcp.ngroups;
4570
4571     jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
4572     jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2];
4573     jcp.iw = src_d.dims()[ndims-1];
4574     jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
4575     jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2];
4576     jcp.ow = diff_dst_d.dims()[ndims-1];
4577
4578     jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1;
4579     jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims-2];
4580     jcp.kw = diff_weights_d.dims()[with_groups + ndims-1];
4581
4582     jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
4583     jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
4584     jcp.l_pad = cd.padding[0][ndims-3];
4585
4586     jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
4587     jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
4588     jcp.stride_w = cd.strides[ndims-3];
4589
4590     jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
4591     jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
4592     jcp.dilate_w = cd.dilates[ndims-3];
4593
4594     const int kh_range = 1 + (jcp.kh - 1) * (jcp.dilate_h + 1);
4595     bool ok = true
4596         // general condition to simplify dilations
4597         && IMPLICATION(jcp.dilate_d != 0, jcp.stride_d == 1)
4598         && IMPLICATION(jcp.dilate_h != 0, jcp.stride_h == 1)
4599         && IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1)
4600         // special condition to simplify dilations in compute_oh_loop_common
4601         && IMPLICATION(jcp.dilate_h != 0, kh_range <= jcp.ih);
4602     if (!ok)
4603         return status::unimplemented;
4604
4605     jcp.r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w
4606             + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
4607     jcp.b_pad = nstl::max(0, (jcp.oh - 1) * jcp.stride_h
4608             + (jcp.kh - 1) * (jcp.dilate_h + 1) - (jcp.ih + jcp.t_pad - 1));
4609     jcp.back_pad = nstl::max(0, (jcp.od - 1) * jcp.stride_d
4610             + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1));
4611
4612     /* XXX: currently, does not support stride_d > 1 or dilation > 0 */
4613     if (ndims == 5)
4614         if (jcp.stride_d > 1 || jcp.dilate_d > 0)
4615             return status::unimplemented;
4616
4617     jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
4618     jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
4619     jcp.ohp = jcp.oh;
4620     jcp.owp = jcp.ow;
4621     jcp.aligned_threads = 0;
4622
4623     /* check for the 1st convolution */
4624     jcp.is_1stconv = is_1stconv(jcp);
4625
4626     jcp.oc_block = jcp.simd_w;
4627
4628     bool ok_to_pad_channels = true
4629         && jcp.ngroups == 1
4630         && src_d.data_type() == data_type::f32;
4631
4632     if (ok_to_pad_channels)
4633         jcp.oc = rnd_up(jcp.oc, jcp.simd_w);
4634
4635     if (jcp.oc % jcp.oc_block)
4636         return status::unimplemented;
4637
4638     auto src_format = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
4639     auto wei_format = with_groups
4640         ? pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o)
4641         : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o);
4642     /* conditions on bias memory */
4643     jcp.with_bias = cd.diff_bias_desc.format != memory_format::undef;
4644     if (jcp.with_bias) {
4645         if (diff_bias_d.format() == any)
4646             CHECK(diff_bias_pd.set_format(x));
4647         if (diff_bias_d.format() != x)
4648             return status::unimplemented;
4649     }
4650
4651     jcp.nb_oc = jcp.oc / jcp.oc_block;
4652
4653     if (diff_dst_d.format() == any)
4654         CHECK(diff_dst_pd.set_format(src_format));
4655     if (diff_dst_d.format() != src_format)
4656         return status::unimplemented;
4657
4658     /* kernel applicability check wrt boundaries
4659      * the conditions are quite general across the kernels we have,
4660      * but ideally the check should belong to a specific kernel... */
4661     const int max_pad = ((jcp.kh - 1) * (jcp.dilate_h + 1) + 1) / 2;
4662     const bool boundaries_ok = true
4663         && jcp.t_pad <= max_pad
4664         && jcp.b_pad <= max_pad;
4665     if (!boundaries_ok)
4666         return status::unimplemented;
4667
4668     /* yet another common check */
4669     if (jcp.kw > 14)
4670         return status::unimplemented;
4671
4672     /* setting register strategy */
4673     for (int ur_w = nstl::min(max_ur_w, jcp.ow); ur_w > 0; --ur_w) {
4674         if (jcp.ow % ur_w == 0) { jcp.ur_w = ur_w; break; }
4675     }
4676
4677     if (jcp.is_1stconv) {
4678         const auto want_src_format = pick(ndims - 3, ncw, nchw, ncdhw);
4679         if (src_d.format() == any)
4680             CHECK(src_pd.set_format(want_src_format));
4681
4682         const bool src_ok = true
4683             && utils::everyone_is(data_type::f32,
4684                 src_d.data_type(), diff_weights_d.data_type(),
4685                 diff_dst_d.data_type())
4686             && one_of(jcp.ic, 1, 3)
4687             && IMPLICATION(jcp.ic == 1, one_of(src_d.format(), want_src_format,
4688                 pick(ndims - 3, nwc, nhwc, ndhwc)))
4689             && IMPLICATION(jcp.ic != 1, src_d.format() == want_src_format)
4690             && jcp.ngroups == 1;
4691         if (!src_ok)
4692             return status::unimplemented;
4693
4694         const int tr_ld = rnd_up(div_up(jcp.iw + jcp.l_pad + jcp.r_pad,
4695                     jcp.stride_w), 16);
4696         const int kh_step = nstl::max((28 - jcp.with_bias) / jcp.kw, 1);
4697         const int kh_step_rem = jcp.kh % kh_step;
4698         const auto want_4fma_wfmt = with_groups
4699             ? pick(ndims - 3, gOiw16o, gOihw16o, gOidhw16o)
4700             : pick(ndims - 3, Oiw16o, Oihw16o, Oidhw16o);
4701         const bool use_4fma = true
4702             && one_of(ndims, 3, 4)
4703             && mayiuse(avx512_mic_4ops)
4704             && mkldnn_thr_syncable()
4705             && everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w)
4706             && everyone_is(0, jcp.l_pad, jcp.r_pad, jcp.t_pad, jcp.b_pad)
4707             && jcp.kw <= 28 - jcp.with_bias
4708             && jcp.stride_w == 4
4709             && tr_ld / jcp.simd_w <= 4 /* [bwd_w:tr_src:r1] */
4710             && IMPLICATION(jcp.with_bias, kh_step_rem == 1) /* [bwd_w:b:r1] */
4711             && IMPLICATION(diff_weights_d.format() != any,
4712                     diff_weights_d.format() == want_4fma_wfmt);
4713
4714         if (use_4fma) {
4715             jcp.ver = ver_4fma;
4716             jcp.kh_step = kh_step;
4717             jcp.tr_ld = tr_ld;
4718             jcp.ic_block = 1;
4719             if (diff_weights_d.format() == any)
4720                 CHECK(diff_weights_pd.set_format(want_4fma_wfmt));
4721         } else {
4722             jcp.ver = ver_fma;
4723             jcp.ic_block = jcp.ic;
4724
4725             const auto want_wfmt = with_groups
4726                 ? pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o)
4727                 : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o);
4728             if (diff_weights_d.format() == any)
4729                 CHECK(diff_weights_pd.set_format(want_wfmt));
4730             if (diff_weights_d.format() != want_wfmt)
4731                 return status::unimplemented;
4732         }
4733
4734         jcp.nb_ic = jcp.ic / jcp.ic_block;
4735         jcp.src_fmt = src_d.format();
4736     } else {
4737         if (src_d.format() == any)
4738             CHECK(src_pd.set_format(src_format));
4739         if (diff_weights_d.format() == any)
4740             CHECK(diff_weights_pd.set_format(wei_format));
4741
4742         const bool ok = true
4743             && src_d.format() == src_format
4744             && diff_weights_d.format() == (wei_format);
4745         if (!ok)
4746             return status::unimplemented;
4747
4748         jcp.ic_block = jcp.simd_w;
4749         if (ok_to_pad_channels)
4750             jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
4751         jcp.nb_ic = jcp.ic / jcp.ic_block;
4752         jcp.src_fmt = src_d.format();
4753         if ((mayiuse(avx512_mic_4ops) || mayiuse(avx512_core_vnni))
4754             && mkldnn_thr_syncable()
4755             && one_of(ndims, 3, 4)
4756             && jcp.stride_w == 1
4757             && everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w)
4758             && ((src_d.data_type() == data_type::s16
4759             && diff_weights_d.data_type() == data_type::s32
4760             && diff_dst_d.data_type() == data_type::s16))) {
4761             if (mayiuse(avx512_core_vnni)) jcp.ver = ver_vnni;
4762             else jcp.ver = ver_4vnni;
4763         } else if ((mayiuse(avx512_mic) || mayiuse(avx512_core))
4764                 && utils::everyone_is(data_type::f32,
4765                     src_d.data_type(), diff_weights_d.data_type(),
4766                     diff_dst_d.data_type())) {
4767             jcp.ver = ver_fma;
4768             if (one_of(ndims, 3, 4) && mayiuse(avx512_mic_4ops) && jcp.stride_w == 1 &&
4769                     everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w) &&
4770                     mkldnn_thr_syncable()) {
4771                 jcp.ver = ver_4fma;
4772             }
4773         } else {
4774             return status::unimplemented;
4775         }
4776         if (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)) {
4777             jcp.ur_w = jcp.ow;
4778             // XXX, BUGBUGBUG, but not a FIXME: this assumes that it's OK to
4779             // cross the right boundary. The only requirement is not to have
4780             // NaNs there because another multiplicand is always guaranteed to
4781             // be zero. This also may require the top-level driver to allocate
4782             // four extra guarding elements at the very end of the buffer.
4783             // I'm not proud of this hack, but it improves performance by
4784             // about 5-10% depending on the dimensions (Roma)
4785
4786             // for vnni, that's results of performance tuning
4787             const int tr_round = (utils::one_of(jcp.ver, ver_4fma, ver_vnni))
4788                 ? 4 : 8;
4789
4790             jcp.tr_iw = rnd_up(jcp.iw + jcp.kw - 1, tr_round);
4791             jcp.tr_src_num_guard_elems = tr_round; // upper bound
4792
4793             if (utils::one_of(jcp.ver, ver_4vnni, ver_vnni)) {
4794                 jcp.tr_ow = rnd_up(jcp.ow, 2);
4795                 jcp.ur_w = jcp.tr_ow;
4796             }
4797         }
4798     }
4799
4800     if (utils::one_of(jcp.ver, ver_4vnni, ver_vnni)) {
4801         jcp.typesize_in = sizeof(int16_t);
4802         jcp.typesize_out = sizeof(int32_t);
4803     } else if (utils::one_of(jcp.ver, ver_4fma, ver_fma)) {
4804         jcp.typesize_in = sizeof(float);
4805         jcp.typesize_out = sizeof(float);
4806     } else
4807         return status::unimplemented;
4808
4809     bool args_ok = true
4810         && jcp.ic % jcp.ic_block == 0
4811         && jcp.oc % jcp.oc_block == 0
4812         && jcp.ic <= src_d.blocking_desc().padding_dims[1]
4813         && jcp.oc <= diff_dst_d.blocking_desc().padding_dims[1]
4814         && jcp.ic <= diff_weights_d.blocking_desc().padding_dims[with_groups + 1]
4815         && jcp.oc <= diff_weights_d.blocking_desc().padding_dims[with_groups + 0];
4816     if (!args_ok) return status::unimplemented;
4817
4818     {   // balancing
4819         int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
4820         balance(jcp, nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b);
4821         jcp.nthr = nthr;
4822         jcp.nthr_mb = nthr_mb;
4823         jcp.nthr_g = nthr_g;
4824         jcp.nthr_oc_b = nthr_oc_b;
4825         jcp.nthr_ic_b = nthr_ic_b;
4826     }
4827
4828     return status::success;
4829 }
4830
4831 void jit_avx512_common_conv_bwd_weights_kernel_f32::init_scratchpad(
4832         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
4833     if (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)) {
4834         if (jcp.is_1stconv) {
4835             const size_t tr_src_size =
4836                 jcp.nthr / jcp.nthr_oc_b * jcp.ih * jcp.stride_w * jcp.tr_ld;
4837             scratchpad.book(key_conv_tr_src, jcp.typesize_in * tr_src_size);
4838         } else {
4839             // XXX: See the comment about tr_iw and guarding elements in
4840             // jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf()
4841             const size_t max_nthr = jcp.nthr_mb * jcp.ngroups * jcp.nb_ic;
4842             const size_t min_tr_src_size_per_thr
4843                 = jcp.ih * jcp.ic_block * jcp.tr_iw;
4844             const size_t tr_src_size = max_nthr * min_tr_src_size_per_thr
4845                 + jcp.tr_src_num_guard_elems;
4846             scratchpad.book(key_conv_tr_src, jcp.typesize_in * tr_src_size);
4847         }
4848
4849         /* prepare synchronization contexts */
4850         if (jcp.nthr_oc_b > 1) {
4851             const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b;
4852             scratchpad.book(key_conv_tr_src_bctx,
4853                     sizeof(simple_barrier::ctx_t) * tr_src_bctx_size);
4854         }
4855
4856         if (utils::one_of(jcp.ver, ver_4vnni, ver_vnni)) {
4857             const size_t tr_diff_dst_size = jcp.nthr_mb * jcp.ngroups
4858                 * jcp.nb_oc * jcp.oc_block * jcp.tr_ow * jcp.oh;
4859             scratchpad.book(key_conv_tr_diff_dst,
4860                     jcp.typesize_in * tr_diff_dst_size);
4861
4862             /* prepare synchronization contexts */
4863             if (jcp.nthr_ic_b > 1) {
4864                 const size_t tr_diff_dst_bctx_size = jcp.nthr / jcp.nthr_ic_b;
4865                 scratchpad.book(key_conv_tr_diff_dst_bctx,
4866                         sizeof(simple_barrier::ctx_t) * tr_diff_dst_bctx_size);
4867             }
4868         }
4869     }
4870
4871     if (jcp.nthr_mb > 1) {
4872         const int wei_size = jcp.ngroups * jcp.oc * jcp.ic
4873             * jcp.kh * jcp.kw * jcp.kd;
4874         const int bia_size = jcp.ngroups * jcp.oc;
4875         const size_t wei_bia_reduction_size = wei_size + bia_size;
4876
4877         scratchpad.book(key_conv_wei_bia_reduction,
4878                 jcp.typesize_out * wei_bia_reduction_size * (jcp.nthr_mb - 1));
4879         scratchpad.book(key_conv_wei_bia_reduction_bctx,
4880                 sizeof(simple_barrier::ctx_t));
4881     }
4882
4883     if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
4884         scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc);
4885 }
4886
4887 void jit_avx512_common_conv_bwd_weights_kernel_f32::balance(
4888         const jit_conv_conf_t &j, int &nthr_, int &nthr_mb_, int &nthr_g_,
4889         int &nthr_oc_b_, int &nthr_ic_b_)
4890 {
4891     nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1;
4892
4893     const int max_threads = mkldnn_get_max_threads();
4894
4895     if (max_threads < j.ngroups) {
4896         /* simplification... fortunately it doesn't hurt much */
4897         return;
4898     }
4899
4900     if (!mkldnn_thr_syncable()
4901             && utils::one_of(j.ver, ver_4fma, ver_4vnni, ver_vnni)) {
4902         // should not happen -- the driver is not ready
4903         // for TBB-like non-synchronous threading yet
4904         return;
4905     }
4906
4907     if (j.ver == ver_4fma && j.is_1stconv) {
4908         nthr_g_ = 1;
4909         nthr_oc_b_ = 1;
4910         nthr_ic_b_ = nstl::min(j.nb_ic, max_threads);
4911         nthr_mb_ = nstl::min(max_threads / nthr_ic_b_, j.mb);
4912         nthr_ = nthr_mb_ * nthr_oc_b_ * nthr_ic_b_ * nthr_g_;
4913         return;
4914     }
4915
4916     nthr_g_ = j.ngroups;
4917     const int nthr = max_threads / nthr_g_;
4918
4919     auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
4920         /* calculate per thread memory cost (read/write). high level optimizer
4921          * tries to minimize memory consumption. few notes:
4922          *  (n1) unclear why, but that essentially helps first convolution...
4923          *  (n2) assuming the reduction over minibatch is always there:
4924          *    - instead of 8 it should be 5 here (write ~= 2 read):
4925          *      kernel: temporal workspace 1 write
4926          *      reduction: 1 read from workspace and 1 write to the diff_wei
4927          *    - but experiments showed 8 works better than 5 or 6... */
4928
4929         const int src_coef = j.ver == ver_4fma || j.ver == ver_vnni ? 4 : 1;
4930         const int dst_coef = 1;
4931         const int wei_coef = j.ver == ver_vnni ? 4 : 8;
4932
4933         return 0
4934             + src_coef
4935             * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_)
4936             * div_up(j.nb_ic, nthr_ic_b) * j.ic_block * j.ih * j.iw * j.id
4937             / j.stride_d / j.stride_h / j.stride_w /* (n1) */
4938             + dst_coef
4939             * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_)
4940             * div_up(j.nb_oc, nthr_oc_b) * j.oc_block * j.oh * j.ow * j.od
4941             + wei_coef /* (n2) */
4942             * div_up(j.ngroups, nthr_g_)
4943             * div_up(j.nb_oc, nthr_oc_b) * div_up(j.nb_ic, nthr_ic_b)
4944             * j.kh * j.kw * j.kd * j.ic_block * j.oc_block;
4945     };
4946
4947     int best_mem_cost = calc_mem_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
4948
4949     /* step 1: find the best thread distribution with lowest memory cost */
4950     const int nthr_mb_max = nstl::min(nthr, j.mb * j.od);
4951     for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
4952         const int nthr_par = nthr / nthr_mb;
4953         const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc);
4954         for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
4955             int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic);
4956
4957             int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
4958             if (mem_cost <= best_mem_cost) {
4959                 best_mem_cost = mem_cost;
4960                 nthr_mb_ = nthr_mb;
4961                 nthr_oc_b_ = nthr_oc_b;
4962                 nthr_ic_b_ = nthr_ic_b;
4963             }
4964         }
4965
4966         if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; }
4967     }
4968
4969     if (j.ver != ver_vnni && !mayiuse(avx512_mic)) {
4970         auto calc_comp_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
4971             return 1
4972                 * div_up(j.mb, nthr_mb)
4973                 * div_up(j.ngroups, nthr_g_)
4974                 * div_up(j.nb_oc, nthr_oc_b)
4975                 * div_up(j.nb_ic, nthr_ic_b);
4976         };
4977
4978         /* step 2: search for a thread distribution with lower compute cost.
4979          * the constrains:
4980          *  - memory cost cannot exceed 110% of the best found in the step 1
4981          *  - unless compute cost is 133% lower than the current best case
4982          * note: both constants were found empirically */
4983         int best_comp_cost = calc_comp_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
4984         for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
4985             const int nthr_par = nthr / nthr_mb;
4986             const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc);
4987             for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
4988                 int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic);
4989                 int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
4990                 int comp_cost = calc_comp_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
4991
4992                 const bool opt1 = comp_cost <= best_comp_cost
4993                     && mem_cost < 1.1 * best_mem_cost;
4994                 const bool opt2 = 4 * comp_cost <= 3 * best_comp_cost;
4995
4996                 if (opt1 || opt2) {
4997                     best_comp_cost = comp_cost;
4998                     nthr_mb_ = nthr_mb;
4999                     nthr_oc_b_ = nthr_oc_b;
5000                     nthr_ic_b_ = nthr_ic_b;
5001                 }
5002             }
5003
5004             if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; }
5005         }
5006     }
5007
5008     if (nthr_mb_ > max_threads/2 && nthr_mb_ < max_threads)
5009         nthr_mb_ = nstl::min(j.mb * j.od, max_threads);
5010     nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_;
5011
5012     assert(nthr_ <= max_threads);
5013     assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_mb_ == 1));
5014 }
5015
5016 template struct  _jit_avx512_common_conv_fwd_kernel<Zmm>;
5017 template struct  _jit_avx512_common_conv_fwd_kernel<Xmm>;
5018
5019 }
5020 }
5021 }
5022
5023 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s