Publishing R5 content (#72)
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 #include <float.h>
17 #include "c_types_map.hpp"
18 #include "nstl.hpp"
19 #include "type_helpers.hpp"
20 #include "mkldnn_thread.hpp"
21 #include "utils.hpp"
22 #include "cpu_memory.hpp"
23
24 #include "jit_uni_1x1_conv_utils.hpp"
25 #include "jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp"
26
27 #define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
28
29 namespace mkldnn {
30 namespace impl {
31 namespace cpu {
32
33 using namespace mkldnn::impl::memory_format;
34 using namespace mkldnn::impl::utils;
35
36 using namespace Xbyak;
37
38 bool jit_avx512_core_x8s8s32x_1x1_conv_kernel::maybe_relu(int position)
39 {
40     using namespace primitive_kind;
41     const auto &p = attr_.post_ops_;
42
43     if (position == 0) {
44         /* relu before sum */
45         return false
46             || jcp.with_eltwise
47             || p.contain(eltwise, 0)
48             || (jcp.dst_dt == data_type::u8 && !p.contain(sum, 0));
49     } else if (position == 1) {
50         /* relu after sum */
51         const int sum_idx = p.contain(sum, 0)
52             ? 0 : (p.contain(sum, 1) ? 1 : -1);
53         if (sum_idx == -1)
54             return false;
55
56         return false
57             || p.contain(eltwise, sum_idx + 1)
58             || jcp.dst_dt == data_type::u8;
59     }
60
61     return false;
62 }
63
64 void jit_avx512_core_x8s8s32x_1x1_conv_kernel::bcast_loop(int load_loop_blk)
65 {
66     mov(aux1_reg_bcast_data, reg_bcast_data);
67     mov(aux_reg_bcast_data, reg_bcast_data);
68
69     mov(aux_reg_output_data, reg_output_data);
70     mov(bcast_loop_iter, EVEX_compress_addr(rsp, bcast_loop_work_off));
71
72     Label bcast_loop;
73     Label bcast_loop_tail;
74
75     cmp(bcast_loop_iter, jcp.ur);
76     jl(bcast_loop_tail, T_NEAR);
77
78     L(bcast_loop); {
79         assert(jcp.bcast_block % jcp.ur == 0);
80         int num_substeps = jcp.bcast_block / jcp.ur;
81         assert(num_substeps > 0 && num_substeps < 10);
82         for (int i = 0; i < num_substeps; i++) {
83             reduce_loop(load_loop_blk, jcp.ur, i, false);
84             if (i < num_substeps - 1) {
85                 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
86                 add(aux_reg_output_data, jcp.bcast_loop_output_substep);
87             }
88             else {
89                 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step
90                     - (num_substeps - 1) * jcp.bcast_loop_bcast_substep);
91                 int output_offset = jcp.bcast_loop_output_step
92                     - (num_substeps - 1) * jcp.bcast_loop_output_substep;
93
94                 add(aux_reg_output_data, output_offset);
95             }
96         }
97         sub(bcast_loop_iter, jcp.bcast_block);
98         cmp(bcast_loop_iter, jcp.bcast_block);
99         jge(bcast_loop, T_NEAR);
100     }
101
102     L(bcast_loop_tail);
103     if (jcp.ur_tail) {
104         Label bcast_loop_tail_out;
105         cmp(bcast_loop_iter, 0);
106         jz(bcast_loop_tail_out, T_NEAR);
107         reduce_loop(load_loop_blk, jcp.ur_tail, 0, true);
108         L(bcast_loop_tail_out);
109     }
110 }
111
112 void jit_avx512_core_x8s8s32x_1x1_conv_kernel::cvt2ps(data_type_t type_in,
113         zmm_t zmm_in, const Xbyak::Operand &op, bool mask_flag) {
114     zmm_t zmm = mask_flag ? zmm_in | ktail_mask | T_z : zmm_in;
115     switch (type_in) {
116     case data_type::f32:
117     case data_type::s32: vmovups(zmm, op); break;
118     case data_type::s8: vpmovsxbd(zmm, op); break;
119     case data_type::u8: vpmovzxbd(zmm, op); break;
120     default: assert(!"unsupported data type");
121     }
122     if (type_in != data_type::f32)
123         vcvtdq2ps(zmm_in, zmm_in);
124 }
125
126 void jit_avx512_core_x8s8s32x_1x1_conv_kernel::reduce_loop(int load_loop_blk,
127          int ur, int substep, bool wraparound)
128 {
129     auto vreg_load = [=](int i_load) {
130         return Zmm(ur * load_loop_blk + i_load);
131     };
132
133     auto vreg_accum = [=](int i_load, int i_ur) {
134         return Zmm(i_ur * load_loop_blk + i_load);
135     };
136
137     auto zmm_bias_alpha = [=]() {
138         return Zmm(ur * load_loop_blk);
139     };
140
141     auto xmm_bias_alpha = [=]() {
142         return Xmm(ur * load_loop_blk);
143     };
144     auto bias_ptr = [=](int i_load) {
145         return EVEX_compress_addr(reg_bias_data,
146                                   jcp.typesize_bia * jcp.oc_block * i_load);
147     };
148
149     auto comp_ptr = [=](int i_load) {
150         return EVEX_compress_addr(reg_comp_data,
151                                   sizeof(int32_t) * jcp.oc_block * i_load);
152     };
153
154     auto scale_ptr = [=](int i_load) {
155         return EVEX_compress_addr(reg_ptr_scales,
156                     jcp.is_oc_scale * (sizeof(float) * jcp.oc_block * i_load));
157     };
158
159     auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast) {
160         assert(i_ur < jcp.ur);
161         assert(i_reduce <= jcp.reduce_loop_unroll);
162         assert(jcp.reduce_loop_unroll == jcp.reduce_block);
163
164         int offt = (jcp.ic_without_padding * i_ur + i_reduce);
165
166         return EVEX_compress_addr(aux_reg_bcast_data, jcp.typesize_in * offt,
167                                 bcast);
168     };
169
170     auto load_ptr = [=](int i_reduce, int i_load) {
171         int u0 = i_reduce % jcp.reduce_loop_unroll;
172         int u1 = i_reduce / jcp.reduce_loop_unroll;
173
174         int offt = (i_load * jcp.reduce_dim + u0) * jcp.load_block;
175
176         return EVEX_compress_addr(aux_reg_load_data,
177                                   u1 * jcp.reduce_loop_load_step
178                                   + jcp.typesize_in * offt);
179     };
180
181     auto output_ptr = [=](int i_load, int i_ur) {
182         return EVEX_compress_addr(aux_reg_output_data,
183             jcp.typesize_out * (jcp.oc_without_padding * i_ur
184                                 + i_load * jcp.load_block));
185     };
186
187     auto init = [=]() {
188         for (int i_load = 0; i_load < load_loop_blk; ++i_load)
189             for (int i_ur = 0; i_ur < ur; ++i_ur) {
190                 auto r = vreg_accum(i_load, i_ur);
191                 vpxord(r, r, r);
192             }
193         if (jcp.signed_input) {
194             xor_(reg_scratch, reg_scratch);
195             Reg8 _t8 = reg_scratch.cvt8();
196             mov(_t8, (int8_t)-128);
197             vpbroadcastb(zmm_shift, _t8);
198         }
199     };
200
201     auto store = [=](const bool mask_flag_in) {
202         const auto &p = attr_.post_ops_;
203         const int sum_idx = p.find(primitive_kind::sum);
204         const float *p_sum_scale = (sum_idx != -1)
205             ? &p.entry_[sum_idx].sum.scale
206             : nullptr;
207         mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data);
208         mov(reg_ptr_scales, EVEX_compress_addr(rsp, reg_ptr_sum_scale_off));
209         if (p_sum_scale && *p_sum_scale != 1.f) {
210             mov(EVEX_compress_addr(rsp, reg_load_data_off), reg_load_data);
211             mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
212         }
213         if (jcp.signed_input && jcp.ver != ver_vnni) {
214             mov(reg_scratch, float2int(jcp.wei_adj_scale));
215             vmovq(xmm_bias_alpha(), reg_scratch);
216             vbroadcastss(zmm_bias_alpha(), xmm_bias_alpha());
217         }
218         for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
219             const bool mask_flag = mask_flag_in && i_load == load_loop_blk - 1;
220             auto zmm_bias = zmm_tmp;
221             auto zmm_comp = zmm_bcast;
222             if (jcp.with_bias) {
223                 if (jcp.signed_input)
224                     mov(reg_bias_data,
225                         EVEX_compress_addr(rsp,reg_bias_data_off));
226                 cvt2ps(jcp.bia_dt, zmm_bias, bias_ptr(i_load), mask_flag);
227                 if (jcp.signed_input && jcp.ver != ver_vnni)
228                     vmulps(zmm_bias, zmm_bias, zmm_bias_alpha());
229             }
230             if (jcp.signed_input) {
231                 mov(reg_comp_data, EVEX_compress_addr(rsp, reg_comp_data_off));
232                 cvt2ps(data_type::s32, zmm_comp, comp_ptr(i_load), mask_flag);
233             }
234
235             for (int i_ur = 0; i_ur < ur; ++i_ur) {
236                 auto r = vreg_accum(i_load, i_ur);
237                 vcvtdq2ps(r, r);
238                 if (jcp.signed_input)
239                     vaddps(r, r, zmm_comp);
240                 if (jcp.with_bias)
241                     vaddps(r, r, zmm_bias);
242
243                 zmm_t mask_zmm = mask_flag ? r | ktail_mask | T_z : r;
244                 vmulps(mask_zmm, r, scale_ptr(i_load));
245                 if (maybe_relu(0)) {
246                     vpxord(zmm_zero, zmm_zero, zmm_zero);
247                     vmaxps(r, zmm_zero, r);
248                 }
249                 if (p_sum_scale) { // post_op: sum
250                     vpxord(zmm_zero, zmm_zero, zmm_zero);
251                     auto zmm_prev_dst = zmm_zero;
252
253                     cvt2ps(jcp.dst_dt, zmm_prev_dst, output_ptr(i_load, i_ur),
254                         mask_flag);
255
256                     if (*p_sum_scale == 1.f)
257                         vaddps(r, zmm_prev_dst);
258                     else
259                         vfmadd231ps(r, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
260                 }
261                 if (maybe_relu(1)) {
262                     vpxord(zmm_zero, zmm_zero, zmm_zero);
263                     vmaxps(r, zmm_zero, r);
264                 }
265                 if (jcp.dst_dt != data_type::f32) {
266                     if (attr_.round_mode_ == round_mode::nearest) {
267                         vcvtps2dq(r | T_rn_sae, r);
268                     } else if (attr_.round_mode_ == round_mode::down) {
269                         vcvtps2dq(r | T_rd_sae, r);
270                     } else
271                         assert(!"unimplemented");
272                 }
273             }
274             for (int i_ur = 0; i_ur < ur; ++i_ur) {
275                 auto r = vreg_accum(i_load, i_ur);
276                 zmm_t r_zmm = mask_flag ? r | ktail_mask : r;
277                 switch (jcp.dst_dt) {
278                 case data_type::f32:
279                 case data_type::s32:
280                     vmovups(output_ptr(i_load, i_ur), r_zmm); break;
281                 case data_type::s8:
282                     vpmovsdb(output_ptr(i_load, i_ur), r_zmm); break;
283                 case data_type::u8:
284                     vpmovusdb(output_ptr(i_load, i_ur), r_zmm); break;
285                 default: assert(!"unknown dst_dt");
286                 }
287             }
288         }
289         mov(reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off));
290         if (p_sum_scale && *p_sum_scale != 1.f)
291             mov(reg_load_data, EVEX_compress_addr(rsp, reg_load_data_off));
292     };
293
294     auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) {
295         if (jcp.ver == ver_vnni) {
296             vpdpbusd(vreg_acc, vreg_src, vreg_wei);
297         } else {
298             vpmaddubsw(zmm_tmp, vreg_src, vreg_wei);
299             vpmaddwd(zmm_tmp, zmm_tmp, zmm_one);
300             vpaddd(vreg_acc, vreg_acc, zmm_tmp);
301         }
302     };
303
304     auto fma_block = [=](bool last_block) {
305         int reduce_step = 4;
306         int tail_size = jcp.ic_without_padding % reduce_step;
307         int loop_unroll = last_block && jcp.ic != jcp.ic_without_padding
308             ? rnd_up(jcp.ic_without_padding % jcp.ic_block, reduce_step)
309             : jcp.reduce_loop_unroll;
310         for (int i_reduce = 0; i_reduce < loop_unroll;
311                 i_reduce += reduce_step) {
312             for (int i_load = 0; i_load < load_loop_blk; ++i_load)
313                 vmovups(vreg_load(i_load), load_ptr(i_reduce, i_load));
314             for (int i_ur = 0; i_ur < ur; ++i_ur) {
315                 if (last_block && tail_size != 0
316                     && i_reduce == loop_unroll - reduce_step) {
317                     Xmm xmm_bcast = Xmm(zmm_bcast.getIdx());
318                     for (int r = 0; r < tail_size; ++r)
319                         vpinsrb(xmm_bcast, xmm_bcast, ptr[aux_reg_bcast_data
320                         + jcp.ic_without_padding * i_ur + i_reduce + r], r);
321                     vpbroadcastd(zmm_bcast, xmm_bcast);
322                 } else {
323                     vpbroadcastd(zmm_bcast, bcast_ptr(i_reduce, i_ur, false));
324                 }
325                 if (jcp.signed_input)
326                     vpsubb(zmm_bcast, zmm_bcast, zmm_shift);
327                 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
328                     compute(vreg_accum(i_load, i_ur),
329                                 vreg_load(i_load), zmm_bcast);
330                 }
331             }
332         }
333     };
334
335     Label reduce_loop;
336     Label reduce_loop_tail;
337
338     mov(aux_reg_load_data, reg_load_data);
339
340     mov(aux_reg_bcast_data, aux1_reg_bcast_data);
341     init();
342
343     mov(reduce_loop_iter, reg_reduce_loop_work);
344     sub(reduce_loop_iter, jcp.reduce_loop_unroll);
345     jle(reduce_loop_tail, T_NEAR);
346
347     L(reduce_loop); {
348         fma_block(false);
349         add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step);
350         add(aux_reg_load_data, jcp.reduce_loop_load_step);
351         sub(reduce_loop_iter, jcp.reduce_loop_unroll);
352         jg(reduce_loop, T_NEAR);
353     }
354
355     L(reduce_loop_tail);
356     if (jcp.ic != jcp.ic_without_padding) {
357         fma_block(true);
358     } else {
359         fma_block(false);
360     }
361
362     if (jcp.oc_without_padding != jcp.oc) {
363         Label end_store, common_store;
364         mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data);
365
366         /*Check if it is the last load_loop_blk*/
367         sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
368         cmp(reg_load_loop_work, 0);
369         jg(common_store, T_NEAR);
370
371         /*Check if it is the last ocb*/
372         test(reg_reduce_pos_flag, FLAG_OC_LAST);
373         jz(common_store, T_NEAR);
374
375         store(true);
376         jmp(end_store, T_NEAR);
377
378         L(common_store);
379         store(false);
380
381         L(end_store);
382
383         add(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
384     } else {
385         store(false);
386     }
387 }
388
389 void jit_avx512_core_x8s8s32x_1x1_conv_kernel::generate()
390 {
391     preamble();
392
393     xor_(reg_scratch, reg_scratch);
394     Reg16 _t = reg_scratch.cvt16();
395     mov(_t, 0x1);
396     vpbroadcastw(zmm_one, _t);
397
398     sub(rsp, stack_space_needed);
399
400     if (jcp.oc_without_padding != jcp.oc) {
401         int tail_size = jcp.oc_without_padding % jcp.oc_block;
402         int mask = (1 << tail_size) - 1;
403         Reg32 regw_tmp = reg_last_load.cvt32();
404         mov(regw_tmp, mask);
405         kmovw(ktail_mask, regw_tmp);
406     }
407
408     if (jcp.with_bias)
409         mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
410     if (jcp.signed_input) {
411         mov(EVEX_compress_addr(rsp, reg_bias_data_off), reg_bias_data);
412         mov(reg_comp_data, ptr[param1 + GET_OFF(compensation)]);
413         mov(EVEX_compress_addr(rsp, reg_comp_data_off), reg_comp_data);
414     }
415     mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
416     mov(EVEX_compress_addr(rsp, reg_ptr_sum_scale_off), reg_ptr_scales);
417     mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
418     mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]);
419     mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]);
420
421     mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]);
422     mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]);
423     mov(EVEX_compress_addr(rsp, bcast_loop_work_off), reg_bcast_loop_work);
424     mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
425     mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
426
427
428     auto load_loop_body = [=](int load_loop_blk) {
429         bcast_loop(load_loop_blk);
430         add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
431         if (jcp.with_bias) {
432             if (jcp.signed_input)
433                 mov(reg_bias_data, EVEX_compress_addr(rsp, reg_bias_data_off));
434             add(reg_bias_data,
435                 load_loop_blk * jcp.load_block * jcp.typesize_bia);
436             if (jcp.signed_input)
437                 mov(EVEX_compress_addr(rsp, reg_bias_data_off), reg_bias_data);
438         }
439         if (jcp.signed_input) {
440             mov(reg_comp_data, EVEX_compress_addr(rsp, reg_comp_data_off));
441             add(reg_comp_data,
442                 load_loop_blk * jcp.load_block * sizeof(int32_t));
443             mov(EVEX_compress_addr(rsp, reg_comp_data_off), reg_comp_data);
444         }
445         mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data);
446         mov(reg_ptr_scales, EVEX_compress_addr(rsp, reg_ptr_sum_scale_off));
447         add(reg_ptr_scales,
448             jcp.is_oc_scale * load_loop_blk * jcp.load_block * sizeof(float));
449         mov(EVEX_compress_addr(rsp, reg_ptr_sum_scale_off), reg_ptr_scales);
450         mov(reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off));
451         add(reg_output_data,
452             load_loop_blk * jcp.load_block * jcp.typesize_out);
453         sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
454     };
455
456     const int simd_w = 16;
457
458     Label load_loop_blk[7];
459
460     static const int ur_cases_fma_expl_bcast[] = { 2, 5, 6, 9, 14, 32 };
461     const int size_ur_cases_fma = sizeof(ur_cases_fma_expl_bcast);
462     const int *ur_cases_fma = ur_cases_fma_expl_bcast;
463     const int *ur_cases = ur_cases_fma;
464     const int num_ur_cases = (size_ur_cases_fma) / sizeof(*ur_cases);
465
466     for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) {
467         int label_idx = num_ur_cases - ur_idx - 1;
468         if (jcp.ur <= ur_cases[ur_idx]) {
469             cmp(reg_load_loop_work, simd_w * (label_idx + 1));
470             jle(load_loop_blk[label_idx], T_NEAR);
471         }
472     }
473
474     for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) {
475         if (jcp.ur <= ur_cases[ur_idx]) {
476             int label_idx = num_ur_cases - ur_idx - 1;
477             L(load_loop_blk[label_idx]);
478             {
479                 if (label_idx == 0) {
480                     cmp(reg_load_loop_work, 0);
481                     je(load_loop_blk[num_ur_cases], T_NEAR);
482                 }
483                 load_loop_body(label_idx + 1);
484                 if (label_idx - 1 > 0) {
485                     cmp(reg_load_loop_work, 2 * label_idx * simd_w);
486                     je(load_loop_blk[label_idx - 1], T_NEAR);
487                 }
488                 cmp(reg_load_loop_work, (label_idx + 1) * simd_w);
489                 jge(load_loop_blk[label_idx]);
490             }
491             for (int idx = label_idx - 1; idx > 0; --idx) {
492                 cmp(reg_load_loop_work, simd_w * (idx + 1));
493                 je(load_loop_blk[idx], T_NEAR);
494             }
495             if (ur_idx < num_ur_cases - 2) {
496                 cmp(reg_load_loop_work, simd_w);
497                 jle(load_loop_blk[0], T_NEAR);
498             }
499         }
500     }
501     L(load_loop_blk[num_ur_cases]);
502
503     add(rsp, stack_space_needed);
504
505     postamble();
506 }
507
508 bool jit_avx512_core_x8s8s32x_1x1_conv_kernel::post_ops_ok(
509         jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
510     using namespace primitive_kind;
511     const auto &p = attr.post_ops_;
512
513     auto is_relu = [&](int idx) {
514         return p.entry_[idx].kind == eltwise
515             && p.entry_[idx].eltwise.scale == 1.
516             && p.entry_[idx].eltwise.alg == alg_kind::eltwise_relu
517             && p.entry_[idx].eltwise.alpha == 0.;
518     };
519
520     switch (p.len_) {
521     case 0: return true;
522     case 1: return true
523                 && IMPLICATION(jcp.with_eltwise, p.contain(sum, 0))
524                 && IMPLICATION(!jcp.with_eltwise, is_relu(0) || p.contain(sum, 0));
525     case 2: return true
526                 && IMPLICATION(jcp.with_eltwise, p.contain(sum, 0) && is_relu(1))
527                 && IMPLICATION(!jcp.with_eltwise, false
528                         || (p.contain(sum, 0) && is_relu(1))
529                         || (p.contain(sum, 1) && is_relu(0)));
530     case 3: return true
531                 && jcp.with_eltwise == false
532                 && (is_relu(0) && p.contain(sum, 1) && is_relu(2));
533     default: return false;
534     }
535
536     return false;
537 }
538
539 status_t jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_conf(
540         jit_1x1_conv_conf_t &jcp, const convolution_desc_t &cd,
541         const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
542         const memory_desc_wrapper &dst_d, const memory_desc_wrapper &bias_d,
543         const primitive_attr_t &attr, bool with_relu, float relu_negative_slope,
544         int nthreads, bool reduce_src)
545 {
546     if (!mayiuse(avx512_core)) return status::unimplemented;
547
548     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
549     if (!one_of(src_d.data_type(), data_type::u8, data_type::s8)
550         || weights_d.data_type() != data_type::s8
551         || !one_of(dst_d.data_type(),
552             data_type::f32, data_type::s32, data_type::s8, data_type::u8))
553         return status::unimplemented;
554     if (!one_of(weights_d.format(), gOIhw4i16o4i, OIhw4i16o4i,
555                 gOIhw4i16o4i_s8s8, OIhw4i16o4i_s8s8)) {
556         return status::unimplemented;
557     }
558     jcp.ver = ver_avx512_core;
559     if (mayiuse(avx512_core_vnni))
560         jcp.ver = ver_vnni;
561
562     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
563     jcp.mb = src_d.dims()[0];
564     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
565     jcp.oc_without_padding = jcp.oc;
566     jcp.ic = src_d.dims()[1] / jcp.ngroups;
567     jcp.ic_without_padding = jcp.ic;
568     jcp.ih = src_d.dims()[2];
569     jcp.iw = src_d.dims()[3];
570     jcp.oh = dst_d.dims()[2];
571     jcp.ow = dst_d.dims()[3];
572     jcp.kh = weights_d.dims()[with_groups + 2];
573     jcp.kw = weights_d.dims()[with_groups + 3];
574     jcp.t_pad = cd.padding[0][0];
575     jcp.l_pad = cd.padding[0][1];
576     jcp.stride_h = cd.strides[0];
577     jcp.stride_w = cd.strides[1];
578     jcp.src_fmt = src_d.format();
579     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
580     jcp.with_eltwise = with_relu;
581     jcp.eltwise_alpha = relu_negative_slope;
582     if (!IMPLICATION(with_relu, relu_negative_slope == 0.))
583         return status::unimplemented;
584
585     jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false;
586
587     jcp.os = jcp.oh * jcp.ow;
588     jcp.is = jcp.ih * jcp.iw;
589     jcp.tr_is = rnd_up(jcp.is, 4);
590
591     if (!post_ops_ok(jcp, attr))
592         return status::unimplemented;
593
594     bool args_ok = true
595         && jcp.ngroups == 1
596         && src_d.format() == nhwc
597         && one_of(cd.bias_desc.format, memory_format::undef, any, x)
598         && dst_d.format() == nhwc;
599     if (!args_ok) return status::unimplemented;
600
601     const int simd_w = 16;
602
603     jcp.oc = rnd_up(jcp.oc, simd_w);
604     jcp.ic = rnd_up(jcp.ic, simd_w);
605
606     args_ok = true
607         && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0
608         && jcp.t_pad == 0 && jcp.l_pad == 0
609         && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides
610         && jcp.kh == 1 && jcp.kw == 1;
611     if (!args_ok) return status::unimplemented;
612
613     jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
614     jcp.dst_dt = cd.dst_desc.data_type;
615
616     jcp.ic_block = jcp.oc_block = simd_w;
617
618     jcp.typesize_in = types::data_type_size(src_d.data_type());
619     jcp.typesize_out = types::data_type_size(dst_d.data_type());
620     jcp.typesize_bia = jcp.with_bias
621         ? types::data_type_size(bias_d.data_type())
622         : 0;
623
624     const int SMALL_SPATIAL = 7 * 7;
625     const int BIG_REDUCE_DIM = 1024;
626
627     int load_blocking = 0;
628     int load_blocking_max = 0;
629     int bcast_blocking = 0;
630     int bcast_blocking_max = 0;
631     int reduce_blocking = 0;
632     int reduce_blocking_max = 0;
633     jcp.load_grp_count = 1;
634     jcp.use_vmovntps = false;
635
636     const int L2_size = get_cache_size(2, true) / sizeof(jcp.typesize_in);
637     const int L2_capacity = (L2_size * 3) / 4;
638
639     int size_treshold = 28;
640     int max_regs = 0;
641     int min_regs = 6;
642     if (jcp.ver == ver_vnni)
643         max_regs = ((jcp.oh > size_treshold && jcp.ow > size_treshold)
644                     && (jcp.oc < 128 || jcp.ic < 128)) ?  min_regs : 9;
645     else
646         max_regs = 8;
647     jcp.expl_bcast = true;
648
649     const int spatial = jcp.oh;
650     jcp.ur = 1;
651     for (int ur_w = max_regs; ur_w >= min_regs; ur_w--) {
652         if ((spatial >= size_treshold && spatial % ur_w == 0)
653                 || (spatial < size_treshold && jcp.os % ur_w == 0)) {
654             jcp.ur = ur_w;
655             break;
656         }
657     }
658     if (jcp.ur == 1) {
659         jcp.ur = nstl::min(max_regs, jcp.os);
660         int os_tail = jcp.os % max_regs;
661         for (int i = max_regs; i >= min_regs; i--) {
662             int i_tail = jcp.os % i;
663             if (i_tail > os_tail || i_tail == 0) {
664                 jcp.ur = i;
665                 os_tail = i_tail;
666                 if (i_tail == 0)
667                     break;
668             }
669         }
670     }
671
672     jcp.reduce_dim = jcp.ic;
673     jcp.reduce_block = jcp.ic_block;
674
675     jcp.load_dim = jcp.oc;
676     jcp.load_block = jcp.oc_block;
677
678     jcp.bcast_dim = jcp.is;
679
680     jcp.bcast_block = jcp.ur;
681
682     jcp.reduce_loop_unroll = jcp.reduce_block;
683     jcp.reduce_loop_bcast_step
684             = jcp.reduce_loop_unroll * jcp.typesize_in;
685
686     jcp.reduce_loop_load_step
687             = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in;
688
689     jcp.bcast_loop_output_step = jcp.ur * jcp.oc_without_padding * jcp.typesize_out;
690     jcp.bcast_loop_output_substep = -1; // unused
691     jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_without_padding * jcp.typesize_in;
692     jcp.bcast_loop_bcast_substep = -1; // unused
693
694     jcp.load_loop_load_step
695             = jcp.reduce_dim * jcp.load_block * jcp.typesize_in;
696
697     jcp.load_loop_iter_step = jcp.load_block;
698
699     jcp.loop_order = reduce_src ? loop_blr : loop_lbr;
700
701     int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
702     int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
703
704     reduce_blocking = nb_reduce;
705     if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM)
706         reduce_blocking = 64;
707     else if (jcp.bcast_dim > SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM)
708         reduce_blocking = 16;
709     reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true);
710     reduce_blocking *= jcp.reduce_block;
711
712     bool cmp_reduce = reduce_blocking <= jcp.reduce_dim;
713     if (cmp_reduce)
714         jcp.loop_order = reduce_src ? loop_rbl : loop_rlb;
715     load_blocking = jcp.load_dim;
716
717     jcp.load_grp_count = div_up(nthreads, jcp.mb * jcp.ngroups * nb_bcast);
718     jcp.load_grp_count = best_divider(
719             nthreads, jcp.load_grp_count, 2 * jcp.load_grp_count, false);
720
721     if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.load_dim * jcp.reduce_dim >= L2_size) {
722         jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4);
723     } else if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.mb <= nthreads
724             && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) {
725         jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2); //
726         load_blocking = jcp.load_block;
727     }
728
729     bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast,
730                              div_up(nthreads, jcp.load_grp_count)) * jcp.bcast_block;
731     bcast_blocking = nstl::min(jcp.bcast_dim, bcast_blocking);
732     bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block);
733
734     int space_for_bcast
735             = (L2_capacity - /* kernel_size - */
736                 2 * jcp.load_block * reduce_blocking
737                     - jcp.ur * reduce_blocking - 3 * 1024);
738     if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity)
739         space_for_bcast /= 2;
740
741     int bcast_in_cache
742             = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking);
743     bcast_blocking = nstl::min(
744             bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block));
745
746     load_blocking_max = load_blocking;
747     bcast_blocking_max = bcast_blocking * 3 / 2;
748     reduce_blocking_max = reduce_blocking;
749
750     assert(load_blocking);
751     assert(load_blocking_max);
752     assert(bcast_blocking);
753     assert(bcast_blocking_max);
754     assert(reduce_blocking);
755     assert(reduce_blocking_max);
756     assert(load_blocking % jcp.load_block == 0);
757     assert(reduce_blocking % jcp.reduce_block == 0);
758     assert(load_blocking_max % jcp.load_block == 0);
759     assert(reduce_blocking_max % jcp.reduce_block == 0);
760
761     assert(jcp.reduce_loop_unroll % 4 == 0);
762     assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0);
763
764     assert(jcp.bcast_block % jcp.ur == 0);
765     assert(jcp.reduce_dim % jcp.reduce_block == 0);
766
767     jcp.ur_tail = jcp.bcast_dim % jcp.ur;
768
769     jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block;
770     jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block;
771     jcp.nb_load_blocking = load_blocking / jcp.load_block;
772     jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block;
773     jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block;
774     jcp.nb_reduce_blocking_max = reduce_blocking_max / jcp.reduce_block;
775
776     jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
777     jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
778     jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
779
780     const auto &oscales = attr.output_scales_;
781     jcp.is_oc_scale = oscales.mask_ == 1 << 1;
782     assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
783
784     jcp.wei_adj_scale = (jcp.signed_input) ? (1.f / 2.f) : 1.f;
785
786     return status::success;
787 }
788
789 }
790 }
791 }