updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_x8s8s32x_deconvolution.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "jit_avx512_core_x8s8s32x_deconvolution.hpp"
18
19 #define GET_OFF(field) offsetof(jit_deconv_call_s, field)
20
21 namespace mkldnn {
22 namespace impl {
23 namespace cpu {
24
25 using namespace mkldnn::impl::status;
26 using namespace mkldnn::impl::memory_format;
27 using namespace mkldnn::impl::memory_tracking::names;
28 using namespace mkldnn::impl::utils;
29 using namespace Xbyak;
30
31 using namespace nstl;
32
33 #define wht_blk_off(d, g, ...)                             \
34     (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) : \
35                            (d).blk_off(__VA_ARGS__))
36
37 status_t jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_conf(
38         jit_conv_conf_t &jcp, const deconvolution_desc_t &cd,
39         cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &weights_pd,
40         cpu_memory_t::pd_t &dst_pd, const bool with_bias,
41         cpu_memory_t::pd_t &bias_pd, const primitive_attr_t &attr) {
42     const memory_desc_wrapper src_d(&src_pd);
43     const memory_desc_wrapper dst_d(&dst_pd);
44     const memory_desc_wrapper weights_d(&weights_pd);
45     const memory_desc_wrapper bias_d(&bias_pd);
46
47     if (!(mayiuse(avx512_core)
48                 && one_of(src_d.data_type(), data_type::u8, data_type::s8)
49                 && weights_d.data_type() == data_type::s8
50                 && one_of(dst_d.data_type(), data_type::f32, data_type::s32,
51                            data_type::s8, data_type::u8)))
52         return status::unimplemented;
53
54     jcp = zero<decltype(jcp)>();
55
56     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
57     jcp.signed_input = src_d.data_type() == data_type::s8;
58     const int ndims = jcp.ndims = dst_d.ndims();
59     const bool is_1d = ndims == 3;
60
61     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
62     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
63     jcp.ic = src_d.dims()[1] / jcp.ngroups;
64     jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups;
65     jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups;
66     jcp.is_depthwise = true && with_groups
67             && utils::everyone_is(1, jcp.ic_without_padding,
68                                jcp.oc_without_padding);
69
70     /* TODO: future work, on hold until depthwise specialized kernel is
71      * implemented. */
72     if (jcp.is_depthwise && jcp.signed_input)
73         return status::unimplemented;
74
75     auto dst_format = pick(ndims - 3, nwc, nhwc);
76     auto src_format = pick(ndims - 3, nwc, nhwc);
77 #define pick_signed(fmt) (jcp.signed_input ? fmt##_s8s8 : fmt)
78     const auto w_format = is_1d ?
79             (jcp.is_depthwise ? Goiw16g :
80                                 (with_groups ? pick_signed(gOIw4i16o4i) :
81                                                pick_signed(OIw4i16o4i))) :
82             (jcp.is_depthwise ? Goihw16g :
83                                 (with_groups ? pick_signed(gOIhw4i16o4i) :
84                                                pick_signed(OIhw4i16o4i)));
85 #undef pick_signed
86
87     if (dst_d.format() == any)
88         CHECK(dst_pd.set_format(dst_format));
89     if (dst_d.format() != dst_format)
90         return status::unimplemented;
91     if (src_d.format() == any)
92         CHECK(src_pd.set_format(src_format));
93     if (src_d.format() != src_format)
94         return status::unimplemented;
95     if (weights_d.format() == any)
96         CHECK(weights_pd.set_format(w_format));
97     if (weights_d.format() != w_format)
98         return status::unimplemented;
99
100     jcp.with_bias = with_bias;
101     if (jcp.with_bias) {
102         if (bias_d.format() == any)
103             CHECK(bias_pd.set_format(x));
104         if (bias_d.format() != x)
105             return status::unimplemented;
106     }
107
108     jcp.prop_kind = cd.prop_kind;
109     jcp.mb = src_d.dims()[0];
110     jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2];
111     jcp.iw = src_d.dims()[ndims - 1];
112     jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2];
113     jcp.ow = dst_d.dims()[ndims - 1];
114     jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2];
115     jcp.kw = weights_d.dims()[with_groups + ndims - 1];
116     jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4];
117     jcp.l_pad = cd.padding[0][ndims - 3];
118     jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4];
119     jcp.stride_w = cd.strides[ndims - 3];
120     jcp.src_fmt = src_d.format();
121
122     if (jcp.is_depthwise) {
123         jcp.ch_block = 16;
124         jcp.oc_block = 1;
125         jcp.ic_block = 1;
126     } else {
127         jcp.ch_block = 1;
128         jcp.oc_block = 16;
129         jcp.ic_block = 16;
130
131         if (jcp.ngroups == 1) {
132             jcp.oc = utils::rnd_up(jcp.oc_without_padding, jcp.oc_block);
133             jcp.ic = utils::rnd_up(jcp.ic_without_padding, jcp.ic_block);
134         }
135         if (jcp.ic % jcp.ic_block != 0)
136             return status::unimplemented;
137     }
138
139     jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4];
140     jcp.dilate_w = cd.dilates[ndims - 3];
141
142     if (!IMPLICATION(jcp.dilate_h, jcp.stride_h == 1)
143             || !IMPLICATION(jcp.dilate_w, jcp.stride_w == 1))
144         return status::unimplemented;
145
146     /* padding: bottom and right */
147     jcp.b_pad = (jcp.ih - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
148             - (jcp.oh + jcp.t_pad - 1);
149     jcp.r_pad = (jcp.iw - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1)
150             - (jcp.ow + jcp.l_pad - 1);
151
152     if (!post_ops_ok(jcp, attr))
153         return status::unimplemented;
154
155     const auto &p = attr.post_ops_;
156     const int eltwise_ind = p.find(primitive_kind::eltwise);
157     jcp.with_eltwise = eltwise_ind != -1;
158     if (jcp.with_eltwise)
159         jcp.eltwise = p.entry_[eltwise_ind].eltwise;
160
161     jcp.ver = ver_avx512_core;
162     if (mayiuse(avx512_core_vnni))
163         jcp.ver = ver_vnni;
164     const auto &oscales = attr.output_scales_;
165     jcp.is_oc_scale = oscales.mask_ == 1 << 1;
166
167     assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
168
169     jcp.dst_dt = dst_d.data_type();
170     jcp.bia_dt = jcp.with_bias ? bias_d.data_type() : data_type::undef;
171     jcp.typesize_bia
172             = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
173     jcp.typesize_in = types::data_type_size(src_d.data_type());
174     jcp.typesize_out = types::data_type_size(dst_d.data_type());
175
176     jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block);
177     jcp.nb_oc = jcp.oc / jcp.oc_block;
178     jcp.nb_ic = jcp.ic / jcp.ic_block;
179
180     /* kernel blocking params */
181     const int regs = jcp.ver == ver_vnni ? 30 : 28;
182     jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc);
183     for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--)
184         if (jcp.nb_oc % jcp.nb_oc_blocking == 0
185                 && jcp.l_pad <= regs / (jcp.nb_oc_blocking + 1))
186             break;
187
188     jcp.ur_w = regs / (jcp.nb_oc_blocking + 1);
189     int l_overflow = max(
190             0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w);
191
192     if (jcp.ow < jcp.ur_w) {
193         jcp.ur_w = jcp.ow;
194         jcp.ur_w_tail = 0;
195     } else {
196         for (; jcp.ur_w >= 1; jcp.ur_w--) {
197             /* ur_w should be multiple of stride_w in order
198                to simplify logic for get_ow_start and get_ow_end */
199             bool is_multiple_of_stride = jcp.ur_w % jcp.stride_w == 0;
200
201             /* boundary conditions:
202                These conditions ensure all elements close to boundary
203                are computed in a single call of compute loop */
204             bool left_boundary_covered = jcp.ur_w >= l_overflow * jcp.stride_w;
205             jcp.ur_w_tail = jcp.ow % jcp.ur_w;
206             int r_overflow_no_tail
207                     = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
208                                      - max(0, jcp.r_pad) - jcp.ur_w_tail)
209                                     / jcp.stride_w);
210             bool right_boundary_covered
211                     = jcp.ur_w >= r_overflow_no_tail * jcp.stride_w;
212
213             if (is_multiple_of_stride && left_boundary_covered
214                     && right_boundary_covered)
215                 break;
216             else if (jcp.ur_w == 1)
217                 /* The boundary conditions above are also important
218                    to maintain simplicity of calls to icb_loop,
219                    if those conditions are not satisfied,
220                    then special cases will need to be added
221                    to use correct l_overflow/r_overflow values
222                    when different iterations of compute loop
223                    work on the locations close to boundary.
224                    So to keep code simple, return unimplemented
225                    for extreme case when a good ur_w cannot be found.
226                  */
227                 return status::unimplemented;
228         }
229     }
230
231     jcp.wei_adj_scale
232             = (jcp.signed_input && (jcp.ver != ver_vnni)) ? (1.f / 2.f) : 1.f;
233
234     jcp.loop_order = jcp.ngroups > 1 ? loop_ngc : loop_cgn;
235     return status::success;
236 }
237
238 bool jit_avx512_core_x8s8s32x_deconv_fwd_kernel::maybe_eltwise(int position) {
239     using namespace primitive_kind;
240     const auto &p = attr_.post_ops_;
241
242     if (position == 0) {
243         /* eltwise before sum */
244         return p.contain(eltwise, 0);
245     } else if (position == 1) {
246         /* eltwise after sum */
247         return p.contain(sum, 0) && p.contain(eltwise, 1);
248     }
249     return false;
250 }
251
252 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::compute_eltwise(int ur_w) {
253     int nb_oc_block
254             = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
255     if (ur_w == jcp.ur_w)
256         eltwise_injector_->compute_vector_range(0, nb_oc_block * jcp.ur_w);
257     else
258         for (int k = 0; k < nb_oc_block; k++)
259             eltwise_injector_->compute_vector_range(
260                     k * jcp.ur_w, k * jcp.ur_w + ur_w);
261 }
262
263 bool jit_avx512_core_x8s8s32x_deconv_fwd_kernel::post_ops_ok(
264         jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
265     using namespace primitive_kind;
266     const auto &p = attr.post_ops_;
267
268     auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
269
270     switch (p.len_) {
271     case 0: return true;
272     case 1: return is_eltwise(0) || p.contain(sum, 0);
273     case 2:
274         return (p.contain(sum, 0) && is_eltwise(1))
275                 || (p.contain(sum, 1) && is_eltwise(0));
276     default: return false;
277     }
278
279     return false;
280 }
281
282 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_scratchpad(
283         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
284         const primitive_attr_t &attr) {
285     if (jcp.signed_input && jcp.ver != ver_vnni) {
286         size_t count = nstl::max(attr.output_scales_.count_, 16);
287         scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count);
288     }
289 }
290
291 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::compute_ker(int ur_w,
292         int l_overflow, int r_overflow, ker_block_t last_ic_block_flag,
293         bool h_padded) {
294
295     const int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block;
296     const int ur_w_stride = jcp.signed_input ? 1 : jcp.stride_w;
297
298     auto src_offset = [=](int oj, int icb, int ki) {
299         return jcp.typesize_in
300                 * (((oj + jcp.l_pad - ki * (jcp.dilate_w + 1)) / jcp.stride_w)
301                                   * jcp.ngroups * jcp.ic_without_padding
302                           + icb * 4);
303     };
304
305     auto kernel_offset = [=](int ocb, int icb, int ki) {
306         return jcp.typesize_in
307                 * (ocb * jcp.nb_ic * jcp.kh * jcp.kw * ch_block_all
308                           + icb * jcp.oc_block * jcp.ic_block / 4
309                           + ki * ch_block_all);
310     };
311
312     auto compute = [=](zmm_t vreg_acc, zmm_t vreg_wei, zmm_t vreg_src) {
313         if (jcp.ver == ver_vnni) {
314             vpdpbusd(vreg_acc, vreg_src, vreg_wei);
315         } else if (jcp.is_depthwise) {
316             vpmulld(zmm_tmp, vreg_src, vreg_wei);
317             vpaddd(vreg_acc, vreg_acc, zmm_tmp);
318         } else {
319             vpmaddubsw(zmm_tmp, vreg_src, vreg_wei);
320             vpmaddwd(zmm_tmp, zmm_tmp, zmm_one);
321             vpaddd(vreg_acc, vreg_acc, zmm_tmp);
322         }
323     };
324
325     for (int ki = 0; ki < jcp.kw; ki++) {
326
327         int jj_start = get_ow_start(ki, l_overflow);
328         int jj_end = get_ow_end(ur_w, ki, r_overflow);
329
330         int _start = (jcp.signed_input) ? 0 : jj_start;
331         int _end = (jcp.signed_input) ? ur_w : jj_end;
332
333         int tail_size = jcp.ic_without_padding % 4;
334         int n_ic_blocks = jcp.is_depthwise ?
335                 1 :
336                 (last_ic_block_flag & ~no_last_block ?
337                                 div_up(jcp.ic_without_padding % jcp.ic_block,
338                                         4) :
339                                 jcp.ic_block / 4);
340
341         for (int icb1 = 0; icb1 < n_ic_blocks; icb1++) {
342             if (h_padded == true) {
343                 /* fill padded area with shifted values */
344                 Zmm inp = zmm_inp(0, jcp.nb_oc_blocking);
345                 vpxord(inp, inp, inp);
346                 vpsubb(inp, inp, zmm_shift);
347             } else {
348
349                 for (int jj = _start; jj < _end; jj += ur_w_stride) {
350
351                     int aux_src_off = src_offset(jj, icb1, ki);
352
353                     if (jj >= jj_start && jj < jj_end
354                             && ((jj + jcp.l_pad - ki) % jcp.stride_w == 0)) {
355                         if (jcp.is_depthwise) {
356                             vpmovzxbd(zmm_inp(jj, jcp.nb_oc_blocking),
357                                     EVEX_compress_addr(
358                                               aux_reg_src, aux_src_off));
359                         } else if ((last_ic_block_flag & last_sp_block)
360                                 && tail_size != 0 && icb1 == n_ic_blocks - 1) {
361                             xmm_t xmm_tmp = xmm_t(
362                                     zmm_inp(jj, jcp.nb_oc_blocking).getIdx());
363                             for (int r = 0; r < tail_size; ++r)
364                                 vpinsrb(xmm_tmp, xmm_tmp,
365                                         ptr[aux_reg_src + aux_src_off + r], r);
366                             vpbroadcastd(
367                                     zmm_inp(jj, jcp.nb_oc_blocking), xmm_tmp);
368                         } else {
369                             vpbroadcastd(zmm_inp(jj, jcp.nb_oc_blocking),
370                                     EVEX_compress_addr(
371                                                  aux_reg_src, aux_src_off));
372                         }
373                         if (jcp.signed_input)
374                             vpsubb(zmm_inp(jj, jcp.nb_oc_blocking),
375                                     zmm_inp(jj, jcp.nb_oc_blocking), zmm_shift);
376                     } else {
377                         /* fill padded area with shifted values */
378                         if (jcp.signed_input) {
379                             Zmm inp = zmm_inp(jj, jcp.nb_oc_blocking);
380                             vpxord(inp, inp, inp);
381                             vpsubb(inp, inp, zmm_shift);
382                         }
383                     }
384                 }
385             }
386             for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
387                 int aux_filt_off = kernel_offset(ocb, icb1, ki);
388
389                 if (_end - _start > 0) {
390                     if (jcp.is_depthwise)
391                         vpmovsxbd(zmm_wei,
392                                 EVEX_compress_addr(aux_reg_filt, aux_filt_off));
393                     else
394                         vmovups(zmm_wei,
395                                 EVEX_compress_addr(aux_reg_filt, aux_filt_off));
396                 }
397                 for (int jj = _start; jj < _end; jj += ur_w_stride) {
398                     Zmm inp = (h_padded == true) ?
399                             zmm_inp(0, jcp.nb_oc_blocking) :
400                             zmm_inp(jj, jcp.nb_oc_blocking);
401                     compute(zmm_out(jj, ocb), zmm_wei, inp);
402                 }
403             }
404         }
405     }
406 }
407
408 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::kh_loop(int ur_w,
409         int l_overflow, int r_overflow, ker_block_t last_ic_block_flag) {
410
411     int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block;
412     int shift_src_ih = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw
413             * jcp.ngroups * jcp.ic_without_padding;
414     const int stride_h = jcp.signed_input ? 1 : jcp.stride_h;
415     int shift_filt_kh = jcp.typesize_in * jcp.kw * ch_block_all * stride_h;
416
417     Label kh_loop_label, skip_kh_loop;
418     Label t_overflow_label, no_t_overflow_label, b_overflow_label,
419             no_b_overflow_label;
420
421     mov(aux_reg_src, reg_src);
422     mov(aux_reg_filt, reg_filt);
423
424     if (jcp.signed_input && jcp.ndims > 3) {
425         /* Weights are transposed, so first compute 'bottom' padding. */
426         mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]);
427         cmp(reg_overflow, 0);
428         je(no_b_overflow_label, T_NEAR);
429         L(b_overflow_label); {
430             compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
431
432             add(aux_reg_filt, shift_filt_kh);
433             dec(reg_overflow);
434             cmp(reg_overflow, 0);
435             jg(b_overflow_label, T_NEAR);
436         }
437         L(no_b_overflow_label);
438     }
439
440     mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]);
441
442     if (jcp.signed_input || ((!jcp.signed_input)
443         && ((min(jcp.t_pad, jcp.b_pad) < 0)
444             || ((jcp.kh - 1) * (jcp.dilate_h + 1)
445                 < nstl::max(jcp.t_pad, jcp.b_pad))))) {
446         cmp(reg_kh, 0);
447         je(skip_kh_loop, T_NEAR);
448     }
449
450     L(kh_loop_label); {
451         compute_ker(ur_w, l_overflow, r_overflow, last_ic_block_flag, false);
452         sub(aux_reg_src, shift_src_ih);
453         add(aux_reg_filt, shift_filt_kh);
454         dec(reg_kh);
455
456         /* Insert weight compensation in stride 'holes' */
457         if (jcp.signed_input && jcp.stride_h > 1) {
458             Label kh_comp_loop;
459
460             cmp(reg_kh, 0);
461             je(skip_kh_loop, T_NEAR);
462             mov(reg_comp_strides, jcp.stride_h - 1);
463             L(kh_comp_loop);
464             {
465                 compute_ker(
466                         ur_w, 0, 0, last_ic_block_flag, true);
467                 add(aux_reg_filt, shift_filt_kh);
468                 dec(reg_comp_strides);
469                 cmp(reg_comp_strides, 0);
470                 jg(kh_comp_loop, T_NEAR);
471             }
472         }
473         cmp(reg_kh, 0);
474         jg(kh_loop_label, T_NEAR);
475     }
476     L(skip_kh_loop);
477     if (jcp.signed_input && jcp.ndims > 3) {
478         mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]);
479         cmp(reg_overflow, 0);
480         je(no_t_overflow_label, T_NEAR);
481         L(t_overflow_label); {
482             compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
483
484             add(aux_reg_filt, shift_filt_kh);
485             dec(reg_overflow);
486             cmp(reg_overflow, 0);
487             jg(t_overflow_label, T_NEAR);
488         }
489         L(no_t_overflow_label);
490     }
491 }
492
493 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::prepare_output(int ur_w) {
494     for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
495         for (int ur = 0; ur < ur_w; ur++) {
496             zmm_t zmm = zmm_out(ur, ocb);
497             vpxord(zmm, zmm, zmm);
498         }
499     }
500     if (jcp.signed_input) {
501         xor_(reg_scratch, reg_scratch);
502         Reg8 _t8 = reg_scratch.cvt8();
503         mov(_t8, (int8_t)-128);
504         vpbroadcastb(zmm_shift, _t8);
505     }
506 }
507
508 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::cvt2ps(
509         data_type_t type_in, zmm_t zmm_in, const Operand &op, bool mask_flag) {
510     zmm_t zmm = mask_flag ? zmm_in | ktail_mask | T_z : zmm_in;
511     switch (type_in) {
512     case data_type::f32:
513     case data_type::s32: vmovups(zmm, op); break;
514     case data_type::s8: vpmovsxbd(zmm, op); break;
515     case data_type::u8: vpmovzxbd(zmm, op); break;
516     default: assert(!"unsupported data type");
517     }
518     if (type_in != data_type::f32)
519         vcvtdq2ps(zmm_in, zmm_in);
520 }
521
522 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::store_output(
523         int ur_w, bool last_oc_block) {
524     mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
525     mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
526
527     if (jcp.signed_input)
528         mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]);
529
530     const auto &p = attr_.post_ops_;
531     const int sum_idx = p.find(primitive_kind::sum);
532     const float *p_sum_scale
533             = (sum_idx != -1) ? &p.entry_[sum_idx].sum.scale : nullptr;
534     if (p_sum_scale && *p_sum_scale != 1.f)
535         mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
536
537     if (jcp.with_bias && jcp.signed_input && jcp.ver != ver_vnni) {
538         mov(reg_bias_alpha, float2int(jcp.wei_adj_scale));
539         vmovq(xmm_bias_alpha(), reg_bias_alpha);
540         vbroadcastss(zmm_bias_alpha(), xmm_bias_alpha());
541     }
542
543     for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
544         const bool mask_flag = last_oc_block && ocb == jcp.nb_oc_blocking - 1;
545         int scale_offset
546                 = jcp.is_oc_scale * (sizeof(float) * ocb * jcp.oc_block);
547
548         auto zmm_bias = zmm_tmp;
549         if (jcp.with_bias) {
550             int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block;
551             auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
552             cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag);
553             if (jcp.signed_input && jcp.ver != ver_vnni)
554                 vmulps(zmm_bias, zmm_bias, zmm_bias_alpha());
555         }
556         if (jcp.signed_input) {
557             int comp_offset = sizeof(int32_t) * ocb * jcp.oc_block;
558             auto comp_addr = EVEX_compress_addr(reg_compensation, comp_offset);
559             cvt2ps(data_type::s32, zmm_comp, comp_addr, mask_flag);
560         }
561
562         for (int ur = 0; ur < ur_w; ur++) {
563             zmm_t zmm = zmm_out(ur, ocb);
564             vcvtdq2ps(zmm, zmm);
565             if (jcp.signed_input)
566                 vaddps(zmm, zmm, zmm_comp);
567             if (jcp.with_bias)
568                 vaddps(zmm, zmm, zmm_bias);
569             zmm_t mask_zmm = mask_flag ? zmm | ktail_mask | T_z : zmm;
570             vmulps(mask_zmm, zmm,
571                     EVEX_compress_addr(reg_ptr_scales, scale_offset));
572         }
573     }
574     if (maybe_eltwise(0))
575         compute_eltwise(ur_w);
576     if (p_sum_scale) { // post_op: sum
577         for (int k = 0; k < jcp.nb_oc_blocking; k++) {
578             const bool mask_flag
579                     = last_oc_block == 1 && k == jcp.nb_oc_blocking - 1;
580             for (int j = 0; j < ur_w; j++) {
581                 int aux_output_offset
582                         = jcp.typesize_out
583                         * (k * jcp.oc_block
584                                   + j * jcp.oc_without_padding * jcp.ngroups);
585                 auto addr = EVEX_compress_addr(reg_dst, aux_output_offset);
586                 Zmm zmm = zmm_out(j, k);
587                 cvt2ps(jcp.dst_dt, zmm_prev_dst, addr, mask_flag);
588                 if (*p_sum_scale == 1.f)
589                     vaddps(zmm, zmm_prev_dst);
590                 else
591                     vfmadd231ps(zmm, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
592             }
593         }
594     }
595     if (maybe_eltwise(1))
596         compute_eltwise(ur_w);
597
598     for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
599         const bool mask_flag = last_oc_block && ocb == jcp.nb_oc_blocking - 1;
600         for (int ur = 0; ur < ur_w; ur++) {
601             zmm_t zmm = zmm_out(ur, ocb);
602             if (jcp.dst_dt == data_type::u8) {
603                 vpxord(zmm_zero, zmm_zero, zmm_zero);
604                 vmaxps(zmm, zmm_zero, zmm);
605             }
606             if (jcp.dst_dt != data_type::f32) {
607                 if (attr_.round_mode_ == round_mode::nearest)
608                     vcvtps2dq(zmm | T_rn_sae, zmm);
609                 else if (attr_.round_mode_ == round_mode::down)
610                     vcvtps2dq(zmm | T_rd_sae, zmm);
611                 else
612                     assert(!"unimplemented");
613             }
614         }
615         for (int ur = 0; ur < ur_w; ur++) {
616             int aux_dst_off = jcp.typesize_out
617                     * (ur * jcp.ngroups * jcp.oc_without_padding
618                                       + ocb * jcp.oc_block);
619             auto addr = EVEX_compress_addr(reg_dst, aux_dst_off);
620
621             zmm_t zmm = zmm_out(ur, ocb);
622             zmm_t r_zmm = mask_flag ? zmm | ktail_mask : zmm;
623             switch (jcp.dst_dt) {
624             case data_type::f32:
625             case data_type::s32: vmovups(addr, r_zmm); break;
626             case data_type::s8: vpmovsdb(addr, r_zmm); break;
627             case data_type::u8: vpmovusdb(addr, r_zmm); break;
628             default: assert(!"unknown dst_dt");
629             }
630         }
631     }
632 }
633
634 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::icb_loop(
635         int ur_w, int l_overflow, int r_overflow, bool is_last_sp_block) {
636
637     int shift_src_icb = jcp.typesize_in * jcp.ic_block;
638     int shift_filt_icb
639             = jcp.typesize_in * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block;
640
641     prepare_output(ur_w);
642
643     Label skip_icb_loop, icb_loop_label;
644
645     mov(reg_icb, jcp.nb_ic);
646     L(icb_loop_label); {
647
648         if (jcp.ic_without_padding != jcp.ic) {
649             Label common_ker, end_ker;
650             cmp(reg_icb, 1);
651             jg(common_ker, T_NEAR);
652
653             kh_loop(ur_w, l_overflow, r_overflow,
654                     is_last_sp_block ? last_sp_block : last_ic_block);
655             jmp(end_ker, T_NEAR);
656
657             L(common_ker);
658             kh_loop(ur_w, l_overflow, r_overflow, no_last_block);
659
660             L(end_ker);
661         } else {
662             kh_loop(ur_w, l_overflow, r_overflow, no_last_block);
663         }
664
665         add(reg_src, shift_src_icb);
666         add(reg_filt, shift_filt_icb);
667         dec(reg_icb);
668         cmp(reg_icb, 0);
669         jg(icb_loop_label, T_NEAR);
670     }
671
672     /* come-back pointers */
673     sub(reg_src, jcp.nb_ic * shift_src_icb);
674     sub(reg_filt, jcp.nb_ic * shift_filt_icb);
675     L(skip_icb_loop);
676
677     if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
678         Label common_store, end_store;
679         mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
680         if (jcp.is_depthwise)
681             cmp(reg_oc_blocks, jcp.nb_ch - 1);
682         else
683             cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
684         jne(common_store, T_NEAR);
685
686         store_output(ur_w, true);
687         jmp(end_store, T_NEAR);
688
689         L(common_store);
690         store_output(ur_w, false);
691
692         L(end_store);
693
694     } else {
695         store_output(ur_w, false);
696     }
697 }
698
699 void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::generate() {
700     preamble();
701
702     xor_(reg_scratch, reg_scratch);
703     Reg16 _t = reg_scratch.cvt16();
704     mov(_t, 0x1);
705     vpbroadcastw(zmm_one, _t);
706
707     if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
708         int tail_size = jcp.is_depthwise ?
709                 jcp.ngroups % jcp.ch_block :
710                 jcp.oc_without_padding % jcp.oc_block;
711         int mask = (1 << tail_size) - 1;
712         Reg32 regw_tmp = reg_nur_w.cvt32();
713         mov(regw_tmp, mask);
714         kmovw(ktail_mask, regw_tmp);
715     }
716
717     mov(reg_src, ptr[param1 + GET_OFF(src)]);
718     mov(reg_filt, ptr[param1 + GET_OFF(filt)]);
719     mov(reg_dst, ptr[param1 + GET_OFF(dst)]);
720
721     int dst_shift = jcp.typesize_out * jcp.ur_w * jcp.ngroups
722             * jcp.oc_without_padding;
723     int src_shift = jcp.typesize_in * (jcp.ur_w / jcp.stride_w) * jcp.ngroups
724             * jcp.ic_without_padding;
725
726     int l_overflow = max(
727             0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w);
728     int r_overflow
729             = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - max(0, jcp.r_pad))
730                             / jcp.stride_w);
731
732     int r_overflow1
733             = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
734                                    - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail)
735                             / jcp.stride_w);
736     int nur_w = jcp.ow / jcp.ur_w;
737     if (r_overflow1 > 0)
738         nur_w--;
739
740     if (jcp.ur_w == jcp.ow) {
741         icb_loop(jcp.ur_w, l_overflow, r_overflow, true);
742     } else if (nur_w == 0) {
743         icb_loop(jcp.ur_w, l_overflow, r_overflow1, jcp.ur_w_tail == 0);
744         add(reg_src, src_shift);
745         add(reg_dst, dst_shift);
746         if (jcp.ur_w_tail != 0)
747             icb_loop(jcp.ur_w_tail, 0, r_overflow, true);
748     } else {
749         xor_(reg_nur_w, reg_nur_w);
750         if (l_overflow > 0) {
751             icb_loop(jcp.ur_w, l_overflow, 0, false);
752             add(reg_src, src_shift);
753             add(reg_dst, dst_shift);
754             inc(reg_nur_w);
755         }
756         if ((l_overflow <= 0 && nur_w > 0) || (l_overflow > 0 && nur_w > 1)) {
757             Label ow_loop_label;
758             L(ow_loop_label);
759             {
760                 icb_loop(jcp.ur_w, 0, 0, false);
761                 add(reg_src, src_shift);
762                 add(reg_dst, dst_shift);
763                 inc(reg_nur_w);
764                 cmp(reg_nur_w, nur_w);
765                 jl(ow_loop_label, T_NEAR);
766             }
767         }
768         if (r_overflow1 > 0) {
769             icb_loop(jcp.ur_w, 0, r_overflow1, jcp.ur_w_tail == 0);
770             add(reg_src, src_shift);
771             add(reg_dst, dst_shift);
772         }
773         if (jcp.ur_w_tail != 0) {
774             icb_loop(jcp.ur_w_tail, 0, r_overflow, true);
775         }
776     }
777     postamble();
778
779     if (jcp.with_eltwise)
780         eltwise_injector_->prepare_table();
781 }
782
783 template <data_type_t src_type, data_type_t dst_type>
784 void _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<src_type,
785         dst_type>::execute_forward_1d() const {
786     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
787     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
788     auto bias = reinterpret_cast<const char *>(this->input_memory(2));
789     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
790
791     const memory_desc_wrapper src_d(pd()->src_pd());
792     const memory_desc_wrapper dst_d(pd()->dst_pd());
793     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
794     const memory_desc_wrapper bias_d(pd()->weights_pd(1));
795
796     auto &jcp = kernel_->jcp;
797
798     int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
799     int nb_groups = jcp.nb_ch;
800
801     const float *oscales = pd()->attr()->output_scales_.scales_;
802     if (jcp.signed_input && jcp.ver != ver_vnni) {
803         auto local_scales
804                 = scratchpad().template get<float>(key_conv_adjusted_scales);
805         size_t count = pd()->attr()->output_scales_.count_;
806         float factor = 1.f / pd()->jcp_.wei_adj_scale;
807         if (count == 1) {
808             utils::array_set(local_scales, oscales[0] * factor, 16);
809         } else {
810             for (size_t c = 0; c < count; c++)
811                 local_scales[c] = oscales[c] * factor;
812         }
813         oscales = local_scales;
814     }
815     size_t offset = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw;
816     auto w = const_cast<wei_data_t *>(weights);
817     int32_t *compensation
818             = (jcp.signed_input) ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
819
820     parallel(0, [&](const int ithr, const int nthr) {
821         int start{ 0 }, end{ 0 };
822         int work_amount = jcp.mb * nb_groups * oc_chunks;
823         balance211(work_amount, nthr, ithr, start, end);
824
825         auto p = jit_deconv_call_s();
826
827         int n{ 0 }, g{ 0 }, occ{ 0 };
828         if (jcp.loop_order == loop_ngc)
829             nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks);
830         else if (jcp.loop_order == loop_cgn)
831             nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb);
832         else
833             assert(!"unsupported loop order");
834         while (start < end) {
835
836             int ocb = occ * jcp.nb_oc_blocking;
837             int g_oc = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block;
838             int g_ic = g * jcp.ch_block * jcp.ic;
839
840             p.dst = dst + dst_d.blk_off(n, g_oc);
841             p.src = src + src_d.blk_off(n, g_ic);
842             p.filt = weights + wht_blk_off(weights_d, g, ocb, 0);
843             p.bias = jcp.with_bias ?
844                     bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia) :
845                     0;
846             p.compensation = (jcp.signed_input) ? compensation + g_oc : 0;
847             p.scales = &oscales[jcp.is_oc_scale * g_oc];
848             p.t_overflow = 0;
849             p.b_overflow = 0;
850             p.kh_padding = jcp.kh;
851             p.oc_blocks = jcp.is_depthwise ? g : ocb;
852
853             kernel_->jit_ker(&p);
854
855             ++start;
856             if (jcp.loop_order == loop_ngc)
857                 nd_iterator_step(n, jcp.mb, g, nb_groups, occ, oc_chunks);
858             else if (jcp.loop_order == loop_cgn)
859                 nd_iterator_step(occ, oc_chunks, g, nb_groups, n, jcp.mb);
860             else
861                 assert(!"unsupported loop order");
862         }
863     });
864 }
865
866 template <data_type_t src_type, data_type_t dst_type>
867 void _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<src_type,
868         dst_type>::execute_forward_2d() const {
869     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
870     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
871     auto bias = reinterpret_cast<const char *>(this->input_memory(2));
872     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
873
874     const memory_desc_wrapper src_d(pd()->src_pd());
875     const memory_desc_wrapper dst_d(pd()->dst_pd());
876     const memory_desc_wrapper weights_d(pd()->weights_pd(0));
877     const memory_desc_wrapper bias_d(pd()->weights_pd(1));
878
879     auto &jcp = kernel_->jcp;
880
881     int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
882     int nb_groups = jcp.nb_ch;
883
884     size_t src_h_stride = src_d.blk_off(0, 0, 1);
885     size_t dst_h_stride = dst_d.blk_off(0, 0, 1);
886     size_t wht_kh_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
887
888     const float *oscales = pd()->attr()->output_scales_.scales_;
889     if (jcp.signed_input && jcp.ver != ver_vnni) {
890         auto local_scales
891                 = scratchpad().template get<float>(key_conv_adjusted_scales);
892         size_t count = pd()->attr()->output_scales_.count_;
893         float factor = 1.f / pd()->jcp_.wei_adj_scale;
894         if (count == 1) {
895             utils::array_set(local_scales, oscales[0] * factor, 16);
896         } else {
897             for (size_t c = 0; c < count; c++)
898                 local_scales[c] = oscales[c] * factor;
899         }
900         oscales = local_scales;
901     }
902     size_t offset = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw;
903     auto w = const_cast<wei_data_t *>(weights);
904     int32_t *compensation
905             = (jcp.signed_input) ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
906
907     int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh;
908
909     parallel(0, (size_t)work_amount, [&](const int ithr, const int nthr) {
910         int start{ 0 }, end{ 0 };
911         balance211(work_amount, nthr, ithr, start, end);
912
913         auto p = jit_deconv_call_s();
914
915         /*loop order = cgn*/
916         int n{ 0 }, g{ 0 }, occ{ 0 }, oh_s{ 0 };
917         if (jcp.loop_order == loop_ngc)
918             nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks,
919                     oh_s, jcp.oh);
920         else if (jcp.loop_order == loop_cgn)
921             nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb,
922                     oh_s, jcp.oh);
923         else
924             assert(!"unsupported loop order");
925         while (start < end) {
926
927             int ocb = occ * jcp.nb_oc_blocking;
928             int g_oc = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block;
929             int g_ic = g * jcp.ch_block * jcp.ic;
930             int work_rem = end - start;
931             int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
932
933             auto dst_w = dst + dst_d.blk_off(n, g_oc);
934             auto src_w = src + src_d.blk_off(n, g_ic);
935             auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0);
936             auto bias_w = jcp.with_bias ?
937                     bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia) :
938                     0;
939             int32_t *compensation_w
940                     = (jcp.signed_input) ? compensation + g_oc : 0;
941
942             auto scales = &oscales[jcp.is_oc_scale * g_oc];
943             for (int oj = oh_s; oj < oh_e; oj++) {
944                 int ih_max = 0, kh_lo = 0, kh_len = 0;
945                 if (jcp.dilate_h != 0 && jcp.stride_h == 1) {
946                     /* dilation */
947                     int dilate_h = jcp.dilate_h + 1;
948                     // Note: use div_up to account for "holes" in filter
949                     int o_t_overflow = div_up(
950                             max(0, (jcp.kh - 1) * dilate_h - oj - jcp.t_pad),
951                             dilate_h);
952                     int o_b_overflow
953                             = div_up(max(0, (jcp.kh - 1) * dilate_h + 1 - jcp.oh
954                                                      + oj - jcp.b_pad),
955                                     dilate_h);
956                     kh_len = jcp.kh - o_t_overflow - o_b_overflow;
957                     kh_lo = o_b_overflow;
958                     ih_max = oj + jcp.t_pad - o_b_overflow * dilate_h;
959                 } else {
960                     int o_t_overflow = max(
961                             0, (jcp.kh - (oj + 1 + jcp.t_pad)) / jcp.stride_h);
962                     int o_b_overflow
963                             = max(0, ((oj + jcp.kh) - (jcp.oh + jcp.b_pad))
964                                             / jcp.stride_h);
965                     int overflow_kh_hi = jcp.kh - 1
966                             - abs(jcp.oh + jcp.b_pad - (oj + 1)) % jcp.stride_h;
967                     int overflow_kh_lo = (oj + jcp.t_pad) % jcp.stride_h;
968
969                     kh_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h
970                             + 1 - o_t_overflow - o_b_overflow;
971                     kh_lo = overflow_kh_lo + o_b_overflow * jcp.stride_h;
972                     ih_max = (oj + jcp.t_pad - kh_lo) / jcp.stride_h;
973                 }
974
975                 int wei_stride
976                         = (!jcp.signed_input) ? kh_lo * wht_kh_stride : 0;
977                 p.src = src_w + ih_max * src_h_stride;
978                 p.dst = dst_w + oj * dst_h_stride;
979                 p.filt = wht_w + wei_stride;
980                 p.bias = bias_w;
981                 p.compensation = compensation_w;
982                 p.t_overflow = max(
983                         0, jcp.kh - (kh_lo + max(0, kh_len - 1) * jcp.stride_h
984                                             + 1));
985                 p.b_overflow = kh_lo;
986                 p.kh_padding = kh_len;
987                 p.scales = scales;
988                 p.oc_blocks = jcp.is_depthwise ? g : ocb;
989                 kernel_->jit_ker(&p);
990             }
991             if (jcp.loop_order == loop_ngc)
992                 nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ,
993                         oc_chunks, oh_s, jcp.oh);
994             else if (jcp.loop_order == loop_cgn)
995                 nd_iterator_jump(start, end, occ, oc_chunks, g, nb_groups, n,
996                         jcp.mb, oh_s, jcp.oh);
997             else
998                 assert(!"unsupported loop order");
999         }
1000     });
1001 }
1002
1003 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::u8,
1004         data_type::u8>;
1005 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::u8,
1006         data_type::s8>;
1007 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::u8,
1008         data_type::f32>;
1009 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::u8,
1010         data_type::s32>;
1011 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::s8,
1012         data_type::u8>;
1013 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::s8,
1014         data_type::s8>;
1015 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::s8,
1016         data_type::f32>;
1017 template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::s8,
1018         data_type::s32>;
1019 }
1020 }
1021 }