1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
19 #include "type_helpers.hpp"
22 #include "cpu_barrier.hpp"
23 #include "cpu_memory.hpp"
25 #include "jit_avx512_core_bf16_conv_kernel.hpp"
27 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
28 #define KNx_L2_EFFECTIVE_CAPACITY ((512-64)*1024)
34 using namespace mkldnn::impl::memory_format;
35 using namespace mkldnn::impl::memory_tracking::names;
36 using namespace mkldnn::impl::utils;
37 using namespace Xbyak;
41 constexpr auto small_spatial = 14;
43 inline void pick_loop_order(jit_conv_conf_t &jcp) {
44 using namespace prop_kind;
45 assert(one_of(jcp.prop_kind,
46 forward_training, forward_inference, backward_data));
47 auto w = (jcp.prop_kind == backward_data) ? jcp.iw : jcp.ow;
48 auto h = (jcp.prop_kind == backward_data) ? jcp.ih : jcp.oh;
50 // ow-threading is currently implemented for forward only
51 // TODO: single code for fwd and bwd after ow-thr for bwd
52 // meaningless switch was removed
53 if (jcp.prop_kind == backward_data) {
54 jcp.loop_order = (w <= small_spatial && h <= small_spatial)
55 ? loop_cgn : loop_gnc;
57 jcp.loop_order = (w <= small_spatial && h <= small_spatial)
58 ? loop_cwgn : loop_gncw;
61 inline bool is_1D_conv(const jit_conv_conf_t &jcp) {
62 return (jcp.ih == 1 && jcp.kh == 1);
64 inline bool is_ow_threading_available(const jit_conv_conf_t &jcp) {
65 return (is_1D_conv(jcp) && one_of(jcp.ndims, 3, 4)
66 && !(jcp.ver == ver_fma && mayiuse(avx512_mic)));
68 inline bool is_ow_threading_on(const jit_conv_conf_t &jcp) {
69 return (jcp.nb_ow > 1);
73 void jit_avx512_core_bf16_fwd_kernel::prepare_output(int ur_w)
75 for (int k = 0; k < jcp.nb_oc_blocking; k++)
76 for (int j = 0; j < ur_w; j++) {
77 Zmm zmm = zmm_out(j, k);
78 vpxord(zmm, zmm, zmm);
82 void jit_avx512_core_bf16_fwd_kernel::store_output(int ur_w)
86 bf16_emu_->init_vcvtneps2bf16();
89 for (int k = 0; k < jcp.nb_oc_blocking; k++) {
90 for (int j = 0; j < ur_w; j++) {
91 Zmm zmm = zmm_out(j, k);
92 size_t aux_output_offset = get_output_offset(j, k);
93 if (jcp.dst_dt == data_type::bf16) {
94 vpmovzxwd(zmm_prev_dst,
95 make_safe_addr(reg_out, aux_output_offset, reg_out_long_offt));
96 vpslld(zmm_prev_dst, zmm_prev_dst, 16);
97 vaddps(zmm, zmm_prev_dst);
100 make_safe_addr(reg_out, aux_output_offset, reg_out_long_offt));
107 mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
108 for (int k = 0; k < jcp.nb_oc_blocking; k++) {
109 int bias_offset = sizeof(float) * k * jcp.oc_block;
110 for (int j = 0; j < ur_w; j++) {
111 Zmm zmm = zmm_out(j, k);
112 vaddps(zmm, EVEX_compress_addr(reg_bias, bias_offset));
117 if (jcp.with_eltwise) {
118 if (ur_w == jcp.ur_w) {
119 eltwise_injector_->compute_vector_range(0,
120 jcp.nb_oc_blocking * jcp.ur_w);
122 for (int k = 0; k < jcp.nb_oc_blocking; k++)
123 eltwise_injector_->compute_vector_range(k * jcp.ur_w,
124 k * jcp.ur_w + ur_w);
129 if (jcp.dst_dt == data_type::f32) {
130 for (int k = 0; k < jcp.nb_oc_blocking; k++)
131 for (int j = 0; j < ur_w; j++) {
132 Zmm zmm = zmm_out(j, k);
133 size_t aux_output_offset = jcp.typesize_out *
134 ((size_t)k * jcp.od * jcp.oh * jcp.ow + j) * jcp.oc_block;
135 auto addr = EVEX_compress_addr(reg_out, aux_output_offset);
139 } else if (jcp.dst_dt == data_type::bf16) {
141 for (int k = 0; k < jcp.nb_oc_blocking; k++) {
142 int n_2bf2ps = (ur_w / 2) * 2, j = 0;
143 for (j = 0; j < n_2bf2ps; j += 2) {
144 size_t aux_output_offset = jcp.typesize_out *
145 ((size_t)k * jcp.od * jcp.oh * jcp.ow + j) * jcp.oc_block;
146 auto addr = EVEX_compress_addr(reg_out, aux_output_offset);
148 auto zmm_str = zmm_inp(j, jcp.nb_oc_blocking);
149 vcvtne2ps2bf16(zmm_str, zmm_out(j+1, k),
151 vmovups(addr, zmm_str);
154 size_t aux_output_offset = jcp.typesize_out *
155 ((size_t)k * jcp.od * jcp.oh * jcp.ow + j) * jcp.oc_block;
156 auto addr = EVEX_compress_addr(reg_out, aux_output_offset);
157 auto ymm_str = ymm_inp(j, jcp.nb_oc_blocking);
158 vcvtneps2bf16(ymm_str, zmm_out(j, k));
159 vmovups(addr, ymm_str);
163 for (int k = 0; k < jcp.nb_oc_blocking; k++)
164 for (int j = 0; j < ur_w; j++) {
165 Zmm zmm = zmm_out(j, k);
166 size_t aux_output_offset = jcp.typesize_out *
167 ((size_t)k * jcp.od * jcp.oh * jcp.ow + j) * jcp.oc_block;
168 auto addr = EVEX_compress_addr(reg_out, aux_output_offset);
169 Ymm ymm = ymm_inp(0, jcp.nb_oc_blocking);
170 bf16_emu_->r_vcvtneps2bf16(ymm, zmm);
175 assert(!"unsupported destination type");
178 void jit_avx512_core_bf16_fwd_kernel::compute_loop(
179 int ur_w, int pad_l, int pad_r)
181 Label kh_label, kd_label;
182 const size_t shift_kernel_ptr = (size_t)jcp.typesize_in * jcp.kw
183 * jcp.oc_block * jcp.ic_block;
184 const size_t shift_input_ptr
185 = (size_t)jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw * jcp.ic_block;
187 prepare_output(ur_w);
191 mov(reg_icb, jcp.nb_ic);
194 mov(aux_reg_inp, reg_inp);
195 mov(aux_reg_ker, reg_ker);
197 Label skip_kh_loop, skip_kd_loop;
200 if ((jcp.dilate_h >= jcp.ih)
201 || (jcp.kh - 1) * (jcp.dilate_h + 1)
202 < nstl::max(jcp.t_pad, jcp.b_pad)) {
204 je(skip_kh_loop, T_NEAR);
208 for (int ki = 0; ki < jcp.kw; ki++) {
209 int ow_start = get_ow_start(ki, pad_l);
210 int ow_end = get_ow_end(ur_w, ki, pad_r);
212 ic < div_up(nstl::min(jcp.ic_block, jcp.ic), 2); ic++) {
214 for (int oi = ow_start; oi < ow_end; oi++) {
215 size_t input_offset =
216 get_input_offset(ki, ic, oi, pad_l);
217 vpbroadcastd(zmm_inp(oi, jcp.nb_oc_blocking),
218 EVEX_compress_addr(aux_reg_inp, input_offset));
221 for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) {
222 size_t kernel_offset = get_kernel_offset(ki, ic, kk, 0);
225 EVEX_compress_addr(aux_reg_ker, kernel_offset));
226 for (int oi = ow_start; oi < ow_end; oi++) {
228 size_t input_offset =
229 get_input_offset(ki, ic, oi, pad_l);
230 vpbroadcastd(zmm_inp(oi, jcp.nb_oc_blocking),
231 EVEX_compress_addr(aux_reg_inp, input_offset));
233 EVEX_compress_addr(aux_reg_ker, kernel_offset));
234 auto acc = zmm_out(oi, kk);
235 auto inp = zmm_inp(oi, jcp.nb_oc_blocking);
236 bf16_emu_->r_vdpbf16ps(acc, zmm_wei, inp);
238 vdpbf16ps(zmm_out(oi, kk), zmm_wei,
239 zmm_inp(oi, jcp.nb_oc_blocking));
244 add(aux_reg_ker, shift_kernel_ptr);
245 add(aux_reg_inp, shift_input_ptr);
249 jg(kh_label, T_NEAR);
255 size_t inp_step = (size_t)jcp.ih * jcp.iw * jcp.ic_block;
256 size_t ker_step = (size_t)jcp.kh * jcp.kw * jcp.oc_block * jcp.ic_block;
257 add(reg_inp, jcp.typesize_in * inp_step);
258 add(reg_ker, jcp.typesize_in * ker_step);
262 jg(icb_label, T_NEAR);
264 sub(reg_inp, jcp.typesize_in * inp_step * jcp.nb_ic);
265 sub(reg_ker, jcp.typesize_in * ker_step * jcp.nb_ic);
270 void jit_avx512_core_bf16_fwd_kernel::generate()
274 int ow_block = jcp.ow_block;
275 int nb_ow = jcp.nb_ow;
277 int l_pad = jcp.l_pad;
279 int ur_w_tail = jcp.ur_w_tail;
280 int dilate_w = jcp.dilate_w + 1;
281 int stride_w = jcp.stride_w;
283 int inp_mult = jcp.ic_block;
285 size_t inp_shift = (size_t)jcp.typesize_in * ur_w * stride_w * inp_mult;
286 size_t out_shift = (size_t)jcp.typesize_out * ur_w * jcp.oc_block;
288 int inp_shift_pad = jcp.typesize_in * (ur_w * stride_w - l_pad) * inp_mult;
289 int inp_shift_pad_second_block = -1 * jcp.typesize_in * l_pad * inp_mult;
292 mov(reg_inp, ptr[param1 + GET_OFF(src)]);
293 mov(reg_out, ptr[param1 + GET_OFF(dst)]);
294 mov(reg_ker, ptr[param1 + GET_OFF(filt)]);
295 mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]);
297 int r_pad = nstl::max(
298 0, (ow - 1) * stride_w + (kw - 1) * dilate_w - (iw + l_pad - 1));
299 int n_oi = ow / ur_w;
300 int r_pad1 = (ur_w * n_oi - 1) * stride_w + (kw - 1) * dilate_w
303 if (!is_ow_threading_on(jcp)) {
304 // ow is being processed as a whole - with left and right paddings
308 xor_(reg_oi, reg_oi);
310 compute_loop(ur_w, l_pad, r_pad);
313 compute_loop(ur_w, l_pad, r_pad1);
314 add(reg_inp, inp_shift_pad);
315 add(reg_out, out_shift);
316 if (ur_w_tail != 0) {
317 compute_loop(ur_w_tail, 0, r_pad);
321 compute_loop(ur_w, l_pad, 0);
322 add(reg_inp, inp_shift_pad);
323 add(reg_out, out_shift);
326 if ((l_pad <= 0 && n_oi > 0) || (l_pad > 0 && n_oi > 1)) {
330 compute_loop(ur_w, 0, 0);
331 add(reg_inp, inp_shift);
332 add(reg_out, out_shift);
336 jl(ow_loop_label, T_NEAR);
340 compute_loop(ur_w, 0, r_pad1);
341 add(reg_inp, inp_shift);
342 add(reg_out, out_shift);
344 if (ur_w_tail != 0) {
345 compute_loop(ur_w_tail, 0, r_pad);
350 // ow block is only processed.
351 // Number of block is passed as parameter owb,
352 // and padding processing depends on this number.
354 Label end_label, last_oi_label, middle_ow_blocks_label, tail_label;
355 Label oi_loop_label, oi_loop_start_label, oi_loop_end_label;
357 assert(ow_block % ur_w == 0);
358 int n_oi_not_last_ow_block = ow_block / ur_w;
359 // to simplify code (and general regs usage),
360 // size of ow block must be >= 2 * ur_w
361 assert(n_oi_not_last_ow_block > 1);
362 int n_oi_next_last_ow_block = n_oi_not_last_ow_block;
363 int n_oi_first_ow_block = n_oi_not_last_ow_block;
365 int n_oi_last_ow_block = (ow - ow_block * (nb_ow-1)) / ur_w;
367 // prepare right padding
368 bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0;
369 bool first_ow_block_padded = next_last_ow_block_padded && jcp.nb_ow == 2;
370 bool last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block > 0;
372 if (last_ow_block_padded) n_oi_last_ow_block--;
373 else if (first_ow_block_padded) n_oi_first_ow_block--;
374 else if (next_last_ow_block_padded) n_oi_next_last_ow_block--;
376 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
377 cmp(reg_owb, 0); // is that the first ow-block ?
378 jg(middle_ow_blocks_label, T_NEAR);
380 // the first ow block, compute left padding
382 mov(reg_oi, n_oi_first_ow_block);
384 compute_loop(ur_w, l_pad, 0);
385 add(reg_inp, inp_shift_pad);
386 add(reg_out, out_shift);
389 jmp(oi_loop_label, T_NEAR);
391 // middle or last ow block entry
393 L(middle_ow_blocks_label);
396 // just to consider left padding, not compute
397 add(reg_inp, inp_shift_pad_second_block);
400 // set number of iteration for oi-loop
401 cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ?
402 mov(reg_oi, n_oi_last_ow_block);
403 je(oi_loop_label, T_NEAR);
404 cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
405 mov(reg_oi, n_oi_next_last_ow_block);
406 je(oi_loop_label, T_NEAR);
407 mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks
409 // oi loop w/o padding
411 L(oi_loop_start_label);
413 jle(oi_loop_end_label, T_NEAR);
415 compute_loop(ur_w, 0, 0);
416 add(reg_inp, inp_shift);
417 add(reg_out, out_shift);
419 jmp(oi_loop_start_label, T_NEAR);
420 L(oi_loop_end_label);
422 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
424 cmp(reg_owb, 0); // first ow-block ?
425 if (first_ow_block_padded) {
426 je(last_oi_label, T_NEAR);
428 je(end_label, T_NEAR);
430 cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
431 jl(end_label, T_NEAR);
432 if (next_last_ow_block_padded) {
433 je(last_oi_label, T_NEAR);
435 je(end_label, T_NEAR);
437 // that is last block
438 if (!last_ow_block_padded) {
439 jmp(tail_label, T_NEAR);
442 // last oi block with right padding
444 compute_loop(ur_w, 0, r_pad1);
445 add(reg_inp, inp_shift);
446 add(reg_out, out_shift);
448 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
449 cmp(reg_owb, jcp.nb_ow - 1); // last ow_block?
450 jl(end_label, T_NEAR);
453 if (ur_w_tail != 0) {
454 compute_loop(ur_w_tail, 0, r_pad);
460 if (jcp.with_eltwise)
461 eltwise_injector_->prepare_table();
464 bool jit_avx512_core_bf16_fwd_kernel::post_ops_ok(
465 jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
466 const auto &p = attr.post_ops_;
468 auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
469 auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
472 case 0: return true; // no post_ops
473 case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise
474 case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise
475 default: return false;
481 status_t jit_avx512_core_bf16_fwd_kernel::init_conf(
482 jit_conv_conf_t &jcp,
483 const convolution_desc_t &cd, cpu_memory_t::pd_t &src_pd,
484 cpu_memory_t::pd_t &weights_pd, cpu_memory_t::pd_t &dst_pd,
485 cpu_memory_t::pd_t &bias_pd, const primitive_attr_t &attr,
488 using namespace prop_kind;
490 const int simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
492 const memory_desc_wrapper src_d(&src_pd);
493 const memory_desc_wrapper weights_d(&weights_pd);
494 const memory_desc_wrapper dst_d(&dst_pd);
495 const memory_desc_wrapper bias_d(&bias_pd);
497 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
498 int ndims = src_d.ndims();
500 jcp = zero<decltype(jcp)>();
502 jcp.prop_kind = cd.prop_kind;
503 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
504 jcp.mb = src_d.dims()[0];
505 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
506 jcp.oc_without_padding = jcp.oc;
507 jcp.ic = src_d.dims()[1] / jcp.ngroups;
508 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
509 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2];
510 jcp.iw = src_d.dims()[ndims-1];
511 jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
512 jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims-2];
513 jcp.ow = dst_d.dims()[ndims-1];
514 jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
515 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims-2];
516 jcp.kw = weights_d.dims()[with_groups + ndims-1];
517 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
518 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
519 jcp.l_pad = cd.padding[0][ndims-3];
520 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
521 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
522 jcp.stride_w = cd.strides[ndims-3];
523 jcp.src_fmt = src_d.format();
524 jcp.dst_dt = cd.dst_desc.data_type;
526 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
527 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
528 jcp.dilate_w = cd.dilates[ndims-3];
530 jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
531 - (jcp.ih + jcp.t_pad - 1);
532 jcp.back_pad = (jcp.od - 1) * jcp.stride_d
533 + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1);
536 return status::unimplemented;
538 jcp.is_cpx = mayiuse(avx512_core_bf16);
539 const int regs = jcp.is_cpx ? 31 /* expl_bcast case */ : 26;
541 jcp.oc_block = simd_w;
542 jcp.ic_block = simd_w;
543 jcp.aligned_threads = 0;
545 bool ok_to_pad_channels = jcp.ngroups == 1;
547 if (ok_to_pad_channels) {
548 jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
549 jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
552 && jcp.oc % jcp.oc_block == 0
553 && jcp.ic % jcp.ic_block == 0;
555 return status::unimplemented;
557 if (!post_ops_ok(jcp, attr))
558 return status::unimplemented;
560 const auto &p = attr.post_ops_;
561 jcp.with_sum = p.find(primitive_kind::sum) != -1;
562 const int eltwise_ind = p.find(primitive_kind::eltwise);
563 jcp.with_eltwise = eltwise_ind != -1;
564 if (jcp.with_eltwise) {
565 jcp.eltwise = p.entry_[eltwise_ind].eltwise;
566 if (dst_d.data_type() == data_type::s32) return status::unimplemented;
569 auto src_format = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
570 auto dst_format = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
572 if (src_d.format() == any)
573 CHECK(src_pd.set_format(src_format));
574 if (src_d.format() != src_format)
575 return status::unimplemented;
576 if (dst_d.format() == any)
577 CHECK(dst_pd.set_format(dst_format));
578 if (dst_d.format() != dst_format)
579 return status::unimplemented;
580 const auto w_format = with_groups
581 ? pick(ndims - 3, gOIw8i16o2i, gOIhw8i16o2i, gOIdhw8i16o2i)
582 : pick(ndims - 3, OIw8i16o2i, OIhw8i16o2i, OIdhw8i16o2i);
583 if (weights_d.format() == any)
584 CHECK(weights_pd.set_format(w_format));
585 if (weights_d.format() != w_format)
586 return status::unimplemented;
588 jcp.with_bias = cd.bias_desc.format != memory_format::undef;
590 if (bias_d.format() == any)
591 CHECK(bias_pd.set_format(x));
592 if (bias_d.format() != x)
593 return status::unimplemented;
597 jcp.typesize_in = sizeof(mkldnn_bfloat16_t);
598 jcp.typesize_out = (dst_d.data_type() == data_type::f32)
599 ? sizeof(float) : sizeof(mkldnn_bfloat16_t);
601 jcp.nb_ic = jcp.ic / jcp.ic_block;
602 jcp.nb_oc = jcp.oc / jcp.oc_block;
603 jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
605 jcp.kernel_kind = expl_bcast;
606 jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc);
607 for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--) {
608 int ur_w = regs / (jcp.nb_oc_blocking + 1);
609 if (jcp.nb_oc % jcp.nb_oc_blocking == 0
610 && (jcp.l_pad <= ur_w
611 && IMPLICATION(jcp.ow != 1, jcp.ow % ur_w != 1)))
615 jcp.ur_w = regs / (jcp.nb_oc_blocking + 1);
616 if (jcp.ow < jcp.ur_w)
618 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
620 jcp.ow_block = jcp.ow;
621 if (is_ow_threading_available(jcp)) {
622 const int L1_part = get_cache_size(1) * 5 / 8;
623 int size_src_chunk = jcp.typesize_in * jcp.ic_block * jcp.ur_w;
624 int size_dst_chunk = jcp.typesize_out
625 * jcp.oc_block * jcp.nb_oc_blocking * jcp.ur_w;
626 int size_wei_chunk = jcp.typesize_in
627 * jcp.oc_block * jcp.ic_block * jcp.nb_oc_blocking * jcp.kw;
628 int nurw = (L1_part - size_wei_chunk)
629 / (size_dst_chunk + size_src_chunk);
630 // current design of generate() requires ow_block >= 2 * ur_w
631 jcp.ow_block = jcp.ur_w * nstl::max(2, nurw);
633 jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
636 && jcp.l_pad <= jcp.ur_w
637 && jcp.ic <= src_d.blocking_desc().padding_dims[1]
638 && jcp.oc <= dst_d.blocking_desc().padding_dims[1]
639 && jcp.ic <= weights_d.blocking_desc().padding_dims[with_groups + 1]
640 && jcp.oc <= weights_d.blocking_desc().padding_dims[with_groups + 0];
642 return status::unimplemented;
644 int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
645 + (jcp.kw - 1) * (jcp.dilate_w + 1)
646 - (jcp.iw + jcp.l_pad - 1));
647 if (r_pad_no_tail > jcp.ur_w)
648 return status::unimplemented;
650 pick_loop_order(jcp);
652 jcp.nb_ic_L2 = jcp.nb_ic;
654 const int L2_size = get_cache_size(2, true) / sizeof(float);
655 // Source and output data needs to fit in L2,
656 // leaving some space for weights and prefetching.
657 int h_L2 = int(((0.6f * L2_size) / simd_w
658 - nstl::min(0, jcp.kh - jcp.stride_h) * jcp.iw)
659 / (jcp.stride_h * jcp.iw + jcp.ow));
660 jcp.h_blocking = nstl::max(1, nstl::min(jcp.oh, h_L2));
662 return status::success;
665 void jit_avx512_core_bf16_bwd_data_kernel::prepare_output(int ur_w)
667 for (int k = 0; k < jcp.nb_ic_blocking; k++) {
668 for (int j = 0; j < ur_w; j++) {
669 Zmm zmm = zmm_out(j, k);
670 vpxord(zmm, zmm, zmm);
675 void jit_avx512_core_bf16_bwd_data_kernel::store_output(int ur_w)
678 bf16_emu_->init_vcvtneps2bf16();
680 if (jcp.dsrc_dt == data_type::f32) {
681 for (int k = 0; k < jcp.nb_ic_blocking; k++)
682 for (int j = 0; j < ur_w; j++) {
683 Zmm zmm = zmm_out(j, k);
684 size_t aux_diff_src_offset = jcp.typesize_out *
685 ((size_t)k * jcp.id * jcp.ih * jcp.iw + j) * jcp.ic_block;
686 auto addr = EVEX_compress_addr(reg_src, aux_diff_src_offset);
690 } else if (jcp.dsrc_dt == data_type::bf16) {
693 const int max_regs = 32;
694 const int free_regs_start_idx = jcp.ur_w * jcp.nb_ic_blocking;
695 const int num_regs_available = max_regs - free_regs_start_idx;
697 for (int k = 0; k < jcp.nb_ic_blocking; k++) {
698 int n_2bf2ps = (ur_w / 2) * 2, j = 0;
699 for (j = 0; j < n_2bf2ps; j += 2) {
700 reg_idx = free_regs_start_idx
701 + store_idx % num_regs_available;
702 assert(reg_idx < max_regs);
703 size_t aux_diff_src_offset = jcp.typesize_out *
704 ((size_t)k * jcp.id * jcp.ih * jcp.iw + j) *
706 auto addr = EVEX_compress_addr(reg_src,
707 aux_diff_src_offset);
709 auto zmm_str = Zmm(reg_idx);
710 vcvtne2ps2bf16(zmm_str, zmm_out(j+1, k),
712 vmovups(addr, zmm_str);
716 reg_idx = free_regs_start_idx
717 + store_idx % num_regs_available;
718 assert(reg_idx < max_regs);
720 size_t aux_diff_src_offset = jcp.typesize_out *
721 ((size_t)k * jcp.id * jcp.ih * jcp.iw + j) * jcp.ic_block;
722 auto addr = EVEX_compress_addr(reg_src, aux_diff_src_offset);
723 auto ymm_str = Ymm(reg_idx);
724 vcvtneps2bf16(ymm_str, zmm_out(j, k));
725 vmovups(addr, ymm_str);
730 for (int k = 0; k < jcp.nb_ic_blocking; k++)
731 for (int j = 0; j < ur_w; j++) {
732 Zmm zmm = zmm_out(j, k);
733 size_t aux_diff_src_offset = jcp.typesize_out *
734 ((size_t)k * jcp.id * jcp.ih * jcp.iw + j) * jcp.ic_block;
735 auto addr = EVEX_compress_addr(reg_src, aux_diff_src_offset);
736 Ymm ymm = ymm_inp(0);
737 bf16_emu_->r_vcvtneps2bf16(ymm, zmm);
742 assert(!"unsupported diff_src type");
745 void jit_avx512_core_bf16_bwd_data_kernel::compute_loop(
746 int ur_w, int l_overflow, int r_overflow)
750 int ic_block = jcp.ic_block;
751 int oc_block = jcp.oc_block;
752 int dilate_w = jcp.dilate_w + 1;
753 int stride_w = jcp.stride_w;
754 int stride_h = jcp.stride_h;
755 Label kh_label, skip_compute_label;
757 auto kernel_offset = [=](int icb, int oc, int ki) {
758 size_t blk_idx = (size_t)icb * jcp.kh * jcp.kw + ki;
759 size_t blk_offset = blk_idx * jcp.oc_block * jcp.ic_block;
760 size_t oc_offset = (size_t)oc * jcp.oc_block;
761 return jcp.typesize_in * (blk_offset + oc_offset);
764 prepare_output(ur_w);
766 jle(skip_compute_label, T_NEAR);
770 mov(reg_ocb, jcp.nb_oc);
773 mov(aux_reg_dst, reg_dst);
774 mov(aux_reg_ker, reg_ker);
778 for (int ki = 0; ki < kw; ki++) {
779 int jj_start = get_iw_start(ki, l_overflow);
780 int jj_end = get_iw_end(ur_w, ki, r_overflow);
782 || jj_start == nstl::max(0,
783 l_overflow - (kw - 1 - ki) * dilate_w));
785 || jj_end == ur_w - nstl::max(0,
786 r_overflow - ki * dilate_w));
789 oc < div_up(nstl::min(oc_block, jcp.oc), 2); oc++) {
791 for (int jj = jj_start; jj < jj_end; jj += stride_w) {
792 assert((jj + jcp.l_pad - ki * dilate_w) % stride_w == 0);
793 size_t aux_dst_offset = jcp.typesize_in
794 * ((jj + jcp.l_pad - ki * dilate_w) / stride_w
797 auto inp = zmm_inp(jj / stride_w);
798 vpbroadcastd(inp, ptr[aux_reg_dst + aux_dst_offset]);
801 for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) {
802 size_t aux_kernel_offset = kernel_offset(kk, 2 * oc, ki);
805 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
808 for (int jj = jj_start; jj < jj_end; jj += stride_w) {
809 auto inp = zmm_inp(jj / stride_w);
810 auto acc = zmm_out(jj, kk);
813 size_t aux_dst_offset = jcp.typesize_in
814 * ((jj + jcp.l_pad - ki * dilate_w) / stride_w
818 ptr[aux_reg_dst + aux_dst_offset]);
820 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
821 bf16_emu_->r_vdpbf16ps(acc, zmm_wei, inp);
823 vdpbf16ps(acc, zmm_wei, inp);
829 add(aux_reg_ker, jcp.typesize_in * stride_h * kw * oc_block * ic_block);
830 sub(aux_reg_dst, jcp.typesize_in * (jcp.dilate_h + 1) * ow * oc_block);
834 jg(kh_label, T_NEAR);
838 size_t diff_dst_step = (size_t)jcp.oh * jcp.ow * jcp.oc_block;
839 size_t ker_step = (size_t)jcp.ic * jcp.kh * jcp.kw * jcp.oc_block;
840 add(reg_dst, jcp.typesize_in * diff_dst_step);
841 add(reg_ker, jcp.typesize_in * ker_step);
845 jg(ocb_label, T_NEAR);
847 sub(reg_dst, jcp.typesize_in * diff_dst_step * jcp.nb_oc);
848 sub(reg_ker, jcp.typesize_in * ker_step * jcp.nb_oc);
850 L(skip_compute_label);
854 void jit_avx512_core_bf16_bwd_data_kernel::generate()
859 int ic_block = jcp.ic_block;
860 int oc_block = jcp.oc_block;
861 int ur_w_tail = jcp.ur_w_tail;
862 int dilate_w = jcp.dilate_w + 1;
863 int stride_w = jcp.stride_w;
865 size_t dst_shift = jcp.typesize_in * (ur_w / stride_w) * ic_block;
866 size_t src_shift = jcp.typesize_out * ur_w * oc_block;
870 mov(reg_src, ptr[param + GET_OFF(src)]);
871 mov(reg_dst, ptr[param + GET_OFF(dst)]);
872 mov(reg_ker, ptr[param + GET_OFF(filt)]);
874 mov(reg_kh, ptr[param + GET_OFF(kh_padding)]);
876 int l_overflow = nstl::max(0, ((kw - 1) * dilate_w - jcp.l_pad) / stride_w);
877 int r_overflow = nstl::max(0, ((kw - 1) * dilate_w
878 - nstl::max(0, jcp.r_pad)) / stride_w);
879 int r_overflow1 = nstl::max(0, ((kw - 1) * dilate_w
880 - nstl::max(0, jcp.r_pad) - ur_w_tail) / stride_w);
882 int n_oi = iw / ur_w;
883 if (r_overflow1 > 0) n_oi--;
886 compute_loop(ur_w, l_overflow, r_overflow);
887 } else if (n_oi == 0) {
888 compute_loop(ur_w, l_overflow, r_overflow1);
889 add(reg_src, src_shift);
890 add(reg_dst, dst_shift);
892 compute_loop(ur_w_tail, 0, r_overflow);
894 xor_(reg_oi, reg_oi);
895 if (l_overflow > 0) {
896 compute_loop(ur_w, l_overflow, 0);
897 add(reg_src, src_shift);
898 add(reg_dst, dst_shift);
902 if ((l_overflow <= 0 && n_oi > 0)
903 || (l_overflow > 0 && n_oi > 1)) {
906 compute_loop(ur_w, 0, 0);
907 add(reg_src, src_shift);
908 add(reg_dst, dst_shift);
912 jl(ow_loop_label, T_NEAR);
915 if (r_overflow1 > 0) {
916 compute_loop(ur_w, 0, r_overflow1);
917 add(reg_src, src_shift);
918 add(reg_dst, dst_shift);
920 if (ur_w_tail != 0) {
921 compute_loop(ur_w_tail, 0, r_overflow);
928 status_t jit_avx512_core_bf16_bwd_data_kernel::init_conf(
929 jit_conv_conf_t &jcp,
930 const convolution_desc_t &cd,
931 const memory_desc_wrapper &diff_src_d,
932 const memory_desc_wrapper &weights_d,
933 const memory_desc_wrapper &diff_dst_d)
935 const int simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
936 const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
937 int ndims = diff_src_d.ndims();
940 jcp.prop_kind = cd.prop_kind;
942 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
943 jcp.mb = diff_src_d.dims()[0];
945 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
946 jcp.oc_without_padding = jcp.oc;
947 jcp.ic = diff_src_d.dims()[1] / jcp.ngroups;
949 jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1;
950 jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims-2];
951 jcp.iw = diff_src_d.dims()[ndims-1];
952 jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
953 jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2];
954 jcp.ow = diff_dst_d.dims()[ndims-1];
956 jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
957 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
958 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
960 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
961 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
962 jcp.l_pad = cd.padding[0][ndims-3];
964 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
965 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
966 jcp.stride_w = cd.strides[ndims-3];
968 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
969 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
970 jcp.dilate_w = cd.dilates[ndims-3];
971 jcp.dsrc_dt = cd.diff_src_desc.data_type;
973 /* Dilated convolutions supported with unit strides only */
974 if ((jcp.dilate_w != 0 && jcp.stride_w != 1)
975 || (jcp.dilate_d != 0 && jcp.stride_d != 1)
976 || (jcp.dilate_h != 0 && jcp.stride_h != 1))
977 return status::unimplemented;
979 jcp.is_cpx = mayiuse(avx512_core_bf16);
981 jcp.r_pad = (jcp.ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1)
982 - (jcp.iw + jcp.l_pad - 1);
983 jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
984 - (jcp.ih + jcp.t_pad - 1);
985 jcp.back_pad = (jcp.od - 1) * jcp.stride_d
986 + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1);
988 jcp.aligned_threads = 0;
990 jcp.oc_block = simd_w;
991 jcp.ic_block = simd_w;
993 bool ok_to_pad_channels = jcp.ngroups == 1;
995 if (ok_to_pad_channels) {
996 jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
997 jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
1000 auto src_format = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
1001 auto wei_format = with_groups
1002 ? pick(ndims - 3, gOIw8o16i2o, gOIhw8o16i2o, gOIdhw8o16i2o)
1003 : pick(ndims - 3, OIw8o16i2o, OIhw8o16i2o, OIdhw8o16i2o);
1005 && jcp.oc % jcp.oc_block == 0
1006 && jcp.ic % jcp.ic_block == 0
1007 && diff_src_d.format() == src_format
1008 && diff_dst_d.format() == src_format
1009 && weights_d.format() == wei_format;
1011 return status::unimplemented;
1013 jcp.nb_ic = jcp.ic / jcp.ic_block;
1014 jcp.nb_oc = jcp.oc / jcp.oc_block;
1016 jcp.ur_w = jcp.stride_w;
1018 /* Maximun number of registers available for result accumulation and delta
1019 dst data. One additional register is reserved for weights data. */
1020 const int max_regs = jcp.is_cpx ? 31 : 26; /* In case of cpx emulation
1021 additional 5 registers are
1023 int l_overflow = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
1024 - jcp.l_pad) / jcp.stride_w);
1025 int r_overflow1 = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
1026 - nstl::max(0, jcp.r_pad) - jcp.iw % jcp.ur_w) / jcp.stride_w);
1027 int n_oi = jcp.iw / jcp.ur_w;
1028 if (r_overflow1 > 0) n_oi--;
1031 jcp.typesize_in = sizeof(mkldnn_bfloat16_t);
1032 jcp.typesize_out = (diff_src_d.data_type() == data_type::f32)
1033 ? sizeof(float) : sizeof(mkldnn_bfloat16_t);
1036 return status::unimplemented;
1038 /* Find the best blocking with maximum number of compute instructions
1039 per ur_w * nb_ic_blocking compute loops. Number of required registers
1040 is num_regs = ur_w * nb_ic_blocking + ur_w / stride_w <= max_regs.
1041 ur_w must be divisible by stride_w */
1042 if (jcp.stride_w + 1 > max_regs) /* Minimal possible registers
1043 distribution exceeds max_regs */
1044 return status::unimplemented;
1046 jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
1048 jcp.kernel_kind = expl_bcast;
1049 int best_compute_pipeline_length = 0;
1050 const int max_ic_blocks = 4;
1051 for (int b = 1; b <= max_ic_blocks; b++)
1053 if (jcp.nb_ic % b != 0)
1056 for (int u = jcp.stride_w;
1057 u * b + u / jcp.stride_w <= max_regs
1058 && u < jcp.iw + jcp.stride_w;
1061 int ur_w = nstl::min(u, jcp.iw);
1062 /* maximum 1 step with l_overflow so far */
1063 if (l_overflow * jcp.stride_w > ur_w && ur_w != jcp.iw)
1065 int pipeline_length = utils::div_up(ur_w, jcp.stride_w) * b;
1066 if (pipeline_length > best_compute_pipeline_length
1067 || (pipeline_length == best_compute_pipeline_length
1068 && jcp.ur_w < ur_w)) {
1070 jcp.nb_ic_blocking = b;
1071 best_compute_pipeline_length = pipeline_length;
1075 if (best_compute_pipeline_length == 0) /* can't find
1076 appropriate blocking */
1077 return status::unimplemented;
1080 jcp.loop_order = loop_gnc;
1082 jcp.ur_w_tail = jcp.iw % jcp.ur_w;
1084 if (l_overflow * jcp.stride_w > jcp.ur_w)
1085 return status::unimplemented;
1086 int r_overflow_no_tail = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
1087 - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w);
1088 if (r_overflow_no_tail * jcp.stride_w > jcp.ur_w)
1089 return status::unimplemented;
1090 if ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0))
1091 return status::unimplemented;
1093 pick_loop_order(jcp);
1095 jcp.nb_oc_L2 = jcp.nb_oc;
1098 && jcp.ic <= diff_src_d.blocking_desc().padding_dims[1]
1099 && jcp.oc <= diff_dst_d.blocking_desc().padding_dims[1]
1100 && jcp.ic <= weights_d.blocking_desc().padding_dims[with_groups + 1]
1101 && jcp.oc <= weights_d.blocking_desc().padding_dims[with_groups + 0];
1103 return args_ok ? status::success : status::unimplemented;
1106 const int jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::max_ur_w = 28;
1108 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::oh_step_comeback_pointers()
1110 Label kh_comeback_label, kd_comeback_label;
1112 L(kh_comeback_label); {
1113 int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
1115 sub(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mult);
1117 jcp.typesize_out * jcp.kw * jcp.ic_block * jcp.oc_block);
1120 jg(kh_comeback_label, T_NEAR);
1123 #ifndef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1124 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::compute_ic_block_step(
1125 int ur_w, int pad_l, int pad_r,
1126 int ic_block_step, int input_offset, int kernel_offset,
1127 int output_offset, bool is_tail)
1130 int ic_block = jcp.ic_block;
1131 int oc_block = jcp.oc_block;
1133 auto zmm_ker = [=](int i_kw, int i_ic) {
1134 return Zmm(i_kw * ic_block_step + i_ic);
1136 auto zmm_out = [=](int i_iw) {
1137 // TODO: move reg calc to global member funcs
1138 const int out_zmm_base_idx = 24;
1139 const int num_out_zmm_regs = !jcp.is_cpx ? 2 : 4;
1140 return Zmm(out_zmm_base_idx + i_iw % num_out_zmm_regs);
1143 auto ker_addr = [=](int i_kw, int i_ic) {
1145 = jcp.typesize_out * (i_kw * ic_block + i_ic) * jcp.oc_block;
1146 return EVEX_compress_addr(reg_kernel, local_offset + kernel_offset);
1148 auto inp_addr = [=](int i_iw, int i_ic, ptrdiff_t extra_offset = 0,
1149 bool vnni_bcast = false) {
1150 int stride = jcp.tr_iw;
1151 int local_offset = jcp.typesize_in * (i_iw + i_ic * stride);
1153 return EVEX_compress_addr(reg_input,
1154 local_offset + input_offset + extra_offset, true);
1156 return EVEX_compress_addr(reg_input,
1157 local_offset + input_offset + extra_offset);
1159 auto out_addr = [=](int i_ur) {
1161 return EVEX_compress_addr(reg_output,
1162 jcp.typesize_in * i_ur * oc_block * ow_per_oc + output_offset);
1165 for (int i_kw = 0; i_kw < kw; i_kw++)
1166 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1167 auto zmm = zmm_ker(i_kw, i_ic);
1168 vpxord(zmm, zmm, zmm);
1170 assert(ur_w % 2 == 0);
1171 auto steps = ur_w / 2;
1173 const int str_w = jcp.stride_w;
1174 for (int s = 0; s < str_w; s++) {
1175 const int kw_start = s;
1176 assert(jcp.tr_iw % str_w == 0);
1177 const int inp_stride_w_shift = jcp.tr_iw / str_w;
1178 for (int i_ur = 0; i_ur < steps; i_ur++) {
1179 auto zmm = zmm_out(i_ur);
1180 vmovdqu16(zmm, out_addr(i_ur));
1182 for (int i_kw = kw_start; i_kw < kw; i_kw += str_w)
1183 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1184 int i_iw = 2 * i_ur + i_kw / str_w
1185 + s * inp_stride_w_shift;
1188 vpbroadcastd(inp, inp_addr(i_iw, i_ic, 0));
1189 auto acc = zmm_ker(i_kw, i_ic);
1190 auto wei = zmm_out(i_ur);
1191 bf16_emu_->r_vdpbf16ps(acc, wei, inp);
1193 vdpbf16ps(zmm_ker(i_kw, i_ic), zmm_out(i_ur),
1194 inp_addr(i_iw, i_ic, 0, true));
1197 for (int i_kw = kw_start; i_kw < kw; i_kw += str_w) {
1198 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1199 auto addr = ker_addr(i_kw, i_ic);
1200 auto zmm = zmm_ker(i_kw, i_ic);
1201 vaddps(zmm, zmm, addr);
1208 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::compute_ic_block_step(
1209 int ur_w, int pad_l, int pad_r,
1210 int ic_block_step, int input_offset, int kernel_offset,
1211 int output_offset, bool is_tail)
1214 int ic_block = jcp.ic_block;
1215 int oc_block = jcp.oc_block;
1217 for (int i_kw = 0; i_kw < kw; i_kw++)
1218 for (int i_ic = 0; i_ic < ic_block_step; i_ic++)
1219 vmovups(Zmm(i_kw * ic_block_step + i_ic),
1220 EVEX_compress_addr(reg_kernel,typesize *
1221 (i_kw * ic_block + i_ic) * jcp.oc_block + kernel_offset));
1223 Reg64 reg_trans_tmp = r11;
1224 mov(reg_trans_tmp, dst_prm_table);
1225 auto perm = Zmm(24);
1226 vmovups(perm, ptr[reg_trans_tmp]);
1228 Opmask load_mask = Opmask(7);
1229 for (int i_ur = 0; i_ur < ur_w; i_ur += 2) {
1230 if (ur_w % 2 && i_ur + 2 >= ur_w)
1231 mov(reg_trans_tmp.cvt32(), 0x0000ffff);
1233 mov(reg_trans_tmp.cvt32(), 0xffffffff);
1234 kmovd(load_mask, reg_trans_tmp.cvt32());
1235 auto zmm_dst = Zmm(25);
1236 vmovdqu16(zmm_dst | load_mask | T_z,
1237 EVEX_compress_addr(reg_output,
1238 jcp.typesize_in * i_ur * oc_block + output_offset));
1239 vpermw(zmm_dst, perm, zmm_dst);
1240 for (int i_kw = 0; i_kw < kw; i_kw++) {
1241 int iw_1 = (i_ur + i_kw);
1242 int iw_2 = (i_ur + 1 == ur_w) ? -1 : (i_ur + 1) + i_kw;
1243 iw_1 = (iw_1 - pad_l < 0 || iw_1 > (ur_w - 1) + (kw - 1) - pad_r)
1244 ? -1 : iw_1 - pad_l;
1245 iw_2 = (iw_2 - pad_l < 0 || iw_2 > (ur_w - 1) + (kw - 1) - pad_r)
1246 ? -1 : iw_2 - pad_l;
1248 int local_offset = i_ur + i_kw - pad_l;
1249 if (iw_1 == -1 && iw_2 == -1) continue;
1250 if (iw_1 != -1 && iw_2 != -1) mov(reg_trans_tmp.cvt32(), 0xffffffff);
1251 if (iw_1 != -1 && iw_2 == -1) mov(reg_trans_tmp.cvt32(), 0x0000ffff);
1252 if (iw_1 == -1 && iw_2 != -1) mov(reg_trans_tmp.cvt32(), 0xffff0000);
1253 kmovd(load_mask, reg_trans_tmp.cvt32());
1255 const size_t i_offset = (size_t)input_offset +
1256 (size_t)jcp.typesize_in * (local_offset) * ic_block;
1257 auto bcast_values = Zmm(26);
1258 vpxord(bcast_values, bcast_values, bcast_values);
1259 vmovdqu16(bcast_values| load_mask | T_z, ptr[reg_input + i_offset]);
1260 vpermw(bcast_values,perm, bcast_values);
1261 vmovups(ptr[rsp], bcast_values);
1263 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1265 auto zmm_src = Zmm(28);
1266 vpbroadcastd(zmm_src, ptr[rsp + jcp.typesize_in * 2 * i_ic]);
1267 bf16_emu_->r_vdpbf16ps(Zmm(i_kw * ic_block_step + i_ic),
1270 vdpbf16ps(Zmm(i_kw * ic_block_step + i_ic), zmm_dst,
1271 zword_b[rsp + jcp.typesize_in * 2 * i_ic]);
1275 for (int i_kw = 0; i_kw < kw; i_kw++) {
1276 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
1277 int l_offset = jcp.typesize_out *
1278 (i_kw * ic_block + i_ic) * jcp.oc_block;
1279 vmovups(EVEX_compress_addr(reg_kernel, l_offset + kernel_offset),
1280 Zmm(i_kw * ic_block_step + i_ic));
1285 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32
1286 ::compute_oh_step_unroll_ow_icblock(
1287 int ic_block_step, int max_ur_w)
1291 Label kh_label, kd_label;
1293 int ic_block = jcp.ic_block;
1294 int oc_block = jcp.oc_block;
1295 int inp_mul = !jcp.is_1stconv ? ic_block : 1;
1297 #ifndef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1298 // physical padding exists
1303 // XXX: is it possible to use jcp.r_pad here?
1304 int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w
1305 + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
1306 int l_pad = jcp.l_pad;
1312 for (int i_b_ic = 0; i_b_ic < jcp.ic_block; i_b_ic += ic_block_step) {
1313 #ifndef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1314 const int input_offset = jcp.typesize_in * i_b_ic * iw;
1316 const int input_offset = jcp.typesize_in * i_b_ic;
1318 compute_ic_block_step(jcp.ur_w, l_pad, r_pad, ic_block_step,
1319 input_offset, jcp.typesize_out * i_b_ic * jcp.oc_block, 0,
1320 i_b_ic + ic_block_step >= jcp.ic_block);
1322 add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul);
1323 add(reg_kernel, jcp.typesize_out * jcp.kw * ic_block * oc_block);
1326 jg(kh_label, T_NEAR);
1330 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32
1331 ::compute_oh_step_unroll_ow(
1332 int ic_block_step, int max_ur_w)
1334 Label kh_label, ic_block_label, kd_label;
1338 int ic_block = jcp.ic_block;
1339 int oc_block = jcp.oc_block;
1342 #ifndef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1343 // physical padding exists
1347 // XXX: is it possible to use jcp.r_pad here?
1348 int r_pad = nstl::max(0,
1349 (ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1)
1350 - (jcp.iw + jcp.l_pad - 1));
1351 int l_pad = jcp.l_pad;
1358 L(ic_block_label); {
1359 compute_ic_block_step(ow, l_pad, r_pad, ic_block_step,
1361 #ifndef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1362 size_t inp_icblk_stride = jcp.tr_iw;
1364 size_t inp_icblk_stride = jcp.is_1stconv
1365 ? (size_t)jcp.ih * jcp.iw * jcp.id
1369 = inp_icblk_stride * jcp.typesize_in * ic_block_step;
1370 safe_add(reg_input, input_offset, reg_long_offt);
1371 add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block);
1372 add(b_ic, ic_block_step);
1373 cmp(b_ic, jcp.ic_block);
1374 jl(ic_block_label, T_NEAR);
1376 #ifdef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1377 if (jcp.is_1stconv) {
1379 = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block;
1380 safe_sub(reg_input, input_offset, reg_long_offt);
1381 add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw);
1383 add(reg_input, jcp.typesize_in
1384 * ((jcp.dilate_h + 1) * jcp.iw - 1) * ic_block);
1387 add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block);
1390 jg(kh_label, T_NEAR);
1394 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32
1395 ::compute_oh_step_common(
1396 int ic_block_step, int max_ur_w)
1398 Label kh_label, ic_block_label, ow_block_label, kd_label;
1400 int ic_block = jcp.ic_block;
1401 int oc_block = jcp.oc_block;
1404 #ifndef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1405 // physical padding exists
1410 int l_pad = jcp.l_pad;
1411 // XXX: is it possible to use jcp.r_pad here?
1412 int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w
1413 + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
1414 int stride_w = jcp.stride_w;
1416 int ur_w = nstl::min(ow, max_ur_w);
1417 int ur_w_trips = ow / ur_w;
1418 int ur_w_tail = ow % ur_w;
1419 if ((ur_w_tail == 0 && r_pad != 0)
1420 || r_pad >= ur_w_tail) {
1421 if (ur_w_trips > 1) {
1425 ur_w_tail += (ur_w - ur_w / 2);
1429 #ifndef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1432 int inp_mult = (jcp.is_1stconv) ? 1 : ic_block;
1434 int input_comeback = (ur_w_trips * ur_w * stride_w - l_pad) * inp_mult;
1435 int output_comeback = ur_w_trips * ur_w * oc_block;
1440 L(ic_block_label); {
1443 compute_ic_block_step(ur_w, l_pad, 0, ic_block_step, 0, 0, 0);
1444 add(reg_input, jcp.typesize_in * (ur_w * stride_w - l_pad)
1446 add(reg_output, jcp.typesize_in * ur_w * oc_block);
1449 if (ur_w_trips > 0) {
1450 xor_(reg_ur_w_trips, reg_ur_w_trips);
1451 L(ow_block_label); {
1452 compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0);
1453 add(reg_input, jcp.typesize_in * ur_w * stride_w
1455 add(reg_output, jcp.typesize_in * ur_w * oc_block);
1457 inc(reg_ur_w_trips);
1458 cmp(reg_ur_w_trips, ur_w_trips);
1459 jl(ow_block_label, T_NEAR);
1463 if (ur_w_tail > 0) {
1464 compute_ic_block_step(ur_w_tail, 0, r_pad,
1465 ic_block_step, 0, 0, 0, true);
1468 sub(reg_input, jcp.typesize_in * input_comeback);
1469 sub(reg_output, jcp.typesize_in * output_comeback);
1470 #ifndef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1471 int inp_icblk_stride = jcp.tr_iw;
1473 int inp_icblk_stride = jcp.is_1stconv
1474 ? jcp.ih * jcp.iw * jcp.id
1478 = inp_icblk_stride * jcp.typesize_in * ic_block_step;
1479 safe_add(reg_input, input_offset, reg_long_offt);
1480 add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block);
1482 add(b_ic, ic_block_step);
1483 cmp(b_ic, jcp.ic_block);
1484 jl(ic_block_label, T_NEAR);
1486 #ifdef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1487 if (jcp.is_1stconv) {
1489 = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block;
1490 safe_sub(reg_input, input_offset, reg_long_offt);
1491 add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw);
1493 add(reg_input, jcp.typesize_in
1494 * ((jcp.dilate_h + 1 ) * jcp.iw - 1) * ic_block);
1497 add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block);
1500 jg(kh_label, T_NEAR);
1504 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32
1505 ::compute_oh_step_disp()
1507 int ic_block_step = jcp.kw <= 3 ? 8 : (jcp.kw < 7 ? 4 : 2);
1509 bool too_large_to_unroll
1510 = (jcp.kw > 1 || jcp.kh > 1)
1511 && (jcp.stride_w > 1 || jcp.stride_h > 1 || jcp.stride_d > 1);
1514 if (jcp.kw <= 3 && ow <= 16 && !too_large_to_unroll) {
1515 compute_oh_step_unroll_ow_icblock(ic_block_step, max_ur_w);
1516 } else if (ow <= max_ur_w) {
1517 compute_oh_step_unroll_ow(ic_block_step, max_ur_w);
1519 compute_oh_step_common(ic_block_step, max_ur_w);
1521 oh_step_comeback_pointers();
1524 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::maybe_zero_kernel()
1526 Label skip_zeroing, zeroing_loop;
1528 mov(reg_tmp, ptr[param + GET_OFF(channel)]);
1530 jz(skip_zeroing, T_NEAR);
1533 vpxord(zero, zero, zero);
1534 xor_(reg_tmp, reg_tmp);
1536 assert(jcp.oc_block * jcp.typesize_out
1537 == cpu_isa_traits<avx512_core>::vlen);
1538 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++)
1539 vmovups(ptr[reg_kernel + reg_tmp + ic1 * jcp.oc_block
1540 * jcp.typesize_out], zero);
1541 add(reg_tmp, jcp.ic_block * jcp.oc_block * jcp.typesize_out);
1542 cmp(reg_tmp, jcp.ic_block * jcp.oc_block * jcp.kw * jcp.kh
1543 * jcp.typesize_out);
1550 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32
1553 int b_pad = jcp.b_pad;
1554 int t_pad = jcp.t_pad;
1555 bool is_dilated = jcp.dilate_h != 0;
1556 int dilate_h = jcp.dilate_h + 1;
1557 int stride_h = jcp.stride_h;
1558 const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
1561 Label oh_label, oh_label_end, oh_tpad_label, oh_tpad_tail_label,
1562 oh_bpad_label, oh_bpad_label_end,
1563 oh_dilate_label_shift, oh_dilate_label_noshift, oh_dilate_label_end,
1564 skip_neg_overlap_label, skip_fpad_label, skip_input_label;
1566 maybe_zero_kernel();
1568 mov(reg_kh, jcp.kh);
1569 xor_(reg_ih_count, reg_ih_count);
1570 xor_(reg_oj, reg_oj);
1571 /* Compute 'top' edge */
1573 const int kh_range = 1 + (jcp.kh - 1) * dilate_h;
1575 = nstl::max(0, jcp.kh - div_up(t_pad + jcp.ih, dilate_h));
1576 const int underflow = div_up(t_pad, dilate_h);
1577 const int initial_inp_ker_overlap = jcp.kh - overflow - underflow;
1578 mov(reg_kh, initial_inp_ker_overlap);
1579 add(reg_kernel, jcp.typesize_out * underflow * jcp.kw * jcp.ic_block
1581 // generate loop to process kernel while it remains within t_pad + ih
1582 if (kh_range < t_pad + jcp.ih) {
1584 const int tail = t_pad % dilate_h;
1585 const int shift = tail == 0 ? 0 : dilate_h - tail;
1586 mov(reg_tmp, shift);
1588 add(reg_input, jcp.typesize_in * shift * iw * inp_mult);
1591 compute_oh_step_disp();
1592 add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
1595 cmp(reg_tmp, dilate_h);
1596 jl(oh_dilate_label_shift, T_NEAR);
1597 // unshift input as new kernel element enters
1598 sub(reg_input, jcp.typesize_in * (dilate_h - 1) * iw * inp_mult);
1599 xor_(reg_tmp, reg_tmp);
1601 // kernel overlap only changes when (t_pad + oj) % dilate_h == 0
1602 sub(reg_kernel, jcp.typesize_out * stride_h * jcp.kw
1603 * jcp.ic_block * jcp.oc_block);
1604 add(reg_kh, stride_h);
1606 jmp(oh_dilate_label_noshift, T_NEAR);
1607 L(oh_dilate_label_shift);
1608 // shift input as old kernel element progresses
1609 add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult);
1610 L(oh_dilate_label_noshift);
1613 add(reg_ih_count, stride_h);
1615 // final number of kernel elements that overlap with input
1616 const int final_inp_ker_overlap
1617 = nstl::min(jcp.kh, div_up(jcp.ih, dilate_h));
1618 cmp(reg_kh, final_inp_ker_overlap);
1619 jl(oh_tpad_label, T_NEAR);
1622 // need second loop to process kernel if it is larger than the input
1623 // (does not apply to dilations as they must have unit stride)
1624 if (kh_range >= jcp.ih + (t_pad % stride_h == 0 ? stride_h :
1625 t_pad % stride_h)) {
1626 assert(!is_dilated);
1627 mov(reg_kh, jcp.ih);
1628 L(oh_tpad_tail_label); {
1629 compute_oh_step_disp();
1630 add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
1631 sub(reg_kernel, jcp.typesize_out * stride_h * jcp.kw
1632 * jcp.ic_block * jcp.oc_block);
1635 add(reg_ih_count, stride_h);
1637 cmp(reg_ih_count, nstl::min(t_pad, jcp.oh * stride_h));
1638 jl(oh_tpad_tail_label, T_NEAR);
1641 // correct any excess shifts to kernel and input
1642 // (does not apply to dilations as they must have unit stride,
1643 // kernel must fit inside input, and padding is smaller than input)
1644 if (t_pad <= jcp.oh * stride_h) {
1645 // kernel has moved beyond padding (adjust for stride effects)
1646 if (t_pad % stride_h != 0) {
1647 assert(!is_dilated);
1648 int inp_corr = stride_h - t_pad % stride_h;
1649 add(reg_kernel, jcp.typesize_out * inp_corr * jcp.kw
1650 * jcp.ic_block * jcp.oc_block);
1651 add(reg_input, jcp.typesize_in * inp_corr * iw * inp_mult);
1654 // kernel still overlaps padding (complete reset)
1655 assert(!is_dilated);
1656 sub(reg_kernel, jcp.typesize_out * (t_pad - jcp.oh * stride_h)
1657 * jcp.kw * jcp.ic_block * jcp.oc_block);
1661 cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h);
1662 jge(oh_label_end, T_NEAR);
1663 cmp(reg_oj, jcp.oh);
1664 jge(oh_label, T_NEAR);
1666 /* Compute middle block(s) */
1667 mov(reg_kh, jcp.kh);
1669 compute_oh_step_disp();
1670 add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult);
1671 add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
1674 add(reg_ih_count, stride_h);
1676 cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h);
1677 jge(oh_label_end, T_NEAR);
1679 cmp(reg_oj, jcp.oh);
1680 jl(oh_label, T_NEAR);
1684 /* Compute bottom edge */
1686 cmp(reg_oj, jcp.oh);
1687 jge(oh_bpad_label_end, T_NEAR);
1690 mov(reg_kh, jcp.kh - 1); // assumes unit stride for dilations
1693 mov(reg_kh, jcp.ihp - b_pad);
1694 sub(reg_kh, reg_ih_count);
1698 compute_oh_step_disp();
1699 add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult);
1700 add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
1703 cmp(reg_tmp, dilate_h);
1704 jl(oh_dilate_label_end, T_NEAR);
1705 xor_(reg_tmp, reg_tmp);
1707 sub(reg_kh, stride_h);
1709 jle(oh_bpad_label_end, T_NEAR);
1711 L(oh_dilate_label_end);
1714 cmp(reg_oj, jcp.oh);
1715 jl(oh_bpad_label, T_NEAR);
1717 L(oh_bpad_label_end);
1721 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::generate()
1725 #ifdef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1726 sub(rsp, stack_space_needed);
1729 mov(reg_input, ptr[param + GET_OFF(src)]);
1730 mov(reg_output, ptr[param + GET_OFF(dst)]);
1731 mov(reg_kernel, ptr[param + GET_OFF(filt)]);
1735 #ifdef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1736 add(rsp, stack_space_needed);
1741 #ifdef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1744 const uint16_t dst_prm_array[32] =
1745 {0,16, 1,17, 2,18, 3,19, 4,20, 5,21, 6,22, 7,23, 8,24,
1746 9,25, 10,26, 11,27, 12,28, 13,29, 14,30, 15,31 };
1748 for (size_t i = 0; i < 32; ++i)
1749 dw(dst_prm_array[i]);
1753 status_t jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::init_conf(
1754 jit_conv_conf_t &jcp, const convolution_desc_t &cd,
1755 cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &diff_weights_pd,
1756 cpu_memory_t::pd_t &diff_bias_pd, cpu_memory_t::pd_t &diff_dst_pd)
1758 const int simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
1760 const memory_desc_wrapper src_d(&src_pd);
1761 const memory_desc_wrapper diff_weights_d(&diff_weights_pd);
1762 const memory_desc_wrapper diff_bias_d(&diff_bias_pd);
1763 const memory_desc_wrapper diff_dst_d(&diff_dst_pd);
1765 const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
1766 int ndims = src_d.ndims();
1768 jcp = zero<decltype(jcp)>();
1770 jcp.prop_kind = cd.prop_kind;
1772 jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
1773 jcp.mb = src_d.dims()[0];
1775 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
1776 jcp.oc_without_padding = jcp.oc;
1777 jcp.ic = src_d.dims()[1] / jcp.ngroups;
1779 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
1780 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2];
1781 jcp.iw = src_d.dims()[ndims-1];
1782 jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
1783 jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2];
1784 jcp.ow = diff_dst_d.dims()[ndims-1];
1786 jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1;
1787 jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims-2];
1788 jcp.kw = diff_weights_d.dims()[with_groups + ndims-1];
1790 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
1791 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
1792 jcp.l_pad = cd.padding[0][ndims-3];
1794 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
1795 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
1796 jcp.stride_w = cd.strides[ndims-3];
1798 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
1799 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
1800 jcp.dilate_w = cd.dilates[ndims-3];
1802 const int kh_range = 1 + (jcp.kh - 1) * (jcp.dilate_h + 1);
1804 // general condition to simplify dilations
1805 && IMPLICATION(jcp.dilate_d != 0, jcp.stride_d == 1)
1806 && IMPLICATION(jcp.dilate_h != 0, jcp.stride_h == 1)
1807 && IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1)
1808 // special condition to simplify dilations in compute_oh_loop_common
1809 && IMPLICATION(jcp.dilate_h != 0, kh_range <= jcp.ih);
1811 return status::unimplemented;
1813 jcp.r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w
1814 + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
1815 jcp.b_pad = nstl::max(0, (jcp.oh - 1) * jcp.stride_h
1816 + (jcp.kh - 1) * (jcp.dilate_h + 1) - (jcp.ih + jcp.t_pad - 1));
1818 /* XXX: currently, does not support stride_d > 1 or dilation > 0 */
1820 if (jcp.stride_d > 1 || jcp.dilate_d > 0)
1821 return status::unimplemented;
1823 jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
1824 jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
1827 jcp.aligned_threads = 0;
1829 jcp.oc_block = simd_w;
1831 bool ok_to_pad_channels = jcp.ngroups == 1;
1833 if (ok_to_pad_channels) {
1834 jcp.oc = rnd_up(jcp.oc, simd_w);
1835 jcp.ic = rnd_up(jcp.ic, simd_w);
1838 auto src_format = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
1839 auto wei_format = with_groups
1840 ? pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o)
1841 : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o);
1842 // temporary workaround until bf16 jit supports 1d
1843 if (wei_format == gOIw16i16o || wei_format == OIw16i16o)
1844 return status::unimplemented;
1846 /* conditions on bias memory */
1847 jcp.with_bias = cd.diff_bias_desc.format != memory_format::undef;
1848 if (jcp.with_bias) {
1849 if (diff_bias_d.format() == any)
1850 CHECK(diff_bias_pd.set_format(x));
1851 if (diff_bias_d.format() != x)
1852 return status::unimplemented;
1855 jcp.nb_oc = jcp.oc / jcp.oc_block;
1857 if (diff_dst_d.format() == any)
1858 CHECK(diff_dst_pd.set_format(src_format));
1859 if (diff_dst_d.format() != src_format)
1860 return status::unimplemented;
1862 /* kernel applicability check wrt boundaries
1863 * the conditions are quite general across the kernels we have,
1864 * but ideally the check should belong to a specific kernel... */
1865 const int max_pad = ((jcp.kh - 1) * (jcp.dilate_h + 1) + 1) / 2;
1866 const bool boundaries_ok = true
1867 && jcp.t_pad <= max_pad
1868 && jcp.b_pad <= max_pad;
1870 return status::unimplemented;
1872 /* yet another common check */
1874 return status::unimplemented;
1876 /* setting register strategy */
1877 for (int ur_w = nstl::min(max_ur_w, jcp.ow); ur_w > 0; --ur_w) {
1878 if (jcp.ow % ur_w == 0) { jcp.ur_w = ur_w; break; }
1881 if (src_d.format() == any)
1882 CHECK(src_pd.set_format(src_format));
1883 if (diff_weights_d.format() == any)
1884 CHECK(diff_weights_pd.set_format(wei_format));
1887 && src_d.format() == src_format
1888 && diff_weights_d.format() == (wei_format);
1890 return status::unimplemented;
1891 jcp.dwei_dt = diff_weights_d.data_type();
1893 jcp.ic_block = simd_w;
1894 if (ok_to_pad_channels)
1895 jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
1896 jcp.nb_ic = jcp.ic / jcp.ic_block;
1897 jcp.src_fmt = src_d.format();
1898 if (mkldnn_thr_syncable()
1899 && one_of(ndims, 3, 4)
1900 && everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w)
1901 && everyone_is(data_type::bf16,
1902 src_d.data_type(), diff_dst_d.data_type())
1903 && one_of(diff_weights_d.data_type(),
1904 data_type::f32, data_type::bf16)) {
1907 return status::unimplemented;
1909 jcp.is_cpx = mayiuse(avx512_core_bf16);
1910 #ifndef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1911 const int tr_round = 4;
1912 // TODO: try to optimize required memory size
1913 int tr_pad = rnd_up(nstl::max(1, nstl::max(jcp.l_pad, jcp.r_pad)),
1915 jcp.tr_iw = rnd_up(div_up(jcp.iw, jcp.stride_w) + tr_pad, tr_round)
1917 jcp.tr_src_num_guard_elems = tr_pad; // upper bound
1918 jcp.tr_ow = rnd_up(jcp.ow, 2);
1919 jcp.ur_w = jcp.tr_ow;
1924 if (jcp.stride_w != 1)
1925 return status::unimplemented;
1927 jcp.typesize_in = sizeof(mkldnn_bfloat16_t);
1928 jcp.typesize_out = sizeof(float);
1931 && jcp.ic % jcp.ic_block == 0
1932 && jcp.oc % jcp.oc_block == 0
1933 && jcp.ic <= src_d.blocking_desc().padding_dims[1]
1934 && jcp.oc <= diff_dst_d.blocking_desc().padding_dims[1]
1936 diff_weights_d.blocking_desc().padding_dims[with_groups + 1]
1938 diff_weights_d.blocking_desc().padding_dims[with_groups + 0];
1939 if (!args_ok) return status::unimplemented;
1942 int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
1943 balance(jcp, nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b);
1945 jcp.nthr_mb = nthr_mb;
1946 jcp.nthr_g = nthr_g;
1947 jcp.nthr_oc_b = nthr_oc_b;
1948 jcp.nthr_ic_b = nthr_ic_b;
1951 return status::success;
1954 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::init_scratchpad(
1955 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
1956 #ifndef BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1957 // XXX: See the comment about tr_iw and guarding elements in
1958 // jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::init_conf()
1959 #if !defined(BF16_CONV_BWD_W_DOES_NOT_USE_BARRIERS)
1960 const size_t max_nthr = jcp.nthr_mb * jcp.ngroups * jcp.nb_ic;
1962 const size_t max_nthr = jcp.nthr;
1963 #endif // defined(BF16_CONV_BWD_W_DOES_NOT_USE_BARRIERS)
1964 const size_t min_tr_src_size_per_thr = jcp.ih * jcp.ic_block * jcp.tr_iw;
1965 const size_t tr_src_size = max_nthr * min_tr_src_size_per_thr
1966 + jcp.tr_src_num_guard_elems;
1967 scratchpad.book(key_conv_tr_src, jcp.typesize_in * tr_src_size);
1969 #if !defined(BF16_CONV_BWD_W_DOES_NOT_USE_BARRIERS)
1970 /* prepare synchronization contexts */
1971 if (jcp.nthr_oc_b > 1) {
1972 const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b;
1973 scratchpad.book(key_conv_tr_src_bctx,
1974 sizeof(simple_barrier::ctx_t) * tr_src_bctx_size);
1976 #endif // !defined(BF16_CONV_BWD_W_DOES_NOT_USE_BARRIERS)
1978 #if !defined(BF16_CONV_BWD_W_DOES_NOT_USE_BARRIERS)
1979 const size_t tr_diff_dst_size = jcp.nthr_mb * jcp.ngroups
1980 * jcp.nb_oc * jcp.oc_block * jcp.tr_ow * jcp.oh;
1982 const size_t tr_diff_dst_size = jcp.nthr
1983 * jcp.oc_block * jcp.tr_ow * jcp.oh;
1984 #endif // defined(BF16_CONV_BWD_W_DOES_NOT_USE_BARRIERS)
1985 scratchpad.book(key_conv_tr_diff_dst, jcp.typesize_in * tr_diff_dst_size);
1987 #if !defined(BF16_CONV_BWD_W_DOES_NOT_USE_BARRIERS)
1988 /* prepare synchronization contexts */
1989 if (jcp.nthr_ic_b > 1) {
1990 const size_t tr_diff_dst_bctx_size = jcp.nthr / jcp.nthr_ic_b;
1991 scratchpad.book(key_conv_tr_diff_dst_bctx,
1992 sizeof(simple_barrier::ctx_t) * tr_diff_dst_bctx_size);
1994 #endif // defined(BF16_CONV_BWD_W_DOES_NOT_USE_BARRIERS)
1995 #endif // BF16_CONV_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
1997 if (jcp.nthr_mb > 1 || jcp.dwei_dt == data_type::bf16) {
1998 const size_t wei_size = jcp.ngroups * jcp.oc * jcp.ic
1999 * jcp.kh * jcp.kw * jcp.kd;
2000 const size_t bia_size = jcp.ngroups * jcp.oc;
2002 const int num_wei_buffers = jcp.dwei_dt == data_type::bf16
2006 const size_t wei_bia_reduction_size = wei_size + bia_size;
2008 scratchpad.book(key_conv_wei_bia_reduction,
2009 jcp.typesize_out * wei_bia_reduction_size * num_wei_buffers);
2010 // TODO: don't use barrier for case
2011 // jcp.dwei_dt == data_type::bf16 && nthr_mb_ == 1
2012 scratchpad.book(key_conv_wei_bia_reduction_bctx,
2013 sizeof(simple_barrier::ctx_t));
2016 if (jcp.with_bias) {
2017 const size_t dst_f32_size = (size_t)jcp.od * jcp.oh * jcp.ow
2018 * jcp.oc_block * jcp.typesize_out;
2019 scratchpad.book(key_conv_dst_bf16_convert_wsp, jcp.nthr * dst_f32_size);
2022 if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
2023 scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc);
2026 void jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::balance(
2027 const jit_conv_conf_t &j, int &nthr_, int &nthr_mb_, int &nthr_g_,
2028 int &nthr_oc_b_, int &nthr_ic_b_)
2030 nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1;
2032 const int max_threads = mkldnn_get_max_threads();
2034 if (max_threads < j.ngroups) {
2035 /* simplification... fortunately it doesn't hurt much */
2039 if (!mkldnn_thr_syncable()) {
2040 // should not happen -- the driver is not ready
2041 // for TBB-like non-synchronous threading yet
2045 nthr_g_ = j.ngroups;
2046 const int nthr = max_threads / nthr_g_;
2048 auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
2049 /* calculate per thread memory cost (read/write). high level optimizer
2050 * tries to minimize memory consumption. few notes:
2051 * (n1) unclear why, but that essentially helps first convolution...
2052 * (n2) assuming the reduction over minibatch is always there:
2053 * - instead of 8 it should be 5 here (write ~= 2 read):
2054 * kernel: temporal workspace 1 write
2055 * reduction: 1 read from workspace and 1 write to the diff_wei
2056 * - but experiments showed 8 works better than 5 or 6... */
2058 const int src_coef = 4;
2059 const int dst_coef = 1;
2060 const int wei_coef = 4;
2064 * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_)
2065 * div_up(j.nb_ic, nthr_ic_b) * j.ic_block * j.ih * j.iw * j.id
2066 / j.stride_d / j.stride_h / j.stride_w /* (n1) */
2068 * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_)
2069 * div_up(j.nb_oc, nthr_oc_b) * j.oc_block * j.oh * j.ow * j.od
2070 + wei_coef /* (n2) */
2071 * div_up(j.ngroups, nthr_g_)
2072 * div_up(j.nb_oc, nthr_oc_b) * div_up(j.nb_ic, nthr_ic_b)
2073 * j.kh * j.kw * j.kd * j.ic_block * j.oc_block;
2076 int best_mem_cost = calc_mem_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
2078 /* step 1: find the best thread distribution with lowest memory cost */
2079 const int nthr_mb_max = nstl::min(nthr, j.mb * j.od);
2080 for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
2081 const int nthr_par = nthr / nthr_mb;
2082 const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc);
2083 for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
2084 int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic);
2086 int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
2087 if (mem_cost <= best_mem_cost) {
2088 best_mem_cost = mem_cost;
2090 nthr_oc_b_ = nthr_oc_b;
2091 nthr_ic_b_ = nthr_ic_b;
2095 if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; }
2098 if (nthr_mb_ > max_threads/2 && nthr_mb_ < max_threads)
2099 nthr_mb_ = nstl::min(j.mb * j.od, max_threads);
2100 nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_;
2102 assert(nthr_ <= max_threads);
2103 assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_mb_ == 1));
2109 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s