updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_x8s8s32x_conv_kernel.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "memory_tracking.hpp"
19 #include "nstl.hpp"
20 #include "type_helpers.hpp"
21 #include "utils.hpp"
22
23 #include "cpu_memory.hpp"
24
25 #include "jit_avx512_core_x8s8s32x_conv_kernel.hpp"
26
27 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
28
29 namespace mkldnn {
30 namespace impl {
31 namespace cpu {
32
33 using namespace mkldnn::impl::memory_format;
34 using namespace mkldnn::impl::memory_tracking::names;
35 using namespace mkldnn::impl::utils;
36 using namespace Xbyak;
37
38 namespace {
39 void pick_loop_order(jit_conv_conf_t &jcp, int nthr)
40 {
41     jcp.loop_order = loop_cwgn;
42     if (jcp.ngroups > 1) {
43         jcp.loop_order = loop_ngcw;
44         if (jcp.mb < nthr)
45             jcp.loop_order = jcp.ndims == 3 ? loop_nwcg : loop_nhwcg;
46     }
47 }
48 }
49
50 template<typename Vmm>
51 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::prepare_output(int ur_w)
52 {
53     int nb_oc_block
54             = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
55     for (int k = 0; k < nb_oc_block; k++)
56         for (int j = 0; j < ur_w; j++) {
57             Vmm vmm = vmm_out(j, k);
58             vpxord(vmm, vmm, vmm);
59         }
60     if (jcp.signed_input) {
61         xor_(reg_scratch, reg_scratch);
62         if (jcp.is_depthwise && !jcp.is_fast_depthwise) {
63             Reg32 _t32 = reg_scratch.cvt32();
64             mov(_t32, (uint32_t)128);
65             vpbroadcastd(vmm_shift, _t32);
66         } else {
67             Reg8 _t8 = reg_scratch.cvt8();
68             mov(_t8, (int8_t)128);
69             vpbroadcastb(vmm_shift, _t8);
70         }
71     }
72 }
73
74 template<typename Vmm>
75 const Vmm _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::
76     vmm_mask(const Vmm vmm_in, bool mask_flag, bool store) {
77     return vmm_in;
78 }
79
80 template<>
81 const Zmm _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>::
82     vmm_mask(const Zmm zmm_in, bool mask_flag, bool store) {
83     return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z)
84                      : zmm_in;
85 }
86
87
88 template<typename Vmm>
89 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::cvt2ps(data_type_t type_in,
90         const Vmm vmm_in, const Operand &op, bool mask_flag) {
91     //const Vmm vmm = mask_flag ? vmm_in | ktail_mask | T_z : vmm_in;
92     const Vmm vmm = vmm_mask(vmm_in, mask_flag);
93     switch (type_in) {
94     case data_type::f32:
95     case data_type::s32: vmovups(vmm, op); break;
96     case data_type::s8: vpmovsxbd(vmm, op); break;
97     case data_type::u8: vpmovzxbd(vmm, op); break;
98     default: assert(!"unsupported data type");
99     }
100     if (type_in != data_type::f32)
101         vcvtdq2ps(vmm_in, vmm_in);
102 }
103
104 template<typename Vmm>
105 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::store_output(
106         int ur_w, bool last_oc_block_flag) {
107     int nb_oc_block
108             = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
109     int oc_block = jcp.is_depthwise ? jcp.ch_block : jcp.oc_block;
110
111     mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
112     mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
113     if (jcp.signed_input)
114         mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]);
115
116     const auto &p = attr_.post_ops_;
117     const int sum_idx = p.find(primitive_kind::sum);
118     const float *p_sum_scale = nullptr;
119     if (sum_idx != -1) {
120         const auto &p_entry = p.entry_[sum_idx];
121         p_sum_scale = &p_entry.sum.scale;
122     }
123
124     if (p_sum_scale && *p_sum_scale != 1.f)
125         mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
126
127     if (jcp.signed_input && jcp.ver != ver_vnni) {
128         /* put 'wei_adj_scale = 0.5' for bias calculation */
129         mov(reg_bias_alpha, float2int(jcp.wei_adj_scale));
130         vmovq(xmm_bias_alpha(), reg_bias_alpha);
131         vbroadcastss(vmm_bias_alpha(), xmm_bias_alpha());
132     }
133
134     for (int k = 0; k < nb_oc_block; k++) {
135         const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1;
136         int scale_offset = jcp.is_oc_scale * (sizeof(float) * k * oc_block);
137         if (jcp.with_bias) {
138             int bias_offset = jcp.typesize_bia * k * oc_block;
139             auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
140
141             cvt2ps(jcp.bia_dt, vmm_bias, bias_addr, mask_flag);
142             if (jcp.signed_input && jcp.ver != ver_vnni)
143                 /* bias *= 0.5 */
144                 vmulps(vmm_bias, vmm_bias, vmm_bias_alpha());
145         }
146         if (jcp.signed_input) {
147             int comp_offset = sizeof(int32_t) * k * oc_block;
148             auto comp_addr = EVEX_compress_addr(reg_compensation, comp_offset);
149
150             cvt2ps(data_type::s32, vmm_comp, comp_addr, mask_flag);
151         }
152         /* add to zmm_accum: compensation, bias and permute */
153         for (int j = 0; j < ur_w; j++) {
154             Vmm vmm = vmm_out(j, k);
155             if (jcp.is_fast_depthwise)
156                 vpermd(zmm_out(j, k), zmm_permute, zmm_out(j, k));
157             vcvtdq2ps(vmm, vmm);
158             if (jcp.signed_input)
159                 vaddps(vmm, vmm, vmm_comp);
160             if (jcp.with_bias)
161                 vaddps(vmm, vmm, vmm_bias);
162
163             const Vmm vmm_k = vmm_mask(vmm, mask_flag);
164             vmulps(vmm_k, vmm,
165                     EVEX_compress_addr(reg_ptr_scales, scale_offset));
166         }
167     }
168
169     int eltwise_inj_idx = 0;
170     int depthwise_inj_idx = 0;
171     for (int i = 0; i < p.len_; i++) {
172         auto& post_op = p.entry_[i];
173         if (post_op.is_eltwise()) {
174             if (ur_w == jcp.ur_w)
175                eltwise_injectors[eltwise_inj_idx]->compute_vector_range(0, nb_oc_block * jcp.ur_w);
176             else
177                 for (int k = 0; k < nb_oc_block; k++)
178                     eltwise_injectors[eltwise_inj_idx]->compute_vector_range(k * jcp.ur_w, k * jcp.ur_w + ur_w);
179
180             eltwise_inj_idx++;
181         } else if (post_op.is_depthwise()) {
182             mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
183             mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
184
185             add(reg_d_weights, ptr[param1 + GET_OFF(oc_off)]);
186             add(reg_d_bias, ptr[param1 + GET_OFF(oc_off)]);
187
188             for (int k = 0; k < nb_oc_block; k++) {
189                 depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
190                         k * jcp.ur_w, k * jcp.ur_w + ur_w, reg_d_weights, reg_d_bias);
191
192                 add(reg_d_weights, oc_block * sizeof(float));
193                 add(reg_d_bias, oc_block * sizeof(float));
194             }
195
196             depthwise_inj_idx++;
197         } else if (post_op.is_sum(false)) {
198             for (int k = 0; k < nb_oc_block; k++) {
199                 const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1;
200                 for (int j = 0; j < ur_w; j++) {
201                     int aux_output_offset
202                             = jcp.typesize_out
203                             * (k * oc_block
204                                       + j * jcp.oc_without_padding * jcp.ngroups);
205                     auto addr = EVEX_compress_addr(reg_out, aux_output_offset);
206                     Zmm zmm = zmm_out(j, k);
207                     cvt2ps(jcp.dst_dt, vmm_prev_dst, addr, mask_flag);
208                     if (*p_sum_scale == 1.f)
209                         vaddps(zmm, vmm_prev_dst);
210                     else
211                         vfmadd231ps(zmm, vmm_prev_dst, zword_b[reg_ptr_sum_scale]);
212                 }
213             }
214         }
215     }
216
217     /* write out register to output_addr */
218     for (int k = 0; k < nb_oc_block; k++) {
219         const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1;
220         for (int j = 0; j < ur_w; j++) {
221             Vmm vmm = vmm_out(j, k);
222             if (jcp.dst_dt == data_type::u8) {
223                 vpxord(vmm_zero, vmm_zero, vmm_zero);
224                 vmaxps(vmm, vmm_zero, vmm);
225             }
226
227             if (jcp.dst_dt != data_type::f32) {
228                 /* Note: using Zmm for rounding in Xmm/Ymm kernel
229                    because there is no instruction to do rounding
230                    from Xmm/Ymm -> Xmm/Ymm.
231                    Embedded rounding is not supported for Xmm.
232                    TODO: maybe avoid Zmm if it helps performance.*/
233                 Zmm zmm = zmm_out(j, k);
234                 if (attr_.round_mode_ == round_mode::nearest)
235                     vcvtps2dq(zmm | T_rn_sae, zmm);
236                 else if (attr_.round_mode_ == round_mode::down)
237                     vcvtps2dq(zmm | T_rd_sae, zmm);
238                 else
239                     assert(!"unimplemented");
240             }
241         }
242
243         for (int j = 0; j < ur_w; j++) {
244             int aux_output_offset = jcp.typesize_out
245                     * (k * oc_block + j * jcp.oc_without_padding * jcp.ngroups);
246             auto addr = EVEX_compress_addr(reg_out, aux_output_offset);
247
248             Vmm vmm = vmm_out(j, k);
249             const Vmm r_vmm = vmm_mask(vmm, mask_flag, true);
250
251             switch (jcp.dst_dt) {
252             case data_type::f32:
253             case data_type::s32: vmovups(addr, r_vmm); break;
254             case data_type::s8: vpmovsdb(addr, r_vmm); break;
255             case data_type::u8: vpmovusdb(addr, r_vmm); break;
256             default: assert(!"unknown dst_dt");
257             }
258         }
259     }
260
261 }
262
263 template <typename Vmm>
264 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::compute_ker_dw(
265         int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded) {
266     assert(!"invalid group blocking for depthwise convolution");
267 }
268
269 template <>
270 void _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>::compute_ker_dw(
271         int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded) {
272
273     auto input_spatial_index = [=](int oi, int ki) {
274         return (ki * (jcp.dilate_w + 1) + oi * jcp.stride_w - pad_l);
275     };
276
277     auto input_offset2 = [=](int ii, int ci) {
278         return jcp.typesize_in * (ii * jcp.ngroups + ci * jcp.ch_block);
279     };
280
281     auto input_offset3 = [=](int oi, int ci, int ki) {
282         return jcp.typesize_in * input_offset2(input_spatial_index(oi, ki), ci);
283     };
284
285     auto kernel_offset = [=](int ci, int ki) {
286         return jcp.typesize_in * ((ci * jcp.kh * jcp.kw + ki) * jcp.ch_block);
287     };
288
289     auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) {
290         // okay for depthwise since src is zero-extended
291         if (jcp.ver == ver_vnni) {
292             vpdpbusd(vreg_acc, vreg_src, vreg_wei);
293         } else {
294             vpmaddwd(zmm_tmp, vreg_src, vreg_wei);
295             vpaddd(vreg_acc, vreg_acc, zmm_tmp);
296         }
297     };
298
299     int ii_start = 0;
300     int ii_end = -1;
301     if (jcp.is_resrc_depthwise && !h_padded) {
302         // find bounds of input spatial indices
303         bool first = true;
304         for (int ki = 0; ki < jcp.kw; ki++) {
305             int oi_start = get_ow_start(ki, pad_l);
306             int oi_end = get_ow_end(ur_w, ki, pad_r);
307             for (int oi = oi_start; oi < oi_end; oi++) {
308                 int ii = input_spatial_index(oi, ki);
309                 if (first || ii < ii_start)
310                     ii_start = ii;
311                 if (first || ii > ii_end)
312                     ii_end = ii;
313                 first = false;
314             }
315         }
316     }
317
318     if (jcp.signed_input) {
319         vpxord(zmm_shifted_zero, zmm_shifted_zero, zmm_shifted_zero);
320         vpaddb(zmm_shifted_zero, zmm_shifted_zero, vmm_shift);
321     }
322     for (int ci = 0; ci < jcp.nb_ch_blocking; ci++) {
323         const bool mask_flag = last_ic_block_flag != no_last_block
324                 && ci == jcp.nb_ch_blocking - 1;
325         if (jcp.is_resrc_depthwise && !h_padded) {
326             // now we can load input once and reuse up to jcp.kw times
327             for (int ii = ii_start; ii <= ii_end; ii++) {
328                 int aux_input_offset = input_offset2(ii, ci);
329                 const Zmm zmm_inp_tmp = zmm_inp(ii, jcp.nb_ch_blocking);
330                 const Zmm zmm_inp_msk = mask_flag
331                         ? zmm_inp_tmp | ktail_mask | T_z
332                         : zmm_inp_tmp;
333                 if (jcp.is_fast_depthwise) {
334                     assert(!mask_flag);
335                     vbroadcasti32x4(zmm_inp_msk,
336                             EVEX_compress_addr(aux_reg_inp, aux_input_offset));
337                 } else {
338                     vpmovzxbd(zmm_inp_msk,
339                             EVEX_compress_addr(aux_reg_inp, aux_input_offset));
340                 }
341                 if (jcp.signed_input)
342                     vpaddb(zmm_inp_tmp, zmm_inp_tmp, vmm_shift);
343             }
344         }
345         for (int ki = 0; ki < jcp.kw; ki++) {
346             int aux_kernel_offset = kernel_offset(ci, ki);
347             if (jcp.is_fast_depthwise) {
348                 vbroadcasti32x4(zmm_wei,
349                         EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
350                 vmovdqu8(zmm_wei | kblend_mask | T_z, zmm_wei);
351             } else {
352                 vpmovsxbd(zmm_wei,
353                         EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
354             }
355             if (h_padded) {
356                 assert(jcp.signed_input);
357                 for (int oi = 0; oi < ur_w; oi++)
358                     compute(zmm_out(oi, ci), zmm_wei, zmm_shifted_zero);
359             } else {
360                 const Zmm r_zmm_src = mask_flag ? zmm_src | ktail_mask : zmm_src;
361                 int oi_start = get_ow_start(ki, pad_l);
362                 int oi_end = get_ow_end(ur_w, ki, pad_r);
363                 int start_ = jcp.signed_input ? 0 : oi_start;
364                 int end_ = jcp.signed_input ? ur_w : oi_end;
365                 for (int oi = start_; oi < end_; oi++) {
366                     if (oi >= oi_start && oi < oi_end) {
367                         if (jcp.is_resrc_depthwise) {
368                             int ii = input_spatial_index(oi, ki);
369                             zmm_src = zmm_inp(ii, jcp.nb_ch_blocking);
370                         } else {
371                             int aux_input_offset = input_offset3(oi, ci, ki);
372                             if (jcp.is_fast_depthwise) {
373                                 assert(!mask_flag);
374                                 vbroadcasti32x4(r_zmm_src,
375                                         EVEX_compress_addr(aux_reg_inp,
376                                                         aux_input_offset));
377                             } else {
378                                 vpmovzxbd(r_zmm_src,
379                                         EVEX_compress_addr(aux_reg_inp,
380                                                   aux_input_offset));
381                             }
382                             if (jcp.signed_input)
383                                 vpaddb(zmm_src, zmm_src, vmm_shift);
384                         }
385                     } else if (jcp.signed_input) {
386                         zmm_src = zmm_shifted_zero;
387                     }
388                     compute(zmm_out(oi, ci), zmm_wei, zmm_src);
389                 }
390             }
391         }
392     }
393 }
394
395 template<typename Vmm>
396 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::compute_ker(int ur_w, int pad_l,
397         int pad_r, ic_block_t last_ic_block_flag, bool h_padded) {
398     if (jcp.is_depthwise)
399         return compute_ker_dw(ur_w, pad_l, pad_r, last_ic_block_flag, h_padded);
400
401     int kw = jcp.kw;
402     int stride_w = jcp.stride_w;
403     int ic_block = jcp.ic_block;
404     int oc_block = jcp.oc_block;
405     int ch_block_all = jcp.ch_block * ic_block * oc_block;
406
407     int nb_oc_block = jcp.nb_oc_blocking;
408
409     auto input_offset = [=](int oi, int ic, int ki) {
410         return jcp.typesize_in
411                 * ((ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l)
412                           * jcp.ic_without_padding * jcp.ngroups + 4 * ic);
413     };
414     auto kernel_offset = [=](int ii, int ic, int ki) {
415         return jcp.typesize_in
416                 * ((ii * jcp.nb_ic * jcp.kh * jcp.kw + ki) * ch_block_all
417                     + 4 * ic * oc_block);
418     };
419     auto compute = [=](Vmm vreg_acc, Vmm vreg_wei, Vmm vreg_src) {
420         if (jcp.ver == ver_vnni) {
421             vpdpbusd(vreg_acc, vreg_src, vreg_wei);
422         } else {
423             vpmaddubsw(vmm_tmp, vreg_src, vreg_wei);
424             vpmaddwd(vmm_tmp, vmm_tmp, vmm_one);
425             vpaddd(vreg_acc, vreg_acc, vmm_tmp);
426         }
427     };
428
429     for (int ki = 0; ki < kw; ki++) {
430         int jj_start = get_ow_start(ki, pad_l);
431         int jj_end = get_ow_end(ur_w, ki, pad_r);
432         int tail_size = jcp.ic_without_padding % 4;
433         int _start = (jcp.signed_input) ? 0 : jj_start;
434         int _end = (jcp.signed_input) ? ur_w : jj_end;
435         /* Skip the last loads of input if (ic%16)/4 < ic_block/4 */
436         int icb = (last_ic_block_flag != no_last_block)
437             ? div_up((jcp.ic_without_padding % ic_block), 4)
438             : ic_block / 4;
439         for (int ic = 0; ic < icb; ic++) {
440             if (h_padded == true) {
441                 /* fill padded area with shifted values */
442                 Vmm inp = vmm_inp(0,nb_oc_block);
443                 vpxord(inp, inp, inp);
444                 vpaddb(inp, inp, vmm_shift);
445             } else {
446                 for (int jj = _start; jj < _end; jj++) {
447                     int aux_input_offset = input_offset(jj, ic, ki);
448                     if (jj >= jj_start && jj < jj_end) {
449                         if (last_ic_block_flag == last_sp_block
450                                 && tail_size != 0 && ic == icb - 1) {
451                             Xmm xmm_tmp = Xmm(vmm_inp(jj, nb_oc_block).getIdx());
452                             for (int r = 0; r < tail_size; ++r)
453                                 vpinsrb(xmm_tmp, xmm_tmp,
454                                     ptr[aux_reg_inp + aux_input_offset + r], r);
455                             vpbroadcastd(vmm_inp(jj, nb_oc_block), xmm_tmp);
456                         } else {
457                             vpbroadcastd(vmm_inp(jj, nb_oc_block),
458                                     EVEX_compress_addr(
459                                                  aux_reg_inp, aux_input_offset));
460                         }
461                         if (jcp.signed_input)
462                             vpaddb(vmm_inp(jj, nb_oc_block),
463                                    vmm_inp(jj, nb_oc_block), vmm_shift);
464                     } else {
465                         /* fill padded area with shifted values */
466                         if (jcp.signed_input) {
467                             Vmm inp = vmm_inp(jj, nb_oc_block);
468                             vpxord(inp, inp, inp);
469                             vpaddb(inp, inp, vmm_shift);
470                         }
471                     }
472                 }
473             }
474             for (int ii = 0; ii < nb_oc_block; ii++) {
475                 int aux_kernel_offset = kernel_offset(ii, ic, ki);
476                 vmovups(vmm_wei,
477                         EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
478                 for (int jj = _start; jj < _end; jj++)  {
479                     Vmm inp = (h_padded == true)
480                         ? vmm_inp(0,nb_oc_block) : vmm_inp(jj, nb_oc_block);
481                     compute(vmm_out(jj, ii), vmm_wei, inp);
482                 }
483             }
484         }
485     }
486 }
487
488 template<typename Vmm>
489 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::kh_loop(
490         int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag) {
491     Label kh_label, skip_kh_loop;
492     Label t_overflow_label, no_t_overflow_label,
493           b_overflow_label, no_b_overflow_label;
494
495     int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block;
496     int shift_kernel_ptr = jcp.typesize_in * jcp.kw * ch_block_all;
497     int shift_input_ptr = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw
498         * jcp.ic_without_padding * jcp.ngroups;
499
500     mov(aux_reg_inp, reg_inp);
501     mov(aux_reg_ker, reg_ker);
502
503     if (jcp.signed_input && jcp.ndims > 3) {
504         mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]);
505         cmp(reg_overflow, 0);
506         je(no_t_overflow_label, T_NEAR);
507         L(t_overflow_label); {
508             compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
509
510             add(aux_reg_ker, shift_kernel_ptr);
511             dec(reg_overflow);
512             cmp(reg_overflow, 0);
513             jg(t_overflow_label, T_NEAR);
514         }
515         L(no_t_overflow_label);
516     }
517     mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
518     if ((jcp.signed_input) || (!jcp.signed_input &&
519        (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad))) {
520         cmp(reg_kj, 0);
521         je(skip_kh_loop, T_NEAR);
522     }
523     L(kh_label); {
524         compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, false);
525
526         add(aux_reg_ker, shift_kernel_ptr);
527         add(aux_reg_inp, shift_input_ptr);
528         dec(reg_kj);
529         cmp(reg_kj, 0);
530         jg(kh_label, T_NEAR);
531     }
532     L(skip_kh_loop);
533     if (jcp.signed_input && jcp.ndims > 3) {
534         mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]);
535         cmp(reg_overflow, 0);
536         je(no_b_overflow_label, T_NEAR);
537         L(b_overflow_label); {
538             compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
539
540             add(aux_reg_ker, shift_kernel_ptr);
541             dec(reg_overflow);
542             cmp(reg_overflow, 0);
543             jg(b_overflow_label, T_NEAR);
544         }
545         L(no_b_overflow_label);
546     }
547 }
548
549 template<typename Vmm>
550 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::icb_loop(
551         int ur_w, int pad_l, int pad_r, bool is_last_sp_block)
552 {
553     prepare_output(ur_w);
554
555     // IC loop
556     Label icb_label;
557     mov(reg_icb, jcp.nb_ic);
558     L(icb_label);
559     if (jcp.ngroups % jcp.ch_block != 0 || jcp.ic_without_padding != jcp.ic) {
560         Label common_ker, end_ker;
561
562         cmp(reg_icb, 1); // The last IC block
563         jne(common_ker, T_NEAR);
564
565         kh_loop(ur_w, pad_l, pad_r,
566                 is_last_sp_block ? last_sp_block : last_ic_block);
567         jmp(end_ker, T_NEAR);
568
569         L(common_ker);
570         kh_loop(ur_w, pad_l, pad_r, no_last_block);
571
572         L(end_ker);
573     } else {
574         kh_loop(ur_w, pad_l, pad_r, no_last_block);
575     }
576     // End of IC Loop
577     int inp_step = jcp.ic_block;
578     int ker_step = jcp.kh * jcp.kw * jcp.oc_block * jcp.ic_block;
579     add(reg_inp, jcp.typesize_in * inp_step);
580     add(reg_ker, jcp.typesize_in * ker_step);
581
582     dec(reg_icb);
583     cmp(reg_icb, 0);
584     jg(icb_label, T_NEAR);
585
586     sub(reg_inp, jcp.typesize_in * inp_step * jcp.nb_ic);
587     sub(reg_ker, jcp.typesize_in * ker_step * jcp.nb_ic);
588
589     if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
590         Label common_store, end_store;
591
592         if (jcp.is_depthwise)
593             cmp(reg_oc_blocks, jcp.nb_ch - jcp.nb_ch_blocking);
594         else
595             cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
596
597         jne(common_store, T_NEAR);
598
599         store_output(ur_w, true); // last oc block
600         jmp(end_store, T_NEAR);
601
602         L(common_store);
603         store_output(ur_w, false);
604
605         L(end_store);
606     } else {
607         store_output(ur_w, false);
608     }
609 }
610
611 template<typename Vmm>
612 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::generate()
613 {
614     const auto &p = attr_.post_ops_;
615     for (int i = 0; i < p.len_; i++) {
616         auto &post_op = p.entry_[i];
617         if (post_op.is_eltwise()) {
618             eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<avx512_common>(
619                     this,
620                     post_op.eltwise.alg,
621                     post_op.eltwise.alpha,
622                     post_op.eltwise.beta
623             ));
624         } else if (post_op.is_depthwise()) {
625             depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<avx512_common>(
626                     this,
627                     post_op.depthwise.alg
628             ));
629         }
630     }
631
632     Label permute_index_table;
633     int inp_shift_pad = jcp.typesize_in * (jcp.ur_w * jcp.stride_w - jcp.l_pad)
634         * jcp.ic_without_padding * jcp.ngroups;
635     int inp_shift_pad_second_block = -1 * jcp.typesize_in * jcp.l_pad
636         * jcp.ic_without_padding * jcp.ngroups;
637     int inp_shift = jcp.typesize_in *
638                         (jcp.ur_w * jcp.stride_w * jcp.ic_without_padding
639                          * jcp.ngroups);
640     int out_shift = jcp.typesize_out *
641                         (jcp.ur_w * jcp.oc_without_padding * jcp.ngroups);
642     preamble();
643
644     if (jcp.is_depthwise) {
645         int idx = jcp.max_regs_ur - 1;
646         if (!jcp.is_resrc_depthwise)
647             zmm_src = Zmm(++idx);
648         if (jcp.ver != ver_vnni)
649             zmm_tmp = Zmm(++idx);
650         if (jcp.is_fast_depthwise)
651             zmm_permute = Zmm(++idx);
652         if (jcp.signed_input) {
653             zmm_shifted_zero = Zmm(++idx);
654             ++idx; // due to extra register used for shifts and compensations
655         }
656         assert(idx == ker_dw_reg_base_idx);
657     }
658
659     if (!jcp.is_depthwise && jcp.ver != ver_vnni) {
660         xor_(reg_scratch, reg_scratch);
661         Reg16 _t16 = reg_scratch.cvt16();
662         mov(_t16, 0x1);
663         vpbroadcastw(vmm_one, _t16);
664     }
665
666     mov(reg_inp, ptr[param1 + GET_OFF(src)]);
667     mov(reg_out, ptr[param1 + GET_OFF(dst)]);
668     mov(reg_ker, ptr[param1 + GET_OFF(filt)]);
669
670     if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
671         int tail_size = jcp.is_depthwise
672             ? jcp.ngroups % jcp.ch_block
673             : jcp.oc_without_padding % jcp.oc_block;
674         int mask = (1 << tail_size) - 1;
675         mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
676         Reg32 regw_tmp = reg_oi.cvt32();
677         mov(regw_tmp, mask);
678         kmovw(ktail_mask, regw_tmp);
679     }
680     if (jcp.is_fast_depthwise) {
681         // prepare mask register for blending weights
682         mov(reg_scratch, 0x8888444422221111);
683         kmovq(kblend_mask, reg_scratch);
684         // load permute indices from data section
685         mov(reg_scratch, permute_index_table);
686         vmovdqu32(zmm_permute, ptr[reg_scratch]);
687     }
688
689     int r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w
690                     + (jcp.kw - 1) * (jcp.dilate_w + 1)
691                     - (jcp.iw + jcp.l_pad - 1));
692     int n_oi = jcp.ow / jcp.ur_w;
693     int r_pad1 = (jcp.ur_w * n_oi - 1) * jcp.stride_w
694         + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1);
695
696     if (jcp.nb_ow == 1) {
697         if (r_pad1 > 0 || jcp.ur_w_tail == 0)
698             n_oi--;
699
700         xor_(reg_oi, reg_oi);
701         if (jcp.ow == jcp.ur_w) {
702             icb_loop(jcp.ur_w, jcp.l_pad, r_pad, true);
703         } else {
704             if (n_oi == 0) {
705                 icb_loop(jcp.ur_w, jcp.l_pad, r_pad1, jcp.ur_w_tail == 0);
706                 add(reg_inp, inp_shift_pad);
707                 add(reg_out, out_shift);
708                 if (jcp.ur_w_tail != 0) {
709                     icb_loop(jcp.ur_w_tail, 0, r_pad, true);
710                 }
711             } else {
712                 if (jcp.l_pad > 0) {
713                     icb_loop(jcp.ur_w, jcp.l_pad, 0, false);
714                     add(reg_inp, inp_shift_pad);
715                     add(reg_out, out_shift);
716
717                     inc(reg_oi);
718                 }
719                 if ((jcp.l_pad <= 0 && n_oi > 0) || (jcp.l_pad > 0 && n_oi > 1))
720                 {
721                     Label ow_loop_label;
722                     L(ow_loop_label); {
723                         icb_loop(jcp.ur_w, 0, 0, false);
724                         add(reg_inp, inp_shift);
725                         add(reg_out, out_shift);
726
727                         inc(reg_oi);
728                         cmp(reg_oi, n_oi);
729                         jl(ow_loop_label, T_NEAR);
730                     }
731                 }
732                 if (r_pad1 > 0 || jcp.ur_w_tail == 0) {
733                     icb_loop(jcp.ur_w, 0, r_pad1, jcp.ur_w_tail == 0);
734                     add(reg_inp, inp_shift);
735                     add(reg_out, out_shift);
736                 }
737                 if (jcp.ur_w_tail != 0) {
738                     icb_loop(jcp.ur_w_tail, 0, r_pad, true);
739                 }
740             }
741         }
742     } else {
743         // ow block is only processed.
744         // Number of block is passed as parameter owb,
745         // and padding processing depends on this number.
746         Label end_label, last_oi_label, middle_ow_blocks_label, tail_label,
747             oi_loop_label, oi_loop_end_label;
748
749         assert(jcp.ow_block % jcp.ur_w == 0);
750         int n_oi_not_last_ow_block = jcp.ow_block / jcp.ur_w;
751         // to simplify code (and general regs usage),
752         // size of ow block must be >= 2 * ur_w
753         assert(n_oi_not_last_ow_block > 1);
754         int n_oi_next_last_ow_block = n_oi_not_last_ow_block;
755         int n_oi_first_ow_block = n_oi_not_last_ow_block;
756         int n_oi_last_ow_block
757             = (jcp.ow - jcp.ow_block * (jcp.nb_ow - 1)) / jcp.ur_w;
758         // prepare right padding
759         bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0;
760         bool first_ow_block_padded
761                 = next_last_ow_block_padded && jcp.nb_ow == 2;
762         bool last_ow_block_padded
763                 = (r_pad1 > 0 || jcp.ur_w_tail == 0) && n_oi_last_ow_block > 0;
764
765         if (last_ow_block_padded) n_oi_last_ow_block--;
766         else if (first_ow_block_padded) n_oi_first_ow_block--;
767         else if (next_last_ow_block_padded) n_oi_next_last_ow_block--;
768
769         mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
770         cmp(reg_owb, 0); // is that the first ow-block ?
771         jg(middle_ow_blocks_label, T_NEAR);
772
773         // the first ow block, compute left padding
774         mov(reg_oi, n_oi_first_ow_block);
775         if (jcp.l_pad > 0) {
776             icb_loop(jcp.ur_w, jcp.l_pad, 0, false);
777             add(reg_inp, inp_shift_pad);
778             add(reg_out, out_shift);
779
780             dec(reg_oi);
781         }
782         jmp(oi_loop_label, T_NEAR);
783
784         // middle or last ow block entry
785         L(middle_ow_blocks_label);
786
787         if (jcp.l_pad > 0) {
788             // just to consider left padding, not compute
789             add(reg_inp, inp_shift_pad_second_block);
790         }
791
792         // set number of iteration for oi-loop
793         if (n_oi_last_ow_block != n_oi_not_last_ow_block) {
794             cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ?
795             mov(reg_oi, n_oi_last_ow_block);
796             je(oi_loop_label, T_NEAR);
797         }
798
799         if (n_oi_next_last_ow_block != n_oi_not_last_ow_block) {
800             cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
801
802             mov(reg_oi, n_oi_next_last_ow_block);
803             je(oi_loop_label, T_NEAR);
804         }
805         mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks
806
807         // oi loop w/o padding
808         L(oi_loop_label); {
809             cmp(reg_oi, 0);
810             jle(oi_loop_end_label, T_NEAR);
811
812             icb_loop(jcp.ur_w, 0, 0, false);
813
814             add(reg_inp, inp_shift);
815             add(reg_out, out_shift);
816             dec(reg_oi);
817
818             jmp(oi_loop_label, T_NEAR);
819         }
820         L(oi_loop_end_label);
821
822         mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
823         cmp(reg_owb, 0); // first ow-block ?
824         if (first_ow_block_padded)
825             je(last_oi_label, T_NEAR);
826         else
827             je(end_label, T_NEAR);
828
829         cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
830         jl(end_label, T_NEAR);
831         if (next_last_ow_block_padded)
832             je(last_oi_label, T_NEAR);
833         else
834             je(end_label, T_NEAR);
835
836         // that is last block
837         if (!last_ow_block_padded)
838             jmp(tail_label, T_NEAR);
839
840         // last oi block with right padding
841         L(last_oi_label);
842         icb_loop(jcp.ur_w, 0, r_pad1, jcp.ur_w_tail == 0);
843         add(reg_inp, inp_shift);
844         add(reg_out, out_shift);
845
846         mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
847         cmp(reg_owb, jcp.nb_ow - 1); // last ow_block?
848         jl(end_label, T_NEAR);
849
850         // ur_w tail
851         L(tail_label);
852         if (jcp.ur_w_tail != 0) {
853             icb_loop(jcp.ur_w_tail, 0, r_pad, true);
854         }
855         L(end_label);
856     }
857     postamble();
858
859     for (auto& inj : eltwise_injectors)
860         inj->prepare_table();
861
862     if (jcp.is_fast_depthwise) {
863         align(64);
864         L(permute_index_table);
865         const uint32_t _idx[]
866                 = { 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15 };
867         for (size_t i = 0; i < sizeof(_idx) / sizeof(_idx[0]); ++i)
868             dd(_idx[i]);
869     }
870 }
871
872 bool jit_avx512_core_x8s8s32x_fwd_kernel::post_ops_ok(
873         jit_conv_conf_t &jcp, const primitive_attr_t &attr)
874 {
875     const auto &p = attr.post_ops_;
876
877     auto all_post_ops_supported = [&]() {
878         bool ok = true;
879
880         for (int i = 0; i < p.len_; i++) {
881             ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::sum, primitive_kind::eltwise, primitive_kind::depthwise);
882         }
883         return ok;
884     };
885     auto count = [&](mkldnn::impl::primitive_kind_t kind) { return p.count(kind); };
886
887     return all_post_ops_supported() &&
888            count(primitive_kind::sum) <= 1;
889 }
890
891 status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
892             const convolution_desc_t &cd, cpu_memory_t::pd_t &src_pd,
893             cpu_memory_t::pd_t &weights_pd, cpu_memory_t::pd_t &dst_pd,
894             cpu_memory_t::pd_t &bias_pd, const primitive_attr_t &attr,
895             int nthreads)
896 {
897     using namespace prop_kind;
898
899     const memory_desc_wrapper src_d(&src_pd);
900     const memory_desc_wrapper weights_d(&weights_pd);
901     const memory_desc_wrapper dst_d(&dst_pd);
902     const memory_desc_wrapper bias_d(&bias_pd);
903
904     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
905     int ndims = src_d.ndims();
906     bool is_1d = ndims == 3;
907
908     if (!(mayiuse(avx512_core)
909          && one_of(src_d.data_type(), data_type::u8, data_type::s8)
910          && weights_d.data_type() == data_type::s8
911          && one_of(dst_d.data_type(), data_type::f32, data_type::s32,
912             data_type::s8, data_type::u8)))
913         return status::unimplemented;
914
915     jcp = zero<decltype(jcp)>();
916     jcp.ndims = ndims;
917     jcp.prop_kind = cd.prop_kind;
918     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
919     jcp.mb = src_d.dims()[0];
920     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
921     jcp.oc_without_padding = jcp.oc;
922     jcp.ic = src_d.dims()[1] / jcp.ngroups;
923     jcp.ic_without_padding = jcp.ic;
924     jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2];
925     jcp.iw = src_d.dims()[ndims - 1];
926     jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2];
927     jcp.ow = dst_d.dims()[ndims - 1];
928     jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2];
929     jcp.kw = weights_d.dims()[with_groups + ndims - 1];
930     jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4];
931     jcp.l_pad = cd.padding[0][ndims - 3];
932     jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4];
933     jcp.stride_w = cd.strides[ndims - 3];
934     jcp.src_fmt = src_d.format();
935     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
936
937     jcp.ur_h = 1; /* no code-unrolling by h so far */
938
939     jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4];
940     jcp.dilate_w = cd.dilates[ndims - 3];
941
942     jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false;
943     jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc);
944
945     if (jcp.is_depthwise) {
946         jcp.ch_block = 16;
947         jcp.ic_block = 1;
948         jcp.oc_block = 1;
949     } else {
950         jcp.ch_block = 1;
951         jcp.ic_block = 16;
952         jcp.oc_block = 16;
953
954         if (jcp.ngroups == 1) {
955             /* For non grouped convolutions, pad channels by 16 if needed */
956             jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
957             jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
958         } else if (!is_1d && jcp.ngroups != 1 && jcp.ic % jcp.ic_block != 0) {
959             /* For grouped convolutions, MKL-DNN doesn't support padding.
960                Use Ymm when channels per group is multiple of 8,
961                Xmm when channels per group is multiple of 4 */
962             jcp.ic_block = jcp.ic % 8 == 0 ? 8 : 4;
963             jcp.oc_block = jcp.ic_block;
964         }
965         if (jcp.ic % jcp.ic_block !=0 || jcp.oc % jcp.oc_block != 0)
966             return status::unimplemented;
967     }
968
969     jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
970             - (jcp.ih + jcp.t_pad - 1);
971
972     if (!post_ops_ok(jcp, attr))
973         return status::unimplemented;
974
975     jcp.ver = mayiuse(avx512_core_vnni) ? ver_vnni : ver_avx512_core;
976     jcp.is_fast_depthwise = true && jcp.is_depthwise && jcp.ver == ver_vnni
977             && jcp.ngroups % jcp.ch_block == 0; // for groups not multiple of 16
978                                                 // would require byte masking
979                                                 // for load from src
980     jcp.is_resrc_depthwise = jcp.is_depthwise && jcp.stride_w < jcp.kw
981             && jcp.kw < 4 && jcp.dilate_w == 0;
982     if (jcp.is_depthwise) {
983         jcp.max_regs_ur = 31 - jcp.is_fast_depthwise - !jcp.is_resrc_depthwise
984                 - 2 * jcp.signed_input - (jcp.ver != ver_vnni);
985     } else {
986         jcp.max_regs_ur = jcp.ver == ver_vnni ? 31 : 28;
987     }
988
989     auto src_format = pick(ndims - 3, nwc, nhwc);
990     auto dst_format = pick(ndims - 3, nwc, nhwc);
991 #define pick_signed(fmt) (jcp.signed_input ? fmt##_s8s8 : fmt)
992     memory_format_t w_format;
993     if (jcp.ic_block == 16 || jcp.ch_block == 16) {
994         w_format = is_1d ?
995                 (with_groups ? (jcp.is_depthwise ? pick_signed(Goiw16g) :
996                                                    pick_signed(gOIw4i16o4i)) :
997                                pick_signed(OIw4i16o4i)) :
998                 (with_groups ? (jcp.is_depthwise ? pick_signed(Goihw16g) :
999                                                    pick_signed(gOIhw4i16o4i)) :
1000                                pick_signed(OIhw4i16o4i));
1001         /* Non-grouped conv will always be padded by 16*/
1002     } else if (with_groups && jcp.ic_block == 8) {
1003         w_format = pick_signed(gOIhw2i8o4i);
1004     } else {
1005         w_format = pick_signed(gOIhw4o4i);
1006     }
1007 #undef pick_signed
1008
1009     if (weights_d.format() == any)
1010         CHECK(weights_pd.set_format(w_format));
1011     if (weights_d.format() != w_format)
1012         return status::unimplemented;
1013     if (dst_d.format() == any)
1014         CHECK(dst_pd.set_format(dst_format));
1015     if (dst_d.format() != dst_format)
1016         return status::unimplemented;
1017     if (src_d.format() == any)
1018         CHECK(src_pd.set_format(src_format));
1019     if (src_d.format() != src_format)
1020         return status::unimplemented;
1021     if (jcp.with_bias) {
1022         if (bias_d.format() == any)
1023             CHECK(bias_pd.set_format(x));
1024         if (bias_d.format() != x)
1025             return status::unimplemented;
1026     }
1027
1028     jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
1029     jcp.dst_dt = cd.dst_desc.data_type;
1030
1031     jcp.typesize_in = types::data_type_size(src_d.data_type());
1032     jcp.typesize_out = types::data_type_size(dst_d.data_type());
1033     jcp.typesize_bia = jcp.with_bias
1034         ? types::data_type_size(bias_d.data_type())
1035         : 0;
1036
1037     jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block);
1038     jcp.nb_ic = jcp.ic / jcp.ic_block;
1039     jcp.nb_oc = jcp.oc / jcp.oc_block;
1040
1041     // Try to use 4 channel-groups at a time to avoid false sharing (depthwise)
1042     int nb_ch_blocking = 4;
1043     for ( /* init above */ ; nb_ch_blocking > 1; nb_ch_blocking--)
1044         if (jcp.nb_ch % nb_ch_blocking == 0)
1045             break;
1046     jcp.nb_ch_blocking = jcp.is_depthwise ? nb_ch_blocking : 1;
1047
1048     // If OC blocking is incommensurate with the number of OC blocks (general
1049     // requirement for all convolutions), or if it results in an unrolling
1050     // factor smaller than the left padding (special requirement for SSD:fc6),
1051     // then search for a smaller OC blocking that satisfies both constraints.
1052     auto is_oc_blocking_ok = [&](int block) {
1053         int ur_w = nstl::min(jcp.ow, jcp.max_regs_ur / (block + 1));
1054         return jcp.nb_oc % block == 0
1055                 && jcp.l_pad <= ur_w && jcp.ow % ur_w != 1;
1056     };
1057
1058     // choose nb_oc work chunk size for distribution within threads
1059     int max_threading_nb_oc_chunk = 4;
1060     // Performance improvements for googlenet_v3 and resnet_50 with mb = 1;
1061     // TODO: generalize this condition and rewrite it in appropriate manner
1062     if (jcp.ver == ver_vnni && jcp.mb == 1 && jcp.kh == 3 && jcp.kw == 3
1063             && jcp.stride_w == 1 && jcp.ic % 64 == 0)
1064         max_threading_nb_oc_chunk = 2;
1065     jcp.nb_oc_blocking_thr_chunk =
1066         nstl::min(max_threading_nb_oc_chunk, jcp.nb_oc);
1067     for (; jcp.nb_oc_blocking_thr_chunk > 1; jcp.nb_oc_blocking_thr_chunk--) {
1068         if (is_oc_blocking_ok(jcp.nb_oc_blocking_thr_chunk))
1069             break;
1070     }
1071
1072     // choose oc blocking for computational kernel
1073     jcp.nb_oc_blocking = jcp.nb_oc_blocking_thr_chunk;
1074     // Performance improvements for googlenet_v3 with mb = 1;
1075     // TODO: generalize this condition and rewrite it in appropriate manner
1076     const int size_treshold_for_nb_oc_blocking_reduction = 17;
1077     if (jcp.mb == 1 && jcp.ow <= size_treshold_for_nb_oc_blocking_reduction
1078             && jcp.stride_w == 1
1079             && !(jcp.kh == 1 && jcp.kw == 3)
1080             && !(jcp.kh >= 7 && jcp.oc % 64 == 0)) {
1081         const int max_nb_oc_blocking = 2;
1082         jcp.nb_oc_blocking = nstl::min(max_nb_oc_blocking, jcp.nb_oc);
1083         for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--)
1084             if (jcp.nb_oc_blocking_thr_chunk % jcp.nb_oc_blocking == 0
1085                 && is_oc_blocking_ok(jcp.nb_oc_blocking))
1086                 break;
1087     }
1088
1089     if (jcp.is_resrc_depthwise)
1090         jcp.ur_w = (jcp.max_regs_ur - jcp.kw + jcp.stride_w)
1091                 / (jcp.nb_ch_blocking + jcp.stride_w);
1092     else
1093         jcp.ur_w
1094                 = jcp.max_regs_ur / (jcp.is_depthwise ? jcp.nb_ch_blocking
1095                                                       : jcp.nb_oc_blocking + 1);
1096     if (jcp.ow < jcp.ur_w)
1097         jcp.ur_w = jcp.ow;
1098     jcp.ur_w_tail = jcp.ow % jcp.ur_w;
1099
1100     jcp.ow_block = jcp.ow;
1101     int base_work_amount = jcp.mb * jcp.nb_ch * jcp.oh
1102                          * (jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk);
1103     float best_thr_eff
1104             = (float)base_work_amount / rnd_up(base_work_amount, nthreads);
1105     int max_nb_ow = div_up(jcp.ow, 2 * jcp.ur_w);
1106     for (int nb_ow = 1; nb_ow <= max_nb_ow; nb_ow++) {
1107         int ow_block
1108                 = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), jcp.ur_w), jcp.ow);
1109         if (ow_block < jcp.nb_oc_blocking_thr_chunk * jcp.oc_block
1110              && best_thr_eff > 0.8f)
1111             break;
1112         if (div_up(jcp.ow, ow_block) != nb_ow)
1113             continue;
1114         auto work_amount = base_work_amount * nb_ow;
1115         float thr_eff = (float)work_amount / rnd_up(work_amount, nthreads);
1116         if (ow_block >= 2 * jcp.ur_w && thr_eff > 1.1f * best_thr_eff) {
1117             jcp.ow_block = ow_block;
1118             best_thr_eff = thr_eff;
1119         }
1120         if (best_thr_eff > 0.9f)
1121             break;
1122     }
1123     jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
1124
1125     bool args_ok = true
1126         && jcp.oc % jcp.oc_block == 0
1127         && jcp.l_pad <= jcp.ur_w
1128         && IMPLICATION(!jcp.is_1stconv, jcp.ic % jcp.ic_block == 0);
1129     if (!args_ok)
1130         return status::unimplemented;
1131
1132     int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
1133                     + (jcp.kw - 1) * (jcp.dilate_w + 1)
1134                     - (jcp.iw + jcp.l_pad - 1));
1135     if (r_pad_no_tail > jcp.ur_w)
1136         return status::unimplemented;
1137
1138     pick_loop_order(jcp, nthreads);
1139
1140     jcp.nb_ic_L2 = jcp.nb_ic;
1141
1142     const auto &oscales = attr.output_scales_;
1143     jcp.is_oc_scale = oscales.mask_ == 1 << 1;
1144
1145     assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
1146
1147     jcp.wei_adj_scale = (jcp.signed_input) ? (1.f / 2.f) : 1.f;
1148
1149     return status::success;
1150 }
1151
1152 void jit_avx512_core_x8s8s32x_fwd_kernel::init_scratchpad(
1153         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
1154         const primitive_attr_t &attr) {
1155     if (jcp.signed_input && jcp.ver != ver_vnni) {
1156         size_t count = nstl::max(attr.output_scales_.count_, jcp.ic_block);
1157         scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count);
1158     }
1159 }
1160
1161 template struct  _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>;
1162 template struct  _jit_avx512_core_x8s8s32x_fwd_kernel<Ymm>;
1163 template struct  _jit_avx512_core_x8s8s32x_fwd_kernel<Xmm>;
1164 }
1165 }
1166 }
1167
1168 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s