1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "jit_avx512_core_u8s8s32x_deconvolution.hpp"
19 #define GET_OFF(field) offsetof(jit_deconv_call_s, field)
25 using namespace mkldnn::impl::status;
26 using namespace mkldnn::impl::memory_format;
27 using namespace mkldnn::impl::utils;
31 #define wht_blk_off(d, g, ...) \
32 (conf_.with_groups() \
33 ? (d).blk_off((g), __VA_ARGS__) \
34 : (d).blk_off(__VA_ARGS__))
36 status_t jit_avx512_core_u8s8s32x_deconv_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
37 const deconvolution_desc_t &cd, cpu_memory_t::pd_t &src_pd,
38 cpu_memory_t::pd_t &weights_pd, cpu_memory_t::pd_t &dst_pd,
39 const bool with_bias, cpu_memory_t::pd_t &bias_pd,
40 const primitive_attr_t &attr) {
41 const memory_desc_wrapper src_d(&src_pd);
42 const memory_desc_wrapper dst_d(&dst_pd);
43 const memory_desc_wrapper weights_d(&weights_pd);
44 const memory_desc_wrapper bias_d(&bias_pd);
46 if (!(mayiuse(avx512_core) &&
47 src_d.data_type() == data_type::u8
48 && weights_d.data_type() == data_type::s8
49 && one_of(dst_d.data_type(), data_type::f32, data_type::s32,
50 data_type::s8, data_type::u8)))
51 return status::unimplemented;
53 jcp = zero<decltype(jcp)>();
55 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
57 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
58 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
59 jcp.ic = src_d.dims()[1] / jcp.ngroups;
60 jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups;
61 jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups;
62 jcp.is_depthwise = true && with_groups && utils::everyone_is(1,
63 jcp.ic_without_padding, jcp.oc_without_padding);
65 const auto w_format = with_groups
66 ? (jcp.is_depthwise ? Goihw16g : gOIhw4i16o4i)
69 if (dst_d.format() == any)
70 CHECK(dst_pd.set_format(nhwc));
71 if (dst_d.format() != nhwc)
72 return status::unimplemented;
73 if (src_d.format() == any)
74 CHECK(src_pd.set_format(nhwc));
75 if (src_d.format() != nhwc)
76 return status::unimplemented;
77 if (weights_d.format() == any)
78 CHECK(weights_pd.set_format(w_format));
79 if (weights_d.format() != w_format)
80 return status::unimplemented;
82 jcp.with_bias = with_bias;
84 if (bias_d.format() == any)
85 CHECK(bias_pd.set_format(x));
86 if (bias_d.format() != x)
87 return status::unimplemented;
90 jcp.ndims = dst_d.ndims();
91 jcp.prop_kind = cd.prop_kind;
92 jcp.mb = src_d.dims()[0];
93 jcp.ih = src_d.dims()[2];
94 jcp.iw = src_d.dims()[3];
95 jcp.oh = dst_d.dims()[2];
96 jcp.ow = dst_d.dims()[3];
97 jcp.kh = weights_d.dims()[with_groups + 2];
98 jcp.kw = weights_d.dims()[with_groups + 3];
99 jcp.t_pad = cd.padding[0][0];
100 jcp.l_pad = cd.padding[0][1];
101 jcp.stride_h = cd.strides[0];
102 jcp.stride_w = cd.strides[1];
103 jcp.src_fmt = src_d.format();
104 jcp.with_eltwise = false;/*TODO: support post-ops*/
106 if (jcp.is_depthwise) {
115 if (jcp.ngroups == 1) {
116 jcp.oc = utils::rnd_up(jcp.oc_without_padding, jcp.oc_block);
117 jcp.ic = utils::rnd_up(jcp.ic_without_padding, jcp.ic_block);
119 if (jcp.ic % jcp.ic_block != 0)
120 return status::unimplemented;
123 jcp.dilate_h = cd.dilates[0];
124 jcp.dilate_w = cd.dilates[1];
126 if (!IMPLICATION(jcp.dilate_h, jcp.stride_h == 1)
127 || !IMPLICATION(jcp.dilate_w, jcp.stride_w == 1))
128 return status::unimplemented;
130 /*bottom and right :padding*/
131 jcp.b_pad = (jcp.ih - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
132 - (jcp.oh + jcp.t_pad - 1);
133 jcp.r_pad = (jcp.iw - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1)
134 - (jcp.ow + jcp.l_pad - 1);
136 if (!attr.post_ops_.has_default_values())
137 return status::unimplemented;
139 jcp.ver = ver_avx512_core;
140 if (mayiuse(avx512_core_vnni))
142 const auto &oscales = attr.output_scales_;
143 jcp.is_oc_scale = oscales.mask_ == 1 << 1;
145 jcp.dst_dt = dst_d.data_type();
146 jcp.bia_dt = jcp.with_bias ? bias_d.data_type() : data_type::undef;
147 jcp.typesize_bia = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
148 jcp.typesize_in = types::data_type_size(src_d.data_type());
149 jcp.typesize_out = types::data_type_size(dst_d.data_type());
151 jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block);
152 jcp.nb_oc = jcp.oc / jcp.oc_block;
153 jcp.nb_ic = jcp.ic / jcp.ic_block;
155 /*kernel blocking params*/
156 const int regs = jcp.ver == ver_vnni ? 31 : 29;
157 jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc);
158 for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--)
159 if (jcp.nb_oc % jcp.nb_oc_blocking == 0
160 && jcp.l_pad <= regs / (jcp.nb_oc_blocking + 1))
163 jcp.ur_w = regs / (jcp.nb_oc_blocking + 1);
164 int l_overflow = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w);
165 int r_overflow = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
166 - max(0, jcp.r_pad)) / jcp.stride_w);
167 if (jcp.ow < jcp.ur_w)
169 for (; jcp.ur_w > 1; jcp.ur_w--)
170 if (jcp.ur_w % jcp.stride_w == 0
172 r_overflow - (jcp.ow % jcp.ur_w) / jcp.stride_w) * jcp.stride_w <= jcp.ur_w)
174 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
176 jcp.loop_order = jcp.ngroups > 1 ? loop_ngc : loop_cgn;
177 return status::success;
180 void jit_avx512_core_u8s8s32x_deconv_fwd_kernel::compute_ker(
181 int ur_w, int l_overflow, int r_overflow, ker_block_t last_block) {
183 int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block;
184 int shift_src_ih = jcp.typesize_in * (jcp.dilate_h + 1)
185 * jcp.iw * jcp.ngroups * jcp.ic_without_padding;
186 int shift_filt_kh = jcp.typesize_in * jcp.kw * jcp.stride_h * ch_block_all;
188 auto src_offset = [=] (int oj, int icb, int ki) {
189 return jcp.typesize_in *
190 (((oj + jcp.l_pad - ki * (jcp.dilate_w + 1)) / jcp.stride_w) * jcp.ngroups * jcp.ic_without_padding + icb * 4);
193 auto kernel_offset = [=] (int ocb, int icb, int ki) {
194 return jcp.typesize_in *
195 (ocb * jcp.nb_ic * jcp.kh * jcp.kw * ch_block_all + icb * jcp.oc_block * jcp.ic_block/4
196 + ki * ch_block_all);
199 auto compute = [=](zmm_t vreg_acc, zmm_t vreg_wei, zmm_t vreg_src) {
200 if (jcp.ver == ver_vnni) {
201 vpdpbusd(vreg_acc, vreg_src, vreg_wei);
202 } else if (jcp.is_depthwise) {
203 vpmulld(zmm_tmp, vreg_src, vreg_wei);
204 vpaddd(vreg_acc, vreg_acc, zmm_tmp);
206 vpmaddubsw(zmm_tmp, vreg_src, vreg_wei);
207 vpmaddwd(zmm_tmp, zmm_tmp, zmm_one);
208 vpaddd(vreg_acc, vreg_acc, zmm_tmp);
212 mov(aux_reg_src, reg_src);
213 mov(aux_reg_filt, reg_filt);
215 Xbyak::Label kh_loop_label;
217 for (int ki = 0; ki < jcp.kw; ki++) {
218 int jj_start = get_ow_start(ki, l_overflow);
219 int jj_end = get_ow_end(ur_w, ki, r_overflow);
220 int tail_size = jcp.ic_without_padding % 4;
221 int n_ic_blocks = jcp.is_depthwise
223 : (last_block & ~no_last_block
224 ? div_up(jcp.ic_without_padding % jcp.ic_block, 4)
226 for (int icb1 = 0; icb1 < n_ic_blocks; icb1++) {
227 for (int jj = jj_start; jj < jj_end; jj += jcp.stride_w) {
228 assert((jj + jcp.l_pad - ki) % jcp.stride_w == 0);
230 int aux_src_off = src_offset(jj, icb1, ki);
231 if (jcp.is_depthwise) {
232 vpmovzxbd(zmm_inp(jj, jcp.nb_oc_blocking),
233 EVEX_compress_addr(aux_reg_src, aux_src_off));
234 } else if ((last_block & last_sp_block)
235 && tail_size != 0 && icb1 == n_ic_blocks - 1) {
236 xmm_t xmm_tmp = xmm_t(zmm_inp(jj, jcp.nb_oc_blocking).getIdx());
237 for (int r = 0; r < tail_size; ++r)
238 vpinsrb(xmm_tmp, xmm_tmp,
239 ptr[aux_reg_src + aux_src_off + r], r);
240 vpbroadcastd(zmm_inp(jj, jcp.nb_oc_blocking), xmm_tmp);
242 vpbroadcastd(zmm_inp(jj, jcp.nb_oc_blocking),
243 EVEX_compress_addr(aux_reg_src, aux_src_off));
247 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
248 int aux_filt_off = kernel_offset(ocb, icb1, ki);
249 if (jj_end - jj_start > 0) {
250 if (jcp.is_depthwise)
252 EVEX_compress_addr(aux_reg_filt, aux_filt_off));
255 EVEX_compress_addr(aux_reg_filt, aux_filt_off));
257 for (int jj = jj_start; jj < jj_end; jj += jcp.stride_w) {
258 compute(zmm_out(jj, ocb),
259 zmm_wei, zmm_inp(jj, jcp.nb_oc_blocking));
264 sub(aux_reg_src, shift_src_ih);
265 add(aux_reg_filt, shift_filt_kh);
268 jg(kh_loop_label, T_NEAR);
272 void jit_avx512_core_u8s8s32x_deconv_fwd_kernel::prepare_output(int ur_w) {
273 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
274 for (int ur = 0; ur < ur_w; ur++) {
275 zmm_t zmm = zmm_out(ur, ocb);
276 vpxord(zmm, zmm, zmm);
281 void jit_avx512_core_u8s8s32x_deconv_fwd_kernel::cvt2ps(data_type_t type_in,
282 zmm_t zmm_in, const Xbyak::Operand &op, bool mask_flag) {
283 zmm_t zmm = mask_flag ? zmm_in | ktail_mask | T_z : zmm_in;
286 case data_type::s32: vmovups(zmm, op); break;
287 case data_type::s8: vpmovsxbd(zmm, op); break;
288 case data_type::u8: vpmovzxbd(zmm, op); break;
289 default: assert(!"unsupported data type");
291 if (type_in != data_type::f32)
292 vcvtdq2ps(zmm_in, zmm_in);
295 void jit_avx512_core_u8s8s32x_deconv_fwd_kernel::store_output(int ur_w, bool last_oc_block) {
296 mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
297 mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
299 vpxord(zmm_zero, zmm_zero, zmm_zero);
300 for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
301 const bool mask_flag = last_oc_block && ocb == jcp.nb_oc_blocking - 1;
302 int scale_offset = jcp.is_oc_scale * (sizeof(float) * ocb * jcp.oc_block);
304 auto zmm_bias = zmm_tmp;
306 int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block;
307 auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
308 cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag);
311 for (int ur = 0; ur < ur_w; ur++) {
312 zmm_t zmm = zmm_out(ur, ocb);
314 if (jcp.with_bias) vaddps(zmm, zmm, zmm_bias);
315 zmm_t mask_zmm = mask_flag
316 ? zmm | ktail_mask | T_z
318 vmulps(mask_zmm, zmm,
319 EVEX_compress_addr(reg_ptr_scales, scale_offset));
321 if (jcp.dst_dt == data_type::u8) vmaxps(zmm, zmm_zero, zmm);
323 if (jcp.dst_dt != data_type::f32) {
324 if (attr_.round_mode_ == round_mode::nearest)
325 vcvtps2dq(zmm | T_rn_sae, zmm);
326 else if (attr_.round_mode_ == round_mode::down)
327 vcvtps2dq(zmm | T_rd_sae, zmm);
329 assert(!"unimplemented");
332 for (int ur = 0; ur < ur_w; ur++) {
333 int aux_dst_off = jcp.typesize_out
334 * (ur * jcp.ngroups * jcp.oc_without_padding + ocb * jcp.oc_block);
335 auto addr = EVEX_compress_addr(reg_dst, aux_dst_off);
337 zmm_t zmm = zmm_out(ur, ocb);
338 zmm_t r_zmm = mask_flag
341 switch (jcp.dst_dt) {
343 case data_type::s32: vmovups(addr, r_zmm); break;
344 case data_type::s8: vpmovsdb(addr, r_zmm); break;
345 case data_type::u8: vpmovusdb(addr, r_zmm); break;
346 default: assert(!"unknown dst_dt");
352 void jit_avx512_core_u8s8s32x_deconv_fwd_kernel::compute_loop(
353 int ur_w, int l_overflow, int r_overflow, bool is_last_sp_block) {
355 int shift_src_icb = jcp.typesize_in * jcp.ic_block;
356 int shift_filt_icb = jcp.typesize_in * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block;
358 prepare_output(ur_w);
360 Xbyak::Label icb_loop_label;
361 mov(reg_icb, jcp.nb_ic);
364 if (jcp.ic_without_padding != jcp.ic) {
365 Xbyak::Label common_ker, end_ker;
367 jg(common_ker, T_NEAR);
369 compute_ker(ur_w, l_overflow, r_overflow,
370 is_last_sp_block ? last_sp_block : last_ic_block);
371 jmp(end_ker, T_NEAR);
374 compute_ker(ur_w, l_overflow, r_overflow, no_last_block);
378 compute_ker(ur_w, l_overflow, r_overflow, no_last_block);
381 add(reg_src, shift_src_icb);
382 add(reg_filt, shift_filt_icb);
385 jg(icb_loop_label, T_NEAR);
387 sub(reg_src, jcp.nb_ic * shift_src_icb);
388 sub(reg_filt, jcp.nb_ic * shift_filt_icb);
390 if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
391 Xbyak::Label common_store, end_store;
392 mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
393 if (jcp.is_depthwise)
394 cmp(reg_oc_blocks, jcp.nb_ch - 1);
396 cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
397 jne(common_store, T_NEAR);
399 store_output(ur_w, true);
400 jmp(end_store, T_NEAR);
403 store_output(ur_w, false);
408 store_output(ur_w, false);
412 void jit_avx512_core_u8s8s32x_deconv_fwd_kernel::generate() {
415 Xbyak::Reg16 _t = reg_scratch.cvt16();
417 vpbroadcastw(zmm_one, _t);
419 if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
420 int tail_size = jcp.is_depthwise
421 ? jcp.ngroups % jcp.ch_block
422 : jcp.oc_without_padding % jcp.oc_block;
423 int mask = (1 << tail_size) - 1;
424 Xbyak::Reg32 regw_tmp = reg_nur_w.cvt32();
426 kmovw(ktail_mask, regw_tmp);
429 mov(reg_src, ptr[param1 + GET_OFF(src)]);
430 mov(reg_filt, ptr[param1 + GET_OFF(filt)]);
431 mov(reg_dst, ptr[param1 + GET_OFF(dst)]);
432 mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]);
434 int dst_shift = jcp.typesize_out * jcp.ur_w * jcp.ngroups * jcp.oc_without_padding;
435 int src_shift = jcp.typesize_in * (jcp.ur_w / jcp.stride_w) * jcp.ngroups * jcp.ic_without_padding;
437 int l_overflow = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w);
438 int r_overflow = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
439 - max(0, jcp.r_pad)) / jcp.stride_w);
441 int r_overflow1 = nstl::max(0, ((jcp.kw -1) * (jcp.dilate_w + 1)
442 - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w);
443 int nur_w = jcp.ow / jcp.ur_w;
444 if (r_overflow1 > 0) nur_w--;
446 if (jcp.ur_w == jcp.ow) {
447 compute_loop(jcp.ur_w, l_overflow, r_overflow, true);
448 } else if (nur_w == 0) {
449 compute_loop(jcp.ur_w, l_overflow, r_overflow1, jcp.ur_w_tail == 0);
450 add(reg_src, src_shift);
451 add(reg_dst, dst_shift);
452 if (jcp.ur_w_tail != 0)
453 compute_loop(jcp.ur_w_tail, 0, r_overflow, true);
455 xor_(reg_nur_w, reg_nur_w);
456 if (l_overflow > 0) {
457 compute_loop(jcp.ur_w, l_overflow, 0, false);
458 add(reg_src, src_shift);
459 add(reg_dst, dst_shift);
462 if ((l_overflow <= 0 && nur_w > 0)
463 || (l_overflow > 0 && nur_w > 1)) {
464 Xbyak::Label ow_loop_label;
466 compute_loop(jcp.ur_w, 0, 0, false);
467 add(reg_src, src_shift);
468 add(reg_dst, dst_shift);
470 cmp(reg_nur_w, nur_w);
471 jl(ow_loop_label, T_NEAR);
474 if (r_overflow1 > 0) {
475 compute_loop(jcp.ur_w, 0, r_overflow1, jcp.ur_w_tail == 0);
476 add(reg_src, src_shift);
477 add(reg_dst, dst_shift);
479 if (jcp.ur_w_tail != 0) {
480 compute_loop(jcp.ur_w_tail, 0, r_overflow, true);
486 template <data_type_t dst_type>
487 void _jit_avx512_core_u8s8s32x_deconvolution_fwd_t<dst_type>::
490 auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
491 auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
492 auto bias = reinterpret_cast<const char *>(this->input_memory(2));
493 auto dst = reinterpret_cast<dst_data_t *>(this->memory());
495 const memory_desc_wrapper src_d(conf_.src_pd());
496 const memory_desc_wrapper dst_d(conf_.dst_pd());
497 const memory_desc_wrapper weights_d(conf_.weights_pd(0));
498 const memory_desc_wrapper bias_d(conf_.weights_pd(1));
500 auto &jcp = kernel_->jcp;
502 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
503 int nb_groups = jcp.nb_ch;
505 size_t src_h_stride = src_d.blk_off(0, 0, 1);
506 size_t dst_h_stride = dst_d.blk_off(0, 0, 1);
507 size_t wht_kh_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
509 const auto &oscales = conf_.attr()->output_scales_;
512 [&](const int ithr, const int nthr) {
513 int start{0}, end{0};
514 int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh;
515 balance211(work_amount, nthr, ithr, start, end);
517 auto p = jit_deconv_call_s();
520 int n{0}, g{0}, occ{0}, oh_s{0};
521 if (jcp.loop_order == loop_ngc)
522 nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks,
524 else if (jcp.loop_order == loop_cgn)
525 nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb,
528 assert(!"unsupported loop order");
529 while (start < end) {
531 int ocb = occ * jcp.nb_oc_blocking;
532 int g_oc = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block;
533 int g_ic = g * jcp.ch_block * jcp.ic;
534 int work_rem = end - start;
535 int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
537 auto dst_w = dst + dst_d.blk_off(n, g_oc);
538 auto src_w = src + src_d.blk_off(n, g_ic);
539 auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0);
540 auto bias_w = jcp.with_bias
541 ? bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia)
544 auto scales = &oscales.scales_[jcp.is_oc_scale * g_oc];
545 for (int oj = oh_s; oj < oh_e; oj++) {
546 int ih_max, kh_lo, kh_len;
547 if (jcp.dilate_h != 0 && jcp.stride_h == 1) {
548 int dilate_h = jcp.dilate_h + 1;
549 // Note: use div_up to account for "holes" in filter
551 = div_up(max(0, (jcp.kh - 1) * dilate_h
552 - oj - jcp.t_pad), dilate_h);
554 = div_up(max(0, (jcp.kh - 1) * dilate_h + 1
555 - jcp.ih + oj - jcp.b_pad), dilate_h);
556 kh_len = jcp.kh - o_t_overflow - o_b_overflow;
557 kh_lo = o_b_overflow;
558 ih_max = oj + jcp.t_pad - o_b_overflow * dilate_h;
560 int o_t_overflow = max(0,
561 (jcp.kh - (oj + 1 + jcp.t_pad)) / jcp.stride_h);
562 int o_b_overflow = max(0,
563 ((oj + 1 + jcp.kh - 1)
564 - (jcp.oh + jcp.b_pad)) / jcp.stride_h);
565 int overflow_kh_hi = jcp.kh - 1
566 - abs(jcp.oh + jcp.b_pad - (oj + 1)) % jcp.stride_h;
567 int overflow_kh_lo = ((oj + 1 + jcp.t_pad) - 1) % jcp.stride_h;
569 kh_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h
570 + 1 - o_t_overflow - o_b_overflow;
571 kh_lo = overflow_kh_lo + o_b_overflow * jcp.stride_h;
572 ih_max = (oj + jcp.t_pad - kh_lo) / jcp.stride_h;
575 p.src = src_w + ih_max * src_h_stride;
576 p.dst = dst_w + oj * dst_h_stride;
577 p.filt = wht_w + kh_lo * wht_kh_stride;
579 p.kh_padding = kh_len;
581 p.oc_blocks = jcp.is_depthwise ? g : ocb;
582 kernel_->jit_ker(&p);
584 if (jcp.loop_order == loop_ngc)
585 nd_iterator_jump(start, end,
586 n, jcp.mb, g, nb_groups, occ, oc_chunks, oh_s, jcp.oh);
587 else if (jcp.loop_order == loop_cgn)
588 nd_iterator_jump(start, end,
589 occ, oc_chunks, g, nb_groups, n, jcp.mb, oh_s, jcp.oh);
591 assert(!"unsupported loop order");
596 template struct _jit_avx512_core_u8s8s32x_deconvolution_fwd_t<data_type::u8>;
597 template struct _jit_avx512_core_u8s8s32x_deconvolution_fwd_t<data_type::s8>;
598 template struct _jit_avx512_core_u8s8s32x_deconvolution_fwd_t<data_type::f32>;
599 template struct _jit_avx512_core_u8s8s32x_deconvolution_fwd_t<data_type::s32>;