Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_x8s8s32x_deconvolution.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "jit_avx512_core_x8s8s32x_deconvolution.hpp"
18
19 #define GET_OFF(field) offsetof(jit_deconv_call_s, field)
20
21 namespace mkldnn {
22 namespace impl {
23 namespace cpu {
24
25 using namespace mkldnn::impl::status;
26 using namespace mkldnn::impl::memory_format;
27 using namespace mkldnn::impl::memory_tracking::names;
28 using namespace mkldnn::impl::utils;
29 using namespace Xbyak;
30
31 using namespace nstl;
32
33 #define wht_blk_off(d, g, ...)                             \
34     (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) : \
35                            (d).blk_off(__VA_ARGS__))
36
37 status_t jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_conf(
38         jit_conv_conf_t &jcp, const deconvolution_desc_t &cd,
39         cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &weights_pd,
40         cpu_memory_t::pd_t &dst_pd, const bool with_bias,
41         cpu_memory_t::pd_t &bias_pd, const primitive_attr_t &attr) {
42     const memory_desc_wrapper src_d(&src_pd);
43     const memory_desc_wrapper dst_d(&dst_pd);
44     const memory_desc_wrapper weights_d(&weights_pd);
45     const memory_desc_wrapper bias_d(&bias_pd);
46
47     if (!(mayiuse(avx512_core)
48                 && one_of(src_d.data_type(), data_type::u8, data_type::s8)
49                 && weights_d.data_type() == data_type::s8
50                 && one_of(dst_d.data_type(), data_type::f32, data_type::s32,
51                            data_type::s8, data_type::u8)))
52         return status::unimplemented;
53
54     jcp = zero<decltype(jcp)>();
55
56     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
57     jcp.signed_input = src_d.data_type() == data_type::s8;
58
59     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
60     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
61     jcp.ic = src_d.dims()[1] / jcp.ngroups;
62     jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups;
63     jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups;
64     jcp.is_depthwise = true && with_groups
65             && utils::everyone_is(1, jcp.ic_without_padding,
66                                jcp.oc_without_padding);
67
68     /* TODO: future work, on hold until depthwise specialized kernel is
69      * implemented. */
70     if (jcp.is_depthwise && jcp.signed_input)
71         return status::unimplemented;
72
73     const auto w_format = jcp.is_depthwise ? Goihw16g : with_groups ?
74             (jcp.signed_input ? gOIhw4i16o4i_s8s8 : gOIhw4i16o4i) :
75             (jcp.signed_input ? OIhw4i16o4i_s8s8 : OIhw4i16o4i);
76
77     if (dst_d.format() == any)
78         CHECK(dst_pd.set_format(nhwc));
79     if (dst_d.format() != nhwc)
80         return status::unimplemented;
81     if (src_d.format() == any)
82         CHECK(src_pd.set_format(nhwc));
83     if (src_d.format() != nhwc)
84         return status::unimplemented;
85     if (weights_d.format() == any)
86         CHECK(weights_pd.set_format(w_format));
87     if (weights_d.format() != w_format)
88         return status::unimplemented;
89
90     jcp.with_bias = with_bias;
91     if (jcp.with_bias) {
92         if (bias_d.format() == any)
93             CHECK(bias_pd.set_format(x));
94         if (bias_d.format() != x)
95             return status::unimplemented;
96     }
97
98     jcp.ndims = dst_d.ndims();
99     jcp.prop_kind = cd.prop_kind;
100     jcp.mb = src_d.dims()[0];
101     jcp.ih = src_d.dims()[2];
102     jcp.iw = src_d.dims()[3];
103     jcp.oh = dst_d.dims()[2];
104     jcp.ow = dst_d.dims()[3];
105     jcp.kh = weights_d.dims()[with_groups + 2];
106     jcp.kw = weights_d.dims()[with_groups + 3];
107     jcp.t_pad = cd.padding[0][0];
108     jcp.l_pad = cd.padding[0][1];
109     jcp.stride_h = cd.strides[0];
110     jcp.stride_w = cd.strides[1];
111     jcp.src_fmt = src_d.format();
112
113     if (jcp.is_depthwise) {
114         jcp.ch_block = 16;
115         jcp.oc_block = 1;
116         jcp.ic_block = 1;
117     } else {
118         jcp.ch_block = 1;
119         jcp.oc_block = 16;
120         jcp.ic_block = 16;
121
122         if (jcp.ngroups == 1) {
123             jcp.oc = utils::rnd_up(jcp.oc_without_padding, jcp.oc_block);
124             jcp.ic = utils::rnd_up(jcp.ic_without_padding, jcp.ic_block);
125         }
126         if (jcp.ic % jcp.ic_block != 0)
127             return status::unimplemented;
128     }
129
130     jcp.dilate_h = cd.dilates[0];
131     jcp.dilate_w = cd.dilates[1];
132
133     if (!IMPLICATION(jcp.dilate_h, jcp.stride_h == 1)
134             || !IMPLICATION(jcp.dilate_w, jcp.stride_w == 1))
135         return status::unimplemented;
136
137     /* padding: bottom and right */
138     jcp.b_pad = (jcp.ih - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
139             - (jcp.oh + jcp.t_pad - 1);
140     jcp.r_pad = (jcp.iw - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1)
141             - (jcp.ow + jcp.l_pad - 1);
142
143     if (!post_ops_ok(jcp, attr))
144         return status::unimplemented;
145
146     const auto &p = attr.post_ops_;
147     const int eltwise_ind = p.find(primitive_kind::eltwise);
148     jcp.with_eltwise = eltwise_ind != -1;
149     if (jcp.with_eltwise)
150         jcp.eltwise = p.entry_[eltwise_ind].eltwise;
151
152     jcp.ver = ver_avx512_core;
153     if (mayiuse(avx512_core_vnni))
154         jcp.ver = ver_vnni;
155     const auto &oscales = attr.output_scales_;
156     jcp.is_oc_scale = oscales.mask_ == 1 << 1;
157
158     assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
159
160     jcp.dst_dt = dst_d.data_type();
161     jcp.bia_dt = jcp.with_bias ? bias_d.data_type() : data_type::undef;
162     jcp.typesize_bia
163             = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
164     jcp.typesize_in = types::data_type_size(src_d.data_type());
165     jcp.typesize_out = types::data_type_size(dst_d.data_type());
166
167     jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block);
168     jcp.nb_oc = jcp.oc / jcp.oc_block;
169     jcp.nb_ic = jcp.ic / jcp.ic_block;
170
171     /* kernel blocking params */
172     const int regs = jcp.ver == ver_vnni ? 30 : 28;
173     jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc);
174     for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--)
175         if (jcp.nb_oc % jcp.nb_oc_blocking == 0
176                 && jcp.l_pad <= regs / (jcp.nb_oc_blocking + 1))
177             break;
178
179     jcp.ur_w = regs / (jcp.nb_oc_blocking + 1);
180     int l_overflow = max(
181             0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w);
182
183     if (jcp.ow < jcp.ur_w) {
184         jcp.ur_w = jcp.ow;
185         jcp.ur_w_tail = 0;
186     } else {
187         for (; jcp.ur_w >= 1; jcp.ur_w--) {
188             /* ur_w should be multiple of stride_w in order
189                to simplify logic for get_ow_start and get_ow_end */
190             bool is_multiple_of_stride = jcp.ur_w % jcp.stride_w == 0;
191
192             /* boundary conditions:
193                These conditions ensure all elements close to boundary
194                are computed in a single call of compute loop */
195             bool left_boundary_covered = jcp.ur_w >= l_overflow * jcp.stride_w;
196             jcp.ur_w_tail = jcp.ow % jcp.ur_w;
197             int r_overflow_no_tail
198                     = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
199                                      - max(0, jcp.r_pad) - jcp.ur_w_tail)
200                                     / jcp.stride_w);
201             bool right_boundary_covered
202                     = jcp.ur_w >= r_overflow_no_tail * jcp.stride_w;
203
204             if (is_multiple_of_stride && left_boundary_covered
205                     && right_boundary_covered)
206                 break;
207             else if (jcp.ur_w == 1)
208                 /* The boundary conditions above are also important
209                    to maintain simplicity of calls to icb_loop,
210                    if those conditions are not satisfied,
211                    then special cases will need to be added
212                    to use correct l_overflow/r_overflow values
213                    when different iterations of compute loop
214                    work on the locations close to boundary.
215                    So to keep code simple, return unimplemented
216                    for extreme case when a good ur_w cannot be found.
217                  */
218                 return status::unimplemented;
219         }
220     }
221
222     jcp.wei_adj_scale
223             = (jcp.signed_input && (jcp.ver != ver_vnni)) ? (1.f / 2.f) : 1.f;
224
225     jcp.loop_order = jcp.ngroups > 1 ? loop_ngc : loop_cgn;
226     return status::success;
227 }
228
229 bool jit_avx512_core_x8s8s32x_deconv_fwd_kernel::maybe_eltwise(int position) {
230     using namespace primitive_kind;
231     const auto &p = attr_.post_ops_;
232
233     if (position == 0) {
234         /* eltwise before sum */
235         return p.contain(eltwise, 0);
236     } else if (position == 1) {
237         /* eltwise after sum */
238         return p.contain(sum, 0) && p.contain(eltwise, 1);
239     }
240     return false;
241 }
242
243 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::compute_eltwise(int ur_w) {
244     int nb_oc_block
245             = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
246     if (ur_w == jcp.ur_w)
247         eltwise_injector_->compute_vector_range(0, nb_oc_block * jcp.ur_w);
248     else
249         for (int k = 0; k < nb_oc_block; k++)
250             eltwise_injector_->compute_vector_range(
251                     k * jcp.ur_w, k * jcp.ur_w + ur_w);
252 }
253
254 bool jit_avx512_core_x8s8s32x_deconv_fwd_kernel::post_ops_ok(
255         jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
256     using namespace primitive_kind;
257     const auto &p = attr.post_ops_;
258
259     auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
260
261     switch (p.len_) {
262     case 0: return true;
263     case 1: return is_eltwise(0) || p.contain(sum, 0);
264     case 2:
265         return (p.contain(sum, 0) && is_eltwise(1))
266                 || (p.contain(sum, 1) && is_eltwise(0));
267     default: return false;
268     }
269
270     return false;
271 }
272
273 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_scratchpad(
274         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
275         const primitive_attr_t &attr) {
276     if (jcp.signed_input && jcp.ver != ver_vnni) {
277         size_t count = nstl::max(attr.output_scales_.count_, 16);
278         scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count);
279     }
280 }
281
282 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::compute_ker(int ur_w,
283         int l_overflow, int r_overflow, ker_block_t last_ic_block_flag,
284         bool h_padded) {
285
286     const int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block;
287     const int ur_w_stride = jcp.signed_input ? 1 : jcp.stride_w;
288
289     auto src_offset = [=](int oj, int icb, int ki) {
290         return jcp.typesize_in
291                 * (((oj + jcp.l_pad - ki * (jcp.dilate_w + 1)) / jcp.stride_w)
292                                   * jcp.ngroups * jcp.ic_without_padding
293                           + icb * 4);
294     };
295
296     auto kernel_offset = [=](int ocb, int icb, int ki) {
297         return jcp.typesize_in
298                 * (ocb * jcp.nb_ic * jcp.kh * jcp.kw * ch_block_all
299                           + icb * jcp.oc_block * jcp.ic_block / 4
300                           + ki * ch_block_all);
301     };
302
303     auto compute = [=](zmm_t vreg_acc, zmm_t vreg_wei, zmm_t vreg_src) {
304         if (jcp.ver == ver_vnni) {
305             vpdpbusd(vreg_acc, vreg_src, vreg_wei);
306         } else if (jcp.is_depthwise) {
307             vpmulld(zmm_tmp, vreg_src, vreg_wei);
308             vpaddd(vreg_acc, vreg_acc, zmm_tmp);
309         } else {
310             vpmaddubsw(zmm_tmp, vreg_src, vreg_wei);
311             vpmaddwd(zmm_tmp, zmm_tmp, zmm_one);
312             vpaddd(vreg_acc, vreg_acc, zmm_tmp);
313         }
314     };
315
316     for (int ki = 0; ki < jcp.kw; ki++) {
317
318         int jj_start = get_ow_start(ki, l_overflow);
319         int jj_end = get_ow_end(ur_w, ki, r_overflow);
320
321         int _start = (jcp.signed_input) ? 0 : jj_start;
322         int _end = (jcp.signed_input) ? ur_w : jj_end;
323
324         int tail_size = jcp.ic_without_padding % 4;
325         int n_ic_blocks = jcp.is_depthwise ?
326                 1 :
327                 (last_ic_block_flag & ~no_last_block ?
328                                 div_up(jcp.ic_without_padding % jcp.ic_block,
329                                         4) :
330                                 jcp.ic_block / 4);
331
332         for (int icb1 = 0; icb1 < n_ic_blocks; icb1++) {
333             if (h_padded == true) {
334                 /* fill padded area with shifted values */
335                 Zmm inp = zmm_inp(0, jcp.nb_oc_blocking);
336                 vpxord(inp, inp, inp);
337                 vpsubb(inp, inp, zmm_shift);
338             } else {
339
340                 for (int jj = _start; jj < _end; jj += ur_w_stride) {
341
342                     int aux_src_off = src_offset(jj, icb1, ki);
343
344                     if (jj >= jj_start && jj < jj_end
345                             && ((jj + jcp.l_pad - ki) % jcp.stride_w == 0)) {
346                         if (jcp.is_depthwise) {
347                             vpmovzxbd(zmm_inp(jj, jcp.nb_oc_blocking),
348                                     EVEX_compress_addr(
349                                               aux_reg_src, aux_src_off));
350                         } else if ((last_ic_block_flag & last_sp_block)
351                                 && tail_size != 0 && icb1 == n_ic_blocks - 1) {
352                             xmm_t xmm_tmp = xmm_t(
353                                     zmm_inp(jj, jcp.nb_oc_blocking).getIdx());
354                             for (int r = 0; r < tail_size; ++r)
355                                 vpinsrb(xmm_tmp, xmm_tmp,
356                                         ptr[aux_reg_src + aux_src_off + r], r);
357                             vpbroadcastd(
358                                     zmm_inp(jj, jcp.nb_oc_blocking), xmm_tmp);
359                         } else {
360                             vpbroadcastd(zmm_inp(jj, jcp.nb_oc_blocking),
361                                     EVEX_compress_addr(
362                                                  aux_reg_src, aux_src_off));
363                         }
364                         if (jcp.signed_input)
365                             vpsubb(zmm_inp(jj, jcp.nb_oc_blocking),
366                                     zmm_inp(jj, jcp.nb_oc_blocking), zmm_shift);
367                     } else {
368                         /* fill padded area with shifted values */
369                         if (jcp.signed_input) {
370                             Zmm inp = zmm_inp(jj, jcp.nb_oc_blocking);
371                             vpxord(inp, inp, inp);
372                             vpsubb(inp, inp, zmm_shift);
373                         }
374                     }
375                 }
376             }
377             for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
378                 int aux_filt_off = kernel_offset(ocb, icb1, ki);
379
380                 if (_end - _start > 0) {
381                     if (jcp.is_depthwise)
382                         vpmovsxbd(zmm_wei,
383                                 EVEX_compress_addr(aux_reg_filt, aux_filt_off));
384                     else
385                         vmovups(zmm_wei,
386                                 EVEX_compress_addr(aux_reg_filt, aux_filt_off));
387                 }
388                 for (int jj = _start; jj < _end; jj += ur_w_stride) {
389                     Zmm inp = (h_padded == true) ?
390                             zmm_inp(0, jcp.nb_oc_blocking) :
391                             zmm_inp(jj, jcp.nb_oc_blocking);
392                     compute(zmm_out(jj, ocb), zmm_wei, inp);
393                 }
394             }
395         }
396     }
397 }
398
399 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::kh_loop(int ur_w,
400         int l_overflow, int r_overflow, ker_block_t last_ic_block_flag) {
401
402     int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block;
403     int shift_src_ih = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw
404             * jcp.ngroups * jcp.ic_without_padding;
405     const int stride_h = jcp.signed_input ? 1 : jcp.stride_h;
406     int shift_filt_kh = jcp.typesize_in * jcp.kw * ch_block_all * stride_h;
407
408     Label kh_loop_label, skip_kh_loop;
409     Label t_overflow_label, no_t_overflow_label, b_overflow_label,
410             no_b_overflow_label;
411
412     mov(aux_reg_src, reg_src);
413     mov(aux_reg_filt, reg_filt);
414
415     if (jcp.signed_input) {
416         /* Weights are transposed, so first compute 'bottom' padding. */
417         mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]);
418         cmp(reg_overflow, 0);
419         je(no_b_overflow_label, T_NEAR);
420         L(b_overflow_label); {
421             compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
422
423             add(aux_reg_filt, shift_filt_kh);
424             dec(reg_overflow);
425             cmp(reg_overflow, 0);
426             jg(b_overflow_label, T_NEAR);
427         }
428         L(no_b_overflow_label);
429     }
430
431     mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]);
432
433     if (jcp.signed_input || ((!jcp.signed_input)
434         && ((min(jcp.t_pad, jcp.b_pad) < 0)
435             || ((jcp.kh - 1) * (jcp.dilate_h + 1)
436                 < nstl::max(jcp.t_pad, jcp.b_pad))))) {
437         cmp(reg_kh, 0);
438         je(skip_kh_loop, T_NEAR);
439     }
440
441     L(kh_loop_label); {
442         compute_ker(ur_w, l_overflow, r_overflow, last_ic_block_flag, false);
443         sub(aux_reg_src, shift_src_ih);
444         add(aux_reg_filt, shift_filt_kh);
445         dec(reg_kh);
446
447         /* Insert weight compensation in stride 'holes' */
448         if (jcp.signed_input && jcp.stride_h > 1) {
449             Label kh_comp_loop;
450
451             cmp(reg_kh, 0);
452             je(skip_kh_loop, T_NEAR);
453             mov(reg_comp_strides, jcp.stride_h - 1);
454             L(kh_comp_loop);
455             {
456                 compute_ker(
457                         ur_w, 0, 0, last_ic_block_flag, true);
458                 add(aux_reg_filt, shift_filt_kh);
459                 dec(reg_comp_strides);
460                 cmp(reg_comp_strides, 0);
461                 jg(kh_comp_loop, T_NEAR);
462             }
463         }
464         cmp(reg_kh, 0);
465         jg(kh_loop_label, T_NEAR);
466     }
467     L(skip_kh_loop);
468     if (jcp.signed_input) {
469         mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]);
470         cmp(reg_overflow, 0);
471         je(no_t_overflow_label, T_NEAR);
472         L(t_overflow_label); {
473             compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
474
475             add(aux_reg_filt, shift_filt_kh);
476             dec(reg_overflow);
477             cmp(reg_overflow, 0);
478             jg(t_overflow_label, T_NEAR);
479         }
480         L(no_t_overflow_label);
481     }
482 }
483
484 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::prepare_output(int ur_w) {
485     for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
486         for (int ur = 0; ur < ur_w; ur++) {
487             zmm_t zmm = zmm_out(ur, ocb);
488             vpxord(zmm, zmm, zmm);
489         }
490     }
491     if (jcp.signed_input) {
492         xor_(reg_scratch, reg_scratch);
493         Reg8 _t8 = reg_scratch.cvt8();
494         mov(_t8, (int8_t)-128);
495         vpbroadcastb(zmm_shift, _t8);
496     }
497 }
498
499 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::cvt2ps(
500         data_type_t type_in, zmm_t zmm_in, const Operand &op, bool mask_flag) {
501     zmm_t zmm = mask_flag ? zmm_in | ktail_mask | T_z : zmm_in;
502     switch (type_in) {
503     case data_type::f32:
504     case data_type::s32: vmovups(zmm, op); break;
505     case data_type::s8: vpmovsxbd(zmm, op); break;
506     case data_type::u8: vpmovzxbd(zmm, op); break;
507     default: assert(!"unsupported data type");
508     }
509     if (type_in != data_type::f32)
510         vcvtdq2ps(zmm_in, zmm_in);
511 }
512
513 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::store_output(
514         int ur_w, bool last_oc_block) {
515     mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
516     mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
517
518     if (jcp.signed_input)
519         mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]);
520
521     const auto &p = attr_.post_ops_;
522     const int sum_idx = p.find(primitive_kind::sum);
523     const float *p_sum_scale
524             = (sum_idx != -1) ? &p.entry_[sum_idx].sum.scale : nullptr;
525     if (p_sum_scale && *p_sum_scale != 1.f)
526         mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
527
528     if (jcp.with_bias && jcp.signed_input && jcp.ver != ver_vnni) {
529         mov(reg_bias_alpha, float2int(jcp.wei_adj_scale));
530         vmovq(xmm_bias_alpha(), reg_bias_alpha);
531         vbroadcastss(zmm_bias_alpha(), xmm_bias_alpha());
532     }
533
534     for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
535         const bool mask_flag = last_oc_block && ocb == jcp.nb_oc_blocking - 1;
536         int scale_offset
537                 = jcp.is_oc_scale * (sizeof(float) * ocb * jcp.oc_block);
538
539         auto zmm_bias = zmm_tmp;
540         if (jcp.with_bias) {
541             int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block;
542             auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
543             cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag);
544             if (jcp.signed_input && jcp.ver != ver_vnni)
545                 vmulps(zmm_bias, zmm_bias, zmm_bias_alpha());
546         }
547         if (jcp.signed_input) {
548             int comp_offset = sizeof(int32_t) * ocb * jcp.oc_block;
549             auto comp_addr = EVEX_compress_addr(reg_compensation, comp_offset);
550             cvt2ps(data_type::s32, zmm_comp, comp_addr, mask_flag);
551         }
552
553         for (int ur = 0; ur < ur_w; ur++) {
554             zmm_t zmm = zmm_out(ur, ocb);
555             vcvtdq2ps(zmm, zmm);
556             if (jcp.signed_input)
557                 vaddps(zmm, zmm, zmm_comp);
558             if (jcp.with_bias)
559                 vaddps(zmm, zmm, zmm_bias);
560             zmm_t mask_zmm = mask_flag ? zmm | ktail_mask | T_z : zmm;
561             vmulps(mask_zmm, zmm,
562                     EVEX_compress_addr(reg_ptr_scales, scale_offset));
563         }
564     }
565     if (maybe_eltwise(0))
566         compute_eltwise(ur_w);
567     if (p_sum_scale) { // post_op: sum
568         for (int k = 0; k < jcp.nb_oc_blocking; k++) {
569             const bool mask_flag
570                     = last_oc_block == 1 && k == jcp.nb_oc_blocking - 1;
571             for (int j = 0; j < ur_w; j++) {
572                 int aux_output_offset
573                         = jcp.typesize_out
574                         * (k * jcp.oc_block
575                                   + j * jcp.oc_without_padding * jcp.ngroups);
576                 auto addr = EVEX_compress_addr(reg_dst, aux_output_offset);
577                 Zmm zmm = zmm_out(j, k);
578                 cvt2ps(jcp.dst_dt, zmm_prev_dst, addr, mask_flag);
579                 if (*p_sum_scale == 1.f)
580                     vaddps(zmm, zmm_prev_dst);
581                 else
582                     vfmadd231ps(zmm, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
583             }
584         }
585     }
586     if (maybe_eltwise(1))
587         compute_eltwise(ur_w);
588
589     for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
590         const bool mask_flag = last_oc_block && ocb == jcp.nb_oc_blocking - 1;
591         for (int ur = 0; ur < ur_w; ur++) {
592             zmm_t zmm = zmm_out(ur, ocb);
593             if (jcp.dst_dt == data_type::u8) {
594                 vpxord(zmm_zero, zmm_zero, zmm_zero);
595                 vmaxps(zmm, zmm_zero, zmm);
596             }
597             if (jcp.dst_dt != data_type::f32) {
598                 if (attr_.round_mode_ == round_mode::nearest)
599                     vcvtps2dq(zmm | T_rn_sae, zmm);
600                 else if (attr_.round_mode_ == round_mode::down)
601                     vcvtps2dq(zmm | T_rd_sae, zmm);
602                 else
603                     assert(!"unimplemented");
604             }
605         }
606         for (int ur = 0; ur < ur_w; ur++) {
607             int aux_dst_off = jcp.typesize_out
608                     * (ur * jcp.ngroups * jcp.oc_without_padding
609                                       + ocb * jcp.oc_block);
610             auto addr = EVEX_compress_addr(reg_dst, aux_dst_off);
611
612             zmm_t zmm = zmm_out(ur, ocb);
613             zmm_t r_zmm = mask_flag ? zmm | ktail_mask : zmm;
614             switch (jcp.dst_dt) {
615             case data_type::f32:
616             case data_type::s32: vmovups(addr, r_zmm); break;
617             case data_type::s8: vpmovsdb(addr, r_zmm); break;
618             case data_type::u8: vpmovusdb(addr, r_zmm); break;
619             default: assert(!"unknown dst_dt");
620             }
621         }
622     }
623 }
624
625 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::icb_loop(
626         int ur_w, int l_overflow, int r_overflow, bool is_last_sp_block) {
627
628     int shift_src_icb = jcp.typesize_in * jcp.ic_block;
629     int shift_filt_icb
630             = jcp.typesize_in * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block;
631
632     prepare_output(ur_w);
633
634     Label skip_icb_loop, icb_loop_label;
635
636     mov(reg_icb, jcp.nb_ic);
637     L(icb_loop_label); {
638
639         if (jcp.ic_without_padding != jcp.ic) {
640             Label common_ker, end_ker;
641             cmp(reg_icb, 1);
642             jg(common_ker, T_NEAR);
643
644             kh_loop(ur_w, l_overflow, r_overflow,
645                     is_last_sp_block ? last_sp_block : last_ic_block);
646             jmp(end_ker, T_NEAR);
647
648             L(common_ker);
649             kh_loop(ur_w, l_overflow, r_overflow, no_last_block);
650
651             L(end_ker);
652         } else {
653             kh_loop(ur_w, l_overflow, r_overflow, no_last_block);
654         }
655
656         add(reg_src, shift_src_icb);
657         add(reg_filt, shift_filt_icb);
658         dec(reg_icb);
659         cmp(reg_icb, 0);
660         jg(icb_loop_label, T_NEAR);
661     }
662
663     /* come-back pointers */
664     sub(reg_src, jcp.nb_ic * shift_src_icb);
665     sub(reg_filt, jcp.nb_ic * shift_filt_icb);
666     L(skip_icb_loop);
667
668     if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
669         Label common_store, end_store;
670         mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
671         if (jcp.is_depthwise)
672             cmp(reg_oc_blocks, jcp.nb_ch - 1);
673         else
674             cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
675         jne(common_store, T_NEAR);
676
677         store_output(ur_w, true);
678         jmp(end_store, T_NEAR);
679
680         L(common_store);
681         store_output(ur_w, false);
682
683         L(end_store);
684
685     } else {
686         store_output(ur_w, false);
687     }
688 }
689
690 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::generate() {
691     preamble();
692
693     xor_(reg_scratch, reg_scratch);
694     Reg16 _t = reg_scratch.cvt16();
695     mov(_t, 0x1);
696     vpbroadcastw(zmm_one, _t);
697
698     if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
699         int tail_size = jcp.is_depthwise ?
700                 jcp.ngroups % jcp.ch_block :
701                 jcp.oc_without_padding % jcp.oc_block;
702         int mask = (1 << tail_size) - 1;
703         Reg32 regw_tmp = reg_nur_w.cvt32();
704         mov(regw_tmp, mask);
705         kmovw(ktail_mask, regw_tmp);
706     }
707
708     mov(reg_src, ptr[param1 + GET_OFF(src)]);
709     mov(reg_filt, ptr[param1 + GET_OFF(filt)]);
710     mov(reg_dst, ptr[param1 + GET_OFF(dst)]);
711
712     int dst_shift = jcp.typesize_out * jcp.ur_w * jcp.ngroups
713             * jcp.oc_without_padding;
714     int src_shift = jcp.typesize_in * (jcp.ur_w / jcp.stride_w) * jcp.ngroups
715             * jcp.ic_without_padding;
716
717     int l_overflow = max(
718             0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w);
719     int r_overflow
720             = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - max(0, jcp.r_pad))
721                             / jcp.stride_w);
722
723     int r_overflow1
724             = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
725                                    - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail)
726                             / jcp.stride_w);
727     int nur_w = jcp.ow / jcp.ur_w;
728     if (r_overflow1 > 0)
729         nur_w--;
730
731     if (jcp.ur_w == jcp.ow) {
732         icb_loop(jcp.ur_w, l_overflow, r_overflow, true);
733     } else if (nur_w == 0) {
734         icb_loop(jcp.ur_w, l_overflow, r_overflow1, jcp.ur_w_tail == 0);
735         add(reg_src, src_shift);
736         add(reg_dst, dst_shift);
737         if (jcp.ur_w_tail != 0)
738             icb_loop(jcp.ur_w_tail, 0, r_overflow, true);
739     } else {
740         xor_(reg_nur_w, reg_nur_w);
741         if (l_overflow > 0) {
742             icb_loop(jcp.ur_w, l_overflow, 0, false);
743             add(reg_src, src_shift);
744             add(reg_dst, dst_shift);
745             inc(reg_nur_w);
746         }
747         if ((l_overflow <= 0 && nur_w > 0) || (l_overflow > 0 && nur_w > 1)) {
748             Label ow_loop_label;
749             L(ow_loop_label);
750             {
751                 icb_loop(jcp.ur_w, 0, 0, false);
752                 add(reg_src, src_shift);
753                 add(reg_dst, dst_shift);
754                 inc(reg_nur_w);
755                 cmp(reg_nur_w, nur_w);
756                 jl(ow_loop_label, T_NEAR);
757             }
758         }
759         if (r_overflow1 > 0) {
760             icb_loop(jcp.ur_w, 0, r_overflow1, jcp.ur_w_tail == 0);
761             add(reg_src, src_shift);
762             add(reg_dst, dst_shift);
763         }
764         if (jcp.ur_w_tail != 0) {
765             icb_loop(jcp.ur_w_tail, 0, r_overflow, true);
766         }
767     }
768     postamble();
769
770     if (jcp.with_eltwise)
771         eltwise_injector_->prepare_table();
772 }
773
774 template <data_type_t src_type, data_type_t dst_type>
775 void _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<src_type,
776         dst_type>::execute_forward() const {
777     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
778     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
779     auto bias = reinterpret_cast<const char *>(this->input_memory(2));
780     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
781
782     const memory_desc_wrapper src_d(pd()->src_pd());
783     const memory_desc_wrapper dst_d(pd()->dst_pd());
784     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
785     const memory_desc_wrapper bias_d(pd()->weights_pd(1));
786
787     auto &jcp = kernel_->jcp;
788
789     int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
790     int nb_groups = jcp.nb_ch;
791
792     size_t src_h_stride = src_d.blk_off(0, 0, 1);
793     size_t dst_h_stride = dst_d.blk_off(0, 0, 1);
794     size_t wht_kh_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
795
796     const float *oscales = pd()->attr()->output_scales_.scales_;
797     if (jcp.signed_input && jcp.ver != ver_vnni) {
798         auto local_scales
799                 = scratchpad().template get<float>(key_conv_adjusted_scales);
800         size_t count = pd()->attr()->output_scales_.count_;
801         float factor = 1.f / pd()->jcp_.wei_adj_scale;
802         if (count == 1) {
803             utils::array_set(local_scales, oscales[0] * factor, 16);
804         } else {
805             for (size_t c = 0; c < count; c++)
806                 local_scales[c] = oscales[c] * factor;
807         }
808         oscales = local_scales;
809     }
810     size_t offset = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw;
811     auto w = const_cast<wei_data_t *>(weights);
812     int32_t *compensation
813             = (jcp.signed_input) ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
814
815     parallel(0, [&](const int ithr, const int nthr) {
816         int start{ 0 }, end{ 0 };
817         int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh;
818         balance211(work_amount, nthr, ithr, start, end);
819
820         auto p = jit_deconv_call_s();
821
822         /*loop order = cgn*/
823         int n{ 0 }, g{ 0 }, occ{ 0 }, oh_s{ 0 };
824         if (jcp.loop_order == loop_ngc)
825             nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks,
826                     oh_s, jcp.oh);
827         else if (jcp.loop_order == loop_cgn)
828             nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb,
829                     oh_s, jcp.oh);
830         else
831             assert(!"unsupported loop order");
832         while (start < end) {
833
834             int ocb = occ * jcp.nb_oc_blocking;
835             int g_oc = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block;
836             int g_ic = g * jcp.ch_block * jcp.ic;
837             int work_rem = end - start;
838             int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
839
840             auto dst_w = dst + dst_d.blk_off(n, g_oc);
841             auto src_w = src + src_d.blk_off(n, g_ic);
842             auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0);
843             auto bias_w = jcp.with_bias ?
844                     bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia) :
845                     0;
846             int32_t *compensation_w
847                     = (jcp.signed_input) ? compensation + g_oc : 0;
848
849             auto scales = &oscales[jcp.is_oc_scale * g_oc];
850             for (int oj = oh_s; oj < oh_e; oj++) {
851                 int ih_max = 0, kh_lo = 0, kh_len = 0;
852                 if (jcp.dilate_h != 0 && jcp.stride_h == 1) {
853                     /* dilation */
854                     int dilate_h = jcp.dilate_h + 1;
855                     // Note: use div_up to account for "holes" in filter
856                     int o_t_overflow = div_up(
857                             max(0, (jcp.kh - 1) * dilate_h - oj - jcp.t_pad),
858                             dilate_h);
859                     int o_b_overflow
860                             = div_up(max(0, (jcp.kh - 1) * dilate_h + 1 - jcp.oh
861                                                      + oj - jcp.b_pad),
862                                     dilate_h);
863                     kh_len = jcp.kh - o_t_overflow - o_b_overflow;
864                     kh_lo = o_b_overflow;
865                     ih_max = oj + jcp.t_pad - o_b_overflow * dilate_h;
866                 } else {
867                     int o_t_overflow = max(
868                             0, (jcp.kh - (oj + 1 + jcp.t_pad)) / jcp.stride_h);
869                     int o_b_overflow
870                             = max(0, ((oj + jcp.kh) - (jcp.oh + jcp.b_pad))
871                                             / jcp.stride_h);
872                     int overflow_kh_hi = jcp.kh - 1
873                             - abs(jcp.oh + jcp.b_pad - (oj + 1)) % jcp.stride_h;
874                     int overflow_kh_lo = (oj + jcp.t_pad) % jcp.stride_h;
875
876                     kh_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h
877                             + 1 - o_t_overflow - o_b_overflow;
878                     kh_lo = overflow_kh_lo + o_b_overflow * jcp.stride_h;
879                     ih_max = (oj + jcp.t_pad - kh_lo) / jcp.stride_h;
880                 }
881
882                 int wei_stride
883                         = (!jcp.signed_input) ? kh_lo * wht_kh_stride : 0;
884                 p.src = src_w + ih_max * src_h_stride;
885                 p.dst = dst_w + oj * dst_h_stride;
886                 p.filt = wht_w + wei_stride;
887                 p.bias = bias_w;
888                 p.compensation = compensation_w;
889                 p.t_overflow = max(
890                         0, jcp.kh - (kh_lo + max(0, kh_len - 1) * jcp.stride_h
891                                             + 1));
892                 p.b_overflow = kh_lo;
893                 p.kh_padding = kh_len;
894                 p.scales = scales;
895                 p.oc_blocks = jcp.is_depthwise ? g : ocb;
896                 kernel_->jit_ker(&p);
897             }
898             if (jcp.loop_order == loop_ngc)
899                 nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ,
900                         oc_chunks, oh_s, jcp.oh);
901             else if (jcp.loop_order == loop_cgn)
902                 nd_iterator_jump(start, end, occ, oc_chunks, g, nb_groups, n,
903                         jcp.mb, oh_s, jcp.oh);
904             else
905                 assert(!"unsupported loop order");
906         }
907     });
908 }
909
910 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::u8,
911         data_type::u8>;
912 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::u8,
913         data_type::s8>;
914 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::u8,
915         data_type::f32>;
916 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::u8,
917         data_type::s32>;
918 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::s8,
919         data_type::u8>;
920 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::s8,
921         data_type::s8>;
922 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::s8,
923         data_type::f32>;
924 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::s8,
925         data_type::s32>;
926 }
927 }
928 }