1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
18 #include "memory_tracking.hpp"
20 #include "type_helpers.hpp"
23 #include "cpu_memory.hpp"
25 #include "jit_avx512_core_x8s8s32x_conv_kernel.hpp"
27 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
33 using namespace mkldnn::impl::memory_format;
34 using namespace mkldnn::impl::memory_tracking::names;
35 using namespace mkldnn::impl::utils;
36 using namespace Xbyak;
39 void pick_loop_order(jit_conv_conf_t &jcp, int nthr)
41 jcp.loop_order = loop_cwgn;
42 if (jcp.ngroups > 1) {
43 jcp.loop_order = loop_ngcw;
45 jcp.loop_order = jcp.ndims == 3 ? loop_nwcg : loop_nhwcg;
50 template<typename Vmm>
51 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::prepare_output(int ur_w)
54 = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
55 for (int k = 0; k < nb_oc_block; k++)
56 for (int j = 0; j < ur_w; j++) {
57 Vmm vmm = vmm_out(j, k);
58 vpxord(vmm, vmm, vmm);
60 if (jcp.signed_input) {
61 xor_(reg_scratch, reg_scratch);
62 if (jcp.is_depthwise && !jcp.is_fast_depthwise) {
63 Reg32 _t32 = reg_scratch.cvt32();
64 mov(_t32, (uint32_t)128);
65 vpbroadcastd(vmm_shift, _t32);
67 Reg8 _t8 = reg_scratch.cvt8();
68 mov(_t8, (int8_t)128);
69 vpbroadcastb(vmm_shift, _t8);
74 template<typename Vmm>
75 const Vmm _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::
76 vmm_mask(const Vmm vmm_in, bool mask_flag, bool store) {
81 const Zmm _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>::
82 vmm_mask(const Zmm zmm_in, bool mask_flag, bool store) {
83 return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z)
88 template<typename Vmm>
89 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::cvt2ps(data_type_t type_in,
90 const Vmm vmm_in, const Operand &op, bool mask_flag) {
91 //const Vmm vmm = mask_flag ? vmm_in | ktail_mask | T_z : vmm_in;
92 const Vmm vmm = vmm_mask(vmm_in, mask_flag);
95 case data_type::s32: vmovups(vmm, op); break;
96 case data_type::s8: vpmovsxbd(vmm, op); break;
97 case data_type::u8: vpmovzxbd(vmm, op); break;
98 default: assert(!"unsupported data type");
100 if (type_in != data_type::f32)
101 vcvtdq2ps(vmm_in, vmm_in);
104 template<typename Vmm>
105 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::store_output(
106 int ur_w, bool last_oc_block_flag) {
108 = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
109 int oc_block = jcp.is_depthwise ? jcp.ch_block : jcp.oc_block;
111 mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
112 mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
113 if (jcp.signed_input)
114 mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]);
116 const auto &p = attr_.post_ops_;
117 const int sum_idx = p.find(primitive_kind::sum);
118 const float *p_sum_scale = nullptr;
120 const auto &p_entry = p.entry_[sum_idx];
121 p_sum_scale = &p_entry.sum.scale;
124 if (p_sum_scale && *p_sum_scale != 1.f)
125 mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
127 if (jcp.signed_input && jcp.ver != ver_vnni) {
128 /* put 'wei_adj_scale = 0.5' for bias calculation */
129 mov(reg_bias_alpha, float2int(jcp.wei_adj_scale));
130 vmovq(xmm_bias_alpha(), reg_bias_alpha);
131 vbroadcastss(vmm_bias_alpha(), xmm_bias_alpha());
134 for (int k = 0; k < nb_oc_block; k++) {
135 const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1;
136 int scale_offset = jcp.is_oc_scale * (sizeof(float) * k * oc_block);
138 int bias_offset = jcp.typesize_bia * k * oc_block;
139 auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
141 cvt2ps(jcp.bia_dt, vmm_bias, bias_addr, mask_flag);
142 if (jcp.signed_input && jcp.ver != ver_vnni)
144 vmulps(vmm_bias, vmm_bias, vmm_bias_alpha());
146 if (jcp.signed_input) {
147 int comp_offset = sizeof(int32_t) * k * oc_block;
148 auto comp_addr = EVEX_compress_addr(reg_compensation, comp_offset);
150 cvt2ps(data_type::s32, vmm_comp, comp_addr, mask_flag);
152 /* add to zmm_accum: compensation, bias and permute */
153 for (int j = 0; j < ur_w; j++) {
154 Vmm vmm = vmm_out(j, k);
155 if (jcp.is_fast_depthwise)
156 vpermd(zmm_out(j, k), zmm_permute, zmm_out(j, k));
158 if (jcp.signed_input)
159 vaddps(vmm, vmm, vmm_comp);
161 vaddps(vmm, vmm, vmm_bias);
163 const Vmm vmm_k = vmm_mask(vmm, mask_flag);
165 EVEX_compress_addr(reg_ptr_scales, scale_offset));
169 int eltwise_inj_idx = 0;
170 int depthwise_inj_idx = 0;
171 for (int i = 0; i < p.len_; i++) {
172 auto& post_op = p.entry_[i];
173 if (post_op.is_eltwise()) {
174 if (ur_w == jcp.ur_w)
175 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(0, nb_oc_block * jcp.ur_w);
177 for (int k = 0; k < nb_oc_block; k++)
178 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(k * jcp.ur_w, k * jcp.ur_w + ur_w);
181 } else if (post_op.is_depthwise()) {
182 mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
183 mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
185 add(reg_d_weights, ptr[param1 + GET_OFF(oc_off)]);
186 add(reg_d_bias, ptr[param1 + GET_OFF(oc_off)]);
188 for (int k = 0; k < nb_oc_block; k++) {
189 depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
190 k * jcp.ur_w, k * jcp.ur_w + ur_w, reg_d_weights, reg_d_bias);
192 add(reg_d_weights, oc_block * sizeof(float));
193 add(reg_d_bias, oc_block * sizeof(float));
197 } else if (post_op.is_sum(false)) {
198 for (int k = 0; k < nb_oc_block; k++) {
199 const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1;
200 for (int j = 0; j < ur_w; j++) {
201 int aux_output_offset
204 + j * jcp.oc_without_padding * jcp.ngroups);
205 auto addr = EVEX_compress_addr(reg_out, aux_output_offset);
206 Zmm zmm = zmm_out(j, k);
207 cvt2ps(jcp.dst_dt, vmm_prev_dst, addr, mask_flag);
208 if (*p_sum_scale == 1.f)
209 vaddps(zmm, vmm_prev_dst);
211 vfmadd231ps(zmm, vmm_prev_dst, zword_b[reg_ptr_sum_scale]);
217 /* write out register to output_addr */
218 for (int k = 0; k < nb_oc_block; k++) {
219 const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1;
220 for (int j = 0; j < ur_w; j++) {
221 Vmm vmm = vmm_out(j, k);
222 if (jcp.dst_dt == data_type::u8) {
223 vpxord(vmm_zero, vmm_zero, vmm_zero);
224 vmaxps(vmm, vmm_zero, vmm);
227 if (jcp.dst_dt != data_type::f32) {
228 /* Note: using Zmm for rounding in Xmm/Ymm kernel
229 because there is no instruction to do rounding
230 from Xmm/Ymm -> Xmm/Ymm.
231 Embedded rounding is not supported for Xmm.
232 TODO: maybe avoid Zmm if it helps performance.*/
233 Zmm zmm = zmm_out(j, k);
234 if (attr_.round_mode_ == round_mode::nearest)
235 vcvtps2dq(zmm | T_rn_sae, zmm);
236 else if (attr_.round_mode_ == round_mode::down)
237 vcvtps2dq(zmm | T_rd_sae, zmm);
239 assert(!"unimplemented");
243 for (int j = 0; j < ur_w; j++) {
244 int aux_output_offset = jcp.typesize_out
245 * (k * oc_block + j * jcp.oc_without_padding * jcp.ngroups);
246 auto addr = EVEX_compress_addr(reg_out, aux_output_offset);
248 Vmm vmm = vmm_out(j, k);
249 const Vmm r_vmm = vmm_mask(vmm, mask_flag, true);
251 switch (jcp.dst_dt) {
253 case data_type::s32: vmovups(addr, r_vmm); break;
254 case data_type::s8: vpmovsdb(addr, r_vmm); break;
255 case data_type::u8: vpmovusdb(addr, r_vmm); break;
256 default: assert(!"unknown dst_dt");
263 template <typename Vmm>
264 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::compute_ker_dw(
265 int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded) {
266 assert(!"invalid group blocking for depthwise convolution");
270 void _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>::compute_ker_dw(
271 int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded) {
273 auto input_spatial_index = [=](int oi, int ki) {
274 return (ki * (jcp.dilate_w + 1) + oi * jcp.stride_w - pad_l);
277 auto input_offset2 = [=](int ii, int ci) {
278 return jcp.typesize_in * (ii * jcp.ngroups + ci * jcp.ch_block);
281 auto input_offset3 = [=](int oi, int ci, int ki) {
282 return jcp.typesize_in * input_offset2(input_spatial_index(oi, ki), ci);
285 auto kernel_offset = [=](int ci, int ki) {
286 return jcp.typesize_in * ((ci * jcp.kh * jcp.kw + ki) * jcp.ch_block);
289 auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) {
290 // okay for depthwise since src is zero-extended
291 if (jcp.ver == ver_vnni) {
292 vpdpbusd(vreg_acc, vreg_src, vreg_wei);
294 vpmaddwd(zmm_tmp, vreg_src, vreg_wei);
295 vpaddd(vreg_acc, vreg_acc, zmm_tmp);
301 if (jcp.is_resrc_depthwise && !h_padded) {
302 // find bounds of input spatial indices
304 for (int ki = 0; ki < jcp.kw; ki++) {
305 int oi_start = get_ow_start(ki, pad_l);
306 int oi_end = get_ow_end(ur_w, ki, pad_r);
307 for (int oi = oi_start; oi < oi_end; oi++) {
308 int ii = input_spatial_index(oi, ki);
309 if (first || ii < ii_start)
311 if (first || ii > ii_end)
318 if (jcp.signed_input) {
319 vpxord(zmm_shifted_zero, zmm_shifted_zero, zmm_shifted_zero);
320 vpaddb(zmm_shifted_zero, zmm_shifted_zero, vmm_shift);
322 for (int ci = 0; ci < jcp.nb_ch_blocking; ci++) {
323 const bool mask_flag = last_ic_block_flag != no_last_block
324 && ci == jcp.nb_ch_blocking - 1;
325 if (jcp.is_resrc_depthwise && !h_padded) {
326 // now we can load input once and reuse up to jcp.kw times
327 for (int ii = ii_start; ii <= ii_end; ii++) {
328 int aux_input_offset = input_offset2(ii, ci);
329 const Zmm zmm_inp_tmp = zmm_inp(ii, jcp.nb_ch_blocking);
330 const Zmm zmm_inp_msk = mask_flag
331 ? zmm_inp_tmp | ktail_mask | T_z
333 if (jcp.is_fast_depthwise) {
335 vbroadcasti32x4(zmm_inp_msk,
336 EVEX_compress_addr(aux_reg_inp, aux_input_offset));
338 vpmovzxbd(zmm_inp_msk,
339 EVEX_compress_addr(aux_reg_inp, aux_input_offset));
341 if (jcp.signed_input)
342 vpaddb(zmm_inp_tmp, zmm_inp_tmp, vmm_shift);
345 for (int ki = 0; ki < jcp.kw; ki++) {
346 int aux_kernel_offset = kernel_offset(ci, ki);
347 if (jcp.is_fast_depthwise) {
348 vbroadcasti32x4(zmm_wei,
349 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
350 vmovdqu8(zmm_wei | kblend_mask | T_z, zmm_wei);
353 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
356 assert(jcp.signed_input);
357 for (int oi = 0; oi < ur_w; oi++)
358 compute(zmm_out(oi, ci), zmm_wei, zmm_shifted_zero);
360 const Zmm r_zmm_src = mask_flag ? zmm_src | ktail_mask : zmm_src;
361 int oi_start = get_ow_start(ki, pad_l);
362 int oi_end = get_ow_end(ur_w, ki, pad_r);
363 int start_ = jcp.signed_input ? 0 : oi_start;
364 int end_ = jcp.signed_input ? ur_w : oi_end;
365 for (int oi = start_; oi < end_; oi++) {
366 if (oi >= oi_start && oi < oi_end) {
367 if (jcp.is_resrc_depthwise) {
368 int ii = input_spatial_index(oi, ki);
369 zmm_src = zmm_inp(ii, jcp.nb_ch_blocking);
371 int aux_input_offset = input_offset3(oi, ci, ki);
372 if (jcp.is_fast_depthwise) {
374 vbroadcasti32x4(r_zmm_src,
375 EVEX_compress_addr(aux_reg_inp,
379 EVEX_compress_addr(aux_reg_inp,
382 if (jcp.signed_input)
383 vpaddb(zmm_src, zmm_src, vmm_shift);
385 } else if (jcp.signed_input) {
386 zmm_src = zmm_shifted_zero;
388 compute(zmm_out(oi, ci), zmm_wei, zmm_src);
395 template<typename Vmm>
396 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::compute_ker(int ur_w, int pad_l,
397 int pad_r, ic_block_t last_ic_block_flag, bool h_padded) {
398 if (jcp.is_depthwise)
399 return compute_ker_dw(ur_w, pad_l, pad_r, last_ic_block_flag, h_padded);
402 int stride_w = jcp.stride_w;
403 int ic_block = jcp.ic_block;
404 int oc_block = jcp.oc_block;
405 int ch_block_all = jcp.ch_block * ic_block * oc_block;
407 int nb_oc_block = jcp.nb_oc_blocking;
409 auto input_offset = [=](int oi, int ic, int ki) {
410 return jcp.typesize_in
411 * ((ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l)
412 * jcp.ic_without_padding * jcp.ngroups + 4 * ic);
414 auto kernel_offset = [=](int ii, int ic, int ki) {
415 return jcp.typesize_in
416 * ((ii * jcp.nb_ic * jcp.kh * jcp.kw + ki) * ch_block_all
417 + 4 * ic * oc_block);
419 auto compute = [=](Vmm vreg_acc, Vmm vreg_wei, Vmm vreg_src) {
420 if (jcp.ver == ver_vnni) {
421 vpdpbusd(vreg_acc, vreg_src, vreg_wei);
423 vpmaddubsw(vmm_tmp, vreg_src, vreg_wei);
424 vpmaddwd(vmm_tmp, vmm_tmp, vmm_one);
425 vpaddd(vreg_acc, vreg_acc, vmm_tmp);
429 for (int ki = 0; ki < kw; ki++) {
430 int jj_start = get_ow_start(ki, pad_l);
431 int jj_end = get_ow_end(ur_w, ki, pad_r);
432 int tail_size = jcp.ic_without_padding % 4;
433 int _start = (jcp.signed_input) ? 0 : jj_start;
434 int _end = (jcp.signed_input) ? ur_w : jj_end;
435 /* Skip the last loads of input if (ic%16)/4 < ic_block/4 */
436 int icb = (last_ic_block_flag != no_last_block)
437 ? div_up((jcp.ic_without_padding % ic_block), 4)
439 for (int ic = 0; ic < icb; ic++) {
440 if (h_padded == true) {
441 /* fill padded area with shifted values */
442 Vmm inp = vmm_inp(0,nb_oc_block);
443 vpxord(inp, inp, inp);
444 vpaddb(inp, inp, vmm_shift);
446 for (int jj = _start; jj < _end; jj++) {
447 int aux_input_offset = input_offset(jj, ic, ki);
448 if (jj >= jj_start && jj < jj_end) {
449 if (last_ic_block_flag == last_sp_block
450 && tail_size != 0 && ic == icb - 1) {
451 Xmm xmm_tmp = Xmm(vmm_inp(jj, nb_oc_block).getIdx());
452 for (int r = 0; r < tail_size; ++r)
453 vpinsrb(xmm_tmp, xmm_tmp,
454 ptr[aux_reg_inp + aux_input_offset + r], r);
455 vpbroadcastd(vmm_inp(jj, nb_oc_block), xmm_tmp);
457 vpbroadcastd(vmm_inp(jj, nb_oc_block),
459 aux_reg_inp, aux_input_offset));
461 if (jcp.signed_input)
462 vpaddb(vmm_inp(jj, nb_oc_block),
463 vmm_inp(jj, nb_oc_block), vmm_shift);
465 /* fill padded area with shifted values */
466 if (jcp.signed_input) {
467 Vmm inp = vmm_inp(jj, nb_oc_block);
468 vpxord(inp, inp, inp);
469 vpaddb(inp, inp, vmm_shift);
474 for (int ii = 0; ii < nb_oc_block; ii++) {
475 int aux_kernel_offset = kernel_offset(ii, ic, ki);
477 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
478 for (int jj = _start; jj < _end; jj++) {
479 Vmm inp = (h_padded == true)
480 ? vmm_inp(0,nb_oc_block) : vmm_inp(jj, nb_oc_block);
481 compute(vmm_out(jj, ii), vmm_wei, inp);
488 template<typename Vmm>
489 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::kh_loop(
490 int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag) {
491 Label kh_label, skip_kh_loop;
492 Label t_overflow_label, no_t_overflow_label,
493 b_overflow_label, no_b_overflow_label;
495 int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block;
496 int shift_kernel_ptr = jcp.typesize_in * jcp.kw * ch_block_all;
497 int shift_input_ptr = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw
498 * jcp.ic_without_padding * jcp.ngroups;
500 mov(aux_reg_inp, reg_inp);
501 mov(aux_reg_ker, reg_ker);
503 if (jcp.signed_input && jcp.ndims > 3) {
504 mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]);
505 cmp(reg_overflow, 0);
506 je(no_t_overflow_label, T_NEAR);
507 L(t_overflow_label); {
508 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
510 add(aux_reg_ker, shift_kernel_ptr);
512 cmp(reg_overflow, 0);
513 jg(t_overflow_label, T_NEAR);
515 L(no_t_overflow_label);
517 mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
518 if ((jcp.signed_input) || (!jcp.signed_input &&
519 (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad))) {
521 je(skip_kh_loop, T_NEAR);
524 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, false);
526 add(aux_reg_ker, shift_kernel_ptr);
527 add(aux_reg_inp, shift_input_ptr);
530 jg(kh_label, T_NEAR);
533 if (jcp.signed_input && jcp.ndims > 3) {
534 mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]);
535 cmp(reg_overflow, 0);
536 je(no_b_overflow_label, T_NEAR);
537 L(b_overflow_label); {
538 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
540 add(aux_reg_ker, shift_kernel_ptr);
542 cmp(reg_overflow, 0);
543 jg(b_overflow_label, T_NEAR);
545 L(no_b_overflow_label);
549 template<typename Vmm>
550 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::icb_loop(
551 int ur_w, int pad_l, int pad_r, bool is_last_sp_block)
553 prepare_output(ur_w);
557 mov(reg_icb, jcp.nb_ic);
559 if (jcp.ngroups % jcp.ch_block != 0 || jcp.ic_without_padding != jcp.ic) {
560 Label common_ker, end_ker;
562 cmp(reg_icb, 1); // The last IC block
563 jne(common_ker, T_NEAR);
565 kh_loop(ur_w, pad_l, pad_r,
566 is_last_sp_block ? last_sp_block : last_ic_block);
567 jmp(end_ker, T_NEAR);
570 kh_loop(ur_w, pad_l, pad_r, no_last_block);
574 kh_loop(ur_w, pad_l, pad_r, no_last_block);
577 int inp_step = jcp.ic_block;
578 int ker_step = jcp.kh * jcp.kw * jcp.oc_block * jcp.ic_block;
579 add(reg_inp, jcp.typesize_in * inp_step);
580 add(reg_ker, jcp.typesize_in * ker_step);
584 jg(icb_label, T_NEAR);
586 sub(reg_inp, jcp.typesize_in * inp_step * jcp.nb_ic);
587 sub(reg_ker, jcp.typesize_in * ker_step * jcp.nb_ic);
589 if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
590 Label common_store, end_store;
592 if (jcp.is_depthwise)
593 cmp(reg_oc_blocks, jcp.nb_ch - jcp.nb_ch_blocking);
595 cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
597 jne(common_store, T_NEAR);
599 store_output(ur_w, true); // last oc block
600 jmp(end_store, T_NEAR);
603 store_output(ur_w, false);
607 store_output(ur_w, false);
611 template<typename Vmm>
612 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::generate()
614 const auto &p = attr_.post_ops_;
615 for (int i = 0; i < p.len_; i++) {
616 auto &post_op = p.entry_[i];
617 if (post_op.is_eltwise()) {
618 eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<avx512_common>(
621 post_op.eltwise.alpha,
624 } else if (post_op.is_depthwise()) {
625 depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<avx512_common>(
627 post_op.depthwise.alg
632 Label permute_index_table;
633 int inp_shift_pad = jcp.typesize_in * (jcp.ur_w * jcp.stride_w - jcp.l_pad)
634 * jcp.ic_without_padding * jcp.ngroups;
635 int inp_shift_pad_second_block = -1 * jcp.typesize_in * jcp.l_pad
636 * jcp.ic_without_padding * jcp.ngroups;
637 int inp_shift = jcp.typesize_in *
638 (jcp.ur_w * jcp.stride_w * jcp.ic_without_padding
640 int out_shift = jcp.typesize_out *
641 (jcp.ur_w * jcp.oc_without_padding * jcp.ngroups);
644 if (jcp.is_depthwise) {
645 int idx = jcp.max_regs_ur - 1;
646 if (!jcp.is_resrc_depthwise)
647 zmm_src = Zmm(++idx);
648 if (jcp.ver != ver_vnni)
649 zmm_tmp = Zmm(++idx);
650 if (jcp.is_fast_depthwise)
651 zmm_permute = Zmm(++idx);
652 if (jcp.signed_input) {
653 zmm_shifted_zero = Zmm(++idx);
654 ++idx; // due to extra register used for shifts and compensations
656 assert(idx == ker_dw_reg_base_idx);
659 if (!jcp.is_depthwise && jcp.ver != ver_vnni) {
660 xor_(reg_scratch, reg_scratch);
661 Reg16 _t16 = reg_scratch.cvt16();
663 vpbroadcastw(vmm_one, _t16);
666 mov(reg_inp, ptr[param1 + GET_OFF(src)]);
667 mov(reg_out, ptr[param1 + GET_OFF(dst)]);
668 mov(reg_ker, ptr[param1 + GET_OFF(filt)]);
670 if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
671 int tail_size = jcp.is_depthwise
672 ? jcp.ngroups % jcp.ch_block
673 : jcp.oc_without_padding % jcp.oc_block;
674 int mask = (1 << tail_size) - 1;
675 mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
676 Reg32 regw_tmp = reg_oi.cvt32();
678 kmovw(ktail_mask, regw_tmp);
680 if (jcp.is_fast_depthwise) {
681 // prepare mask register for blending weights
682 mov(reg_scratch, 0x8888444422221111);
683 kmovq(kblend_mask, reg_scratch);
684 // load permute indices from data section
685 mov(reg_scratch, permute_index_table);
686 vmovdqu32(zmm_permute, ptr[reg_scratch]);
689 int r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w
690 + (jcp.kw - 1) * (jcp.dilate_w + 1)
691 - (jcp.iw + jcp.l_pad - 1));
692 int n_oi = jcp.ow / jcp.ur_w;
693 int r_pad1 = (jcp.ur_w * n_oi - 1) * jcp.stride_w
694 + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1);
696 if (jcp.nb_ow == 1) {
697 if (r_pad1 > 0 || jcp.ur_w_tail == 0)
700 xor_(reg_oi, reg_oi);
701 if (jcp.ow == jcp.ur_w) {
702 icb_loop(jcp.ur_w, jcp.l_pad, r_pad, true);
705 icb_loop(jcp.ur_w, jcp.l_pad, r_pad1, jcp.ur_w_tail == 0);
706 add(reg_inp, inp_shift_pad);
707 add(reg_out, out_shift);
708 if (jcp.ur_w_tail != 0) {
709 icb_loop(jcp.ur_w_tail, 0, r_pad, true);
713 icb_loop(jcp.ur_w, jcp.l_pad, 0, false);
714 add(reg_inp, inp_shift_pad);
715 add(reg_out, out_shift);
719 if ((jcp.l_pad <= 0 && n_oi > 0) || (jcp.l_pad > 0 && n_oi > 1))
723 icb_loop(jcp.ur_w, 0, 0, false);
724 add(reg_inp, inp_shift);
725 add(reg_out, out_shift);
729 jl(ow_loop_label, T_NEAR);
732 if (r_pad1 > 0 || jcp.ur_w_tail == 0) {
733 icb_loop(jcp.ur_w, 0, r_pad1, jcp.ur_w_tail == 0);
734 add(reg_inp, inp_shift);
735 add(reg_out, out_shift);
737 if (jcp.ur_w_tail != 0) {
738 icb_loop(jcp.ur_w_tail, 0, r_pad, true);
743 // ow block is only processed.
744 // Number of block is passed as parameter owb,
745 // and padding processing depends on this number.
746 Label end_label, last_oi_label, middle_ow_blocks_label, tail_label,
747 oi_loop_label, oi_loop_end_label;
749 assert(jcp.ow_block % jcp.ur_w == 0);
750 int n_oi_not_last_ow_block = jcp.ow_block / jcp.ur_w;
751 // to simplify code (and general regs usage),
752 // size of ow block must be >= 2 * ur_w
753 assert(n_oi_not_last_ow_block > 1);
754 int n_oi_next_last_ow_block = n_oi_not_last_ow_block;
755 int n_oi_first_ow_block = n_oi_not_last_ow_block;
756 int n_oi_last_ow_block
757 = (jcp.ow - jcp.ow_block * (jcp.nb_ow - 1)) / jcp.ur_w;
758 // prepare right padding
759 bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0;
760 bool first_ow_block_padded
761 = next_last_ow_block_padded && jcp.nb_ow == 2;
762 bool last_ow_block_padded
763 = (r_pad1 > 0 || jcp.ur_w_tail == 0) && n_oi_last_ow_block > 0;
765 if (last_ow_block_padded) n_oi_last_ow_block--;
766 else if (first_ow_block_padded) n_oi_first_ow_block--;
767 else if (next_last_ow_block_padded) n_oi_next_last_ow_block--;
769 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
770 cmp(reg_owb, 0); // is that the first ow-block ?
771 jg(middle_ow_blocks_label, T_NEAR);
773 // the first ow block, compute left padding
774 mov(reg_oi, n_oi_first_ow_block);
776 icb_loop(jcp.ur_w, jcp.l_pad, 0, false);
777 add(reg_inp, inp_shift_pad);
778 add(reg_out, out_shift);
782 jmp(oi_loop_label, T_NEAR);
784 // middle or last ow block entry
785 L(middle_ow_blocks_label);
788 // just to consider left padding, not compute
789 add(reg_inp, inp_shift_pad_second_block);
792 // set number of iteration for oi-loop
793 if (n_oi_last_ow_block != n_oi_not_last_ow_block) {
794 cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ?
795 mov(reg_oi, n_oi_last_ow_block);
796 je(oi_loop_label, T_NEAR);
799 if (n_oi_next_last_ow_block != n_oi_not_last_ow_block) {
800 cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
802 mov(reg_oi, n_oi_next_last_ow_block);
803 je(oi_loop_label, T_NEAR);
805 mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks
807 // oi loop w/o padding
810 jle(oi_loop_end_label, T_NEAR);
812 icb_loop(jcp.ur_w, 0, 0, false);
814 add(reg_inp, inp_shift);
815 add(reg_out, out_shift);
818 jmp(oi_loop_label, T_NEAR);
820 L(oi_loop_end_label);
822 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
823 cmp(reg_owb, 0); // first ow-block ?
824 if (first_ow_block_padded)
825 je(last_oi_label, T_NEAR);
827 je(end_label, T_NEAR);
829 cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
830 jl(end_label, T_NEAR);
831 if (next_last_ow_block_padded)
832 je(last_oi_label, T_NEAR);
834 je(end_label, T_NEAR);
836 // that is last block
837 if (!last_ow_block_padded)
838 jmp(tail_label, T_NEAR);
840 // last oi block with right padding
842 icb_loop(jcp.ur_w, 0, r_pad1, jcp.ur_w_tail == 0);
843 add(reg_inp, inp_shift);
844 add(reg_out, out_shift);
846 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
847 cmp(reg_owb, jcp.nb_ow - 1); // last ow_block?
848 jl(end_label, T_NEAR);
852 if (jcp.ur_w_tail != 0) {
853 icb_loop(jcp.ur_w_tail, 0, r_pad, true);
859 for (auto& inj : eltwise_injectors)
860 inj->prepare_table();
862 if (jcp.is_fast_depthwise) {
864 L(permute_index_table);
865 const uint32_t _idx[]
866 = { 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15 };
867 for (size_t i = 0; i < sizeof(_idx) / sizeof(_idx[0]); ++i)
872 bool jit_avx512_core_x8s8s32x_fwd_kernel::post_ops_ok(
873 jit_conv_conf_t &jcp, const primitive_attr_t &attr)
875 const auto &p = attr.post_ops_;
877 auto all_post_ops_supported = [&]() {
880 for (int i = 0; i < p.len_; i++) {
881 ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::sum, primitive_kind::eltwise, primitive_kind::depthwise);
885 auto count = [&](mkldnn::impl::primitive_kind_t kind) { return p.count(kind); };
887 return all_post_ops_supported() &&
888 count(primitive_kind::sum) <= 1;
891 status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
892 const convolution_desc_t &cd, cpu_memory_t::pd_t &src_pd,
893 cpu_memory_t::pd_t &weights_pd, cpu_memory_t::pd_t &dst_pd,
894 cpu_memory_t::pd_t &bias_pd, const primitive_attr_t &attr,
897 using namespace prop_kind;
899 const memory_desc_wrapper src_d(&src_pd);
900 const memory_desc_wrapper weights_d(&weights_pd);
901 const memory_desc_wrapper dst_d(&dst_pd);
902 const memory_desc_wrapper bias_d(&bias_pd);
904 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
905 int ndims = src_d.ndims();
906 bool is_1d = ndims == 3;
908 if (!(mayiuse(avx512_core)
909 && one_of(src_d.data_type(), data_type::u8, data_type::s8)
910 && weights_d.data_type() == data_type::s8
911 && one_of(dst_d.data_type(), data_type::f32, data_type::s32,
912 data_type::s8, data_type::u8)))
913 return status::unimplemented;
915 jcp = zero<decltype(jcp)>();
917 jcp.prop_kind = cd.prop_kind;
918 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
919 jcp.mb = src_d.dims()[0];
920 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
921 jcp.oc_without_padding = jcp.oc;
922 jcp.ic = src_d.dims()[1] / jcp.ngroups;
923 jcp.ic_without_padding = jcp.ic;
924 jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2];
925 jcp.iw = src_d.dims()[ndims - 1];
926 jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2];
927 jcp.ow = dst_d.dims()[ndims - 1];
928 jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2];
929 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
930 jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4];
931 jcp.l_pad = cd.padding[0][ndims - 3];
932 jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4];
933 jcp.stride_w = cd.strides[ndims - 3];
934 jcp.src_fmt = src_d.format();
935 jcp.with_bias = cd.bias_desc.format != memory_format::undef;
937 jcp.ur_h = 1; /* no code-unrolling by h so far */
939 jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4];
940 jcp.dilate_w = cd.dilates[ndims - 3];
942 jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false;
943 jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc);
945 if (jcp.is_depthwise) {
954 if (jcp.ngroups == 1) {
955 /* For non grouped convolutions, pad channels by 16 if needed */
956 jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
957 jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
958 } else if (!is_1d && jcp.ngroups != 1 && jcp.ic % jcp.ic_block != 0) {
959 /* For grouped convolutions, MKL-DNN doesn't support padding.
960 Use Ymm when channels per group is multiple of 8,
961 Xmm when channels per group is multiple of 4 */
962 jcp.ic_block = jcp.ic % 8 == 0 ? 8 : 4;
963 jcp.oc_block = jcp.ic_block;
965 if (jcp.ic % jcp.ic_block !=0 || jcp.oc % jcp.oc_block != 0)
966 return status::unimplemented;
969 jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
970 - (jcp.ih + jcp.t_pad - 1);
972 if (!post_ops_ok(jcp, attr))
973 return status::unimplemented;
975 jcp.ver = mayiuse(avx512_core_vnni) ? ver_vnni : ver_avx512_core;
976 jcp.is_fast_depthwise = true && jcp.is_depthwise && jcp.ver == ver_vnni
977 && jcp.ngroups % jcp.ch_block == 0; // for groups not multiple of 16
978 // would require byte masking
980 jcp.is_resrc_depthwise = jcp.is_depthwise && jcp.stride_w < jcp.kw
981 && jcp.kw < 4 && jcp.dilate_w == 0;
982 if (jcp.is_depthwise) {
983 jcp.max_regs_ur = 31 - jcp.is_fast_depthwise - !jcp.is_resrc_depthwise
984 - 2 * jcp.signed_input - (jcp.ver != ver_vnni);
986 jcp.max_regs_ur = jcp.ver == ver_vnni ? 31 : 28;
989 auto src_format = pick(ndims - 3, nwc, nhwc);
990 auto dst_format = pick(ndims - 3, nwc, nhwc);
991 #define pick_signed(fmt) (jcp.signed_input ? fmt##_s8s8 : fmt)
992 memory_format_t w_format;
993 if (jcp.ic_block == 16 || jcp.ch_block == 16) {
995 (with_groups ? (jcp.is_depthwise ? pick_signed(Goiw16g) :
996 pick_signed(gOIw4i16o4i)) :
997 pick_signed(OIw4i16o4i)) :
998 (with_groups ? (jcp.is_depthwise ? pick_signed(Goihw16g) :
999 pick_signed(gOIhw4i16o4i)) :
1000 pick_signed(OIhw4i16o4i));
1001 /* Non-grouped conv will always be padded by 16*/
1002 } else if (with_groups && jcp.ic_block == 8) {
1003 w_format = pick_signed(gOIhw2i8o4i);
1005 w_format = pick_signed(gOIhw4o4i);
1009 if (weights_d.format() == any)
1010 CHECK(weights_pd.set_format(w_format));
1011 if (weights_d.format() != w_format)
1012 return status::unimplemented;
1013 if (dst_d.format() == any)
1014 CHECK(dst_pd.set_format(dst_format));
1015 if (dst_d.format() != dst_format)
1016 return status::unimplemented;
1017 if (src_d.format() == any)
1018 CHECK(src_pd.set_format(src_format));
1019 if (src_d.format() != src_format)
1020 return status::unimplemented;
1021 if (jcp.with_bias) {
1022 if (bias_d.format() == any)
1023 CHECK(bias_pd.set_format(x));
1024 if (bias_d.format() != x)
1025 return status::unimplemented;
1028 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
1029 jcp.dst_dt = cd.dst_desc.data_type;
1031 jcp.typesize_in = types::data_type_size(src_d.data_type());
1032 jcp.typesize_out = types::data_type_size(dst_d.data_type());
1033 jcp.typesize_bia = jcp.with_bias
1034 ? types::data_type_size(bias_d.data_type())
1037 jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block);
1038 jcp.nb_ic = jcp.ic / jcp.ic_block;
1039 jcp.nb_oc = jcp.oc / jcp.oc_block;
1041 // Try to use 4 channel-groups at a time to avoid false sharing (depthwise)
1042 int nb_ch_blocking = 4;
1043 for ( /* init above */ ; nb_ch_blocking > 1; nb_ch_blocking--)
1044 if (jcp.nb_ch % nb_ch_blocking == 0)
1046 jcp.nb_ch_blocking = jcp.is_depthwise ? nb_ch_blocking : 1;
1048 // If OC blocking is incommensurate with the number of OC blocks (general
1049 // requirement for all convolutions), or if it results in an unrolling
1050 // factor smaller than the left padding (special requirement for SSD:fc6),
1051 // then search for a smaller OC blocking that satisfies both constraints.
1052 auto is_oc_blocking_ok = [&](int block) {
1053 int ur_w = nstl::min(jcp.ow, jcp.max_regs_ur / (block + 1));
1054 return jcp.nb_oc % block == 0
1055 && jcp.l_pad <= ur_w && jcp.ow % ur_w != 1;
1058 // choose nb_oc work chunk size for distribution within threads
1059 int max_threading_nb_oc_chunk = 4;
1060 // Performance improvements for googlenet_v3 and resnet_50 with mb = 1;
1061 // TODO: generalize this condition and rewrite it in appropriate manner
1062 if (jcp.ver == ver_vnni && jcp.mb == 1 && jcp.kh == 3 && jcp.kw == 3
1063 && jcp.stride_w == 1 && jcp.ic % 64 == 0)
1064 max_threading_nb_oc_chunk = 2;
1065 jcp.nb_oc_blocking_thr_chunk =
1066 nstl::min(max_threading_nb_oc_chunk, jcp.nb_oc);
1067 for (; jcp.nb_oc_blocking_thr_chunk > 1; jcp.nb_oc_blocking_thr_chunk--) {
1068 if (is_oc_blocking_ok(jcp.nb_oc_blocking_thr_chunk))
1072 // choose oc blocking for computational kernel
1073 jcp.nb_oc_blocking = jcp.nb_oc_blocking_thr_chunk;
1074 // Performance improvements for googlenet_v3 with mb = 1;
1075 // TODO: generalize this condition and rewrite it in appropriate manner
1076 const int size_treshold_for_nb_oc_blocking_reduction = 17;
1077 if (jcp.mb == 1 && jcp.ow <= size_treshold_for_nb_oc_blocking_reduction
1078 && jcp.stride_w == 1
1079 && !(jcp.kh == 1 && jcp.kw == 3)
1080 && !(jcp.kh >= 7 && jcp.oc % 64 == 0)) {
1081 const int max_nb_oc_blocking = 2;
1082 jcp.nb_oc_blocking = nstl::min(max_nb_oc_blocking, jcp.nb_oc);
1083 for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--)
1084 if (jcp.nb_oc_blocking_thr_chunk % jcp.nb_oc_blocking == 0
1085 && is_oc_blocking_ok(jcp.nb_oc_blocking))
1089 if (jcp.is_resrc_depthwise)
1090 jcp.ur_w = (jcp.max_regs_ur - jcp.kw + jcp.stride_w)
1091 / (jcp.nb_ch_blocking + jcp.stride_w);
1094 = jcp.max_regs_ur / (jcp.is_depthwise ? jcp.nb_ch_blocking
1095 : jcp.nb_oc_blocking + 1);
1096 if (jcp.ow < jcp.ur_w)
1098 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
1100 jcp.ow_block = jcp.ow;
1101 int base_work_amount = jcp.mb * jcp.nb_ch * jcp.oh
1102 * (jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk);
1104 = (float)base_work_amount / rnd_up(base_work_amount, nthreads);
1105 int max_nb_ow = div_up(jcp.ow, 2 * jcp.ur_w);
1106 for (int nb_ow = 1; nb_ow <= max_nb_ow; nb_ow++) {
1108 = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), jcp.ur_w), jcp.ow);
1109 if (ow_block < jcp.nb_oc_blocking_thr_chunk * jcp.oc_block
1110 && best_thr_eff > 0.8f)
1112 if (div_up(jcp.ow, ow_block) != nb_ow)
1114 auto work_amount = base_work_amount * nb_ow;
1115 float thr_eff = (float)work_amount / rnd_up(work_amount, nthreads);
1116 if (ow_block >= 2 * jcp.ur_w && thr_eff > 1.1f * best_thr_eff) {
1117 jcp.ow_block = ow_block;
1118 best_thr_eff = thr_eff;
1120 if (best_thr_eff > 0.9f)
1123 jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
1126 && jcp.oc % jcp.oc_block == 0
1127 && jcp.l_pad <= jcp.ur_w
1128 && IMPLICATION(!jcp.is_1stconv, jcp.ic % jcp.ic_block == 0);
1130 return status::unimplemented;
1132 int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
1133 + (jcp.kw - 1) * (jcp.dilate_w + 1)
1134 - (jcp.iw + jcp.l_pad - 1));
1135 if (r_pad_no_tail > jcp.ur_w)
1136 return status::unimplemented;
1138 pick_loop_order(jcp, nthreads);
1140 jcp.nb_ic_L2 = jcp.nb_ic;
1142 const auto &oscales = attr.output_scales_;
1143 jcp.is_oc_scale = oscales.mask_ == 1 << 1;
1145 assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
1147 jcp.wei_adj_scale = (jcp.signed_input) ? (1.f / 2.f) : 1.f;
1149 return status::success;
1152 void jit_avx512_core_x8s8s32x_fwd_kernel::init_scratchpad(
1153 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
1154 const primitive_attr_t &attr) {
1155 if (jcp.signed_input && jcp.ver != ver_vnni) {
1156 size_t count = nstl::max(attr.output_scales_.count_, jcp.ic_block);
1157 scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count);
1161 template struct _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>;
1162 template struct _jit_avx512_core_x8s8s32x_fwd_kernel<Ymm>;
1163 template struct _jit_avx512_core_x8s8s32x_fwd_kernel<Xmm>;
1168 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s