1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
19 #include "type_helpers.hpp"
22 #include "cpu_barrier.hpp"
23 #include "cpu_memory.hpp"
25 #include "jit_avx512_common_conv_kernel.hpp"
27 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
28 #define KNx_L2_EFFECTIVE_CAPACITY ((512-64)*1024)
34 using namespace mkldnn::impl::memory_format;
35 using namespace mkldnn::impl::memory_tracking::names;
36 using namespace mkldnn::impl::utils;
37 using namespace Xbyak;
41 constexpr auto small_spatial = 14;
42 unsigned int L1_cache_size = get_cache_size(1, true);
44 inline void pick_loop_order(jit_conv_conf_t &jcp) {
45 using namespace prop_kind;
46 assert(one_of(jcp.prop_kind,
47 forward_training, forward_inference, backward_data));
48 auto w = (jcp.prop_kind == backward_data) ? jcp.iw : jcp.ow;
49 auto h = (jcp.prop_kind == backward_data) ? jcp.ih : jcp.oh;
51 // ow-threading is currently implemented for forward only
52 // TODO: single code for fwd and bwd after ow-thr for bwd
53 // meaningless switch was removed
54 if (jcp.prop_kind == backward_data) {
55 jcp.loop_order = (w <= small_spatial && h <= small_spatial)
56 ? loop_cgn : loop_gnc;
58 jcp.loop_order = (w <= small_spatial && h <= small_spatial)
59 ? loop_cwgn : loop_gncw;
63 inline bool is_1stconv(const jit_conv_conf_t &jcp) {
64 if (mayiuse(avx512_core) && !mayiuse(avx512_core_vnni))
65 return (jcp.ic < 16 && jcp.ngroups == 1);
67 return one_of(jcp.ic, 1, 3);
70 inline bool is_ow_threading_on(const jit_conv_conf_t &jcp) {
71 return (jcp.nb_ow > 1);
74 inline bool is_owb_prefetching(const jit_conv_conf_t &jcp) {
75 return (jcp.ver == ver_4fma && is_ow_threading_on(jcp));
80 template<typename Vmm>
81 void _jit_avx512_common_conv_fwd_kernel<Vmm>::prepare_output(int ur_w)
83 for (int k = 0; k < jcp.nb_oc_blocking; k++)
84 for (int j = 0; j < ur_w; j++) {
85 Vmm vmm = vmm_out(j, k);
86 vpxord(vmm, vmm, vmm);
87 if (!is_owb_prefetching(jcp)) {
88 size_t aux_output_offset = get_output_offset(j, k);
89 mic_prefetcht1(EVEX_compress_addr_safe(reg_out_prf,
90 aux_output_offset, reg_out_long_offt));
95 template<typename Vmm>
96 void _jit_avx512_common_conv_fwd_kernel<Vmm>::store_output(int ur_w)
98 Label no_update_label, store_label, postproc_label;
100 mov(reg_channel, ptr[param1 + GET_OFF(channel)]);
102 mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
107 je(no_update_label, T_NEAR);
110 for (int k = 0; k < jcp.nb_oc_blocking; k++)
111 for (int j = 0; j < ur_w; j++) {
112 Vmm vmm = vmm_out(j, k);
113 size_t aux_output_offset = get_output_offset(j, k);
115 make_safe_addr(reg_out, aux_output_offset, reg_out_long_offt));
119 jmp(postproc_label, T_NEAR);
122 jne(postproc_label, T_NEAR);
127 for (int k = 0; k < jcp.nb_oc_blocking; k++) {
128 int bias_offset = jcp.typesize_out * k * jcp.oc_block;
129 for (int j = 0; j < ur_w; j++) {
130 Vmm vmm = vmm_out(j, k);
131 vadd(vmm, EVEX_compress_addr(reg_bias, bias_offset));
133 mic_prefetcht1(EVEX_compress_addr(reg_bias, bias_offset + 64));
139 cmp(reg_channel, jcp.nb_ic - 1);
140 jl(store_label, T_NEAR);
142 int eltwise_inj_idx = 0;
143 int depthwise_inj_idx = 0;
144 const auto &p = attr_.post_ops_;
146 for (int i = 0; i < p.len_; i++) {
147 auto& post_op = p.entry_[i];
148 if (post_op.is_eltwise()) {
149 if (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) {
150 Vmm vmm_zero = vmm_wei;
151 vpxord(vmm_zero, vmm_zero, vmm_zero);
153 for (int k = 0; k < jcp.nb_oc_blocking; k++)
154 for (int j = 0; j < ur_w; j++) {
155 Vmm vmm = vmm_out(j, k);
156 vpcmpd(k1, vmm, vmm_zero, _cmp_lt_os);
157 vpmulld(vmm | k1, vmm, vmm_zero);
160 if (ur_w == jcp.ur_w) {
161 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(0,
162 jcp.nb_oc_blocking * jcp.ur_w);
164 for (int k = 0; k < jcp.nb_oc_blocking; k++)
165 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(k * jcp.ur_w,
166 k * jcp.ur_w + ur_w);
171 } else if (post_op.is_depthwise()) {
172 mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
173 mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
175 add(reg_d_weights, ptr[this->param1 + GET_OFF(oc_off)]);
176 add(reg_d_bias, ptr[this->param1 + GET_OFF(oc_off)]);
178 for (int k = 0; k < jcp.nb_oc_blocking; k++) {
179 depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
180 k*jcp.ur_w, k*jcp.ur_w + ur_w, reg_d_weights, reg_d_bias);
182 add(reg_d_weights, jcp.oc_block * sizeof(float));
183 add(reg_d_bias, jcp.oc_block * sizeof(float));
191 for (int k = 0; k < jcp.nb_oc_blocking; k++)
192 for (int j = 0; j < ur_w; j++) {
193 Vmm vmm = vmm_out(j, k);
194 size_t aux_output_offset = (size_t)typesize *
195 ((size_t)k * jcp.od * jcp.oh * jcp.ow + j) * jcp.oc_block;
196 vmovups(EVEX_compress_addr_safe(reg_out, aux_output_offset,
197 reg_out_long_offt), vmm);
198 if (!is_owb_prefetching(jcp))
199 mic_prefetcht0(EVEX_compress_addr_safe(reg_out_prf,
200 aux_output_offset, reg_out_long_offt));
204 template<typename Vmm>
205 void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_4fma_1st(int ur_w,
206 int pad_l, int pad_r)
211 void _jit_avx512_common_conv_fwd_kernel<Zmm>::compute_loop_4fma_1st(int ur_w,
212 int pad_l, int pad_r)
214 assert(jcp.dilate_d == 0 && jcp.dilate_h == 0 && jcp.dilate_w == 0);
219 int stride_w = jcp.stride_w;
220 int ic_block = jcp.ic_block;
221 int oc_block = jcp.oc_block;
223 Label kh_label, kd_label;
225 if (one_of(jcp.ndims, 3, 4)) {
226 mov(aux_reg_inp, reg_inp);
227 mov(aux_reg_ker, reg_ker);
228 mov(aux_reg_inp_prf, reg_inp_prf);
231 size_t max_input_offset = (size_t)jcp.typesize_in
232 * ((size_t)(kw + ur_w * stride_w - pad_l)
233 + (size_t)ic_block * iw * ih * jcp.id);
234 assert(reg_inp_prf == reg_long_offt);
235 if (max_input_offset > INT_MAX) push(reg_inp_prf);
237 if (jcp.ndims == 5) {
241 mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
242 mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
243 mov(aux_reg_inp_d, reg_inp);
244 mov(aux_reg_inp_d_prf, reg_inp_prf);
249 if (jcp.ndims == 5) {
250 mov(aux_reg_inp, aux_reg_inp_d);
251 mov(aux_reg_ker, aux_reg_ker_d);
252 mov(aux_reg_inp_prf, aux_reg_inp_d_prf);
256 for (int ki = 0; ki < kw; ki += 4) {
257 for (int ic = 0; ic < ic_block; ic++) {
258 for (int i = 0; i < 4; i++) {
261 * ((ki + i) * oc_block
262 + ic * kw * jcp.kh * jcp.kd * oc_block);
265 EVEX_compress_addr(aux_reg_ker, aux_ker_offset));
267 vpxord(vmm_ker(i), vmm_ker(i), vmm_ker(i));
270 int j_start = get_ow_start(ki, pad_l);
271 int j_end = get_ow_end(ur_w, ki, pad_r);
273 for (int j = j_start, prf_count=0; j < j_end; j++) {
274 size_t aux_input_offset = (size_t)jcp.typesize_in
275 * ((size_t)(ki + j * stride_w
276 - pad_l) + (size_t)ic * iw * ih * jcp.id);
277 v4fmaddps(vmm_out(j, 0), vmm_ker(0),
278 EVEX_compress_addr_safe(aux_reg_inp, aux_input_offset,
280 if (ki + prf_count < kw && prf_count < 4
281 && ((ki < 2 && j % 4) || j % 2)) {
282 int aux_ker_offset = jcp.typesize_in
283 * ((ki + prf_count) * oc_block
284 + ic * kw * jcp.kh * jcp.kd * oc_block + kw * oc_block);
285 mic_prefetcht0(EVEX_compress_addr(aux_reg_ker,
290 && j % (64 / (stride_w * jcp.typesize_in)) == 0) {
291 mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp_prf,
292 aux_input_offset, reg_long_offt));
295 && j % (64 / (stride_w * jcp.typesize_in)) == 0) {
296 mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp,
297 aux_input_offset+jcp.typesize_in * iw, reg_long_offt));
302 add(aux_reg_ker, jcp.typesize_in * kw * oc_block);
303 add(aux_reg_inp, jcp.typesize_in * iw);
304 add(aux_reg_inp_prf, jcp.typesize_in * iw);
308 jg(kh_label, T_NEAR);
310 if (jcp.ndims == 5) {
311 add(aux_reg_inp_d, typesize * jcp.ih * jcp.iw);
312 add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block);
313 add(aux_reg_inp_d_prf, typesize * jcp.ih * jcp.iw);
317 jg(kd_label, T_NEAR);
323 if (max_input_offset > INT_MAX) pop(reg_inp_prf);
326 template<typename Vmm>
327 void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_4fma(int ur_w,
328 int pad_l, int pad_r)
333 void _jit_avx512_common_conv_fwd_kernel<Zmm>::compute_loop_4fma(int ur_w,
334 int pad_l, int pad_r)
336 int stride_w = jcp.stride_w;
337 int ic_block = jcp.ic_block;
338 int oc_block = jcp.oc_block;
339 Label kh_label, last_iter_label, loop_end_label, kd_label;
340 int ker_load_number = 4;
341 int shift_kernel_ptr = typesize * jcp.kw * jcp.oc_block * jcp.ic_block;
342 int shift_input_ptr = typesize * (jcp.dilate_h + 1) * jcp.iw * jcp.ic_block;
344 bool check_last_kh = (jcp.kh > 3);
345 bool pref_current_inp = (jcp.iw < 14 || jcp.iw > 28);
347 int oi_ipref_t0 = get_ow_start(0, pad_l);
348 int ow_end_ipref = get_ow_end(ur_w, 0, pad_r);
350 assert(jcp.oc % jcp.nb_oc_blocking == 0);
352 auto kernel_offset = [=](int ocb, int ic, int ki) {
353 int blk_idx = ocb * jcp.nb_ic * jcp.kh * jcp.kw * jcp.kd + ki;
354 int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block;
355 int ic_offset = ic * jcp.oc_block;
356 return typesize * (blk_offset + ic_offset);
358 auto kernel_loads = [=](int ki, int ic, int kk) {
359 for (int ii = 0; ii < ker_load_number; ii++) {
360 int aux_kernel_offset = kernel_offset(kk, ic + ii, ki);
362 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
365 auto prefetch_inp_next_kh = [&](int ki, int ki_start, int cnt0, int cnt1) {
366 if (cnt1 >= ker_load_number && cnt0 >= ker_load_number
367 && ki >= ki_start && oi_ipref_t0 < ow_end_ipref) {
370 * ((oi_ipref_t0 * stride_w - pad_l) * ic_block
371 + (jcp.dilate_h + 1) * jcp.iw * ic_block);
372 prefetcht0(EVEX_compress_addr(aux_reg_inp,
378 if (one_of(jcp.ndims, 3, 4)) {
379 mov(aux_reg_inp, reg_inp);
380 mov(aux_reg_ker, reg_ker);
381 mov(aux_reg_ker_prf, reg_ker_prf);
382 mov(aux_reg_inp_prf, reg_inp_prf);
385 if (jcp.ndims == 5) {
389 mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
390 mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
391 mov(aux_reg_inp_d, reg_inp);
392 mov(aux_reg_inp_d_prf, reg_inp_prf);
393 mov(aux_reg_ker_d_prf, reg_ker_prf);
395 mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
399 if (jcp.ndims == 5) {
400 mov(aux_reg_inp, aux_reg_inp_d);
401 mov(aux_reg_ker, aux_reg_ker_d);
402 mov(aux_reg_ker_prf, aux_reg_ker_d_prf);
403 mov(aux_reg_inp_prf, aux_reg_inp_d_prf);
410 for (int ki = 0; ki < kw; ki++)
411 for (int ic = 0; ic < ic_block; ic += 4)
412 for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) {
413 bool last_kernel_loads = (kk == jcp.nb_oc_blocking - 1
414 && ki == kw - 1 && (ic + 4) == ic_block);
416 if (last_kernel_loads) {
418 je(last_iter_label, T_NEAR);
421 kernel_loads(ki, ic, kk);
422 for (int oi = get_ow_start(ki, pad_l), prf_count_t1 = 0,
424 oi < get_ow_end(ur_w, ki, pad_r); oi++) {
425 int aux_input_offset = typesize
426 * ((ki * (jcp.dilate_w + 1) + oi * stride_w
429 v4fmaddps(vmm_out(oi, kk), vmm_ker(0),
430 EVEX_compress_addr(aux_reg_inp, aux_input_offset));
433 if (prf_count_t0 < 4) {
435 if (last_kernel_loads)
436 aux_kernel_prf= kernel_offset(0,
437 prf_count_t0 + ic + 4
438 - ic_block, 0) + typesize * kw
439 * oc_block * ic_block;
441 aux_kernel_prf = kernel_offset(kk, ic + 4
443 mic_prefetcht0(EVEX_compress_addr(aux_reg_ker,
446 } else if (prf_count_t1 < 4) {
447 mic_prefetcht1(EVEX_compress_addr(
448 aux_reg_ker_prf, kernel_offset(kk, ic
449 + prf_count_t1, ki)));
453 prefetch_inp_next_kh(ki, 2, prf_count_t0,
457 if (last_kernel_loads) {
458 jmp(loop_end_label, T_NEAR);
462 kernel_loads(ki, ic, kk);
463 for (int oi = get_ow_start(ki, pad_l), prf_count_t1 = 0,
465 oi < get_ow_end(ur_w, ki, pad_r); oi++) {
466 int aux_input_offset = typesize
467 * ((ki * (jcp.dilate_w + 1) + oi * stride_w
470 v4fmaddps(vmm_out(oi, kk), vmm_ker(0),
471 EVEX_compress_addr(aux_reg_inp,
474 if (prf_count_t0 < 4) {
475 mic_prefetcht0(EVEX_compress_addr(
476 aux_reg_ker_prf, kernel_offset(0,
479 } else if (prf_count_t1 < 4) {
480 mic_prefetcht1(EVEX_compress_addr(
481 aux_reg_ker_prf, kernel_offset(kk,
482 ic + prf_count_t1, ki)));
491 for (int ki = 0; ki < kw; ki++)
492 for (int ic = 0; ic < ic_block; ic += 4)
493 for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) {
494 kernel_loads(ki, ic, kk);
495 for (int oi = get_ow_start(ki, pad_l),
496 prf_count_t1 = 0, prf_count_t0 = 0;
497 oi < get_ow_end(ur_w, ki, pad_r); oi++) {
498 int aux_input_offset = typesize
499 * ((ki * (jcp.dilate_w + 1) + oi * stride_w
500 - pad_l) * ic_block + ic);
501 v4fmaddps(vmm_out(oi, kk), vmm_ker(0),
502 EVEX_compress_addr(aux_reg_inp,
505 if (!is_owb_prefetching(jcp)) {
506 if ((oi % 2) && (prf_count_t1 < 4)) {
507 mic_prefetcht1(EVEX_compress_addr(
508 aux_reg_ker_prf, kernel_offset(kk,
509 ic + prf_count_t1, ki)));
513 if (!(ki == 0 && ic == 0)
514 && !(ki == kw-1 && ic == 0) &&
515 (oi % 2) && (prf_count_t1 < 4)
517 mic_prefetcht0(EVEX_compress_addr(
518 aux_reg_ker, kernel_offset(kk,
519 ic + 4 + prf_count_t0, ki)));
523 if (!is_owb_prefetching(jcp)) {
524 if (pref_current_inp) {
525 if (ki == 0 && ic == 0 && kk == 0)
526 mic_prefetcht0(EVEX_compress_addr(
528 aux_input_offset + shift_input_ptr));
530 if (ki == 1 && ic == 0 && kk == 0)
531 mic_prefetcht1(EVEX_compress_addr(
532 aux_reg_inp_prf, aux_input_offset));
535 int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
537 = jcp.typesize_in * ur_w * stride_w * inp_mult;
538 bool kk_pref_slot = kk ? oi % 2 : !(oi % 2);
539 if (ki == 0 && ic == 0 && kk_pref_slot)
540 mic_prefetcht1(EVEX_compress_addr(
542 aux_input_offset + inp_shift));
544 if (ki == kw - 1 && ic == 0 && kk_pref_slot)
545 mic_prefetcht0(EVEX_compress_addr(
547 aux_input_offset + inp_shift));
553 add(aux_reg_ker, shift_kernel_ptr);
554 add(aux_reg_inp, shift_input_ptr);
555 add(aux_reg_ker_prf, shift_kernel_ptr);
556 add(aux_reg_inp_prf, shift_input_ptr);
560 jg(kh_label, T_NEAR);
562 if (jcp.ndims == 5) {
564 typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * jcp.ic_block);
565 add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block
567 add(aux_reg_inp_d_prf,
568 typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * jcp.ic_block);
569 add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh * jcp.oc_block
574 jg(kd_label, T_NEAR);
581 template<typename Vmm>
582 void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_fma(int ur_w,
583 int pad_l, int pad_r)
588 int stride_w = jcp.stride_w;
592 int ic_block = jcp.ic_block;
593 int oc_block = jcp.oc_block;
594 int nb_oc_block = jcp.nb_oc_blocking;
595 Label kh_label, kd_label;
597 int ker_pipeline_depth = 4;
598 assert(ker_reg_base_idx + ker_pipeline_depth <= 32);
599 assert(oc_block >= ker_pipeline_depth);
601 int num_ker_loads = ic_block * nb_oc_block * kw;
602 int num_ker_prfs = prf_ker ? num_ker_loads : 0;
603 int num_inp_prfs = prf_inp ?
604 ur_w * nstl::min(kw, stride_w) + nstl::max(0, kw - stride_w) :
606 if (jcp.is_1stconv && prf_inp) {
607 num_inp_prfs = div_up(num_inp_prfs, jcp.simd_w) * ic_block;
609 int num_prfs = num_ker_prfs + num_inp_prfs;
610 int num_fmas = num_ker_loads * ur_w;
612 = (prf_ker || prf_inp) ? nstl::max(1, num_fmas / num_prfs) : 1;
613 int prf_inst_trigger = (num_fmas % prf_inst_spacing) / 2;
614 int inp_mul = !jcp.is_1stconv ? ic_block : 1;
616 if (one_of(jcp.ndims, 3, 4)) {
617 mov(aux_reg_inp, reg_inp);
618 mov(aux_reg_ker, reg_ker);
619 mov(aux_reg_inp_prf, reg_inp_prf);
620 mov(aux_reg_ker_prf, reg_ker_prf);
623 size_t max_input_offset = (size_t)jcp.typesize_in * ic_block * iw * ih * id;
624 assert(reg_inp_prf == reg_long_offt);
625 if (max_input_offset > INT_MAX) push(reg_inp_prf);
628 if (jcp.ndims == 5) {
632 mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
633 mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
634 mov(aux_reg_inp_d, reg_inp);
635 mov(aux_reg_inp_d_prf, reg_inp_prf);
636 mov(aux_reg_ker_d_prf, reg_ker_prf);
639 mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
644 if (jcp.ndims == 5) {
645 mov(aux_reg_inp, aux_reg_inp_d);
646 mov(aux_reg_ker, aux_reg_ker_d);
647 mov(aux_reg_ker_prf, aux_reg_ker_d_prf);
648 mov(aux_reg_inp_prf, aux_reg_inp_d_prf);
656 for (int ki = 0; ki < kw; ki++) {
657 for (int ic = 0; ic < ic_block; ic++) {
658 int aux_kernel_offset = 0;
660 for (int i = 0; i < ker_pipeline_depth; i++) {
661 aux_kernel_offset = get_kernel_offset(ki, ic, 0, i);
662 vmovups(vmm_ker(i), EVEX_compress_addr(
663 aux_reg_ker, aux_kernel_offset));
665 } else if (step < num_ker_loads - ker_pipeline_depth + 1) {
666 int load_offset = ker_pipeline_depth - 1;
668 = (step + load_offset) % ker_pipeline_depth;
670 = get_kernel_offset(ki, ic, 0, load_offset);
671 vmovups(vmm_ker(ker_load_reg_idx),
672 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
675 bool ker_prf_inserted = false;
676 Vmm vmm_kernel = vmm_ker(step % ker_pipeline_depth);
677 int j_start = get_ow_start(ki, pad_l);
678 int j_end = get_ow_end(ur_w, ki, pad_r);
679 for (int j = j_start; j < j_end; j++) {
680 size_t aux_input_offset = get_input_offset(ki, ic, j, pad_l);
681 auto addr = EVEX_compress_addr_safe(aux_reg_inp,
682 aux_input_offset, reg_long_offt, true);
683 vfmadd231ps(vmm_out(j, 0), vmm_kernel, addr);
684 int fma_idx = step * ur_w + j;
685 int prf_slot_idx = fma_idx / prf_inst_spacing;
686 if (fma_idx % prf_inst_spacing == prf_inst_trigger) {
687 if (prf_ker && !ker_prf_inserted
688 && ker_prfs < num_ker_prfs) {
690 = jcp.typesize_in * ker_prfs * jcp.oc_block;
691 mic_prefetcht2(EVEX_compress_addr(
692 aux_reg_ker_prf, ker_prf_offset));
693 ker_prf_inserted = true;
695 } else if (prf_inp) {
696 int inp_prf_idx = prf_slot_idx - ker_prfs;
697 if (inp_prf_idx < num_inp_prfs) {
698 size_t inp_prf_stride = nstl::max(kw, stride_w);
699 size_t inp_prf_offset;
700 if (!jcp.is_1stconv) {
702 = ic_block * jcp.typesize_in
703 * ((inp_prf_idx / kw)
705 + (inp_prf_idx % kw));
707 size_t ic_prf_stride =
708 (size_t)jcp.typesize_in * iw * ih * id;
710 = jcp.typesize_in * jcp.simd_w;
711 inp_prf_offset = ((inp_prf_idx / ic_block)
713 + (inp_prf_idx % ic_block)
716 mic_prefetcht0(EVEX_compress_addr_safe(
717 aux_reg_inp_prf, inp_prf_offset,
726 add(aux_reg_ker, jcp.typesize_in * kw * oc_block * ic_block);
728 add(aux_reg_ker_prf, jcp.typesize_in * kw * oc_block * ic_block);
729 add(aux_reg_inp, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul);
732 jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul);
735 jg(kh_label, T_NEAR);
739 if (jcp.ndims == 5) {
741 typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul);
742 add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block
744 add(aux_reg_inp_d_prf,
745 typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul);
746 add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh * jcp.oc_block
751 jg(kd_label, T_NEAR);
756 if (max_input_offset > INT_MAX) pop(reg_inp_prf);
759 template<typename Vmm>
760 void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_fma_core(int ur_w,
761 int pad_l, int pad_r)
764 int stride_w = jcp.stride_w;
765 int ic_block = jcp.ic_block;
766 int oc_block = jcp.oc_block;
767 int nb_oc_block = jcp.nb_oc_blocking;
768 Label kh_label, kd_label;
769 int shift_kernel_ptr = jcp.typesize_in * jcp.kw * jcp.oc_block
771 int inp_mul = !jcp.is_1stconv ? ic_block : 1;
772 int shift_input_ptr = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw
776 auto input_offset = [=](int oi, int ic, int ki) {
777 return (size_t)jcp.typesize_in
778 * ((size_t)(ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l)
779 * inp_mul + (size_t)ic
780 * (!jcp.is_1stconv ? 1 : (size_t)jcp.iw * jcp.ih * jcp.id));
783 if (one_of(jcp.ndims, 3, 4)) {
784 mov(aux_reg_inp, reg_inp);
785 mov(aux_reg_ker, reg_ker);
788 if (jcp.ndims == 5) {
791 mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
792 mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
793 mov(aux_reg_inp_d, reg_inp);
796 mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
801 if (jcp.ndims == 5) {
802 mov(aux_reg_inp, aux_reg_inp_d);
803 mov(aux_reg_ker, aux_reg_ker_d);
808 for (int ki = 0; ki < kw; ki++) {
809 int jj_start = get_ow_start(ki, pad_l);
810 int jj_end = get_ow_end(ur_w, ki, pad_r);
811 for (int ic = 0; ic < ic_block; ic++) {
812 if (jcp.kernel_kind == expl_bcast) {
813 for (int jj = jj_start; jj < jj_end; jj++) {
814 size_t aux_input_offset = input_offset(jj, ic, ki);
815 vbroadcastss(vmm_inp(jj, nb_oc_block),
816 EVEX_compress_addr_safe(aux_reg_inp,
817 aux_input_offset, reg_long_offt));
820 for (int ii = 0; ii < nb_oc_block; ii++) {
821 int aux_kernel_offset = jcp.typesize_in
822 * (ii * jcp.nb_ic * jcp.kh * jcp.kw * jcp.kd * ic_block
823 * oc_block + ki * ic_block * oc_block + ic * oc_block);
824 if (jj_end - jj_start > 0)
825 vmovups(vmm_wei, EVEX_compress_addr(aux_reg_ker,
827 for (int jj = jj_start; jj < jj_end; jj++)
828 if (jcp.kernel_kind == expl_bcast)
829 vfmadd231ps(vmm_out(jj, ii),
830 vmm_inp(jj, nb_oc_block), vmm_wei);
832 size_t aux_input_offset = input_offset(jj, ic, ki);
833 vfmadd231ps(vmm_out(jj, ii), vmm_wei,
834 EVEX_compress_addr_safe(aux_reg_inp,
835 aux_input_offset, reg_long_offt, true));
840 add(aux_reg_ker, shift_kernel_ptr);
841 add(aux_reg_inp, shift_input_ptr);
844 jg(kh_label, T_NEAR);
847 if (jcp.ndims == 5) {
849 typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul);
850 add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block
855 jg(kd_label, T_NEAR);
861 template<typename Vmm>
862 void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_vnni(
863 int ur_w, int pad_l, int pad_r)
868 void _jit_avx512_common_conv_fwd_kernel<Zmm>::compute_loop_vnni(
869 int ur_w, int pad_l, int pad_r)
871 Label kh_label, kd_label;
872 const int ker_reg_base_idx = 28;
873 const int channel_inc = jcp.ver == ver_4vnni ? 4 : 1;
874 const int ker_load_number = jcp.ver == ver_4vnni ? 4 : 1;
875 const int shift_kernel_ptr = jcp.typesize_in * jcp.kw
876 * jcp.oc_block * jcp.ic_block;
877 const int shift_input_ptr
878 = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw * jcp.ic_block;
880 size_t max_input_offset = (size_t)jcp.typesize_in
881 * jcp.ic_block * jcp.iw * jcp.ih * jcp.id;
882 assert(reg_inp_prf == reg_long_offt);
883 if (max_input_offset > INT_MAX) push(reg_inp_prf);
886 if (one_of(jcp.ndims, 3, 4)) {
887 mov(aux_reg_inp, reg_inp);
888 mov(aux_reg_ker, reg_ker);
889 mov(aux_reg_ker_prf, reg_ker_prf);
890 mov(aux_reg_inp_prf, reg_inp_prf);
893 if (jcp.ndims == 5) {
897 mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
898 mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
899 mov(aux_reg_inp_d, reg_inp);
900 mov(aux_reg_inp_d_prf, reg_inp_prf);
901 mov(aux_reg_ker_d_prf, reg_ker_prf);
904 mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
908 if (jcp.ndims == 5) {
909 mov(aux_reg_inp, aux_reg_inp_d);
910 mov(aux_reg_ker, aux_reg_ker_d);
911 mov(aux_reg_ker_prf, aux_reg_ker_d_prf);
912 mov(aux_reg_inp_prf, aux_reg_inp_d_prf);
916 for (int ki = 0; ki < jcp.kw; ki++) {
917 int ow_start = get_ow_start(ki, pad_l);
918 int ow_end = get_ow_end(ur_w, ki, pad_r);
919 for (int ic = 0; ic < jcp.ic_block / 2; ic += channel_inc) {
920 if (jcp.kernel_kind == expl_bcast) {
921 for (int oi = ow_start; oi < ow_end; oi++) {
922 size_t input_offset = get_input_offset(ki, ic, oi, pad_l);
923 vpbroadcastd(vmm_inp(oi, jcp.nb_oc_blocking),
924 EVEX_compress_addr_safe(aux_reg_inp, input_offset,
928 for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) {
929 if (jcp.kernel_kind == expl_bcast) {
930 int kernel_offset = get_kernel_offset(ki, ic, kk, 0);
932 EVEX_compress_addr(aux_reg_ker, kernel_offset));
934 for (int ii = 0; ii < ker_load_number; ii++) {
936 = get_kernel_offset(ki, ic, kk, ii);
937 vmovups(Zmm(ker_reg_base_idx + ii),
939 aux_reg_ker, kernel_offset));
942 for (int oi = ow_start, prf_count = 0; oi < ow_end; oi++) {
943 size_t input_offset = get_input_offset(ki, ic, oi, pad_l);
944 if (jcp.kernel_kind == expl_bcast) {
945 vpdpwssd(vmm_out(oi, kk), vmm_wei,
946 vmm_inp(oi, jcp.nb_oc_blocking));
948 if (jcp.ver == ver_4vnni)
949 vp4dpwssd(vmm_out(oi, kk), Zmm(ker_reg_base_idx),
950 EVEX_compress_addr_safe(aux_reg_inp,
951 input_offset, reg_long_offt, false));
953 vpdpwssd(vmm_out(oi, kk), Zmm(ker_reg_base_idx),
954 EVEX_compress_addr_safe(aux_reg_inp,
955 input_offset, reg_long_offt, true));
957 if ((oi % 2) && (prf_count < ker_load_number)) {
958 int kernel_offset = get_kernel_offset(
959 ki, ic, kk, prf_count++);
960 mic_prefetcht0(EVEX_compress_addr(aux_reg_ker_prf,
963 if (!(oi % 2) && ki == 0 && ic == 0 && kk == 0) {
964 mic_prefetcht1(EVEX_compress_addr_safe(
965 aux_reg_inp_prf, input_offset, reg_long_offt));
967 if (!(oi % 2) && ki == 1 && ic == 0 && kk == 0) {
968 mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp,
969 input_offset + shift_input_ptr, reg_long_offt));
975 add(aux_reg_ker_prf, shift_kernel_ptr);
976 add(aux_reg_inp_prf, shift_input_ptr);
977 add(aux_reg_ker, shift_kernel_ptr);
978 add(aux_reg_inp, shift_input_ptr);
982 jg(kh_label, T_NEAR);
985 if (jcp.ndims == 5) {
986 add(aux_reg_inp_d, jcp.typesize_in * jcp.ih * jcp.iw * jcp.ic_block);
987 add(aux_reg_ker_d, jcp.typesize_in * jcp.kw * jcp.kh * jcp.oc_block
989 add(aux_reg_inp_d_prf, jcp.typesize_in * jcp.ih * jcp.iw * jcp.ic_block);
990 add(aux_reg_ker_d_prf, jcp.typesize_in * jcp.kw * jcp.kh * jcp.oc_block
995 jg(kd_label, T_NEAR);
1000 if (max_input_offset > INT_MAX) pop(reg_inp_prf);
1003 template<typename Vmm>
1004 void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop(int ur_w,
1005 int pad_l, int pad_r)
1007 if (jcp.ndims == 5) push(reg_oi);
1009 prepare_output(ur_w);
1011 Label skip_compute_loop;
1012 if (jcp.ndims == 5) {
1013 if ((jcp.dilate_d >= jcp.id)
1014 || (jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) {
1015 mov(reg_kj, ptr[param1 + GET_OFF(kd_padding)]);
1017 je(skip_compute_loop, T_NEAR);
1020 if ((jcp.dilate_h >= jcp.ih)
1021 || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
1022 mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
1024 je(skip_compute_loop, T_NEAR);
1027 if (jcp.ver == ver_4vnni || jcp.ver == ver_vnni)
1028 compute_loop_vnni(ur_w, pad_l, pad_r);
1029 else if (jcp.ver == ver_4fma)
1031 compute_loop_4fma_1st(ur_w, pad_l, pad_r);
1033 compute_loop_4fma(ur_w, pad_l, pad_r);
1034 else if (jcp.ver == ver_fma)
1035 if ((jcp.is_1stconv && jcp.kernel_kind != expl_bcast)
1036 || mayiuse(avx512_mic))
1037 compute_loop_fma(ur_w, pad_l, pad_r);
1039 if (jcp.kernel_kind == embd_bcast && jcp.nb_oc_blocking == 1)
1040 compute_loop_fma(ur_w, pad_l, pad_r);
1042 compute_loop_fma_core(ur_w, pad_l, pad_r);
1044 assert(!"unknown convolution version");
1046 L(skip_compute_loop);
1048 if (jcp.ndims == 5) pop(reg_oi);
1051 template<typename Vmm>
1052 void _jit_avx512_common_conv_fwd_kernel<Vmm>::generate()
1054 const auto &p = attr_.post_ops_;
1055 for (int i = 0; i < p.len_; i++) {
1056 auto &post_op = p.entry_[i];
1057 if (post_op.is_eltwise()) {
1058 eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<avx512_common>(
1060 post_op.eltwise.alg,
1061 post_op.eltwise.alpha,
1062 post_op.eltwise.beta
1064 } else if (post_op.is_depthwise()) {
1065 depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<avx512_common>(
1067 post_op.depthwise.alg
1074 int ow_block = jcp.ow_block;
1075 int nb_ow = jcp.nb_ow;
1077 int l_pad = jcp.l_pad;
1078 int ur_w = jcp.ur_w;
1079 int ur_w_tail = jcp.ur_w_tail;
1080 int dilate_w = jcp.dilate_w + 1;
1081 int stride_w = jcp.stride_w;
1083 int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
1084 int inp_shift_pad = jcp.typesize_in * (ur_w * stride_w - l_pad) * inp_mult;
1085 int inp_shift = jcp.typesize_in * ur_w * stride_w * inp_mult;
1086 int inp_shift_pad_second_block = -1 * jcp.typesize_in * l_pad * inp_mult;
1087 int out_shift = jcp.typesize_out * ur_w * jcp.oc_block;
1090 mov(reg_inp, ptr[param1 + GET_OFF(src)]);
1091 mov(reg_out, ptr[param1 + GET_OFF(dst)]);
1092 mov(reg_ker, ptr[param1 + GET_OFF(filt)]);
1093 mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]);
1094 mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]);
1096 int r_pad = nstl::max(
1097 0, (ow - 1) * stride_w + (kw - 1) * dilate_w - (iw + l_pad - 1));
1098 int n_oi = ow / ur_w;
1099 int r_pad1 = (ur_w * n_oi - 1) * stride_w + (kw - 1) * dilate_w
1102 if (!is_ow_threading_on(jcp)) {
1103 // ow is being processed as a whole - with left and right paddings
1104 if (r_pad1 > 0) n_oi--;
1107 mov(reg_inp_prf, ptr[param1 + GET_OFF(src_prf)]);
1108 mov(reg_out_prf, ptr[param1 + GET_OFF(dst_prf)]);
1109 compute_loop(ur_w, l_pad, r_pad);
1111 mov(reg_inp_prf, reg_inp);
1112 mov(reg_out_prf, reg_out);
1114 add(reg_inp_prf, inp_shift_pad);
1115 add(reg_out_prf, out_shift);
1116 compute_loop(ur_w, l_pad, r_pad1);
1117 add(reg_inp, inp_shift_pad);
1118 add(reg_out, out_shift);
1119 if (ur_w_tail != 0) {
1120 add(reg_inp_prf, inp_shift);
1121 add(reg_out_prf, out_shift);
1122 compute_loop(ur_w_tail, 0, r_pad);
1127 add(reg_inp_prf, inp_shift_pad);
1128 add(reg_out_prf, out_shift);
1129 compute_loop(ur_w, l_pad, 0);
1130 add(reg_inp, inp_shift_pad);
1131 add(reg_out, out_shift);
1134 xor_(reg_oi, reg_oi);
1135 Label ow_loop_label;
1138 add(reg_inp_prf, inp_shift);
1139 add(reg_out_prf, out_shift);
1140 compute_loop(ur_w, 0, 0);
1141 add(reg_inp, inp_shift);
1142 add(reg_out, out_shift);
1145 jl(ow_loop_label, T_NEAR);
1149 add(reg_inp_prf, inp_shift);
1150 add(reg_out_prf, out_shift);
1151 compute_loop(ur_w, 0, r_pad1);
1152 add(reg_inp, inp_shift);
1153 add(reg_out, out_shift);
1155 if (ur_w_tail != 0) {
1156 add(reg_inp_prf, inp_shift);
1157 add(reg_out_prf, out_shift);
1158 compute_loop(ur_w_tail, 0, r_pad);
1163 // ow block is only processed.
1164 // Number of block is passed as parameter owb,
1165 // and padding processing depends on this number.
1167 Label end_label, last_oi_label, middle_ow_blocks_label, tail_label;
1168 Label oi_loop_label, oi_loop_start_label, oi_loop_end_label;
1170 assert(ow_block % ur_w == 0);
1171 int n_oi_not_last_ow_block = ow_block / ur_w;
1172 // to simplify code (and general regs usage),
1173 // size of ow block must be >= 2 * ur_w
1174 assert(n_oi_not_last_ow_block > 1);
1175 int n_oi_next_last_ow_block = n_oi_not_last_ow_block;
1176 int n_oi_first_ow_block = n_oi_not_last_ow_block;
1178 int n_oi_last_ow_block = (ow - ow_block * (nb_ow-1)) / ur_w;
1180 // prepare right padding
1181 bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0;
1182 bool first_ow_block_padded = next_last_ow_block_padded && jcp.nb_ow == 2;
1183 bool last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block > 0;
1185 if (last_ow_block_padded) n_oi_last_ow_block--;
1186 else if (first_ow_block_padded) n_oi_first_ow_block--;
1187 else if (next_last_ow_block_padded) n_oi_next_last_ow_block--;
1189 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
1190 cmp(reg_owb, 0); // is that the first ow-block ?
1191 jg(middle_ow_blocks_label, T_NEAR);
1193 // the first ow block, compute left padding
1195 mov(reg_oi, n_oi_first_ow_block);
1196 mov(reg_inp_prf, reg_inp);
1197 mov(reg_out_prf, reg_out);
1200 mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]);
1201 add(reg_inp_prf, inp_shift_pad);
1202 add(reg_out_prf, out_shift);
1203 compute_loop(ur_w, l_pad, 0);
1204 add(reg_inp, inp_shift_pad);
1205 add(reg_out, out_shift);
1208 jmp(oi_loop_label, T_NEAR);
1210 // middle or last ow block entry
1212 L(middle_ow_blocks_label);
1215 // just to consider left padding, not compute
1216 add(reg_inp, inp_shift_pad_second_block);
1217 add(reg_inp_prf, inp_shift_pad_second_block);
1220 // set number of iteration for oi-loop
1221 cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ?
1222 mov(reg_oi, n_oi_last_ow_block);
1223 je(oi_loop_label, T_NEAR);
1224 cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
1225 mov(reg_oi, n_oi_next_last_ow_block);
1226 je(oi_loop_label, T_NEAR);
1227 mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks
1229 // oi loop w/o padding
1231 mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]);
1232 L(oi_loop_start_label);
1234 jle(oi_loop_end_label, T_NEAR);
1236 add(reg_inp_prf, inp_shift);
1237 add(reg_out_prf, out_shift);
1238 compute_loop(ur_w, 0, 0);
1239 add(reg_inp, inp_shift);
1240 add(reg_out, out_shift);
1242 jmp(oi_loop_start_label, T_NEAR);
1243 L(oi_loop_end_label);
1245 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
1247 cmp(reg_owb, 0); // first ow-block ?
1248 if (first_ow_block_padded) {
1249 je(last_oi_label, T_NEAR);
1251 je(end_label, T_NEAR);
1253 cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
1254 jl(end_label, T_NEAR);
1255 if (next_last_ow_block_padded) {
1256 je(last_oi_label, T_NEAR);
1258 je(end_label, T_NEAR);
1260 // that is last block
1261 if (!last_ow_block_padded) {
1262 jmp(tail_label, T_NEAR);
1265 // last oi block with right padding
1267 mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]);
1268 add(reg_inp_prf, inp_shift);
1269 add(reg_out_prf, out_shift);
1270 compute_loop(ur_w, 0, r_pad1);
1271 add(reg_inp, inp_shift);
1272 add(reg_out, out_shift);
1274 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
1275 cmp(reg_owb, jcp.nb_ow - 1); // last ow_block?
1276 jl(end_label, T_NEAR);
1279 mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]);
1280 if (ur_w_tail != 0) {
1281 add(reg_inp_prf, inp_shift);
1282 add(reg_out_prf, out_shift);
1283 compute_loop(ur_w_tail, 0, r_pad);
1289 for (auto& inj : eltwise_injectors)
1290 inj->prepare_table();
1293 bool jit_avx512_common_conv_fwd_kernel::post_ops_ok(
1294 jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
1295 const auto &p = attr.post_ops_;
1297 auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
1298 auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
1299 auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
1300 auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
1303 case 0: return true;
1304 case 1: return is_simple(0) || is_sum(0);
1305 case 2: return (is_sum(0) && is_simple(1)) || (is_simple(0) && is_simple(1));
1306 case 3: return is_sum(0) && is_simple(1) && is_simple(2);
1307 default: return false;
1313 status_t jit_avx512_common_conv_fwd_kernel::init_conf(
1314 jit_conv_conf_t &jcp, const convolution_desc_t &cd,
1315 cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &weights_pd,
1316 cpu_memory_t::pd_t &dst_pd, cpu_memory_t::pd_t &bias_pd,
1317 const primitive_attr_t &attr, int nthreads)
1319 using namespace prop_kind;
1321 if (!mayiuse(avx512_common))
1322 return status::unimplemented;
1324 const memory_desc_wrapper src_d(&src_pd);
1325 const memory_desc_wrapper weights_d(&weights_pd);
1326 const memory_desc_wrapper dst_d(&dst_pd);
1327 const memory_desc_wrapper bias_d(&bias_pd);
1329 const int regs = 28;
1330 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
1331 int ndims = src_d.ndims();
1333 jcp = zero<decltype(jcp)>();
1335 jcp.prop_kind = cd.prop_kind;
1336 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
1337 jcp.mb = src_d.dims()[0];
1338 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
1339 jcp.oc_without_padding = jcp.oc;
1340 jcp.ic = src_d.dims()[1] / jcp.ngroups;
1341 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
1342 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2];
1343 jcp.iw = src_d.dims()[ndims-1];
1344 jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
1345 jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims-2];
1346 jcp.ow = dst_d.dims()[ndims-1];
1347 jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
1348 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims-2];
1349 jcp.kw = weights_d.dims()[with_groups + ndims-1];
1350 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
1351 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
1352 jcp.l_pad = cd.padding[0][ndims-3];
1353 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
1354 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
1355 jcp.stride_w = cd.strides[ndims-3];
1356 jcp.src_fmt = src_d.format();
1358 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
1359 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
1360 jcp.dilate_w = cd.dilates[ndims-3];
1362 jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
1363 - (jcp.ih + jcp.t_pad - 1);
1364 jcp.back_pad = (jcp.od - 1) * jcp.stride_d
1365 + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1);
1367 jcp.is_1stconv = is_1stconv(jcp);
1369 bool ok_to_pad_channels = true
1371 && src_d.data_type() == data_type::f32;
1373 const int full_simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
1374 jcp.simd_w = full_simd_w;
1375 bool ok_to_try_xmm = true
1376 && mayiuse(avx512_core)
1377 && src_d.data_type() == data_type::f32
1379 && !ok_to_pad_channels
1380 && (jcp.ic % jcp.simd_w != 0 || jcp.oc % jcp.simd_w != 0)
1381 && (jcp.ic % 8 != 0 || jcp.oc % 8 != 0);
1385 jcp.oc_block = jcp.simd_w;
1386 jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w;
1387 jcp.aligned_threads = 0;
1389 if (ok_to_pad_channels) {
1390 jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
1391 jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
1394 && jcp.oc % jcp.oc_block == 0
1395 && jcp.ic % jcp.ic_block == 0;
1397 return status::unimplemented;
1399 if (!post_ops_ok(jcp, attr))
1400 return status::unimplemented;
1402 const auto &p = attr.post_ops_;
1403 jcp.with_sum = p.find(primitive_kind::sum) != -1;
1404 const int eltwise_ind = p.find(primitive_kind::eltwise);
1405 jcp.with_eltwise = eltwise_ind != -1;
1406 if (jcp.with_eltwise) {
1407 jcp.eltwise = p.entry_[eltwise_ind].eltwise;
1408 if (dst_d.data_type() == data_type::s32) return status::unimplemented;
1411 auto src_format = jcp.is_1stconv
1412 ? pick(ndims - 3, ncw, nchw, ncdhw)
1413 : ((jcp.simd_w == 4)
1414 ? pick(ndims - 3, nCw4c, nChw4c, nCdhw4c)
1415 : pick(ndims - 3, nCw16c, nChw16c, nCdhw16c));
1416 auto dst_format = (jcp.simd_w == 4)
1417 ? pick(ndims - 3, nCw4c, nChw4c, nCdhw4c)
1418 : pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
1419 auto wei_format = with_groups
1420 ? ((jcp.simd_w == 4)
1421 ? pick(ndims - 3, gOIw4i4o, gOIhw4i4o, gOIdhw4i4o)
1422 : pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o))
1423 : ((jcp.simd_w == 4)
1424 ? pick(ndims - 3, OIw4i4o, OIhw4i4o, OIdhw4i4o)
1425 : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o));
1427 if (src_d.format() == any)
1428 CHECK(src_pd.set_format(src_format));
1429 if (src_d.format() != src_format)
1430 return status::unimplemented;
1432 if (dst_d.format() == any)
1433 CHECK(dst_pd.set_format(dst_format));
1434 if (dst_d.format() != dst_format)
1435 return status::unimplemented;
1437 jcp.with_bias = cd.bias_desc.format != memory_format::undef;
1438 if (jcp.with_bias) {
1439 if (bias_d.format() == any)
1440 CHECK(bias_pd.set_format(x));
1441 if (bias_d.format() != x)
1442 return status::unimplemented;
1445 if ((mayiuse(avx512_mic_4ops) || mayiuse(avx512_core_vnni))
1446 && src_d.data_type() == data_type::s16
1447 && weights_d.data_type() == data_type::s16
1448 && dst_d.data_type() == data_type::s32)
1451 return status::unimplemented;
1453 if (mayiuse(avx512_mic_4ops)) {
1454 jcp.ver = ver_4vnni;
1458 jcp.typesize_in = sizeof(int16_t);
1459 jcp.typesize_out = sizeof(int32_t);
1461 const auto w_format = with_groups
1462 ? pick(ndims - 3, gOIw8i16o2i, gOIhw8i16o2i, gOIdhw8i16o2i)
1463 : pick(ndims - 3, OIw8i16o2i, OIhw8i16o2i, OIdhw8i16o2i);
1464 if (weights_d.format() == any)
1465 CHECK(weights_pd.set_format(w_format));
1466 if (weights_d.format() != w_format)
1467 return status::unimplemented;
1468 } else if (mayiuse(avx512_common) &&
1469 src_d.data_type() == data_type::f32
1470 && weights_d.data_type() == data_type::f32
1471 && dst_d.data_type() == data_type::f32) {
1473 jcp.typesize_in = sizeof(float);
1474 jcp.typesize_out = sizeof(float);
1475 if (mayiuse(avx512_mic_4ops))
1478 if (jcp.is_1stconv) {
1479 // TODO: fix & remove constraints below
1481 = IMPLICATION(everyone_is(0, jcp.l_pad, jcp.t_pad),
1482 nstl::max(jcp.kw, jcp.kh) < 7);
1484 = !everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w);
1485 if (one_of(true, not_for_4fma, is_dilated))
1487 if (jcp.ver == ver_4fma) {
1488 const auto w_format = with_groups
1489 ? ((jcp.simd_w == 4)
1490 ? pick(ndims - 3, gOiw4o, gOihw4o, gOidhw4o)
1491 : pick(ndims - 3, gOiw16o, gOihw16o, gOidhw16o))
1492 : ((jcp.simd_w == 4)
1493 ? pick(ndims - 3, Oiw4o, Oihw4o, Oidhw4o)
1494 : pick(ndims - 3, Oiw16o, Oihw16o, Oidhw16o));
1495 if (weights_d.format() == any)
1496 CHECK(weights_pd.set_format(w_format));
1497 if (weights_d.format() != w_format)
1498 return status::unimplemented;
1500 const auto w_format = with_groups
1501 ? ((jcp.simd_w == 4)
1502 ? pick(ndims - 3, gOwi4o, gOhwi4o, gOdhwi4o)
1503 : pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o))
1504 : ((jcp.simd_w == 4)
1505 ? pick(ndims - 3, Owi4o, Ohwi4o, Odhwi4o)
1506 : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o));
1507 if (weights_d.format() == any)
1508 CHECK(weights_pd.set_format(w_format));
1509 if (weights_d.format() != w_format)
1510 return status::unimplemented;
1513 if (weights_d.format() == any)
1514 CHECK(weights_pd.set_format(wei_format));
1515 if (weights_d.format() != wei_format)
1516 return status::unimplemented;
1519 return status::unimplemented;
1522 if (jcp.is_1stconv) {
1523 jcp.ur_w = nstl::min(jcp.ow, regs);
1525 // avx512_core guard - just to avoid possible regression for other archs
1526 if (jcp.ver == ver_fma && mayiuse(avx512_core)) {
1527 jcp.ur_w = nstl::min(jcp.ow, regs);
1529 for (int ur_w = regs; ur_w > 0; --ur_w) {
1530 if (jcp.ow % ur_w == 0) {
1536 if ((ndims == 5 && jcp.ur_w <= 8) || (jcp.ur_w <= 1)) {
1537 jcp.ur_w = nstl::min(jcp.ow, regs);
1540 // TODO (Tanya): currently applied to Segnet convolutions only.
1541 // Need to try for other topologies
1542 if (jcp.ow > 150 && jcp.ur_w < regs/2)
1545 int n_oi = (jcp.ow / jcp.ur_w);
1546 int r_pad = (jcp.ur_w * n_oi - 1) * jcp.stride_w
1547 + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1);
1548 if (jcp.l_pad > 0 && r_pad > 0)
1551 bool large_code_size = jcp.ur_w != jcp.ow && jcp.l_pad > 0 && r_pad > 0
1552 && ((jcp.l_pad <= 0 && n_oi > 0) || (jcp.l_pad > 0 && n_oi > 1));
1553 if (large_code_size) {
1554 const int max_code_size = 24 * 1024;
1555 const int num_ops_per_reg = 6 + jcp.ic_block * jcp.kw;
1557 if (jcp.l_pad > 0) mult += 1;
1558 if (r_pad > 0) mult += 1;
1559 for (int ur_w = jcp.ur_w; ur_w > regs/2; --ur_w) {
1560 if (ur_w * mult * num_ops_per_reg * 9.0 < max_code_size) {
1567 /* Grouped channel offset to support 'non-blocked data' format for
1568 * convolution sizes with '(input_channel / ngroups) < simd' */
1569 jcp.nonblk_group_off
1570 = (jcp.ngroups > 1 && one_of(src_d.format(), ncw, nchw, ncdhw)) ?
1574 jcp.nb_ic = jcp.ic / jcp.ic_block;
1575 jcp.nb_oc = jcp.oc / jcp.oc_block;
1576 jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
1578 auto is_ow_threading_applicable = [=]() {
1579 return (true && !jcp.is_1stconv && one_of(jcp.ndims, 3, 4)
1580 && IMPLICATION(mayiuse(avx512_mic),
1582 && IMPLICATION(jcp.mb != 1,
1583 jcp.ih == 1 && jcp.kh == 1)));
1586 if (jcp.ver == ver_4vnni) {
1587 jcp.kernel_kind = embd_bcast;
1589 if (jcp.ver == ver_vnni) {
1590 // TODO: kernel_kind and nb_oc_blocking selection
1591 // should be tuned on real HW
1592 if (jcp.ow <= 8 && jcp.oh <= 8 && jcp.od <= 8) {
1593 jcp.kernel_kind = expl_bcast;
1594 jcp.nb_oc_blocking = 2;
1596 jcp.kernel_kind = embd_bcast;
1597 jcp.nb_oc_blocking = 2;
1599 if (jcp.nb_oc_blocking > 1) {
1600 if (jcp.nb_oc < jcp.nb_oc_blocking) jcp.nb_oc_blocking = jcp.nb_oc;
1601 if (jcp.nb_oc % jcp.nb_oc_blocking != 0)
1602 for (int i = jcp.nb_oc_blocking; i > 0; i--)
1603 if (jcp.nb_oc % i == 0) {
1604 jcp.nb_oc_blocking = i;
1607 jcp.ur_w = 31 / (jcp.nb_oc_blocking + 1);
1608 if (jcp.ow < jcp.ur_w)
1613 if (one_of(jcp.ver, ver_4vnni, ver_4fma) && !jcp.is_1stconv) {
1614 if ((jcp.kw <= 5 && jcp.kh <= 5 && jcp.kw == jcp.kh && jcp.ow <= 8
1615 && jcp.oh <= 8 && jcp.ow == jcp.oh)
1616 || (jcp.stride_h != 1 && jcp.ur_w < jcp.ow)) {
1617 if (jcp.nb_oc % 2 == 0) {
1618 jcp.nb_oc_blocking = 2;
1619 jcp.ur_w = nstl::min(jcp.ow, regs / jcp.nb_oc_blocking);
1622 for (int i = jcp.nb_oc; i > 0; i--)
1623 if (i * jcp.ur_w <= regs && jcp.nb_oc % i == 0) {
1624 jcp.nb_oc_blocking = i;
1628 if (jcp.ver == ver_4fma && is_ow_threading_applicable()) {
1629 if (jcp.nb_oc % 2 == 0 && jcp.ur_w < jcp.ow
1630 && jcp.ow != 2 * jcp.ur_w) {
1631 jcp.nb_oc_blocking = 2;
1632 jcp.ur_w = nstl::min(jcp.ow, regs / jcp.nb_oc_blocking);
1637 jcp.ow_block = jcp.ow;
1639 auto get_thr_eff = [=](int nb_oc_blocking, int ow_block) {
1640 int nb_ow = div_up(jcp.ow, ow_block);
1641 int nb_oc_chunks = div_up(jcp.nb_oc, nb_oc_blocking);
1642 int work_amount = jcp.mb * jcp.oh * nb_oc_chunks * nb_ow;
1643 float disbalance = (float)jcp.ow / rnd_up(jcp.ow, ow_block);
1644 float thr_eff = disbalance * (float)work_amount
1645 / rnd_up(work_amount, nthreads);
1649 auto get_ow_block = [=](int nb_oc_blocking, int ur_w, float &eff) {
1650 int res_ow_block = jcp.ow;
1651 eff = get_thr_eff(nb_oc_blocking, res_ow_block);
1652 if (!is_ow_threading_applicable())
1653 return res_ow_block;
1655 int L2_part = (get_cache_size(2) * 7 / 8) / typesize;
1656 if (jcp.ver == ver_4fma)
1658 int size_src_chunk = jcp.ic_block * ur_w * jcp.kh;
1659 int size_dst_chunk = jcp.oc_block * nb_oc_blocking * ur_w;
1660 int size_wei_chunk = jcp.oc_block * nb_oc_blocking * jcp.ic_block
1662 int nurw_cache = (L2_part - 2 * size_wei_chunk)
1663 / (2 * size_dst_chunk + 2 * size_src_chunk);
1664 // current design of generate() requires ow_block >= 2 * ur_w
1665 int ow_block_cache = ur_w * nstl::max(2, nurw_cache);
1667 int ow_block_thr = ow_block_cache;
1668 eff = get_thr_eff(nb_oc_blocking, ow_block_thr);
1670 int max_nb_ow = div_up(jcp.ow, 2 * ur_w);
1671 int start_nb_ow = div_up(jcp.ow, ow_block_thr);
1672 for (int nb_ow = start_nb_ow; nb_ow <= max_nb_ow; nb_ow++) {
1674 = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), ur_w), jcp.ow);
1675 float eff_threshold = (jcp.ver == ver_4fma) ? 0.8f : 0.9f;
1676 if (ow_block < nb_oc_blocking * jcp.oc_block && eff > eff_threshold)
1678 if (div_up(jcp.ow, ow_block) != nb_ow)
1680 float thr_eff = get_thr_eff(nb_oc_blocking, ow_block);
1681 float eff_step = (jcp.ver == ver_4fma) ? 1.1f : 1.f;
1682 if (ow_block >= 2 * ur_w && thr_eff > eff_step * eff) {
1683 ow_block_thr = ow_block;
1686 eff_threshold = (jcp.ver == ver_4fma) ? 0.9f : 0.98f;
1687 if (eff > eff_threshold)
1690 res_ow_block = nstl::min(jcp.ow, nstl::max(2 * ur_w, ow_block_thr));
1691 eff = get_thr_eff(nb_oc_blocking, res_ow_block);
1692 return res_ow_block;
1696 if (jcp.ver == ver_fma && mayiuse(avx512_core)) {
1697 int try_nb_oc_blocking = 2;
1698 unsigned int ker_inp_size = typesize * div_up(jcp.iw, jcp.stride_w)
1699 * jcp.ic_block * jcp.kh * jcp.kd;
1700 unsigned int ker_out_size = typesize * jcp.ow * jcp.oc_block
1701 * try_nb_oc_blocking;
1702 unsigned int ker_wei_size = typesize * jcp.kh * jcp.kw * jcp.ic_block
1703 * jcp.oc_block * try_nb_oc_blocking * jcp.kd;
1704 unsigned int ker_total_size = ker_inp_size + ker_out_size
1707 bool embd_bcast_condition = true
1708 && (jcp.kw == 3 && jcp.ow <= 28 && ker_total_size < L1_cache_size)
1709 && !(jcp.kw == 3 && jcp.ow == 13 && jcp.ic >= 192)
1710 && !(jcp.kw == 3 && jcp.ow == 28 && jcp.ic >= 512);
1713 unsigned int inp_size = jcp.mb * div_up(jcp.ih, jcp.stride_h)
1714 * div_up(jcp.iw, jcp.stride_w) * jcp.ic;
1715 unsigned int wei_size = jcp.ic * jcp.oc * jcp.kh * jcp.kw;
1717 // Estimate whether we need to limit the number of threads
1718 // and calculate this number. Includes some heuristic.
1719 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
1720 int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.oh;
1721 int job_size_min = work_amount / nthreads;
1722 int job_size_max = div_up(work_amount, nthreads);
1723 int ch_max = rnd_up(jcp.oh, job_size_max);
1724 int ch_min = (job_size_min == 0)
1726 : rnd_up(jcp.oh, job_size_min);
1727 bool not_aligned_max = ch_max % jcp.oh != 0 && ch_max / jcp.oh < 2
1728 && (jcp.oh != 8 || ch_max / jcp.oh > 1);
1729 bool not_aligned_min = ch_min % jcp.oh != 0 && ch_min / jcp.oh < 2
1730 && (jcp.oh != 8 || ch_min / jcp.oh > 1);
1731 bool eligible_case = (jcp.stride_h == 1 && jcp.stride_w == 1)
1732 || nthreads > oc_chunks;
1733 if (jcp.loop_order == loop_cgn && oc_chunks > 1 && nthreads > 1
1734 && wei_size / inp_size > 24
1735 && (not_aligned_max || not_aligned_min)
1737 jcp.aligned_threads = nthreads;
1738 for (int i = nthreads; i > 0; i--) {
1739 if (oc_chunks % i == 0 || i % oc_chunks == 0) {
1740 jcp.aligned_threads = i;
1748 || (jcp.stride_w == 1 && jcp.stride_h == 1
1749 && embd_bcast_condition)
1750 || ((jcp.stride_w != 1 || jcp.stride_h != 1)
1751 && ((jcp.mb <= 16 && (jcp.oc <= 192 || jcp.oh <= 10)
1752 && embd_bcast_condition)))
1754 && (jcp.ur_w >= jcp.ow || jcp.is_1stconv
1755 || (jcp.ow <= 147 && jcp.oc <= 96)))) {
1756 jcp.kernel_kind = embd_bcast;
1757 jcp.ur_w = nstl::min(jcp.ow, regs);
1758 jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
1759 if (ker_total_size < L1_cache_size && jcp.ow <= 8 && jcp.kh <= 3
1760 && jcp.kw <= 3 && jcp.nb_oc % try_nb_oc_blocking == 0
1761 && IMPLICATION(jcp.is_1stconv, jcp.mb == 1)
1762 && IMPLICATION(jcp.mb == 1, jcp.ur_w < jcp.ow)) {
1763 jcp.nb_oc_blocking = try_nb_oc_blocking;
1764 jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1));
1767 jcp.kernel_kind = expl_bcast;
1768 jcp.nb_ic_blocking = 1;
1769 if (IMPLICATION(jcp.is_1stconv, jcp.mb > 1)) {
1770 float best_thr_eff = 0.f;
1771 int best_nb_oc_blocking = 1;
1772 for (int i = nstl::min(jcp.nb_oc, 5); i > 0; i--) {
1773 if (jcp.nb_oc % i == 0) {
1775 int ur_w = nstl::min(jcp.ow, 31 / (i + 1));
1776 get_ow_block(i, ur_w, thr_eff);
1777 if (thr_eff > 1.05f * best_thr_eff) {
1778 best_nb_oc_blocking = i;
1779 best_thr_eff = thr_eff;
1783 jcp.nb_oc_blocking = best_nb_oc_blocking;
1784 jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1));
1789 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
1792 && jcp.l_pad <= jcp.ur_w
1793 && jcp.ic <= src_d.blocking_desc().padding_dims[1]
1794 && jcp.oc <= dst_d.blocking_desc().padding_dims[1]
1795 && jcp.ic <= weights_d.blocking_desc().padding_dims[with_groups + 1]
1796 && jcp.oc <= weights_d.blocking_desc().padding_dims[with_groups + 0];
1798 return status::unimplemented;
1800 int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
1801 + (jcp.kw - 1) * (jcp.dilate_w + 1)
1802 - (jcp.iw + jcp.l_pad - 1));
1803 if (r_pad_no_tail > jcp.ur_w)
1804 return status::unimplemented;
1806 pick_loop_order(jcp);
1808 jcp.nb_ic_L2 = jcp.nb_ic;
1811 jcp.ow_block = get_ow_block(jcp.nb_oc_blocking, jcp.ur_w, thr_eff);
1812 jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
1814 const int L2_size = get_cache_size(2, true) / sizeof(float);
1815 // Source and output data needs to fit in L2,
1816 // leaving some space for weights and prefetching.
1817 int h_L2 = int(((0.6f * L2_size) / jcp.simd_w
1818 - nstl::min(0, jcp.kh - jcp.stride_h) * jcp.iw)
1819 / (jcp.stride_h * jcp.iw + jcp.ow));
1820 jcp.h_blocking = nstl::max(1, nstl::min(jcp.oh, h_L2));
1822 // TODO check for 4vnni
1823 if (jcp.ver == ver_4fma) {
1824 if (!is_ow_threading_on(jcp)) {
1825 for (int divf = 2, temp_nb = jcp.nb_ic_L2; divf <= jcp.nb_ic;
1828 = (size_t)jcp.iw * jcp.ic_block * jcp.ih * temp_nb * jcp.id;
1829 size_t l2_dst = (size_t)jcp.ow * jcp.oc_block * jcp.nb_oc_blocking
1831 size_t l2_filt = (size_t)jcp.kw * jcp.oc_block * jcp.ic_block
1832 * jcp.kh * jcp.nb_oc_blocking * temp_nb * jcp.kd;
1833 if (4 * (l2_src + l2_dst + l2_filt) > KNx_L2_EFFECTIVE_CAPACITY) {
1834 if (jcp.kh == 3 && jcp.oh == 7) {
1838 temp_nb = (jcp.nb_ic_L2 % divf == 0 ? jcp.nb_ic_L2 / divf
1841 jcp.nb_ic_L2 = temp_nb;
1845 } else if (jcp.ic > 64) {
1846 jcp.nb_ic_L2 = 2; /* according to performance data*/
1850 return status::success;
1853 void jit_avx512_common_conv_fwd_kernel::init_scratchpad(
1854 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
1855 if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
1856 scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc);
1859 void jit_avx512_common_conv_bwd_data_kernel_f32::prepare_output(int ur_w)
1861 for (int k = 0; k < jcp.nb_ic_blocking; k++) {
1862 for (int j = 0; j < ur_w; j++) {
1863 Zmm zmm = zmm_out(j, k);
1864 vpxord(zmm, zmm, zmm);
1865 size_t aux_src_offset
1866 = (size_t)typesize * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j)
1868 mic_prefetcht1(EVEX_compress_addr_safe(reg_src_prf, aux_src_offset,
1874 void jit_avx512_common_conv_bwd_data_kernel_f32::store_output(int ur_w)
1876 Label no_update_label;
1878 mov(reg_channel, ptr[param + GET_OFF(channel)]);
1879 cmp(reg_channel, 0);
1880 je(no_update_label, T_NEAR);
1881 for (int k = 0; k < jcp.nb_ic_blocking; k++) {
1882 for (int j = 0; j < ur_w; j++) {
1883 Zmm zmm = zmm_out(j, k);
1884 size_t aux_src_offset = (size_t)typesize
1885 * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block;
1886 vadd(zmm, EVEX_compress_addr_safe(reg_src, aux_src_offset,
1892 for (int k = 0; k < jcp.nb_ic_blocking; k++) {
1893 for (int j = 0; j < ur_w; j++) {
1894 Zmm zmm = zmm_out(j, k);
1895 size_t aux_src_offset = (size_t)typesize
1896 * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block;
1897 vmovups(EVEX_compress_addr_safe(reg_src, aux_src_offset,
1898 reg_long_offt), zmm);
1899 mic_prefetcht0(EVEX_compress_addr_safe(reg_src_prf, aux_src_offset,
1905 void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_4fma(
1906 int ur_w, int l_overflow, int r_overflow)
1910 int ic_block = jcp.ic_block;
1911 int oc_block = jcp.oc_block;
1912 Label kh_label, last_iter_label, loop_end_label, kd_label;
1913 int ker_load_number = 4;
1914 int shift_ker_ptr = typesize * kw * oc_block * ic_block;
1915 int shift_dst_ptr = typesize * ow * oc_block;
1916 int ii_dpref_t0 = get_iw_start(0, l_overflow);
1917 int iw_end_ipref = get_iw_end(ur_w, 0, r_overflow);
1919 bool check_last_kh = (jcp.kh > 3);
1920 auto kernel_offset = [=](int icb, int oc, int ki) {
1921 int blk_idx = icb * jcp.kh * jcp.kw * jcp.kd + ki;
1922 int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block;
1923 int oc_offset = oc * jcp.oc_block;
1924 return typesize * (blk_offset + oc_offset);
1926 auto kernel_loads = [=](int ki, int oc, int kk) {
1927 for (int ii = 0; ii < ker_load_number; ii++) {
1928 int aux_kernel_offset = kernel_offset(kk, oc + ii, ki);
1929 vmovups(zmm_ker(ii),
1930 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
1933 auto prefetch_dst_next_kh = [&](int ki, int ki_start, int cnt0, int cnt1) {
1934 if (cnt1 >= ker_load_number && cnt0 >= ker_load_number
1935 && ki >= ki_start && ii_dpref_t0 < iw_end_ipref) {
1936 int aux_dst_offset = typesize * ((ii_dpref_t0
1937 + jcp.l_pad) * oc_block + jcp.ow * oc_block);
1938 prefetcht0(EVEX_compress_addr(aux_reg_dst, aux_dst_offset));
1943 if (one_of(jcp.ndims, 3, 4)) {
1944 mov(aux_reg_dst, reg_dst);
1945 mov(aux_reg_ker, reg_ker);
1946 mov(aux_reg_dst_prf, reg_dst_prf);
1947 mov(aux_reg_ker_prf, reg_ker_prf);
1950 if (jcp.ndims == 5) {
1954 mov(reg_ki, ptr[param + GET_OFF(kd_padding)]);
1955 mov(aux_reg_dst_d, reg_dst);
1956 mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]);
1957 mov(aux_reg_dst_d_prf, reg_dst_prf);
1958 mov(aux_reg_ker_d_prf, reg_ker_prf);
1961 mov(reg_kj, ptr[param + GET_OFF(kh_padding)]);
1963 mov(reg_kj, reg_kh);
1966 if (jcp.ndims == 5) {
1967 mov(aux_reg_dst, aux_reg_dst_d);
1968 mov(aux_reg_ker, aux_reg_ker_d);
1969 mov(aux_reg_dst_prf, aux_reg_dst_d_prf);
1970 mov(aux_reg_ker_prf, aux_reg_ker_d_prf);
1975 if (check_last_kh) {
1976 for (int ki = 0; ki < kw; ki++)
1977 for (int oc = 0; oc < oc_block; oc += 4)
1978 for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) {
1979 bool last_kernel_loads = (kk == jcp.nb_ic_blocking - 1
1980 && ki == kw - 1 && (oc + 4) == oc_block);
1982 if (last_kernel_loads) {
1984 je(last_iter_label, T_NEAR);
1987 kernel_loads(ki, oc, kk);
1988 for (int ii = get_iw_start(ki, l_overflow),
1989 prf_count_t0 = 0, prf_count_t1 = 0;
1990 ii < get_iw_end(ur_w, ki, r_overflow); ii++) {
1991 int aux_dst_offset = typesize
1992 * ((ii + jcp.l_pad - ki) * oc_block + oc);
1993 v4fmaddps(zmm_out(ii, kk), zmm_ker(0),
1994 EVEX_compress_addr(aux_reg_dst, aux_dst_offset));
1997 if (prf_count_t0 < 4) {
1999 if (last_kernel_loads)
2000 aux_kernel_prf= kernel_offset(0, prf_count_t0
2001 + oc + 4 - oc_block, 0) + typesize * kw
2002 * oc_block * ic_block;
2004 aux_kernel_prf = kernel_offset(kk, oc + 4
2005 + prf_count_t0, ki);
2006 mic_prefetcht0(EVEX_compress_addr(aux_reg_ker,
2009 } else if (prf_count_t1 < 4) {
2010 mic_prefetcht1(EVEX_compress_addr(aux_reg_ker_prf,
2011 kernel_offset(kk, oc + prf_count_t1, ki)));
2015 prefetch_dst_next_kh(ki, 2, prf_count_t0, prf_count_t1);
2017 if (last_kernel_loads) {
2018 jmp(loop_end_label, T_NEAR);
2022 kernel_loads(ki, oc, kk);
2023 for (int ii = get_iw_start(ki, l_overflow),
2024 prf_count_t0 = 0, prf_count_t1 = 0;
2025 ii < get_iw_end(ur_w, ki, r_overflow); ii++) {
2026 int aux_dst_offset = typesize
2027 * ((ii + jcp.l_pad - ki) * oc_block + oc);
2028 v4fmaddps(zmm_out(ii, kk), zmm_ker(0),
2029 EVEX_compress_addr(aux_reg_dst, aux_dst_offset));
2031 if (prf_count_t0 < 4) {
2032 mic_prefetcht0(EVEX_compress_addr(aux_reg_ker_prf,
2033 kernel_offset(0, prf_count_t0, 0)));
2035 } else if (prf_count_t1 < 4) {
2036 mic_prefetcht1(EVEX_compress_addr(aux_reg_ker_prf,
2037 kernel_offset(kk, oc + prf_count_t1, ki)));
2046 for (int ki = 0; ki < kw; ki++)
2047 for (int oc = 0; oc < oc_block; oc += 4)
2048 for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) {
2049 kernel_loads(ki, oc, kk);
2051 for (int ii = get_iw_start(ki, l_overflow), prf_count_t1 = 0;
2052 ii < get_iw_end(ur_w, ki, r_overflow); ii++) {
2053 int aux_dst_offset = typesize
2054 * ((ii + jcp.l_pad - ki) * oc_block + oc);
2055 v4fmaddps(zmm_out(ii, kk), zmm_ker(0),
2056 EVEX_compress_addr(aux_reg_dst, aux_dst_offset));
2057 if ((ii % 2) && (prf_count_t1 < 4)) {
2058 mic_prefetcht1(EVEX_compress_addr(
2059 aux_reg_ker_prf, kernel_offset(kk,
2060 oc + prf_count_t1, ki)));
2063 if ( ki == 1 && oc == 0 && kk == 0)
2064 mic_prefetcht1(EVEX_compress_addr(
2065 aux_reg_dst_prf, aux_dst_offset));
2070 add(aux_reg_ker, shift_ker_ptr);
2071 sub(aux_reg_dst, shift_dst_ptr);
2072 add(aux_reg_ker_prf, shift_ker_ptr);
2073 sub(aux_reg_dst_prf, shift_dst_ptr);
2077 jg(kh_label, T_NEAR);
2079 if (jcp.ndims == 5) {
2080 sub(aux_reg_dst_d, typesize * (jcp.oh * ow) * ic_block);
2081 add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block * ic_block);
2082 sub(aux_reg_dst_d_prf, typesize * (jcp.oh * ow) * ic_block);
2083 add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh *oc_block * ic_block);
2087 jg(kd_label, T_NEAR);
2094 void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_vnni(
2095 int ur_w, int l_overflow, int r_overflow)
2099 int ic_block = jcp.ic_block;
2100 int oc_block = jcp.oc_block;
2101 const int channel_inc = jcp.ver == ver_4vnni ? 4 : 1;
2102 const int ker_load_number = jcp.ver == ver_4vnni ? 4 : 1;
2105 auto kernel_offset = [=](int icb, int oc, int ki) {
2106 int blk_idx = icb * jcp.kh * jcp.kw + ki;
2107 int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block;
2108 int oc_offset = oc * jcp.oc_block;
2109 return jcp.typesize_in * (blk_offset + oc_offset);
2112 mov(aux_reg_dst, reg_dst);
2113 mov(aux_reg_ker, reg_ker);
2114 mov(aux_reg_dst_prf, reg_dst_prf);
2115 mov(aux_reg_ker_prf, reg_ker_prf);
2117 mov(reg_kj, reg_kh);
2119 for (int ki = 0; ki < kw; ki++) {
2120 int jj_start = get_iw_start(ki, l_overflow);
2121 int jj_end = get_iw_end(ur_w, ki, r_overflow);
2122 for (int oc = 0; oc < oc_block / 2; oc += channel_inc) {
2123 if (jcp.kernel_kind == expl_bcast) {
2124 for (int jj = jj_start; jj < jj_end; jj++) {
2125 int aux_dst_offset = jcp.typesize_in
2126 * ((jj + jcp.l_pad - ki) * oc_block + 2 * oc);
2127 vpbroadcastd(zmm_inp(jj, jcp.nb_ic_blocking),
2128 ptr[aux_reg_dst + aux_dst_offset]);
2131 for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) {
2132 if (jcp.kernel_kind == expl_bcast) {
2133 int aux_kernel_offset = kernel_offset(kk, 2 * oc, ki);
2135 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
2137 for (int ii = 0; ii < ker_load_number; ii++) {
2138 int aux_kernel_offset
2139 = kernel_offset(kk, 2 * (oc + ii), ki);
2140 vmovups(zmm_ker(ii),
2141 EVEX_compress_addr(aux_reg_ker,
2142 aux_kernel_offset));
2146 for (int jj = jj_start, prf_count = 0; jj < jj_end; jj++) {
2147 int aux_dst_offset = jcp.typesize_in
2148 * ((jj + jcp.l_pad - ki) * oc_block + 2 * oc);
2149 if (jcp.kernel_kind == expl_bcast) {
2150 vpdpwssd(zmm_out(jj, kk), zmm_wei,
2151 zmm_inp(jj, jcp.nb_ic_blocking));
2153 vpXdpwssd(zmm_out(jj, kk), zmm_ker(0),
2154 aux_reg_dst, aux_dst_offset);
2157 if ((jj % 2) && (prf_count < 4)) {
2159 = kernel_offset(kk, oc + prf_count, ki);
2160 mic_prefetcht1(EVEX_compress_addr(
2161 aux_reg_ker_prf, aux_kernel_prf));
2164 if (!(jj % 2) && ki == 0 && oc == 0 && kk == 0) {
2165 mic_prefetcht1(EVEX_compress_addr(aux_reg_dst_prf,
2168 if (!(jj % 2) && ki == 1 && oc == 0 && kk == 0) {
2169 mic_prefetcht0(EVEX_compress_addr(aux_reg_dst,
2170 aux_dst_offset + jcp.typesize_in
2178 add(aux_reg_ker, jcp.typesize_in * kw * oc_block * ic_block);
2179 sub(aux_reg_dst, jcp.typesize_in * ow * oc_block);
2180 add(aux_reg_ker_prf, jcp.typesize_in * kw * oc_block * ic_block);
2181 sub(aux_reg_dst_prf, jcp.typesize_in * ow * oc_block);
2185 jg(kh_label, T_NEAR);
2189 void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma(
2190 int ur_w, int l_overflow, int r_overflow)
2192 Label kh_label, kd_label;
2196 int ic_block = jcp.ic_block;
2197 int oc_block = jcp.oc_block;
2198 int l_pad = jcp.l_pad;
2199 int dilate_w = jcp.dilate_w + 1;
2200 int stride_w = jcp.stride_w;
2201 int stride_h = jcp.stride_h;
2203 int ker_pipeline_depth = 4;
2204 assert(ker_reg_base_idx + ker_pipeline_depth <= 32);
2205 assert(oc_block >= ker_pipeline_depth);
2207 int num_ker_loads = oc_block * kw;
2208 int num_inp_prfs = ur_w * nstl::min(kw, stride_w)
2209 + nstl::max(0, kw - stride_w);
2210 int num_prfs = num_ker_loads + num_inp_prfs;
2211 int num_fmas = num_ker_loads * ur_w / stride_w;
2212 int prf_inst_spacing = nstl::max(1, num_fmas / num_prfs);
2213 int prf_inst_trigger = (num_fmas % prf_inst_spacing) / 2;
2215 if (one_of(jcp.ndims, 3, 4)) {
2216 mov(aux_reg_dst, reg_dst);
2217 mov(aux_reg_ker, reg_ker);
2219 mov(aux_reg_dst_prf, reg_dst_prf);
2220 mov(aux_reg_ker_prf, reg_ker_prf);
2223 if (jcp.ndims == 5) {
2227 mov(reg_ki, ptr[param + GET_OFF(kd_padding)]);
2228 mov(aux_reg_dst_d, reg_dst);
2229 mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]);
2230 mov(aux_reg_dst_d_prf, reg_dst_prf);
2231 mov(aux_reg_ker_d_prf, reg_ker_prf);
2234 mov(reg_kj, ptr[param + GET_OFF(kh_padding)]);
2236 mov(reg_kj, reg_kh);
2239 if (jcp.ndims == 5) {
2240 mov(aux_reg_dst, aux_reg_dst_d);
2241 mov(aux_reg_ker, aux_reg_ker_d);
2242 mov(aux_reg_dst_prf, aux_reg_dst_d_prf);
2243 mov(aux_reg_ker_prf, aux_reg_ker_d_prf);
2249 for (int ki = 0; ki < kw; ki++) {
2250 for (int oc = 0; oc < oc_block; oc++) {
2252 for (int i = 0; i < ker_pipeline_depth; i++) {
2253 int aux_kernel_offset = typesize * ((oc + i) * oc_block
2254 + ki * ic_block * oc_block);
2255 vmovups(zmm_ker(i), EVEX_compress_addr(
2256 aux_reg_ker, aux_kernel_offset));
2258 } else if (step < num_ker_loads - ker_pipeline_depth + 1) {
2259 int load_offset = ker_pipeline_depth - 1;
2260 int ker_load_reg_idx
2261 = (step + load_offset) % ker_pipeline_depth;
2262 int aux_kernel_offset = typesize * ((oc + load_offset)
2263 * oc_block + ki * ic_block * oc_block);
2264 vmovups(zmm_ker(ker_load_reg_idx),
2265 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
2268 bool ker_prf_inserted = false;
2269 auto zmm_kernel = zmm_ker(step % ker_pipeline_depth);
2271 int jj_start = get_iw_start(ki, l_overflow);
2272 int jj_end = get_iw_end(ur_w, ki, r_overflow);
2273 assert(stride_w != 1
2274 || jj_start == nstl::max(0,
2275 l_overflow - (kw - 1 - ki) * dilate_w));
2276 assert(stride_w != 1
2277 || jj_end == ur_w - nstl::max(0,
2278 r_overflow - ki * dilate_w));
2280 for (int jj = jj_start; jj < jj_end; jj += stride_w) {
2281 assert((jj + l_pad - ki * dilate_w) % stride_w == 0);
2282 int aux_dst_offset = typesize *
2283 (((jj + l_pad - ki * dilate_w)
2284 / stride_w) * jcp.oc_block + oc);
2285 vfmadd231ps(zmm_out(jj, 0), zmm_kernel,
2286 EVEX_compress_addr(aux_reg_dst, aux_dst_offset, true));
2288 int fma_idx = (step * ur_w + jj) / stride_w;
2289 int prf_slot_idx = fma_idx / prf_inst_spacing;
2290 if (fma_idx % prf_inst_spacing == prf_inst_trigger) {
2291 if (!ker_prf_inserted && ker_prfs < num_ker_loads) {
2292 int ker_prf_offset = typesize
2293 * ker_prfs * jcp.oc_block;
2294 mic_prefetcht1(EVEX_compress_addr(
2295 aux_reg_ker_prf, ker_prf_offset));
2296 ker_prf_inserted = true;
2299 int inp_prf_idx = prf_slot_idx - ker_prfs;
2300 if (inp_prf_idx < num_inp_prfs) {
2302 = ic_block * typesize
2303 * ((inp_prf_idx / kw) * kw
2304 + (inp_prf_idx % kw));
2305 mic_prefetcht0(EVEX_compress_addr(
2306 aux_reg_dst_prf, inp_prf_offset));
2315 add(aux_reg_ker, typesize * stride_h * kw * oc_block * ic_block);
2316 sub(aux_reg_dst, typesize * (jcp.dilate_h + 1) * ow * oc_block);
2317 add(aux_reg_ker_prf, typesize * stride_h * kw * oc_block * ic_block);
2318 sub(aux_reg_dst_prf, typesize * (jcp.dilate_h + 1) * ow * oc_block);
2322 jg(kh_label, T_NEAR);
2324 if (jcp.ndims == 5) {
2326 typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block);
2327 add(aux_reg_ker_d, typesize * jcp.stride_d * jcp.kw * jcp.kh
2328 * oc_block * ic_block);
2329 sub(aux_reg_dst_d_prf,
2330 typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block);
2331 add(aux_reg_ker_d_prf, typesize * jcp.stride_d * jcp.kw * jcp.kh
2332 * oc_block * ic_block);
2336 jg(kd_label, T_NEAR);
2346 void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma_core(
2347 int ur_w, int l_overflow, int r_overflow)
2351 int dilate_w = jcp.dilate_w + 1;
2352 int stride_w = jcp.stride_w;
2353 int ic_block = jcp.ic_block;
2354 int oc_block = jcp.oc_block;
2355 int nb_ic_block = jcp.nb_ic_blocking;
2356 Label kh_label, kd_label;
2358 int shift_ker_ptr = typesize * kw * oc_block * ic_block;
2359 int shift_dst_ptr = typesize * (jcp.dilate_h + 1) * ow * oc_block;
2361 auto output_offset = [=](int oi, int oc, int ki) {
2363 (((oi + jcp.l_pad - ki * dilate_w) / stride_w) * oc_block + oc);
2365 auto kernel_offset = [=](int icb, int oc, int ki) {
2366 int blk_idx = icb * jcp.kh * jcp.kw * jcp.kd + ki;
2367 int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block;
2368 int oc_offset = oc * jcp.oc_block;
2369 return typesize * (blk_offset + oc_offset);
2372 if (one_of(jcp.ndims, 3, 4)) {
2373 mov(aux_reg_dst, reg_dst);
2374 mov(aux_reg_ker, reg_ker);
2377 if (jcp.ndims == 5) {
2381 mov(reg_ki, ptr[param + GET_OFF(kd_padding)]);
2382 mov(aux_reg_dst_d, reg_dst);
2383 mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]);
2386 mov(reg_kj, ptr[param + GET_OFF(kh_padding)]);
2388 mov(reg_kj, reg_kh);
2391 if (jcp.ndims == 5) {
2392 mov(aux_reg_dst, aux_reg_dst_d);
2393 mov(aux_reg_ker, aux_reg_ker_d);
2398 for (int ki = 0; ki < kw; ki++) {
2399 int jj_start = get_iw_start(ki, l_overflow);
2400 int jj_end = get_iw_end(ur_w, ki, r_overflow);
2401 for (int oc = 0; oc < oc_block; oc++) {
2402 if (jcp.kernel_kind == expl_bcast) {
2403 for (int jj = jj_start; jj < jj_end; jj++) {
2404 int aux_output_offset = output_offset(jj, oc, ki);
2405 vbroadcastss(zmm_inp(jj, nb_ic_block),
2406 ptr[aux_reg_dst + aux_output_offset]);
2409 for (int ii = 0; ii < nb_ic_block; ii++) {
2410 int aux_kernel_offset = kernel_offset(ii, oc, ki);
2411 if (jj_end - jj_start > 0)
2412 vmovups(zmm_wei, EVEX_compress_addr(aux_reg_ker,
2413 aux_kernel_offset));
2414 for (int jj = jj_start; jj < jj_end; jj += stride_w)
2415 if (jcp.kernel_kind == expl_bcast)
2416 vfmadd231ps(zmm_out(jj, ii),
2417 zmm_inp(jj, nb_ic_block), zmm_wei);
2419 vfmadd231ps(zmm_out(jj, ii), zmm_wei,
2420 EVEX_compress_addr(aux_reg_dst,
2421 output_offset(jj, oc, ki), true));
2425 add(aux_reg_ker, shift_ker_ptr);
2426 sub(aux_reg_dst, shift_dst_ptr);
2429 jg(kh_label, T_NEAR);
2432 if (jcp.ndims == 5) {
2434 typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block);
2435 add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block * ic_block);
2439 jg(kd_label, T_NEAR);
2446 inline void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop(
2447 int ur_w, int l_overflow, int r_overflow)
2449 if (jcp.ndims == 5) push(reg_oi);
2451 prepare_output(ur_w);
2453 Label skip_compute_loop;
2454 if (jcp.ndims == 5) {
2455 mov(reg_kj, ptr[param + GET_OFF(kd_padding)]);
2457 je(skip_compute_loop, T_NEAR);
2459 mov(reg_kj, ptr[param + GET_OFF(kh_padding)]);
2461 je(skip_compute_loop, T_NEAR);
2463 if (jcp.ver == ver_4vnni || jcp.ver == ver_vnni)
2464 compute_loop_vnni(ur_w, l_overflow, r_overflow);
2465 else if (jcp.ver == ver_4fma)
2466 compute_loop_4fma(ur_w, l_overflow, r_overflow);
2467 else if (jcp.ver == ver_fma)
2468 if (mayiuse(avx512_mic))
2469 compute_loop_fma(ur_w, l_overflow, r_overflow);
2471 if (jcp.kernel_kind == embd_bcast && jcp.nb_ic_blocking == 1)
2472 compute_loop_fma(ur_w, l_overflow, r_overflow);
2474 compute_loop_fma_core(ur_w, l_overflow, r_overflow);
2476 assert("!unknown convolution version");
2478 L(skip_compute_loop);
2480 if (jcp.ndims == 5) pop(reg_oi);
2483 void jit_avx512_common_conv_bwd_data_kernel_f32::generate()
2487 int ur_w = jcp.ur_w;
2488 int ic_block = jcp.ic_block;
2489 int oc_block = jcp.oc_block;
2490 int ur_w_tail = jcp.ur_w_tail;
2491 int dilate_w = jcp.dilate_w + 1;
2492 int stride_w = jcp.stride_w;
2494 int dst_shift = jcp.typesize_in * (ur_w / stride_w) * ic_block;
2495 int src_shift = jcp.typesize_out * ur_w * oc_block;
2499 mov(reg_src, ptr[param + GET_OFF(src)]);
2500 mov(reg_dst, ptr[param + GET_OFF(dst)]);
2501 mov(reg_ker, ptr[param + GET_OFF(filt)]);
2503 mov(reg_kh, ptr[param + GET_OFF(kh_padding)]);
2504 mov(reg_src_prf, ptr[param + GET_OFF(src_prf)]);
2505 mov(reg_dst_prf, ptr[param + GET_OFF(dst_prf)]);
2506 mov(reg_ker_prf, ptr[param + GET_OFF(filt_prf)]);
2508 int l_overflow = nstl::max(0, ((kw - 1) * dilate_w - jcp.l_pad) / stride_w);
2509 int r_overflow = nstl::max(0, ((kw - 1) * dilate_w
2510 - nstl::max(0, jcp.r_pad)) / stride_w);
2511 int r_overflow1 = nstl::max(0, ((kw - 1) * dilate_w
2512 - nstl::max(0, jcp.r_pad) - ur_w_tail) / stride_w);
2514 int n_oi = iw / ur_w;
2515 if (r_overflow1 > 0) n_oi--;
2518 compute_loop(ur_w, l_overflow, r_overflow);
2519 } else if (n_oi == 0) {
2520 compute_loop(ur_w, l_overflow, r_overflow1);
2521 add(reg_src, src_shift);
2522 add(reg_dst, dst_shift);
2523 add(reg_src_prf, src_shift);
2524 add(reg_dst_prf, dst_shift);
2526 compute_loop(ur_w_tail, 0, r_overflow);
2528 xor_(reg_oi, reg_oi);
2529 if (l_overflow > 0) {
2530 compute_loop(ur_w, l_overflow, 0);
2531 add(reg_src, src_shift);
2532 add(reg_dst, dst_shift);
2533 add(reg_src_prf, src_shift);
2534 add(reg_dst_prf, dst_shift);
2538 if ((l_overflow <= 0 && n_oi > 0)
2539 || (l_overflow > 0 && n_oi > 1)) {
2540 Label ow_loop_label;
2542 compute_loop(ur_w, 0, 0);
2543 add(reg_src, src_shift);
2544 add(reg_dst, dst_shift);
2545 add(reg_src_prf, src_shift);
2546 add(reg_dst_prf, dst_shift);
2550 jl(ow_loop_label, T_NEAR);
2553 if (r_overflow1 > 0) {
2554 compute_loop(ur_w, 0, r_overflow1);
2555 add(reg_src, src_shift);
2556 add(reg_dst, dst_shift);
2557 add(reg_src_prf, src_shift);
2558 add(reg_dst_prf, dst_shift);
2560 if (ur_w_tail != 0) {
2561 compute_loop(ur_w_tail, 0, r_overflow);
2568 status_t jit_avx512_common_conv_bwd_data_kernel_f32::init_conf(
2569 jit_conv_conf_t &jcp,
2570 const convolution_desc_t &cd,
2571 const memory_desc_wrapper &diff_src_d,
2572 const memory_desc_wrapper &weights_d,
2573 const memory_desc_wrapper &diff_dst_d)
2575 if (!mayiuse(avx512_common)) return status::unimplemented;
2577 jcp = zero<decltype(jcp)>();
2579 jcp.simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
2580 const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
2581 int ndims = diff_src_d.ndims();
2584 jcp.prop_kind = cd.prop_kind;
2586 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
2587 jcp.mb = diff_src_d.dims()[0];
2589 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
2590 jcp.oc_without_padding = jcp.oc;
2591 jcp.ic = diff_src_d.dims()[1] / jcp.ngroups;
2593 jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1;
2594 jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims-2];
2595 jcp.iw = diff_src_d.dims()[ndims-1];
2596 jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
2597 jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2];
2598 jcp.ow = diff_dst_d.dims()[ndims-1];
2600 jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
2601 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
2602 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
2604 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
2605 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
2606 jcp.l_pad = cd.padding[0][ndims-3];
2608 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
2609 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
2610 jcp.stride_w = cd.strides[ndims-3];
2612 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
2613 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
2614 jcp.dilate_w = cd.dilates[ndims-3];
2615 if ((jcp.dilate_w != 0 && jcp.stride_w != 1)
2616 || (jcp.dilate_d != 0 && jcp.stride_d != 1)
2617 || (jcp.dilate_h != 0 && jcp.stride_h != 1))
2618 return status::unimplemented;
2620 jcp.r_pad = (jcp.ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1)
2621 - (jcp.iw + jcp.l_pad - 1);
2622 jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
2623 - (jcp.ih + jcp.t_pad - 1);
2624 jcp.back_pad = (jcp.od - 1) * jcp.stride_d
2625 + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1);
2627 jcp.aligned_threads = 0;
2629 jcp.is_1stconv = false;
2631 jcp.oc_block = jcp.simd_w;
2632 jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w;
2634 bool ok_to_pad_channels = true
2636 && diff_src_d.data_type() == data_type::f32;
2638 if (ok_to_pad_channels) {
2639 jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
2640 jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
2643 auto src_format = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
2644 auto wei_format = with_groups
2645 ? pick(ndims - 3, gOIw16o16i, gOIhw16o16i, gOIdhw16o16i)
2646 : pick(ndims - 3, OIw16o16i, OIhw16o16i, OIdhw16o16i);
2648 && jcp.oc % jcp.oc_block == 0
2649 && jcp.ic % jcp.ic_block == 0
2650 && diff_src_d.format() == src_format
2651 && diff_dst_d.format() == src_format;
2653 return status::unimplemented;
2655 jcp.nb_ic = jcp.ic / jcp.ic_block;
2656 jcp.nb_oc = jcp.oc / jcp.oc_block;
2658 jcp.ur_w = jcp.stride_w;
2664 for (int ur_w = regs; ur_w > 0; --ur_w)
2665 if (ur_w % jcp.stride_w == 0) {
2670 int l_overflow = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
2671 - jcp.l_pad) / jcp.stride_w);
2672 int r_overflow1 = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
2673 - nstl::max(0, jcp.r_pad) - jcp.iw % jcp.ur_w) / jcp.stride_w);
2674 int n_oi = jcp.iw / jcp.ur_w;
2675 if (r_overflow1 > 0) n_oi--;
2677 if ((mayiuse(avx512_mic_4ops) || mayiuse(avx512_core_vnni))
2678 && jcp.stride_w == 1 && jcp.stride_h == 1
2679 && diff_dst_d.data_type() == data_type::s16
2680 && weights_d.data_type() == data_type::s16
2681 && diff_src_d.data_type() == data_type::s32) {
2682 if (weights_d.format() != (with_groups ? gOIhw8o16i2o : OIhw8o16i2o))
2683 return status::unimplemented;
2684 if (mayiuse(avx512_mic_4ops)) {
2685 jcp.ver = ver_4vnni;
2689 jcp.typesize_in = sizeof(int16_t);
2690 jcp.typesize_out = sizeof(int32_t);
2691 } else if (mayiuse(avx512_common)
2692 && diff_dst_d.data_type() == data_type::f32
2693 && weights_d.data_type() == data_type::f32
2694 && diff_src_d.data_type() == data_type::f32) {
2695 if (weights_d.format() != wei_format)
2696 return status::unimplemented;
2698 jcp.typesize_in = sizeof(float);
2699 jcp.typesize_out = sizeof(float);
2700 if (mayiuse(avx512_mic_4ops)
2701 && jcp.stride_w == 1 && jcp.stride_h == 1 && jcp.stride_d == 1) {
2705 return status::unimplemented;
2707 if (!utils::everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w)
2708 && jcp.ver != ver_fma)
2709 return status::unimplemented;
2711 jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
2712 if (jcp.ver == ver_4vnni) {
2713 jcp.kernel_kind = embd_bcast;
2715 if (jcp.ver == ver_vnni) {
2716 // TODO: kernel_kind and nb_oc_blocking selection
2717 // should be tuned on real HW
2718 if ((jcp.iw <= 56 && jcp.ih <= 56 && jcp.kh < 5)
2719 || (jcp.iw <= 17 && jcp.ih <= 17 && jcp.kh >= 5) ) {
2720 jcp.kernel_kind = expl_bcast;
2721 jcp.nb_ic_blocking = 4;
2723 jcp.kernel_kind = embd_bcast;
2724 jcp.nb_ic_blocking = 2;
2726 if (jcp.nb_ic_blocking > 1) {
2727 if (jcp.nb_ic < jcp.nb_ic_blocking) jcp.nb_ic_blocking = jcp.nb_ic;
2728 if (jcp.nb_ic % jcp.nb_ic_blocking != 0)
2729 for (int i = jcp.nb_ic_blocking; i > 0; i--)
2730 if (jcp.nb_ic % i == 0) {
2731 jcp.nb_ic_blocking = i;
2734 jcp.ur_w = 31 / (jcp.nb_ic_blocking + 1);
2735 if (jcp.iw < jcp.ur_w) jcp.ur_w = jcp.iw;
2738 if (jcp.ver == ver_4fma) {
2739 if (jcp.kw == 3 && jcp.kh == 3 && jcp.iw == 7 && jcp.ih == 7) {
2740 jcp.nb_ic_blocking = 2;
2742 for (int i = jcp.nb_ic; i > 0; i--)
2743 if (i * jcp.ur_w <= regs && jcp.nb_ic % i == 0) {
2744 jcp.nb_ic_blocking = i;
2750 jcp.loop_order = loop_gnc;
2752 bool large_code_size = (jcp.ur_w != jcp.ow)
2753 && ((l_overflow <= 0 && n_oi > 0) ||(l_overflow > 0 && n_oi > 1))
2754 && (r_overflow1 > 0) && (l_overflow > 0);
2755 if (large_code_size) {
2756 const int max_code_size = 24 * 1024;
2757 const int num_ops_per_reg = 6 + jcp.oc_block * jcp.kw;
2759 if (l_overflow > 0) mult += 1;
2760 if (r_overflow1 > 0) mult += 1;
2761 for (int ur_w = jcp.ur_w; ur_w > regs/2; --ur_w) {
2762 if ((ur_w / jcp.stride_w) * mult * num_ops_per_reg * 9.2
2764 if (ur_w % jcp.stride_w == 0) {
2772 if (jcp.ver == ver_fma && mayiuse(avx512_core)) {
2773 int try_nb_ic_blocking = 2;
2774 unsigned int ker_inp_size = typesize * jcp.iw * jcp.ic_block
2775 * try_nb_ic_blocking * jcp.kh;
2776 unsigned int ker_out_size = typesize * jcp.ow * jcp.oc_block;
2777 unsigned int ker_wei_size = typesize * jcp.kh * jcp.kw * jcp.ic_block
2778 * jcp.oc_block * try_nb_ic_blocking;
2779 unsigned int ker_total_size = ker_inp_size + ker_out_size
2781 if (!(jcp.kw == 1 || (jcp.kw == 5 && jcp.iw < 8)
2782 || (jcp.kw < 5 && ((jcp.iw <= 5 || (jcp.iw > 8 && jcp.iw <= 13))
2783 || ker_total_size > L1_cache_size )))
2784 || jcp.stride_h > 1 || jcp.stride_d > 1) {
2785 jcp.kernel_kind = embd_bcast;
2786 jcp.ur_w = nstl::min(jcp.iw, regs);
2787 jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
2788 if (!(jcp.kw > 3 || (jcp.kw == 3 && ker_total_size < L1_cache_size
2789 && jcp.ow > 8)) && jcp.stride_h == 1)
2790 if (jcp.nb_ic % try_nb_ic_blocking == 0) {
2791 jcp.nb_ic_blocking = try_nb_ic_blocking;
2792 jcp.ur_w = 31 / (jcp.nb_ic_blocking + 1);
2793 if (jcp.iw < jcp.ur_w) jcp.ur_w = jcp.iw;
2796 jcp.kernel_kind = expl_bcast;
2797 jcp.nb_oc_blocking = 1;
2798 jcp.nb_ic_blocking = 4;
2799 if (jcp.nb_ic < jcp.nb_ic_blocking) jcp.nb_ic_blocking = jcp.nb_ic;
2800 if (jcp.nb_ic % jcp.nb_ic_blocking != 0)
2801 for (int i = jcp.nb_ic_blocking; i > 0; i--)
2802 if (jcp.nb_ic % i == 0) {
2803 jcp.nb_ic_blocking = i;
2806 jcp.ur_w = 31 / (jcp.nb_ic_blocking + 1);
2807 if (jcp.iw < jcp.ur_w) jcp.ur_w = jcp.iw;
2810 jcp.ur_w_tail = jcp.iw % jcp.ur_w;
2812 if (l_overflow * jcp.stride_w > jcp.ur_w)
2813 return status::unimplemented;
2814 int r_overflow_no_tail = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
2815 - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w);
2816 if (r_overflow_no_tail * jcp.stride_w > jcp.ur_w)
2817 return status::unimplemented;
2818 if ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0))
2819 return status::unimplemented;
2821 pick_loop_order(jcp);
2823 jcp.nb_oc_L2 = jcp.nb_oc;
2824 // TODO check for 4vnni
2825 if (jcp.ver == ver_4fma && (jcp.kh < 5 && jcp.kw < 5)) {
2826 for (int divf = 2, temp_nb = jcp.nb_oc_L2; divf <= jcp.nb_oc;
2828 size_t l2_src = jcp.iw * jcp.ic_block * jcp.nb_ic_blocking * jcp.ih
2830 size_t l2_dst = jcp.ow * jcp.oc_block * temp_nb * jcp.oh * jcp.od;
2831 size_t l2_filt = jcp.kw * jcp.oc_block * jcp.ic_block * jcp.kh
2832 * jcp.kd * jcp.nb_ic_blocking * temp_nb;
2833 if (4 * (l2_src + l2_dst + l2_filt) > KNx_L2_EFFECTIVE_CAPACITY) {
2834 if (jcp.kh == 3 && jcp.ih == 7) {
2838 temp_nb = (jcp.nb_oc_L2 % divf == 0 ? jcp.nb_oc_L2 / divf
2841 jcp.nb_oc_L2 = temp_nb;
2848 && jcp.ic <= diff_src_d.blocking_desc().padding_dims[1]
2849 && jcp.oc <= diff_dst_d.blocking_desc().padding_dims[1]
2850 && jcp.ic <= weights_d.blocking_desc().padding_dims[with_groups + 1]
2851 && jcp.oc <= weights_d.blocking_desc().padding_dims[with_groups + 0];
2852 if (!args_ok) return status::unimplemented;
2854 return status::success;
2857 void jit_avx512_common_conv_bwd_data_kernel_f32::init_scratchpad(
2858 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
2863 const int jit_avx512_common_conv_bwd_weights_kernel_f32::max_ur_w = 28;
2865 void jit_avx512_common_conv_bwd_weights_kernel_f32::od_step_comeback_pointers()
2867 Label kd_comeback_label;
2869 /* 'depth' loop count bound by 'kd_work_size' */
2870 mov(kj, ptr[param + GET_OFF(kd_padding)]);
2871 L(kd_comeback_label); {
2872 int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
2873 int iw = (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni))
2874 ? jcp.tr_iw : jcp.iw;
2876 jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * iw * inp_mult);
2878 jcp.typesize_out * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block);
2881 jg(kd_comeback_label, T_NEAR);
2885 void jit_avx512_common_conv_bwd_weights_kernel_f32::oh_step_comeback_pointers()
2887 Label kh_comeback_label, kd_comeback_label;
2889 L(kh_comeback_label); {
2890 int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
2891 int iw = (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni))
2892 ? jcp.tr_iw : jcp.iw;
2893 sub(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mult);
2895 jcp.typesize_out * jcp.kw * jcp.ic_block * jcp.oc_block);
2898 jg(kh_comeback_label, T_NEAR);
2902 void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step_fma(
2903 int ur_w, int pad_l, int pad_r,
2904 int ic_block_step, int input_offset, int kernel_offset,
2905 int output_offset, bool input_wraparound)
2909 int ic_block = jcp.ic_block;
2910 int oc_block = jcp.oc_block;
2911 for (int i_kw = 0; i_kw < kw; i_kw++)
2912 for (int i_ic = 0; i_ic < ic_block_step; i_ic++)
2913 vmovups(Zmm(i_kw * ic_block_step + i_ic),
2914 EVEX_compress_addr(reg_kernel, typesize * (i_kw * ic_block
2915 + i_ic) * jcp.oc_block + kernel_offset));
2917 for (int i_ur = 0; i_ur < ur_w; i_ur++) {
2919 vmovups(Zmm(kw * ic_block_step + (i_ur + 0) % 4),
2920 EVEX_compress_addr(reg_output, typesize * (i_ur + 0)
2921 * oc_block + output_offset));
2922 if (ur_w > 1) vmovups(Zmm(kw * ic_block_step + (i_ur + 1) % 4),
2923 EVEX_compress_addr(reg_output, typesize * (i_ur + 1) * oc_block
2925 if (ur_w > 2) vmovups(Zmm(kw * ic_block_step + (i_ur + 2) % 4),
2926 EVEX_compress_addr(reg_output, typesize * (i_ur + 2) * oc_block
2928 if (ur_w > 3) vmovups(Zmm(kw * ic_block_step + (i_ur + 3) % 4),
2929 EVEX_compress_addr(reg_output, typesize * (i_ur + 3) * oc_block
2931 } else if (i_ur + 3 < ur_w)
2932 vmovups(Zmm(kw * ic_block_step + (i_ur + 3) % 4),
2933 EVEX_compress_addr(reg_output, typesize * (i_ur + 3) * oc_block
2936 for (int i_kw = 0; i_kw < kw; i_kw++) {
2937 int i_iw = i_ur * jcp.stride_w + i_kw * (jcp.dilate_w + 1);
2938 if (i_iw - pad_l < 0 || i_iw > (ur_w - 1) * jcp.stride_w +
2939 (kw - 1) * (jcp.dilate_w + 1) - pad_r) continue;
2940 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
2941 const size_t i_offset = (size_t)input_offset
2942 + (size_t)typesize * (jcp.ver == ver_4fma
2943 ? (i_iw - pad_l + i_ic * jcp.tr_iw)
2945 ? (i_iw - pad_l) + (size_t)i_ic
2946 * ((size_t)jcp.ih*jcp.iw*jcp.id)
2947 : (i_iw - pad_l) * ic_block + i_ic));
2948 vfmadd231ps(Zmm(i_kw * ic_block_step + i_ic),
2949 Zmm(kw * ic_block_step + i_ur % 4),
2950 EVEX_compress_addr_safe(reg_input, i_offset, reg_long_offt,
2956 for (int i_kw = 0; i_kw < kw; i_kw++)
2957 for (int i_ic = 0; i_ic < ic_block_step; i_ic++)
2958 vmovups(EVEX_compress_addr(reg_kernel, typesize
2959 * (i_kw * ic_block + i_ic) * jcp.oc_block + kernel_offset),
2960 Zmm(i_kw * ic_block_step + i_ic));
2963 void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step_4fma(
2964 int ur_w, int pad_l, int pad_r,
2965 int ic_block_step, int input_offset, int kernel_offset,
2966 int output_offset, bool input_wraparound)
2968 // TODO: add prefetches to fma version as well
2970 assert(jcp.ver == ver_4fma);
2973 int ic_block = jcp.ic_block;
2974 int oc_block = jcp.oc_block;
2976 auto zmm_ker = [=](int i_kw, int i_ic) {
2977 return Zmm(i_kw * ic_block_step + i_ic);
2980 auto ker_addr = [=](int i_kw, int i_ic) {
2982 = jcp.typesize_out * (i_kw * ic_block + i_ic) * jcp.oc_block;
2983 return EVEX_compress_addr(reg_kernel, local_offset + kernel_offset);
2986 auto inp_addr = [=](int i_iw, int i_ic, ptrdiff_t extra_offset = 0) {
2987 int stride = jcp.tr_iw * (jcp.is_1stconv ? jcp.ih : 1);
2988 int local_offset = jcp.typesize_in * (i_iw + i_ic * stride);
2989 return EVEX_compress_addr(reg_input,
2990 local_offset + input_offset + extra_offset);
2993 auto zmm_out = [=](int i_iw) {
2994 // TODO: move reg calc to global member funcs
2995 const int out_zmm_base_idx = 28;
2996 return Zmm(out_zmm_base_idx + i_iw % 4);
2999 auto out_addr = [=](int i_ur) {
3000 return EVEX_compress_addr(reg_output,
3001 jcp.typesize_in * i_ur * oc_block + output_offset);
3004 auto pf_callback = [=](int i_ur, int i_kw, int i_ic) {
3005 assert(i_ur % 4 == 0);
3007 prefetcht1(ker_addr(i_kw, i_ic));
3008 if (i_ur + 4 >= ur_w)
3009 prefetcht0(ker_addr(i_kw, i_ic));
3011 const ptrdiff_t next_input_block_offset
3012 = jcp.typesize_in * ic_block_step * jcp.tr_iw;
3013 if (i_ur % 16 == 4 && i_kw == 0) {
3014 if (i_ur + 16 < ur_w)
3015 prefetcht0(inp_addr(i_ur + 16, i_ic));
3017 prefetcht0(inp_addr(0, i_ic, next_input_block_offset));
3019 if (i_ur % 16 == 4 && i_kw == 1) {
3020 if (input_wraparound)
3021 prefetcht1(inp_addr(i_ur, i_ic, -input_offset));
3023 prefetcht1(inp_addr(i_ur, i_ic, next_input_block_offset));
3027 for (int i_kw = 0; i_kw < kw; i_kw++)
3028 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
3029 auto zmm = zmm_ker(i_kw, i_ic);
3030 vpxord(zmm, zmm, zmm);
3033 for (int i_ur = 0; i_ur < ur_w; i_ur += 4) {
3035 for (int i = 0; i < 4; i++) {
3036 auto zmm = zmm_out(i_ur + i);
3037 if (i_ur + i < ur_w)
3038 vmovups(zmm, out_addr(i_ur + i));
3040 vpxord(zmm, zmm, zmm);
3041 prefetcht0(out_addr(i_ur + i + 4));
3044 for (int i_kw = 0; i_kw < kw; i_kw++)
3045 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
3046 int i_iw = i_ur + i_kw;
3047 v4fmaddps(zmm_ker(i_kw, i_ic),
3048 zmm_out(i_ur), inp_addr(i_iw, i_ic));
3049 pf_callback(i_ur, i_kw, i_ic);
3053 for (int i_kw = 0; i_kw < kw; i_kw++)
3054 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
3055 auto addr = ker_addr(i_kw, i_ic);
3056 auto zmm = zmm_ker(i_kw, i_ic);
3057 vaddps(zmm, zmm, addr);
3062 void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step_vnni(
3063 int ur_w, int pad_l, int pad_r,
3064 int ic_block_step, int input_offset, int kernel_offset,
3065 int output_offset, bool input_wraparound)
3067 // TODO: add prefetches to fma version as well
3068 assert(jcp.ver == ver_4vnni || jcp.ver == ver_vnni);
3071 int ic_block = jcp.ic_block;
3072 int oc_block = jcp.oc_block;
3074 auto zmm_ker = [=](int i_kw, int i_ic) {
3075 return Zmm(i_kw * ic_block_step + i_ic);
3078 auto ker_addr = [=](int i_kw, int i_ic) {
3080 = jcp.typesize_out * (i_kw * ic_block + i_ic) * jcp.oc_block;
3081 return EVEX_compress_addr(reg_kernel, local_offset + kernel_offset);
3084 auto inp_addr = [=](int i_iw, int i_ic, ptrdiff_t extra_offset = 0,
3085 bool vnni_bcast = false) {
3086 int stride = jcp.tr_iw * (jcp.is_1stconv ? jcp.ih : 1);
3087 int local_offset = jcp.typesize_in * (i_iw + i_ic * stride);
3089 return EVEX_compress_addr(reg_input,
3090 local_offset + input_offset + extra_offset, true);
3092 return EVEX_compress_addr(reg_input,
3093 local_offset + input_offset + extra_offset);
3096 auto zmm_out = [=](int i_iw) {
3097 // TODO: move reg calc to global member funcs
3098 const int out_zmm_base_idx = 28;
3099 return Zmm(out_zmm_base_idx + i_iw % 4);
3102 auto out_addr = [=](int i_ur) {
3103 assert(utils::one_of(jcp.ver, ver_4vnni, ver_4fma, ver_vnni));
3104 auto ow_per_oc = (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) ? 2 : 1;
3105 return EVEX_compress_addr(reg_output,
3106 jcp.typesize_in * i_ur * oc_block * ow_per_oc + output_offset);
3109 auto pf_callback = [=](int i_ur, int i_kw, int i_ic) {
3111 mic_prefetcht1(ker_addr(i_kw, i_ic));
3112 if (i_ur + 4 >= ur_w)
3113 mic_prefetcht0(ker_addr(i_kw, i_ic));
3115 const ptrdiff_t next_input_block_offset
3116 = jcp.typesize_in * ic_block_step * jcp.tr_iw;
3117 if (i_ur % 16 == 4 && i_kw == 0) {
3118 if (i_ur + 16 < ur_w)
3119 mic_prefetcht0(inp_addr(i_ur + 16, i_ic));
3121 mic_prefetcht0(inp_addr(0, i_ic, next_input_block_offset));
3123 if (i_ur % 16 == 4 && i_kw == 1) {
3124 if (input_wraparound)
3125 mic_prefetcht1(inp_addr(i_ur, i_ic, -input_offset));
3127 mic_prefetcht1(inp_addr(i_ur, i_ic, next_input_block_offset));
3131 for (int i_kw = 0; i_kw < kw; i_kw++)
3132 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
3133 auto zmm = zmm_ker(i_kw, i_ic);
3134 vpxord(zmm, zmm, zmm);
3136 auto steps = ur_w / 2;
3137 auto numloads = (jcp.ver == ver_vnni) ? 1 : 4;
3138 for (int i_ur = 0; i_ur < steps; i_ur += numloads) {
3140 for (int i = 0; i < numloads; i++) {
3142 auto zmm = zmm_out(oi);
3144 vmovups(zmm, out_addr(oi));
3146 vpxord(zmm, zmm, zmm);
3147 mic_prefetcht0(out_addr(2 * i_ur + i + 4));
3150 for (int i_kw = 0; i_kw < kw; i_kw++)
3151 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
3152 int i_iw = 2 * i_ur + i_kw;
3153 if (jcp.ver == ver_4vnni)
3154 vp4dpwssd(zmm_ker(i_kw, i_ic), zmm_out(i_ur),
3155 inp_addr(i_iw, i_ic));
3156 else if (jcp.ver == ver_vnni)
3157 vpdpwssd(zmm_ker(i_kw, i_ic), zmm_out(i_ur),
3158 inp_addr(i_iw, i_ic, 0, true));
3160 assert(!"unknown convolution version");
3161 pf_callback(2 * i_ur, i_kw, i_ic);
3165 for (int i_kw = 0; i_kw < kw; i_kw++) {
3166 for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
3167 auto addr = ker_addr(i_kw, i_ic);
3168 auto zmm = zmm_ker(i_kw, i_ic);
3169 vpaddd(zmm, zmm, addr);
3175 void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step(
3176 int ur_w, int pad_l, int pad_r,
3177 int ic_block_step, int input_offset, int kernel_offset,
3178 int output_offset, bool input_wraparound)
3180 if (jcp.ver == ver_4vnni || jcp.ver == ver_vnni)
3181 compute_ic_block_step_vnni(ur_w, pad_l, pad_r,
3182 ic_block_step, input_offset, kernel_offset, output_offset,
3184 else if (jcp.ver == ver_4fma)
3185 compute_ic_block_step_4fma(ur_w, pad_l, pad_r,
3186 ic_block_step, input_offset, kernel_offset, output_offset,
3188 else if (jcp.ver == ver_fma)
3189 compute_ic_block_step_fma(ur_w, pad_l, pad_r,
3190 ic_block_step, input_offset, kernel_offset, output_offset,
3193 assert(!"unknown convolution version");
3196 void jit_avx512_common_conv_bwd_weights_kernel_f32
3197 ::compute_oh_step_unroll_ow_icblock(
3198 int ic_block_step, int max_ur_w)
3202 Label kh_label, kd_label;
3204 int ic_block = jcp.ic_block;
3205 int oc_block = jcp.oc_block;
3206 int inp_mul = !jcp.is_1stconv ? ic_block : 1;
3207 int iw = (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni))
3208 ? jcp.tr_iw : jcp.iw;
3209 int ow = (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) ? jcp.tr_ow : jcp.ow;
3211 int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w
3212 + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
3213 int l_pad = jcp.l_pad;
3215 if (jcp.ndims == 5) {
3216 mov(aux_reg_input, reg_input);
3217 mov(aux_reg_kernel, reg_kernel);
3218 mov(ki, ptr[param + GET_OFF(kd_padding)]);
3220 mov(reg_input, aux_reg_input);
3221 mov(reg_kernel, aux_reg_kernel);
3227 for (int i_b_ic = 0; i_b_ic < jcp.ic_block; i_b_ic += ic_block_step) {
3228 const int input_offset = jcp.typesize_in
3229 * (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)
3230 ? i_b_ic * iw : i_b_ic);
3231 compute_ic_block_step(jcp.ur_w, l_pad, r_pad, ic_block_step,
3232 input_offset, jcp.typesize_out * i_b_ic * jcp.oc_block, 0,
3233 i_b_ic + ic_block_step >= jcp.ic_block);
3235 add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul);
3236 add(reg_kernel, jcp.typesize_out * jcp.kw * ic_block * oc_block);
3239 jg(kh_label, T_NEAR);
3242 if (jcp.ndims == 5) {
3244 jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * iw * inp_mul);
3245 add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block
3249 jg(kd_label, T_NEAR);
3253 void jit_avx512_common_conv_bwd_weights_kernel_f32
3254 ::compute_oh_step_unroll_ow(
3255 int ic_block_step, int max_ur_w)
3257 Label kh_label, ic_block_label, kd_label;
3261 int ic_block = jcp.ic_block;
3262 int oc_block = jcp.oc_block;
3264 int ow = (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) ? jcp.tr_ow : jcp.ow;
3266 int r_pad = nstl::max(0,
3267 (ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1)
3268 - (jcp.iw + jcp.l_pad - 1));
3269 int l_pad = jcp.l_pad;
3271 if (jcp.ndims == 5) {
3272 mov(aux_reg_input, reg_input);
3273 mov(aux_reg_kernel, reg_kernel);
3274 mov(ki, ptr[param + GET_OFF(kd_padding)]);
3276 mov(reg_input, aux_reg_input);
3277 mov(reg_kernel, aux_reg_kernel);
3284 L(ic_block_label); {
3285 compute_ic_block_step(ow, l_pad, r_pad, ic_block_step,
3287 size_t inp_icblk_stride = jcp.is_1stconv
3288 ? (size_t)jcp.ih * jcp.iw * jcp.id
3289 : (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)
3292 = inp_icblk_stride * jcp.typesize_in * ic_block_step;
3293 safe_add(reg_input, input_offset, reg_long_offt);
3294 add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block);
3295 add(b_ic, ic_block_step);
3296 cmp(b_ic, jcp.ic_block);
3297 jl(ic_block_label, T_NEAR);
3300 if (jcp.is_1stconv) {
3302 = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block;
3303 safe_sub(reg_input, input_offset, reg_long_offt);
3304 add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw);
3305 } else if (!utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)) {
3306 add(reg_input, jcp.typesize_in
3307 * ((jcp.dilate_h + 1) * jcp.iw - 1) * ic_block);
3309 add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block);
3312 jg(kh_label, T_NEAR);
3314 if (jcp.ndims == 5) {
3315 add(aux_reg_input, jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih
3316 * jcp.iw * (jcp.is_1stconv ? 1 : ic_block));
3317 add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block
3321 jg(kd_label, T_NEAR);
3325 void jit_avx512_common_conv_bwd_weights_kernel_f32
3326 ::compute_oh_step_common(
3327 int ic_block_step, int max_ur_w)
3329 Label kh_label, ic_block_label, ow_block_label, kd_label;
3331 int ic_block = jcp.ic_block;
3332 int oc_block = jcp.oc_block;
3334 int ow = (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) ? jcp.tr_ow : jcp.ow;
3335 int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w
3336 + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
3337 int l_pad = (jcp.ver == ver_4fma || jcp.ver == ver_4vnni
3338 || jcp.ver == ver_vnni) ? 0 : jcp.l_pad;
3340 int ur_w = nstl::min(ow, max_ur_w);
3341 int ur_w_trips = ow / ur_w;
3342 int ur_w_tail = ow % ur_w;
3343 if ((ur_w_tail == 0 && r_pad != 0)
3344 || r_pad >= ur_w_tail) {
3345 if (ur_w_trips > 1) {
3349 ur_w_tail += (ur_w - ur_w / 2);
3354 int inp_mult = (jcp.is_1stconv ||
3355 utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)) ? 1 : ic_block;
3356 int input_comeback = (ur_w_trips * ur_w * jcp.stride_w - l_pad) * inp_mult;
3357 int output_comeback = ur_w_trips * ur_w * oc_block;
3359 if (jcp.ndims == 5) {
3360 mov(aux_reg_input, reg_input);
3361 mov(aux_reg_kernel, reg_kernel);
3362 mov(ki, ptr[param + GET_OFF(kd_padding)]);
3364 mov(reg_input, aux_reg_input);
3365 mov(reg_kernel, aux_reg_kernel);
3371 L(ic_block_label); {
3374 compute_ic_block_step(ur_w, l_pad, 0, ic_block_step, 0, 0, 0);
3375 add(reg_input, jcp.typesize_in * (ur_w * jcp.stride_w - l_pad)
3377 add(reg_output, jcp.typesize_in * ur_w * oc_block);
3380 if (ur_w_trips > 0) {
3381 xor_(reg_ur_w_trips, reg_ur_w_trips);
3382 L(ow_block_label); {
3383 compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0);
3384 add(reg_input, jcp.typesize_in * ur_w * jcp.stride_w
3386 add(reg_output, jcp.typesize_in * ur_w * oc_block);
3388 inc(reg_ur_w_trips);
3389 cmp(reg_ur_w_trips, ur_w_trips);
3390 jl(ow_block_label, T_NEAR);
3394 if (ur_w_tail > 0) compute_ic_block_step(ur_w_tail, 0, r_pad,
3395 ic_block_step, 0, 0, 0);
3397 sub(reg_input, jcp.typesize_in * input_comeback);
3398 sub(reg_output, jcp.typesize_in * output_comeback);
3399 int inp_icblk_stride = jcp.is_1stconv
3400 ? jcp.ih * jcp.iw * jcp.id
3401 : (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)
3404 = inp_icblk_stride * jcp.typesize_in * ic_block_step;
3405 safe_add(reg_input, input_offset, reg_long_offt);
3406 add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block);
3408 add(b_ic, ic_block_step);
3409 cmp(b_ic, jcp.ic_block);
3410 jl(ic_block_label, T_NEAR);
3412 if (jcp.is_1stconv) {
3414 = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block;
3415 safe_sub(reg_input, input_offset, reg_long_offt);
3416 add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw);
3417 } else if (!utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)) {
3418 add(reg_input, jcp.typesize_in
3419 * ((jcp.dilate_h + 1 ) * jcp.iw - 1) * ic_block);
3421 add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block);
3424 jg(kh_label, T_NEAR);
3426 if (jcp.ndims == 5) {
3427 add(aux_reg_input, jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih
3428 * jcp.iw * (jcp.is_1stconv ? 1 : ic_block));
3429 add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block
3433 jg(kd_label, T_NEAR);
3437 void jit_avx512_common_conv_bwd_weights_kernel_f32
3438 ::compute_oh_step_disp()
3440 int ic_block_step = jcp.kw <= 3 ? 8 : (jcp.kw <= 7 ? 4 : 2);
3441 if (jcp.is_1stconv) {
3442 bool large_code = jcp.kw >= 7 && (jcp.l_pad > 0 || jcp.t_pad > 0);
3444 = (jcp.kw * jcp.ic_block <= 28 && !large_code) ? jcp.ic_block : 1;
3447 bool too_large_to_unroll
3448 = (jcp.kw > 1 || jcp.kh > 1 || jcp.kd > 1)
3449 && (jcp.stride_w > 1 || jcp.stride_h > 1 || jcp.stride_d > 1);
3451 int ow = (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) ? jcp.tr_ow : jcp.ow;
3452 if (jcp.kw <= 3 && ow <= 16 && !too_large_to_unroll)
3453 compute_oh_step_unroll_ow_icblock(ic_block_step, max_ur_w);
3454 else if (ow <= max_ur_w)
3455 compute_oh_step_unroll_ow(ic_block_step, max_ur_w);
3457 compute_oh_step_common(ic_block_step, max_ur_w);
3459 if (jcp.ndims == 5) {
3460 od_step_comeback_pointers();
3461 mov(reg_input, aux_reg_input);
3462 mov(reg_kernel, aux_reg_kernel);
3464 oh_step_comeback_pointers();
3468 void jit_avx512_common_conv_bwd_weights_kernel_f32::maybe_zero_kernel()
3470 Label skip_zeroing, zeroing_loop;
3472 mov(reg_tmp, ptr[param + GET_OFF(channel)]);
3474 jz(skip_zeroing, T_NEAR);
3477 vpxord(zero, zero, zero);
3478 xor_(reg_tmp, reg_tmp);
3480 assert(jcp.oc_block * jcp.typesize_out
3481 == cpu_isa_traits<avx512_common>::vlen);
3482 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++)
3483 vmovups(ptr[reg_kernel + reg_tmp + ic1 * jcp.oc_block
3484 * jcp.typesize_out], zero);
3485 add(reg_tmp, jcp.ic_block * jcp.oc_block * jcp.typesize_out);
3486 cmp(reg_tmp, jcp.ic_block * jcp.oc_block * jcp.kw * jcp.kh * jcp.kd
3487 * jcp.typesize_out);
3494 void jit_avx512_common_conv_bwd_weights_kernel_f32::bias_kernel()
3496 Label skip_bias, bias_loop, skip_load_bias;
3498 mov(reg_tmp, ptr[param + GET_OFF(flags)]);
3499 test(reg_tmp,reg_tmp);
3500 jne(skip_bias, T_NEAR);
3502 mov(reg_bias, ptr[param + GET_OFF(bias)]);
3503 mov(reg_output, ptr[param + GET_OFF(dst)]);
3504 vpxord(Zmm(1), Zmm(1), Zmm(1));
3506 mov(reg_tmp, ptr[param + GET_OFF(channel)]);
3508 jne(skip_load_bias, T_NEAR);
3509 vmovups(Zmm(1), ptr[reg_bias]);
3513 mov(reg_oi, ptr[param + GET_OFF(d_worksize)]);
3514 sub(reg_oi, ptr[param + GET_OFF(d_index)]);
3515 mov(reg_tmp, jcp.oc_block * jcp.ow * jcp.oh * jcp.typesize_out);
3516 imul(reg_oi, reg_tmp);
3518 xor_(reg_tmp, reg_tmp);
3520 vmovups(Zmm(0), ptr[reg_output + reg_tmp]);
3521 vaddps(Zmm(1), Zmm(1), Zmm(0));
3522 add(reg_tmp, jcp.oc_block * jcp.typesize_out);
3523 cmp(reg_tmp, reg_oi);
3526 vmovups(EVEX_compress_addr(reg_bias,0), Zmm(1));
3531 void jit_avx512_common_conv_bwd_weights_kernel_f32
3532 ::compute_oh_loop_common()
3534 int ic_block = jcp.ic_block;
3535 int oc_block = jcp.oc_block;
3536 int back_pad = jcp.back_pad;
3537 int b_pad = jcp.b_pad;
3538 int t_pad = jcp.t_pad;
3539 bool is_dilated = jcp.dilate_h != 0;
3540 int dilate_h = jcp.dilate_h + 1;
3541 int stride_h = jcp.stride_h;
3542 const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
3543 int iw = utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni) ? jcp.tr_iw
3545 const size_t io_overlap = jcp.od - back_pad;
3546 Label oh_label, oh_label_end, oh_tpad_label, oh_tpad_tail_label,
3547 oh_bpad_label, oh_bpad_label_end, od_label, od_label_end,
3548 oh_dilate_label_shift, oh_dilate_label_noshift, oh_dilate_label_end,
3549 skip_neg_overlap_label, skip_fpad_label, skip_input_label;
3551 maybe_zero_kernel();
3552 if (jcp.ndims == 5 && jcp.with_bias) bias_kernel();
3554 /* initially offset 'kd' by f_pad */
3555 if (jcp.ndims == 5) add(reg_kernel, ptr[param + GET_OFF(kd_offset)]);
3557 int ow = (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) ? jcp.tr_ow : jcp.ow;
3559 if (jcp.ndims == 5) {
3560 mov(reg_input_d, ptr[param + GET_OFF(src)]);
3561 mov(reg_output_d, ptr[param + GET_OFF(dst)]);
3562 mov(reg_d_index, ptr[param + GET_OFF(d_index)]);
3565 mov(reg_input, reg_input_d);
3566 mov(reg_output, reg_output_d);
3572 mov(reg_kh, jcp.kh);
3573 xor_(reg_ih_count, reg_ih_count);
3574 xor_(reg_oj, reg_oj);
3575 /* Compute 'top' edge */
3577 const int kh_range = 1 + (jcp.kh - 1) * dilate_h;
3579 = nstl::max(0, jcp.kh - div_up(t_pad + jcp.ih, dilate_h));
3580 const int underflow = div_up(t_pad, dilate_h);
3581 const int initial_inp_ker_overlap = jcp.kh - overflow - underflow;
3582 mov(reg_kh, initial_inp_ker_overlap);
3583 add(reg_kernel, jcp.typesize_out * underflow * jcp.kw * jcp.ic_block
3585 // generate loop to process kernel while it remains within t_pad + ih
3586 if (kh_range < t_pad + jcp.ih) {
3588 const int tail = t_pad % dilate_h;
3589 const int shift = tail == 0 ? 0 : dilate_h - tail;
3590 mov(reg_tmp, shift);
3592 add(reg_input, jcp.typesize_in * shift * iw * inp_mult);
3595 compute_oh_step_disp();
3596 add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
3599 cmp(reg_tmp, dilate_h);
3600 jl(oh_dilate_label_shift, T_NEAR);
3601 // unshift input as new kernel element enters
3602 sub(reg_input, jcp.typesize_in * (dilate_h - 1) * iw * inp_mult);
3603 xor_(reg_tmp, reg_tmp);
3605 // kernel overlap only changes when (t_pad + oj) % dilate_h == 0
3606 sub(reg_kernel, jcp.typesize_out * stride_h * jcp.kw
3607 * jcp.ic_block * jcp.oc_block);
3608 add(reg_kh, stride_h);
3610 jmp(oh_dilate_label_noshift, T_NEAR);
3611 L(oh_dilate_label_shift);
3612 // shift input as old kernel element progresses
3613 add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult);
3614 L(oh_dilate_label_noshift);
3617 add(reg_ih_count, stride_h);
3619 // final number of kernel elements that overlap with input
3620 const int final_inp_ker_overlap
3621 = nstl::min(jcp.kh, div_up(jcp.ih, dilate_h));
3622 cmp(reg_kh, final_inp_ker_overlap);
3623 jl(oh_tpad_label, T_NEAR);
3626 // need second loop to process kernel if it is larger than the input
3627 // (does not apply to dilations as they must have unit stride)
3628 if (kh_range >= jcp.ih + (t_pad % stride_h == 0 ? stride_h :
3629 t_pad % stride_h)) {
3630 assert(!is_dilated);
3631 mov(reg_kh, jcp.ih);
3632 L(oh_tpad_tail_label); {
3633 compute_oh_step_disp();
3634 add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
3635 sub(reg_kernel, jcp.typesize_out * stride_h * jcp.kw
3636 * jcp.ic_block * jcp.oc_block);
3639 add(reg_ih_count, stride_h);
3641 cmp(reg_ih_count, nstl::min(t_pad, jcp.oh * stride_h));
3642 jl(oh_tpad_tail_label, T_NEAR);
3645 // correct any excess shifts to kernel and input
3646 // (does not apply to dilations as they must have unit stride,
3647 // kernel must fit inside input, and padding is smaller than input)
3648 if (t_pad <= jcp.oh * stride_h) {
3649 // kernel has moved beyond padding (adjust for stride effects)
3650 if (t_pad % stride_h != 0) {
3651 assert(!is_dilated);
3652 int inp_corr = stride_h - t_pad % stride_h;
3653 add(reg_kernel, jcp.typesize_out * inp_corr * jcp.kw
3654 * jcp.ic_block * jcp.oc_block);
3655 add(reg_input, jcp.typesize_in * inp_corr * iw * inp_mult);
3658 // kernel still overlaps padding (complete reset)
3659 assert(!is_dilated);
3660 sub(reg_kernel, jcp.typesize_out * (t_pad - jcp.oh * stride_h)
3661 * jcp.kw * jcp.ic_block * jcp.oc_block);
3665 cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h);
3666 jge(oh_label_end, T_NEAR);
3667 cmp(reg_oj, jcp.oh);
3668 jge(oh_label, T_NEAR);
3670 /* Compute middle block(s) */
3671 mov(reg_kh, jcp.kh);
3673 compute_oh_step_disp();
3674 add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult);
3675 add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
3678 add(reg_ih_count, stride_h);
3680 cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h);
3681 jge(oh_label_end, T_NEAR);
3683 cmp(reg_oj, jcp.oh);
3684 jl(oh_label, T_NEAR);
3688 /* Compute bottom edge */
3690 cmp(reg_oj, jcp.oh);
3691 jge(oh_bpad_label_end, T_NEAR);
3694 mov(reg_kh, jcp.kh - 1); // assumes unit stride for dilations
3697 mov(reg_kh, jcp.ihp - b_pad);
3698 sub(reg_kh, reg_ih_count);
3702 compute_oh_step_disp();
3703 add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult);
3704 add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
3707 cmp(reg_tmp, dilate_h);
3708 jl(oh_dilate_label_end, T_NEAR);
3709 xor_(reg_tmp, reg_tmp);
3711 sub(reg_kh, stride_h);
3713 jle(oh_bpad_label_end, T_NEAR);
3715 L(oh_dilate_label_end);
3718 cmp(reg_oj, jcp.oh);
3719 jl(oh_bpad_label, T_NEAR);
3721 L(oh_bpad_label_end);
3724 if (jcp.ndims == 5) {
3729 mov(reg_kd_count, ptr[param + GET_OFF(kd_padding)]);
3731 /* 'outer-depth loop' offset into next 'depth' index */
3732 add(reg_output_d, jcp.typesize_in * jcp.oh * ow * jcp.oc_block);
3734 /* only increase input address when convolution is not within the
3736 if (jcp.f_pad > 0) {
3737 cmp(reg_d_index, jcp.f_pad);
3738 jl(skip_input_label);
3741 jcp.typesize_in * jcp.stride_d * jcp.ih * iw * inp_mult);
3742 L(skip_input_label);
3745 cmp(reg_d_index, io_overlap);
3746 jl(skip_neg_overlap_label);
3748 /* Reduce 'kd' count as convolution steps within 'back_pad' region */
3750 jmp(skip_fpad_label);
3752 L(skip_neg_overlap_label);
3753 cmp(reg_kd_count, jcp.kd);
3754 jge(skip_fpad_label);
3756 /* increase 'kd' count as convolution steps out of 'f_pad' region */
3759 jcp.typesize_out * jcp.kh * jcp.kw * ic_block * oc_block);
3762 mov(ptr[param + GET_OFF(kd_padding)], reg_kd_count);
3764 cmp(reg_d_index, ptr[param + GET_OFF(d_worksize)]);
3765 jl(od_label, T_NEAR);
3771 bool jit_avx512_common_conv_bwd_weights_kernel_f32
3772 ::compute_full_spat_loop()
3774 // FIXME: use register mapping from the class declaration
3775 bool ok = one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)
3776 && (jcp.ver == ver_4fma || !one_of(1, jcp.kh, jcp.kw))
3777 && everyone_is(0, jcp.dilate_h, jcp.dilate_w)
3778 && everyone_is(1, jcp.stride_h, jcp.stride_w);
3779 if (!ok) return false;
3780 if (jcp.l_pad != jcp.kw / 2 || jcp.t_pad != jcp.kh / 2)
3783 // General code layout:
3785 // Blocking over OH -- top level
3786 // (Reduces L2 pressure; not very useful right now)
3787 // Loop over all KHxKW kernel -- emit_kh_kw_loop()
3788 // Loop over OH block -- emit_h_loop()
3789 // Loop over OW blocks -- emit_fma_block()
3790 // (Supports both fully unrolled and partially unrolled versions to
3791 // reduce code size)
3792 // Loop over OW block -- emit_fma_step()
3794 int max_working_set_size = 128 * 1024;
3795 int pad_ow = (jcp.ver == ver_4vnni || jcp.ver == ver_vnni)? jcp.tr_ow
3798 int inp_row_size = jcp.ic_block * jcp.tr_iw * jcp.typesize_in;
3799 int out_row_size = jcp.oc_block * pad_ow * jcp.typesize_in;
3800 int row_size = inp_row_size + out_row_size;
3802 int h_block_size = jcp.oh;
3803 int working_set_size = row_size * h_block_size;
3805 if (working_set_size > max_working_set_size) {
3806 int opt_working_set_size = 48 * 1024;
3807 assert(opt_working_set_size < max_working_set_size);
3809 while (working_set_size > opt_working_set_size) {
3810 for (int i = 2; i <= h_block_size; i++)
3811 if (i == h_block_size)
3812 h_block_size = h_block_size / 2;
3813 else if (h_block_size % i == 0) {
3814 h_block_size = h_block_size / i;
3817 working_set_size = row_size * h_block_size;
3819 if (h_block_size == 1 && working_set_size > opt_working_set_size)
3824 // NB1: t_pad <= oh_block_size and b_pad <= last_oh_block_size (see below)
3825 if (h_block_size < nstl::max(1, jcp.t_pad)
3826 || jcp.b_pad > (jcp.oh % h_block_size == 0 ? h_block_size
3827 : jcp.oh % h_block_size))
3830 // check that we can use simple arithmetic for prefetch address
3832 // TODO: we need some traits for this check (Roma)
3833 int cache_line_size = 64;
3834 assert(jcp.ic_block * typesize == 64);
3835 assert(jcp.oc_block * typesize == 64);
3837 int num_inp_l2_pfs = jcp.tr_iw * h_block_size;
3838 int avg_h_loop_len = h_block_size;
3839 int num_inp_l2_pfs_per_fma_block
3840 = div_up(num_inp_l2_pfs, avg_h_loop_len * jcp.kw * jcp.kh);
3841 int num_out_l2_pfs = pad_ow * h_block_size;
3842 int num_out_l2_pfs_per_fma_block
3843 = div_up(num_out_l2_pfs, avg_h_loop_len * jcp.kw * jcp.kh);
3845 Opmask reg_h_block = k1; // 32-bit only on Intel(R) Xeon Phi(TM) processors
3848 Reg64 reg_tmp = abi_not_param1;
3849 Reg32 reg_tmp_w = reg_tmp.cvt32();
3850 Reg64 reg_ohs = rdx;
3851 Reg64 reg_ihs = rsi;
3856 Reg64 reg_inp = r13;
3857 Reg64 reg_out = r14;
3858 Reg64 reg_ker = r15;
3860 Reg64 reg_inp_pf_l1 = rbp;
3862 Reg64 reg_inp_pf_l2 = r11;
3863 Reg64 reg_out_pf_l2 = r12;
3865 Xmm reg_inp_pf_save = xmm17;
3866 Xmm reg_out_pf_save = xmm18;
3868 Reg64 reg_inp_save = abi_param1;
3869 Reg64 reg_out_save = reg_tmp;
3871 auto zmm_out = [&](int oi) { return Zmm(24 + oi % 8); };
3872 auto zmm_ker = [&](int ic1) { return Zmm(ic1); };
3873 auto inp_addr = [&](int oi, int ic1, bool vnni_bcast = false) {
3875 return zword_b[reg_inp + (ic1 * jcp.tr_iw + oi) * jcp.typesize_in];
3877 return ptr[reg_inp + (ic1 * jcp.tr_iw + oi) * jcp.typesize_in];
3879 auto out_addr = [&](int oi, int oj = 0) {
3880 assert(utils::one_of(jcp.ver, ver_4vnni, ver_4fma, ver_vnni));
3881 auto ow_per_oc = (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) ? 2 : 1;
3882 auto pad_ow = (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) ? jcp.tr_ow
3885 + ((oi + oj * pad_ow / ow_per_oc) * jcp.oc_block * ow_per_oc)
3888 auto ker_addr = [&](int ic1) {
3889 return ptr[reg_ker + ic1 * jcp.oc_block * jcp.typesize_out];
3892 auto emit_block = [&](int h_block_size,
3893 bool is_last_block, bool is_last_kh_kw_iter, bool is_last_row)
3895 // TODO: add an fma version (Roma)
3896 auto pad_ow = (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) ? jcp.tr_ow
3899 int ow_per_oc = (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) ? 2 : 1;
3900 int ow4u = rnd_up(pad_ow, 4 * ow_per_oc);
3901 int def_step_size = 16;
3903 bool has_w_tail = (pad_ow % def_step_size != 0
3904 || pad_ow % (4 * ow_per_oc) != 0);
3905 bool full_w_unroll = pad_ow / def_step_size < 2 + has_w_tail;
3907 auto emit_step = [&](int ur_ow,
3908 int num_inp_l1_pfs_per_fma_step,
3909 int num_inp_l2_pfs_per_fma_step,
3910 int num_out_l2_pfs_per_fma_step, bool is_w_tail)
3912 bool block_wraparound = is_w_tail && is_last_row;
3914 assert(ur_ow % 4 == 0);
3915 int tail_size = ow4u % ur_ow;
3917 = (is_w_tail && tail_size) ? tail_size : ur_ow;
3918 int ow_last_chunk4 = pad_ow % (4 * ow_per_oc);
3919 int ow_zero_tail4 = ow_last_chunk4
3920 ? (4 * ow_per_oc) - ow_last_chunk4 : 0;
3922 auto emit_out_pf = [&](int oi) {
3924 if (oi + def_step_size < ur_ow / ow_per_oc || !block_wraparound)
3925 mic_prefetcht0(ptr[reg_out
3926 + ((def_step_size + oi)
3927 * ow_per_oc * jcp.oc_block * jcp.typesize_in)]);
3929 assert(block_wraparound);
3930 assert(oi + def_step_size >= ur_ow / ow_per_oc);
3931 mic_prefetcht0(ptr[reg_out_save
3932 + ((oi + def_step_size - ur_ow / ow_per_oc)
3933 * ow_per_oc * jcp.oc_block * jcp.typesize_in)]);
3936 // XXX: This is an alternative prefetching strategy that
3937 // always prefetches the next row. Keeping it here for
3938 // future experiments (Roma)
3939 if (!block_wraparound)
3940 mic_prefetcht0(ptr[reg_out
3941 + (jcp.ow + oi) * jcp.oc_block * jcp.typesize_in]);
3943 mic_prefetcht0(ptr[reg_out + reg_ohs
3944 - ((h_block_size - 1) * jcp.ow
3945 - oi) * jcp.oc_block * jcp.typesize_in]);
3947 if (oi < num_out_l2_pfs_per_fma_step)
3948 mic_prefetcht1(ptr[reg_out_pf_l2
3949 + oi * jcp.oc_block * jcp.typesize_in]);
3952 auto emit_inp_pf = [&](int oi4, int ic1) {
3953 int pf_slot_idx = ic1 + oi4 / 4 * jcp.ic_block;
3954 int num_pf_slots = jcp.ic_block * ur_ow / 4;
3956 int num_pfs = num_inp_l1_pfs_per_fma_step
3957 + num_inp_l2_pfs_per_fma_step;
3958 int pf_freq = nstl::max(1, num_pf_slots / num_pfs);
3960 if (pf_slot_idx % pf_freq)
3963 int pf_idx = pf_slot_idx / pf_freq;
3965 if (pf_idx < num_inp_l2_pfs_per_fma_step)
3966 mic_prefetcht1(ptr[reg_inp_pf_l2
3967 + pf_idx * jcp.ic_block * jcp.typesize_in]);
3969 pf_idx -= num_inp_l2_pfs_per_fma_step;
3970 // prefetch the 'tail' of the cache line because most of
3971 // the accesses are not aligned
3972 mic_prefetcht0(ptr[reg_inp_pf_l1
3973 + pf_idx * jcp.ic_block * jcp.typesize_in
3974 + cache_line_size - jcp.typesize_in]);
3978 auto numloads = (jcp.ver == ver_vnni) ? 1 : 4;
3980 int steps = this_ur_ow / ow_per_oc;
3981 for (int oi4 = 0; oi4 < steps; oi4 += numloads) {
3982 for (int oi1 = 0; oi1 < numloads; oi1++) {
3984 if (!is_w_tail || oi < (this_ur_ow - ow_zero_tail4)/ow_per_oc) {
3985 vmovups(zmm_out(oi), out_addr(oi));
3988 auto zmm = zmm_out(oi);
3989 vpxord(zmm, zmm, zmm);
3993 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
3994 if (jcp.ver == ver_4fma) {
3995 v4fmaddps(zmm_ker(ic1),
3996 zmm_out(oi4), inp_addr(oi4, ic1));
3997 } else if (jcp.ver == ver_4vnni) {
3998 vp4dpwssd(zmm_ker(ic1),
3999 zmm_out(oi4), inp_addr(ow_per_oc*oi4, ic1));
4000 } else if (jcp.ver == ver_vnni) {
4001 vpdpwssd(zmm_ker(ic1),
4002 zmm_out(oi4), inp_addr(ow_per_oc*oi4, ic1, true));
4004 assert(!"unknown convolution version");
4006 emit_inp_pf(ow_per_oc * oi4, ic1);
4011 // Input is transposed and padded but we only access about jcp.iw
4012 // elements so use that to compute the # of cache lines in each 'row'
4014 = div_up(jcp.iw * jcp.typesize_in, cache_line_size) * jcp.ic_block;
4016 if (full_w_unroll) {
4017 emit_step(ow4u, num_inp_l1_pfs,
4018 num_inp_l2_pfs_per_fma_block,
4019 num_out_l2_pfs_per_fma_block, true);
4020 add(reg_inp_pf_l2, num_inp_l2_pfs_per_fma_block * cache_line_size);
4021 add(reg_out_pf_l2, num_out_l2_pfs_per_fma_block * cache_line_size);
4024 int num_w_iters = pad_ow / def_step_size;
4025 int num_w_iters_full = num_w_iters + has_w_tail;
4026 int num_inp_l1_pfs_per_fma_step
4027 = div_up(num_inp_l1_pfs, num_w_iters_full);
4028 int num_inp_l2_pfs_per_fma_step
4029 = div_up(num_inp_l2_pfs_per_fma_block, num_w_iters_full);
4030 int num_out_l2_pfs_per_fma_step
4031 = div_up(num_out_l2_pfs_per_fma_block, num_w_iters_full);
4032 mov(reg_i, num_w_iters);
4034 emit_step(def_step_size, num_inp_l1_pfs_per_fma_step,
4035 num_inp_l2_pfs_per_fma_step,
4036 num_out_l2_pfs_per_fma_step, false);
4037 add(reg_inp, def_step_size * jcp.typesize_in);
4038 add(reg_out, def_step_size * jcp.oc_block * jcp.typesize_in);
4040 num_inp_l1_pfs_per_fma_step * cache_line_size);
4042 num_inp_l2_pfs_per_fma_step * cache_line_size);
4044 num_out_l2_pfs_per_fma_step * cache_line_size);
4049 emit_step(def_step_size, num_inp_l1_pfs_per_fma_step,
4050 num_inp_l2_pfs_per_fma_step,
4051 num_out_l2_pfs_per_fma_step, true);
4053 num_inp_l2_pfs_per_fma_step * cache_line_size);
4055 num_out_l2_pfs_per_fma_step * cache_line_size);
4057 // reset reg_inp and reg_out because emit_h_loop expects
4058 // unmodified pointers
4059 int w_offset = num_w_iters * def_step_size;
4060 sub(reg_inp, w_offset * jcp.typesize_in);
4061 sub(reg_out, w_offset * jcp.oc_block * jcp.typesize_in);
4065 auto emit_h_loop = [&](int h_block_size,
4066 bool is_last_block, bool is_last_kh_kw_iter)
4068 Label h_loop, skip_h_loop;
4071 je(skip_h_loop, T_NEAR);
4075 ptr[reg_inp + jcp.tr_iw * jcp.ic_block * jcp.typesize_in]);
4076 emit_block(h_block_size,
4077 is_last_block, is_last_kh_kw_iter, false);
4079 add(reg_inp, jcp.tr_iw * jcp.ic_block * jcp.typesize_in);
4080 add(reg_out, pad_ow * jcp.oc_block * jcp.typesize_in);
4088 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++)
4089 mic_prefetcht0(ker_addr(ic1));
4091 lea(reg_inp_pf_l1, ptr[reg_inp_save + reg_kw * jcp.typesize_in]);
4092 emit_block(h_block_size, is_last_block, is_last_kh_kw_iter, true);
4095 auto emit_kh_kw_loop = [&](bool is_first_block, bool is_last_block,
4098 xor_(reg_kh, reg_kh);
4099 Label kh_loop, kh_loop_end;
4101 int last_oh_block_size
4102 = jcp.oh - rnd_up(jcp.oh - h_block_size, h_block_size);
4103 int oh_block_size = (is_last_block) ? last_oh_block_size : h_block_size;
4104 // NB1: t_pad <= oh_block_size and b_pad <= last_oh_block_size
4105 int ih_block_size = oh_block_size - 1 + jcp.kh
4106 - is_first_block * jcp.t_pad - is_last_block * jcp.b_pad;
4109 // determine starting indices for this block
4110 if (is_first_block) {
4111 xor_(reg_tmp, reg_tmp);
4112 mov(reg_ohs, jcp.t_pad);
4113 sub(reg_ohs, reg_kh);
4114 cmovb(reg_ohs, reg_tmp);
4116 mov(reg_ihs, reg_ohs);
4117 sub(reg_ihs, jcp.t_pad);
4118 add(reg_ihs, reg_kh);
4120 xor_(reg_ohs, reg_ohs);
4121 mov(reg_ihs, reg_kh);
4124 // determine effective size of block based on padding
4125 mov(reg_tmp, oh_block_size);
4126 sub(reg_tmp, reg_ohs);
4127 mov(reg_h, ih_block_size);
4128 sub(reg_h, reg_ihs);
4129 cmp(reg_tmp, reg_h);
4130 cmovb(reg_h, reg_tmp);
4134 jg(kh_loop_work, T_NEAR);
4136 // empty h loop for this jcp.kh:
4137 // - set the output to 0 if necessary
4139 // - jump to the end
4141 Label skip_ker_zeroing;
4143 // The reg_ker ptr has highest bit set if the output needs to be
4144 // zeroed. Those who have byte-aligned their data will suffer the
4146 // TODO: move the flag to a mask register? (Roma)
4148 jz(skip_ker_zeroing, T_NEAR);
4151 vpxord(zmm0, zmm0, zmm0);
4152 and_(reg_ker, ~1); // temporarily clear the zeroing flag
4153 mov(reg_tmp, jcp.kw);
4155 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++)
4156 vmovups(ker_addr(ic1), zmm0);
4157 add(reg_ker, jcp.oc_block * jcp.ic_block * jcp.typesize_out);
4159 jnz(zeroing_loop, T_NEAR);
4161 // restore the zeroing flag (it will be cleared after the end of
4162 // emit_kh_kw_loop, but we may need it until then)
4164 jmp(kh_loop_end, T_NEAR);
4166 L(skip_ker_zeroing);
4167 add(reg_ker, jcp.oc_block * jcp.ic_block * jcp.kw
4168 * jcp.typesize_out);
4169 jmp(kh_loop_end, T_NEAR);
4173 mul_by_const(reg_ihs, reg_tmp,
4174 jcp.tr_iw * jcp.ic_block * jcp.typesize_in);
4175 mul_by_const(reg_ohs, reg_tmp,
4176 pad_ow * jcp.oc_block * jcp.typesize_in);
4178 add(reg_inp, reg_ihs);
4179 add(reg_out, reg_ohs);
4182 xor_(reg_kw, reg_kw);
4184 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
4185 auto zmm = zmm_ker(ic1);
4186 vpxord(zmm, zmm, zmm);
4187 mic_prefetcht1(ker_addr(ic1));
4190 mov(reg_out_save, reg_out);
4191 mov(reg_inp_save, reg_inp);
4192 lea(reg_inp, ptr[reg_inp + reg_kw * jcp.typesize_in]);
4195 // XXX: Generate code with special prefetches when switching
4196 // blocks or at the end of the last block. Disabled to reduce
4197 // code size and because there's no performance benefit (Roma)
4198 Label regular_h_loop, end_h_loop;
4199 cmp(reg_kw, jcp.kw - 1);
4200 jne(regular_h_loop, T_NEAR);
4201 cmp(reg_kh, jcp.kh - 1);
4202 jne(regular_h_loop, T_NEAR);
4204 emit_h_loop(oh_block_size, is_last_block, true);
4205 jmp(end_h_loop, T_NEAR);
4208 emit_h_loop(oh_block_size, is_last_block, false);
4212 emit_h_loop(oh_block_size, is_last_block, false);
4215 mov(reg_out, reg_out_save);
4216 mov(reg_inp, reg_inp_save);
4219 // The reg_ker ptr has highest bit set if the output needs to
4220 // be zeroed. Those who have byte-aligned their data will
4221 // suffer the consiquences :(
4222 mov(reg_tmp, reg_ker);
4225 jnz(do_store, T_NEAR);
4227 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
4228 auto zmm = zmm_ker(ic1);
4229 if (jcp.ver == ver_4fma) {
4230 vaddps(zmm, ker_addr(ic1));
4231 } else if (jcp.ver == ver_4vnni || jcp.ver == ver_vnni) {
4232 vpaddd(zmm, zmm, ker_addr(ic1));
4234 assert(!"unknown convolution version");
4239 for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
4240 auto zmm = zmm_ker(ic1);
4241 vmovups(ker_addr(ic1), zmm);
4244 mov(reg_ker, reg_tmp);
4245 add(reg_ker, jcp.ic_block * jcp.oc_block * jcp.typesize_out);
4247 cmp(reg_kw, jcp.kw);
4251 sub(reg_inp, reg_ihs);
4252 sub(reg_out, reg_ohs);
4257 cmp(reg_kh, jcp.kh);
4262 mov(reg_inp, ptr[param + GET_OFF(src)]);
4263 mov(reg_out, ptr[param + GET_OFF(dst)]);
4264 mov(reg_ker, ptr[param + GET_OFF(filt)]);
4265 mov(reg_inp_pf_l2, ptr[param + GET_OFF(src_prf)]);
4266 mov(reg_out_pf_l2, ptr[param + GET_OFF(dst_prf)]);
4267 mov(reg_tmp, ptr[param + GET_OFF(channel)]);
4268 or_(reg_ker, reg_tmp);
4270 bool single_kh_kw_loop = (h_block_size == jcp.oh);
4272 size_t inp_row_step = jcp.tr_iw * jcp.ic_block * jcp.typesize_in;
4273 size_t first_inp_block_step = inp_row_step * (h_block_size - jcp.t_pad);
4274 size_t inp_block_step = inp_row_step * h_block_size;
4275 size_t out_block_step = pad_ow * jcp.oc_block * jcp.typesize_in
4278 if (!single_kh_kw_loop) {
4279 // Save the original prefetch pointers from the OpenMP driver
4280 vmovq(reg_inp_pf_save, reg_inp_pf_l2);
4281 vmovq(reg_out_pf_save, reg_out_pf_l2);
4282 mov(reg_inp_pf_l2, reg_inp);
4283 add(reg_inp_pf_l2, first_inp_block_step);
4284 mov(reg_out_pf_l2, reg_out);
4285 add(reg_out_pf_l2, out_block_step);
4287 emit_kh_kw_loop(true, single_kh_kw_loop, h_block_size);
4289 if (!single_kh_kw_loop) {
4290 size_t ker_reset_offset
4291 = jcp.oc_block * jcp.ic_block * jcp.typesize_out * jcp.kw * jcp.kh;
4292 sub(reg_ker, ker_reset_offset);
4293 and_(reg_ker, ~1); // Clear the zeroing flag for subsequent updates
4295 add(reg_inp, first_inp_block_step);
4296 add(reg_out, out_block_step);
4297 mov(reg_inp_pf_l2, reg_inp);
4298 add(reg_inp_pf_l2, inp_block_step);
4299 mov(reg_out_pf_l2, reg_out);
4300 add(reg_out_pf_l2, out_block_step);
4302 int num_innermost_iters = div_up(jcp.oh, h_block_size) - 2;
4303 if (num_innermost_iters > 0) {
4306 mov(reg_tmp_w, num_innermost_iters);
4307 kmovw(reg_h_block, reg_tmp_w);
4309 emit_kh_kw_loop(false, false, h_block_size);
4310 sub(reg_ker, ker_reset_offset);
4311 add(reg_inp, inp_row_step * h_block_size);
4312 add(reg_out, out_block_step);
4313 mov(reg_inp_pf_l2, reg_inp);
4314 add(reg_inp_pf_l2, inp_block_step);
4315 mov(reg_out_pf_l2, reg_out);
4316 add(reg_out_pf_l2, out_block_step);
4317 kmovw(reg_tmp_w, reg_h_block);
4319 kmovw(reg_h_block, reg_tmp_w);
4324 // Restore the original prefetch pointers that came from the OpenMP
4326 vmovq(reg_inp_pf_l2, reg_inp_pf_save);
4327 vmovq(reg_out_pf_l2, reg_out_pf_save);
4328 emit_kh_kw_loop(false, true, h_block_size);
4334 bool jit_avx512_common_conv_bwd_weights_kernel_f32
4335 ::flat_4ops_compute() {
4336 const auto &j = jcp;
4337 const bool ok = j.ver == ver_4fma && j.is_1stconv
4338 && everyone_is(0, j.dilate_h, j.dilate_w);
4339 if (!ok) return false;
4341 Reg64 reg_ptr_tr_src = r8;
4342 Reg64 reg_ptr_dst = r9;
4343 Reg64 reg_ptr_wei = r10;
4344 Reg64 reg_ptr_bia = r11;
4346 Reg64 reg_kh_step = rax;
4347 Reg64 reg_oh = abi_not_param1;
4350 Reg32 reg_flag_save = ebx;
4351 Reg32 reg_flag = esi;
4355 auto zmm_wei = [&](int kh, int kw) {
4356 return Zmm(8 + kh * j.kw + kw);
4358 auto zmm_dst = [&](int ow) {
4362 auto addr_tr_src = [&](int kh, int iw) {
4363 return ptr[reg_ptr_tr_src
4364 + (kh * j.stride_w * j.tr_ld + iw) * jcp.typesize_in];
4366 auto addr_dst = [&](int ow) {
4367 return ptr[reg_ptr_dst + ow * jcp.oc_block * jcp.typesize_in];
4369 auto addr_wei = [&](int kh, int kw) {
4370 return ptr[reg_ptr_wei + (kh * j.kw + kw) * j.oc_block
4371 * jcp.typesize_out];
4374 auto emit_fma_block = [&](int kh_step) {
4375 for (int kh = 0; kh < kh_step; ++kh) {
4376 for (int kw = 0; kw < j.kw; ++kw) {
4377 auto vwei = zmm_wei(kh, kw);
4378 vpxord(vwei, vwei, vwei);
4382 for (int ow = 0; ow < j.ow; ow += 4) {
4383 for (int _ow = ow; _ow < ow + 4; ++_ow) {
4384 auto vdst = zmm_dst(_ow);
4386 vmovups(vdst, addr_dst(_ow));
4388 vpxord(vdst, vdst, vdst);
4391 for (int kh = 0; kh < kh_step; ++kh) {
4392 for (int kw = 0; kw < j.kw; ++kw) {
4393 const int iw = ow + (kw % j.stride_w) * j.tr_ld
4394 + (kw / j.stride_w);
4395 v4fmaddps(zmm_wei(kh, kw), zmm_dst(ow),
4396 addr_tr_src(kh, iw));
4397 if (1 && kh == 0 && kw < 4) {
4398 prefetcht1(ptr[reg_ptr_dst
4399 + (j.ow + ow + kw) * jcp.oc_block
4400 * jcp.typesize_in]);
4402 if (j.with_bias && kh_step == 1) { /* [bwd_w:b:r1] */
4403 const int off = kw + 4 - j.kw;
4404 if (off >= 0 && ow + off < j.ow)
4405 vaddps(vbia, vbia, zmm_dst(ow + off));
4412 test(reg_flag, FLAG_MB_FIRST);
4413 jnz(l_store, T_NEAR);
4414 for (int kh = 0; kh < kh_step; ++kh) {
4415 for (int kw = 0; kw < j.kw; ++kw)
4416 vaddps(zmm_wei(kh, kw), addr_wei(kh, kw));
4419 for (int kh = 0; kh < kh_step; ++kh) {
4420 for (int kw = 0; kw < j.kw; ++kw)
4421 vmovups(addr_wei(kh, kw), zmm_wei(kh, kw));
4425 auto emit_kh_loop = [&]() {
4426 const int kh_step_rem = j.kh % j.kh_step;
4427 xor_(reg_kh, reg_kh);
4428 mov(reg_kh_step, j.kh_step);
4434 if (kh_step_rem != 0) {
4435 Label l_keep_kh_step;
4436 cmp(reg_kh, j.kh - j.kh_step);
4437 jle(l_keep_kh_step, T_NEAR);
4439 mov(reg_kh_step, kh_step_rem);
4440 emit_fma_block(kh_step_rem);
4441 jmp(l_done, T_NEAR);
4446 emit_fma_block(j.kh_step);
4450 add(reg_ptr_tr_src, j.kh_step * j.stride_w * j.tr_ld
4452 add(reg_ptr_wei, j.kh_step * j.kw * j.oc_block * jcp.typesize_out);
4453 add(reg_kh, j.kh_step);
4456 jl(l_kh_loop, T_NEAR);
4459 const int kh_steps = rnd_up(j.kh, j.kh_step);
4460 sub(reg_ptr_tr_src, kh_steps * j.stride_w * j.tr_ld * jcp.typesize_in);
4461 sub(reg_ptr_wei, kh_steps * j.kw * j.oc_block * jcp.typesize_out);
4464 auto emit_oh_loop = [&]() {
4469 Label l_restore_mb_flag, l_jump;
4472 je(l_restore_mb_flag, T_NEAR);
4474 and_(reg_flag, ~FLAG_MB_FIRST);
4475 jmp(l_jump, T_NEAR);
4477 L(l_restore_mb_flag);
4478 mov(reg_flag, reg_flag_save);
4484 add(reg_ptr_tr_src, j.stride_h * j.stride_w * j.tr_ld
4486 add(reg_ptr_dst, j.ow * j.oc_block * jcp.typesize_in);
4489 jnz(l_oh_loop, T_NEAR);
4493 auto emit_bia_store = [&]() {
4494 if (!j.with_bias) return;
4496 Label l_bia_store, l_bia_skip;
4497 test(reg_flag, FLAG_IC_FIRST);
4500 test(reg_flag, FLAG_MB_FIRST);
4501 jnz(l_bia_store, T_NEAR);
4502 vaddps(vbia, ptr[reg_ptr_bia]);
4504 vmovups(ptr[reg_ptr_bia], vbia);
4508 mov(reg_ptr_tr_src, ptr[param + GET_OFF(src)]);
4509 mov(reg_ptr_dst, ptr[param + GET_OFF(dst)]);
4510 mov(reg_ptr_wei, ptr[param + GET_OFF(filt)]);
4511 mov(reg_ptr_bia, ptr[param + GET_OFF(bias)]);
4512 mov(reg_flag_save, ptr[param + GET_OFF(flags)]);
4514 vpxord(vbia, vbia, vbia);
4521 void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_loop()
4523 if (flat_4ops_compute())
4525 if (compute_full_spat_loop())
4527 compute_oh_loop_common();
4530 void jit_avx512_common_conv_bwd_weights_kernel_f32::generate()
4534 mov(reg_input, ptr[param + GET_OFF(src)]);
4535 mov(reg_output, ptr[param + GET_OFF(dst)]);
4536 mov(reg_kernel, ptr[param + GET_OFF(filt)]);
4543 status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf(
4544 jit_conv_conf_t &jcp, const convolution_desc_t &cd,
4545 cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &diff_weights_pd,
4546 cpu_memory_t::pd_t &diff_bias_pd, cpu_memory_t::pd_t &diff_dst_pd) {
4547 if (!mayiuse(avx512_common))
4548 return status::unimplemented;
4550 const memory_desc_wrapper src_d(&src_pd);
4551 const memory_desc_wrapper diff_weights_d(&diff_weights_pd);
4552 const memory_desc_wrapper diff_bias_d(&diff_bias_pd);
4553 const memory_desc_wrapper diff_dst_d(&diff_dst_pd);
4555 const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
4556 int ndims = src_d.ndims();
4558 jcp = zero<decltype(jcp)>();
4560 jcp.simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
4562 jcp.prop_kind = cd.prop_kind;
4564 jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
4565 jcp.mb = src_d.dims()[0];
4567 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
4568 jcp.oc_without_padding = jcp.oc;
4569 jcp.ic = src_d.dims()[1] / jcp.ngroups;
4571 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
4572 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2];
4573 jcp.iw = src_d.dims()[ndims-1];
4574 jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
4575 jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2];
4576 jcp.ow = diff_dst_d.dims()[ndims-1];
4578 jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1;
4579 jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims-2];
4580 jcp.kw = diff_weights_d.dims()[with_groups + ndims-1];
4582 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
4583 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
4584 jcp.l_pad = cd.padding[0][ndims-3];
4586 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
4587 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
4588 jcp.stride_w = cd.strides[ndims-3];
4590 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
4591 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
4592 jcp.dilate_w = cd.dilates[ndims-3];
4594 const int kh_range = 1 + (jcp.kh - 1) * (jcp.dilate_h + 1);
4596 // general condition to simplify dilations
4597 && IMPLICATION(jcp.dilate_d != 0, jcp.stride_d == 1)
4598 && IMPLICATION(jcp.dilate_h != 0, jcp.stride_h == 1)
4599 && IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1)
4600 // special condition to simplify dilations in compute_oh_loop_common
4601 && IMPLICATION(jcp.dilate_h != 0, kh_range <= jcp.ih);
4603 return status::unimplemented;
4605 jcp.r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w
4606 + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
4607 jcp.b_pad = nstl::max(0, (jcp.oh - 1) * jcp.stride_h
4608 + (jcp.kh - 1) * (jcp.dilate_h + 1) - (jcp.ih + jcp.t_pad - 1));
4609 jcp.back_pad = nstl::max(0, (jcp.od - 1) * jcp.stride_d
4610 + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1));
4612 /* XXX: currently, does not support stride_d > 1 or dilation > 0 */
4614 if (jcp.stride_d > 1 || jcp.dilate_d > 0)
4615 return status::unimplemented;
4617 jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
4618 jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
4621 jcp.aligned_threads = 0;
4623 /* check for the 1st convolution */
4624 jcp.is_1stconv = is_1stconv(jcp);
4626 jcp.oc_block = jcp.simd_w;
4628 bool ok_to_pad_channels = true
4630 && src_d.data_type() == data_type::f32;
4632 if (ok_to_pad_channels)
4633 jcp.oc = rnd_up(jcp.oc, jcp.simd_w);
4635 if (jcp.oc % jcp.oc_block)
4636 return status::unimplemented;
4638 auto src_format = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
4639 auto wei_format = with_groups
4640 ? pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o)
4641 : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o);
4642 /* conditions on bias memory */
4643 jcp.with_bias = cd.diff_bias_desc.format != memory_format::undef;
4644 if (jcp.with_bias) {
4645 if (diff_bias_d.format() == any)
4646 CHECK(diff_bias_pd.set_format(x));
4647 if (diff_bias_d.format() != x)
4648 return status::unimplemented;
4651 jcp.nb_oc = jcp.oc / jcp.oc_block;
4653 if (diff_dst_d.format() == any)
4654 CHECK(diff_dst_pd.set_format(src_format));
4655 if (diff_dst_d.format() != src_format)
4656 return status::unimplemented;
4658 /* kernel applicability check wrt boundaries
4659 * the conditions are quite general across the kernels we have,
4660 * but ideally the check should belong to a specific kernel... */
4661 const int max_pad = ((jcp.kh - 1) * (jcp.dilate_h + 1) + 1) / 2;
4662 const bool boundaries_ok = true
4663 && jcp.t_pad <= max_pad
4664 && jcp.b_pad <= max_pad;
4666 return status::unimplemented;
4668 /* yet another common check */
4670 return status::unimplemented;
4672 /* setting register strategy */
4673 for (int ur_w = nstl::min(max_ur_w, jcp.ow); ur_w > 0; --ur_w) {
4674 if (jcp.ow % ur_w == 0) { jcp.ur_w = ur_w; break; }
4677 if (jcp.is_1stconv) {
4678 const auto want_src_format = pick(ndims - 3, ncw, nchw, ncdhw);
4679 if (src_d.format() == any)
4680 CHECK(src_pd.set_format(want_src_format));
4682 const bool src_ok = true
4683 && utils::everyone_is(data_type::f32,
4684 src_d.data_type(), diff_weights_d.data_type(),
4685 diff_dst_d.data_type())
4686 && one_of(jcp.ic, 1, 3)
4687 && IMPLICATION(jcp.ic == 1, one_of(src_d.format(), want_src_format,
4688 pick(ndims - 3, nwc, nhwc, ndhwc)))
4689 && IMPLICATION(jcp.ic != 1, src_d.format() == want_src_format)
4690 && jcp.ngroups == 1;
4692 return status::unimplemented;
4694 const int tr_ld = rnd_up(div_up(jcp.iw + jcp.l_pad + jcp.r_pad,
4696 const int kh_step = nstl::max((28 - jcp.with_bias) / jcp.kw, 1);
4697 const int kh_step_rem = jcp.kh % kh_step;
4698 const auto want_4fma_wfmt = with_groups
4699 ? pick(ndims - 3, gOiw16o, gOihw16o, gOidhw16o)
4700 : pick(ndims - 3, Oiw16o, Oihw16o, Oidhw16o);
4701 const bool use_4fma = true
4702 && one_of(ndims, 3, 4)
4703 && mayiuse(avx512_mic_4ops)
4704 && mkldnn_thr_syncable()
4705 && everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w)
4706 && everyone_is(0, jcp.l_pad, jcp.r_pad, jcp.t_pad, jcp.b_pad)
4707 && jcp.kw <= 28 - jcp.with_bias
4708 && jcp.stride_w == 4
4709 && tr_ld / jcp.simd_w <= 4 /* [bwd_w:tr_src:r1] */
4710 && IMPLICATION(jcp.with_bias, kh_step_rem == 1) /* [bwd_w:b:r1] */
4711 && IMPLICATION(diff_weights_d.format() != any,
4712 diff_weights_d.format() == want_4fma_wfmt);
4716 jcp.kh_step = kh_step;
4719 if (diff_weights_d.format() == any)
4720 CHECK(diff_weights_pd.set_format(want_4fma_wfmt));
4723 jcp.ic_block = jcp.ic;
4725 const auto want_wfmt = with_groups
4726 ? pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o)
4727 : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o);
4728 if (diff_weights_d.format() == any)
4729 CHECK(diff_weights_pd.set_format(want_wfmt));
4730 if (diff_weights_d.format() != want_wfmt)
4731 return status::unimplemented;
4734 jcp.nb_ic = jcp.ic / jcp.ic_block;
4735 jcp.src_fmt = src_d.format();
4737 if (src_d.format() == any)
4738 CHECK(src_pd.set_format(src_format));
4739 if (diff_weights_d.format() == any)
4740 CHECK(diff_weights_pd.set_format(wei_format));
4742 const bool ok = true
4743 && src_d.format() == src_format
4744 && diff_weights_d.format() == (wei_format);
4746 return status::unimplemented;
4748 jcp.ic_block = jcp.simd_w;
4749 if (ok_to_pad_channels)
4750 jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
4751 jcp.nb_ic = jcp.ic / jcp.ic_block;
4752 jcp.src_fmt = src_d.format();
4753 if ((mayiuse(avx512_mic_4ops) || mayiuse(avx512_core_vnni))
4754 && mkldnn_thr_syncable()
4755 && one_of(ndims, 3, 4)
4756 && jcp.stride_w == 1
4757 && everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w)
4758 && ((src_d.data_type() == data_type::s16
4759 && diff_weights_d.data_type() == data_type::s32
4760 && diff_dst_d.data_type() == data_type::s16))) {
4761 if (mayiuse(avx512_core_vnni)) jcp.ver = ver_vnni;
4762 else jcp.ver = ver_4vnni;
4763 } else if ((mayiuse(avx512_mic) || mayiuse(avx512_core))
4764 && utils::everyone_is(data_type::f32,
4765 src_d.data_type(), diff_weights_d.data_type(),
4766 diff_dst_d.data_type())) {
4768 if (one_of(ndims, 3, 4) && mayiuse(avx512_mic_4ops) && jcp.stride_w == 1 &&
4769 everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w) &&
4770 mkldnn_thr_syncable()) {
4774 return status::unimplemented;
4776 if (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)) {
4778 // XXX, BUGBUGBUG, but not a FIXME: this assumes that it's OK to
4779 // cross the right boundary. The only requirement is not to have
4780 // NaNs there because another multiplicand is always guaranteed to
4781 // be zero. This also may require the top-level driver to allocate
4782 // four extra guarding elements at the very end of the buffer.
4783 // I'm not proud of this hack, but it improves performance by
4784 // about 5-10% depending on the dimensions (Roma)
4786 // for vnni, that's results of performance tuning
4787 const int tr_round = (utils::one_of(jcp.ver, ver_4fma, ver_vnni))
4790 jcp.tr_iw = rnd_up(jcp.iw + jcp.kw - 1, tr_round);
4791 jcp.tr_src_num_guard_elems = tr_round; // upper bound
4793 if (utils::one_of(jcp.ver, ver_4vnni, ver_vnni)) {
4794 jcp.tr_ow = rnd_up(jcp.ow, 2);
4795 jcp.ur_w = jcp.tr_ow;
4800 if (utils::one_of(jcp.ver, ver_4vnni, ver_vnni)) {
4801 jcp.typesize_in = sizeof(int16_t);
4802 jcp.typesize_out = sizeof(int32_t);
4803 } else if (utils::one_of(jcp.ver, ver_4fma, ver_fma)) {
4804 jcp.typesize_in = sizeof(float);
4805 jcp.typesize_out = sizeof(float);
4807 return status::unimplemented;
4810 && jcp.ic % jcp.ic_block == 0
4811 && jcp.oc % jcp.oc_block == 0
4812 && jcp.ic <= src_d.blocking_desc().padding_dims[1]
4813 && jcp.oc <= diff_dst_d.blocking_desc().padding_dims[1]
4814 && jcp.ic <= diff_weights_d.blocking_desc().padding_dims[with_groups + 1]
4815 && jcp.oc <= diff_weights_d.blocking_desc().padding_dims[with_groups + 0];
4816 if (!args_ok) return status::unimplemented;
4819 int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
4820 balance(jcp, nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b);
4822 jcp.nthr_mb = nthr_mb;
4823 jcp.nthr_g = nthr_g;
4824 jcp.nthr_oc_b = nthr_oc_b;
4825 jcp.nthr_ic_b = nthr_ic_b;
4828 return status::success;
4831 void jit_avx512_common_conv_bwd_weights_kernel_f32::init_scratchpad(
4832 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
4833 if (utils::one_of(jcp.ver, ver_4fma, ver_4vnni, ver_vnni)) {
4834 if (jcp.is_1stconv) {
4835 const size_t tr_src_size =
4836 jcp.nthr / jcp.nthr_oc_b * jcp.ih * jcp.stride_w * jcp.tr_ld;
4837 scratchpad.book(key_conv_tr_src, jcp.typesize_in * tr_src_size);
4839 // XXX: See the comment about tr_iw and guarding elements in
4840 // jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf()
4841 const size_t max_nthr = jcp.nthr_mb * jcp.ngroups * jcp.nb_ic;
4842 const size_t min_tr_src_size_per_thr
4843 = jcp.ih * jcp.ic_block * jcp.tr_iw;
4844 const size_t tr_src_size = max_nthr * min_tr_src_size_per_thr
4845 + jcp.tr_src_num_guard_elems;
4846 scratchpad.book(key_conv_tr_src, jcp.typesize_in * tr_src_size);
4849 /* prepare synchronization contexts */
4850 if (jcp.nthr_oc_b > 1) {
4851 const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b;
4852 scratchpad.book(key_conv_tr_src_bctx,
4853 sizeof(simple_barrier::ctx_t) * tr_src_bctx_size);
4856 if (utils::one_of(jcp.ver, ver_4vnni, ver_vnni)) {
4857 const size_t tr_diff_dst_size = jcp.nthr_mb * jcp.ngroups
4858 * jcp.nb_oc * jcp.oc_block * jcp.tr_ow * jcp.oh;
4859 scratchpad.book(key_conv_tr_diff_dst,
4860 jcp.typesize_in * tr_diff_dst_size);
4862 /* prepare synchronization contexts */
4863 if (jcp.nthr_ic_b > 1) {
4864 const size_t tr_diff_dst_bctx_size = jcp.nthr / jcp.nthr_ic_b;
4865 scratchpad.book(key_conv_tr_diff_dst_bctx,
4866 sizeof(simple_barrier::ctx_t) * tr_diff_dst_bctx_size);
4871 if (jcp.nthr_mb > 1) {
4872 const int wei_size = jcp.ngroups * jcp.oc * jcp.ic
4873 * jcp.kh * jcp.kw * jcp.kd;
4874 const int bia_size = jcp.ngroups * jcp.oc;
4875 const size_t wei_bia_reduction_size = wei_size + bia_size;
4877 scratchpad.book(key_conv_wei_bia_reduction,
4878 jcp.typesize_out * wei_bia_reduction_size * (jcp.nthr_mb - 1));
4879 scratchpad.book(key_conv_wei_bia_reduction_bctx,
4880 sizeof(simple_barrier::ctx_t));
4883 if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
4884 scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc);
4887 void jit_avx512_common_conv_bwd_weights_kernel_f32::balance(
4888 const jit_conv_conf_t &j, int &nthr_, int &nthr_mb_, int &nthr_g_,
4889 int &nthr_oc_b_, int &nthr_ic_b_)
4891 nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1;
4893 const int max_threads = mkldnn_get_max_threads();
4895 if (max_threads < j.ngroups) {
4896 /* simplification... fortunately it doesn't hurt much */
4900 if (!mkldnn_thr_syncable()
4901 && utils::one_of(j.ver, ver_4fma, ver_4vnni, ver_vnni)) {
4902 // should not happen -- the driver is not ready
4903 // for TBB-like non-synchronous threading yet
4907 if (j.ver == ver_4fma && j.is_1stconv) {
4910 nthr_ic_b_ = nstl::min(j.nb_ic, max_threads);
4911 nthr_mb_ = nstl::min(max_threads / nthr_ic_b_, j.mb);
4912 nthr_ = nthr_mb_ * nthr_oc_b_ * nthr_ic_b_ * nthr_g_;
4916 nthr_g_ = j.ngroups;
4917 const int nthr = max_threads / nthr_g_;
4919 auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
4920 /* calculate per thread memory cost (read/write). high level optimizer
4921 * tries to minimize memory consumption. few notes:
4922 * (n1) unclear why, but that essentially helps first convolution...
4923 * (n2) assuming the reduction over minibatch is always there:
4924 * - instead of 8 it should be 5 here (write ~= 2 read):
4925 * kernel: temporal workspace 1 write
4926 * reduction: 1 read from workspace and 1 write to the diff_wei
4927 * - but experiments showed 8 works better than 5 or 6... */
4929 const int src_coef = j.ver == ver_4fma || j.ver == ver_vnni ? 4 : 1;
4930 const int dst_coef = 1;
4931 const int wei_coef = j.ver == ver_vnni ? 4 : 8;
4935 * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_)
4936 * div_up(j.nb_ic, nthr_ic_b) * j.ic_block * j.ih * j.iw * j.id
4937 / j.stride_d / j.stride_h / j.stride_w /* (n1) */
4939 * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_)
4940 * div_up(j.nb_oc, nthr_oc_b) * j.oc_block * j.oh * j.ow * j.od
4941 + wei_coef /* (n2) */
4942 * div_up(j.ngroups, nthr_g_)
4943 * div_up(j.nb_oc, nthr_oc_b) * div_up(j.nb_ic, nthr_ic_b)
4944 * j.kh * j.kw * j.kd * j.ic_block * j.oc_block;
4947 int best_mem_cost = calc_mem_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
4949 /* step 1: find the best thread distribution with lowest memory cost */
4950 const int nthr_mb_max = nstl::min(nthr, j.mb * j.od);
4951 for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
4952 const int nthr_par = nthr / nthr_mb;
4953 const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc);
4954 for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
4955 int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic);
4957 int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
4958 if (mem_cost <= best_mem_cost) {
4959 best_mem_cost = mem_cost;
4961 nthr_oc_b_ = nthr_oc_b;
4962 nthr_ic_b_ = nthr_ic_b;
4966 if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; }
4969 if (j.ver != ver_vnni && !mayiuse(avx512_mic)) {
4970 auto calc_comp_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
4972 * div_up(j.mb, nthr_mb)
4973 * div_up(j.ngroups, nthr_g_)
4974 * div_up(j.nb_oc, nthr_oc_b)
4975 * div_up(j.nb_ic, nthr_ic_b);
4978 /* step 2: search for a thread distribution with lower compute cost.
4980 * - memory cost cannot exceed 110% of the best found in the step 1
4981 * - unless compute cost is 133% lower than the current best case
4982 * note: both constants were found empirically */
4983 int best_comp_cost = calc_comp_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
4984 for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
4985 const int nthr_par = nthr / nthr_mb;
4986 const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc);
4987 for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
4988 int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic);
4989 int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
4990 int comp_cost = calc_comp_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
4992 const bool opt1 = comp_cost <= best_comp_cost
4993 && mem_cost < 1.1 * best_mem_cost;
4994 const bool opt2 = 4 * comp_cost <= 3 * best_comp_cost;
4997 best_comp_cost = comp_cost;
4999 nthr_oc_b_ = nthr_oc_b;
5000 nthr_ic_b_ = nthr_ic_b;
5004 if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; }
5008 if (nthr_mb_ > max_threads/2 && nthr_mb_ < max_threads)
5009 nthr_mb_ = nstl::min(j.mb * j.od, max_threads);
5010 nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_;
5012 assert(nthr_ <= max_threads);
5013 assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_mb_ == 1));
5016 template struct _jit_avx512_common_conv_fwd_kernel<Zmm>;
5017 template struct _jit_avx512_common_conv_fwd_kernel<Xmm>;
5023 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s