updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_bf16_conv_kernel.cpp
1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "nstl.hpp"
19 #include "type_helpers.hpp"
20 #include "utils.hpp"
21
22 #include "cpu_barrier.hpp"
23 #include "cpu_memory.hpp"
24
25 #include "jit_avx512_core_bf16_conv_kernel.hpp"
26
27 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
28 #define KNx_L2_EFFECTIVE_CAPACITY ((512-64)*1024)
29
30 namespace mkldnn {
31 namespace impl {
32 namespace cpu {
33
34 using namespace mkldnn::impl::memory_format;
35 using namespace mkldnn::impl::memory_tracking::names;
36 using namespace mkldnn::impl::utils;
37 using namespace Xbyak;
38
39 namespace {
40
41 constexpr auto small_spatial = 14;
42
43 inline void pick_loop_order(jit_conv_conf_t &jcp) {
44     using namespace prop_kind;
45     assert(one_of(jcp.prop_kind,
46                 forward_training, forward_inference, backward_data));
47     auto w = (jcp.prop_kind == backward_data) ? jcp.iw : jcp.ow;
48     auto h = (jcp.prop_kind == backward_data) ? jcp.ih : jcp.oh;
49
50     // ow-threading is currently implemented for forward only
51     // TODO: single code for fwd and bwd after ow-thr for bwd
52     // meaningless switch was removed
53     if (jcp.prop_kind == backward_data) {
54         jcp.loop_order = (w <= small_spatial && h <= small_spatial)
55             ? loop_cgn : loop_gnc;
56     } else {
57         jcp.loop_order = (w <= small_spatial && h <= small_spatial)
58             ? loop_cwgn : loop_gncw;
59     }
60 }
61 inline bool is_1D_conv(const jit_conv_conf_t &jcp) {
62     return (jcp.ih == 1 && jcp.kh == 1);
63 }
64 inline bool is_ow_threading_available(const jit_conv_conf_t &jcp) {
65     return (is_1D_conv(jcp) && one_of(jcp.ndims, 3, 4)
66         && !(jcp.ver == ver_fma && mayiuse(avx512_mic)));
67 }
68 inline bool is_ow_threading_on(const jit_conv_conf_t &jcp) {
69     return (jcp.nb_ow > 1);
70 }
71 }
72
73 void jit_avx512_core_bf16_fwd_kernel::prepare_output(int ur_w)
74 {
75     for (int k = 0; k < jcp.nb_oc_blocking; k++)
76         for (int j = 0; j < ur_w; j++) {
77             Zmm zmm = zmm_out(j, k);
78             vpxord(zmm, zmm, zmm);
79         }
80 }
81
82 void jit_avx512_core_bf16_fwd_kernel::store_output(int ur_w)
83 {
84     Label store_label;
85     if (!jcp.is_cpx)
86         bf16_emu_->init_vcvtneps2bf16();
87
88     if (jcp.with_sum) {
89         for (int k = 0; k < jcp.nb_oc_blocking; k++) {
90             for (int j = 0; j < ur_w; j++) {
91                 Zmm zmm = zmm_out(j, k);
92                 size_t aux_output_offset = get_output_offset(j, k);
93                 if (jcp.dst_dt == data_type::bf16) {
94                     vpmovzxwd(zmm_prev_dst,
95                         make_safe_addr(reg_out, aux_output_offset, reg_out_long_offt));
96                     vpslld(zmm_prev_dst, zmm_prev_dst, 16);
97                     vaddps(zmm, zmm_prev_dst);
98                 } else {
99                     vaddps(zmm,
100                         make_safe_addr(reg_out, aux_output_offset, reg_out_long_offt));
101                 }
102             }
103         }
104     }
105
106     if (jcp.with_bias) {
107         mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
108         for (int k = 0; k < jcp.nb_oc_blocking; k++) {
109             int bias_offset = sizeof(float) * k * jcp.oc_block;
110             for (int j = 0; j < ur_w; j++) {
111                 Zmm zmm = zmm_out(j, k);
112                 vaddps(zmm, EVEX_compress_addr(reg_bias, bias_offset));
113             }
114         }
115     }
116
117     if (jcp.with_eltwise) {
118         if (ur_w == jcp.ur_w) {
119             eltwise_injector_->compute_vector_range(0,
120                     jcp.nb_oc_blocking * jcp.ur_w);
121         } else {
122             for (int k = 0; k < jcp.nb_oc_blocking; k++)
123                 eltwise_injector_->compute_vector_range(k * jcp.ur_w,
124                         k * jcp.ur_w + ur_w);
125         }
126     }
127
128     L(store_label);
129     if (jcp.dst_dt == data_type::f32) {
130         for (int k = 0; k < jcp.nb_oc_blocking; k++)
131             for (int j = 0; j < ur_w; j++) {
132                 Zmm zmm = zmm_out(j, k);
133                 size_t aux_output_offset = jcp.typesize_out *
134                     ((size_t)k * jcp.od * jcp.oh * jcp.ow + j) * jcp.oc_block;
135                 auto addr = EVEX_compress_addr(reg_out, aux_output_offset);
136
137                 vmovups(addr, zmm);
138             }
139     } else if (jcp.dst_dt == data_type::bf16) {
140         if (jcp.is_cpx) {
141             for (int k = 0; k < jcp.nb_oc_blocking; k++) {
142                 int n_2bf2ps = (ur_w / 2) * 2, j = 0;
143                 for (j = 0; j < n_2bf2ps; j += 2) {
144                     size_t aux_output_offset = jcp.typesize_out *
145                         ((size_t)k * jcp.od * jcp.oh * jcp.ow + j) * jcp.oc_block;
146                     auto addr = EVEX_compress_addr(reg_out, aux_output_offset);
147
148                     auto zmm_str = zmm_inp(j, jcp.nb_oc_blocking);
149                     vcvtne2ps2bf16(zmm_str, zmm_out(j+1, k),
150                                             zmm_out(j, k));
151                     vmovups(addr, zmm_str);
152                 }
153                 if (j < ur_w) {
154                     size_t aux_output_offset = jcp.typesize_out *
155                         ((size_t)k * jcp.od * jcp.oh * jcp.ow + j) * jcp.oc_block;
156                     auto addr = EVEX_compress_addr(reg_out, aux_output_offset);
157                     auto ymm_str = ymm_inp(j, jcp.nb_oc_blocking);
158                     vcvtneps2bf16(ymm_str, zmm_out(j, k));
159                     vmovups(addr, ymm_str);
160                 }
161             }
162         } else {
163             for (int k = 0; k < jcp.nb_oc_blocking; k++)
164                 for (int j = 0; j < ur_w; j++) {
165                     Zmm zmm = zmm_out(j, k);
166                     size_t aux_output_offset = jcp.typesize_out *
167                         ((size_t)k * jcp.od * jcp.oh * jcp.ow + j) * jcp.oc_block;
168                     auto addr = EVEX_compress_addr(reg_out, aux_output_offset);
169                     Ymm ymm = ymm_inp(0, jcp.nb_oc_blocking);
170                     bf16_emu_->r_vcvtneps2bf16(ymm, zmm);
171                     vmovups(addr, ymm);
172                 }
173         }
174     } else
175         assert(!"unsupported destination type");
176 }
177
178 void jit_avx512_core_bf16_fwd_kernel::compute_loop(
179         int ur_w, int pad_l, int pad_r)
180 {
181     Label kh_label, kd_label;
182     const size_t shift_kernel_ptr = (size_t)jcp.typesize_in * jcp.kw
183                                * jcp.oc_block * jcp.ic_block;
184     const size_t shift_input_ptr
185             = (size_t)jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw * jcp.ic_block;
186
187     prepare_output(ur_w);
188
189     // IC loop
190     Label icb_label;
191     mov(reg_icb, jcp.nb_ic);
192     L(icb_label);
193
194     mov(aux_reg_inp, reg_inp);
195     mov(aux_reg_ker, reg_ker);
196
197     Label skip_kh_loop, skip_kd_loop;
198
199     mov(reg_kj, reg_kh);
200     if ((jcp.dilate_h >= jcp.ih)
201             || (jcp.kh - 1) * (jcp.dilate_h + 1)
202                     < nstl::max(jcp.t_pad, jcp.b_pad)) {
203         cmp(reg_kj, 0);
204         je(skip_kh_loop, T_NEAR);
205     }
206
207     L(kh_label); {
208         for (int ki = 0; ki < jcp.kw; ki++) {
209             int ow_start = get_ow_start(ki, pad_l);
210             int ow_end = get_ow_end(ur_w, ki, pad_r);
211             for (int ic = 0;
212                  ic < div_up(nstl::min(jcp.ic_block, jcp.ic), 2); ic++) {
213                 if (jcp.is_cpx) {
214                     for (int oi = ow_start; oi < ow_end; oi++) {
215                         size_t input_offset =
216                             get_input_offset(ki, ic, oi, pad_l);
217                         vpbroadcastd(zmm_inp(oi, jcp.nb_oc_blocking),
218                             EVEX_compress_addr(aux_reg_inp, input_offset));
219                     }
220                 }
221                 for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) {
222                     size_t kernel_offset = get_kernel_offset(ki, ic, kk, 0);
223                     if (jcp.is_cpx)
224                         vmovups(zmm_wei,
225                             EVEX_compress_addr(aux_reg_ker, kernel_offset));
226                     for (int oi = ow_start; oi < ow_end; oi++) {
227                         if (!jcp.is_cpx) {
228                             size_t input_offset =
229                                 get_input_offset(ki, ic, oi, pad_l);
230                             vpbroadcastd(zmm_inp(oi, jcp.nb_oc_blocking),
231                                 EVEX_compress_addr(aux_reg_inp, input_offset));
232                             vmovups(zmm_wei,
233                                 EVEX_compress_addr(aux_reg_ker, kernel_offset));
234                             auto acc = zmm_out(oi, kk);
235                             auto inp = zmm_inp(oi, jcp.nb_oc_blocking);
236                             bf16_emu_->r_vdpbf16ps(acc, zmm_wei, inp);
237                         } else
238                             vdpbf16ps(zmm_out(oi, kk), zmm_wei,
239                                 zmm_inp(oi, jcp.nb_oc_blocking));
240                     }
241                 }
242             }
243         }
244         add(aux_reg_ker, shift_kernel_ptr);
245         add(aux_reg_inp, shift_input_ptr);
246
247         dec(reg_kj);
248         cmp(reg_kj, 0);
249         jg(kh_label, T_NEAR);
250     }
251
252     L(skip_kh_loop);
253
254     // End of IC Loop
255     size_t inp_step = (size_t)jcp.ih * jcp.iw * jcp.ic_block;
256     size_t ker_step = (size_t)jcp.kh * jcp.kw * jcp.oc_block * jcp.ic_block;
257     add(reg_inp, jcp.typesize_in * inp_step);
258     add(reg_ker, jcp.typesize_in * ker_step);
259
260     dec(reg_icb);
261     cmp(reg_icb, 0);
262     jg(icb_label, T_NEAR);
263
264     sub(reg_inp, jcp.typesize_in * inp_step * jcp.nb_ic);
265     sub(reg_ker, jcp.typesize_in * ker_step * jcp.nb_ic);
266
267     store_output(ur_w);
268 }
269
270 void jit_avx512_core_bf16_fwd_kernel::generate()
271 {
272     int iw = jcp.iw;
273     int ow = jcp.ow;
274     int ow_block = jcp.ow_block;
275     int nb_ow = jcp.nb_ow;
276     int kw = jcp.kw;
277     int l_pad = jcp.l_pad;
278     int ur_w = jcp.ur_w;
279     int ur_w_tail = jcp.ur_w_tail;
280     int dilate_w = jcp.dilate_w + 1;
281     int stride_w = jcp.stride_w;
282
283     int inp_mult = jcp.ic_block;
284
285     size_t inp_shift = (size_t)jcp.typesize_in * ur_w * stride_w * inp_mult;
286     size_t out_shift = (size_t)jcp.typesize_out * ur_w * jcp.oc_block;
287
288     int inp_shift_pad = jcp.typesize_in * (ur_w * stride_w - l_pad) * inp_mult;
289     int inp_shift_pad_second_block = -1 * jcp.typesize_in * l_pad * inp_mult;
290
291     preamble();
292     mov(reg_inp, ptr[param1 + GET_OFF(src)]);
293     mov(reg_out, ptr[param1 + GET_OFF(dst)]);
294     mov(reg_ker, ptr[param1 + GET_OFF(filt)]);
295     mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]);
296
297     int r_pad = nstl::max(
298             0, (ow - 1) * stride_w + (kw - 1) * dilate_w - (iw + l_pad - 1));
299     int n_oi = ow / ur_w;
300     int r_pad1 = (ur_w * n_oi - 1) * stride_w + (kw - 1) * dilate_w
301             - (iw + l_pad - 1);
302
303     if (!is_ow_threading_on(jcp)) {
304         // ow is being processed as a whole - with left and right paddings
305         if (r_pad1 > 0)
306             n_oi--;
307
308         xor_(reg_oi, reg_oi);
309         if (ow == ur_w) {
310             compute_loop(ur_w, l_pad, r_pad);
311         } else {
312             if (n_oi == 0) {
313                 compute_loop(ur_w, l_pad, r_pad1);
314                 add(reg_inp, inp_shift_pad);
315                 add(reg_out, out_shift);
316                 if (ur_w_tail != 0) {
317                     compute_loop(ur_w_tail, 0, r_pad);
318                 }
319             } else {
320                 if (l_pad > 0) {
321                     compute_loop(ur_w, l_pad, 0);
322                     add(reg_inp, inp_shift_pad);
323                     add(reg_out, out_shift);
324                     inc(reg_oi);
325                 }
326                 if ((l_pad <= 0 && n_oi > 0) || (l_pad > 0 && n_oi > 1)) {
327                     Label ow_loop_label;
328                     L(ow_loop_label);
329                     {
330                         compute_loop(ur_w, 0, 0);
331                         add(reg_inp, inp_shift);
332                         add(reg_out, out_shift);
333
334                         inc(reg_oi);
335                         cmp(reg_oi, n_oi);
336                         jl(ow_loop_label, T_NEAR);
337                     }
338                 }
339                 if (r_pad1 > 0) {
340                     compute_loop(ur_w, 0, r_pad1);
341                     add(reg_inp, inp_shift);
342                     add(reg_out, out_shift);
343                 }
344                 if (ur_w_tail != 0) {
345                     compute_loop(ur_w_tail, 0, r_pad);
346                 }
347             }
348         }
349     } else {
350         // ow block is only processed.
351         // Number of block is passed as parameter owb,
352         // and padding processing depends on this number.
353
354         Label end_label, last_oi_label, middle_ow_blocks_label, tail_label;
355         Label oi_loop_label, oi_loop_start_label, oi_loop_end_label;
356
357         assert(ow_block % ur_w == 0);
358         int n_oi_not_last_ow_block = ow_block / ur_w;
359         // to simplify code (and general regs usage),
360         // size of ow block must be >= 2 * ur_w
361         assert(n_oi_not_last_ow_block > 1);
362         int n_oi_next_last_ow_block = n_oi_not_last_ow_block;
363         int n_oi_first_ow_block = n_oi_not_last_ow_block;
364
365         int n_oi_last_ow_block = (ow - ow_block * (nb_ow-1)) / ur_w;
366
367         // prepare right padding
368         bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0;
369         bool first_ow_block_padded = next_last_ow_block_padded && jcp.nb_ow == 2;
370         bool last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block > 0;
371
372         if (last_ow_block_padded) n_oi_last_ow_block--;
373         else if (first_ow_block_padded) n_oi_first_ow_block--;
374         else if (next_last_ow_block_padded) n_oi_next_last_ow_block--;
375
376         mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
377         cmp(reg_owb, 0); // is that the first ow-block ?
378         jg(middle_ow_blocks_label, T_NEAR);
379
380         // the first ow block, compute left padding
381
382         mov(reg_oi, n_oi_first_ow_block);
383         if (l_pad > 0) {
384             compute_loop(ur_w, l_pad, 0);
385             add(reg_inp, inp_shift_pad);
386             add(reg_out, out_shift);
387             dec(reg_oi);
388         }
389         jmp(oi_loop_label, T_NEAR);
390
391         // middle or last ow block entry
392
393         L(middle_ow_blocks_label);
394
395         if (l_pad > 0) {
396             // just to consider left padding, not compute
397             add(reg_inp, inp_shift_pad_second_block);
398         }
399
400         // set number of iteration for oi-loop
401         cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ?
402         mov(reg_oi, n_oi_last_ow_block);
403         je(oi_loop_label, T_NEAR);
404         cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
405         mov(reg_oi, n_oi_next_last_ow_block);
406         je(oi_loop_label, T_NEAR);
407         mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks
408
409         // oi loop w/o padding
410         L(oi_loop_label);
411         L(oi_loop_start_label);
412             cmp(reg_oi, 0);
413             jle(oi_loop_end_label, T_NEAR);
414
415             compute_loop(ur_w, 0, 0);
416             add(reg_inp, inp_shift);
417             add(reg_out, out_shift);
418             dec(reg_oi);
419             jmp(oi_loop_start_label, T_NEAR);
420         L(oi_loop_end_label);
421
422         mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
423
424         cmp(reg_owb, 0); // first ow-block ?
425         if (first_ow_block_padded) {
426             je(last_oi_label, T_NEAR);
427         } else {
428             je(end_label, T_NEAR);
429         }
430         cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
431         jl(end_label, T_NEAR);
432         if (next_last_ow_block_padded) {
433             je(last_oi_label, T_NEAR);
434         } else {
435             je(end_label, T_NEAR);
436         }
437         // that is last block
438         if (!last_ow_block_padded) {
439             jmp(tail_label, T_NEAR);
440         }
441
442         // last oi block with right padding
443         L(last_oi_label);
444         compute_loop(ur_w, 0, r_pad1);
445         add(reg_inp, inp_shift);
446         add(reg_out, out_shift);
447
448         mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
449         cmp(reg_owb, jcp.nb_ow - 1); // last ow_block?
450         jl(end_label, T_NEAR);
451
452         L(tail_label);
453         if (ur_w_tail != 0) {
454             compute_loop(ur_w_tail, 0, r_pad);
455         }
456         L(end_label);
457     }
458     postamble();
459
460     if (jcp.with_eltwise)
461         eltwise_injector_->prepare_table();
462 }
463
464 bool jit_avx512_core_bf16_fwd_kernel::post_ops_ok(
465         jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
466     const auto &p = attr.post_ops_;
467
468     auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
469     auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
470
471     switch (p.len_) {
472     case 0: return true; // no post_ops
473     case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise
474     case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise
475     default: return false;
476     }
477
478     return false;
479 }
480
481 status_t jit_avx512_core_bf16_fwd_kernel::init_conf(
482             jit_conv_conf_t &jcp,
483             const convolution_desc_t &cd, cpu_memory_t::pd_t &src_pd,
484             cpu_memory_t::pd_t &weights_pd, cpu_memory_t::pd_t &dst_pd,
485             cpu_memory_t::pd_t &bias_pd, const primitive_attr_t &attr,
486             int nthreads)
487 {
488     using namespace prop_kind;
489
490     const int simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
491
492     const memory_desc_wrapper src_d(&src_pd);
493     const memory_desc_wrapper weights_d(&weights_pd);
494     const memory_desc_wrapper dst_d(&dst_pd);
495     const memory_desc_wrapper bias_d(&bias_pd);
496
497     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
498     int ndims = src_d.ndims();
499
500     jcp = zero<decltype(jcp)>();
501     jcp.ndims = ndims;
502     jcp.prop_kind = cd.prop_kind;
503     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
504     jcp.mb = src_d.dims()[0];
505     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
506     jcp.oc_without_padding = jcp.oc;
507     jcp.ic = src_d.dims()[1] / jcp.ngroups;
508     jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
509     jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2];
510     jcp.iw = src_d.dims()[ndims-1];
511     jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
512     jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims-2];
513     jcp.ow = dst_d.dims()[ndims-1];
514     jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
515     jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims-2];
516     jcp.kw = weights_d.dims()[with_groups + ndims-1];
517     jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
518     jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
519     jcp.l_pad = cd.padding[0][ndims-3];
520     jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
521     jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
522     jcp.stride_w = cd.strides[ndims-3];
523     jcp.src_fmt = src_d.format();
524     jcp.dst_dt = cd.dst_desc.data_type;
525
526     jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
527     jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
528     jcp.dilate_w = cd.dilates[ndims-3];
529
530     jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
531             - (jcp.ih + jcp.t_pad - 1);
532     jcp.back_pad = (jcp.od - 1) * jcp.stride_d
533             + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1);
534
535     if (ndims != 4)
536         return status::unimplemented;
537
538     jcp.is_cpx = mayiuse(avx512_core_bf16);
539     const int regs = jcp.is_cpx ? 31 /* expl_bcast case */ : 26;
540
541     jcp.oc_block = simd_w;
542     jcp.ic_block = simd_w;
543     jcp.aligned_threads = 0;
544
545     bool ok_to_pad_channels = jcp.ngroups == 1;
546
547     if (ok_to_pad_channels) {
548         jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
549         jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
550     }
551     bool args_ok = true
552         && jcp.oc % jcp.oc_block == 0
553         && jcp.ic % jcp.ic_block == 0;
554     if (!args_ok)
555         return status::unimplemented;
556
557     if (!post_ops_ok(jcp, attr))
558         return status::unimplemented;
559
560     const auto &p = attr.post_ops_;
561     jcp.with_sum = p.find(primitive_kind::sum) != -1;
562     const int eltwise_ind = p.find(primitive_kind::eltwise);
563     jcp.with_eltwise = eltwise_ind != -1;
564     if (jcp.with_eltwise) {
565         jcp.eltwise = p.entry_[eltwise_ind].eltwise;
566         if (dst_d.data_type() == data_type::s32) return status::unimplemented;
567     }
568
569     auto src_format = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
570     auto dst_format = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
571
572     if (src_d.format() == any)
573         CHECK(src_pd.set_format(src_format));
574     if (src_d.format() != src_format)
575         return status::unimplemented;
576     if (dst_d.format() == any)
577         CHECK(dst_pd.set_format(dst_format));
578     if (dst_d.format() != dst_format)
579         return status::unimplemented;
580     const auto w_format = with_groups
581         ? pick(ndims - 3, gOIw8i16o2i, gOIhw8i16o2i, gOIdhw8i16o2i)
582         : pick(ndims - 3, OIw8i16o2i, OIhw8i16o2i, OIdhw8i16o2i);
583     if (weights_d.format() == any)
584         CHECK(weights_pd.set_format(w_format));
585     if (weights_d.format() != w_format)
586         return status::unimplemented;
587
588     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
589     if (jcp.with_bias) {
590         if (bias_d.format() == any)
591             CHECK(bias_pd.set_format(x));
592         if (bias_d.format() != x)
593             return status::unimplemented;
594     }
595
596     jcp.ver = ver_vnni;
597     jcp.typesize_in = sizeof(mkldnn_bfloat16_t);
598     jcp.typesize_out = (dst_d.data_type() == data_type::f32)
599         ? sizeof(float) : sizeof(mkldnn_bfloat16_t);
600
601     jcp.nb_ic = jcp.ic / jcp.ic_block;
602     jcp.nb_oc = jcp.oc / jcp.oc_block;
603     jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
604
605     jcp.kernel_kind = expl_bcast;
606     jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc);
607     for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--) {
608         int ur_w = regs / (jcp.nb_oc_blocking + 1);
609         if (jcp.nb_oc % jcp.nb_oc_blocking == 0
610                 && (jcp.l_pad <= ur_w
611                          && IMPLICATION(jcp.ow != 1, jcp.ow % ur_w != 1)))
612             break;
613     }
614
615     jcp.ur_w = regs / (jcp.nb_oc_blocking + 1);
616     if (jcp.ow < jcp.ur_w)
617         jcp.ur_w = jcp.ow;
618     jcp.ur_w_tail = jcp.ow % jcp.ur_w;
619
620     jcp.ow_block = jcp.ow;
621     if (is_ow_threading_available(jcp)) {
622         const int L1_part = get_cache_size(1) * 5 / 8;
623         int size_src_chunk = jcp.typesize_in * jcp.ic_block * jcp.ur_w;
624         int size_dst_chunk = jcp.typesize_out
625             * jcp.oc_block * jcp.nb_oc_blocking * jcp.ur_w;
626         int size_wei_chunk = jcp.typesize_in
627             * jcp.oc_block * jcp.ic_block * jcp.nb_oc_blocking * jcp.kw;
628         int nurw = (L1_part - size_wei_chunk)
629             / (size_dst_chunk + size_src_chunk);
630         // current design of generate() requires ow_block >= 2 * ur_w
631         jcp.ow_block = jcp.ur_w * nstl::max(2, nurw);
632     }
633     jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
634
635     args_ok = true
636         && jcp.l_pad <= jcp.ur_w
637         && jcp.ic <= src_d.blocking_desc().padding_dims[1]
638         && jcp.oc <= dst_d.blocking_desc().padding_dims[1]
639         && jcp.ic <= weights_d.blocking_desc().padding_dims[with_groups + 1]
640         && jcp.oc <= weights_d.blocking_desc().padding_dims[with_groups + 0];
641     if (!args_ok)
642         return status::unimplemented;
643
644     int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
645                     + (jcp.kw - 1) * (jcp.dilate_w + 1)
646                     - (jcp.iw + jcp.l_pad - 1));
647     if (r_pad_no_tail > jcp.ur_w)
648         return status::unimplemented;
649
650     pick_loop_order(jcp);
651
652     jcp.nb_ic_L2 = jcp.nb_ic;
653
654     const int L2_size = get_cache_size(2, true) / sizeof(float);
655     // Source and output data needs to fit in L2,
656     // leaving some space for weights and prefetching.
657     int h_L2 = int(((0.6f * L2_size) / simd_w
658                            - nstl::min(0, jcp.kh - jcp.stride_h) * jcp.iw)
659             / (jcp.stride_h * jcp.iw + jcp.ow));
660     jcp.h_blocking = nstl::max(1, nstl::min(jcp.oh, h_L2));
661
662     return status::success;
663 }
664
665 void jit_avx512_core_bf16_bwd_data_kernel::prepare_output(int ur_w)
666 {
667     for (int k = 0; k < jcp.nb_ic_blocking; k++) {
668         for (int j = 0; j < ur_w; j++) {
669             Zmm zmm = zmm_out(j, k);
670             vpxord(zmm, zmm, zmm);
671         }
672     }
673 }
674
675 void jit_avx512_core_bf16_bwd_data_kernel::store_output(int ur_w)
676 {
677     if (!jcp.is_cpx)
678         bf16_emu_->init_vcvtneps2bf16();
679
680     if (jcp.dsrc_dt == data_type::f32) {
681         for (int k = 0; k < jcp.nb_ic_blocking; k++)
682             for (int j = 0; j < ur_w; j++) {
683                 Zmm zmm = zmm_out(j, k);
684                 size_t aux_diff_src_offset = jcp.typesize_out *
685                     ((size_t)k * jcp.id * jcp.ih * jcp.iw + j) * jcp.ic_block;
686                 auto addr = EVEX_compress_addr(reg_src, aux_diff_src_offset);
687
688                 vmovups(addr, zmm);
689             }
690     } else if (jcp.dsrc_dt == data_type::bf16) {
691         if (jcp.is_cpx) {
692             int store_idx = 0;
693             const int max_regs = 32;
694             const int free_regs_start_idx = jcp.ur_w * jcp.nb_ic_blocking;
695             const int num_regs_available = max_regs - free_regs_start_idx;
696             int reg_idx = 0;
697             for (int k = 0; k < jcp.nb_ic_blocking; k++) {
698                 int n_2bf2ps = (ur_w / 2) * 2, j = 0;
699                 for (j = 0; j < n_2bf2ps; j += 2) {
700                     reg_idx = free_regs_start_idx
701                         + store_idx % num_regs_available;
702                     assert(reg_idx < max_regs);
703                     size_t aux_diff_src_offset = jcp.typesize_out *
704                         ((size_t)k * jcp.id * jcp.ih * jcp.iw + j) *
705                         jcp.ic_block;
706                     auto addr = EVEX_compress_addr(reg_src,
707                                     aux_diff_src_offset);
708
709                     auto zmm_str = Zmm(reg_idx);
710                     vcvtne2ps2bf16(zmm_str, zmm_out(j+1, k),
711                                             zmm_out(j, k));
712                     vmovups(addr, zmm_str);
713                     store_idx++;
714                 }
715                 if (j < ur_w) {
716                     reg_idx = free_regs_start_idx
717                         + store_idx % num_regs_available;
718                     assert(reg_idx < max_regs);
719
720                     size_t aux_diff_src_offset = jcp.typesize_out *
721                         ((size_t)k * jcp.id * jcp.ih * jcp.iw + j) * jcp.ic_block;
722                     auto addr = EVEX_compress_addr(reg_src, aux_diff_src_offset);
723                     auto ymm_str = Ymm(reg_idx);
724                     vcvtneps2bf16(ymm_str, zmm_out(j, k));
725                     vmovups(addr, ymm_str);
726                     store_idx++;
727                 }
728             }
729         } else {
730             for (int k = 0; k < jcp.nb_ic_blocking; k++)
731                 for (int j = 0; j < ur_w; j++) {
732                     Zmm zmm = zmm_out(j, k);
733                     size_t aux_diff_src_offset = jcp.typesize_out *
734                         ((size_t)k * jcp.id * jcp.ih * jcp.iw + j) * jcp.ic_block;
735                     auto addr = EVEX_compress_addr(reg_src, aux_diff_src_offset);
736                     Ymm ymm = ymm_inp(0);
737                     bf16_emu_->r_vcvtneps2bf16(ymm, zmm);
738                     vmovups(addr, ymm);
739                 }
740         }
741     } else
742         assert(!"unsupported diff_src type");
743 }
744
745 void jit_avx512_core_bf16_bwd_data_kernel::compute_loop(
746         int ur_w, int l_overflow, int r_overflow)
747 {
748     int ow = jcp.ow;
749     int kw = jcp.kw;
750     int ic_block = jcp.ic_block;
751     int oc_block = jcp.oc_block;
752     int dilate_w = jcp.dilate_w + 1;
753     int stride_w = jcp.stride_w;
754     int stride_h = jcp.stride_h;
755     Label kh_label, skip_compute_label;
756
757     auto kernel_offset = [=](int icb, int oc, int ki) {
758         size_t blk_idx = (size_t)icb * jcp.kh * jcp.kw + ki;
759         size_t blk_offset = blk_idx * jcp.oc_block * jcp.ic_block;
760         size_t oc_offset = (size_t)oc * jcp.oc_block;
761         return jcp.typesize_in * (blk_offset + oc_offset);
762     };
763
764     prepare_output(ur_w);
765     cmp(reg_kh, 0);
766     jle(skip_compute_label, T_NEAR);
767
768     // OC loop
769     Label ocb_label;
770     mov(reg_ocb, jcp.nb_oc);
771     L(ocb_label);
772
773     mov(aux_reg_dst, reg_dst);
774     mov(aux_reg_ker, reg_ker);
775
776     mov(reg_kj, reg_kh);
777     L(kh_label); {
778         for (int ki = 0; ki < kw; ki++) {
779             int jj_start = get_iw_start(ki, l_overflow);
780             int jj_end = get_iw_end(ur_w, ki, r_overflow);
781             assert(stride_w != 1
782                     || jj_start == nstl::max(0,
783                         l_overflow - (kw - 1 - ki) * dilate_w));
784             assert(stride_w != 1
785                     || jj_end == ur_w - nstl::max(0,
786                         r_overflow - ki * dilate_w));
787
788             for (int oc = 0;
789                  oc < div_up(nstl::min(oc_block, jcp.oc), 2); oc++) {
790                 if (jcp.is_cpx) {
791                     for (int jj = jj_start; jj < jj_end; jj += stride_w) {
792                         assert((jj + jcp.l_pad - ki * dilate_w) % stride_w == 0);
793                         size_t aux_dst_offset = jcp.typesize_in
794                             * ((jj + jcp.l_pad - ki * dilate_w) / stride_w
795                                    * oc_block
796                                    + 2 * oc);
797                         auto inp = zmm_inp(jj / stride_w);
798                         vpbroadcastd(inp, ptr[aux_reg_dst + aux_dst_offset]);
799                     }
800                 }
801                 for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) {
802                     size_t aux_kernel_offset = kernel_offset(kk, 2 * oc, ki);
803                     if (jcp.is_cpx) {
804                         vmovups(zmm_wei,
805                             EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
806                     }
807
808                     for (int jj = jj_start; jj < jj_end; jj += stride_w) {
809                         auto inp = zmm_inp(jj / stride_w);
810                         auto acc = zmm_out(jj, kk);
811
812                         if (!jcp.is_cpx) {
813                             size_t aux_dst_offset = jcp.typesize_in
814                                 * ((jj + jcp.l_pad - ki * dilate_w) / stride_w
815                                        * oc_block
816                                        + 2 * oc);
817                             vpbroadcastd(inp,
818                                 ptr[aux_reg_dst + aux_dst_offset]);
819                             vmovups(zmm_wei,
820                                 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
821                             bf16_emu_->r_vdpbf16ps(acc, zmm_wei, inp);
822                         } else
823                             vdpbf16ps(acc, zmm_wei, inp);
824                     }
825                 }
826             }
827         }
828
829         add(aux_reg_ker, jcp.typesize_in * stride_h * kw * oc_block * ic_block);
830         sub(aux_reg_dst, jcp.typesize_in * (jcp.dilate_h + 1) * ow * oc_block);
831
832         dec(reg_kj);
833         cmp(reg_kj, 0);
834         jg(kh_label, T_NEAR);
835     }
836
837     // End of OC Loop
838     size_t diff_dst_step = (size_t)jcp.oh * jcp.ow * jcp.oc_block;
839     size_t ker_step = (size_t)jcp.ic * jcp.kh * jcp.kw * jcp.oc_block;
840     add(reg_dst, jcp.typesize_in * diff_dst_step);
841     add(reg_ker, jcp.typesize_in * ker_step);
842
843     dec(reg_ocb);
844     cmp(reg_ocb, 0);
845     jg(ocb_label, T_NEAR);
846
847     sub(reg_dst, jcp.typesize_in * diff_dst_step * jcp.nb_oc);
848     sub(reg_ker, jcp.typesize_in * ker_step * jcp.nb_oc);
849
850     L(skip_compute_label);
851     store_output(ur_w);
852 }
853
854 void jit_avx512_core_bf16_bwd_data_kernel::generate()
855 {
856     int iw = jcp.iw;
857     int kw = jcp.kw;
858     int ur_w = jcp.ur_w;
859     int ic_block = jcp.ic_block;
860     int oc_block = jcp.oc_block;
861     int ur_w_tail = jcp.ur_w_tail;
862     int dilate_w = jcp.dilate_w + 1;
863     int stride_w = jcp.stride_w;
864
865     size_t dst_shift = jcp.typesize_in * (ur_w / stride_w) * ic_block;
866     size_t src_shift = jcp.typesize_out * ur_w * oc_block;
867
868     preamble();
869
870     mov(reg_src, ptr[param + GET_OFF(src)]);
871     mov(reg_dst, ptr[param + GET_OFF(dst)]);
872     mov(reg_ker, ptr[param + GET_OFF(filt)]);
873
874     mov(reg_kh, ptr[param + GET_OFF(kh_padding)]);
875
876     int l_overflow = nstl::max(0, ((kw - 1) * dilate_w - jcp.l_pad) / stride_w);
877     int r_overflow = nstl::max(0, ((kw - 1) * dilate_w
878                     - nstl::max(0, jcp.r_pad)) / stride_w);
879     int r_overflow1 = nstl::max(0, ((kw - 1) * dilate_w
880                     - nstl::max(0, jcp.r_pad) - ur_w_tail) / stride_w);
881
882     int n_oi = iw / ur_w;
883     if (r_overflow1 > 0) n_oi--;
884
885     if (ur_w == iw) {
886         compute_loop(ur_w, l_overflow, r_overflow);
887     } else if (n_oi == 0) {
888         compute_loop(ur_w, l_overflow, r_overflow1);
889         add(reg_src, src_shift);
890         add(reg_dst, dst_shift);
891         if (ur_w_tail != 0)
892             compute_loop(ur_w_tail, 0, r_overflow);
893     } else {
894         xor_(reg_oi, reg_oi);
895         if (l_overflow > 0) {
896             compute_loop(ur_w, l_overflow, 0);
897             add(reg_src, src_shift);
898             add(reg_dst, dst_shift);
899
900             inc(reg_oi);
901         }
902         if ((l_overflow <= 0 && n_oi > 0)
903             || (l_overflow > 0 && n_oi > 1)) {
904             Label ow_loop_label;
905             L(ow_loop_label); {
906                 compute_loop(ur_w, 0, 0);
907                 add(reg_src, src_shift);
908                 add(reg_dst, dst_shift);
909
910                 inc(reg_oi);
911                 cmp(reg_oi, n_oi);
912                 jl(ow_loop_label, T_NEAR);
913             }
914         }
915         if (r_overflow1 > 0) {
916             compute_loop(ur_w, 0, r_overflow1);
917             add(reg_src, src_shift);
918             add(reg_dst, dst_shift);
919         }
920         if (ur_w_tail != 0) {
921             compute_loop(ur_w_tail, 0, r_overflow);
922         }
923     }
924
925     postamble();
926 }
927
928 status_t jit_avx512_core_bf16_bwd_data_kernel::init_conf(
929         jit_conv_conf_t &jcp,
930         const convolution_desc_t &cd,
931         const memory_desc_wrapper &diff_src_d,
932         const memory_desc_wrapper &weights_d,
933         const memory_desc_wrapper &diff_dst_d)
934 {
935     const int simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
936     const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
937     int ndims = diff_src_d.ndims();
938
939     jcp.ndims = ndims;
940     jcp.prop_kind = cd.prop_kind;
941
942     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
943     jcp.mb = diff_src_d.dims()[0];
944
945     jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
946     jcp.oc_without_padding = jcp.oc;
947     jcp.ic = diff_src_d.dims()[1] / jcp.ngroups;
948
949     jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1;
950     jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims-2];
951     jcp.iw = diff_src_d.dims()[ndims-1];
952     jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
953     jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2];
954     jcp.ow = diff_dst_d.dims()[ndims-1];
955
956     jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
957     jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
958     jcp.kw = weights_d.dims()[with_groups + ndims - 1];
959
960     jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
961     jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
962     jcp.l_pad = cd.padding[0][ndims-3];
963
964     jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
965     jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
966     jcp.stride_w = cd.strides[ndims-3];
967
968     jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
969     jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
970     jcp.dilate_w = cd.dilates[ndims-3];
971     jcp.dsrc_dt = cd.diff_src_desc.data_type;
972
973     /* Dilated convolutions supported with unit strides only */
974     if ((jcp.dilate_w != 0 && jcp.stride_w != 1)
975             || (jcp.dilate_d != 0 && jcp.stride_d != 1)
976             || (jcp.dilate_h != 0 && jcp.stride_h != 1))
977         return status::unimplemented;
978
979     jcp.is_cpx = mayiuse(avx512_core_bf16);
980
981     jcp.r_pad = (jcp.ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1)
982             - (jcp.iw + jcp.l_pad - 1);
983     jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
984             - (jcp.ih + jcp.t_pad - 1);
985     jcp.back_pad = (jcp.od - 1) * jcp.stride_d
986             + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1);
987
988     jcp.aligned_threads = 0;
989
990     jcp.oc_block = simd_w;
991     jcp.ic_block = simd_w;
992
993     bool ok_to_pad_channels = jcp.ngroups == 1;
994
995     if (ok_to_pad_channels) {
996         jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
997         jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
998     }
999
1000     auto src_format = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
1001     auto wei_format = with_groups
1002         ? pick(ndims - 3, gOIw8o16i2o, gOIhw8o16i2o, gOIdhw8o16i2o)
1003         : pick(ndims - 3, OIw8o16i2o, OIhw8o16i2o, OIdhw8o16i2o);
1004     bool args_ok = true
1005         && jcp.oc % jcp.oc_block == 0
1006         && jcp.ic % jcp.ic_block == 0
1007         && diff_src_d.format() == src_format
1008         && diff_dst_d.format() == src_format
1009         && weights_d.format() == wei_format;
1010     if (!args_ok)
1011         return status::unimplemented;
1012
1013     jcp.nb_ic = jcp.ic / jcp.ic_block;
1014     jcp.nb_oc = jcp.oc / jcp.oc_block;
1015
1016     jcp.ur_w = jcp.stride_w;
1017
1018     /* Maximun number of registers available for result accumulation and delta
1019        dst data. One additional register is reserved for weights data. */
1020     const int max_regs = jcp.is_cpx ? 31 : 26; /* In case of cpx emulation
1021                                                   additional 5 registers are
1022                                                   reserved */
1023     int l_overflow = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
1024                     - jcp.l_pad) / jcp.stride_w);
1025     int r_overflow1 = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
1026                     - nstl::max(0, jcp.r_pad) - jcp.iw % jcp.ur_w) / jcp.stride_w);
1027     int n_oi = jcp.iw / jcp.ur_w;
1028     if (r_overflow1 > 0) n_oi--;
1029
1030     jcp.ver = ver_vnni;
1031     jcp.typesize_in = sizeof(mkldnn_bfloat16_t);
1032     jcp.typesize_out = (diff_src_d.data_type() == data_type::f32)
1033         ? sizeof(float) : sizeof(mkldnn_bfloat16_t);
1034
1035     if (ndims != 4)
1036         return status::unimplemented;
1037
1038     /* Find the best blocking with maximum number of compute instructions
1039        per ur_w * nb_ic_blocking compute loops. Number of required registers
1040        is num_regs = ur_w * nb_ic_blocking + ur_w / stride_w <= max_regs.
1041        ur_w must be divisible by stride_w */
1042     if (jcp.stride_w + 1 > max_regs)  /* Minimal possible registers
1043                                          distribution exceeds max_regs */
1044         return status::unimplemented;
1045
1046     jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
1047     {
1048         jcp.kernel_kind = expl_bcast;
1049         int best_compute_pipeline_length = 0;
1050         const int max_ic_blocks = 4;
1051         for (int b = 1; b <= max_ic_blocks; b++)
1052         {
1053             if (jcp.nb_ic % b != 0)
1054                 continue;
1055
1056             for (int u = jcp.stride_w;
1057                  u * b + u / jcp.stride_w <= max_regs
1058                      && u < jcp.iw + jcp.stride_w;
1059                  u += jcp.stride_w)
1060             {
1061                 int ur_w = nstl::min(u, jcp.iw);
1062                 /* maximum 1 step with l_overflow so far */
1063                 if (l_overflow * jcp.stride_w > ur_w && ur_w != jcp.iw)
1064                     continue;
1065                 int pipeline_length = utils::div_up(ur_w, jcp.stride_w) * b;
1066                 if (pipeline_length > best_compute_pipeline_length
1067                    || (pipeline_length == best_compute_pipeline_length
1068                        && jcp.ur_w < ur_w)) {
1069                     jcp.ur_w = ur_w;
1070                     jcp.nb_ic_blocking = b;
1071                     best_compute_pipeline_length = pipeline_length;
1072                 }
1073             }
1074         }
1075         if (best_compute_pipeline_length == 0) /* can't find
1076                                                   appropriate blocking */
1077             return status::unimplemented;
1078     }
1079
1080     jcp.loop_order = loop_gnc;
1081
1082     jcp.ur_w_tail = jcp.iw % jcp.ur_w;
1083
1084     if (l_overflow * jcp.stride_w > jcp.ur_w)
1085         return status::unimplemented;
1086     int r_overflow_no_tail = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
1087                     - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w);
1088     if (r_overflow_no_tail * jcp.stride_w > jcp.ur_w)
1089         return status::unimplemented;
1090     if ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0))
1091         return status::unimplemented;
1092
1093     pick_loop_order(jcp);
1094
1095     jcp.nb_oc_L2 = jcp.nb_oc;
1096
1097     args_ok = true
1098         && jcp.ic <= diff_src_d.blocking_desc().padding_dims[1]
1099         && jcp.oc <= diff_dst_d.blocking_desc().padding_dims[1]
1100         && jcp.ic <= weights_d.blocking_desc().padding_dims[with_groups + 1]
1101         && jcp.oc <= weights_d.blocking_desc().padding_dims[with_groups + 0];
1102
1103     return args_ok ? status::success : status::unimplemented;
1104 }
1105
1106 const int jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::max_ur_w = 28;
1107
1108 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::oh_step_comeback_pointers()
1109 {
1110     Label kh_comeback_label, kd_comeback_label;
1111     mov(kj, reg_kh);
1112     L(kh_comeback_label); {
1113         int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
1114         int iw = jcp.tr_iw;
1115         sub(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mult);
1116         sub(reg_kernel,
1117             jcp.typesize_out * jcp.kw * jcp.ic_block * jcp.oc_block);
1118         dec(kj);
1119         cmp(kj, 0);
1120         jg(kh_comeback_label, T_NEAR);
1121     }
1122 }
1123 #ifndef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1124 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::compute_ic_block_step(
1125     int ur_w, int pad_l, int pad_r,
1126     int ic_block_step, int input_offset, int kernel_offset,
1127     int output_offset, bool is_tail)
1128 {
1129     int kw = jcp.kw;
1130     int ic_block = jcp.ic_block;
1131     int oc_block = jcp.oc_block;
1132
1133     auto zmm_ker = [=](int i_kw, int i_ic) {
1134         return Zmm(i_kw * ic_block_step + i_ic);
1135     };
1136     auto zmm_out = [=](int i_iw) {
1137         // TODO: move reg calc to global member funcs
1138         const int out_zmm_base_idx = 24;
1139         const int num_out_zmm_regs = !jcp.is_cpx ? 2 : 4;
1140         return Zmm(out_zmm_base_idx + i_iw % num_out_zmm_regs);
1141     };
1142
1143     auto ker_addr = [=](int i_kw, int i_ic) {
1144         size_t local_offset
1145             = jcp.typesize_out * (i_kw * ic_block + i_ic) * jcp.oc_block;
1146         return EVEX_compress_addr(reg_kernel, local_offset + kernel_offset);
1147     };
1148     auto inp_addr = [=](int i_iw, int i_ic, ptrdiff_t extra_offset = 0,
1149                         bool vnni_bcast = false) {
1150         int stride = jcp.tr_iw;
1151         int local_offset = jcp.typesize_in * (i_iw + i_ic * stride);
1152         if (vnni_bcast)
1153             return EVEX_compress_addr(reg_input,
1154                     local_offset + input_offset + extra_offset, true);
1155         else
1156             return EVEX_compress_addr(reg_input,
1157                     local_offset + input_offset + extra_offset);
1158     };
1159     auto out_addr = [=](int i_ur) {
1160         auto ow_per_oc = 2;
1161         return EVEX_compress_addr(reg_output,
1162                 jcp.typesize_in * i_ur * oc_block * ow_per_oc + output_offset);
1163     };
1164
1165     for (int i_kw = 0; i_kw < kw; i_kw++)
1166         for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1167             auto zmm = zmm_ker(i_kw, i_ic);
1168             vpxord(zmm, zmm, zmm);
1169         }
1170     assert(ur_w % 2 == 0);
1171     auto steps = ur_w / 2;
1172
1173     const int str_w = jcp.stride_w;
1174     for (int s = 0; s < str_w; s++) {
1175         const int kw_start = s;
1176         assert(jcp.tr_iw % str_w == 0);
1177         const int inp_stride_w_shift = jcp.tr_iw / str_w;
1178         for (int i_ur = 0; i_ur < steps; i_ur++) {
1179             auto zmm = zmm_out(i_ur);
1180             vmovdqu16(zmm, out_addr(i_ur));
1181
1182             for (int i_kw = kw_start; i_kw < kw; i_kw += str_w)
1183                 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1184                     int i_iw = 2 * i_ur + i_kw / str_w
1185                                  + s * inp_stride_w_shift;
1186                     if (!jcp.is_cpx) {
1187                         auto inp = Zmm(26);
1188                         vpbroadcastd(inp, inp_addr(i_iw, i_ic, 0));
1189                         auto acc = zmm_ker(i_kw, i_ic);
1190                         auto wei = zmm_out(i_ur);
1191                         bf16_emu_->r_vdpbf16ps(acc, wei, inp);
1192                     } else
1193                         vdpbf16ps(zmm_ker(i_kw, i_ic), zmm_out(i_ur),
1194                             inp_addr(i_iw, i_ic, 0, true));
1195                 }
1196         }
1197         for (int i_kw = kw_start; i_kw < kw; i_kw += str_w) {
1198             for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1199                 auto addr = ker_addr(i_kw, i_ic);
1200                 auto zmm = zmm_ker(i_kw, i_ic);
1201                 vaddps(zmm, zmm, addr);
1202                 vmovups(addr, zmm);
1203             }
1204         }
1205     }
1206 }
1207 #else
1208 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::compute_ic_block_step(
1209     int ur_w, int pad_l, int pad_r,
1210     int ic_block_step, int input_offset, int kernel_offset,
1211     int output_offset, bool is_tail)
1212 {
1213     int kw = jcp.kw;
1214     int ic_block = jcp.ic_block;
1215     int oc_block = jcp.oc_block;
1216
1217     for (int i_kw = 0; i_kw < kw; i_kw++)
1218         for (int i_ic = 0; i_ic < ic_block_step; i_ic++)
1219             vmovups(Zmm(i_kw * ic_block_step + i_ic),
1220                 EVEX_compress_addr(reg_kernel,typesize *
1221                     (i_kw * ic_block + i_ic) * jcp.oc_block + kernel_offset));
1222
1223     Reg64 reg_trans_tmp = r11;
1224     mov(reg_trans_tmp, dst_prm_table);
1225     auto perm = Zmm(24);
1226     vmovups(perm, ptr[reg_trans_tmp]);
1227
1228     Opmask load_mask = Opmask(7);
1229     for (int i_ur = 0; i_ur < ur_w; i_ur += 2) {
1230         if (ur_w % 2 && i_ur + 2 >= ur_w)
1231             mov(reg_trans_tmp.cvt32(), 0x0000ffff);
1232         else
1233             mov(reg_trans_tmp.cvt32(), 0xffffffff);
1234         kmovd(load_mask, reg_trans_tmp.cvt32());
1235         auto zmm_dst = Zmm(25);
1236         vmovdqu16(zmm_dst | load_mask | T_z,
1237             EVEX_compress_addr(reg_output,
1238                 jcp.typesize_in * i_ur * oc_block + output_offset));
1239         vpermw(zmm_dst, perm, zmm_dst);
1240         for (int i_kw = 0; i_kw < kw; i_kw++) {
1241             int iw_1 = (i_ur + i_kw);
1242             int iw_2 = (i_ur + 1 == ur_w) ? -1 : (i_ur + 1) + i_kw;
1243             iw_1 = (iw_1 - pad_l < 0 || iw_1 > (ur_w - 1) + (kw - 1) - pad_r)
1244                 ? -1 : iw_1 - pad_l;
1245             iw_2 = (iw_2 - pad_l < 0 || iw_2 > (ur_w - 1) + (kw - 1) - pad_r)
1246                 ? -1 : iw_2 - pad_l;
1247
1248             int local_offset = i_ur + i_kw - pad_l;
1249             if (iw_1 == -1 && iw_2 == -1) continue;
1250             if (iw_1 != -1 && iw_2 != -1) mov(reg_trans_tmp.cvt32(), 0xffffffff);
1251             if (iw_1 != -1 && iw_2 == -1) mov(reg_trans_tmp.cvt32(), 0x0000ffff);
1252             if (iw_1 == -1 && iw_2 != -1) mov(reg_trans_tmp.cvt32(), 0xffff0000);
1253             kmovd(load_mask, reg_trans_tmp.cvt32());
1254
1255             const size_t i_offset = (size_t)input_offset +
1256                             (size_t)jcp.typesize_in * (local_offset) * ic_block;
1257             auto bcast_values = Zmm(26);
1258             vpxord(bcast_values, bcast_values, bcast_values);
1259             vmovdqu16(bcast_values| load_mask | T_z, ptr[reg_input + i_offset]);
1260             vpermw(bcast_values,perm, bcast_values);
1261             vmovups(ptr[rsp], bcast_values);
1262
1263             for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1264                 if (!jcp.is_cpx) {
1265                     auto zmm_src = Zmm(28);
1266                     vpbroadcastd(zmm_src, ptr[rsp + jcp.typesize_in * 2 * i_ic]);
1267                     bf16_emu_->r_vdpbf16ps(Zmm(i_kw * ic_block_step + i_ic),
1268                         zmm_dst, zmm_src);
1269                 } else
1270                     vdpbf16ps(Zmm(i_kw * ic_block_step + i_ic), zmm_dst,
1271                                 zword_b[rsp + jcp.typesize_in * 2 * i_ic]);
1272             }
1273         }
1274     }
1275     for (int i_kw = 0; i_kw < kw; i_kw++) {
1276         for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1277             int l_offset = jcp.typesize_out *
1278                 (i_kw * ic_block + i_ic) * jcp.oc_block;
1279             vmovups(EVEX_compress_addr(reg_kernel,  l_offset + kernel_offset),
1280                         Zmm(i_kw * ic_block_step + i_ic));
1281         }
1282     }
1283 }
1284 #endif
1285 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32
1286     ::compute_oh_step_unroll_ow_icblock(
1287     int ic_block_step, int max_ur_w)
1288 {
1289     UNUSED(max_ur_w);
1290
1291     Label kh_label, kd_label;
1292
1293     int ic_block = jcp.ic_block;
1294     int oc_block = jcp.oc_block;
1295     int inp_mul = !jcp.is_1stconv ? ic_block : 1;
1296     int iw = jcp.tr_iw;
1297 #ifndef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1298     // physical padding exists
1299     int r_pad = 0;
1300     int l_pad = 0;
1301 #else
1302     int ow = jcp.tr_ow;
1303     // XXX: is it possible to use jcp.r_pad here?
1304     int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w
1305             + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
1306     int l_pad = jcp.l_pad;
1307 #endif
1308
1309     mov(kj, reg_kh);
1310     L(kh_label);
1311     {
1312         for (int i_b_ic = 0; i_b_ic < jcp.ic_block; i_b_ic += ic_block_step) {
1313 #ifndef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1314             const int input_offset = jcp.typesize_in * i_b_ic * iw;
1315 #else
1316             const int input_offset = jcp.typesize_in * i_b_ic;
1317 #endif
1318             compute_ic_block_step(jcp.ur_w, l_pad, r_pad, ic_block_step,
1319                 input_offset, jcp.typesize_out * i_b_ic * jcp.oc_block, 0,
1320                 i_b_ic + ic_block_step >= jcp.ic_block);
1321         }
1322         add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul);
1323         add(reg_kernel, jcp.typesize_out * jcp.kw * ic_block * oc_block);
1324         dec(kj);
1325         cmp(kj, 0);
1326         jg(kh_label, T_NEAR);
1327     }
1328 }
1329
1330 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32
1331     ::compute_oh_step_unroll_ow(
1332     int ic_block_step, int max_ur_w)
1333 {
1334     Label kh_label, ic_block_label, kd_label;
1335
1336     UNUSED(max_ur_w);
1337
1338     int ic_block = jcp.ic_block;
1339     int oc_block = jcp.oc_block;
1340
1341     int ow = jcp.tr_ow;
1342 #ifndef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1343     // physical padding exists
1344     int r_pad = 0;
1345     int l_pad = 0;
1346 #else
1347     // XXX: is it possible to use jcp.r_pad here?
1348     int r_pad = nstl::max(0,
1349         (ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1)
1350         - (jcp.iw + jcp.l_pad - 1));
1351     int l_pad = jcp.l_pad;
1352 #endif
1353
1354     mov(kj, reg_kh);
1355     L(kh_label);
1356     {
1357         xor_(b_ic, b_ic);
1358         L(ic_block_label); {
1359             compute_ic_block_step(ow, l_pad, r_pad, ic_block_step,
1360                 0, 0, 0);
1361 #ifndef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1362             size_t inp_icblk_stride = jcp.tr_iw;
1363 #else
1364             size_t inp_icblk_stride = jcp.is_1stconv
1365                 ? (size_t)jcp.ih * jcp.iw * jcp.id
1366                 : 1;
1367 #endif
1368             size_t input_offset
1369                 = inp_icblk_stride * jcp.typesize_in * ic_block_step;
1370             safe_add(reg_input, input_offset, reg_long_offt);
1371             add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block);
1372             add(b_ic, ic_block_step);
1373             cmp(b_ic, jcp.ic_block);
1374             jl(ic_block_label, T_NEAR);
1375         }
1376 #ifdef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1377         if (jcp.is_1stconv) {
1378             size_t input_offset
1379                 = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block;
1380             safe_sub(reg_input, input_offset, reg_long_offt);
1381             add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw);
1382         } else {
1383             add(reg_input, jcp.typesize_in
1384                     * ((jcp.dilate_h + 1) * jcp.iw - 1) * ic_block);
1385         }
1386 #endif
1387         add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block);
1388         dec(kj);
1389         cmp(kj, 0);
1390         jg(kh_label, T_NEAR);
1391     }
1392 }
1393
1394 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32
1395     ::compute_oh_step_common(
1396     int ic_block_step, int max_ur_w)
1397 {
1398     Label kh_label, ic_block_label, ow_block_label, kd_label;
1399
1400     int ic_block = jcp.ic_block;
1401     int oc_block = jcp.oc_block;
1402
1403     int ow = jcp.tr_ow;
1404 #ifndef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1405     // physical padding exists
1406     int l_pad = 0;
1407     int r_pad = 0;
1408     int stride_w = 1;
1409 #else
1410     int l_pad = jcp.l_pad;
1411     // XXX: is it possible to use jcp.r_pad here?
1412     int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w
1413             + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
1414     int stride_w = jcp.stride_w;
1415 #endif
1416     int ur_w = nstl::min(ow, max_ur_w);
1417     int ur_w_trips = ow / ur_w;
1418     int ur_w_tail = ow % ur_w;
1419     if ((ur_w_tail == 0 && r_pad != 0)
1420         || r_pad >= ur_w_tail) {
1421         if (ur_w_trips > 1) {
1422             ur_w_tail += ur_w;
1423             ur_w_trips--;
1424         } else {
1425             ur_w_tail += (ur_w - ur_w / 2);
1426             ur_w = ur_w / 2;
1427         }
1428     }
1429 #ifndef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1430     int inp_mult = 1;
1431 #else
1432     int inp_mult = (jcp.is_1stconv) ? 1 : ic_block;
1433 #endif
1434     int input_comeback = (ur_w_trips * ur_w * stride_w - l_pad) * inp_mult;
1435     int output_comeback = ur_w_trips * ur_w * oc_block;
1436
1437     mov(kj, reg_kh);
1438     L(kh_label); {
1439         xor_(b_ic, b_ic);
1440         L(ic_block_label); {
1441             if (l_pad != 0) {
1442                 ur_w_trips--;
1443                 compute_ic_block_step(ur_w, l_pad, 0, ic_block_step, 0, 0, 0);
1444                 add(reg_input, jcp.typesize_in * (ur_w * stride_w - l_pad)
1445                     * inp_mult);
1446                 add(reg_output, jcp.typesize_in * ur_w * oc_block);
1447             }
1448
1449             if (ur_w_trips > 0) {
1450                 xor_(reg_ur_w_trips, reg_ur_w_trips);
1451                 L(ow_block_label); {
1452                     compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0);
1453                     add(reg_input, jcp.typesize_in * ur_w * stride_w
1454                         * inp_mult);
1455                     add(reg_output, jcp.typesize_in * ur_w * oc_block);
1456
1457                     inc(reg_ur_w_trips);
1458                     cmp(reg_ur_w_trips, ur_w_trips);
1459                     jl(ow_block_label, T_NEAR);
1460                 }
1461             }
1462
1463             if (ur_w_tail > 0) {
1464                 compute_ic_block_step(ur_w_tail, 0, r_pad,
1465                     ic_block_step, 0, 0, 0, true);
1466             }
1467
1468             sub(reg_input, jcp.typesize_in * input_comeback);
1469             sub(reg_output, jcp.typesize_in * output_comeback);
1470 #ifndef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1471             int inp_icblk_stride = jcp.tr_iw;
1472 #else
1473             int inp_icblk_stride = jcp.is_1stconv
1474                 ? jcp.ih * jcp.iw * jcp.id
1475                 : 1;
1476 #endif
1477             size_t input_offset
1478                 = inp_icblk_stride * jcp.typesize_in * ic_block_step;
1479             safe_add(reg_input, input_offset, reg_long_offt);
1480             add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block);
1481
1482             add(b_ic, ic_block_step);
1483             cmp(b_ic, jcp.ic_block);
1484             jl(ic_block_label, T_NEAR);
1485         }
1486 #ifdef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1487         if (jcp.is_1stconv) {
1488             size_t input_offset
1489                 = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block;
1490             safe_sub(reg_input, input_offset, reg_long_offt);
1491             add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw);
1492         } else {
1493             add(reg_input, jcp.typesize_in
1494                     * ((jcp.dilate_h + 1 ) * jcp.iw - 1) * ic_block);
1495         }
1496 #endif
1497         add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block);
1498         dec(kj);
1499         cmp(kj, 0);
1500         jg(kh_label, T_NEAR);
1501     }
1502 }
1503
1504 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32
1505     ::compute_oh_step_disp()
1506 {
1507     int ic_block_step = jcp.kw <= 3 ? 8 : (jcp.kw < 7 ? 4 : 2);
1508
1509     bool too_large_to_unroll
1510         = (jcp.kw > 1 || jcp.kh > 1)
1511         && (jcp.stride_w > 1 || jcp.stride_h > 1 || jcp.stride_d > 1);
1512
1513     int ow = jcp.tr_ow;
1514     if (jcp.kw <= 3 && ow <= 16 && !too_large_to_unroll) {
1515         compute_oh_step_unroll_ow_icblock(ic_block_step, max_ur_w);
1516     } else if (ow <= max_ur_w) {
1517         compute_oh_step_unroll_ow(ic_block_step, max_ur_w);
1518     } else {
1519         compute_oh_step_common(ic_block_step, max_ur_w);
1520     }
1521     oh_step_comeback_pointers();
1522 }
1523
1524 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::maybe_zero_kernel()
1525 {
1526     Label skip_zeroing, zeroing_loop;
1527
1528     mov(reg_tmp, ptr[param + GET_OFF(channel)]);
1529     cmp(reg_tmp, 0);
1530     jz(skip_zeroing, T_NEAR);
1531
1532     Zmm zero = Zmm(0);
1533     vpxord(zero, zero, zero);
1534     xor_(reg_tmp, reg_tmp);
1535     L(zeroing_loop); {
1536         assert(jcp.oc_block * jcp.typesize_out
1537             == cpu_isa_traits<avx512_core>::vlen);
1538         for (int ic1 = 0; ic1 < jcp.ic_block; ic1++)
1539             vmovups(ptr[reg_kernel + reg_tmp + ic1 * jcp.oc_block
1540                 * jcp.typesize_out], zero);
1541         add(reg_tmp, jcp.ic_block * jcp.oc_block * jcp.typesize_out);
1542         cmp(reg_tmp, jcp.ic_block * jcp.oc_block * jcp.kw * jcp.kh
1543             * jcp.typesize_out);
1544         jnz(zeroing_loop);
1545     }
1546
1547     L(skip_zeroing);
1548 }
1549
1550 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32
1551     ::compute_loop()
1552 {
1553     int b_pad = jcp.b_pad;
1554     int t_pad = jcp.t_pad;
1555     bool is_dilated = jcp.dilate_h != 0;
1556     int dilate_h = jcp.dilate_h + 1;
1557     int stride_h = jcp.stride_h;
1558     const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
1559     int iw = jcp.tr_iw;
1560     int ow = jcp.tr_ow;
1561     Label oh_label, oh_label_end, oh_tpad_label, oh_tpad_tail_label,
1562             oh_bpad_label, oh_bpad_label_end,
1563             oh_dilate_label_shift, oh_dilate_label_noshift, oh_dilate_label_end,
1564             skip_neg_overlap_label, skip_fpad_label, skip_input_label;
1565
1566     maybe_zero_kernel();
1567
1568     mov(reg_kh, jcp.kh);
1569     xor_(reg_ih_count, reg_ih_count);
1570     xor_(reg_oj, reg_oj);
1571     /* Compute 'top' edge */
1572     if (t_pad > 0) {
1573         const int kh_range = 1 + (jcp.kh - 1) * dilate_h;
1574         const int overflow
1575             = nstl::max(0, jcp.kh - div_up(t_pad + jcp.ih, dilate_h));
1576         const int underflow = div_up(t_pad, dilate_h);
1577         const int initial_inp_ker_overlap = jcp.kh - overflow - underflow;
1578         mov(reg_kh, initial_inp_ker_overlap);
1579         add(reg_kernel, jcp.typesize_out * underflow * jcp.kw * jcp.ic_block
1580             * jcp.oc_block);
1581         // generate loop to process kernel while it remains within t_pad + ih
1582         if (kh_range < t_pad + jcp.ih) {
1583             if (is_dilated) {
1584                 const int tail = t_pad % dilate_h;
1585                 const int shift = tail == 0 ? 0 : dilate_h - tail;
1586                 mov(reg_tmp, shift);
1587                 if (tail != 0)
1588                     add(reg_input, jcp.typesize_in * shift * iw * inp_mult);
1589             }
1590             L(oh_tpad_label); {
1591                 compute_oh_step_disp();
1592                 add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
1593                 if (is_dilated) {
1594                     inc(reg_tmp);
1595                     cmp(reg_tmp, dilate_h);
1596                     jl(oh_dilate_label_shift, T_NEAR);
1597                     // unshift input as new kernel element enters
1598                     sub(reg_input, jcp.typesize_in * (dilate_h - 1) * iw * inp_mult);
1599                     xor_(reg_tmp, reg_tmp);
1600                 }
1601                 // kernel overlap only changes when (t_pad + oj) % dilate_h == 0
1602                 sub(reg_kernel, jcp.typesize_out * stride_h * jcp.kw
1603                                 * jcp.ic_block * jcp.oc_block);
1604                 add(reg_kh, stride_h);
1605                 if (is_dilated) {
1606                     jmp(oh_dilate_label_noshift, T_NEAR);
1607                     L(oh_dilate_label_shift);
1608                     // shift input as old kernel element progresses
1609                     add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult);
1610                     L(oh_dilate_label_noshift);
1611                 }
1612                 inc(reg_oj);
1613                 add(reg_ih_count, stride_h);
1614
1615                 // final number of kernel elements that overlap with input
1616                 const int final_inp_ker_overlap
1617                     = nstl::min(jcp.kh, div_up(jcp.ih, dilate_h));
1618                 cmp(reg_kh, final_inp_ker_overlap);
1619                 jl(oh_tpad_label, T_NEAR);
1620             }
1621         }
1622         // need second loop to process kernel if it is larger than the input
1623         // (does not apply to dilations as they must have unit stride)
1624         if (kh_range >= jcp.ih + (t_pad % stride_h == 0 ? stride_h :
1625                                                         t_pad % stride_h)) {
1626             assert(!is_dilated);
1627             mov(reg_kh, jcp.ih);
1628             L(oh_tpad_tail_label); {
1629                 compute_oh_step_disp();
1630                 add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
1631                 sub(reg_kernel, jcp.typesize_out * stride_h * jcp.kw
1632                                 * jcp.ic_block * jcp.oc_block);
1633
1634                 inc(reg_oj);
1635                 add(reg_ih_count, stride_h);
1636
1637                 cmp(reg_ih_count, nstl::min(t_pad, jcp.oh * stride_h));
1638                 jl(oh_tpad_tail_label, T_NEAR);
1639             }
1640         }
1641         // correct any excess shifts to kernel and input
1642         // (does not apply to dilations as they must have unit stride,
1643         //  kernel must fit inside input, and padding is smaller than input)
1644         if (t_pad <= jcp.oh * stride_h) {
1645             // kernel has moved beyond padding (adjust for stride effects)
1646             if (t_pad % stride_h != 0) {
1647                 assert(!is_dilated);
1648                 int inp_corr = stride_h - t_pad % stride_h;
1649                 add(reg_kernel, jcp.typesize_out * inp_corr * jcp.kw
1650                                 * jcp.ic_block * jcp.oc_block);
1651                 add(reg_input, jcp.typesize_in * inp_corr * iw * inp_mult);
1652             }
1653         } else {
1654             // kernel still overlaps padding (complete reset)
1655             assert(!is_dilated);
1656             sub(reg_kernel, jcp.typesize_out * (t_pad - jcp.oh * stride_h)
1657                             * jcp.kw * jcp.ic_block * jcp.oc_block);
1658         }
1659     }
1660
1661     cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h);
1662     jge(oh_label_end, T_NEAR);
1663     cmp(reg_oj, jcp.oh);
1664     jge(oh_label, T_NEAR);
1665
1666     /* Compute middle block(s) */
1667     mov(reg_kh, jcp.kh);
1668     L(oh_label); {
1669         compute_oh_step_disp();
1670         add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult);
1671         add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
1672
1673         inc(reg_oj);
1674         add(reg_ih_count, stride_h);
1675
1676         cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h);
1677         jge(oh_label_end, T_NEAR);
1678
1679         cmp(reg_oj, jcp.oh);
1680         jl(oh_label, T_NEAR);
1681     }
1682     L(oh_label_end);
1683
1684     /* Compute bottom edge */
1685     if (b_pad > 0) {
1686         cmp(reg_oj, jcp.oh);
1687         jge(oh_bpad_label_end, T_NEAR);
1688
1689         if (is_dilated) {
1690             mov(reg_kh, jcp.kh - 1); // assumes unit stride for dilations
1691             mov(reg_tmp, 0);
1692         } else {
1693             mov(reg_kh, jcp.ihp - b_pad);
1694             sub(reg_kh, reg_ih_count);
1695         }
1696         L(oh_bpad_label);
1697         {
1698             compute_oh_step_disp();
1699             add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult);
1700             add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
1701             if (is_dilated) {
1702                 inc(reg_tmp);
1703                 cmp(reg_tmp, dilate_h);
1704                 jl(oh_dilate_label_end, T_NEAR);
1705                 xor_(reg_tmp, reg_tmp);
1706             }
1707             sub(reg_kh, stride_h);
1708             cmp(reg_kh, 0);
1709             jle(oh_bpad_label_end, T_NEAR);
1710             if (is_dilated)
1711                 L(oh_dilate_label_end);
1712
1713             inc(reg_oj);
1714             cmp(reg_oj, jcp.oh);
1715             jl(oh_bpad_label, T_NEAR);
1716         }
1717         L(oh_bpad_label_end);
1718     }
1719 }
1720
1721 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::generate()
1722 {
1723     preamble();
1724
1725 #ifdef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1726     sub(rsp, stack_space_needed);
1727 #endif
1728
1729     mov(reg_input, ptr[param + GET_OFF(src)]);
1730     mov(reg_output, ptr[param + GET_OFF(dst)]);
1731     mov(reg_kernel, ptr[param + GET_OFF(filt)]);
1732
1733     compute_loop();
1734
1735 #ifdef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1736     add(rsp, stack_space_needed);
1737 #endif
1738
1739     postamble();
1740
1741 #ifdef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1742     align(64);
1743     L(dst_prm_table);
1744     const uint16_t dst_prm_array[32] =
1745         {0,16,  1,17,  2,18,  3,19,  4,20,  5,21,  6,22,  7,23,  8,24,
1746          9,25,  10,26,  11,27,  12,28,  13,29,  14,30,  15,31 };
1747
1748     for (size_t i = 0; i < 32; ++i)
1749         dw(dst_prm_array[i]);
1750 #endif
1751 }
1752
1753 status_t jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::init_conf(
1754     jit_conv_conf_t &jcp, const convolution_desc_t &cd,
1755     cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &diff_weights_pd,
1756     cpu_memory_t::pd_t &diff_bias_pd, cpu_memory_t::pd_t &diff_dst_pd)
1757 {
1758     const int simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
1759
1760     const memory_desc_wrapper src_d(&src_pd);
1761     const memory_desc_wrapper diff_weights_d(&diff_weights_pd);
1762     const memory_desc_wrapper diff_bias_d(&diff_bias_pd);
1763     const memory_desc_wrapper diff_dst_d(&diff_dst_pd);
1764
1765     const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
1766     int ndims = src_d.ndims();
1767
1768     jcp = zero<decltype(jcp)>();
1769     jcp.ndims = ndims;
1770     jcp.prop_kind = cd.prop_kind;
1771
1772     jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
1773     jcp.mb = src_d.dims()[0];
1774
1775     jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
1776     jcp.oc_without_padding = jcp.oc;
1777     jcp.ic = src_d.dims()[1] / jcp.ngroups;
1778
1779     jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
1780     jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2];
1781     jcp.iw = src_d.dims()[ndims-1];
1782     jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
1783     jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2];
1784     jcp.ow = diff_dst_d.dims()[ndims-1];
1785
1786     jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1;
1787     jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims-2];
1788     jcp.kw = diff_weights_d.dims()[with_groups + ndims-1];
1789
1790     jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
1791     jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
1792     jcp.l_pad = cd.padding[0][ndims-3];
1793
1794     jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
1795     jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
1796     jcp.stride_w = cd.strides[ndims-3];
1797
1798     jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
1799     jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
1800     jcp.dilate_w = cd.dilates[ndims-3];
1801
1802     const int kh_range = 1 + (jcp.kh - 1) * (jcp.dilate_h + 1);
1803     bool ok = true
1804         // general condition to simplify dilations
1805         && IMPLICATION(jcp.dilate_d != 0, jcp.stride_d == 1)
1806         && IMPLICATION(jcp.dilate_h != 0, jcp.stride_h == 1)
1807         && IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1)
1808         // special condition to simplify dilations in compute_oh_loop_common
1809         && IMPLICATION(jcp.dilate_h != 0, kh_range <= jcp.ih);
1810     if (!ok)
1811         return status::unimplemented;
1812
1813     jcp.r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w
1814             + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
1815     jcp.b_pad = nstl::max(0, (jcp.oh - 1) * jcp.stride_h
1816             + (jcp.kh - 1) * (jcp.dilate_h + 1) - (jcp.ih + jcp.t_pad - 1));
1817
1818     /* XXX: currently, does not support stride_d > 1 or dilation > 0 */
1819     if (ndims == 5)
1820         if (jcp.stride_d > 1 || jcp.dilate_d > 0)
1821             return status::unimplemented;
1822
1823     jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
1824     jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
1825     jcp.ohp = jcp.oh;
1826     jcp.owp = jcp.ow;
1827     jcp.aligned_threads = 0;
1828
1829     jcp.oc_block = simd_w;
1830
1831     bool ok_to_pad_channels = jcp.ngroups == 1;
1832
1833     if (ok_to_pad_channels) {
1834         jcp.oc = rnd_up(jcp.oc, simd_w);
1835         jcp.ic = rnd_up(jcp.ic, simd_w);
1836     }
1837
1838     auto src_format = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
1839     auto wei_format = with_groups
1840         ? pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o)
1841         : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o);
1842     // temporary workaround until bf16 jit supports 1d
1843     if (wei_format == gOIw16i16o || wei_format == OIw16i16o)
1844         return status::unimplemented;
1845
1846     /* conditions on bias memory */
1847     jcp.with_bias = cd.diff_bias_desc.format != memory_format::undef;
1848     if (jcp.with_bias) {
1849         if (diff_bias_d.format() == any)
1850             CHECK(diff_bias_pd.set_format(x));
1851         if (diff_bias_d.format() != x)
1852             return status::unimplemented;
1853     }
1854
1855     jcp.nb_oc = jcp.oc / jcp.oc_block;
1856
1857     if (diff_dst_d.format() == any)
1858         CHECK(diff_dst_pd.set_format(src_format));
1859     if (diff_dst_d.format() != src_format)
1860         return status::unimplemented;
1861
1862     /* kernel applicability check wrt boundaries
1863      * the conditions are quite general across the kernels we have,
1864      * but ideally the check should belong to a specific kernel... */
1865     const int max_pad = ((jcp.kh - 1) * (jcp.dilate_h + 1) + 1) / 2;
1866     const bool boundaries_ok = true
1867         && jcp.t_pad <= max_pad
1868         && jcp.b_pad <= max_pad;
1869     if (!boundaries_ok)
1870         return status::unimplemented;
1871
1872     /* yet another common check */
1873     if (jcp.kw > 14)
1874         return status::unimplemented;
1875
1876     /* setting register strategy */
1877     for (int ur_w = nstl::min(max_ur_w, jcp.ow); ur_w > 0; --ur_w) {
1878         if (jcp.ow % ur_w == 0) { jcp.ur_w = ur_w; break; }
1879     }
1880
1881     if (src_d.format() == any)
1882         CHECK(src_pd.set_format(src_format));
1883     if (diff_weights_d.format() == any)
1884         CHECK(diff_weights_pd.set_format(wei_format));
1885
1886     ok = true
1887         && src_d.format() == src_format
1888         && diff_weights_d.format() == (wei_format);
1889     if (!ok)
1890         return status::unimplemented;
1891     jcp.dwei_dt = diff_weights_d.data_type();
1892
1893     jcp.ic_block = simd_w;
1894     if (ok_to_pad_channels)
1895         jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
1896     jcp.nb_ic = jcp.ic / jcp.ic_block;
1897     jcp.src_fmt = src_d.format();
1898     if (mkldnn_thr_syncable()
1899             && one_of(ndims, 3, 4)
1900             && everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w)
1901             && everyone_is(data_type::bf16,
1902                                src_d.data_type(), diff_dst_d.data_type())
1903             && one_of(diff_weights_d.data_type(),
1904                           data_type::f32, data_type::bf16)) {
1905         jcp.ver = ver_vnni;
1906     } else {
1907         return status::unimplemented;
1908     }
1909     jcp.is_cpx = mayiuse(avx512_core_bf16);
1910 #ifndef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1911     const int tr_round = 4;
1912     // TODO: try to optimize required memory size
1913     int tr_pad = rnd_up(nstl::max(1, nstl::max(jcp.l_pad, jcp.r_pad)),
1914                             tr_round);
1915     jcp.tr_iw = rnd_up(div_up(jcp.iw, jcp.stride_w) + tr_pad, tr_round)
1916                     * jcp.stride_w;
1917     jcp.tr_src_num_guard_elems = tr_pad; // upper bound
1918     jcp.tr_ow = rnd_up(jcp.ow, 2);
1919     jcp.ur_w = jcp.tr_ow;
1920 #else
1921     jcp.tr_ow = jcp.ow;
1922     jcp.tr_iw = jcp.iw;
1923     jcp.ur_w = jcp.ow;
1924     if (jcp.stride_w != 1)
1925         return status::unimplemented;
1926 #endif
1927     jcp.typesize_in = sizeof(mkldnn_bfloat16_t);
1928     jcp.typesize_out = sizeof(float);
1929
1930     bool args_ok = true
1931         && jcp.ic % jcp.ic_block == 0
1932         && jcp.oc % jcp.oc_block == 0
1933         && jcp.ic <= src_d.blocking_desc().padding_dims[1]
1934         && jcp.oc <= diff_dst_d.blocking_desc().padding_dims[1]
1935         && jcp.ic <=
1936                 diff_weights_d.blocking_desc().padding_dims[with_groups + 1]
1937         && jcp.oc <=
1938                 diff_weights_d.blocking_desc().padding_dims[with_groups + 0];
1939     if (!args_ok) return status::unimplemented;
1940
1941     {   // balancing
1942         int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
1943         balance(jcp, nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b);
1944         jcp.nthr = nthr;
1945         jcp.nthr_mb = nthr_mb;
1946         jcp.nthr_g = nthr_g;
1947         jcp.nthr_oc_b = nthr_oc_b;
1948         jcp.nthr_ic_b = nthr_ic_b;
1949     }
1950
1951     return status::success;
1952 }
1953
1954 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::init_scratchpad(
1955         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
1956 #ifndef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1957     // XXX: See the comment about tr_iw and guarding elements in
1958     // jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::init_conf()
1959 #if !defined(BF16_CONV_BWD_W_DOES_NOT_USE_BARRIERS)
1960     const size_t max_nthr = jcp.nthr_mb * jcp.ngroups * jcp.nb_ic;
1961 #else
1962     const size_t max_nthr = jcp.nthr;
1963 #endif // defined(BF16_CONV_BWD_W_DOES_NOT_USE_BARRIERS)
1964     const size_t min_tr_src_size_per_thr = jcp.ih * jcp.ic_block * jcp.tr_iw;
1965     const size_t tr_src_size = max_nthr * min_tr_src_size_per_thr
1966         + jcp.tr_src_num_guard_elems;
1967     scratchpad.book(key_conv_tr_src, jcp.typesize_in * tr_src_size);
1968
1969 #if !defined(BF16_CONV_BWD_W_DOES_NOT_USE_BARRIERS)
1970     /* prepare synchronization contexts */
1971     if (jcp.nthr_oc_b > 1) {
1972         const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b;
1973         scratchpad.book(key_conv_tr_src_bctx,
1974                 sizeof(simple_barrier::ctx_t) * tr_src_bctx_size);
1975     }
1976 #endif // !defined(BF16_CONV_BWD_W_DOES_NOT_USE_BARRIERS)
1977
1978 #if !defined(BF16_CONV_BWD_W_DOES_NOT_USE_BARRIERS)
1979     const size_t tr_diff_dst_size = jcp.nthr_mb * jcp.ngroups
1980         * jcp.nb_oc * jcp.oc_block * jcp.tr_ow * jcp.oh;
1981 #else
1982     const size_t tr_diff_dst_size = jcp.nthr
1983         * jcp.oc_block * jcp.tr_ow * jcp.oh;
1984 #endif // defined(BF16_CONV_BWD_W_DOES_NOT_USE_BARRIERS)
1985     scratchpad.book(key_conv_tr_diff_dst, jcp.typesize_in * tr_diff_dst_size);
1986
1987 #if !defined(BF16_CONV_BWD_W_DOES_NOT_USE_BARRIERS)
1988     /* prepare synchronization contexts */
1989     if (jcp.nthr_ic_b > 1) {
1990         const size_t tr_diff_dst_bctx_size = jcp.nthr / jcp.nthr_ic_b;
1991         scratchpad.book(key_conv_tr_diff_dst_bctx,
1992                 sizeof(simple_barrier::ctx_t) * tr_diff_dst_bctx_size);
1993     }
1994 #endif // defined(BF16_CONV_BWD_W_DOES_NOT_USE_BARRIERS)
1995 #endif // BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1996
1997     if (jcp.nthr_mb > 1 || jcp.dwei_dt == data_type::bf16) {
1998         const size_t wei_size = jcp.ngroups * jcp.oc * jcp.ic
1999             * jcp.kh * jcp.kw * jcp.kd;
2000         const size_t bia_size = jcp.ngroups * jcp.oc;
2001
2002         const int num_wei_buffers = jcp.dwei_dt == data_type::bf16
2003             ? jcp.nthr_mb
2004             : jcp.nthr_mb - 1;
2005
2006         const size_t wei_bia_reduction_size = wei_size + bia_size;
2007
2008         scratchpad.book(key_conv_wei_bia_reduction,
2009                 jcp.typesize_out * wei_bia_reduction_size * num_wei_buffers);
2010         // TODO: don't use barrier for case
2011         // jcp.dwei_dt == data_type::bf16 && nthr_mb_ == 1
2012         scratchpad.book(key_conv_wei_bia_reduction_bctx,
2013                 sizeof(simple_barrier::ctx_t));
2014     }
2015
2016     if (jcp.with_bias) {
2017         const size_t dst_f32_size = (size_t)jcp.od * jcp.oh * jcp.ow
2018              * jcp.oc_block * jcp.typesize_out;
2019         scratchpad.book(key_conv_dst_bf16_convert_wsp, jcp.nthr * dst_f32_size);
2020     }
2021
2022     if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
2023         scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc);
2024 }
2025
2026 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::balance(
2027         const jit_conv_conf_t &j, int &nthr_, int &nthr_mb_, int &nthr_g_,
2028         int &nthr_oc_b_, int &nthr_ic_b_)
2029 {
2030     nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1;
2031
2032     const int max_threads = mkldnn_get_max_threads();
2033
2034     if (max_threads < j.ngroups) {
2035         /* simplification... fortunately it doesn't hurt much */
2036         return;
2037     }
2038
2039     if (!mkldnn_thr_syncable()) {
2040         // should not happen -- the driver is not ready
2041         // for TBB-like non-synchronous threading yet
2042         return;
2043     }
2044
2045     nthr_g_ = j.ngroups;
2046     const int nthr = max_threads / nthr_g_;
2047
2048     auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
2049         /* calculate per thread memory cost (read/write). high level optimizer
2050          * tries to minimize memory consumption. few notes:
2051          *  (n1) unclear why, but that essentially helps first convolution...
2052          *  (n2) assuming the reduction over minibatch is always there:
2053          *    - instead of 8 it should be 5 here (write ~= 2 read):
2054          *      kernel: temporal workspace 1 write
2055          *      reduction: 1 read from workspace and 1 write to the diff_wei
2056          *    - but experiments showed 8 works better than 5 or 6... */
2057
2058         const int src_coef = 4;
2059         const int dst_coef = 1;
2060         const int wei_coef = 4;
2061
2062         return 0
2063             + src_coef
2064             * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_)
2065             * div_up(j.nb_ic, nthr_ic_b) * j.ic_block * j.ih * j.iw * j.id
2066             / j.stride_d / j.stride_h / j.stride_w /* (n1) */
2067             + dst_coef
2068             * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_)
2069             * div_up(j.nb_oc, nthr_oc_b) * j.oc_block * j.oh * j.ow * j.od
2070             + wei_coef /* (n2) */
2071             * div_up(j.ngroups, nthr_g_)
2072             * div_up(j.nb_oc, nthr_oc_b) * div_up(j.nb_ic, nthr_ic_b)
2073             * j.kh * j.kw * j.kd * j.ic_block * j.oc_block;
2074     };
2075
2076     int best_mem_cost = calc_mem_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
2077
2078     /* step 1: find the best thread distribution with lowest memory cost */
2079     const int nthr_mb_max = nstl::min(nthr, j.mb * j.od);
2080     for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
2081         const int nthr_par = nthr / nthr_mb;
2082         const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc);
2083         for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
2084             int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic);
2085
2086             int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
2087             if (mem_cost <= best_mem_cost) {
2088                 best_mem_cost = mem_cost;
2089                 nthr_mb_ = nthr_mb;
2090                 nthr_oc_b_ = nthr_oc_b;
2091                 nthr_ic_b_ = nthr_ic_b;
2092             }
2093         }
2094
2095         if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; }
2096     }
2097
2098     if (nthr_mb_ > max_threads/2 && nthr_mb_ < max_threads)
2099         nthr_mb_ = nstl::min(j.mb * j.od, max_threads);
2100     nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_;
2101
2102     assert(nthr_ <= max_threads);
2103     assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_mb_ == 1));
2104 }
2105
2106 }
2107 }
2108 }
2109 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s