1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "jit_avx512_core_x8s8s32x_deconvolution.hpp"
19 #define GET_OFF(field) offsetof(jit_deconv_call_s, field)
25 using namespace mkldnn::impl::status;
26 using namespace mkldnn::impl::memory_format;
27 using namespace mkldnn::impl::memory_tracking::names;
28 using namespace mkldnn::impl::utils;
29 using namespace Xbyak;
33 #define wht_blk_off(d, g, ...) \
34 (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) : \
35 (d).blk_off(__VA_ARGS__))
37 status_t jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_conf(
38 jit_conv_conf_t &jcp, const deconvolution_desc_t &cd,
39 cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &weights_pd,
40 cpu_memory_t::pd_t &dst_pd, const bool with_bias,
41 cpu_memory_t::pd_t &bias_pd, const primitive_attr_t &attr) {
42 const memory_desc_wrapper src_d(&src_pd);
43 const memory_desc_wrapper dst_d(&dst_pd);
44 const memory_desc_wrapper weights_d(&weights_pd);
45 const memory_desc_wrapper bias_d(&bias_pd);
47 if (!(mayiuse(avx512_core)
48 && one_of(src_d.data_type(), data_type::u8, data_type::s8)
49 && weights_d.data_type() == data_type::s8
50 && one_of(dst_d.data_type(), data_type::f32, data_type::s32,
51 data_type::s8, data_type::u8)))
52 return status::unimplemented;
54 jcp = zero<decltype(jcp)>();
56 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
57 jcp.signed_input = src_d.data_type() == data_type::s8;
58 const int ndims = jcp.ndims = dst_d.ndims();
59 const bool is_1d = ndims == 3;
61 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
62 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
63 jcp.ic = src_d.dims()[1] / jcp.ngroups;
64 jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups;
65 jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups;
66 jcp.is_depthwise = true && with_groups
67 && utils::everyone_is(1, jcp.ic_without_padding,
68 jcp.oc_without_padding);
70 /* TODO: future work, on hold until depthwise specialized kernel is
72 if (jcp.is_depthwise && jcp.signed_input)
73 return status::unimplemented;
75 auto dst_format = pick(ndims - 3, nwc, nhwc);
76 auto src_format = pick(ndims - 3, nwc, nhwc);
77 #define pick_signed(fmt) (jcp.signed_input ? fmt##_s8s8 : fmt)
78 const auto w_format = is_1d ?
79 (jcp.is_depthwise ? Goiw16g :
80 (with_groups ? pick_signed(gOIw4i16o4i) :
81 pick_signed(OIw4i16o4i))) :
82 (jcp.is_depthwise ? Goihw16g :
83 (with_groups ? pick_signed(gOIhw4i16o4i) :
84 pick_signed(OIhw4i16o4i)));
87 if (dst_d.format() == any)
88 CHECK(dst_pd.set_format(dst_format));
89 if (dst_d.format() != dst_format)
90 return status::unimplemented;
91 if (src_d.format() == any)
92 CHECK(src_pd.set_format(src_format));
93 if (src_d.format() != src_format)
94 return status::unimplemented;
95 if (weights_d.format() == any)
96 CHECK(weights_pd.set_format(w_format));
97 if (weights_d.format() != w_format)
98 return status::unimplemented;
100 jcp.with_bias = with_bias;
102 if (bias_d.format() == any)
103 CHECK(bias_pd.set_format(x));
104 if (bias_d.format() != x)
105 return status::unimplemented;
108 jcp.prop_kind = cd.prop_kind;
109 jcp.mb = src_d.dims()[0];
110 jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2];
111 jcp.iw = src_d.dims()[ndims - 1];
112 jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2];
113 jcp.ow = dst_d.dims()[ndims - 1];
114 jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2];
115 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
116 jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4];
117 jcp.l_pad = cd.padding[0][ndims - 3];
118 jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4];
119 jcp.stride_w = cd.strides[ndims - 3];
120 jcp.src_fmt = src_d.format();
122 if (jcp.is_depthwise) {
131 if (jcp.ngroups == 1) {
132 jcp.oc = utils::rnd_up(jcp.oc_without_padding, jcp.oc_block);
133 jcp.ic = utils::rnd_up(jcp.ic_without_padding, jcp.ic_block);
135 if (jcp.ic % jcp.ic_block != 0)
136 return status::unimplemented;
139 jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4];
140 jcp.dilate_w = cd.dilates[ndims - 3];
142 if (!IMPLICATION(jcp.dilate_h, jcp.stride_h == 1)
143 || !IMPLICATION(jcp.dilate_w, jcp.stride_w == 1))
144 return status::unimplemented;
146 /* padding: bottom and right */
147 jcp.b_pad = (jcp.ih - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
148 - (jcp.oh + jcp.t_pad - 1);
149 jcp.r_pad = (jcp.iw - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1)
150 - (jcp.ow + jcp.l_pad - 1);
152 if (!post_ops_ok(jcp, attr))
153 return status::unimplemented;
155 const auto &p = attr.post_ops_;
156 const int eltwise_ind = p.find(primitive_kind::eltwise);
157 jcp.with_eltwise = eltwise_ind != -1;
158 if (jcp.with_eltwise)
159 jcp.eltwise = p.entry_[eltwise_ind].eltwise;
161 jcp.ver = ver_avx512_core;
162 if (mayiuse(avx512_core_vnni))
164 const auto &oscales = attr.output_scales_;
165 jcp.is_oc_scale = oscales.mask_ == 1 << 1;
167 assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
169 jcp.dst_dt = dst_d.data_type();
170 jcp.bia_dt = jcp.with_bias ? bias_d.data_type() : data_type::undef;
172 = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
173 jcp.typesize_in = types::data_type_size(src_d.data_type());
174 jcp.typesize_out = types::data_type_size(dst_d.data_type());
176 jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block);
177 jcp.nb_oc = jcp.oc / jcp.oc_block;
178 jcp.nb_ic = jcp.ic / jcp.ic_block;
180 /* kernel blocking params */
181 const int regs = jcp.ver == ver_vnni ? 30 : 28;
182 jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc);
183 for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--)
184 if (jcp.nb_oc % jcp.nb_oc_blocking == 0
185 && jcp.l_pad <= regs / (jcp.nb_oc_blocking + 1))
188 jcp.ur_w = regs / (jcp.nb_oc_blocking + 1);
189 int l_overflow = max(
190 0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w);
192 if (jcp.ow < jcp.ur_w) {
196 for (; jcp.ur_w >= 1; jcp.ur_w--) {
197 /* ur_w should be multiple of stride_w in order
198 to simplify logic for get_ow_start and get_ow_end */
199 bool is_multiple_of_stride = jcp.ur_w % jcp.stride_w == 0;
201 /* boundary conditions:
202 These conditions ensure all elements close to boundary
203 are computed in a single call of compute loop */
204 bool left_boundary_covered = jcp.ur_w >= l_overflow * jcp.stride_w;
205 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
206 int r_overflow_no_tail
207 = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
208 - max(0, jcp.r_pad) - jcp.ur_w_tail)
210 bool right_boundary_covered
211 = jcp.ur_w >= r_overflow_no_tail * jcp.stride_w;
213 if (is_multiple_of_stride && left_boundary_covered
214 && right_boundary_covered)
216 else if (jcp.ur_w == 1)
217 /* The boundary conditions above are also important
218 to maintain simplicity of calls to icb_loop,
219 if those conditions are not satisfied,
220 then special cases will need to be added
221 to use correct l_overflow/r_overflow values
222 when different iterations of compute loop
223 work on the locations close to boundary.
224 So to keep code simple, return unimplemented
225 for extreme case when a good ur_w cannot be found.
227 return status::unimplemented;
232 = (jcp.signed_input && (jcp.ver != ver_vnni)) ? (1.f / 2.f) : 1.f;
234 jcp.loop_order = jcp.ngroups > 1 ? loop_ngc : loop_cgn;
235 return status::success;
238 bool jit_avx512_core_x8s8s32x_deconv_fwd_kernel::maybe_eltwise(int position) {
239 using namespace primitive_kind;
240 const auto &p = attr_.post_ops_;
243 /* eltwise before sum */
244 return p.contain(eltwise, 0);
245 } else if (position == 1) {
246 /* eltwise after sum */
247 return p.contain(sum, 0) && p.contain(eltwise, 1);
252 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::compute_eltwise(int ur_w) {
254 = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
255 if (ur_w == jcp.ur_w)
256 eltwise_injector_->compute_vector_range(0, nb_oc_block * jcp.ur_w);
258 for (int k = 0; k < nb_oc_block; k++)
259 eltwise_injector_->compute_vector_range(
260 k * jcp.ur_w, k * jcp.ur_w + ur_w);
263 bool jit_avx512_core_x8s8s32x_deconv_fwd_kernel::post_ops_ok(
264 jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
265 using namespace primitive_kind;
266 const auto &p = attr.post_ops_;
268 auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
272 case 1: return is_eltwise(0) || p.contain(sum, 0);
274 return (p.contain(sum, 0) && is_eltwise(1))
275 || (p.contain(sum, 1) && is_eltwise(0));
276 default: return false;
282 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_scratchpad(
283 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
284 const primitive_attr_t &attr) {
285 if (jcp.signed_input && jcp.ver != ver_vnni) {
286 size_t count = nstl::max(attr.output_scales_.count_, 16);
287 scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count);
291 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::compute_ker(int ur_w,
292 int l_overflow, int r_overflow, ker_block_t last_ic_block_flag,
295 const int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block;
296 const int ur_w_stride = jcp.signed_input ? 1 : jcp.stride_w;
298 auto src_offset = [=](int oj, int icb, int ki) {
299 return jcp.typesize_in
300 * (((oj + jcp.l_pad - ki * (jcp.dilate_w + 1)) / jcp.stride_w)
301 * jcp.ngroups * jcp.ic_without_padding
305 auto kernel_offset = [=](int ocb, int icb, int ki) {
306 return jcp.typesize_in
307 * (ocb * jcp.nb_ic * jcp.kh * jcp.kw * ch_block_all
308 + icb * jcp.oc_block * jcp.ic_block / 4
309 + ki * ch_block_all);
312 auto compute = [=](zmm_t vreg_acc, zmm_t vreg_wei, zmm_t vreg_src) {
313 if (jcp.ver == ver_vnni) {
314 vpdpbusd(vreg_acc, vreg_src, vreg_wei);
315 } else if (jcp.is_depthwise) {
316 vpmulld(zmm_tmp, vreg_src, vreg_wei);
317 vpaddd(vreg_acc, vreg_acc, zmm_tmp);
319 vpmaddubsw(zmm_tmp, vreg_src, vreg_wei);
320 vpmaddwd(zmm_tmp, zmm_tmp, zmm_one);
321 vpaddd(vreg_acc, vreg_acc, zmm_tmp);
325 for (int ki = 0; ki < jcp.kw; ki++) {
327 int jj_start = get_ow_start(ki, l_overflow);
328 int jj_end = get_ow_end(ur_w, ki, r_overflow);
330 int _start = (jcp.signed_input) ? 0 : jj_start;
331 int _end = (jcp.signed_input) ? ur_w : jj_end;
333 int tail_size = jcp.ic_without_padding % 4;
334 int n_ic_blocks = jcp.is_depthwise ?
336 (last_ic_block_flag & ~no_last_block ?
337 div_up(jcp.ic_without_padding % jcp.ic_block,
341 for (int icb1 = 0; icb1 < n_ic_blocks; icb1++) {
342 if (h_padded == true) {
343 /* fill padded area with shifted values */
344 Zmm inp = zmm_inp(0, jcp.nb_oc_blocking);
345 vpxord(inp, inp, inp);
346 vpsubb(inp, inp, zmm_shift);
349 for (int jj = _start; jj < _end; jj += ur_w_stride) {
351 int aux_src_off = src_offset(jj, icb1, ki);
353 if (jj >= jj_start && jj < jj_end
354 && ((jj + jcp.l_pad - ki) % jcp.stride_w == 0)) {
355 if (jcp.is_depthwise) {
356 vpmovzxbd(zmm_inp(jj, jcp.nb_oc_blocking),
358 aux_reg_src, aux_src_off));
359 } else if ((last_ic_block_flag & last_sp_block)
360 && tail_size != 0 && icb1 == n_ic_blocks - 1) {
361 xmm_t xmm_tmp = xmm_t(
362 zmm_inp(jj, jcp.nb_oc_blocking).getIdx());
363 for (int r = 0; r < tail_size; ++r)
364 vpinsrb(xmm_tmp, xmm_tmp,
365 ptr[aux_reg_src + aux_src_off + r], r);
367 zmm_inp(jj, jcp.nb_oc_blocking), xmm_tmp);
369 vpbroadcastd(zmm_inp(jj, jcp.nb_oc_blocking),
371 aux_reg_src, aux_src_off));
373 if (jcp.signed_input)
374 vpsubb(zmm_inp(jj, jcp.nb_oc_blocking),
375 zmm_inp(jj, jcp.nb_oc_blocking), zmm_shift);
377 /* fill padded area with shifted values */
378 if (jcp.signed_input) {
379 Zmm inp = zmm_inp(jj, jcp.nb_oc_blocking);
380 vpxord(inp, inp, inp);
381 vpsubb(inp, inp, zmm_shift);
386 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
387 int aux_filt_off = kernel_offset(ocb, icb1, ki);
389 if (_end - _start > 0) {
390 if (jcp.is_depthwise)
392 EVEX_compress_addr(aux_reg_filt, aux_filt_off));
395 EVEX_compress_addr(aux_reg_filt, aux_filt_off));
397 for (int jj = _start; jj < _end; jj += ur_w_stride) {
398 Zmm inp = (h_padded == true) ?
399 zmm_inp(0, jcp.nb_oc_blocking) :
400 zmm_inp(jj, jcp.nb_oc_blocking);
401 compute(zmm_out(jj, ocb), zmm_wei, inp);
408 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::kh_loop(int ur_w,
409 int l_overflow, int r_overflow, ker_block_t last_ic_block_flag) {
411 int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block;
412 int shift_src_ih = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw
413 * jcp.ngroups * jcp.ic_without_padding;
414 const int stride_h = jcp.signed_input ? 1 : jcp.stride_h;
415 int shift_filt_kh = jcp.typesize_in * jcp.kw * ch_block_all * stride_h;
417 Label kh_loop_label, skip_kh_loop;
418 Label t_overflow_label, no_t_overflow_label, b_overflow_label,
421 mov(aux_reg_src, reg_src);
422 mov(aux_reg_filt, reg_filt);
424 if (jcp.signed_input && jcp.ndims > 3) {
425 /* Weights are transposed, so first compute 'bottom' padding. */
426 mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]);
427 cmp(reg_overflow, 0);
428 je(no_b_overflow_label, T_NEAR);
429 L(b_overflow_label); {
430 compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
432 add(aux_reg_filt, shift_filt_kh);
434 cmp(reg_overflow, 0);
435 jg(b_overflow_label, T_NEAR);
437 L(no_b_overflow_label);
440 mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]);
442 if (jcp.signed_input || ((!jcp.signed_input)
443 && ((min(jcp.t_pad, jcp.b_pad) < 0)
444 || ((jcp.kh - 1) * (jcp.dilate_h + 1)
445 < nstl::max(jcp.t_pad, jcp.b_pad))))) {
447 je(skip_kh_loop, T_NEAR);
451 compute_ker(ur_w, l_overflow, r_overflow, last_ic_block_flag, false);
452 sub(aux_reg_src, shift_src_ih);
453 add(aux_reg_filt, shift_filt_kh);
456 /* Insert weight compensation in stride 'holes' */
457 if (jcp.signed_input && jcp.stride_h > 1) {
461 je(skip_kh_loop, T_NEAR);
462 mov(reg_comp_strides, jcp.stride_h - 1);
466 ur_w, 0, 0, last_ic_block_flag, true);
467 add(aux_reg_filt, shift_filt_kh);
468 dec(reg_comp_strides);
469 cmp(reg_comp_strides, 0);
470 jg(kh_comp_loop, T_NEAR);
474 jg(kh_loop_label, T_NEAR);
477 if (jcp.signed_input && jcp.ndims > 3) {
478 mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]);
479 cmp(reg_overflow, 0);
480 je(no_t_overflow_label, T_NEAR);
481 L(t_overflow_label); {
482 compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
484 add(aux_reg_filt, shift_filt_kh);
486 cmp(reg_overflow, 0);
487 jg(t_overflow_label, T_NEAR);
489 L(no_t_overflow_label);
493 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::prepare_output(int ur_w) {
494 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
495 for (int ur = 0; ur < ur_w; ur++) {
496 zmm_t zmm = zmm_out(ur, ocb);
497 vpxord(zmm, zmm, zmm);
500 if (jcp.signed_input) {
501 xor_(reg_scratch, reg_scratch);
502 Reg8 _t8 = reg_scratch.cvt8();
503 mov(_t8, (int8_t)-128);
504 vpbroadcastb(zmm_shift, _t8);
508 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::cvt2ps(
509 data_type_t type_in, zmm_t zmm_in, const Operand &op, bool mask_flag) {
510 zmm_t zmm = mask_flag ? zmm_in | ktail_mask | T_z : zmm_in;
513 case data_type::s32: vmovups(zmm, op); break;
514 case data_type::s8: vpmovsxbd(zmm, op); break;
515 case data_type::u8: vpmovzxbd(zmm, op); break;
516 default: assert(!"unsupported data type");
518 if (type_in != data_type::f32)
519 vcvtdq2ps(zmm_in, zmm_in);
522 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::store_output(
523 int ur_w, bool last_oc_block) {
524 mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
525 mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
527 if (jcp.signed_input)
528 mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]);
530 const auto &p = attr_.post_ops_;
531 const int sum_idx = p.find(primitive_kind::sum);
532 const float *p_sum_scale
533 = (sum_idx != -1) ? &p.entry_[sum_idx].sum.scale : nullptr;
534 if (p_sum_scale && *p_sum_scale != 1.f)
535 mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
537 if (jcp.with_bias && jcp.signed_input && jcp.ver != ver_vnni) {
538 mov(reg_bias_alpha, float2int(jcp.wei_adj_scale));
539 vmovq(xmm_bias_alpha(), reg_bias_alpha);
540 vbroadcastss(zmm_bias_alpha(), xmm_bias_alpha());
543 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
544 const bool mask_flag = last_oc_block && ocb == jcp.nb_oc_blocking - 1;
546 = jcp.is_oc_scale * (sizeof(float) * ocb * jcp.oc_block);
548 auto zmm_bias = zmm_tmp;
550 int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block;
551 auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
552 cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag);
553 if (jcp.signed_input && jcp.ver != ver_vnni)
554 vmulps(zmm_bias, zmm_bias, zmm_bias_alpha());
556 if (jcp.signed_input) {
557 int comp_offset = sizeof(int32_t) * ocb * jcp.oc_block;
558 auto comp_addr = EVEX_compress_addr(reg_compensation, comp_offset);
559 cvt2ps(data_type::s32, zmm_comp, comp_addr, mask_flag);
562 for (int ur = 0; ur < ur_w; ur++) {
563 zmm_t zmm = zmm_out(ur, ocb);
565 if (jcp.signed_input)
566 vaddps(zmm, zmm, zmm_comp);
568 vaddps(zmm, zmm, zmm_bias);
569 zmm_t mask_zmm = mask_flag ? zmm | ktail_mask | T_z : zmm;
570 vmulps(mask_zmm, zmm,
571 EVEX_compress_addr(reg_ptr_scales, scale_offset));
574 if (maybe_eltwise(0))
575 compute_eltwise(ur_w);
576 if (p_sum_scale) { // post_op: sum
577 for (int k = 0; k < jcp.nb_oc_blocking; k++) {
579 = last_oc_block == 1 && k == jcp.nb_oc_blocking - 1;
580 for (int j = 0; j < ur_w; j++) {
581 int aux_output_offset
584 + j * jcp.oc_without_padding * jcp.ngroups);
585 auto addr = EVEX_compress_addr(reg_dst, aux_output_offset);
586 Zmm zmm = zmm_out(j, k);
587 cvt2ps(jcp.dst_dt, zmm_prev_dst, addr, mask_flag);
588 if (*p_sum_scale == 1.f)
589 vaddps(zmm, zmm_prev_dst);
591 vfmadd231ps(zmm, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
595 if (maybe_eltwise(1))
596 compute_eltwise(ur_w);
598 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
599 const bool mask_flag = last_oc_block && ocb == jcp.nb_oc_blocking - 1;
600 for (int ur = 0; ur < ur_w; ur++) {
601 zmm_t zmm = zmm_out(ur, ocb);
602 if (jcp.dst_dt == data_type::u8) {
603 vpxord(zmm_zero, zmm_zero, zmm_zero);
604 vmaxps(zmm, zmm_zero, zmm);
606 if (jcp.dst_dt != data_type::f32) {
607 if (attr_.round_mode_ == round_mode::nearest)
608 vcvtps2dq(zmm | T_rn_sae, zmm);
609 else if (attr_.round_mode_ == round_mode::down)
610 vcvtps2dq(zmm | T_rd_sae, zmm);
612 assert(!"unimplemented");
615 for (int ur = 0; ur < ur_w; ur++) {
616 int aux_dst_off = jcp.typesize_out
617 * (ur * jcp.ngroups * jcp.oc_without_padding
618 + ocb * jcp.oc_block);
619 auto addr = EVEX_compress_addr(reg_dst, aux_dst_off);
621 zmm_t zmm = zmm_out(ur, ocb);
622 zmm_t r_zmm = mask_flag ? zmm | ktail_mask : zmm;
623 switch (jcp.dst_dt) {
625 case data_type::s32: vmovups(addr, r_zmm); break;
626 case data_type::s8: vpmovsdb(addr, r_zmm); break;
627 case data_type::u8: vpmovusdb(addr, r_zmm); break;
628 default: assert(!"unknown dst_dt");
634 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::icb_loop(
635 int ur_w, int l_overflow, int r_overflow, bool is_last_sp_block) {
637 int shift_src_icb = jcp.typesize_in * jcp.ic_block;
639 = jcp.typesize_in * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block;
641 prepare_output(ur_w);
643 Label skip_icb_loop, icb_loop_label;
645 mov(reg_icb, jcp.nb_ic);
648 if (jcp.ic_without_padding != jcp.ic) {
649 Label common_ker, end_ker;
651 jg(common_ker, T_NEAR);
653 kh_loop(ur_w, l_overflow, r_overflow,
654 is_last_sp_block ? last_sp_block : last_ic_block);
655 jmp(end_ker, T_NEAR);
658 kh_loop(ur_w, l_overflow, r_overflow, no_last_block);
662 kh_loop(ur_w, l_overflow, r_overflow, no_last_block);
665 add(reg_src, shift_src_icb);
666 add(reg_filt, shift_filt_icb);
669 jg(icb_loop_label, T_NEAR);
672 /* come-back pointers */
673 sub(reg_src, jcp.nb_ic * shift_src_icb);
674 sub(reg_filt, jcp.nb_ic * shift_filt_icb);
677 if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
678 Label common_store, end_store;
679 mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
680 if (jcp.is_depthwise)
681 cmp(reg_oc_blocks, jcp.nb_ch - 1);
683 cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
684 jne(common_store, T_NEAR);
686 store_output(ur_w, true);
687 jmp(end_store, T_NEAR);
690 store_output(ur_w, false);
695 store_output(ur_w, false);
699 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::generate() {
702 xor_(reg_scratch, reg_scratch);
703 Reg16 _t = reg_scratch.cvt16();
705 vpbroadcastw(zmm_one, _t);
707 if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
708 int tail_size = jcp.is_depthwise ?
709 jcp.ngroups % jcp.ch_block :
710 jcp.oc_without_padding % jcp.oc_block;
711 int mask = (1 << tail_size) - 1;
712 Reg32 regw_tmp = reg_nur_w.cvt32();
714 kmovw(ktail_mask, regw_tmp);
717 mov(reg_src, ptr[param1 + GET_OFF(src)]);
718 mov(reg_filt, ptr[param1 + GET_OFF(filt)]);
719 mov(reg_dst, ptr[param1 + GET_OFF(dst)]);
721 int dst_shift = jcp.typesize_out * jcp.ur_w * jcp.ngroups
722 * jcp.oc_without_padding;
723 int src_shift = jcp.typesize_in * (jcp.ur_w / jcp.stride_w) * jcp.ngroups
724 * jcp.ic_without_padding;
726 int l_overflow = max(
727 0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w);
729 = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - max(0, jcp.r_pad))
733 = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
734 - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail)
736 int nur_w = jcp.ow / jcp.ur_w;
740 if (jcp.ur_w == jcp.ow) {
741 icb_loop(jcp.ur_w, l_overflow, r_overflow, true);
742 } else if (nur_w == 0) {
743 icb_loop(jcp.ur_w, l_overflow, r_overflow1, jcp.ur_w_tail == 0);
744 add(reg_src, src_shift);
745 add(reg_dst, dst_shift);
746 if (jcp.ur_w_tail != 0)
747 icb_loop(jcp.ur_w_tail, 0, r_overflow, true);
749 xor_(reg_nur_w, reg_nur_w);
750 if (l_overflow > 0) {
751 icb_loop(jcp.ur_w, l_overflow, 0, false);
752 add(reg_src, src_shift);
753 add(reg_dst, dst_shift);
756 if ((l_overflow <= 0 && nur_w > 0) || (l_overflow > 0 && nur_w > 1)) {
760 icb_loop(jcp.ur_w, 0, 0, false);
761 add(reg_src, src_shift);
762 add(reg_dst, dst_shift);
764 cmp(reg_nur_w, nur_w);
765 jl(ow_loop_label, T_NEAR);
768 if (r_overflow1 > 0) {
769 icb_loop(jcp.ur_w, 0, r_overflow1, jcp.ur_w_tail == 0);
770 add(reg_src, src_shift);
771 add(reg_dst, dst_shift);
773 if (jcp.ur_w_tail != 0) {
774 icb_loop(jcp.ur_w_tail, 0, r_overflow, true);
779 if (jcp.with_eltwise)
780 eltwise_injector_->prepare_table();
783 template <data_type_t src_type, data_type_t dst_type>
784 void _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<src_type,
785 dst_type>::execute_forward_1d() const {
786 auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
787 auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
788 auto bias = reinterpret_cast<const char *>(this->input_memory(2));
789 auto dst = reinterpret_cast<dst_data_t *>(this->memory());
791 const memory_desc_wrapper src_d(pd()->src_pd());
792 const memory_desc_wrapper dst_d(pd()->dst_pd());
793 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
794 const memory_desc_wrapper bias_d(pd()->weights_pd(1));
796 auto &jcp = kernel_->jcp;
798 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
799 int nb_groups = jcp.nb_ch;
801 const float *oscales = pd()->attr()->output_scales_.scales_;
802 if (jcp.signed_input && jcp.ver != ver_vnni) {
804 = scratchpad().template get<float>(key_conv_adjusted_scales);
805 size_t count = pd()->attr()->output_scales_.count_;
806 float factor = 1.f / pd()->jcp_.wei_adj_scale;
808 utils::array_set(local_scales, oscales[0] * factor, 16);
810 for (size_t c = 0; c < count; c++)
811 local_scales[c] = oscales[c] * factor;
813 oscales = local_scales;
815 size_t offset = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw;
816 auto w = const_cast<wei_data_t *>(weights);
817 int32_t *compensation
818 = (jcp.signed_input) ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
820 parallel(0, [&](const int ithr, const int nthr) {
821 int start{ 0 }, end{ 0 };
822 int work_amount = jcp.mb * nb_groups * oc_chunks;
823 balance211(work_amount, nthr, ithr, start, end);
825 auto p = jit_deconv_call_s();
827 int n{ 0 }, g{ 0 }, occ{ 0 };
828 if (jcp.loop_order == loop_ngc)
829 nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks);
830 else if (jcp.loop_order == loop_cgn)
831 nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb);
833 assert(!"unsupported loop order");
834 while (start < end) {
836 int ocb = occ * jcp.nb_oc_blocking;
837 int g_oc = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block;
838 int g_ic = g * jcp.ch_block * jcp.ic;
840 p.dst = dst + dst_d.blk_off(n, g_oc);
841 p.src = src + src_d.blk_off(n, g_ic);
842 p.filt = weights + wht_blk_off(weights_d, g, ocb, 0);
843 p.bias = jcp.with_bias ?
844 bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia) :
846 p.compensation = (jcp.signed_input) ? compensation + g_oc : 0;
847 p.scales = &oscales[jcp.is_oc_scale * g_oc];
850 p.kh_padding = jcp.kh;
851 p.oc_blocks = jcp.is_depthwise ? g : ocb;
853 kernel_->jit_ker(&p);
856 if (jcp.loop_order == loop_ngc)
857 nd_iterator_step(n, jcp.mb, g, nb_groups, occ, oc_chunks);
858 else if (jcp.loop_order == loop_cgn)
859 nd_iterator_step(occ, oc_chunks, g, nb_groups, n, jcp.mb);
861 assert(!"unsupported loop order");
866 template <data_type_t src_type, data_type_t dst_type>
867 void _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<src_type,
868 dst_type>::execute_forward_2d() const {
869 auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
870 auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
871 auto bias = reinterpret_cast<const char *>(this->input_memory(2));
872 auto dst = reinterpret_cast<dst_data_t *>(this->memory());
874 const memory_desc_wrapper src_d(pd()->src_pd());
875 const memory_desc_wrapper dst_d(pd()->dst_pd());
876 const memory_desc_wrapper weights_d(pd()->weights_pd(0));
877 const memory_desc_wrapper bias_d(pd()->weights_pd(1));
879 auto &jcp = kernel_->jcp;
881 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
882 int nb_groups = jcp.nb_ch;
884 size_t src_h_stride = src_d.blk_off(0, 0, 1);
885 size_t dst_h_stride = dst_d.blk_off(0, 0, 1);
886 size_t wht_kh_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
888 const float *oscales = pd()->attr()->output_scales_.scales_;
889 if (jcp.signed_input && jcp.ver != ver_vnni) {
891 = scratchpad().template get<float>(key_conv_adjusted_scales);
892 size_t count = pd()->attr()->output_scales_.count_;
893 float factor = 1.f / pd()->jcp_.wei_adj_scale;
895 utils::array_set(local_scales, oscales[0] * factor, 16);
897 for (size_t c = 0; c < count; c++)
898 local_scales[c] = oscales[c] * factor;
900 oscales = local_scales;
902 size_t offset = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw;
903 auto w = const_cast<wei_data_t *>(weights);
904 int32_t *compensation
905 = (jcp.signed_input) ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
907 int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh;
909 parallel(0, (size_t)work_amount, [&](const int ithr, const int nthr) {
910 int start{ 0 }, end{ 0 };
911 balance211(work_amount, nthr, ithr, start, end);
913 auto p = jit_deconv_call_s();
916 int n{ 0 }, g{ 0 }, occ{ 0 }, oh_s{ 0 };
917 if (jcp.loop_order == loop_ngc)
918 nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks,
920 else if (jcp.loop_order == loop_cgn)
921 nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb,
924 assert(!"unsupported loop order");
925 while (start < end) {
927 int ocb = occ * jcp.nb_oc_blocking;
928 int g_oc = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block;
929 int g_ic = g * jcp.ch_block * jcp.ic;
930 int work_rem = end - start;
931 int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
933 auto dst_w = dst + dst_d.blk_off(n, g_oc);
934 auto src_w = src + src_d.blk_off(n, g_ic);
935 auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0);
936 auto bias_w = jcp.with_bias ?
937 bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia) :
939 int32_t *compensation_w
940 = (jcp.signed_input) ? compensation + g_oc : 0;
942 auto scales = &oscales[jcp.is_oc_scale * g_oc];
943 for (int oj = oh_s; oj < oh_e; oj++) {
944 int ih_max = 0, kh_lo = 0, kh_len = 0;
945 if (jcp.dilate_h != 0 && jcp.stride_h == 1) {
947 int dilate_h = jcp.dilate_h + 1;
948 // Note: use div_up to account for "holes" in filter
949 int o_t_overflow = div_up(
950 max(0, (jcp.kh - 1) * dilate_h - oj - jcp.t_pad),
953 = div_up(max(0, (jcp.kh - 1) * dilate_h + 1 - jcp.oh
956 kh_len = jcp.kh - o_t_overflow - o_b_overflow;
957 kh_lo = o_b_overflow;
958 ih_max = oj + jcp.t_pad - o_b_overflow * dilate_h;
960 int o_t_overflow = max(
961 0, (jcp.kh - (oj + 1 + jcp.t_pad)) / jcp.stride_h);
963 = max(0, ((oj + jcp.kh) - (jcp.oh + jcp.b_pad))
965 int overflow_kh_hi = jcp.kh - 1
966 - abs(jcp.oh + jcp.b_pad - (oj + 1)) % jcp.stride_h;
967 int overflow_kh_lo = (oj + jcp.t_pad) % jcp.stride_h;
969 kh_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h
970 + 1 - o_t_overflow - o_b_overflow;
971 kh_lo = overflow_kh_lo + o_b_overflow * jcp.stride_h;
972 ih_max = (oj + jcp.t_pad - kh_lo) / jcp.stride_h;
976 = (!jcp.signed_input) ? kh_lo * wht_kh_stride : 0;
977 p.src = src_w + ih_max * src_h_stride;
978 p.dst = dst_w + oj * dst_h_stride;
979 p.filt = wht_w + wei_stride;
981 p.compensation = compensation_w;
983 0, jcp.kh - (kh_lo + max(0, kh_len - 1) * jcp.stride_h
985 p.b_overflow = kh_lo;
986 p.kh_padding = kh_len;
988 p.oc_blocks = jcp.is_depthwise ? g : ocb;
989 kernel_->jit_ker(&p);
991 if (jcp.loop_order == loop_ngc)
992 nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ,
993 oc_chunks, oh_s, jcp.oh);
994 else if (jcp.loop_order == loop_cgn)
995 nd_iterator_jump(start, end, occ, oc_chunks, g, nb_groups, n,
996 jcp.mb, oh_s, jcp.oh);
998 assert(!"unsupported loop order");
1003 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::u8,
1005 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::u8,
1007 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::u8,
1009 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::u8,
1011 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::s8,
1013 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::s8,
1015 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::s8,
1017 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::s8,