1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
18 #include "memory_tracking.hpp"
20 #include "type_helpers.hpp"
23 #include "cpu_memory.hpp"
25 #include "jit_avx512_core_x8s8s32x_conv_kernel.hpp"
27 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
33 using namespace mkldnn::impl::memory_format;
34 using namespace mkldnn::impl::memory_tracking::names;
35 using namespace mkldnn::impl::utils;
36 using namespace Xbyak;
39 void pick_loop_order(jit_conv_conf_t &jcp, int nthr)
41 jcp.loop_order = loop_cwgn;
42 if (jcp.ngroups > 1) {
43 jcp.loop_order = loop_ngcw;
45 jcp.loop_order = loop_nhwcg;
50 template<typename Vmm>
51 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::prepare_output(int ur_w)
54 = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
55 for (int k = 0; k < nb_oc_block; k++)
56 for (int j = 0; j < ur_w; j++) {
57 Vmm vmm = vmm_out(j, k);
58 vpxord(vmm, vmm, vmm);
60 if (jcp.signed_input) {
61 xor_(reg_scratch, reg_scratch);
62 if (jcp.is_depthwise && !jcp.is_fast_depthwise) {
63 Reg32 _t32 = reg_scratch.cvt32();
64 mov(_t32, (uint32_t)128);
65 vpbroadcastd(vmm_shift, _t32);
67 Reg8 _t8 = reg_scratch.cvt8();
68 mov(_t8, (int8_t)128);
69 vpbroadcastb(vmm_shift, _t8);
72 if (jcp.is_fast_depthwise) {
73 vpxord(zmm_zero_blend, zmm_zero_blend, zmm_zero_blend);
77 template<typename Vmm>
78 const Vmm _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::
79 vmm_mask(const Vmm vmm_in, bool mask_flag, bool store) {
84 const Zmm _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>::
85 vmm_mask(const Zmm zmm_in, bool mask_flag, bool store) {
86 return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z)
91 template<typename Vmm>
92 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::cvt2ps(data_type_t type_in,
93 const Vmm vmm_in, const Operand &op, bool mask_flag) {
94 //const Vmm vmm = mask_flag ? vmm_in | ktail_mask | T_z : vmm_in;
95 const Vmm vmm = vmm_mask(vmm_in, mask_flag);
98 case data_type::s32: vmovups(vmm, op); break;
99 case data_type::s8: vpmovsxbd(vmm, op); break;
100 case data_type::u8: vpmovzxbd(vmm, op); break;
101 default: assert(!"unsupported data type");
103 if (type_in != data_type::f32)
104 vcvtdq2ps(vmm_in, vmm_in);
107 template<typename Vmm>
108 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::store_output(
109 int ur_w, bool last_oc_block_flag) {
111 = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
112 int oc_block = jcp.is_depthwise ? jcp.ch_block : jcp.oc_block;
114 mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
115 mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
116 if (jcp.signed_input)
117 mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]);
119 const auto &p = attr_.post_ops_;
120 const int sum_idx = p.find(primitive_kind::sum);
121 const float *p_sum_scale = nullptr;
123 const auto &p_entry = p.entry_[sum_idx];
124 p_sum_scale = &p_entry.sum.scale;
127 if (p_sum_scale && *p_sum_scale != 1.f)
128 mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
130 if (jcp.signed_input && jcp.ver != ver_vnni) {
131 /* put 'wei_adj_scale = 0.5' for bias calculation */
132 mov(reg_bias_alpha, float2int(jcp.wei_adj_scale));
133 vmovq(xmm_bias_alpha(), reg_bias_alpha);
134 vbroadcastss(vmm_bias_alpha(), xmm_bias_alpha());
137 for (int k = 0; k < nb_oc_block; k++) {
138 const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1;
139 int scale_offset = jcp.is_oc_scale * (sizeof(float) * k * oc_block);
141 int bias_offset = jcp.typesize_bia * k * oc_block;
142 auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
144 cvt2ps(jcp.bia_dt, vmm_bias, bias_addr, mask_flag);
145 if (jcp.signed_input && jcp.ver != ver_vnni)
147 vmulps(vmm_bias, vmm_bias, vmm_bias_alpha());
149 if (jcp.signed_input) {
150 int comp_offset = sizeof(int32_t) * k * oc_block;
151 auto comp_addr = EVEX_compress_addr(reg_compensation, comp_offset);
153 cvt2ps(data_type::s32, vmm_comp, comp_addr, mask_flag);
155 /* add to zmm_accum: compensation, bias and permute */
156 for (int j = 0; j < ur_w; j++) {
157 Vmm vmm = vmm_out(j, k);
158 if (jcp.is_fast_depthwise)
159 vpermd(zmm_out(j, k), zmm_permute, zmm_out(j, k));
161 if (jcp.signed_input)
162 vaddps(vmm, vmm, vmm_comp);
164 vaddps(vmm, vmm, vmm_bias);
166 const Vmm vmm_k = vmm_mask(vmm, mask_flag);
168 EVEX_compress_addr(reg_ptr_scales, scale_offset));
172 int eltwise_inj_idx = 0;
173 int depthwise_inj_idx = 0;
174 for (int i = 0; i < p.len_; i++) {
175 auto& post_op = p.entry_[i];
176 if (post_op.is_eltwise()) {
177 if (ur_w == jcp.ur_w)
178 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(0, nb_oc_block * jcp.ur_w);
180 for (int k = 0; k < nb_oc_block; k++)
181 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(k * jcp.ur_w, k * jcp.ur_w + ur_w);
184 } else if (post_op.is_depthwise()) {
185 mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
186 mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
188 add(reg_d_weights, ptr[param1 + GET_OFF(oc_off)]);
189 add(reg_d_bias, ptr[param1 + GET_OFF(oc_off)]);
191 for (int k = 0; k < nb_oc_block; k++) {
192 depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
193 k * jcp.ur_w, k * jcp.ur_w + ur_w, reg_d_weights, reg_d_bias);
195 add(reg_d_weights, oc_block * sizeof(float));
196 add(reg_d_bias, oc_block * sizeof(float));
200 } else if (post_op.is_sum(false)) {
201 for (int k = 0; k < nb_oc_block; k++) {
202 const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1;
203 for (int j = 0; j < ur_w; j++) {
204 int aux_output_offset
207 + j * jcp.oc_without_padding * jcp.ngroups);
208 auto addr = EVEX_compress_addr(reg_out, aux_output_offset);
209 Zmm zmm = zmm_out(j, k);
210 cvt2ps(jcp.dst_dt, vmm_prev_dst, addr, mask_flag);
211 if (*p_sum_scale == 1.f)
212 vaddps(zmm, vmm_prev_dst);
214 vfmadd231ps(zmm, vmm_prev_dst, zword_b[reg_ptr_sum_scale]);
220 /* write out register to output_addr */
221 for (int k = 0; k < nb_oc_block; k++) {
222 const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1;
223 for (int j = 0; j < ur_w; j++) {
224 Vmm vmm = vmm_out(j, k);
225 if (jcp.dst_dt == data_type::u8) {
226 vpxord(vmm_zero, vmm_zero, vmm_zero);
227 vmaxps(vmm, vmm_zero, vmm);
230 if (jcp.dst_dt != data_type::f32) {
231 /* Note: using Zmm for rounding in Xmm/Ymm kernel
232 because there is no instruction to do rounding
233 from Xmm/Ymm -> Xmm/Ymm.
234 Embedded rounding is not supported for Xmm.
235 TODO: maybe avoid Zmm if it helps performance.*/
236 Zmm zmm = zmm_out(j, k);
237 if (attr_.round_mode_ == round_mode::nearest)
238 vcvtps2dq(zmm | T_rn_sae, zmm);
239 else if (attr_.round_mode_ == round_mode::down)
240 vcvtps2dq(zmm | T_rd_sae, zmm);
242 assert(!"unimplemented");
246 for (int j = 0; j < ur_w; j++) {
247 int aux_output_offset = jcp.typesize_out
248 * (k * oc_block + j * jcp.oc_without_padding * jcp.ngroups);
249 auto addr = EVEX_compress_addr(reg_out, aux_output_offset);
251 Vmm vmm = vmm_out(j, k);
252 const Vmm r_vmm = vmm_mask(vmm, mask_flag, true);
254 switch (jcp.dst_dt) {
256 case data_type::s32: vmovups(addr, r_vmm); break;
257 case data_type::s8: vpmovsdb(addr, r_vmm); break;
258 case data_type::u8: vpmovusdb(addr, r_vmm); break;
259 default: assert(!"unknown dst_dt");
266 template <typename Vmm>
267 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::compute_ker_dw(
268 int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded) {
269 assert(!"invalid group blocking for depthwise convolution");
273 void _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>::compute_ker_dw(
274 int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded) {
275 auto input_offset = [=](int oi, int ii, int ki) {
276 return jcp.typesize_in
277 * ((ki * (jcp.dilate_w + 1) + oi * jcp.stride_w - pad_l)
279 + ii * jcp.ch_block);
282 auto kernel_offset = [=](int ii, int ki) {
283 return jcp.typesize_in * ((ii * jcp.kh * jcp.kw + ki) * jcp.ch_block);
286 auto compute = [=](Zmm vreg_acc, Zmm vreg_wei,
288 // okay for depthwise since src is zero-extended
289 if (jcp.ver == ver_vnni) {
290 vpdpbusd(vreg_acc, vreg_src, vreg_wei);
292 // zmm_src is a tmp register that can be safely overwritten here
293 vpmaddwd(vreg_src, vreg_src, vreg_wei);
294 vpaddd(vreg_acc, vreg_acc, vreg_src);
298 for (int ki = 0; ki < jcp.kw; ki++) {
299 for (int ii = 0; ii < jcp.nb_ch_blocking; ii++) {
300 int aux_kernel_offset = kernel_offset(ii, ki);
301 if (jcp.is_fast_depthwise) {
302 vbroadcasti32x4(zmm_wei,
303 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
304 vpblendmb(zmm_wei | kblend_mask, zmm_zero_blend, zmm_wei);
307 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
310 if (jcp.ver == ver_vnni) {
311 vpxord(zmm_src, zmm_src, zmm_src);
312 vpaddb(zmm_src, zmm_src, vmm_shift);
314 for (int jj = 0; jj < ur_w; jj++) {
315 if (jcp.ver != ver_vnni) {
316 vpxord(zmm_src, zmm_src, zmm_src);
317 vpaddb(zmm_src, zmm_src, vmm_shift);
319 compute(zmm_out(jj, ii), zmm_wei, zmm_src);
322 const bool mask_flag = last_ic_block_flag != no_last_block
323 && ii == jcp.nb_ch_blocking - 1;
324 const Zmm r_zmm_src = mask_flag ? zmm_src | ktail_mask : zmm_src;
325 int jj_start = get_ow_start(ki, pad_l);
326 int jj_end = get_ow_end(ur_w, ki, pad_r);
327 int start_ = jcp.signed_input ? 0 : jj_start;
328 int end_ = jcp.signed_input ? ur_w : jj_end;
329 for (int jj = start_; jj < end_; jj++) {
330 if (jj >= jj_start && jj < jj_end) {
331 int aux_input_offset = input_offset(jj, ii, ki);
332 if (jcp.is_fast_depthwise) {
333 vbroadcasti32x4(zmm_src,
334 EVEX_compress_addr(aux_reg_inp, aux_input_offset));
337 EVEX_compress_addr(aux_reg_inp, aux_input_offset));
339 if (jcp.signed_input) {
340 vpaddb(zmm_src, zmm_src, vmm_shift);
343 if (jcp.signed_input) {
344 vpxord(zmm_src, zmm_src, zmm_src);
345 vpaddb(zmm_src, zmm_src, vmm_shift);
348 compute(zmm_out(jj, ii), zmm_wei, zmm_src);
355 template<typename Vmm>
356 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::compute_ker(int ur_w, int pad_l,
357 int pad_r, ic_block_t last_ic_block_flag, bool h_padded) {
358 if (jcp.is_depthwise)
359 return compute_ker_dw(ur_w, pad_l, pad_r, last_ic_block_flag, h_padded);
362 int stride_w = jcp.stride_w;
363 int ic_block = jcp.ic_block;
364 int oc_block = jcp.oc_block;
365 int ch_block_all = jcp.ch_block * ic_block * oc_block;
367 int nb_oc_block = jcp.nb_oc_blocking;
369 auto input_offset = [=](int oi, int ic, int ki) {
370 return jcp.typesize_in
371 * ((ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l)
372 * jcp.ic_without_padding * jcp.ngroups + 4 * ic);
374 auto kernel_offset = [=](int ii, int ic, int ki) {
375 return jcp.typesize_in
376 * ((ii * jcp.nb_ic * jcp.kh * jcp.kw + ki) * ch_block_all
377 + 4 * ic * oc_block);
379 auto compute = [=](Vmm vreg_acc, Vmm vreg_wei, Vmm vreg_src) {
380 if (jcp.ver == ver_vnni) {
381 vpdpbusd(vreg_acc, vreg_src, vreg_wei);
383 vpmaddubsw(vmm_tmp, vreg_src, vreg_wei);
384 vpmaddwd(vmm_tmp, vmm_tmp, vmm_one);
385 vpaddd(vreg_acc, vreg_acc, vmm_tmp);
389 for (int ki = 0; ki < kw; ki++) {
390 int jj_start = get_ow_start(ki, pad_l);
391 int jj_end = get_ow_end(ur_w, ki, pad_r);
392 int tail_size = jcp.ic_without_padding % 4;
393 int _start = (jcp.signed_input) ? 0 : jj_start;
394 int _end = (jcp.signed_input) ? ur_w : jj_end;
395 /* Skip the last loads of input if (ic%16)/4 < ic_block/4 */
396 int icb = (last_ic_block_flag != no_last_block)
397 ? div_up((jcp.ic_without_padding % ic_block), 4)
399 for (int ic = 0; ic < icb; ic++) {
400 if (h_padded == true) {
401 /* fill padded area with shifted values */
402 Vmm inp = vmm_inp(0,nb_oc_block);
403 vpxord(inp, inp, inp);
404 vpaddb(inp, inp, vmm_shift);
406 for (int jj = _start; jj < _end; jj++) {
407 int aux_input_offset = input_offset(jj, ic, ki);
408 if (jj >= jj_start && jj < jj_end) {
409 if (last_ic_block_flag == last_sp_block
410 && tail_size != 0 && ic == icb - 1) {
411 Xmm xmm_tmp = Xmm(vmm_inp(jj, nb_oc_block).getIdx());
412 for (int r = 0; r < tail_size; ++r)
413 vpinsrb(xmm_tmp, xmm_tmp,
414 ptr[aux_reg_inp + aux_input_offset + r], r);
415 vpbroadcastd(vmm_inp(jj, nb_oc_block), xmm_tmp);
417 vpbroadcastd(vmm_inp(jj, nb_oc_block),
419 aux_reg_inp, aux_input_offset));
421 if (jcp.signed_input)
422 vpaddb(vmm_inp(jj, nb_oc_block),
423 vmm_inp(jj, nb_oc_block), vmm_shift);
425 /* fill padded area with shifted values */
426 if (jcp.signed_input) {
427 Vmm inp = vmm_inp(jj, nb_oc_block);
428 vpxord(inp, inp, inp);
429 vpaddb(inp, inp, vmm_shift);
434 for (int ii = 0; ii < nb_oc_block; ii++) {
435 int aux_kernel_offset = kernel_offset(ii, ic, ki);
437 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
438 for (int jj = _start; jj < _end; jj++) {
439 Vmm inp = (h_padded == true)
440 ? vmm_inp(0,nb_oc_block) : vmm_inp(jj, nb_oc_block);
441 compute(vmm_out(jj, ii), vmm_wei, inp);
448 template<typename Vmm>
449 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::kh_loop(
450 int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag) {
451 Label kh_label, skip_kh_loop;
452 Label t_overflow_label, no_t_overflow_label,
453 b_overflow_label, no_b_overflow_label;
455 int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block;
456 int shift_kernel_ptr = jcp.typesize_in * jcp.kw * ch_block_all;
457 int shift_input_ptr = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw
458 * jcp.ic_without_padding * jcp.ngroups;
460 mov(aux_reg_inp, reg_inp);
461 mov(aux_reg_ker, reg_ker);
463 if (jcp.signed_input) {
464 mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]);
465 cmp(reg_overflow, 0);
466 je(no_t_overflow_label, T_NEAR);
467 L(t_overflow_label); {
468 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
470 add(aux_reg_ker, shift_kernel_ptr);
472 cmp(reg_overflow, 0);
473 jg(t_overflow_label, T_NEAR);
475 L(no_t_overflow_label);
477 mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
478 if ((jcp.signed_input) || (!jcp.signed_input &&
479 (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad))) {
481 je(skip_kh_loop, T_NEAR);
484 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, false);
486 add(aux_reg_ker, shift_kernel_ptr);
487 add(aux_reg_inp, shift_input_ptr);
490 jg(kh_label, T_NEAR);
493 if (jcp.signed_input) {
494 mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]);
495 cmp(reg_overflow, 0);
496 je(no_b_overflow_label, T_NEAR);
497 L(b_overflow_label); {
498 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
500 add(aux_reg_ker, shift_kernel_ptr);
502 cmp(reg_overflow, 0);
503 jg(b_overflow_label, T_NEAR);
505 L(no_b_overflow_label);
509 template<typename Vmm>
510 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::icb_loop(
511 int ur_w, int pad_l, int pad_r, bool is_last_sp_block)
513 prepare_output(ur_w);
517 mov(reg_icb, jcp.nb_ic);
519 if (jcp.ngroups % jcp.ch_block != 0 || jcp.ic_without_padding != jcp.ic) {
520 Label common_ker, end_ker;
522 cmp(reg_icb, 1); // The last IC block
523 jne(common_ker, T_NEAR);
525 kh_loop(ur_w, pad_l, pad_r,
526 is_last_sp_block ? last_sp_block : last_ic_block);
527 jmp(end_ker, T_NEAR);
530 kh_loop(ur_w, pad_l, pad_r, no_last_block);
534 kh_loop(ur_w, pad_l, pad_r, no_last_block);
537 int inp_step = jcp.ic_block;
538 int ker_step = jcp.kh * jcp.kw * jcp.oc_block * jcp.ic_block;
539 add(reg_inp, jcp.typesize_in * inp_step);
540 add(reg_ker, jcp.typesize_in * ker_step);
544 jg(icb_label, T_NEAR);
546 sub(reg_inp, jcp.typesize_in * inp_step * jcp.nb_ic);
547 sub(reg_ker, jcp.typesize_in * ker_step * jcp.nb_ic);
549 if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
550 Label common_store, end_store;
552 if (jcp.is_depthwise)
553 cmp(reg_oc_blocks, jcp.nb_ch - jcp.nb_ch_blocking);
555 cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
557 jne(common_store, T_NEAR);
559 store_output(ur_w, true); // last oc block
560 jmp(end_store, T_NEAR);
563 store_output(ur_w, false);
567 store_output(ur_w, false);
571 template<typename Vmm>
572 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::generate()
574 const auto &p = attr_.post_ops_;
575 for (int i = 0; i < p.len_; i++) {
576 auto &post_op = p.entry_[i];
577 if (post_op.is_eltwise()) {
578 eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<avx512_common>(
581 post_op.eltwise.alpha,
584 } else if (post_op.is_depthwise()) {
585 depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<avx512_common>(
587 post_op.depthwise.alg
592 Label permute_index_table;
593 int inp_shift_pad = jcp.typesize_in * (jcp.ur_w * jcp.stride_w - jcp.l_pad)
594 * jcp.ic_without_padding * jcp.ngroups;
595 int inp_shift_pad_second_block = -1 * jcp.typesize_in * jcp.l_pad
596 * jcp.ic_without_padding * jcp.ngroups;
597 int inp_shift = jcp.typesize_in *
598 (jcp.ur_w * jcp.stride_w * jcp.ic_without_padding
600 int out_shift = jcp.typesize_out *
601 (jcp.ur_w * jcp.oc_without_padding * jcp.ngroups);
604 if (jcp.is_depthwise) {
605 zmm_src = Zmm(jcp.max_regs_ur);
606 if (jcp.is_fast_depthwise) {
607 zmm_zero_blend = Zmm(jcp.max_regs_ur + 1);
608 zmm_permute = Zmm(jcp.max_regs_ur + 2);
612 if (!jcp.is_depthwise && jcp.ver != ver_vnni) {
613 xor_(reg_scratch, reg_scratch);
614 Reg16 _t16 = reg_scratch.cvt16();
616 vpbroadcastw(vmm_one, _t16);
619 mov(reg_inp, ptr[param1 + GET_OFF(src)]);
620 mov(reg_out, ptr[param1 + GET_OFF(dst)]);
621 mov(reg_ker, ptr[param1 + GET_OFF(filt)]);
623 if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
624 int tail_size = jcp.is_depthwise
625 ? jcp.ngroups % jcp.ch_block
626 : jcp.oc_without_padding % jcp.oc_block;
627 int mask = (1 << tail_size) - 1;
628 mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
629 Reg32 regw_tmp = reg_oi.cvt32();
631 kmovw(ktail_mask, regw_tmp);
633 if (jcp.is_fast_depthwise) {
634 // prepare mask register for blending weights
635 mov(reg_scratch, 0x8888444422221111);
636 kmovq(kblend_mask, reg_scratch);
637 // load permute indices from data section
638 mov(reg_scratch, permute_index_table);
639 vmovdqu32(zmm_permute, ptr[reg_scratch]);
642 int r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w
643 + (jcp.kw - 1) * (jcp.dilate_w + 1)
644 - (jcp.iw + jcp.l_pad - 1));
645 int n_oi = jcp.ow / jcp.ur_w;
646 int r_pad1 = (jcp.ur_w * n_oi - 1) * jcp.stride_w
647 + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1);
649 if (jcp.nb_ow == 1) {
650 if (r_pad1 > 0 || jcp.ur_w_tail == 0)
653 xor_(reg_oi, reg_oi);
654 if (jcp.ow == jcp.ur_w) {
655 icb_loop(jcp.ur_w, jcp.l_pad, r_pad, true);
658 icb_loop(jcp.ur_w, jcp.l_pad, r_pad1, jcp.ur_w_tail == 0);
659 add(reg_inp, inp_shift_pad);
660 add(reg_out, out_shift);
661 if (jcp.ur_w_tail != 0) {
662 icb_loop(jcp.ur_w_tail, 0, r_pad, true);
666 icb_loop(jcp.ur_w, jcp.l_pad, 0, false);
667 add(reg_inp, inp_shift_pad);
668 add(reg_out, out_shift);
672 if ((jcp.l_pad <= 0 && n_oi > 0) || (jcp.l_pad > 0 && n_oi > 1))
676 icb_loop(jcp.ur_w, 0, 0, false);
677 add(reg_inp, inp_shift);
678 add(reg_out, out_shift);
682 jl(ow_loop_label, T_NEAR);
685 if (r_pad1 > 0 || jcp.ur_w_tail == 0) {
686 icb_loop(jcp.ur_w, 0, r_pad1, jcp.ur_w_tail == 0);
687 add(reg_inp, inp_shift);
688 add(reg_out, out_shift);
690 if (jcp.ur_w_tail != 0) {
691 icb_loop(jcp.ur_w_tail, 0, r_pad, true);
696 // ow block is only processed.
697 // Number of block is passed as parameter owb,
698 // and padding processing depends on this number.
699 Label end_label, last_oi_label, middle_ow_blocks_label, tail_label,
700 oi_loop_label, oi_loop_end_label;
702 assert(jcp.ow_block % jcp.ur_w == 0);
703 int n_oi_not_last_ow_block = jcp.ow_block / jcp.ur_w;
704 // to simplify code (and general regs usage),
705 // size of ow block must be >= 2 * ur_w
706 assert(n_oi_not_last_ow_block > 1);
707 int n_oi_next_last_ow_block = n_oi_not_last_ow_block;
708 int n_oi_first_ow_block = n_oi_not_last_ow_block;
709 int n_oi_last_ow_block
710 = (jcp.ow - jcp.ow_block * (jcp.nb_ow - 1)) / jcp.ur_w;
711 // prepare right padding
712 bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0;
713 bool first_ow_block_padded
714 = next_last_ow_block_padded && jcp.nb_ow == 2;
715 bool last_ow_block_padded
716 = (r_pad1 > 0 || jcp.ur_w_tail == 0) && n_oi_last_ow_block > 0;
718 if (last_ow_block_padded) n_oi_last_ow_block--;
719 else if (first_ow_block_padded) n_oi_first_ow_block--;
720 else if (next_last_ow_block_padded) n_oi_next_last_ow_block--;
722 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
723 cmp(reg_owb, 0); // is that the first ow-block ?
724 jg(middle_ow_blocks_label, T_NEAR);
726 // the first ow block, compute left padding
727 mov(reg_oi, n_oi_first_ow_block);
729 icb_loop(jcp.ur_w, jcp.l_pad, 0, false);
730 add(reg_inp, inp_shift_pad);
731 add(reg_out, out_shift);
735 jmp(oi_loop_label, T_NEAR);
737 // middle or last ow block entry
738 L(middle_ow_blocks_label);
741 // just to consider left padding, not compute
742 add(reg_inp, inp_shift_pad_second_block);
745 // set number of iteration for oi-loop
746 if (n_oi_last_ow_block != n_oi_not_last_ow_block) {
747 cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ?
748 mov(reg_oi, n_oi_last_ow_block);
749 je(oi_loop_label, T_NEAR);
752 if (n_oi_next_last_ow_block != n_oi_not_last_ow_block) {
753 cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
755 mov(reg_oi, n_oi_next_last_ow_block);
756 je(oi_loop_label, T_NEAR);
758 mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks
760 // oi loop w/o padding
763 jle(oi_loop_end_label, T_NEAR);
765 icb_loop(jcp.ur_w, 0, 0, false);
767 add(reg_inp, inp_shift);
768 add(reg_out, out_shift);
771 jmp(oi_loop_label, T_NEAR);
773 L(oi_loop_end_label);
775 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
776 cmp(reg_owb, 0); // first ow-block ?
777 if (first_ow_block_padded)
778 je(last_oi_label, T_NEAR);
780 je(end_label, T_NEAR);
782 cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
783 jl(end_label, T_NEAR);
784 if (next_last_ow_block_padded)
785 je(last_oi_label, T_NEAR);
787 je(end_label, T_NEAR);
789 // that is last block
790 if (!last_ow_block_padded)
791 jmp(tail_label, T_NEAR);
793 // last oi block with right padding
795 icb_loop(jcp.ur_w, 0, r_pad1, jcp.ur_w_tail == 0);
796 add(reg_inp, inp_shift);
797 add(reg_out, out_shift);
799 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
800 cmp(reg_owb, jcp.nb_ow - 1); // last ow_block?
801 jl(end_label, T_NEAR);
805 if (jcp.ur_w_tail != 0) {
806 icb_loop(jcp.ur_w_tail, 0, r_pad, true);
812 for (auto& inj : eltwise_injectors)
813 inj->prepare_table();
815 if (jcp.is_fast_depthwise) {
817 L(permute_index_table);
818 const uint32_t _idx[]
819 = { 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15 };
820 for (size_t i = 0; i < sizeof(_idx) / sizeof(_idx[0]); ++i)
825 bool jit_avx512_core_x8s8s32x_fwd_kernel::post_ops_ok(
826 jit_conv_conf_t &jcp, const primitive_attr_t &attr)
828 using namespace primitive_kind;
829 const auto &p = attr.post_ops_;
831 auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
832 auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
833 auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(false); };
834 auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
838 case 1: return is_simple(0) || is_sum(0);
839 case 2: return (is_sum(0) && is_simple(1)) || (is_simple(0) && is_sum(1)) ||
840 (is_simple(0) && is_simple(1));
841 case 3: return (is_simple(0) && is_sum(1) && is_simple(2));
842 default: return false;
848 status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
849 const convolution_desc_t &cd, cpu_memory_t::pd_t &src_pd,
850 cpu_memory_t::pd_t &weights_pd, cpu_memory_t::pd_t &dst_pd,
851 cpu_memory_t::pd_t &bias_pd, const primitive_attr_t &attr,
854 using namespace prop_kind;
856 const memory_desc_wrapper src_d(&src_pd);
857 const memory_desc_wrapper weights_d(&weights_pd);
858 const memory_desc_wrapper dst_d(&dst_pd);
859 const memory_desc_wrapper bias_d(&bias_pd);
861 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
863 if (!(mayiuse(avx512_core)
864 && one_of(src_d.data_type(), data_type::u8, data_type::s8)
865 && weights_d.data_type() == data_type::s8
866 && one_of(dst_d.data_type(), data_type::f32, data_type::s32,
867 data_type::s8, data_type::u8)))
868 return status::unimplemented;
870 jcp = zero<decltype(jcp)>();
871 jcp.prop_kind = cd.prop_kind;
872 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
873 jcp.mb = src_d.dims()[0];
874 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
875 jcp.oc_without_padding = jcp.oc;
876 jcp.ic = src_d.dims()[1] / jcp.ngroups;
877 jcp.ic_without_padding = jcp.ic;
878 jcp.ih = src_d.dims()[2];
879 jcp.iw = src_d.dims()[3];
880 jcp.oh = dst_d.dims()[2];
881 jcp.ow = dst_d.dims()[3];
882 jcp.kh = weights_d.dims()[with_groups + 2];
883 jcp.kw = weights_d.dims()[with_groups + 3];
884 jcp.t_pad = cd.padding[0][0];
885 jcp.l_pad = cd.padding[0][1];
886 jcp.stride_h = cd.strides[0];
887 jcp.stride_w = cd.strides[1];
888 jcp.src_fmt = src_d.format();
889 jcp.with_bias = cd.bias_desc.format != memory_format::undef;
893 jcp.dilate_h = cd.dilates[0];
894 jcp.dilate_w = cd.dilates[1];
896 jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false;
897 jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc);
899 if (jcp.is_depthwise) {
908 if (jcp.ngroups == 1) {
909 /* For non grouped convolutions, pad channels by 16 if needed */
910 jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
911 jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
912 } else if (jcp.ngroups != 1 && jcp.ic % jcp.ic_block != 0) {
913 /* For grouped convolutions, MKL-DNN doesn't support padding.
914 Use Ymm when channels per group is multiple of 8,
915 Xmm when channels per group is multiple of 4 */
916 jcp.ic_block = jcp.ic % 8 == 0 ? 8 : 4;
917 jcp.oc_block = jcp.ic_block;
919 if (jcp.ic % jcp.ic_block !=0 || jcp.oc % jcp.oc_block != 0)
920 return status::unimplemented;
923 jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
924 - (jcp.ih + jcp.t_pad - 1);
926 if (!post_ops_ok(jcp, attr))
927 return status::unimplemented;
929 jcp.ver = mayiuse(avx512_core_vnni) ? ver_vnni : ver_avx512_core;
930 jcp.is_fast_depthwise = true && jcp.is_depthwise && jcp.ver == ver_vnni
931 && jcp.ngroups % jcp.ch_block == 0; // for groups not multiple of 16 would require byte masking for load from src
932 if (jcp.is_depthwise) {
933 jcp.max_regs_ur = jcp.is_fast_depthwise
934 ? (jcp.signed_input ? 27 : 28)
935 : (jcp.signed_input ? 29 : 30);
937 jcp.max_regs_ur = jcp.ver == ver_vnni ? 31 : 28;
940 memory_format_t w_format;
941 if (jcp.ic_block == 16 || jcp.ch_block == 16) {
942 w_format = with_groups
943 ? (jcp.is_depthwise ? (jcp.signed_input ? Goihw16g_s8s8 : Goihw16g)
944 : (jcp.signed_input) ? gOIhw4i16o4i_s8s8 : gOIhw4i16o4i)
945 : (jcp.signed_input) ? OIhw4i16o4i_s8s8 : OIhw4i16o4i;
946 /* Non-grouped conv will always be padded by 16*/
947 } else if (with_groups && jcp.ic_block == 8) {
948 w_format = jcp.signed_input ? gOIhw2i8o4i_s8s8 : gOIhw2i8o4i;
950 w_format = jcp.signed_input ? gOIhw4o4i_s8s8 : gOIhw4o4i;
953 if (weights_d.format() == any)
954 CHECK(weights_pd.set_format(w_format));
955 if (weights_d.format() != w_format)
956 return status::unimplemented;
958 if (dst_d.format() == any)
959 CHECK(dst_pd.set_format(nhwc));
960 if (dst_d.format() != nhwc)
961 return status::unimplemented;
962 if (src_d.format() == any)
963 CHECK(src_pd.set_format(nhwc));
964 if (src_d.format() != nhwc)
965 return status::unimplemented;
967 if (bias_d.format() == any)
968 CHECK(bias_pd.set_format(x));
969 if (bias_d.format() != x)
970 return status::unimplemented;
973 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
974 jcp.dst_dt = cd.dst_desc.data_type;
976 jcp.typesize_in = types::data_type_size(src_d.data_type());
977 jcp.typesize_out = types::data_type_size(dst_d.data_type());
978 jcp.typesize_bia = jcp.with_bias
979 ? types::data_type_size(bias_d.data_type())
982 jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block);
983 jcp.nb_ic = jcp.ic / jcp.ic_block;
984 jcp.nb_oc = jcp.oc / jcp.oc_block;
986 // Try to use 4 channel-groups at a time to avoid false sharing (depthwise)
987 jcp.nb_ch_blocking = jcp.is_depthwise
988 ? (jcp.nb_ch % 4 == 0 ? 4 : jcp.nb_ch % 2 == 0 ? 2 : 1)
991 // If OC blocking is incommensurate with the number of OC blocks (general
992 // requirement for all convolutions), or if it results in an unrolling
993 // factor smaller than the left padding (special requirement for SSD:fc6),
994 // then search for a smaller OC blocking that satisfies both constraints.
995 jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc);
996 for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--) {
997 int ur_w = jcp.max_regs_ur / (jcp.nb_oc_blocking + 1);
998 if (jcp.nb_oc % jcp.nb_oc_blocking == 0
999 && (jcp.l_pad <= ur_w
1000 && IMPLICATION(jcp.ow != 1, jcp.ow % ur_w != 1)))
1004 jcp.ur_w = jcp.max_regs_ur
1005 / (jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking + 1);
1006 if (jcp.ow < jcp.ur_w)
1008 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
1010 jcp.ow_block = jcp.ow;
1011 int base_work_amount
1012 = jcp.mb * jcp.nb_ch * jcp.oh * (jcp.nb_oc / jcp.nb_oc_blocking);
1014 = (float)base_work_amount / rnd_up(base_work_amount, nthreads);
1015 int max_nb_ow = div_up(jcp.ow, 2 * jcp.ur_w);
1016 for (int nb_ow = 1; nb_ow <= max_nb_ow; nb_ow++) {
1018 = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), jcp.ur_w), jcp.ow);
1019 if (ow_block < jcp.nb_oc_blocking * jcp.oc_block && best_thr_eff > 0.8f)
1021 if (div_up(jcp.ow, ow_block) != nb_ow)
1023 auto work_amount = base_work_amount * nb_ow;
1024 float thr_eff = (float)work_amount / rnd_up(work_amount, nthreads);
1025 if (ow_block >= 2 * jcp.ur_w && thr_eff > 1.1f * best_thr_eff) {
1026 jcp.ow_block = ow_block;
1027 best_thr_eff = thr_eff;
1029 if (best_thr_eff > 0.9f)
1032 jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
1035 && jcp.oc % jcp.oc_block == 0
1036 && jcp.l_pad <= jcp.ur_w
1037 && IMPLICATION(!jcp.is_1stconv, jcp.ic % jcp.ic_block == 0);
1039 return status::unimplemented;
1041 int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
1042 + (jcp.kw - 1) * (jcp.dilate_w + 1)
1043 - (jcp.iw + jcp.l_pad - 1));
1044 if (r_pad_no_tail > jcp.ur_w)
1045 return status::unimplemented;
1047 pick_loop_order(jcp, nthreads);
1049 jcp.nb_ic_L2 = jcp.nb_ic;
1051 const auto &oscales = attr.output_scales_;
1052 jcp.is_oc_scale = oscales.mask_ == 1 << 1;
1054 assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
1056 jcp.wei_adj_scale = (jcp.signed_input) ? (1.f / 2.f) : 1.f;
1058 return status::success;
1061 void jit_avx512_core_x8s8s32x_fwd_kernel::init_scratchpad(
1062 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
1063 const primitive_attr_t &attr) {
1064 if (jcp.signed_input && jcp.ver != ver_vnni) {
1065 size_t count = nstl::max(attr.output_scales_.count_, jcp.ic_block);
1066 scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count);
1070 template struct _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>;
1071 template struct _jit_avx512_core_x8s8s32x_fwd_kernel<Ymm>;
1072 template struct _jit_avx512_core_x8s8s32x_fwd_kernel<Xmm>;
1077 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s