Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_u8s8s32x_1x1_conv_kernel.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 #include <float.h>
17 #include "c_types_map.hpp"
18 #include "nstl.hpp"
19 #include "type_helpers.hpp"
20 #include "mkldnn_thread.hpp"
21 #include "utils.hpp"
22 #include "cpu_memory.hpp"
23
24 #include "jit_uni_1x1_conv_utils.hpp"
25 #include "jit_avx512_core_u8s8s32x_1x1_conv_kernel.hpp"
26
27 #define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
28
29 namespace mkldnn {
30 namespace impl {
31 namespace cpu {
32
33 using namespace mkldnn::impl::memory_format;
34 using namespace mkldnn::impl::utils;
35
36 using namespace Xbyak;
37
38 bool jit_avx512_core_u8s8s32x_1x1_conv_kernel::maybe_relu(int position)
39 {
40     using namespace primitive_kind;
41     const auto &p = attr_.post_ops_;
42
43     if (position == 0) {
44         /* relu before sum */
45         return false
46             || jcp.with_eltwise
47             || p.contain(eltwise, 0)
48             || (jcp.dst_dt == data_type::u8 && !p.contain(sum, 0));
49     } else if (position == 1) {
50         /* relu after sum */
51         const int sum_idx = p.contain(sum, 0)
52             ? 0 : (p.contain(sum, 1) ? 1 : -1);
53         if (sum_idx == -1)
54             return false;
55
56         return false
57             || p.contain(eltwise, sum_idx + 1)
58             || jcp.dst_dt == data_type::u8;
59     }
60
61     return false;
62 }
63
64 void jit_avx512_core_u8s8s32x_1x1_conv_kernel::bcast_loop(int load_loop_blk)
65 {
66     mov(aux1_reg_bcast_data, reg_bcast_data);
67     mov(aux_reg_bcast_data, reg_bcast_data);
68
69     mov(aux_reg_output_data, reg_output_data);
70     mov(aux_reg_acc_s32, reg_acc_s32);
71
72     mov(bcast_loop_iter, EVEX_compress_addr(rsp, bcast_loop_work_offt));
73
74     Label bcast_loop;
75     Label bcast_loop_tail;
76
77     cmp(bcast_loop_iter, jcp.ur);
78     jl(bcast_loop_tail, T_NEAR);
79
80     L(bcast_loop); {
81         assert(jcp.bcast_block % jcp.ur == 0);
82         int num_substeps = jcp.bcast_block / jcp.ur;
83         assert(num_substeps > 0 && num_substeps < 10);
84         for (int i = 0; i < num_substeps; i++) {
85             reduce_loop(load_loop_blk, jcp.ur, i, false);
86             if (i < num_substeps - 1) {
87                 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
88                 add(aux_reg_output_data, jcp.bcast_loop_output_substep);
89                 int ws_offset =
90                     (jcp.bcast_loop_output_substep / jcp.typesize_out)
91                         * jcp.typesize_acc;
92                 add(aux_reg_acc_s32, ws_offset);
93             }
94             else {
95                 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step
96                     - (num_substeps - 1) * jcp.bcast_loop_bcast_substep);
97                 int output_offset = jcp.bcast_loop_output_step
98                     - (num_substeps - 1) * jcp.bcast_loop_output_substep;
99                 add(aux_reg_output_data, output_offset);
100                 int ws_offset = (output_offset / jcp.typesize_out)
101                     * jcp.typesize_acc;
102                 add(aux_reg_acc_s32, ws_offset);
103             }
104         }
105         sub(bcast_loop_iter, jcp.bcast_block);
106         cmp(bcast_loop_iter, jcp.bcast_block);
107         jge(bcast_loop, T_NEAR);
108     }
109
110     L(bcast_loop_tail);
111     if (jcp.ur_tail) {
112         Label bcast_loop_tail_out;
113         cmp(bcast_loop_iter, 0);
114         jz(bcast_loop_tail_out, T_NEAR);
115         reduce_loop(load_loop_blk, jcp.ur_tail, 0, true);
116         L(bcast_loop_tail_out);
117     }
118 }
119
120 void jit_avx512_core_u8s8s32x_1x1_conv_kernel::reduce_loop(int load_loop_blk,
121          int ur, int substep, bool wraparound)
122 {
123     auto vreg_load = [=](int i_load) {
124         return Zmm(ur * load_loop_blk + i_load);
125     };
126
127     auto vreg_accum = [=](int i_load, int i_ur) {
128         return Zmm(i_ur * load_loop_blk + i_load);
129     };
130
131     auto xreg_accum = [=](int i_load, int i_ur) {
132         return Xmm(i_ur * load_loop_blk + i_load);
133     };
134
135     auto bias_ptr = [=](int i_load) {
136         return EVEX_compress_addr(reg_bias_data,
137                                   jcp.typesize_bia * jcp.oc_block * i_load);
138     };
139     auto scale_ptr = [=](int i_load) {
140         return EVEX_compress_addr(reg_ptr_scales,
141                     jcp.is_oc_scale * (sizeof(float) * jcp.oc_block * i_load));
142     };
143
144     auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast) {
145         assert(i_ur < jcp.ur);
146         assert(i_reduce <= jcp.reduce_loop_unroll);
147         assert(jcp.reduce_loop_unroll == jcp.reduce_block);
148
149         int offt = (jcp.reduce_dim * i_ur + i_reduce);
150
151         return EVEX_compress_addr(aux_reg_bcast_data, jcp.typesize_in * offt,
152                                 bcast);
153     };
154
155     auto load_ptr = [=](int i_reduce, int i_load) {
156         int u0 = i_reduce % jcp.reduce_loop_unroll;
157         int u1 = i_reduce / jcp.reduce_loop_unroll;
158
159         int offt = (i_load * jcp.reduce_dim + u0) * jcp.load_block;
160
161         return EVEX_compress_addr(aux_reg_load_data,
162                                   u1 * jcp.reduce_loop_load_step
163                                   + jcp.typesize_in * offt);
164     };
165
166     auto output_ptr = [=](int i_load, int i_ur) {
167         return EVEX_compress_addr(aux_reg_output_data,
168             jcp.typesize_out * (jcp.load_dim * i_ur + i_load * jcp.load_block));
169     };
170
171     auto acc_s32_ptr = [=](int i_load, int i_ur) {
172         return EVEX_compress_addr(aux_reg_acc_s32,
173             jcp.typesize_acc * (jcp.load_dim * i_ur + i_load * jcp.load_block));
174     };
175
176     auto init = [=]() {
177         Label l_first_load, l_ret;
178
179         test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
180         jnz(l_first_load, T_NEAR); // FISRT load: if not zero jump to <l_first_load>
181
182         for (int i_load = 0; i_load < load_loop_blk; ++i_load)
183             for (int i_ur = 0; i_ur < ur; ++i_ur) {
184                 auto r = vreg_accum(i_load, i_ur);
185                 vmovups(r, acc_s32_ptr(i_load, i_ur));
186             }
187         jmp(l_ret, T_NEAR);
188
189         L(l_first_load);
190         for (int i_load = 0; i_load < load_loop_blk; ++i_load)
191             for (int i_ur = 0; i_ur < ur; ++i_ur) {
192                 auto r = vreg_accum(i_load, i_ur);
193                 vpxord(r, r, r);
194             }
195         L(l_ret);
196     };
197
198     auto store = [=]() {
199         Label l_update_acc, l_ret;
200
201         test(reg_reduce_pos_flag, FLAG_REDUCE_LAST);
202         jz(l_update_acc, T_NEAR); // LAST channel: if zero jump to <l_update_acc>
203
204         const auto &p = attr_.post_ops_;
205         const int sum_idx = p.find(primitive_kind::sum);
206         const float *p_sum_scale = (sum_idx != -1)
207             ? &p.entry_[sum_idx].sum.scale
208             : nullptr;
209
210         if (jcp.with_bias) {
211             mov(EVEX_compress_addr(rsp, aux_reg_acc_s32_offt), aux_reg_acc_s32);
212             mov(reg_bias_data, EVEX_compress_addr(rsp, reg_bias_data_offt));
213         }
214         mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data);
215         mov(reg_ptr_scales, EVEX_compress_addr(rsp, reg_ptr_sum_scale_off));
216         if (p_sum_scale && *p_sum_scale != 1.f) {
217             mov(EVEX_compress_addr(rsp, reg_load_data_off), reg_load_data);
218             mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
219         }
220         vpxord(zmm_zero, zmm_zero, zmm_zero);
221         for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
222             auto zmm_bias = zmm_tmp;
223             if (jcp.with_bias) {
224                 switch (jcp.bia_dt) {
225                 case data_type::f32:
226                 case data_type::s32: vmovups(zmm_bias,
227                                         bias_ptr(i_load)); break;
228                 case data_type::s8: vpmovsxbd(zmm_bias,
229                                         bias_ptr(i_load)); break;
230                 case data_type::u8: vpmovzxbd(zmm_bias,
231                                         bias_ptr(i_load)); break;
232                 default: assert(!"unsupported bias data type");
233                 }
234                 if (jcp.bia_dt != data_type::f32)
235                     vcvtdq2ps(zmm_bias, zmm_bias);
236             }
237             for (int i_ur = 0; i_ur < ur; ++i_ur) {
238                 auto r = vreg_accum(i_load, i_ur);
239                 auto x = xreg_accum(i_load, i_ur);
240                 vcvtdq2ps(r, r);
241                 if (jcp.with_bias)
242                     vaddps(r, r, zmm_bias);
243                 vmulps(r, r, scale_ptr(i_load));
244                 if (maybe_relu(0))
245                     vmaxps(r, zmm_zero, r);
246                 if (p_sum_scale) { // post_op: sum
247                     auto zmm_prev_dst = zmm_bcast;
248                     switch (jcp.dst_dt) {
249                     case data_type::f32:
250                     case data_type::s32: vmovups(zmm_prev_dst,
251                                             output_ptr(i_load, i_ur)); break;
252                     case data_type::s8: vpmovsxbd(zmm_prev_dst,
253                                             output_ptr(i_load, i_ur)); break;
254                     case data_type::u8: vpmovzxbd(zmm_prev_dst,
255                                             output_ptr(i_load, i_ur)); break;
256                     default: assert(!"unsupported dst data type");
257                     }
258                     if (jcp.dst_dt != data_type::f32)
259                         vcvtdq2ps(zmm_prev_dst, zmm_prev_dst);
260                     if (*p_sum_scale == 1.f)
261                         vaddps(r, zmm_prev_dst);
262                     else
263                         vfmadd231ps(r, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
264                 }
265                 if (maybe_relu(1))
266                     vmaxps(r, zmm_zero, r);
267                 if (jcp.dst_dt != data_type::f32) {
268                     if (attr_.round_mode_ == round_mode::nearest) {
269                         vcvtps2dq(r | T_rn_sae, r);
270                     } else if (attr_.round_mode_ == round_mode::down) {
271                         vcvtps2dq(r | T_rd_sae, r);
272                     } else
273                         assert(!"unimplemented");
274                 }
275                 switch (jcp.dst_dt) {
276                 case data_type::f32:
277                 case data_type::s32: vmovups(output_ptr(i_load, i_ur), r); break;
278                 case data_type::s8: vpmovsdb(x, r);
279                                     vmovups(output_ptr(i_load, i_ur), x); break;
280                 case data_type::u8: vpmovusdb(x, r);
281                                     vmovups(output_ptr(i_load, i_ur), x); break;
282                 default: assert(!"unknown dst_dt");
283                 }
284             }
285         }
286         if (jcp.with_bias)
287             mov(aux_reg_acc_s32, EVEX_compress_addr(rsp, aux_reg_acc_s32_offt));
288         mov(reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off));
289         if (p_sum_scale && *p_sum_scale != 1.f)
290             mov(reg_load_data, EVEX_compress_addr(rsp, reg_load_data_off));
291         jmp(l_ret, T_NEAR);
292
293         L(l_update_acc);
294
295         mov(aux_reg_bcast_data, EVEX_compress_addr(rsp, aux_reg_acc_s32_offt));
296         for (int i_load = 0; i_load < load_loop_blk; ++i_load)
297             for (int i_ur = 0; i_ur < ur; ++i_ur) {
298                 auto r = vreg_accum(i_load, i_ur);
299                 vmovups(acc_s32_ptr(i_load, i_ur), r);
300
301             }
302         L(l_ret);
303     };
304
305     auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) {
306         if (jcp.ver == ver_vnni) {
307             vpdpbusd(vreg_acc, vreg_src, vreg_wei);
308         } else {
309             vpmaddubsw(zmm_tmp, vreg_src, vreg_wei);
310             vpmaddwd(zmm_tmp, zmm_tmp, zmm_one);
311             vpaddd(vreg_acc, vreg_acc, zmm_tmp);
312         }
313     };
314
315     auto fma_block = [=](bool last_block) {
316         int reduce_step = 4;
317         for (int i_reduce = 0; i_reduce < jcp.reduce_loop_unroll;
318                 i_reduce += reduce_step) {
319             for (int i_load = 0; i_load < load_loop_blk; ++i_load)
320                 vmovups(vreg_load(i_load), load_ptr(i_reduce, i_load));
321             for (int i_ur = 0; i_ur < ur; ++i_ur) {
322                 vpbroadcastd(zmm_bcast, bcast_ptr(i_reduce, i_ur, false));
323                 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
324                     compute(vreg_accum(i_load, i_ur),
325                                 vreg_load(i_load), zmm_bcast);
326                 }
327             }
328         }
329     };
330
331     Label reduce_loop;
332     Label reduce_loop_tail;
333
334     mov(aux_reg_load_data, reg_load_data);
335
336     mov(aux_reg_bcast_data, aux1_reg_bcast_data);
337     init();
338
339     mov(reduce_loop_iter, reg_reduce_loop_work);
340     sub(reduce_loop_iter, jcp.reduce_loop_unroll);
341     jle(reduce_loop_tail, T_NEAR);
342
343     L(reduce_loop); {
344         fma_block(false);
345         add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step);
346         add(aux_reg_load_data, jcp.reduce_loop_load_step);
347         sub(reduce_loop_iter, jcp.reduce_loop_unroll);
348         jg(reduce_loop, T_NEAR);
349     }
350
351     L(reduce_loop_tail);
352     fma_block(true);
353
354     store();
355 }
356
357 void jit_avx512_core_u8s8s32x_1x1_conv_kernel::generate()
358 {
359     preamble();
360
361     xor_(reg_scratch, reg_scratch);
362     Reg16 _t = reg_scratch.cvt16();
363     mov(_t, 0x1);
364     vpbroadcastw(zmm_one, _t);
365
366     sub(rsp, stack_space_needed);
367     if (jcp.with_bias) {
368         mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
369         mov(EVEX_compress_addr(rsp, reg_bias_data_offt), reg_bias_data);
370     }
371     mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
372     mov(EVEX_compress_addr(rsp, reg_ptr_sum_scale_off), reg_ptr_scales);
373     mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
374     mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]);
375     mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]);
376
377     mov(reg_acc_s32, ptr[param1 + GET_OFF(acc_s32)]);
378     mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]);
379     mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]);
380     mov(EVEX_compress_addr(rsp, bcast_loop_work_offt), reg_bcast_loop_work);
381     mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
382     mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(reduce_pos_flag)]);
383
384
385     auto load_loop_body = [=](int load_loop_blk) {
386         bcast_loop(load_loop_blk);
387         add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
388         if (jcp.with_bias) {
389             mov(reg_bias_data, EVEX_compress_addr(rsp, reg_bias_data_offt));
390             add(reg_bias_data,
391                 load_loop_blk * jcp.load_block * jcp.typesize_bia);
392             mov(EVEX_compress_addr(rsp, reg_bias_data_offt), reg_bias_data);
393         }
394         mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data);
395         mov(reg_ptr_scales, EVEX_compress_addr(rsp, reg_ptr_sum_scale_off));
396         add(reg_ptr_scales,
397             jcp.is_oc_scale * load_loop_blk * jcp.load_block * sizeof(float));
398         mov(EVEX_compress_addr(rsp, reg_ptr_sum_scale_off), reg_ptr_scales);
399         mov(reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off));
400         add(reg_output_data,
401             load_loop_blk * jcp.load_block * jcp.typesize_out);
402         add(reg_acc_s32,
403             load_loop_blk * jcp.load_block * jcp.typesize_acc);
404         sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
405     };
406
407     const int simd_w = 16;
408
409     Label load_loop_blk[7];
410
411     static const int ur_cases_fma_expl_bcast[] = { 2, 5, 6, 9, 14, 32 };
412     const int size_ur_cases_fma = sizeof(ur_cases_fma_expl_bcast);
413     const int *ur_cases_fma = ur_cases_fma_expl_bcast;
414     const int *ur_cases = ur_cases_fma;
415     const int num_ur_cases = (size_ur_cases_fma) / sizeof(*ur_cases);
416
417     for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) {
418         int label_idx = num_ur_cases - ur_idx - 1;
419         if (jcp.ur <= ur_cases[ur_idx]) {
420             cmp(reg_load_loop_work, simd_w * (label_idx + 1));
421             jle(load_loop_blk[label_idx], T_NEAR);
422         }
423     }
424
425     for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) {
426         if (jcp.ur <= ur_cases[ur_idx]) {
427             int label_idx = num_ur_cases - ur_idx - 1;
428             L(load_loop_blk[label_idx]);
429             {
430                 if (label_idx == 0) {
431                     cmp(reg_load_loop_work, 0);
432                     je(load_loop_blk[num_ur_cases], T_NEAR);
433                 }
434                 load_loop_body(label_idx + 1);
435                 if (label_idx - 1 > 0) {
436                     cmp(reg_load_loop_work, 2 * label_idx * simd_w);
437                     je(load_loop_blk[label_idx - 1], T_NEAR);
438                 }
439                 cmp(reg_load_loop_work, (label_idx + 1) * simd_w);
440                 jge(load_loop_blk[label_idx]);
441             }
442             for (int idx = label_idx - 1; idx > 0; --idx) {
443                 cmp(reg_load_loop_work, simd_w * (idx + 1));
444                 je(load_loop_blk[idx], T_NEAR);
445             }
446             if (ur_idx < num_ur_cases - 2) {
447                 cmp(reg_load_loop_work, simd_w);
448                 jle(load_loop_blk[0], T_NEAR);
449             }
450         }
451     }
452     L(load_loop_blk[num_ur_cases]);
453
454     add(rsp, stack_space_needed);
455
456     postamble();
457 }
458
459 bool jit_avx512_core_u8s8s32x_1x1_conv_kernel::post_ops_ok(
460         jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
461     using namespace primitive_kind;
462     const auto &p = attr.post_ops_;
463
464     auto is_relu = [&](int idx) {
465         return p.entry_[idx].kind == eltwise
466             && p.entry_[idx].eltwise.scale == 1.
467             && p.entry_[idx].eltwise.alg == alg_kind::eltwise_relu
468             && p.entry_[idx].eltwise.alpha == 0.;
469     };
470
471    switch (p.len_) {
472     case 0: return true;
473     case 1: return true
474                 && implication(jcp.with_eltwise, p.contain(sum, 0))
475                 && implication(!jcp.with_eltwise, is_relu(0) || p.contain(sum, 0));
476     case 2: return true
477                 && implication(jcp.with_eltwise, p.contain(sum, 0) && is_relu(1))
478                 && implication(!jcp.with_eltwise, false
479                         || (p.contain(sum, 0) && is_relu(1))
480                         || (p.contain(sum, 1) && is_relu(0)));
481     case 3: return true
482                 && jcp.with_eltwise == false
483                 && (is_relu(0) && p.contain(sum, 1) && is_relu(2));
484     default: return false;
485     }
486
487     return false;
488 }
489
490 status_t jit_avx512_core_u8s8s32x_1x1_conv_kernel::init_conf(
491         jit_1x1_conv_conf_t &jcp, const convolution_desc_t &cd,
492         const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
493         const memory_desc_wrapper &dst_d, const memory_desc_wrapper &bias_d,
494         const primitive_attr_t &attr, bool with_relu, float relu_negative_slope,
495         int nthreads, bool reduce_src)
496 {
497     if (!mayiuse(avx512_core)) return status::unimplemented;
498
499     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
500     if (src_d.data_type() != data_type::u8
501         || weights_d.data_type() != data_type::s8
502         || !one_of(dst_d.data_type(),
503             data_type::f32, data_type::s32, data_type::s8, data_type::u8))
504         return status::unimplemented;
505     if (!one_of(weights_d.format(), gOIhw4i16o4i, OIhw4i16o4i))
506         return status::unimplemented;
507
508     jcp.ver = ver_avx512_core;
509     if (mayiuse(avx512_core_vnni))
510         jcp.ver = ver_vnni;
511
512     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
513     jcp.mb = src_d.dims()[0];
514     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
515     jcp.ic = src_d.dims()[1] / jcp.ngroups;
516     jcp.ih = src_d.dims()[2];
517     jcp.iw = src_d.dims()[3];
518     jcp.oh = dst_d.dims()[2];
519     jcp.ow = dst_d.dims()[3];
520     jcp.kh = weights_d.dims()[with_groups + 2];
521     jcp.kw = weights_d.dims()[with_groups + 3];
522     jcp.t_pad = cd.padding[0][0];
523     jcp.l_pad = cd.padding[0][1];
524     jcp.stride_h = cd.strides[0];
525     jcp.stride_w = cd.strides[1];
526     jcp.src_fmt = src_d.format();
527     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
528     jcp.with_eltwise = with_relu;
529     jcp.eltwise_alpha = relu_negative_slope;
530     if (!implication(with_relu, relu_negative_slope == 0.))
531         return status::unimplemented;
532
533     jcp.os = jcp.oh * jcp.ow;
534     jcp.is = jcp.ih * jcp.iw;
535     jcp.tr_is = rnd_up(jcp.is, 4);
536
537     if (!post_ops_ok(jcp, attr))
538         return status::unimplemented;
539
540     bool args_ok = true
541         && jcp.ngroups == 1
542         && src_d.format() == nhwc
543         && one_of(cd.bias_desc.format, memory_format::undef, any, x)
544         && dst_d.format() == nhwc;
545     if (!args_ok) return status::unimplemented;
546
547     const int simd_w = 16;
548
549     args_ok = true
550         && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0
551         && jcp.t_pad == 0 && jcp.l_pad == 0
552         && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides
553         && jcp.kh == 1 && jcp.kw == 1;
554     if (!args_ok) return status::unimplemented;
555
556     jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
557     jcp.dst_dt = cd.dst_desc.data_type;
558
559     jcp.ic_block = jcp.oc_block = simd_w;
560
561     jcp.typesize_in = types::data_type_size(src_d.data_type());
562     jcp.typesize_out = types::data_type_size(dst_d.data_type());
563     jcp.typesize_acc = sizeof(int32_t);
564     jcp.typesize_bia = jcp.with_bias
565         ? types::data_type_size(bias_d.data_type())
566         : 0;
567
568     const int SMALL_SPATIAL = 7 * 7;
569     const int BIG_REDUCE_DIM = 1024;
570
571     int load_blocking = 0;
572     int load_blocking_max = 0;
573     int bcast_blocking = 0;
574     int bcast_blocking_max = 0;
575     int reduce_blocking = 0;
576     int reduce_blocking_max = 0;
577     jcp.load_grp_count = 1;
578     jcp.use_vmovntps = false;
579
580     const int L2_size = get_cache_size(2, true) / sizeof(jcp.typesize_in);
581     const int L2_capacity = (L2_size * 3) / 4;
582
583     int size_treshold = 28;
584     int max_regs = (jcp.ver == ver_vnni) ? 9 : 8;
585     int min_regs = 6;
586     jcp.expl_bcast = true;
587
588     const int spatial = jcp.oh;
589     jcp.ur = 1;
590     for (int ur_w = max_regs; ur_w >= min_regs; ur_w--) {
591         if ((spatial >= size_treshold && spatial % ur_w == 0)
592                 || (spatial < size_treshold && jcp.os % ur_w == 0)) {
593             jcp.ur = ur_w;
594             break;
595         }
596     }
597     if (jcp.ur == 1) {
598         jcp.ur = nstl::min(max_regs, jcp.os);
599         int os_tail = jcp.os % max_regs;
600         for (int i = max_regs; i >= min_regs; i--) {
601             int i_tail = jcp.os % i;
602             if (i_tail > os_tail || i_tail == 0) {
603                 jcp.ur = i;
604                 os_tail = i_tail;
605                 if (i_tail == 0)
606                     break;
607             }
608         }
609     }
610
611     jcp.reduce_dim = jcp.ic;
612     jcp.reduce_block = jcp.ic_block;
613
614     jcp.load_dim = jcp.oc;
615     jcp.load_block = jcp.oc_block;
616
617     jcp.bcast_dim = jcp.is;
618
619     jcp.bcast_block = jcp.ur;
620
621     jcp.reduce_loop_unroll = jcp.reduce_block;
622     jcp.reduce_loop_bcast_step
623             = jcp.reduce_loop_unroll * jcp.typesize_in;
624
625     jcp.reduce_loop_load_step
626             = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in;
627
628     jcp.bcast_loop_output_step = jcp.ur * jcp.load_dim * jcp.typesize_out;
629     jcp.bcast_loop_output_substep = -1; // unused
630     jcp.bcast_loop_bcast_step = jcp.ur * jcp.reduce_dim * jcp.typesize_in;
631     jcp.bcast_loop_bcast_substep = -1; // unused
632
633     jcp.load_loop_load_step
634             = jcp.reduce_dim * jcp.load_block * jcp.typesize_in;
635
636     jcp.load_loop_iter_step = jcp.load_block;
637
638     jcp.loop_order = reduce_src ? loop_blr : loop_lbr;
639
640     int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
641     int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
642
643     reduce_blocking = nb_reduce;
644     if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM)
645         reduce_blocking = 64;
646     else if (jcp.bcast_dim > SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM)
647         reduce_blocking = 16;
648     reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true);
649     reduce_blocking *= jcp.reduce_block;
650
651     bool cmp_reduce = reduce_blocking <= jcp.reduce_dim;
652     if (cmp_reduce)
653         jcp.loop_order = reduce_src ? loop_rbl : loop_rlb;
654     load_blocking = jcp.load_dim;
655
656     jcp.load_grp_count = div_up(nthreads, jcp.mb * jcp.ngroups * nb_bcast);
657     jcp.load_grp_count = best_divider(
658             nthreads, jcp.load_grp_count, 2 * jcp.load_grp_count, false);
659
660     if (jcp.bcast_dim <= 64 && jcp.load_dim * jcp.reduce_dim >= L2_size) {
661         jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4);
662     } else if (jcp.bcast_dim <= 49 && jcp.mb <= nthreads
663             && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) {
664         jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2); //
665         load_blocking = jcp.load_block;
666     }
667
668     bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast,
669                              div_up(nthreads, jcp.load_grp_count)) * jcp.bcast_block;
670     bcast_blocking = nstl::min(jcp.bcast_dim, bcast_blocking);
671     bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block);
672
673     int space_for_bcast
674             = (L2_capacity - /* kernel_size - */
675                 2 * jcp.load_block * reduce_blocking
676                     - jcp.ur * reduce_blocking - 3 * 1024);
677     if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity)
678         space_for_bcast /= 2;
679
680     int bcast_in_cache
681             = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking);
682     bcast_blocking = nstl::min(
683             bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block));
684
685     load_blocking_max = load_blocking;
686     bcast_blocking_max = bcast_blocking * 3 / 2;
687     reduce_blocking_max = reduce_blocking;
688
689     assert(load_blocking);
690     assert(load_blocking_max);
691     assert(bcast_blocking);
692     assert(bcast_blocking_max);
693     assert(reduce_blocking);
694     assert(reduce_blocking_max);
695     assert(load_blocking % jcp.load_block == 0);
696     assert(reduce_blocking % jcp.reduce_block == 0);
697     assert(load_blocking_max % jcp.load_block == 0);
698     assert(reduce_blocking_max % jcp.reduce_block == 0);
699
700     assert(jcp.reduce_loop_unroll % 4 == 0);
701     assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0);
702
703     assert(jcp.bcast_block % jcp.ur == 0);
704     assert(jcp.reduce_dim % jcp.reduce_block == 0);
705
706     jcp.ur_tail = jcp.bcast_dim % jcp.ur;
707
708     jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block;
709     jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block;
710     jcp.nb_load_blocking = load_blocking / jcp.load_block;
711     jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block;
712     jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block;
713     jcp.nb_reduce_blocking_max = reduce_blocking_max / jcp.reduce_block;
714
715     jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
716     jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
717     jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
718
719     const auto &oscales = attr.output_scales_;
720     jcp.is_oc_scale = oscales.mask_ == 1 << 1;
721     assert(utils::implication(!jcp.is_oc_scale, oscales.mask_ == 0));
722
723     return status::success;
724 }
725
726 }
727 }
728 }