Add a section of how to link IE with CMake project (#99)
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_u8s8s32x_deconvolution.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "jit_avx512_core_u8s8s32x_deconvolution.hpp"
18
19 #define GET_OFF(field) offsetof(jit_deconv_call_s, field)
20
21 namespace mkldnn {
22 namespace impl {
23 namespace cpu {
24
25 using namespace mkldnn::impl::status;
26 using namespace mkldnn::impl::memory_format;
27 using namespace mkldnn::impl::utils;
28
29 using namespace nstl;
30
31 #define wht_blk_off(d, g, ...) \
32         (conf_.with_groups() \
33          ? (d).blk_off((g), __VA_ARGS__) \
34          : (d).blk_off(__VA_ARGS__))
35
36 status_t jit_avx512_core_u8s8s32x_deconv_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
37         const deconvolution_desc_t &cd, cpu_memory_t::pd_t &src_pd,
38         cpu_memory_t::pd_t &weights_pd, cpu_memory_t::pd_t &dst_pd,
39         const bool with_bias, cpu_memory_t::pd_t &bias_pd,
40         const primitive_attr_t &attr) {
41     const memory_desc_wrapper src_d(&src_pd);
42     const memory_desc_wrapper dst_d(&dst_pd);
43     const memory_desc_wrapper weights_d(&weights_pd);
44     const memory_desc_wrapper bias_d(&bias_pd);
45
46     if (!(mayiuse(avx512_core) &&
47             src_d.data_type() == data_type::u8
48          && weights_d.data_type() == data_type::s8
49          && one_of(dst_d.data_type(), data_type::f32, data_type::s32,
50             data_type::s8, data_type::u8)))
51         return status::unimplemented;
52
53     jcp = zero<decltype(jcp)>();
54
55     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
56
57     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
58     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
59     jcp.ic = src_d.dims()[1] / jcp.ngroups;
60     jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups;
61     jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups;
62     jcp.is_depthwise = true && with_groups && utils::everyone_is(1,
63             jcp.ic_without_padding, jcp.oc_without_padding);
64
65     const auto w_format = with_groups
66         ? (jcp.is_depthwise ? Goihw16g : gOIhw4i16o4i)
67         : OIhw4i16o4i;
68
69     if (dst_d.format() == any)
70         CHECK(dst_pd.set_format(nhwc));
71     if (dst_d.format() != nhwc)
72         return status::unimplemented;
73     if (src_d.format() == any)
74         CHECK(src_pd.set_format(nhwc));
75     if (src_d.format() != nhwc)
76         return status::unimplemented;
77     if (weights_d.format() == any)
78         CHECK(weights_pd.set_format(w_format));
79     if (weights_d.format() != w_format)
80         return status::unimplemented;
81
82     jcp.with_bias = with_bias;
83     if (jcp.with_bias) {
84         if (bias_d.format() == any)
85             CHECK(bias_pd.set_format(x));
86         if (bias_d.format() != x)
87             return status::unimplemented;
88     }
89
90     jcp.ndims = dst_d.ndims();
91     jcp.prop_kind = cd.prop_kind;
92     jcp.mb = src_d.dims()[0];
93     jcp.ih = src_d.dims()[2];
94     jcp.iw = src_d.dims()[3];
95     jcp.oh = dst_d.dims()[2];
96     jcp.ow = dst_d.dims()[3];
97     jcp.kh = weights_d.dims()[with_groups + 2];
98     jcp.kw = weights_d.dims()[with_groups + 3];
99     jcp.t_pad = cd.padding[0][0];
100     jcp.l_pad = cd.padding[0][1];
101     jcp.stride_h = cd.strides[0];
102     jcp.stride_w = cd.strides[1];
103     jcp.src_fmt = src_d.format();
104     jcp.with_eltwise = false;/*TODO: support post-ops*/
105
106     if (jcp.is_depthwise) {
107         jcp.ch_block = 16;
108         jcp.oc_block = 1;
109         jcp.ic_block = 1;
110     } else {
111         jcp.ch_block = 1;
112         jcp.oc_block = 16;
113         jcp.ic_block = 16;
114
115         if (jcp.ngroups == 1) {
116             jcp.oc = utils::rnd_up(jcp.oc_without_padding, jcp.oc_block);
117             jcp.ic = utils::rnd_up(jcp.ic_without_padding, jcp.ic_block);
118         }
119         if (jcp.ic % jcp.ic_block != 0)
120             return status::unimplemented;
121     }
122
123     jcp.dilate_h = cd.dilates[0];
124     jcp.dilate_w = cd.dilates[1];
125
126     if (!IMPLICATION(jcp.dilate_h, jcp.stride_h == 1)
127             || !IMPLICATION(jcp.dilate_w, jcp.stride_w == 1))
128             return status::unimplemented;
129
130     /*bottom and right :padding*/
131     jcp.b_pad = (jcp.ih - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
132             - (jcp.oh + jcp.t_pad - 1);
133     jcp.r_pad = (jcp.iw - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1)
134             - (jcp.ow + jcp.l_pad - 1);
135
136     if (!attr.post_ops_.has_default_values())
137         return status::unimplemented;
138
139     jcp.ver = ver_avx512_core;
140     if (mayiuse(avx512_core_vnni))
141         jcp.ver = ver_vnni;
142     const auto &oscales = attr.output_scales_;
143     jcp.is_oc_scale = oscales.mask_ == 1 << 1;
144
145     jcp.dst_dt = dst_d.data_type();
146     jcp.bia_dt = jcp.with_bias ? bias_d.data_type() : data_type::undef;
147     jcp.typesize_bia = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
148     jcp.typesize_in = types::data_type_size(src_d.data_type());
149     jcp.typesize_out = types::data_type_size(dst_d.data_type());
150
151     jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block);
152     jcp.nb_oc = jcp.oc / jcp.oc_block;
153     jcp.nb_ic = jcp.ic / jcp.ic_block;
154
155     /*kernel blocking params*/
156     const int regs = jcp.ver == ver_vnni ? 31 : 29;
157     jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc);
158     for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--)
159         if (jcp.nb_oc % jcp.nb_oc_blocking == 0
160                 && jcp.l_pad <= regs / (jcp.nb_oc_blocking + 1))
161             break;
162
163     jcp.ur_w = regs / (jcp.nb_oc_blocking + 1);
164     int l_overflow = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w);
165     int r_overflow = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
166                      - max(0, jcp.r_pad)) / jcp.stride_w);
167     if (jcp.ow < jcp.ur_w)
168         jcp.ur_w = jcp.ow;
169     for (; jcp.ur_w > 1; jcp.ur_w--)
170         if (jcp.ur_w % jcp.stride_w == 0
171                 && max(l_overflow,
172                     r_overflow - (jcp.ow % jcp.ur_w) / jcp.stride_w) * jcp.stride_w <= jcp.ur_w)
173             break;
174     jcp.ur_w_tail = jcp.ow % jcp.ur_w;
175
176     jcp.loop_order = jcp.ngroups > 1 ? loop_ngc : loop_cgn;
177     return status::success;
178 }
179
180 void jit_avx512_core_u8s8s32x_deconv_fwd_kernel::compute_ker(
181         int ur_w, int l_overflow, int r_overflow, ker_block_t last_block) {
182
183     int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block;
184     int shift_src_ih = jcp.typesize_in * (jcp.dilate_h + 1)
185         * jcp.iw * jcp.ngroups * jcp.ic_without_padding;
186     int shift_filt_kh = jcp.typesize_in *  jcp.kw * jcp.stride_h * ch_block_all;
187
188     auto src_offset = [=] (int oj, int icb, int ki) {
189          return jcp.typesize_in *
190            (((oj + jcp.l_pad - ki * (jcp.dilate_w + 1)) / jcp.stride_w) * jcp.ngroups * jcp.ic_without_padding + icb * 4);
191     };
192
193     auto kernel_offset = [=] (int ocb, int icb, int ki) {
194         return jcp.typesize_in *
195             (ocb * jcp.nb_ic * jcp.kh * jcp.kw * ch_block_all + icb * jcp.oc_block * jcp.ic_block/4
196              + ki * ch_block_all);
197     };
198
199     auto compute = [=](zmm_t vreg_acc, zmm_t vreg_wei, zmm_t vreg_src) {
200         if (jcp.ver == ver_vnni) {
201             vpdpbusd(vreg_acc, vreg_src, vreg_wei);
202         } else if (jcp.is_depthwise) {
203             vpmulld(zmm_tmp, vreg_src, vreg_wei);
204             vpaddd(vreg_acc, vreg_acc, zmm_tmp);
205         } else {
206             vpmaddubsw(zmm_tmp, vreg_src, vreg_wei);
207             vpmaddwd(zmm_tmp, zmm_tmp, zmm_one);
208             vpaddd(vreg_acc, vreg_acc, zmm_tmp);
209         }
210     };
211
212     mov(aux_reg_src, reg_src);
213     mov(aux_reg_filt, reg_filt);
214     mov(reg_kj, reg_kh);
215     Xbyak::Label kh_loop_label;
216     L(kh_loop_label); {
217        for (int ki = 0; ki < jcp.kw; ki++) {
218            int jj_start = get_ow_start(ki, l_overflow);
219            int jj_end = get_ow_end(ur_w, ki, r_overflow);
220            int tail_size = jcp.ic_without_padding % 4;
221            int n_ic_blocks = jcp.is_depthwise
222                            ? 1
223                            : (last_block &  ~no_last_block
224                                    ? div_up(jcp.ic_without_padding % jcp.ic_block, 4)
225                                    : jcp.ic_block / 4);
226            for (int icb1 = 0; icb1 < n_ic_blocks; icb1++) {
227                for (int jj = jj_start; jj < jj_end; jj += jcp.stride_w) {
228                     assert((jj + jcp.l_pad - ki) % jcp.stride_w == 0);
229
230                    int aux_src_off = src_offset(jj, icb1, ki);
231                    if (jcp.is_depthwise) {
232                        vpmovzxbd(zmm_inp(jj, jcp.nb_oc_blocking),
233                                    EVEX_compress_addr(aux_reg_src, aux_src_off));
234                    } else if ((last_block & last_sp_block)
235                            && tail_size != 0 && icb1 == n_ic_blocks - 1) {
236                        xmm_t xmm_tmp = xmm_t(zmm_inp(jj, jcp.nb_oc_blocking).getIdx());
237                        for (int r = 0; r < tail_size; ++r)
238                            vpinsrb(xmm_tmp, xmm_tmp,
239                                    ptr[aux_reg_src + aux_src_off + r], r);
240                        vpbroadcastd(zmm_inp(jj, jcp.nb_oc_blocking), xmm_tmp);
241                    } else {
242                        vpbroadcastd(zmm_inp(jj, jcp.nb_oc_blocking),
243                                EVEX_compress_addr(aux_reg_src, aux_src_off));
244                    }
245                }
246
247                for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
248                    int aux_filt_off = kernel_offset(ocb, icb1, ki);
249                    if (jj_end - jj_start > 0) {
250                        if (jcp.is_depthwise)
251                            vpmovsxbd(zmm_wei,
252                                EVEX_compress_addr(aux_reg_filt, aux_filt_off));
253                        else
254                            vmovups(zmm_wei,
255                                    EVEX_compress_addr(aux_reg_filt, aux_filt_off));
256                    }
257                    for (int jj = jj_start; jj < jj_end; jj += jcp.stride_w) {
258                        compute(zmm_out(jj, ocb),
259                                zmm_wei, zmm_inp(jj, jcp.nb_oc_blocking));
260                    }
261                }
262            }
263        }
264        sub(aux_reg_src, shift_src_ih);
265        add(aux_reg_filt, shift_filt_kh);
266        dec(reg_kj);
267        cmp(reg_kj, 0);
268        jg(kh_loop_label, T_NEAR);
269     }
270 }
271
272 void jit_avx512_core_u8s8s32x_deconv_fwd_kernel::prepare_output(int ur_w) {
273     for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
274         for (int ur = 0; ur < ur_w; ur++) {
275                 zmm_t zmm = zmm_out(ur, ocb);
276                 vpxord(zmm, zmm, zmm);
277         }
278     }
279 }
280
281 void jit_avx512_core_u8s8s32x_deconv_fwd_kernel::cvt2ps(data_type_t type_in,
282         zmm_t zmm_in, const Xbyak::Operand &op, bool mask_flag) {
283     zmm_t zmm = mask_flag ? zmm_in | ktail_mask | T_z : zmm_in;
284     switch (type_in) {
285     case data_type::f32:
286     case data_type::s32: vmovups(zmm, op); break;
287     case data_type::s8: vpmovsxbd(zmm, op); break;
288     case data_type::u8: vpmovzxbd(zmm, op); break;
289     default: assert(!"unsupported data type");
290     }
291     if (type_in != data_type::f32)
292         vcvtdq2ps(zmm_in, zmm_in);
293 }
294
295 void jit_avx512_core_u8s8s32x_deconv_fwd_kernel::store_output(int ur_w, bool last_oc_block) {
296     mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
297     mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
298
299     vpxord(zmm_zero, zmm_zero, zmm_zero);
300     for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
301         const bool mask_flag = last_oc_block && ocb == jcp.nb_oc_blocking - 1;
302         int scale_offset = jcp.is_oc_scale * (sizeof(float) * ocb * jcp.oc_block);
303
304         auto zmm_bias = zmm_tmp;
305         if (jcp.with_bias) {
306             int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block;
307             auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
308             cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag);
309         }
310
311         for (int ur = 0; ur < ur_w; ur++) {
312             zmm_t zmm = zmm_out(ur, ocb);
313             vcvtdq2ps(zmm, zmm);
314             if (jcp.with_bias) vaddps(zmm, zmm, zmm_bias);
315             zmm_t mask_zmm = mask_flag
316                            ? zmm | ktail_mask | T_z
317                            : zmm;
318             vmulps(mask_zmm, zmm,
319                     EVEX_compress_addr(reg_ptr_scales, scale_offset));
320
321             if (jcp.dst_dt == data_type::u8) vmaxps(zmm, zmm_zero, zmm);
322
323             if (jcp.dst_dt != data_type::f32) {
324                 if (attr_.round_mode_ == round_mode::nearest)
325                     vcvtps2dq(zmm | T_rn_sae, zmm);
326                 else if (attr_.round_mode_ == round_mode::down)
327                     vcvtps2dq(zmm | T_rd_sae, zmm);
328                 else
329                     assert(!"unimplemented");
330             }
331         }
332         for (int ur = 0; ur < ur_w; ur++) {
333             int aux_dst_off = jcp.typesize_out
334                 * (ur * jcp.ngroups * jcp.oc_without_padding + ocb * jcp.oc_block);
335             auto addr = EVEX_compress_addr(reg_dst, aux_dst_off);
336
337             zmm_t zmm = zmm_out(ur, ocb);
338             zmm_t r_zmm = mask_flag
339                         ? zmm | ktail_mask
340                         : zmm;
341             switch (jcp.dst_dt) {
342             case data_type::f32:
343             case data_type::s32: vmovups(addr, r_zmm); break;
344             case data_type::s8: vpmovsdb(addr, r_zmm); break;
345             case data_type::u8: vpmovusdb(addr, r_zmm); break;
346             default: assert(!"unknown dst_dt");
347             }
348         }
349     }
350 }
351
352 void jit_avx512_core_u8s8s32x_deconv_fwd_kernel::compute_loop(
353         int ur_w, int l_overflow, int r_overflow, bool is_last_sp_block) {
354
355     int shift_src_icb = jcp.typesize_in * jcp.ic_block;
356     int shift_filt_icb = jcp.typesize_in * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block;
357
358     prepare_output(ur_w);
359
360     Xbyak::Label icb_loop_label;
361     mov(reg_icb, jcp.nb_ic);
362     L(icb_loop_label); {
363
364         if (jcp.ic_without_padding != jcp.ic) {
365             Xbyak::Label common_ker, end_ker;
366             cmp(reg_icb, 1);
367             jg(common_ker, T_NEAR);
368
369             compute_ker(ur_w, l_overflow, r_overflow,
370                     is_last_sp_block ? last_sp_block : last_ic_block);
371             jmp(end_ker, T_NEAR);
372
373             L(common_ker);
374             compute_ker(ur_w, l_overflow, r_overflow, no_last_block);
375
376             L(end_ker);
377         } else {
378             compute_ker(ur_w, l_overflow, r_overflow, no_last_block);
379         }
380
381         add(reg_src, shift_src_icb);
382         add(reg_filt, shift_filt_icb);
383         dec(reg_icb);
384         cmp(reg_icb, 0);
385         jg(icb_loop_label, T_NEAR);
386     }
387     sub(reg_src, jcp.nb_ic * shift_src_icb);
388     sub(reg_filt, jcp.nb_ic * shift_filt_icb);
389
390     if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
391         Xbyak::Label common_store, end_store;
392         mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
393         if (jcp.is_depthwise)
394             cmp(reg_oc_blocks, jcp.nb_ch - 1);
395         else
396             cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
397         jne(common_store, T_NEAR);
398
399         store_output(ur_w, true);
400         jmp(end_store, T_NEAR);
401
402         L(common_store);
403         store_output(ur_w, false);
404
405         L(end_store);
406
407     } else {
408         store_output(ur_w, false);
409     }
410 }
411
412 void jit_avx512_core_u8s8s32x_deconv_fwd_kernel::generate() {
413     preamble();
414
415     Xbyak::Reg16 _t = reg_scratch.cvt16();
416     mov(_t, 0x1);
417     vpbroadcastw(zmm_one, _t);
418
419     if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
420         int tail_size = jcp.is_depthwise
421             ? jcp.ngroups % jcp.ch_block
422             : jcp.oc_without_padding % jcp.oc_block;
423         int mask = (1 << tail_size) - 1;
424         Xbyak::Reg32 regw_tmp = reg_nur_w.cvt32();
425         mov(regw_tmp, mask);
426         kmovw(ktail_mask, regw_tmp);
427     }
428
429     mov(reg_src, ptr[param1 + GET_OFF(src)]);
430     mov(reg_filt, ptr[param1 + GET_OFF(filt)]);
431     mov(reg_dst, ptr[param1 + GET_OFF(dst)]);
432     mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]);
433
434     int dst_shift = jcp.typesize_out * jcp.ur_w * jcp.ngroups * jcp.oc_without_padding;
435     int src_shift = jcp.typesize_in * (jcp.ur_w / jcp.stride_w) * jcp.ngroups * jcp.ic_without_padding;
436
437     int l_overflow = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w);
438     int r_overflow = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
439                      - max(0, jcp.r_pad)) / jcp.stride_w);
440
441     int r_overflow1 = nstl::max(0, ((jcp.kw -1) * (jcp.dilate_w + 1)
442                 - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w);
443     int nur_w = jcp.ow / jcp.ur_w;
444     if (r_overflow1 > 0) nur_w--;
445
446     if (jcp.ur_w == jcp.ow) {
447         compute_loop(jcp.ur_w, l_overflow, r_overflow, true);
448     } else if (nur_w == 0) {
449         compute_loop(jcp.ur_w, l_overflow, r_overflow1, jcp.ur_w_tail == 0);
450         add(reg_src, src_shift);
451         add(reg_dst, dst_shift);
452         if (jcp.ur_w_tail != 0)
453             compute_loop(jcp.ur_w_tail, 0, r_overflow, true);
454     } else {
455         xor_(reg_nur_w, reg_nur_w);
456         if (l_overflow > 0) {
457             compute_loop(jcp.ur_w, l_overflow, 0, false);
458             add(reg_src, src_shift);
459             add(reg_dst, dst_shift);
460             inc(reg_nur_w);
461         }
462         if ((l_overflow <= 0 && nur_w > 0)
463                 || (l_overflow > 0 && nur_w > 1)) {
464             Xbyak::Label ow_loop_label;
465             L(ow_loop_label); {
466                 compute_loop(jcp.ur_w, 0, 0, false);
467                 add(reg_src, src_shift);
468                 add(reg_dst, dst_shift);
469                 inc(reg_nur_w);
470                 cmp(reg_nur_w, nur_w);
471                 jl(ow_loop_label, T_NEAR);
472             }
473         }
474         if (r_overflow1 > 0) {
475             compute_loop(jcp.ur_w, 0, r_overflow1, jcp.ur_w_tail == 0);
476             add(reg_src, src_shift);
477             add(reg_dst, dst_shift);
478         }
479         if (jcp.ur_w_tail != 0) {
480             compute_loop(jcp.ur_w_tail, 0, r_overflow, true);
481         }
482     }
483     postamble();
484 }
485
486 template <data_type_t dst_type>
487 void _jit_avx512_core_u8s8s32x_deconvolution_fwd_t<dst_type>::
488 execute_forward()
489 {
490     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
491     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
492     auto bias = reinterpret_cast<const char *>(this->input_memory(2));
493     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
494
495     const memory_desc_wrapper src_d(conf_.src_pd());
496     const memory_desc_wrapper dst_d(conf_.dst_pd());
497     const memory_desc_wrapper weights_d(conf_.weights_pd(0));
498     const memory_desc_wrapper bias_d(conf_.weights_pd(1));
499
500     auto &jcp = kernel_->jcp;
501
502     int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
503     int nb_groups = jcp.nb_ch;
504
505     size_t src_h_stride = src_d.blk_off(0, 0, 1);
506     size_t dst_h_stride = dst_d.blk_off(0, 0, 1);
507     size_t wht_kh_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
508
509     const auto &oscales = conf_.attr()->output_scales_;
510
511     parallel(0,
512             [&](const int ithr, const int nthr) {
513             int start{0}, end{0};
514             int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh;
515             balance211(work_amount, nthr, ithr, start, end);
516
517             auto p = jit_deconv_call_s();
518
519             /*loop order = cgn*/
520             int n{0}, g{0}, occ{0}, oh_s{0};
521             if (jcp.loop_order == loop_ngc)
522                 nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks,
523                     oh_s, jcp.oh);
524             else if (jcp.loop_order == loop_cgn)
525                 nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb,
526                     oh_s, jcp.oh);
527             else
528                 assert(!"unsupported loop order");
529             while (start < end) {
530
531                 int ocb = occ * jcp.nb_oc_blocking;
532                 int g_oc = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block;
533                 int g_ic = g * jcp.ch_block * jcp.ic;
534                 int work_rem = end - start;
535                 int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
536
537                 auto dst_w = dst + dst_d.blk_off(n, g_oc);
538                 auto src_w = src + src_d.blk_off(n, g_ic);
539                 auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0);
540                 auto bias_w = jcp.with_bias
541                             ? bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia)
542                             : 0;
543
544                 auto scales = &oscales.scales_[jcp.is_oc_scale * g_oc];
545                 for (int oj = oh_s; oj < oh_e; oj++) {
546                     int ih_max, kh_lo, kh_len;
547                     if (jcp.dilate_h != 0 && jcp.stride_h == 1) {
548                             int dilate_h = jcp.dilate_h + 1;
549                             // Note: use div_up to account for "holes" in filter
550                             int o_t_overflow
551                                 = div_up(max(0, (jcp.kh - 1) * dilate_h
552                                         - oj - jcp.t_pad), dilate_h);
553                             int o_b_overflow
554                                 = div_up(max(0, (jcp.kh - 1) * dilate_h + 1
555                                         - jcp.ih + oj - jcp.b_pad), dilate_h);
556                             kh_len = jcp.kh - o_t_overflow - o_b_overflow;
557                             kh_lo = o_b_overflow;
558                             ih_max = oj + jcp.t_pad - o_b_overflow * dilate_h;
559                     } else {
560                         int o_t_overflow = max(0,
561                                 (jcp.kh - (oj + 1 + jcp.t_pad)) / jcp.stride_h); 
562                         int o_b_overflow = max(0,
563                                 ((oj + 1 + jcp.kh - 1)
564                                  - (jcp.oh + jcp.b_pad)) / jcp.stride_h);
565                         int overflow_kh_hi = jcp.kh - 1
566                             - abs(jcp.oh + jcp.b_pad - (oj + 1)) % jcp.stride_h;
567                         int overflow_kh_lo = ((oj + 1 + jcp.t_pad) - 1) % jcp.stride_h;
568
569                         kh_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h
570                             + 1 - o_t_overflow - o_b_overflow;
571                         kh_lo = overflow_kh_lo + o_b_overflow * jcp.stride_h;
572                         ih_max = (oj + jcp.t_pad - kh_lo) / jcp.stride_h;
573                     }
574
575                     p.src = src_w + ih_max * src_h_stride;
576                     p.dst = dst_w + oj * dst_h_stride;
577                     p.filt = wht_w + kh_lo * wht_kh_stride;
578                     p.bias = bias_w;
579                     p.kh_padding = kh_len;
580                     p.scales = scales;
581                     p.oc_blocks = jcp.is_depthwise ? g : ocb;
582                     kernel_->jit_ker(&p);
583                 }
584                 if (jcp.loop_order == loop_ngc)
585                     nd_iterator_jump(start, end,
586                             n, jcp.mb, g, nb_groups, occ, oc_chunks, oh_s, jcp.oh);
587                 else if (jcp.loop_order == loop_cgn)
588                     nd_iterator_jump(start, end,
589                             occ, oc_chunks, g, nb_groups, n, jcp.mb, oh_s, jcp.oh);
590                 else
591                     assert(!"unsupported loop order");
592             }
593     });
594 }
595
596 template struct _jit_avx512_core_u8s8s32x_deconvolution_fwd_t<data_type::u8>;
597 template struct _jit_avx512_core_u8s8s32x_deconvolution_fwd_t<data_type::s8>;
598 template struct _jit_avx512_core_u8s8s32x_deconvolution_fwd_t<data_type::f32>;
599 template struct _jit_avx512_core_u8s8s32x_deconvolution_fwd_t<data_type::s32>;
600 }
601 }
602 }