1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "jit_avx512_core_x8s8s32x_deconvolution.hpp"
19 #define GET_OFF(field) offsetof(jit_deconv_call_s, field)
25 using namespace mkldnn::impl::status;
26 using namespace mkldnn::impl::memory_format;
27 using namespace mkldnn::impl::memory_tracking::names;
28 using namespace mkldnn::impl::utils;
29 using namespace Xbyak;
33 #define wht_blk_off(d, g, ...) \
34 (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) : \
35 (d).blk_off(__VA_ARGS__))
37 status_t jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_conf(
38 jit_conv_conf_t &jcp, const deconvolution_desc_t &cd,
39 cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &weights_pd,
40 cpu_memory_t::pd_t &dst_pd, const bool with_bias,
41 cpu_memory_t::pd_t &bias_pd, const primitive_attr_t &attr) {
42 const memory_desc_wrapper src_d(&src_pd);
43 const memory_desc_wrapper dst_d(&dst_pd);
44 const memory_desc_wrapper weights_d(&weights_pd);
45 const memory_desc_wrapper bias_d(&bias_pd);
47 if (!(mayiuse(avx512_core)
48 && one_of(src_d.data_type(), data_type::u8, data_type::s8)
49 && weights_d.data_type() == data_type::s8
50 && one_of(dst_d.data_type(), data_type::f32, data_type::s32,
51 data_type::s8, data_type::u8)))
52 return status::unimplemented;
54 jcp = zero<decltype(jcp)>();
56 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
57 jcp.signed_input = src_d.data_type() == data_type::s8;
59 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
60 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
61 jcp.ic = src_d.dims()[1] / jcp.ngroups;
62 jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups;
63 jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups;
64 jcp.is_depthwise = true && with_groups
65 && utils::everyone_is(1, jcp.ic_without_padding,
66 jcp.oc_without_padding);
68 /* TODO: future work, on hold until depthwise specialized kernel is
70 if (jcp.is_depthwise && jcp.signed_input)
71 return status::unimplemented;
73 const auto w_format = jcp.is_depthwise ? Goihw16g : with_groups ?
74 (jcp.signed_input ? gOIhw4i16o4i_s8s8 : gOIhw4i16o4i) :
75 (jcp.signed_input ? OIhw4i16o4i_s8s8 : OIhw4i16o4i);
77 if (dst_d.format() == any)
78 CHECK(dst_pd.set_format(nhwc));
79 if (dst_d.format() != nhwc)
80 return status::unimplemented;
81 if (src_d.format() == any)
82 CHECK(src_pd.set_format(nhwc));
83 if (src_d.format() != nhwc)
84 return status::unimplemented;
85 if (weights_d.format() == any)
86 CHECK(weights_pd.set_format(w_format));
87 if (weights_d.format() != w_format)
88 return status::unimplemented;
90 jcp.with_bias = with_bias;
92 if (bias_d.format() == any)
93 CHECK(bias_pd.set_format(x));
94 if (bias_d.format() != x)
95 return status::unimplemented;
98 jcp.ndims = dst_d.ndims();
99 jcp.prop_kind = cd.prop_kind;
100 jcp.mb = src_d.dims()[0];
101 jcp.ih = src_d.dims()[2];
102 jcp.iw = src_d.dims()[3];
103 jcp.oh = dst_d.dims()[2];
104 jcp.ow = dst_d.dims()[3];
105 jcp.kh = weights_d.dims()[with_groups + 2];
106 jcp.kw = weights_d.dims()[with_groups + 3];
107 jcp.t_pad = cd.padding[0][0];
108 jcp.l_pad = cd.padding[0][1];
109 jcp.stride_h = cd.strides[0];
110 jcp.stride_w = cd.strides[1];
111 jcp.src_fmt = src_d.format();
113 if (jcp.is_depthwise) {
122 if (jcp.ngroups == 1) {
123 jcp.oc = utils::rnd_up(jcp.oc_without_padding, jcp.oc_block);
124 jcp.ic = utils::rnd_up(jcp.ic_without_padding, jcp.ic_block);
126 if (jcp.ic % jcp.ic_block != 0)
127 return status::unimplemented;
130 jcp.dilate_h = cd.dilates[0];
131 jcp.dilate_w = cd.dilates[1];
133 if (!IMPLICATION(jcp.dilate_h, jcp.stride_h == 1)
134 || !IMPLICATION(jcp.dilate_w, jcp.stride_w == 1))
135 return status::unimplemented;
137 /* padding: bottom and right */
138 jcp.b_pad = (jcp.ih - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
139 - (jcp.oh + jcp.t_pad - 1);
140 jcp.r_pad = (jcp.iw - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1)
141 - (jcp.ow + jcp.l_pad - 1);
143 if (!post_ops_ok(jcp, attr))
144 return status::unimplemented;
146 const auto &p = attr.post_ops_;
147 const int eltwise_ind = p.find(primitive_kind::eltwise);
148 jcp.with_eltwise = eltwise_ind != -1;
149 if (jcp.with_eltwise)
150 jcp.eltwise = p.entry_[eltwise_ind].eltwise;
152 jcp.ver = ver_avx512_core;
153 if (mayiuse(avx512_core_vnni))
155 const auto &oscales = attr.output_scales_;
156 jcp.is_oc_scale = oscales.mask_ == 1 << 1;
158 assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
160 jcp.dst_dt = dst_d.data_type();
161 jcp.bia_dt = jcp.with_bias ? bias_d.data_type() : data_type::undef;
163 = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
164 jcp.typesize_in = types::data_type_size(src_d.data_type());
165 jcp.typesize_out = types::data_type_size(dst_d.data_type());
167 jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block);
168 jcp.nb_oc = jcp.oc / jcp.oc_block;
169 jcp.nb_ic = jcp.ic / jcp.ic_block;
171 /* kernel blocking params */
172 const int regs = jcp.ver == ver_vnni ? 30 : 28;
173 jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc);
174 for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--)
175 if (jcp.nb_oc % jcp.nb_oc_blocking == 0
176 && jcp.l_pad <= regs / (jcp.nb_oc_blocking + 1))
179 jcp.ur_w = regs / (jcp.nb_oc_blocking + 1);
180 int l_overflow = max(
181 0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w);
183 if (jcp.ow < jcp.ur_w) {
187 for (; jcp.ur_w >= 1; jcp.ur_w--) {
188 /* ur_w should be multiple of stride_w in order
189 to simplify logic for get_ow_start and get_ow_end */
190 bool is_multiple_of_stride = jcp.ur_w % jcp.stride_w == 0;
192 /* boundary conditions:
193 These conditions ensure all elements close to boundary
194 are computed in a single call of compute loop */
195 bool left_boundary_covered = jcp.ur_w >= l_overflow * jcp.stride_w;
196 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
197 int r_overflow_no_tail
198 = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
199 - max(0, jcp.r_pad) - jcp.ur_w_tail)
201 bool right_boundary_covered
202 = jcp.ur_w >= r_overflow_no_tail * jcp.stride_w;
204 if (is_multiple_of_stride && left_boundary_covered
205 && right_boundary_covered)
207 else if (jcp.ur_w == 1)
208 /* The boundary conditions above are also important
209 to maintain simplicity of calls to icb_loop,
210 if those conditions are not satisfied,
211 then special cases will need to be added
212 to use correct l_overflow/r_overflow values
213 when different iterations of compute loop
214 work on the locations close to boundary.
215 So to keep code simple, return unimplemented
216 for extreme case when a good ur_w cannot be found.
218 return status::unimplemented;
223 = (jcp.signed_input && (jcp.ver != ver_vnni)) ? (1.f / 2.f) : 1.f;
225 jcp.loop_order = jcp.ngroups > 1 ? loop_ngc : loop_cgn;
226 return status::success;
229 bool jit_avx512_core_x8s8s32x_deconv_fwd_kernel::maybe_eltwise(int position) {
230 using namespace primitive_kind;
231 const auto &p = attr_.post_ops_;
234 /* eltwise before sum */
235 return p.contain(eltwise, 0);
236 } else if (position == 1) {
237 /* eltwise after sum */
238 return p.contain(sum, 0) && p.contain(eltwise, 1);
243 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::compute_eltwise(int ur_w) {
245 = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
246 if (ur_w == jcp.ur_w)
247 eltwise_injector_->compute_vector_range(0, nb_oc_block * jcp.ur_w);
249 for (int k = 0; k < nb_oc_block; k++)
250 eltwise_injector_->compute_vector_range(
251 k * jcp.ur_w, k * jcp.ur_w + ur_w);
254 bool jit_avx512_core_x8s8s32x_deconv_fwd_kernel::post_ops_ok(
255 jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
256 using namespace primitive_kind;
257 const auto &p = attr.post_ops_;
259 auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
263 case 1: return is_eltwise(0) || p.contain(sum, 0);
265 return (p.contain(sum, 0) && is_eltwise(1))
266 || (p.contain(sum, 1) && is_eltwise(0));
267 default: return false;
273 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_scratchpad(
274 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
275 const primitive_attr_t &attr) {
276 if (jcp.signed_input && jcp.ver != ver_vnni) {
277 size_t count = nstl::max(attr.output_scales_.count_, 16);
278 scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count);
282 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::compute_ker(int ur_w,
283 int l_overflow, int r_overflow, ker_block_t last_ic_block_flag,
286 const int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block;
287 const int ur_w_stride = jcp.signed_input ? 1 : jcp.stride_w;
289 auto src_offset = [=](int oj, int icb, int ki) {
290 return jcp.typesize_in
291 * (((oj + jcp.l_pad - ki * (jcp.dilate_w + 1)) / jcp.stride_w)
292 * jcp.ngroups * jcp.ic_without_padding
296 auto kernel_offset = [=](int ocb, int icb, int ki) {
297 return jcp.typesize_in
298 * (ocb * jcp.nb_ic * jcp.kh * jcp.kw * ch_block_all
299 + icb * jcp.oc_block * jcp.ic_block / 4
300 + ki * ch_block_all);
303 auto compute = [=](zmm_t vreg_acc, zmm_t vreg_wei, zmm_t vreg_src) {
304 if (jcp.ver == ver_vnni) {
305 vpdpbusd(vreg_acc, vreg_src, vreg_wei);
306 } else if (jcp.is_depthwise) {
307 vpmulld(zmm_tmp, vreg_src, vreg_wei);
308 vpaddd(vreg_acc, vreg_acc, zmm_tmp);
310 vpmaddubsw(zmm_tmp, vreg_src, vreg_wei);
311 vpmaddwd(zmm_tmp, zmm_tmp, zmm_one);
312 vpaddd(vreg_acc, vreg_acc, zmm_tmp);
316 for (int ki = 0; ki < jcp.kw; ki++) {
318 int jj_start = get_ow_start(ki, l_overflow);
319 int jj_end = get_ow_end(ur_w, ki, r_overflow);
321 int _start = (jcp.signed_input) ? 0 : jj_start;
322 int _end = (jcp.signed_input) ? ur_w : jj_end;
324 int tail_size = jcp.ic_without_padding % 4;
325 int n_ic_blocks = jcp.is_depthwise ?
327 (last_ic_block_flag & ~no_last_block ?
328 div_up(jcp.ic_without_padding % jcp.ic_block,
332 for (int icb1 = 0; icb1 < n_ic_blocks; icb1++) {
333 if (h_padded == true) {
334 /* fill padded area with shifted values */
335 Zmm inp = zmm_inp(0, jcp.nb_oc_blocking);
336 vpxord(inp, inp, inp);
337 vpsubb(inp, inp, zmm_shift);
340 for (int jj = _start; jj < _end; jj += ur_w_stride) {
342 int aux_src_off = src_offset(jj, icb1, ki);
344 if (jj >= jj_start && jj < jj_end
345 && ((jj + jcp.l_pad - ki) % jcp.stride_w == 0)) {
346 if (jcp.is_depthwise) {
347 vpmovzxbd(zmm_inp(jj, jcp.nb_oc_blocking),
349 aux_reg_src, aux_src_off));
350 } else if ((last_ic_block_flag & last_sp_block)
351 && tail_size != 0 && icb1 == n_ic_blocks - 1) {
352 xmm_t xmm_tmp = xmm_t(
353 zmm_inp(jj, jcp.nb_oc_blocking).getIdx());
354 for (int r = 0; r < tail_size; ++r)
355 vpinsrb(xmm_tmp, xmm_tmp,
356 ptr[aux_reg_src + aux_src_off + r], r);
358 zmm_inp(jj, jcp.nb_oc_blocking), xmm_tmp);
360 vpbroadcastd(zmm_inp(jj, jcp.nb_oc_blocking),
362 aux_reg_src, aux_src_off));
364 if (jcp.signed_input)
365 vpsubb(zmm_inp(jj, jcp.nb_oc_blocking),
366 zmm_inp(jj, jcp.nb_oc_blocking), zmm_shift);
368 /* fill padded area with shifted values */
369 if (jcp.signed_input) {
370 Zmm inp = zmm_inp(jj, jcp.nb_oc_blocking);
371 vpxord(inp, inp, inp);
372 vpsubb(inp, inp, zmm_shift);
377 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
378 int aux_filt_off = kernel_offset(ocb, icb1, ki);
380 if (_end - _start > 0) {
381 if (jcp.is_depthwise)
383 EVEX_compress_addr(aux_reg_filt, aux_filt_off));
386 EVEX_compress_addr(aux_reg_filt, aux_filt_off));
388 for (int jj = _start; jj < _end; jj += ur_w_stride) {
389 Zmm inp = (h_padded == true) ?
390 zmm_inp(0, jcp.nb_oc_blocking) :
391 zmm_inp(jj, jcp.nb_oc_blocking);
392 compute(zmm_out(jj, ocb), zmm_wei, inp);
399 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::kh_loop(int ur_w,
400 int l_overflow, int r_overflow, ker_block_t last_ic_block_flag) {
402 int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block;
403 int shift_src_ih = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw
404 * jcp.ngroups * jcp.ic_without_padding;
405 const int stride_h = jcp.signed_input ? 1 : jcp.stride_h;
406 int shift_filt_kh = jcp.typesize_in * jcp.kw * ch_block_all * stride_h;
408 Label kh_loop_label, skip_kh_loop;
409 Label t_overflow_label, no_t_overflow_label, b_overflow_label,
412 mov(aux_reg_src, reg_src);
413 mov(aux_reg_filt, reg_filt);
415 if (jcp.signed_input) {
416 /* Weights are transposed, so first compute 'bottom' padding. */
417 mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]);
418 cmp(reg_overflow, 0);
419 je(no_b_overflow_label, T_NEAR);
420 L(b_overflow_label); {
421 compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
423 add(aux_reg_filt, shift_filt_kh);
425 cmp(reg_overflow, 0);
426 jg(b_overflow_label, T_NEAR);
428 L(no_b_overflow_label);
431 mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]);
433 if (jcp.signed_input || ((!jcp.signed_input)
434 && ((min(jcp.t_pad, jcp.b_pad) < 0)
435 || ((jcp.kh - 1) * (jcp.dilate_h + 1)
436 < nstl::max(jcp.t_pad, jcp.b_pad))))) {
438 je(skip_kh_loop, T_NEAR);
442 compute_ker(ur_w, l_overflow, r_overflow, last_ic_block_flag, false);
443 sub(aux_reg_src, shift_src_ih);
444 add(aux_reg_filt, shift_filt_kh);
447 /* Insert weight compensation in stride 'holes' */
448 if (jcp.signed_input && jcp.stride_h > 1) {
452 je(skip_kh_loop, T_NEAR);
453 mov(reg_comp_strides, jcp.stride_h - 1);
457 ur_w, 0, 0, last_ic_block_flag, true);
458 add(aux_reg_filt, shift_filt_kh);
459 dec(reg_comp_strides);
460 cmp(reg_comp_strides, 0);
461 jg(kh_comp_loop, T_NEAR);
465 jg(kh_loop_label, T_NEAR);
468 if (jcp.signed_input) {
469 mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]);
470 cmp(reg_overflow, 0);
471 je(no_t_overflow_label, T_NEAR);
472 L(t_overflow_label); {
473 compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
475 add(aux_reg_filt, shift_filt_kh);
477 cmp(reg_overflow, 0);
478 jg(t_overflow_label, T_NEAR);
480 L(no_t_overflow_label);
484 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::prepare_output(int ur_w) {
485 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
486 for (int ur = 0; ur < ur_w; ur++) {
487 zmm_t zmm = zmm_out(ur, ocb);
488 vpxord(zmm, zmm, zmm);
491 if (jcp.signed_input) {
492 xor_(reg_scratch, reg_scratch);
493 Reg8 _t8 = reg_scratch.cvt8();
494 mov(_t8, (int8_t)-128);
495 vpbroadcastb(zmm_shift, _t8);
499 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::cvt2ps(
500 data_type_t type_in, zmm_t zmm_in, const Operand &op, bool mask_flag) {
501 zmm_t zmm = mask_flag ? zmm_in | ktail_mask | T_z : zmm_in;
504 case data_type::s32: vmovups(zmm, op); break;
505 case data_type::s8: vpmovsxbd(zmm, op); break;
506 case data_type::u8: vpmovzxbd(zmm, op); break;
507 default: assert(!"unsupported data type");
509 if (type_in != data_type::f32)
510 vcvtdq2ps(zmm_in, zmm_in);
513 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::store_output(
514 int ur_w, bool last_oc_block) {
515 mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
516 mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
518 if (jcp.signed_input)
519 mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]);
521 const auto &p = attr_.post_ops_;
522 const int sum_idx = p.find(primitive_kind::sum);
523 const float *p_sum_scale
524 = (sum_idx != -1) ? &p.entry_[sum_idx].sum.scale : nullptr;
525 if (p_sum_scale && *p_sum_scale != 1.f)
526 mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
528 if (jcp.with_bias && jcp.signed_input && jcp.ver != ver_vnni) {
529 mov(reg_bias_alpha, float2int(jcp.wei_adj_scale));
530 vmovq(xmm_bias_alpha(), reg_bias_alpha);
531 vbroadcastss(zmm_bias_alpha(), xmm_bias_alpha());
534 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
535 const bool mask_flag = last_oc_block && ocb == jcp.nb_oc_blocking - 1;
537 = jcp.is_oc_scale * (sizeof(float) * ocb * jcp.oc_block);
539 auto zmm_bias = zmm_tmp;
541 int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block;
542 auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
543 cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag);
544 if (jcp.signed_input && jcp.ver != ver_vnni)
545 vmulps(zmm_bias, zmm_bias, zmm_bias_alpha());
547 if (jcp.signed_input) {
548 int comp_offset = sizeof(int32_t) * ocb * jcp.oc_block;
549 auto comp_addr = EVEX_compress_addr(reg_compensation, comp_offset);
550 cvt2ps(data_type::s32, zmm_comp, comp_addr, mask_flag);
553 for (int ur = 0; ur < ur_w; ur++) {
554 zmm_t zmm = zmm_out(ur, ocb);
556 if (jcp.signed_input)
557 vaddps(zmm, zmm, zmm_comp);
559 vaddps(zmm, zmm, zmm_bias);
560 zmm_t mask_zmm = mask_flag ? zmm | ktail_mask | T_z : zmm;
561 vmulps(mask_zmm, zmm,
562 EVEX_compress_addr(reg_ptr_scales, scale_offset));
565 if (maybe_eltwise(0))
566 compute_eltwise(ur_w);
567 if (p_sum_scale) { // post_op: sum
568 for (int k = 0; k < jcp.nb_oc_blocking; k++) {
570 = last_oc_block == 1 && k == jcp.nb_oc_blocking - 1;
571 for (int j = 0; j < ur_w; j++) {
572 int aux_output_offset
575 + j * jcp.oc_without_padding * jcp.ngroups);
576 auto addr = EVEX_compress_addr(reg_dst, aux_output_offset);
577 Zmm zmm = zmm_out(j, k);
578 cvt2ps(jcp.dst_dt, zmm_prev_dst, addr, mask_flag);
579 if (*p_sum_scale == 1.f)
580 vaddps(zmm, zmm_prev_dst);
582 vfmadd231ps(zmm, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
586 if (maybe_eltwise(1))
587 compute_eltwise(ur_w);
589 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
590 const bool mask_flag = last_oc_block && ocb == jcp.nb_oc_blocking - 1;
591 for (int ur = 0; ur < ur_w; ur++) {
592 zmm_t zmm = zmm_out(ur, ocb);
593 if (jcp.dst_dt == data_type::u8) {
594 vpxord(zmm_zero, zmm_zero, zmm_zero);
595 vmaxps(zmm, zmm_zero, zmm);
597 if (jcp.dst_dt != data_type::f32) {
598 if (attr_.round_mode_ == round_mode::nearest)
599 vcvtps2dq(zmm | T_rn_sae, zmm);
600 else if (attr_.round_mode_ == round_mode::down)
601 vcvtps2dq(zmm | T_rd_sae, zmm);
603 assert(!"unimplemented");
606 for (int ur = 0; ur < ur_w; ur++) {
607 int aux_dst_off = jcp.typesize_out
608 * (ur * jcp.ngroups * jcp.oc_without_padding
609 + ocb * jcp.oc_block);
610 auto addr = EVEX_compress_addr(reg_dst, aux_dst_off);
612 zmm_t zmm = zmm_out(ur, ocb);
613 zmm_t r_zmm = mask_flag ? zmm | ktail_mask : zmm;
614 switch (jcp.dst_dt) {
616 case data_type::s32: vmovups(addr, r_zmm); break;
617 case data_type::s8: vpmovsdb(addr, r_zmm); break;
618 case data_type::u8: vpmovusdb(addr, r_zmm); break;
619 default: assert(!"unknown dst_dt");
625 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::icb_loop(
626 int ur_w, int l_overflow, int r_overflow, bool is_last_sp_block) {
628 int shift_src_icb = jcp.typesize_in * jcp.ic_block;
630 = jcp.typesize_in * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block;
632 prepare_output(ur_w);
634 Label skip_icb_loop, icb_loop_label;
636 mov(reg_icb, jcp.nb_ic);
639 if (jcp.ic_without_padding != jcp.ic) {
640 Label common_ker, end_ker;
642 jg(common_ker, T_NEAR);
644 kh_loop(ur_w, l_overflow, r_overflow,
645 is_last_sp_block ? last_sp_block : last_ic_block);
646 jmp(end_ker, T_NEAR);
649 kh_loop(ur_w, l_overflow, r_overflow, no_last_block);
653 kh_loop(ur_w, l_overflow, r_overflow, no_last_block);
656 add(reg_src, shift_src_icb);
657 add(reg_filt, shift_filt_icb);
660 jg(icb_loop_label, T_NEAR);
663 /* come-back pointers */
664 sub(reg_src, jcp.nb_ic * shift_src_icb);
665 sub(reg_filt, jcp.nb_ic * shift_filt_icb);
668 if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
669 Label common_store, end_store;
670 mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
671 if (jcp.is_depthwise)
672 cmp(reg_oc_blocks, jcp.nb_ch - 1);
674 cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
675 jne(common_store, T_NEAR);
677 store_output(ur_w, true);
678 jmp(end_store, T_NEAR);
681 store_output(ur_w, false);
686 store_output(ur_w, false);
690 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::generate() {
693 xor_(reg_scratch, reg_scratch);
694 Reg16 _t = reg_scratch.cvt16();
696 vpbroadcastw(zmm_one, _t);
698 if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
699 int tail_size = jcp.is_depthwise ?
700 jcp.ngroups % jcp.ch_block :
701 jcp.oc_without_padding % jcp.oc_block;
702 int mask = (1 << tail_size) - 1;
703 Reg32 regw_tmp = reg_nur_w.cvt32();
705 kmovw(ktail_mask, regw_tmp);
708 mov(reg_src, ptr[param1 + GET_OFF(src)]);
709 mov(reg_filt, ptr[param1 + GET_OFF(filt)]);
710 mov(reg_dst, ptr[param1 + GET_OFF(dst)]);
712 int dst_shift = jcp.typesize_out * jcp.ur_w * jcp.ngroups
713 * jcp.oc_without_padding;
714 int src_shift = jcp.typesize_in * (jcp.ur_w / jcp.stride_w) * jcp.ngroups
715 * jcp.ic_without_padding;
717 int l_overflow = max(
718 0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w);
720 = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - max(0, jcp.r_pad))
724 = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
725 - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail)
727 int nur_w = jcp.ow / jcp.ur_w;
731 if (jcp.ur_w == jcp.ow) {
732 icb_loop(jcp.ur_w, l_overflow, r_overflow, true);
733 } else if (nur_w == 0) {
734 icb_loop(jcp.ur_w, l_overflow, r_overflow1, jcp.ur_w_tail == 0);
735 add(reg_src, src_shift);
736 add(reg_dst, dst_shift);
737 if (jcp.ur_w_tail != 0)
738 icb_loop(jcp.ur_w_tail, 0, r_overflow, true);
740 xor_(reg_nur_w, reg_nur_w);
741 if (l_overflow > 0) {
742 icb_loop(jcp.ur_w, l_overflow, 0, false);
743 add(reg_src, src_shift);
744 add(reg_dst, dst_shift);
747 if ((l_overflow <= 0 && nur_w > 0) || (l_overflow > 0 && nur_w > 1)) {
751 icb_loop(jcp.ur_w, 0, 0, false);
752 add(reg_src, src_shift);
753 add(reg_dst, dst_shift);
755 cmp(reg_nur_w, nur_w);
756 jl(ow_loop_label, T_NEAR);
759 if (r_overflow1 > 0) {
760 icb_loop(jcp.ur_w, 0, r_overflow1, jcp.ur_w_tail == 0);
761 add(reg_src, src_shift);
762 add(reg_dst, dst_shift);
764 if (jcp.ur_w_tail != 0) {
765 icb_loop(jcp.ur_w_tail, 0, r_overflow, true);
770 if (jcp.with_eltwise)
771 eltwise_injector_->prepare_table();
774 template <data_type_t src_type, data_type_t dst_type>
775 void _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<src_type,
776 dst_type>::execute_forward() const {
777 auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
778 auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
779 auto bias = reinterpret_cast<const char *>(this->input_memory(2));
780 auto dst = reinterpret_cast<dst_data_t *>(this->memory());
782 const memory_desc_wrapper src_d(pd()->src_pd());
783 const memory_desc_wrapper dst_d(pd()->dst_pd());
784 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
785 const memory_desc_wrapper bias_d(pd()->weights_pd(1));
787 auto &jcp = kernel_->jcp;
789 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
790 int nb_groups = jcp.nb_ch;
792 size_t src_h_stride = src_d.blk_off(0, 0, 1);
793 size_t dst_h_stride = dst_d.blk_off(0, 0, 1);
794 size_t wht_kh_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
796 const float *oscales = pd()->attr()->output_scales_.scales_;
797 if (jcp.signed_input && jcp.ver != ver_vnni) {
799 = scratchpad().template get<float>(key_conv_adjusted_scales);
800 size_t count = pd()->attr()->output_scales_.count_;
801 float factor = 1.f / pd()->jcp_.wei_adj_scale;
803 utils::array_set(local_scales, oscales[0] * factor, 16);
805 for (size_t c = 0; c < count; c++)
806 local_scales[c] = oscales[c] * factor;
808 oscales = local_scales;
810 size_t offset = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw;
811 auto w = const_cast<wei_data_t *>(weights);
812 int32_t *compensation
813 = (jcp.signed_input) ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
815 parallel(0, [&](const int ithr, const int nthr) {
816 int start{ 0 }, end{ 0 };
817 int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh;
818 balance211(work_amount, nthr, ithr, start, end);
820 auto p = jit_deconv_call_s();
823 int n{ 0 }, g{ 0 }, occ{ 0 }, oh_s{ 0 };
824 if (jcp.loop_order == loop_ngc)
825 nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks,
827 else if (jcp.loop_order == loop_cgn)
828 nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb,
831 assert(!"unsupported loop order");
832 while (start < end) {
834 int ocb = occ * jcp.nb_oc_blocking;
835 int g_oc = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block;
836 int g_ic = g * jcp.ch_block * jcp.ic;
837 int work_rem = end - start;
838 int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
840 auto dst_w = dst + dst_d.blk_off(n, g_oc);
841 auto src_w = src + src_d.blk_off(n, g_ic);
842 auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0);
843 auto bias_w = jcp.with_bias ?
844 bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia) :
846 int32_t *compensation_w
847 = (jcp.signed_input) ? compensation + g_oc : 0;
849 auto scales = &oscales[jcp.is_oc_scale * g_oc];
850 for (int oj = oh_s; oj < oh_e; oj++) {
851 int ih_max = 0, kh_lo = 0, kh_len = 0;
852 if (jcp.dilate_h != 0 && jcp.stride_h == 1) {
854 int dilate_h = jcp.dilate_h + 1;
855 // Note: use div_up to account for "holes" in filter
856 int o_t_overflow = div_up(
857 max(0, (jcp.kh - 1) * dilate_h - oj - jcp.t_pad),
860 = div_up(max(0, (jcp.kh - 1) * dilate_h + 1 - jcp.oh
863 kh_len = jcp.kh - o_t_overflow - o_b_overflow;
864 kh_lo = o_b_overflow;
865 ih_max = oj + jcp.t_pad - o_b_overflow * dilate_h;
867 int o_t_overflow = max(
868 0, (jcp.kh - (oj + 1 + jcp.t_pad)) / jcp.stride_h);
870 = max(0, ((oj + jcp.kh) - (jcp.oh + jcp.b_pad))
872 int overflow_kh_hi = jcp.kh - 1
873 - abs(jcp.oh + jcp.b_pad - (oj + 1)) % jcp.stride_h;
874 int overflow_kh_lo = (oj + jcp.t_pad) % jcp.stride_h;
876 kh_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h
877 + 1 - o_t_overflow - o_b_overflow;
878 kh_lo = overflow_kh_lo + o_b_overflow * jcp.stride_h;
879 ih_max = (oj + jcp.t_pad - kh_lo) / jcp.stride_h;
883 = (!jcp.signed_input) ? kh_lo * wht_kh_stride : 0;
884 p.src = src_w + ih_max * src_h_stride;
885 p.dst = dst_w + oj * dst_h_stride;
886 p.filt = wht_w + wei_stride;
888 p.compensation = compensation_w;
890 0, jcp.kh - (kh_lo + max(0, kh_len - 1) * jcp.stride_h
892 p.b_overflow = kh_lo;
893 p.kh_padding = kh_len;
895 p.oc_blocks = jcp.is_depthwise ? g : ocb;
896 kernel_->jit_ker(&p);
898 if (jcp.loop_order == loop_ngc)
899 nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ,
900 oc_chunks, oh_s, jcp.oh);
901 else if (jcp.loop_order == loop_cgn)
902 nd_iterator_jump(start, end, occ, oc_chunks, g, nb_groups, n,
903 jcp.mb, oh_s, jcp.oh);
905 assert(!"unsupported loop order");
910 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::u8,
912 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::u8,
914 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::u8,
916 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::u8,
918 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::s8,
920 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::s8,
922 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::s8,
924 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::s8,