Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_x8s8s32x_conv_kernel.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "memory_tracking.hpp"
19 #include "nstl.hpp"
20 #include "type_helpers.hpp"
21 #include "utils.hpp"
22
23 #include "cpu_memory.hpp"
24
25 #include "jit_avx512_core_x8s8s32x_conv_kernel.hpp"
26
27 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
28
29 namespace mkldnn {
30 namespace impl {
31 namespace cpu {
32
33 using namespace mkldnn::impl::memory_format;
34 using namespace mkldnn::impl::memory_tracking::names;
35 using namespace mkldnn::impl::utils;
36 using namespace Xbyak;
37
38 namespace {
39 void pick_loop_order(jit_conv_conf_t &jcp, int nthr)
40 {
41     jcp.loop_order = loop_cwgn;
42     if (jcp.ngroups > 1) {
43         jcp.loop_order = loop_ngcw;
44         if (jcp.mb < nthr)
45             jcp.loop_order = loop_nhwcg;
46     }
47 }
48 }
49
50 template<typename Vmm>
51 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::prepare_output(int ur_w)
52 {
53     int nb_oc_block
54             = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
55     for (int k = 0; k < nb_oc_block; k++)
56         for (int j = 0; j < ur_w; j++) {
57             Vmm vmm = vmm_out(j, k);
58             vpxord(vmm, vmm, vmm);
59         }
60     if (jcp.signed_input) {
61         xor_(reg_scratch, reg_scratch);
62         if (jcp.is_depthwise && !jcp.is_fast_depthwise) {
63             Reg32 _t32 = reg_scratch.cvt32();
64             mov(_t32, (uint32_t)128);
65             vpbroadcastd(vmm_shift, _t32);
66         } else {
67             Reg8 _t8 = reg_scratch.cvt8();
68             mov(_t8, (int8_t)128);
69             vpbroadcastb(vmm_shift, _t8);
70         }
71     }
72     if (jcp.is_fast_depthwise) {
73        vpxord(zmm_zero_blend, zmm_zero_blend, zmm_zero_blend);
74     }
75 }
76
77 template<typename Vmm>
78 const Vmm _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::
79     vmm_mask(const Vmm vmm_in, bool mask_flag, bool store) {
80     return vmm_in;
81 }
82
83 template<>
84 const Zmm _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>::
85     vmm_mask(const Zmm zmm_in, bool mask_flag, bool store) {
86     return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z)
87                      : zmm_in;
88 }
89
90
91 template<typename Vmm>
92 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::cvt2ps(data_type_t type_in,
93         const Vmm vmm_in, const Operand &op, bool mask_flag) {
94     //const Vmm vmm = mask_flag ? vmm_in | ktail_mask | T_z : vmm_in;
95     const Vmm vmm = vmm_mask(vmm_in, mask_flag);
96     switch (type_in) {
97     case data_type::f32:
98     case data_type::s32: vmovups(vmm, op); break;
99     case data_type::s8: vpmovsxbd(vmm, op); break;
100     case data_type::u8: vpmovzxbd(vmm, op); break;
101     default: assert(!"unsupported data type");
102     }
103     if (type_in != data_type::f32)
104         vcvtdq2ps(vmm_in, vmm_in);
105 }
106
107 template<typename Vmm>
108 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::store_output(
109         int ur_w, bool last_oc_block_flag) {
110     int nb_oc_block
111             = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
112     int oc_block = jcp.is_depthwise ? jcp.ch_block : jcp.oc_block;
113
114     mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
115     mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
116     if (jcp.signed_input)
117         mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]);
118
119     const auto &p = attr_.post_ops_;
120     const int sum_idx = p.find(primitive_kind::sum);
121     const float *p_sum_scale = nullptr;
122     if (sum_idx != -1) {
123         const auto &p_entry = p.entry_[sum_idx];
124         p_sum_scale = &p_entry.sum.scale;
125     }
126
127     if (p_sum_scale && *p_sum_scale != 1.f)
128         mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
129
130     if (jcp.signed_input && jcp.ver != ver_vnni) {
131         /* put 'wei_adj_scale = 0.5' for bias calculation */
132         mov(reg_bias_alpha, float2int(jcp.wei_adj_scale));
133         vmovq(xmm_bias_alpha(), reg_bias_alpha);
134         vbroadcastss(vmm_bias_alpha(), xmm_bias_alpha());
135     }
136
137     for (int k = 0; k < nb_oc_block; k++) {
138         const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1;
139         int scale_offset = jcp.is_oc_scale * (sizeof(float) * k * oc_block);
140         if (jcp.with_bias) {
141             int bias_offset = jcp.typesize_bia * k * oc_block;
142             auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
143
144             cvt2ps(jcp.bia_dt, vmm_bias, bias_addr, mask_flag);
145             if (jcp.signed_input && jcp.ver != ver_vnni)
146                 /* bias *= 0.5 */
147                 vmulps(vmm_bias, vmm_bias, vmm_bias_alpha());
148         }
149         if (jcp.signed_input) {
150             int comp_offset = sizeof(int32_t) * k * oc_block;
151             auto comp_addr = EVEX_compress_addr(reg_compensation, comp_offset);
152
153             cvt2ps(data_type::s32, vmm_comp, comp_addr, mask_flag);
154         }
155         /* add to zmm_accum: compensation, bias and permute */
156         for (int j = 0; j < ur_w; j++) {
157             Vmm vmm = vmm_out(j, k);
158             if (jcp.is_fast_depthwise)
159                 vpermd(zmm_out(j, k), zmm_permute, zmm_out(j, k));
160             vcvtdq2ps(vmm, vmm);
161             if (jcp.signed_input)
162                 vaddps(vmm, vmm, vmm_comp);
163             if (jcp.with_bias)
164                 vaddps(vmm, vmm, vmm_bias);
165
166             const Vmm vmm_k = vmm_mask(vmm, mask_flag);
167             vmulps(vmm_k, vmm,
168                     EVEX_compress_addr(reg_ptr_scales, scale_offset));
169         }
170     }
171
172     int eltwise_inj_idx = 0;
173     int depthwise_inj_idx = 0;
174     for (int i = 0; i < p.len_; i++) {
175         auto& post_op = p.entry_[i];
176         if (post_op.is_eltwise()) {
177             if (ur_w == jcp.ur_w)
178                eltwise_injectors[eltwise_inj_idx]->compute_vector_range(0, nb_oc_block * jcp.ur_w);
179             else
180                 for (int k = 0; k < nb_oc_block; k++)
181                     eltwise_injectors[eltwise_inj_idx]->compute_vector_range(k * jcp.ur_w, k * jcp.ur_w + ur_w);
182
183             eltwise_inj_idx++;
184         } else if (post_op.is_depthwise()) {
185             mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
186             mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
187
188             add(reg_d_weights, ptr[param1 + GET_OFF(oc_off)]);
189             add(reg_d_bias, ptr[param1 + GET_OFF(oc_off)]);
190
191             for (int k = 0; k < nb_oc_block; k++) {
192                 depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
193                         k * jcp.ur_w, k * jcp.ur_w + ur_w, reg_d_weights, reg_d_bias);
194
195                 add(reg_d_weights, oc_block * sizeof(float));
196                 add(reg_d_bias, oc_block * sizeof(float));
197             }
198
199             depthwise_inj_idx++;
200         } else if (post_op.is_sum(false)) {
201             for (int k = 0; k < nb_oc_block; k++) {
202                 const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1;
203                 for (int j = 0; j < ur_w; j++) {
204                     int aux_output_offset
205                             = jcp.typesize_out
206                             * (k * oc_block
207                                       + j * jcp.oc_without_padding * jcp.ngroups);
208                     auto addr = EVEX_compress_addr(reg_out, aux_output_offset);
209                     Zmm zmm = zmm_out(j, k);
210                     cvt2ps(jcp.dst_dt, vmm_prev_dst, addr, mask_flag);
211                     if (*p_sum_scale == 1.f)
212                         vaddps(zmm, vmm_prev_dst);
213                     else
214                         vfmadd231ps(zmm, vmm_prev_dst, zword_b[reg_ptr_sum_scale]);
215                 }
216             }
217         }
218     }
219
220     /* write out register to output_addr */
221     for (int k = 0; k < nb_oc_block; k++) {
222         const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1;
223         for (int j = 0; j < ur_w; j++) {
224             Vmm vmm = vmm_out(j, k);
225             if (jcp.dst_dt == data_type::u8) {
226                 vpxord(vmm_zero, vmm_zero, vmm_zero);
227                 vmaxps(vmm, vmm_zero, vmm);
228             }
229
230             if (jcp.dst_dt != data_type::f32) {
231                 /* Note: using Zmm for rounding in Xmm/Ymm kernel
232                    because there is no instruction to do rounding
233                    from Xmm/Ymm -> Xmm/Ymm.
234                    Embedded rounding is not supported for Xmm.
235                    TODO: maybe avoid Zmm if it helps performance.*/
236                 Zmm zmm = zmm_out(j, k);
237                 if (attr_.round_mode_ == round_mode::nearest)
238                     vcvtps2dq(zmm | T_rn_sae, zmm);
239                 else if (attr_.round_mode_ == round_mode::down)
240                     vcvtps2dq(zmm | T_rd_sae, zmm);
241                 else
242                     assert(!"unimplemented");
243             }
244         }
245
246         for (int j = 0; j < ur_w; j++) {
247             int aux_output_offset = jcp.typesize_out
248                     * (k * oc_block + j * jcp.oc_without_padding * jcp.ngroups);
249             auto addr = EVEX_compress_addr(reg_out, aux_output_offset);
250
251             Vmm vmm = vmm_out(j, k);
252             const Vmm r_vmm = vmm_mask(vmm, mask_flag, true);
253
254             switch (jcp.dst_dt) {
255             case data_type::f32:
256             case data_type::s32: vmovups(addr, r_vmm); break;
257             case data_type::s8: vpmovsdb(addr, r_vmm); break;
258             case data_type::u8: vpmovusdb(addr, r_vmm); break;
259             default: assert(!"unknown dst_dt");
260             }
261         }
262     }
263
264 }
265
266 template <typename Vmm>
267 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::compute_ker_dw(
268         int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded) {
269     assert(!"invalid group blocking for depthwise convolution");
270 }
271
272 template <>
273 void _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>::compute_ker_dw(
274         int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded) {
275     auto input_offset = [=](int oi, int ii, int ki) {
276         return jcp.typesize_in
277                 * ((ki * (jcp.dilate_w + 1) + oi * jcp.stride_w - pad_l)
278                                   * jcp.ngroups
279                           + ii * jcp.ch_block);
280     };
281
282     auto kernel_offset = [=](int ii, int ki) {
283         return jcp.typesize_in * ((ii * jcp.kh * jcp.kw + ki) * jcp.ch_block);
284     };
285
286     auto compute = [=](Zmm vreg_acc, Zmm vreg_wei,
287             Zmm vreg_src) {
288         // okay for depthwise since src is zero-extended
289         if (jcp.ver == ver_vnni) {
290             vpdpbusd(vreg_acc, vreg_src, vreg_wei);
291         } else {
292             // zmm_src is a tmp register that can be safely overwritten here
293             vpmaddwd(vreg_src, vreg_src, vreg_wei);
294             vpaddd(vreg_acc, vreg_acc, vreg_src);
295         }
296     };
297
298     for (int ki = 0; ki < jcp.kw; ki++) {
299         for (int ii = 0; ii < jcp.nb_ch_blocking; ii++) {
300             int aux_kernel_offset = kernel_offset(ii, ki);
301             if (jcp.is_fast_depthwise) {
302                 vbroadcasti32x4(zmm_wei,
303                         EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
304                 vpblendmb(zmm_wei | kblend_mask, zmm_zero_blend, zmm_wei);
305             } else {
306                 vpmovsxbd(zmm_wei,
307                          EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
308             }
309             if (h_padded) {
310                 if (jcp.ver == ver_vnni) {
311                     vpxord(zmm_src, zmm_src, zmm_src);
312                     vpaddb(zmm_src, zmm_src, vmm_shift);
313                 }
314                 for (int jj = 0; jj < ur_w; jj++) {
315                     if (jcp.ver != ver_vnni) {
316                         vpxord(zmm_src, zmm_src, zmm_src);
317                         vpaddb(zmm_src, zmm_src, vmm_shift);
318                     }
319                     compute(zmm_out(jj, ii), zmm_wei, zmm_src);
320                 }
321             } else {
322                 const bool mask_flag = last_ic_block_flag != no_last_block
323                     && ii == jcp.nb_ch_blocking - 1;
324                 const Zmm r_zmm_src = mask_flag ? zmm_src | ktail_mask : zmm_src;
325                 int jj_start = get_ow_start(ki, pad_l);
326                 int jj_end = get_ow_end(ur_w, ki, pad_r);
327                 int start_ = jcp.signed_input ? 0 : jj_start;
328                 int end_ = jcp.signed_input ? ur_w : jj_end;
329                 for (int jj = start_; jj < end_; jj++) {
330                     if (jj >= jj_start && jj < jj_end) {
331                         int aux_input_offset = input_offset(jj, ii, ki);
332                         if (jcp.is_fast_depthwise) {
333                            vbroadcasti32x4(zmm_src,
334                                     EVEX_compress_addr(aux_reg_inp, aux_input_offset));
335                         } else {
336                             vpmovzxbd(r_zmm_src,
337                                     EVEX_compress_addr(aux_reg_inp, aux_input_offset));
338                         }
339                         if (jcp.signed_input) {
340                             vpaddb(zmm_src, zmm_src, vmm_shift);
341                         }
342                     } else {
343                         if (jcp.signed_input) {
344                             vpxord(zmm_src, zmm_src, zmm_src);
345                             vpaddb(zmm_src, zmm_src, vmm_shift);
346                         }
347                     }
348                     compute(zmm_out(jj, ii), zmm_wei, zmm_src);
349                 }
350             }
351         }
352     }
353 }
354
355 template<typename Vmm>
356 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::compute_ker(int ur_w, int pad_l,
357         int pad_r, ic_block_t last_ic_block_flag, bool h_padded) {
358     if (jcp.is_depthwise)
359         return compute_ker_dw(ur_w, pad_l, pad_r, last_ic_block_flag, h_padded);
360
361     int kw = jcp.kw;
362     int stride_w = jcp.stride_w;
363     int ic_block = jcp.ic_block;
364     int oc_block = jcp.oc_block;
365     int ch_block_all = jcp.ch_block * ic_block * oc_block;
366
367     int nb_oc_block = jcp.nb_oc_blocking;
368
369     auto input_offset = [=](int oi, int ic, int ki) {
370         return jcp.typesize_in
371                 * ((ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l)
372                           * jcp.ic_without_padding * jcp.ngroups + 4 * ic);
373     };
374     auto kernel_offset = [=](int ii, int ic, int ki) {
375         return jcp.typesize_in
376                 * ((ii * jcp.nb_ic * jcp.kh * jcp.kw + ki) * ch_block_all
377                     + 4 * ic * oc_block);
378     };
379     auto compute = [=](Vmm vreg_acc, Vmm vreg_wei, Vmm vreg_src) {
380         if (jcp.ver == ver_vnni) {
381             vpdpbusd(vreg_acc, vreg_src, vreg_wei);
382         } else {
383             vpmaddubsw(vmm_tmp, vreg_src, vreg_wei);
384             vpmaddwd(vmm_tmp, vmm_tmp, vmm_one);
385             vpaddd(vreg_acc, vreg_acc, vmm_tmp);
386         }
387     };
388
389     for (int ki = 0; ki < kw; ki++) {
390         int jj_start = get_ow_start(ki, pad_l);
391         int jj_end = get_ow_end(ur_w, ki, pad_r);
392         int tail_size = jcp.ic_without_padding % 4;
393         int _start = (jcp.signed_input) ? 0 : jj_start;
394         int _end = (jcp.signed_input) ? ur_w : jj_end;
395         /* Skip the last loads of input if (ic%16)/4 < ic_block/4 */
396         int icb = (last_ic_block_flag != no_last_block)
397             ? div_up((jcp.ic_without_padding % ic_block), 4)
398             : ic_block / 4;
399         for (int ic = 0; ic < icb; ic++) {
400             if (h_padded == true) {
401                 /* fill padded area with shifted values */
402                 Vmm inp = vmm_inp(0,nb_oc_block);
403                 vpxord(inp, inp, inp);
404                 vpaddb(inp, inp, vmm_shift);
405             } else {
406                 for (int jj = _start; jj < _end; jj++) {
407                     int aux_input_offset = input_offset(jj, ic, ki);
408                     if (jj >= jj_start && jj < jj_end) {
409                         if (last_ic_block_flag == last_sp_block
410                                 && tail_size != 0 && ic == icb - 1) {
411                             Xmm xmm_tmp = Xmm(vmm_inp(jj, nb_oc_block).getIdx());
412                             for (int r = 0; r < tail_size; ++r)
413                                 vpinsrb(xmm_tmp, xmm_tmp,
414                                     ptr[aux_reg_inp + aux_input_offset + r], r);
415                             vpbroadcastd(vmm_inp(jj, nb_oc_block), xmm_tmp);
416                         } else {
417                             vpbroadcastd(vmm_inp(jj, nb_oc_block),
418                                     EVEX_compress_addr(
419                                                  aux_reg_inp, aux_input_offset));
420                         }
421                         if (jcp.signed_input)
422                             vpaddb(vmm_inp(jj, nb_oc_block),
423                                    vmm_inp(jj, nb_oc_block), vmm_shift);
424                     } else {
425                         /* fill padded area with shifted values */
426                         if (jcp.signed_input) {
427                             Vmm inp = vmm_inp(jj, nb_oc_block);
428                             vpxord(inp, inp, inp);
429                             vpaddb(inp, inp, vmm_shift);
430                         }
431                     }
432                 }
433             }
434             for (int ii = 0; ii < nb_oc_block; ii++) {
435                 int aux_kernel_offset = kernel_offset(ii, ic, ki);
436                 vmovups(vmm_wei,
437                         EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
438                 for (int jj = _start; jj < _end; jj++)  {
439                     Vmm inp = (h_padded == true)
440                         ? vmm_inp(0,nb_oc_block) : vmm_inp(jj, nb_oc_block);
441                     compute(vmm_out(jj, ii), vmm_wei, inp);
442                 }
443             }
444         }
445     }
446 }
447
448 template<typename Vmm>
449 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::kh_loop(
450         int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag) {
451     Label kh_label, skip_kh_loop;
452     Label t_overflow_label, no_t_overflow_label,
453           b_overflow_label, no_b_overflow_label;
454
455     int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block;
456     int shift_kernel_ptr = jcp.typesize_in * jcp.kw * ch_block_all;
457     int shift_input_ptr = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw
458         * jcp.ic_without_padding * jcp.ngroups;
459
460     mov(aux_reg_inp, reg_inp);
461     mov(aux_reg_ker, reg_ker);
462
463     if (jcp.signed_input) {
464         mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]);
465         cmp(reg_overflow, 0);
466         je(no_t_overflow_label, T_NEAR);
467         L(t_overflow_label); {
468             compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
469
470             add(aux_reg_ker, shift_kernel_ptr);
471             dec(reg_overflow);
472             cmp(reg_overflow, 0);
473             jg(t_overflow_label, T_NEAR);
474         }
475         L(no_t_overflow_label);
476     }
477     mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
478     if ((jcp.signed_input) || (!jcp.signed_input &&
479        (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad))) {
480         cmp(reg_kj, 0);
481         je(skip_kh_loop, T_NEAR);
482     }
483     L(kh_label); {
484         compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, false);
485
486         add(aux_reg_ker, shift_kernel_ptr);
487         add(aux_reg_inp, shift_input_ptr);
488         dec(reg_kj);
489         cmp(reg_kj, 0);
490         jg(kh_label, T_NEAR);
491     }
492     L(skip_kh_loop);
493     if (jcp.signed_input) {
494         mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]);
495         cmp(reg_overflow, 0);
496         je(no_b_overflow_label, T_NEAR);
497         L(b_overflow_label); {
498             compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
499
500             add(aux_reg_ker, shift_kernel_ptr);
501             dec(reg_overflow);
502             cmp(reg_overflow, 0);
503             jg(b_overflow_label, T_NEAR);
504         }
505         L(no_b_overflow_label);
506     }
507 }
508
509 template<typename Vmm>
510 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::icb_loop(
511         int ur_w, int pad_l, int pad_r, bool is_last_sp_block)
512 {
513     prepare_output(ur_w);
514
515     // IC loop
516     Label icb_label;
517     mov(reg_icb, jcp.nb_ic);
518     L(icb_label);
519     if (jcp.ngroups % jcp.ch_block != 0 || jcp.ic_without_padding != jcp.ic) {
520         Label common_ker, end_ker;
521
522         cmp(reg_icb, 1); // The last IC block
523         jne(common_ker, T_NEAR);
524
525         kh_loop(ur_w, pad_l, pad_r,
526                 is_last_sp_block ? last_sp_block : last_ic_block);
527         jmp(end_ker, T_NEAR);
528
529         L(common_ker);
530         kh_loop(ur_w, pad_l, pad_r, no_last_block);
531
532         L(end_ker);
533     } else {
534         kh_loop(ur_w, pad_l, pad_r, no_last_block);
535     }
536     // End of IC Loop
537     int inp_step = jcp.ic_block;
538     int ker_step = jcp.kh * jcp.kw * jcp.oc_block * jcp.ic_block;
539     add(reg_inp, jcp.typesize_in * inp_step);
540     add(reg_ker, jcp.typesize_in * ker_step);
541
542     dec(reg_icb);
543     cmp(reg_icb, 0);
544     jg(icb_label, T_NEAR);
545
546     sub(reg_inp, jcp.typesize_in * inp_step * jcp.nb_ic);
547     sub(reg_ker, jcp.typesize_in * ker_step * jcp.nb_ic);
548
549     if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
550         Label common_store, end_store;
551
552         if (jcp.is_depthwise)
553             cmp(reg_oc_blocks, jcp.nb_ch - jcp.nb_ch_blocking);
554         else
555             cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
556
557         jne(common_store, T_NEAR);
558
559         store_output(ur_w, true); // last oc block
560         jmp(end_store, T_NEAR);
561
562         L(common_store);
563         store_output(ur_w, false);
564
565         L(end_store);
566     } else {
567         store_output(ur_w, false);
568     }
569 }
570
571 template<typename Vmm>
572 void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::generate()
573 {
574     const auto &p = attr_.post_ops_;
575     for (int i = 0; i < p.len_; i++) {
576         auto &post_op = p.entry_[i];
577         if (post_op.is_eltwise()) {
578             eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<avx512_common>(
579                     this,
580                     post_op.eltwise.alg,
581                     post_op.eltwise.alpha,
582                     post_op.eltwise.beta
583             ));
584         } else if (post_op.is_depthwise()) {
585             depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<avx512_common>(
586                     this,
587                     post_op.depthwise.alg
588             ));
589         }
590     }
591
592     Label permute_index_table;
593     int inp_shift_pad = jcp.typesize_in * (jcp.ur_w * jcp.stride_w - jcp.l_pad)
594         * jcp.ic_without_padding * jcp.ngroups;
595     int inp_shift_pad_second_block = -1 * jcp.typesize_in * jcp.l_pad
596         * jcp.ic_without_padding * jcp.ngroups;
597     int inp_shift = jcp.typesize_in *
598                         (jcp.ur_w * jcp.stride_w * jcp.ic_without_padding
599                          * jcp.ngroups);
600     int out_shift = jcp.typesize_out *
601                         (jcp.ur_w * jcp.oc_without_padding * jcp.ngroups);
602     preamble();
603
604     if (jcp.is_depthwise) {
605         zmm_src = Zmm(jcp.max_regs_ur);
606         if (jcp.is_fast_depthwise) {
607             zmm_zero_blend = Zmm(jcp.max_regs_ur + 1);
608             zmm_permute = Zmm(jcp.max_regs_ur + 2);
609         }
610     }
611
612     if (!jcp.is_depthwise && jcp.ver != ver_vnni) {
613         xor_(reg_scratch, reg_scratch);
614         Reg16 _t16 = reg_scratch.cvt16();
615         mov(_t16, 0x1);
616         vpbroadcastw(vmm_one, _t16);
617     }
618
619     mov(reg_inp, ptr[param1 + GET_OFF(src)]);
620     mov(reg_out, ptr[param1 + GET_OFF(dst)]);
621     mov(reg_ker, ptr[param1 + GET_OFF(filt)]);
622
623     if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
624         int tail_size = jcp.is_depthwise
625             ? jcp.ngroups % jcp.ch_block
626             : jcp.oc_without_padding % jcp.oc_block;
627         int mask = (1 << tail_size) - 1;
628         mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
629         Reg32 regw_tmp = reg_oi.cvt32();
630         mov(regw_tmp, mask);
631         kmovw(ktail_mask, regw_tmp);
632     }
633     if (jcp.is_fast_depthwise) {
634         // prepare mask register for blending weights
635         mov(reg_scratch, 0x8888444422221111);
636         kmovq(kblend_mask, reg_scratch);
637         // load permute indices from data section
638         mov(reg_scratch, permute_index_table);
639         vmovdqu32(zmm_permute, ptr[reg_scratch]);
640     }
641
642     int r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w
643                     + (jcp.kw - 1) * (jcp.dilate_w + 1)
644                     - (jcp.iw + jcp.l_pad - 1));
645     int n_oi = jcp.ow / jcp.ur_w;
646     int r_pad1 = (jcp.ur_w * n_oi - 1) * jcp.stride_w
647         + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1);
648
649     if (jcp.nb_ow == 1) {
650         if (r_pad1 > 0 || jcp.ur_w_tail == 0)
651             n_oi--;
652
653         xor_(reg_oi, reg_oi);
654         if (jcp.ow == jcp.ur_w) {
655             icb_loop(jcp.ur_w, jcp.l_pad, r_pad, true);
656         } else {
657             if (n_oi == 0) {
658                 icb_loop(jcp.ur_w, jcp.l_pad, r_pad1, jcp.ur_w_tail == 0);
659                 add(reg_inp, inp_shift_pad);
660                 add(reg_out, out_shift);
661                 if (jcp.ur_w_tail != 0) {
662                     icb_loop(jcp.ur_w_tail, 0, r_pad, true);
663                 }
664             } else {
665                 if (jcp.l_pad > 0) {
666                     icb_loop(jcp.ur_w, jcp.l_pad, 0, false);
667                     add(reg_inp, inp_shift_pad);
668                     add(reg_out, out_shift);
669
670                     inc(reg_oi);
671                 }
672                 if ((jcp.l_pad <= 0 && n_oi > 0) || (jcp.l_pad > 0 && n_oi > 1))
673                 {
674                     Label ow_loop_label;
675                     L(ow_loop_label); {
676                         icb_loop(jcp.ur_w, 0, 0, false);
677                         add(reg_inp, inp_shift);
678                         add(reg_out, out_shift);
679
680                         inc(reg_oi);
681                         cmp(reg_oi, n_oi);
682                         jl(ow_loop_label, T_NEAR);
683                     }
684                 }
685                 if (r_pad1 > 0 || jcp.ur_w_tail == 0) {
686                     icb_loop(jcp.ur_w, 0, r_pad1, jcp.ur_w_tail == 0);
687                     add(reg_inp, inp_shift);
688                     add(reg_out, out_shift);
689                 }
690                 if (jcp.ur_w_tail != 0) {
691                     icb_loop(jcp.ur_w_tail, 0, r_pad, true);
692                 }
693             }
694         }
695     } else {
696         // ow block is only processed.
697         // Number of block is passed as parameter owb,
698         // and padding processing depends on this number.
699         Label end_label, last_oi_label, middle_ow_blocks_label, tail_label,
700             oi_loop_label, oi_loop_end_label;
701
702         assert(jcp.ow_block % jcp.ur_w == 0);
703         int n_oi_not_last_ow_block = jcp.ow_block / jcp.ur_w;
704         // to simplify code (and general regs usage),
705         // size of ow block must be >= 2 * ur_w
706         assert(n_oi_not_last_ow_block > 1);
707         int n_oi_next_last_ow_block = n_oi_not_last_ow_block;
708         int n_oi_first_ow_block = n_oi_not_last_ow_block;
709         int n_oi_last_ow_block
710             = (jcp.ow - jcp.ow_block * (jcp.nb_ow - 1)) / jcp.ur_w;
711         // prepare right padding
712         bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0;
713         bool first_ow_block_padded
714                 = next_last_ow_block_padded && jcp.nb_ow == 2;
715         bool last_ow_block_padded
716                 = (r_pad1 > 0 || jcp.ur_w_tail == 0) && n_oi_last_ow_block > 0;
717
718         if (last_ow_block_padded) n_oi_last_ow_block--;
719         else if (first_ow_block_padded) n_oi_first_ow_block--;
720         else if (next_last_ow_block_padded) n_oi_next_last_ow_block--;
721
722         mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
723         cmp(reg_owb, 0); // is that the first ow-block ?
724         jg(middle_ow_blocks_label, T_NEAR);
725
726         // the first ow block, compute left padding
727         mov(reg_oi, n_oi_first_ow_block);
728         if (jcp.l_pad > 0) {
729             icb_loop(jcp.ur_w, jcp.l_pad, 0, false);
730             add(reg_inp, inp_shift_pad);
731             add(reg_out, out_shift);
732
733             dec(reg_oi);
734         }
735         jmp(oi_loop_label, T_NEAR);
736
737         // middle or last ow block entry
738         L(middle_ow_blocks_label);
739
740         if (jcp.l_pad > 0) {
741             // just to consider left padding, not compute
742             add(reg_inp, inp_shift_pad_second_block);
743         }
744
745         // set number of iteration for oi-loop
746         if (n_oi_last_ow_block != n_oi_not_last_ow_block) {
747             cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ?
748             mov(reg_oi, n_oi_last_ow_block);
749             je(oi_loop_label, T_NEAR);
750         }
751
752         if (n_oi_next_last_ow_block != n_oi_not_last_ow_block) {
753             cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
754
755             mov(reg_oi, n_oi_next_last_ow_block);
756             je(oi_loop_label, T_NEAR);
757         }
758         mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks
759
760         // oi loop w/o padding
761         L(oi_loop_label); {
762             cmp(reg_oi, 0);
763             jle(oi_loop_end_label, T_NEAR);
764
765             icb_loop(jcp.ur_w, 0, 0, false);
766
767             add(reg_inp, inp_shift);
768             add(reg_out, out_shift);
769             dec(reg_oi);
770
771             jmp(oi_loop_label, T_NEAR);
772         }
773         L(oi_loop_end_label);
774
775         mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
776         cmp(reg_owb, 0); // first ow-block ?
777         if (first_ow_block_padded)
778             je(last_oi_label, T_NEAR);
779         else
780             je(end_label, T_NEAR);
781
782         cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
783         jl(end_label, T_NEAR);
784         if (next_last_ow_block_padded)
785             je(last_oi_label, T_NEAR);
786         else
787             je(end_label, T_NEAR);
788
789         // that is last block
790         if (!last_ow_block_padded)
791             jmp(tail_label, T_NEAR);
792
793         // last oi block with right padding
794         L(last_oi_label);
795         icb_loop(jcp.ur_w, 0, r_pad1, jcp.ur_w_tail == 0);
796         add(reg_inp, inp_shift);
797         add(reg_out, out_shift);
798
799         mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
800         cmp(reg_owb, jcp.nb_ow - 1); // last ow_block?
801         jl(end_label, T_NEAR);
802
803         // ur_w tail
804         L(tail_label);
805         if (jcp.ur_w_tail != 0) {
806             icb_loop(jcp.ur_w_tail, 0, r_pad, true);
807         }
808         L(end_label);
809     }
810     postamble();
811
812     for (auto& inj : eltwise_injectors)
813         inj->prepare_table();
814
815     if (jcp.is_fast_depthwise) {
816         align(64);
817         L(permute_index_table);
818         const uint32_t _idx[]
819                 = { 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15 };
820         for (size_t i = 0; i < sizeof(_idx) / sizeof(_idx[0]); ++i)
821             dd(_idx[i]);
822     }
823 }
824
825 bool jit_avx512_core_x8s8s32x_fwd_kernel::post_ops_ok(
826         jit_conv_conf_t &jcp, const primitive_attr_t &attr)
827 {
828     using namespace primitive_kind;
829     const auto &p = attr.post_ops_;
830
831     auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
832     auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
833     auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(false); };
834     auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
835
836     switch (p.len_) {
837         case 0: return true;
838         case 1: return is_simple(0) || is_sum(0);
839         case 2: return (is_sum(0) && is_simple(1)) || (is_simple(0) && is_sum(1)) ||
840                        (is_simple(0) && is_simple(1));
841         case 3: return (is_simple(0) && is_sum(1) && is_simple(2));
842         default: return false;
843     }
844
845     return false;
846 }
847
848 status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
849             const convolution_desc_t &cd, cpu_memory_t::pd_t &src_pd,
850             cpu_memory_t::pd_t &weights_pd, cpu_memory_t::pd_t &dst_pd,
851             cpu_memory_t::pd_t &bias_pd, const primitive_attr_t &attr,
852             int nthreads)
853 {
854     using namespace prop_kind;
855
856     const memory_desc_wrapper src_d(&src_pd);
857     const memory_desc_wrapper weights_d(&weights_pd);
858     const memory_desc_wrapper dst_d(&dst_pd);
859     const memory_desc_wrapper bias_d(&bias_pd);
860
861     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
862
863     if (!(mayiuse(avx512_core)
864          && one_of(src_d.data_type(), data_type::u8, data_type::s8)
865          && weights_d.data_type() == data_type::s8
866          && one_of(dst_d.data_type(), data_type::f32, data_type::s32,
867             data_type::s8, data_type::u8)))
868         return status::unimplemented;
869
870     jcp = zero<decltype(jcp)>();
871     jcp.prop_kind = cd.prop_kind;
872     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
873     jcp.mb = src_d.dims()[0];
874     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
875     jcp.oc_without_padding = jcp.oc;
876     jcp.ic = src_d.dims()[1] / jcp.ngroups;
877     jcp.ic_without_padding = jcp.ic;
878     jcp.ih = src_d.dims()[2];
879     jcp.iw = src_d.dims()[3];
880     jcp.oh = dst_d.dims()[2];
881     jcp.ow = dst_d.dims()[3];
882     jcp.kh = weights_d.dims()[with_groups + 2];
883     jcp.kw = weights_d.dims()[with_groups + 3];
884     jcp.t_pad = cd.padding[0][0];
885     jcp.l_pad = cd.padding[0][1];
886     jcp.stride_h = cd.strides[0];
887     jcp.stride_w = cd.strides[1];
888     jcp.src_fmt = src_d.format();
889     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
890
891     jcp.ur_h = 1;
892
893     jcp.dilate_h = cd.dilates[0];
894     jcp.dilate_w = cd.dilates[1];
895
896     jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false;
897     jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc);
898
899     if (jcp.is_depthwise) {
900         jcp.ch_block = 16;
901         jcp.ic_block = 1;
902         jcp.oc_block = 1;
903     } else {
904         jcp.ch_block = 1;
905         jcp.ic_block = 16;
906         jcp.oc_block = 16;
907
908         if (jcp.ngroups == 1) {
909             /* For non grouped convolutions, pad channels by 16 if needed */
910             jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
911             jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
912         } else if (jcp.ngroups != 1 && jcp.ic % jcp.ic_block != 0) {
913             /* For grouped convolutions, MKL-DNN doesn't support padding.
914                Use Ymm when channels per group is multiple of 8,
915                Xmm when channels per group is multiple of 4 */
916             jcp.ic_block = jcp.ic % 8 == 0 ? 8 : 4;
917             jcp.oc_block = jcp.ic_block;
918         }
919         if (jcp.ic % jcp.ic_block !=0 || jcp.oc % jcp.oc_block != 0)
920             return status::unimplemented;
921     }
922
923     jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
924             - (jcp.ih + jcp.t_pad - 1);
925
926     if (!post_ops_ok(jcp, attr))
927         return status::unimplemented;
928
929     jcp.ver = mayiuse(avx512_core_vnni) ? ver_vnni : ver_avx512_core;
930     jcp.is_fast_depthwise = true && jcp.is_depthwise && jcp.ver == ver_vnni
931         && jcp.ngroups % jcp.ch_block == 0; // for groups not multiple of 16 would require byte masking for load from src
932     if (jcp.is_depthwise) {
933         jcp.max_regs_ur = jcp.is_fast_depthwise
934             ? (jcp.signed_input ? 27 : 28)
935             : (jcp.signed_input ? 29 : 30);
936     } else {
937         jcp.max_regs_ur = jcp.ver == ver_vnni ? 31 : 28;
938     }
939
940     memory_format_t w_format;
941     if (jcp.ic_block == 16 || jcp.ch_block == 16) {
942         w_format = with_groups
943             ? (jcp.is_depthwise ? (jcp.signed_input ? Goihw16g_s8s8 : Goihw16g)
944                     : (jcp.signed_input) ? gOIhw4i16o4i_s8s8 : gOIhw4i16o4i)
945             : (jcp.signed_input) ? OIhw4i16o4i_s8s8 : OIhw4i16o4i;
946      /* Non-grouped conv will always be padded by 16*/
947     } else if (with_groups && jcp.ic_block == 8) {
948         w_format = jcp.signed_input ? gOIhw2i8o4i_s8s8 : gOIhw2i8o4i;
949     } else {
950         w_format = jcp.signed_input ? gOIhw4o4i_s8s8 : gOIhw4o4i;
951     }
952
953     if (weights_d.format() == any)
954         CHECK(weights_pd.set_format(w_format));
955     if (weights_d.format() != w_format)
956         return status::unimplemented;
957
958     if (dst_d.format() == any)
959         CHECK(dst_pd.set_format(nhwc));
960     if (dst_d.format() != nhwc)
961         return status::unimplemented;
962     if (src_d.format() == any)
963         CHECK(src_pd.set_format(nhwc));
964     if (src_d.format() != nhwc)
965         return status::unimplemented;
966     if (jcp.with_bias) {
967         if (bias_d.format() == any)
968             CHECK(bias_pd.set_format(x));
969         if (bias_d.format() != x)
970             return status::unimplemented;
971     }
972
973     jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
974     jcp.dst_dt = cd.dst_desc.data_type;
975
976     jcp.typesize_in = types::data_type_size(src_d.data_type());
977     jcp.typesize_out = types::data_type_size(dst_d.data_type());
978     jcp.typesize_bia = jcp.with_bias
979         ? types::data_type_size(bias_d.data_type())
980         : 0;
981
982     jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block);
983     jcp.nb_ic = jcp.ic / jcp.ic_block;
984     jcp.nb_oc = jcp.oc / jcp.oc_block;
985
986     // Try to use 4 channel-groups at a time to avoid false sharing (depthwise)
987     jcp.nb_ch_blocking = jcp.is_depthwise
988         ? (jcp.nb_ch % 4 == 0 ? 4 : jcp.nb_ch % 2 == 0 ? 2 : 1)
989         : 1;
990
991     // If OC blocking is incommensurate with the number of OC blocks (general
992     // requirement for all convolutions), or if it results in an unrolling
993     // factor smaller than the left padding (special requirement for SSD:fc6),
994     // then search for a smaller OC blocking that satisfies both constraints.
995     jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc);
996     for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--) {
997         int ur_w = jcp.max_regs_ur / (jcp.nb_oc_blocking + 1);
998         if (jcp.nb_oc % jcp.nb_oc_blocking == 0
999                 && (jcp.l_pad <= ur_w
1000                          && IMPLICATION(jcp.ow != 1, jcp.ow % ur_w != 1)))
1001             break;
1002     }
1003
1004     jcp.ur_w = jcp.max_regs_ur
1005             / (jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking + 1);
1006     if (jcp.ow < jcp.ur_w)
1007         jcp.ur_w = jcp.ow;
1008     jcp.ur_w_tail = jcp.ow % jcp.ur_w;
1009
1010     jcp.ow_block = jcp.ow;
1011     int base_work_amount
1012             = jcp.mb * jcp.nb_ch * jcp.oh * (jcp.nb_oc / jcp.nb_oc_blocking);
1013     float best_thr_eff
1014             = (float)base_work_amount / rnd_up(base_work_amount, nthreads);
1015     int max_nb_ow = div_up(jcp.ow, 2 * jcp.ur_w);
1016     for (int nb_ow = 1; nb_ow <= max_nb_ow; nb_ow++) {
1017         int ow_block
1018                 = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), jcp.ur_w), jcp.ow);
1019         if (ow_block < jcp.nb_oc_blocking * jcp.oc_block && best_thr_eff > 0.8f)
1020             break;
1021         if (div_up(jcp.ow, ow_block) != nb_ow)
1022             continue;
1023         auto work_amount = base_work_amount * nb_ow;
1024         float thr_eff = (float)work_amount / rnd_up(work_amount, nthreads);
1025         if (ow_block >= 2 * jcp.ur_w && thr_eff > 1.1f * best_thr_eff) {
1026             jcp.ow_block = ow_block;
1027             best_thr_eff = thr_eff;
1028         }
1029         if (best_thr_eff > 0.9f)
1030             break;
1031     }
1032     jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
1033
1034     bool args_ok = true
1035         && jcp.oc % jcp.oc_block == 0
1036         && jcp.l_pad <= jcp.ur_w
1037         && IMPLICATION(!jcp.is_1stconv, jcp.ic % jcp.ic_block == 0);
1038     if (!args_ok)
1039         return status::unimplemented;
1040
1041     int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
1042                     + (jcp.kw - 1) * (jcp.dilate_w + 1)
1043                     - (jcp.iw + jcp.l_pad - 1));
1044     if (r_pad_no_tail > jcp.ur_w)
1045         return status::unimplemented;
1046
1047     pick_loop_order(jcp, nthreads);
1048
1049     jcp.nb_ic_L2 = jcp.nb_ic;
1050
1051     const auto &oscales = attr.output_scales_;
1052     jcp.is_oc_scale = oscales.mask_ == 1 << 1;
1053
1054     assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
1055
1056     jcp.wei_adj_scale = (jcp.signed_input) ? (1.f / 2.f) : 1.f;
1057
1058     return status::success;
1059 }
1060
1061 void jit_avx512_core_x8s8s32x_fwd_kernel::init_scratchpad(
1062         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
1063         const primitive_attr_t &attr) {
1064     if (jcp.signed_input && jcp.ver != ver_vnni) {
1065         size_t count = nstl::max(attr.output_scales_.count_, jcp.ic_block);
1066         scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count);
1067     }
1068 }
1069
1070 template struct  _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>;
1071 template struct  _jit_avx512_core_x8s8s32x_fwd_kernel<Ymm>;
1072 template struct  _jit_avx512_core_x8s8s32x_fwd_kernel<Xmm>;
1073 }
1074 }
1075 }
1076
1077 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s