Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_common_1x1_conv_kernel.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18 #include <float.h>
19
20 #include "c_types_map.hpp"
21 #include "memory_tracking.hpp"
22 #include "mkldnn_thread.hpp"
23 #include "nstl.hpp"
24 #include "type_helpers.hpp"
25 #include "utils.hpp"
26
27 #include "cpu_memory.hpp"
28 #include "cpu_barrier.hpp"
29
30 #include "jit_uni_1x1_conv_utils.hpp"
31 #include "jit_avx512_common_1x1_conv_kernel.hpp"
32
33 #define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
34
35 namespace mkldnn {
36 namespace impl {
37 namespace cpu {
38
39 using namespace mkldnn::impl::prop_kind;
40 using namespace mkldnn::impl::memory_format;
41 using namespace mkldnn::impl::utils;
42
43 using namespace Xbyak;
44
45 void jit_avx512_common_1x1_conv_kernel::bcast_loop(int load_loop_blk)
46 {
47     mov(aux1_reg_bcast_data, reg_bcast_data);
48     mov(aux_reg_bcast_data, reg_bcast_data);
49
50     mov(aux_reg_output_data, reg_output_data);
51     mov(bcast_loop_iter, EVEX_compress_addr(rsp, bcast_loop_work_offt));
52
53     if (jcp.ver == ver_4fma)
54     {
55         Label bcast_loop;
56         Label bcast_loop_wraparound;
57         Label bcast_loop_out;
58         Label bcast_loop_ur_full;
59
60         cmp(bcast_loop_iter, jcp.ur);
61         jle(bcast_loop_wraparound, T_NEAR);
62
63         L(bcast_loop); {
64             assert(jcp.bcast_block % jcp.ur == 0);
65             int num_substeps = jcp.bcast_block / jcp.ur;
66             assert(num_substeps > 0 && num_substeps < 10);
67             for (int i = 0; i < num_substeps; i++) {
68                 reduce_loop(load_loop_blk, jcp.ur, i, false);
69                 if (i < num_substeps - 1) {
70                     add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
71                     add(aux_reg_output_data, jcp.bcast_loop_output_substep);
72                 }
73                 else {
74                     add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step
75                         - (num_substeps - 1) * jcp.bcast_loop_bcast_substep);
76                     add(aux_reg_output_data, jcp.bcast_loop_output_step
77                         - (num_substeps - 1) * jcp.bcast_loop_output_substep);
78                 }
79             }
80             sub(bcast_loop_iter, jcp.bcast_block);
81             cmp(bcast_loop_iter, jcp.bcast_block);
82             jg(bcast_loop, T_NEAR);
83         }
84
85         L(bcast_loop_wraparound);
86         if (jcp.ur_tail) {
87             je(bcast_loop_ur_full, T_NEAR);
88             reduce_loop(load_loop_blk, jcp.ur_tail, 0, true);
89             jmp(bcast_loop_out, T_NEAR);
90         }
91         L(bcast_loop_ur_full);
92         reduce_loop(load_loop_blk, jcp.ur, 0, true);
93         L(bcast_loop_out);
94     }
95     else
96     {
97         Label bcast_loop;
98         Label bcast_loop_tail;
99
100         cmp(bcast_loop_iter, jcp.ur);
101         jl(bcast_loop_tail, T_NEAR);
102
103         L(bcast_loop); {
104             assert(jcp.bcast_block % jcp.ur == 0);
105             int num_substeps = jcp.bcast_block / jcp.ur;
106             assert(num_substeps > 0 && num_substeps < 10);
107             for (int i = 0; i < num_substeps; i++) {
108                 reduce_loop(load_loop_blk, jcp.ur, i, false);
109                 if (i < num_substeps - 1) {
110                     add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
111                     add(aux_reg_output_data, jcp.bcast_loop_output_substep);
112                 }
113                 else {
114                     add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step
115                         - (num_substeps - 1) * jcp.bcast_loop_bcast_substep);
116                     add(aux_reg_output_data, jcp.bcast_loop_output_step
117                         - (num_substeps - 1) * jcp.bcast_loop_output_substep);
118                 }
119             }
120             sub(bcast_loop_iter, jcp.bcast_block);
121             cmp(bcast_loop_iter, jcp.bcast_block);
122             jge(bcast_loop, T_NEAR);
123         }
124
125         L(bcast_loop_tail);
126         if (jcp.ur_tail) {
127             Label bcast_loop_tail_out;
128             cmp(bcast_loop_iter, 0);
129             jz(bcast_loop_tail_out, T_NEAR);
130             reduce_loop(load_loop_blk, jcp.ur_tail, 0, true);
131             L(bcast_loop_tail_out);
132         }
133     }
134 }
135
136 void jit_avx512_common_1x1_conv_kernel::reduce_loop(int load_loop_blk,
137          int ur, int substep, bool wraparound)
138 {
139     auto vreg_load = [=](int i_load, int i_fma) {
140         return Zmm(utils::rnd_up(ur * load_loop_blk, jcp.fma_step)
141                     + jcp.fma_step * i_load + i_fma);
142     };
143
144     auto vreg_accum = [=](int i_load, int i_ur) {
145         return Zmm(i_ur + i_load * ur);
146     };
147
148     auto bias_ptr = [=](int i_load) {
149         return EVEX_compress_addr(reg_bias_data,
150                                   jcp.typesize_out * jcp.oc_block * i_load);
151     };
152
153     auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast) {
154         assert(i_ur < jcp.ur);
155         assert(i_reduce <= jcp.reduce_loop_unroll);
156         int offt;
157         if (one_of(jcp.prop_kind, forward_training, forward_inference,
158                    backward_data)) {
159             assert(jcp.reduce_loop_unroll == jcp.reduce_block);
160             offt = (i_reduce == jcp.reduce_loop_unroll)
161                     ? (jcp.bcast_dim + i_ur) * jcp.reduce_loop_unroll
162                     : i_ur * jcp.reduce_loop_unroll + i_reduce;
163         } else {
164             if (jcp.transpose_src) {
165                 const int reduce_group = i_reduce / 4;
166                 const int reduce_shift = i_reduce % 4;
167                 offt = 4 * (reduce_group * jcp.ic_block + i_ur) + reduce_shift;
168             }
169             else
170                 offt = i_reduce * jcp.ic_block + i_ur;
171         }
172         return EVEX_compress_addr(aux_reg_bcast_data, jcp.typesize_in * offt,
173                                 bcast);
174     };
175
176     auto load_ptr = [=](int i_reduce, int i_load) {
177         int offt;
178         int u0 = i_reduce % jcp.reduce_loop_unroll;
179         int u1 = i_reduce / jcp.reduce_loop_unroll;
180         if (jcp.prop_kind == backward_data && jcp.ver == ver_4vnni)
181             offt = (i_load * jcp.reduce_block + u0) * jcp.load_block;
182         else
183             offt = (i_load * jcp.reduce_dim + u0) * jcp.load_block;
184         return EVEX_compress_addr(aux_reg_load_data,
185                                   u1 * jcp.reduce_loop_load_step
186                                   + jcp.typesize_in * offt);
187     };
188
189     auto output_ptr = [=](int i_load, int i_ur) {
190         if (one_of(jcp.prop_kind, forward_training, forward_inference,
191                    backward_data))
192             return EVEX_compress_addr(aux_reg_output_data,
193                     (i_load * jcp.bcast_dim + i_ur) * jcp.load_block
194                     * jcp.typesize_out);
195         else
196             return ptr[aux_reg_output_data +
197                        (i_load
198                             ? reg_output_stride * i_load
199                             : 0) // TODO: Xbyak should allow 0 scale
200                        + jcp.typesize_out * jcp.load_block * i_ur];
201     };
202
203     auto init = [=]() {
204         Label init_done;
205         Label init_zero;
206
207         if (jcp.with_sum) {
208             for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
209                 for (int i_ur = 0; i_ur < ur; ++i_ur) {
210                     mic_prefetcht1(output_ptr(i_load, i_ur));
211                 }
212             }
213         }
214
215         if (jcp.with_bias
216             && one_of(jcp.prop_kind, forward_training, forward_inference)) {
217             test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
218             jz(init_zero, T_NEAR);
219
220             for (int i_load = 0; i_load < load_loop_blk; i_load++)
221                 for (int i_ur = 0; i_ur < ur; ++i_ur)
222                     vmovups(vreg_accum(i_load, i_ur), bias_ptr(i_load));
223             jmp(init_done, T_NEAR);
224         }
225
226         L(init_zero);
227         for (int i_load = 0; i_load < load_loop_blk; ++i_load)
228             for (int i_ur = 0; i_ur < ur; ++i_ur) {
229                 auto r = vreg_accum(i_load, i_ur);
230                 vpxord(r, r, r);
231             }
232         L(init_done);
233     };
234
235     auto vadd = [=](const Xmm& x1, const Xmm& x2, const Operand& op) {
236         if (jcp.ver == ver_4vnni)
237             vpaddd(x1, x2, op);
238         else
239             vaddps(x1, x2, op);
240     };
241
242     auto store = [=]() {
243
244         Label store_noadd;
245         if (!jcp.with_sum) {
246             test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
247             jnz(store_noadd, T_NEAR);
248         }
249
250         for (int i_ur = 0; i_ur < ur; ++i_ur)
251             for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
252                 auto r = vreg_accum(i_load, i_ur);
253                 vadd(r, r, output_ptr(i_load, i_ur));
254             }
255
256         L(store_noadd);
257
258         Label store_nopostproc;
259         test(reg_reduce_pos_flag, FLAG_REDUCE_LAST);
260         jz(store_nopostproc, T_NEAR);
261
262         int eltwise_inj_idx = 0;
263         int depthwise_inj_idx = 0;
264         const auto &p = attr_.post_ops_;
265
266         for (int i = 0; i < p.len_; i++) {
267             auto& post_op = p.entry_[i];
268             if (post_op.is_eltwise()) {
269                 if (jcp.ver == ver_4vnni) {
270                     zmm_t zmm_zero = vreg_bcast;
271                     vpxord(zmm_zero, zmm_zero, zmm_zero);
272
273                     for (int i_ur = 0; i_ur < ur; ++i_ur) {
274                         for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
275                             Zmm zmm = vreg_accum(i_load, i_ur);
276                             vpcmpd(k1, zmm, zmm_zero, _cmp_lt_os);
277                             vpmulld(zmm | k1, zmm, zmm_zero);
278                         }
279                     }
280                 } else {
281                     eltwise_injectors[eltwise_inj_idx]->compute_vector_range(0, ur * load_loop_blk);
282                 }
283                 eltwise_inj_idx++;
284             } else if (post_op.is_depthwise()) {
285                 mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
286                 mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
287
288                 add(reg_d_weights, reg_oc_off);
289                 add(reg_d_bias, reg_oc_off);
290
291                 for (int j = 0; j < load_loop_blk; ++j) {
292                     int start_idx = vreg_accum(j, 0).getIdx();
293                     int end_idx = start_idx + ur;
294
295                     depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
296                             start_idx, end_idx, reg_d_weights, reg_d_bias);
297
298                     add(reg_d_weights, jcp.oc_block * sizeof(float));
299                     add(reg_d_bias, jcp.oc_block * sizeof(float));
300                 }
301
302                 depthwise_inj_idx++;
303             }
304         }
305
306         L(store_nopostproc);
307
308         auto store_output = [=](bool output_is_aligned) {
309             for (int i_ur = 0; i_ur < ur; ++i_ur)
310                 for (int i_load = 0; i_load < load_loop_blk; ++i_load)
311                     if (output_is_aligned && jcp.use_vmovntps)
312                         vmovntps(output_ptr(i_load, i_ur),
313                             vreg_accum(i_load, i_ur));
314                     else
315                         vmovups(output_ptr(i_load, i_ur),
316                             vreg_accum(i_load, i_ur));
317         };
318
319         Label unaligned_store, end_store;
320         test(aux_reg_output_data, cpu_isa_traits<avx512_common>::vlen - 1);
321         jnz(unaligned_store, T_NEAR);
322         store_output(true);
323         jmp(end_store, T_NEAR);
324         L(unaligned_store); {
325             store_output(false);
326         }
327         L(end_store);
328     };
329
330     auto prefetch_callback = [=](int ur, int i_reduce, int i_ur, int i_load,
331         bool last_block, bool wraparound, int reduce_step)
332     {
333         bool pf_ker_l1 = true;
334         bool pf_ker_l2 = wraparound;
335         int n_ops = (jcp.reduce_loop_unroll / reduce_step) * ur * load_loop_blk;
336         int i_op = (i_reduce / reduce_step) * ur * load_loop_blk +
337             i_ur * load_loop_blk + i_load;
338
339         int n_pf_ker_l1 = pf_ker_l1 ? jcp.reduce_block : 0;
340         int n_pf_ker_l2 = pf_ker_l2 && wraparound ? jcp.reduce_block : 0;
341         int n_pf_out_l1 = jcp.use_vmovntps ? 0 : ur;
342
343         int pf_inp_ops = n_ops / 2; // # of operations during which to pf input
344         int pf_inp_trigger;
345         if (jcp.prop_kind == backward_weights)
346             pf_inp_trigger = nstl::max(1, pf_inp_ops / jcp.reduce_block);
347         else
348             pf_inp_trigger = nstl::max(1, pf_inp_ops / ur);
349
350         int n_other_pf =
351             load_loop_blk * (n_pf_ker_l1 + n_pf_ker_l2 + n_pf_out_l1);
352         int n_other_pf_ops = n_ops - pf_inp_ops;
353         int other_pf_trigger
354                 = n_other_pf ? nstl::max(1, n_other_pf_ops / n_other_pf) : 0;
355
356         if (i_op < pf_inp_ops && i_op % pf_inp_trigger == 0) {
357             // input prefetches have the highest priority b/c the
358             // first iteration of the kernel block touches all the
359             // cache lines
360             int i_pf = i_op / pf_inp_trigger;
361             auto pf_reg = wraparound && last_block
362                                   ? reg_bcast_data
363                                   : (last_block ? aux1_reg_bcast_data
364                                                 : aux_reg_bcast_data);
365             int offt = i_pf;
366             if (jcp.prop_kind == backward_weights) {
367                 offt += wraparound && last_block
368                                     ? 0
369                                     : (last_block ? jcp.is : jcp.reduce_block);
370                 offt *= jcp.bcast_block;
371             } else {
372                 offt += wraparound && last_block
373                                     ? 0
374                                     : (last_block ? jcp.ur : jcp.bcast_dim);
375                 offt *= jcp.reduce_block;
376             }
377             mic_prefetcht0(ptr[pf_reg + offt * jcp.typesize_in]);
378         } else if (i_op >= pf_inp_ops && n_other_pf) {
379             // remaining prefetches are spread among the rest of the
380             // operations; prefetches for output take priority
381             // TODO: spread L2 prefetches among L1 prefetches
382             i_op -= pf_inp_ops;
383             if (i_op % other_pf_trigger == 0) {
384                 int i_pf = i_op / (load_loop_blk * other_pf_trigger);
385                 if (i_pf < n_pf_ker_l2) {
386                     int offt = (i_pf + (i_load + 1) * jcp.reduce_dim)
387                         * jcp.load_block;
388                     if (jcp.prop_kind == backward_data && jcp.ver == ver_4vnni)
389                         offt = (i_pf + (i_load + 1) * jcp.reduce_block)
390                                 * jcp.load_block;
391
392                     mic_prefetcht1(ptr[aux_reg_load_data
393                                     + offt * jcp.typesize_in]);
394                 } else if (i_pf < n_pf_ker_l2 + n_pf_ker_l1) {
395                     i_pf -= n_pf_ker_l2;
396                     auto pf_reg = last_block ? reg_load_data
397                                              : aux_reg_load_data;
398                     int offt = (i_pf + i_load * jcp.reduce_dim
399                         + (last_block
400                             ? (wraparound ? jcp.reduce_dim : 0)
401                             : jcp.reduce_block))
402                         * jcp.load_block;
403                     mic_prefetcht0(ptr[pf_reg + offt * jcp.typesize_in]);
404                 } else if (i_pf < n_pf_ker_l1 + n_pf_ker_l2 + n_pf_out_l1) {
405                     i_pf -= n_pf_ker_l1 + n_pf_ker_l2;
406                     int offt = i_pf * jcp.load_block;
407                     mic_prefetcht0(ptr[aux_reg_output_data
408                                     + offt * jcp.typesize_out]);
409                 }
410             }
411         }
412     };
413
414     auto fma_block = [=](bool last_block) {
415         assert(jcp.reduce_loop_unroll % jcp.fma_step == 0);
416
417         int reduce_step = jcp.fma_step;
418         if (jcp.ver == ver_4vnni)
419             reduce_step *= 2;
420
421         for (int i_reduce = 0; i_reduce < jcp.reduce_loop_unroll;
422                 i_reduce += reduce_step) {
423             int load_scale = (jcp.ver == ver_4vnni) ? 2 : 1;
424             for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
425                 // if transposed input data used and if spatial size is
426                 // not divided by transpose step (4) then for last reduce step
427                 // we should load only needed load_registers data
428                 // and clear remaining
429                 if (jcp.transpose_src && jcp.is % jcp.fma_step && last_block
430                         && i_reduce == jcp.reduce_loop_unroll - reduce_step) {
431                     Label load_all;
432                     Label load_finish;
433                     test(reg_reduce_pos_flag, FLAG_SP_LAST);
434                     jz(load_all, T_NEAR);
435
436                     const int n_loads = jcp.is % jcp.fma_step;
437                     for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) {
438                         if (i_fma < n_loads)
439                             vmovups(vreg_load(i_load, i_fma),
440                                     load_ptr(i_reduce + load_scale * i_fma,
441                                             i_load));
442                         else
443                             vpxord(vreg_load(i_load, i_fma),
444                                     vreg_load(i_load, i_fma),
445                                     vreg_load(i_load, i_fma));
446                     }
447                     jmp(load_finish);
448
449                     L(load_all);
450                     for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) {
451                         vmovups(vreg_load(i_load, i_fma),
452                             load_ptr(i_reduce + load_scale * i_fma, i_load));
453                     }
454                     L(load_finish);
455                 } else {
456                     for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) {
457                         vmovups(vreg_load(i_load, i_fma),
458                             load_ptr(i_reduce
459                                 + load_scale * i_fma,
460                                 i_load));
461                     }
462                 }
463             }
464
465             for (int i_ur = 0; i_ur < ur; ++i_ur) {
466                 if (jcp.ver == ver_avx512_core && jcp.expl_bcast
467                         && load_loop_blk > 1)
468                     vbroadcastss(vreg_bcast, bcast_ptr(i_reduce, i_ur, false));
469                 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
470                     if (jcp.ver == ver_4fma)
471                         v4fmaddps(vreg_accum(i_load, i_ur),
472                                     vreg_load(i_load, 0),
473                                     bcast_ptr(i_reduce, i_ur, false));
474                     else if (jcp.ver == ver_4vnni)
475                         vp4dpwssd(vreg_accum(i_load, i_ur),
476                                 vreg_load(i_load, 0),
477                                 bcast_ptr(i_reduce, i_ur, false));
478                     else if (jcp.ver == ver_avx512_core && jcp.expl_bcast
479                             && load_loop_blk > 1)
480                         vfmadd231ps(vreg_accum(i_load, i_ur),
481                                 vreg_load(i_load, 0), vreg_bcast);
482                     else
483                         vfmadd231ps(vreg_accum(i_load, i_ur),
484                                 vreg_load(i_load, 0),
485                                 bcast_ptr(i_reduce, i_ur, true));
486                     prefetch_callback(ur, i_reduce, i_ur, i_load,
487                                     last_block, wraparound, reduce_step);
488                 }
489             }
490         }
491     };
492     Label reduce_loop;
493     Label reduce_loop_tail;
494
495     mov(aux_reg_load_data, reg_load_data);
496
497     mov(aux_reg_bcast_data, aux1_reg_bcast_data);
498     init();
499
500     mov(reduce_loop_iter, reg_reduce_loop_work);
501     sub(reduce_loop_iter, jcp.reduce_loop_unroll);
502     jle(reduce_loop_tail, T_NEAR);
503
504     L(reduce_loop); {
505         fma_block(false);
506         add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step);
507         add(aux_reg_load_data, jcp.reduce_loop_load_step);
508         sub(reduce_loop_iter, jcp.reduce_loop_unroll);
509         jg(reduce_loop, T_NEAR);
510     }
511
512     L(reduce_loop_tail);
513     fma_block(true);
514
515     store();
516 }
517
518 void jit_avx512_common_1x1_conv_kernel::generate()
519 {
520     const auto &p = attr_.post_ops_;
521     for (int i = 0; i < p.len_; i++) {
522         auto &post_op = p.entry_[i];
523         if (post_op.is_eltwise()) {
524             eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<avx512_common>(
525                     this,
526                     post_op.eltwise.alg,
527                     post_op.eltwise.alpha,
528                     post_op.eltwise.beta
529             ));
530         } else if (post_op.is_depthwise()) {
531             depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<avx512_common>(
532                     this,
533                     post_op.depthwise.alg
534             ));
535         }
536     }
537
538     preamble();
539
540     mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
541     mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]);
542     mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]);
543
544     sub(rsp, stack_space_needed);
545
546     if (jcp.with_bias)
547         mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
548
549     mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]);
550     mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]);
551     mov(EVEX_compress_addr(rsp, bcast_loop_work_offt), reg_bcast_loop_work);
552     mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
553     mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
554     if (one_of(jcp.prop_kind, forward_training, forward_inference))
555         mov(reg_relu_ns, reinterpret_cast<size_t>(&jcp.eltwise.alpha));
556     if (jcp.prop_kind == backward_weights)
557         mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]);
558     mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]);
559
560     auto load_loop_body = [=](int load_loop_blk) {
561         bcast_loop(load_loop_blk);
562         add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
563         switch (jcp.prop_kind) {
564         case forward_training:
565         case forward_inference:
566             add(reg_bias_data,
567                 load_loop_blk * jcp.load_block * jcp.typesize_out);
568             add(reg_output_data,
569                 load_loop_blk * jcp.bcast_dim * jcp.load_block *
570                     jcp.typesize_out);
571             break;
572         case backward_data:
573             add(reg_output_data,
574                 load_loop_blk * jcp.bcast_dim * jcp.load_block *
575                     jcp.typesize_out);
576             break;
577         case backward_weights:
578             for (int i_load = 0; i_load < load_loop_blk; i_load++)
579                 add(reg_output_data, reg_output_stride);
580             break;
581         default:
582             assert(!"invalid prop_kind");
583         }
584         sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
585         add(reg_oc_off, load_loop_blk * jcp.oc_block * jcp.typesize_out);
586     };
587
588     const int simd_w = 16;
589
590     Label load_loop_blk[7];
591
592     static const int ur_cases_fma_embd_bcast[] = { 2, 4, 5, 8, 14, 32 };
593     static const int ur_cases_fma_expl_bcast[] = { 2, 5, 6, 9, 14, 32 };
594     static const int ur_cases_4fma[] = { 2, 4, 6, 12, 32 };
595
596     const int size_ur_cases_fma
597             = (jcp.ver == ver_avx512_core && jcp.expl_bcast) ?
598             sizeof(ur_cases_fma_expl_bcast) :
599             sizeof(ur_cases_fma_embd_bcast);
600     const int size_ur_cases_4fma = sizeof(ur_cases_4fma);
601
602     const int *ur_cases_fma = (jcp.ver == ver_avx512_core && jcp.expl_bcast) ?
603             ur_cases_fma_expl_bcast :
604             ur_cases_fma_embd_bcast;
605     const int *ur_cases = (jcp.ver == ver_4fma || jcp.ver == ver_4vnni)
606         ? ur_cases_4fma : ur_cases_fma;
607     const int num_ur_cases = (jcp.ver == ver_4fma || jcp.ver == ver_4vnni ?
608                                              size_ur_cases_4fma :
609                                              size_ur_cases_fma)
610             / sizeof(*ur_cases);
611
612     for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) {
613         int label_idx = num_ur_cases - ur_idx - 1;
614         if (jcp.ur <= ur_cases[ur_idx]) {
615             cmp(reg_load_loop_work, simd_w * (label_idx + 1));
616             jle(load_loop_blk[label_idx], T_NEAR);
617         }
618     }
619
620     for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) {
621         if (jcp.ur <= ur_cases[ur_idx]) {
622             int label_idx = num_ur_cases - ur_idx - 1;
623             L(load_loop_blk[label_idx]);
624             {
625                 if (label_idx == 0) {
626                     cmp(reg_load_loop_work, 0);
627                     je(load_loop_blk[num_ur_cases], T_NEAR);
628                 }
629                 load_loop_body(label_idx + 1);
630                 if (label_idx - 1 > 0) {
631                     cmp(reg_load_loop_work, 2 * label_idx * simd_w);
632                     je(load_loop_blk[label_idx - 1], T_NEAR);
633                 }
634                 cmp(reg_load_loop_work, (label_idx + 1) * simd_w);
635                 jge(load_loop_blk[label_idx]);
636             }
637             for (int idx = label_idx - 1; idx > 0; --idx) {
638                 cmp(reg_load_loop_work, simd_w * (idx + 1));
639                 je(load_loop_blk[idx], T_NEAR);
640             }
641             if (ur_idx < num_ur_cases - 2) {
642                 cmp(reg_load_loop_work, simd_w);
643                 jle(load_loop_blk[0], T_NEAR);
644             }
645         }
646     }
647     L(load_loop_blk[num_ur_cases]);
648
649     add(rsp, stack_space_needed);
650
651     postamble();
652
653     for (auto& inj : eltwise_injectors)
654         inj->prepare_table();
655 }
656
657 bool jit_avx512_common_1x1_conv_kernel::post_ops_ok(
658         jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
659     const auto &p = attr.post_ops_;
660
661     auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
662     auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
663     auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
664     auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
665
666     switch (p.len_) {
667     case 0: return true;
668     case 1: return is_simple(0) || is_sum(0);
669     case 2: return (is_sum(0) && is_simple(1)) || (is_simple(0) && is_simple(1));
670     case 3: return is_sum(0) && is_simple(1) && is_simple(2);
671     default: return false;
672     }
673
674     return false;
675 }
676
677 status_t jit_avx512_common_1x1_conv_kernel::init_conf(jit_1x1_conv_conf_t &jcp,
678         const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
679         const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
680         const primitive_attr_t &attr, int nthreads, bool reduce_src) {
681     if (!mayiuse(avx512_common)) return status::unimplemented;
682
683     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
684     const int simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
685     const int ndims = src_d.ndims();
686
687     jcp.prop_kind = cd.prop_kind;
688
689     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
690     jcp.mb = src_d.dims()[0];
691
692     jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups;
693     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
694     jcp.ic = src_d.dims()[1] / jcp.ngroups;
695
696     bool ok_to_pad_channels = true
697         && jcp.ngroups == 1
698         && src_d.data_type() == data_type::f32;
699     if (ok_to_pad_channels) {
700         jcp.oc = rnd_up(jcp.oc, simd_w);
701         jcp.ic = rnd_up(jcp.ic, simd_w);
702     }
703
704     jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2];
705     jcp.iw = src_d.dims()[ndims - 1];
706     jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2];
707     jcp.ow = dst_d.dims()[ndims - 1];
708
709     jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2];
710     jcp.kw = weights_d.dims()[with_groups + ndims - 1];
711
712     jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0];
713     jcp.l_pad = cd.padding[0][ndims - 3];
714
715     jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0];
716     jcp.stride_w = cd.strides[ndims - 3];
717
718     jcp.src_fmt = src_d.format();
719     jcp.with_bias = pick_by_prop_kind(jcp.prop_kind, cd.bias_desc.format,
720             memory_format::undef, cd.diff_bias_desc.format)
721         != memory_format::undef;
722
723     jcp.os = jcp.oh * jcp.ow;
724     jcp.is = jcp.ih * jcp.iw;
725     jcp.tr_is = rnd_up(jcp.is, 4);
726
727     if (!post_ops_ok(jcp, attr))
728         return status::unimplemented;
729
730     const auto &p = attr.post_ops_;
731     jcp.with_sum = p.find(primitive_kind::sum) != -1;
732     const int eltwise_ind = p.find(primitive_kind::eltwise);
733     jcp.with_eltwise = eltwise_ind != -1;
734     if (jcp.with_eltwise) {
735         jcp.eltwise = p.entry_[eltwise_ind].eltwise;
736         if (dst_d.data_type() == data_type::s32) return status::unimplemented;
737     }
738
739     bool args_ok = true
740         && jcp.ngroups == 1
741         && everyone_is(pick(ndims - 3, nCw16c, nChw16c), src_d.format(),
742             dst_d.format())
743         && one_of(cd.bias_desc.format, memory_format::undef, any, x);
744     if (!args_ok) return status::unimplemented;
745
746     args_ok = true
747         && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0
748         && jcp.t_pad == 0 && jcp.l_pad == 0
749         && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides
750         && jcp.kh == 1 && jcp.kw == 1;
751     if (!args_ok) return status::unimplemented;
752
753     jcp.ic_block = jcp.oc_block = simd_w;
754     jcp.transpose_src = false;
755
756     if (mayiuse(avx512_mic_4ops)
757         && ((one_of(jcp.prop_kind, forward_training, forward_inference)
758             && src_d.data_type() == data_type::s16
759             && weights_d.data_type() == data_type::s16
760             && dst_d.data_type() == data_type::s32)
761         || (jcp.prop_kind == backward_data
762             && src_d.data_type() == data_type::s32
763             && weights_d.data_type() == data_type::s16
764             && dst_d.data_type() == data_type::s16)))
765     {
766         const int is_bwd_d = jcp.prop_kind == backward_data;
767         memory_format_t weights_format = with_groups
768             ? pick(2 * ndims - 6 + is_bwd_d, gOIw8i16o2i, gOIw8o16i2o,
769                 gOIhw8i16o2i, gOIhw8o16i2o)
770             : pick(2 * ndims - 6 + is_bwd_d, OIw8i16o2i, OIw8o16i2o,
771                 OIhw8i16o2i, OIhw8o16i2o);
772
773         if (weights_d.format() != weights_format)
774             return status::unimplemented;
775
776         jcp.ver = ver_4vnni;
777         jcp.fma_step = 4;
778         jcp.typesize_in = sizeof(prec_traits<data_type::s16>::type);
779         jcp.typesize_out = sizeof(prec_traits<data_type::s32>::type);
780     }
781     else if (everyone_is(data_type::f32, src_d.data_type(),
782                             weights_d.data_type(), dst_d.data_type()))
783     {
784         const int is_bwd_d = jcp.prop_kind == backward_data;
785         memory_format_t weights_format = with_groups
786             ? pick(2 * ndims - 6 + is_bwd_d, gOIw16i16o, gIOw16o16i,
787                 gOIhw16i16o, gIOhw16o16i)
788             : pick(2 * ndims - 6 + is_bwd_d, OIw16i16o, IOw16o16i,
789                 OIhw16i16o, IOhw16o16i);
790
791         if (weights_d.format() != weights_format)
792             return status::unimplemented;
793         if (jcp.prop_kind != backward_weights && mayiuse(avx512_mic_4ops) &&
794             ((jcp.prop_kind == backward_data) ? jcp.oc_block : jcp.ic_block) % 4
795             == 0) {
796             jcp.ver = ver_4fma;
797             jcp.fma_step = 4;
798         } else if (jcp.prop_kind == backward_weights && mayiuse(avx512_mic_4ops)
799                 && !reduce_src
800                 /* Heuristic condition for relation of src size to oc. Otherwise
801                    the src transposition overhead exceed the benefit from 4fma
802                 */
803                 && ((jcp.is * jcp.ic) / jcp.oc <= 2048)
804                 && mkldnn_thr_syncable()
805                 )
806         {
807             jcp.transpose_src = true;
808             jcp.ver = ver_4fma;
809             jcp.fma_step = 4;
810         } else {
811             jcp.ver = (mayiuse(avx512_core)) ? ver_avx512_core : ver_fma;
812             jcp.fma_step = 1;
813         }
814         jcp.typesize_in = sizeof(prec_traits<data_type::f32>::type);
815         jcp.typesize_out = sizeof(prec_traits<data_type::f32>::type);
816     } else {
817         return status::unimplemented;
818     }
819
820     /* once all the formats are set, check the padding consistency */
821     args_ok = true
822         && jcp.ic <= src_d.blocking_desc().padding_dims[1]
823         && jcp.oc <= dst_d.blocking_desc().padding_dims[1]
824         && jcp.ic <= weights_d.blocking_desc().padding_dims[with_groups + 1]
825         && jcp.oc <= weights_d.blocking_desc().padding_dims[with_groups + 0];
826     if (!args_ok) return status::unimplemented;
827
828     const int SMALL_SPATIAL = 10;
829     const int BIG_SPATIAL = 28;
830     const int BIG_REDUCE_DIM = 1024;
831     const int BIG_LOAD_DIM = 256;
832
833     int load_blocking{ 0 };
834     int load_blocking_max{ 0 };
835     int bcast_blocking{ 0 };
836     int bcast_blocking_max{ 0 };
837     int reduce_blocking{ 0 };
838     int reduce_blocking_max{ 0 };
839
840     jcp.load_grp_count = 1;
841
842     const int L1_capacity = get_cache_size(1, true) / sizeof(float);
843     const int L2_size = get_cache_size(2, true) / sizeof(float);
844     const int L2_capacity = (L2_size * 3) / 4;
845
846     if (one_of(jcp.prop_kind, forward_training, forward_inference,
847                 backward_data)) {
848         if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
849             jcp.reduce_dim = jcp.ic;
850             jcp.reduce_block = jcp.ic_block;
851
852             jcp.load_dim = jcp.oc;
853             jcp.load_block = jcp.oc_block;
854
855             jcp.bcast_dim = jcp.is;
856         } else {
857             jcp.reduce_dim = jcp.oc;
858             jcp.reduce_block = jcp.oc_block;
859
860             jcp.load_dim = jcp.ic;
861             jcp.load_block = jcp.ic_block;
862
863             jcp.bcast_dim = jcp.os;
864         }
865         jcp.reduce_loop_unroll = jcp.reduce_block;
866         jcp.reduce_loop_bcast_step
867                 = jcp.reduce_loop_unroll * jcp.bcast_dim * jcp.typesize_in;
868
869         if (jcp.prop_kind == backward_data && jcp.ver == ver_4vnni) {
870             jcp.reduce_loop_load_step
871                     = jcp.reduce_loop_unroll * jcp.ic * jcp.typesize_in;
872             jcp.load_loop_load_step
873                     = jcp.oc_block * jcp.ic_block * jcp.typesize_in;
874         } else {
875             jcp.reduce_loop_load_step
876                     = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in;
877             jcp.load_loop_load_step
878                     = jcp.reduce_dim * jcp.load_block * jcp.typesize_in;
879         }
880
881         // adjusting registry blocking
882         int max_regs, min_regs, size_treshold, ur_step;
883         const int spatial
884                 = (one_of(jcp.prop_kind, forward_training, forward_inference)) ?
885                 jcp.oh :
886                 jcp.ih;
887         if (jcp.ver == ver_avx512_core && (8 * jcp.mb) / nthreads >= 1) {
888             max_regs = 9;
889             min_regs = 6;
890             size_treshold = 14;
891             ur_step = 1;
892             jcp.expl_bcast = true;
893
894             if (jcp.load_dim > 128 && jcp.load_dim < BIG_LOAD_DIM
895                     && spatial > SMALL_SPATIAL && spatial < BIG_SPATIAL) {
896                 max_regs = 6;
897                 min_regs = 5;
898             }
899         } else {
900             bool is4ops = (jcp.ver == ver_4fma || jcp.ver == ver_4vnni);
901
902             max_regs = is4ops ? 28 : 30;
903             min_regs = 9;
904             size_treshold = is4ops ? 28 : 14;
905             ur_step = is4ops ? 4 : 1;
906             jcp.expl_bcast = false;
907             jcp.use_vmovntps = true;
908         }
909         jcp.ur = 1;
910         for (int ur_w = max_regs; ur_w >= min_regs; ur_w -= ur_step) {
911             if ((spatial >= size_treshold && spatial % ur_w == 0)
912                     || (spatial < size_treshold && jcp.os % ur_w == 0)) {
913                 jcp.ur = ur_w;
914                 break;
915             }
916         }
917         if (jcp.ur == 1) {
918             jcp.ur = nstl::min(max_regs, jcp.os);
919             int os_tail = jcp.os % max_regs;
920             for (int i = max_regs; i >= min_regs; i -= ur_step) {
921                 int i_tail = jcp.os % i;
922                 if (i_tail > os_tail || i_tail == 0) {
923                     jcp.ur = i;
924                     os_tail = i_tail;
925                     if (i_tail == 0)
926                         break;
927                 }
928             }
929         }
930
931         jcp.reduce_loop_unroll = jcp.reduce_block;
932         jcp.reduce_loop_bcast_step
933                 = jcp.reduce_loop_unroll * jcp.bcast_dim * jcp.typesize_in;
934
935         jcp.bcast_block = jcp.ur;
936
937         jcp.bcast_loop_output_step = jcp.ur * jcp.load_block * jcp.typesize_out;
938         jcp.bcast_loop_output_substep = -1; // unused
939         jcp.bcast_loop_bcast_step = jcp.ur * jcp.reduce_block * jcp.typesize_in;
940         jcp.bcast_loop_bcast_substep = -1; // unused
941
942         jcp.load_loop_iter_step = jcp.load_block;
943
944         if (jcp.prop_kind == backward_data)
945             jcp.loop_order = loop_lbr;
946         else
947             jcp.loop_order = reduce_src ? loop_blr : loop_lbr;
948
949         int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
950         int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
951         int nb_load = div_up(jcp.load_dim, jcp.load_block);
952
953         if (jcp.ver == ver_avx512_core && jcp.expl_bcast) {
954             if (jcp.load_dim <= BIG_LOAD_DIM && spatial > SMALL_SPATIAL
955                     && spatial < BIG_SPATIAL)
956                 reduce_blocking = nstl::min(jcp.reduce_dim, 80);
957             else if (spatial > SMALL_SPATIAL)
958                 reduce_blocking = nstl::min(jcp.reduce_dim, 512);
959             else
960                 reduce_blocking = nstl::min(jcp.reduce_dim, 256);
961
962             if ((jcp.mb > 28 && spatial >= 28)
963                     || (jcp.mb > 112 && spatial >= 17))
964                 jcp.use_vmovntps = true;
965             else
966                 jcp.use_vmovntps = false;
967         } else {
968
969             reduce_blocking = nb_reduce;
970             if (spatial <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM)
971                 reduce_blocking = 16;
972             else if (spatial > SMALL_SPATIAL
973                     && jcp.reduce_dim >= BIG_REDUCE_DIM)
974                 reduce_blocking = 8;
975             reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true);
976             reduce_blocking *= jcp.reduce_block;
977         }
978
979         // Check input data cache aliasing.
980         // For other ISA constants may be updated.
981         // 64 * 1024 is chosen due to 1MB L2 16-way cache.
982         // 7 is empirical value. It is about half of 16.
983         // So we leave about half of the set for other data - weights, dst
984         int way_size = (64 * 1024) / jcp.typesize_in;
985         int max_hits = 7;
986         if (jcp.bcast_dim * reduce_blocking > way_size * max_hits) {
987             int nrb = reduce_blocking / simd_w;
988             int sp = jcp.bcast_dim;
989             int wl = way_size / simd_w;
990             for (int start_off = 0; start_off < jcp.ur; start_off++) {
991                 for (int off = start_off, hits = 0; off < sp * nrb; off += wl) {
992                     if (off % sp >= jcp.ur || ++hits < max_hits)
993                         continue;
994                     int max_r_blocking = simd_w * nstl::max(1, (off + wl) / sp);
995                     reduce_blocking
996                             = nstl::min(reduce_blocking, max_r_blocking);
997                     break;
998                 }
999             }
1000         }
1001
1002         if (reduce_blocking < jcp.reduce_dim) {
1003             jcp.use_vmovntps = false;
1004             if (jcp.prop_kind == backward_data)
1005                 jcp.loop_order = reduce_src ? loop_lbr : loop_rlb;
1006             else
1007                 jcp.loop_order = reduce_src ? loop_rbl : loop_rlb;
1008         }
1009         load_blocking = jcp.load_dim;
1010
1011         int load_size = jcp.load_dim * jcp.reduce_dim;
1012         int bcast_size = jcp.mb * jcp.ngroups * jcp.bcast_dim * jcp.reduce_dim;
1013
1014         if (jcp.ver == ver_avx512_core && nthreads <= 28 && jcp.mb < nthreads
1015                 && nb_load * nb_bcast > nthreads) {
1016             // Some heuristic here
1017             float calc_koef = 0.01, best_cost = FLT_MAX;
1018             int n_lgc = nthreads;
1019             float ratio = (float)load_size / (float)bcast_size;
1020             int best_lgc = ratio > 1 ? n_lgc : 1;
1021             auto calc_job_cost = [&](int lb, int tg, float mem_k) {
1022                 int bb_size = jcp.mb * div_up(nb_bcast, tg);
1023                 float calc_size = (float)(bb_size * jcp.ur)
1024                         * (lb * jcp.load_block) * jcp.reduce_dim;
1025                 float mem_size = (float)(bb_size * jcp.ur + lb * jcp.load_block)
1026                         * jcp.reduce_dim;
1027                 return calc_koef * calc_size + mem_k * mem_size;
1028             };
1029             for (int lgc, ilgc = 0; ilgc < n_lgc; ilgc++) {
1030                 lgc = ratio > 1 ? n_lgc - ilgc : ilgc + 1;
1031                 int min_lb = nb_load / lgc;
1032                 int max_lb = div_up(nb_load, lgc);
1033                 int min_tg = nthreads / lgc;
1034                 int max_tg = div_up(nthreads, lgc);
1035                 // Some heuristic here
1036                 float mem_koef = (max_tg == 1) ? 1.f : 1.3f;
1037                 float job_cost = 0.;
1038                 if (nthreads % lgc < nb_load % lgc) {
1039                     job_cost = calc_job_cost(max_lb, min_tg, mem_koef);
1040                 } else {
1041                     auto job_cost1 = calc_job_cost(max_lb, max_tg, mem_koef);
1042                     auto job_cost2 = calc_job_cost(min_lb, min_tg, mem_koef);
1043                     job_cost = nstl::max(job_cost1, job_cost2);
1044                 }
1045
1046                 if (job_cost < best_cost) {
1047                     best_lgc = lgc;
1048                     best_cost = job_cost;
1049                 }
1050             }
1051             jcp.load_grp_count = best_lgc;
1052             load_blocking = div_up(nb_load, jcp.load_grp_count) * jcp.load_block;
1053         } else {
1054             jcp.load_grp_count = div_up(nthreads, jcp.mb * jcp.ngroups * nb_bcast);
1055             jcp.load_grp_count = best_divider(
1056                 nthreads, jcp.load_grp_count, 2 * jcp.load_grp_count, false);
1057         }
1058
1059         if (jcp.ver == ver_avx512_core && jcp.expl_bcast && jcp.bcast_dim <= 64
1060                 && load_size >= L2_size) {
1061             jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4);
1062         } else if (jcp.bcast_dim <= 49 && jcp.mb <= nthreads
1063                 && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) {
1064             jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2);
1065             load_blocking = jcp.load_block;
1066         }
1067
1068         if (jcp.ver == ver_4fma && jcp.bcast_dim * jcp.mb < jcp.load_dim
1069                 && jcp.oh * jcp.ow > 64
1070                 && IMPLICATION(reduce_src, jcp.load_dim < 1024)) {
1071             /* Looking for best loading dimension blocking
1072             * to get the best thread and data read/write efficiency
1073             * by finding the optimal 'load_chunk' value
1074             * Example:
1075             * for 72 threads and convolution with mb=1, ih=iw=7, oc = 512
1076             * the 'best' load_chunk value should be 1
1077             * TODO: remove heuristic constants in above condition
1078             * TODO: check this blocking for other ISA
1079             */
1080             float best_eff = -1.f;
1081             int best_lgc = 1;
1082
1083             for (int load_chunk = 1; load_chunk <= nb_load; load_chunk++) {
1084                 int lgc = div_up(nb_load, load_chunk);
1085                 if (lgc > nthreads)
1086                     continue;
1087                 int thr_per_grp = div_up(nthreads, lgc);
1088                 int bcast_per_thr = div_up(jcp.mb * nb_bcast, thr_per_grp)
1089                         * jcp.bcast_block;
1090                 int load_per_thr = load_chunk * simd_w;
1091                 float data_norm = (bcast_per_thr + load_per_thr) / 2.f;
1092                 float data_eff = (bcast_per_thr * load_per_thr)
1093                         / (data_norm * data_norm);
1094                 float thr_eff_over_grp = (float)nstl::max(1, nthreads / lgc)
1095                         / div_up(nthreads, lgc);
1096                 float thr_eff_in_grp = ((float)jcp.mb * nb_bcast)
1097                         / rnd_up(jcp.mb * nb_bcast, thr_per_grp);
1098                 float thr_eff = thr_eff_over_grp * thr_eff_in_grp;
1099                 float load_eff = (float)nb_load / rnd_up(nb_load, lgc);
1100                 float overall_eff = data_eff + thr_eff + load_eff;
1101                 if (overall_eff > best_eff) {
1102                     best_eff = overall_eff;
1103                     best_lgc = lgc;
1104                 }
1105             }
1106             jcp.load_grp_count = best_lgc;
1107             load_blocking
1108                     = div_up(nb_load, jcp.load_grp_count) * jcp.load_block;
1109         }
1110         bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast,
1111                                  div_up(nthreads, jcp.load_grp_count))
1112                 * jcp.bcast_block;
1113         bcast_blocking = nstl::min(jcp.bcast_dim, bcast_blocking);
1114         bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block);
1115
1116         int space_for_bcast
1117                 = (L2_capacity - /* kernel_size - */
1118                     2 * jcp.load_block * reduce_blocking
1119                         - jcp.ur * reduce_blocking - 3 * 1024);
1120         if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity)
1121             space_for_bcast /= 2;
1122
1123         int bcast_in_cache
1124                 = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking);
1125         bcast_blocking = nstl::min(
1126                 bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block));
1127
1128         load_blocking_max = load_blocking;
1129         bcast_blocking_max = bcast_blocking * 3 / 2;
1130         reduce_blocking_max = reduce_blocking;
1131
1132     } else if (jcp.prop_kind == backward_weights) {
1133
1134         jcp.use_vmovntps = false;
1135         if (jcp.is > SMALL_SPATIAL * SMALL_SPATIAL && jcp.ver == ver_4fma)
1136             jcp.use_vmovntps = true;
1137
1138         if (jcp.transpose_src)
1139             jcp.reduce_dim = jcp.tr_is;
1140         else
1141             jcp.reduce_dim = jcp.is;
1142
1143         if (jcp.ver == ver_4fma) {
1144             // reduce_block should be divided by fma_step
1145             jcp.reduce_block = best_divider(jcp.reduce_dim, 4, 16, true, 4);
1146         } else {
1147             jcp.reduce_block = best_divider(jcp.reduce_dim, 7, 16, true);
1148             if (jcp.reduce_dim % jcp.reduce_block != 0)
1149                 jcp.reduce_block = best_divider(jcp.iw, 4, jcp.iw, false);
1150             if (jcp.reduce_block > 256) {
1151                 jcp.reduce_block = 1;
1152             }
1153
1154         }
1155
1156         jcp.load_dim = jcp.oc;
1157         jcp.load_block = jcp.oc_block;
1158
1159         jcp.bcast_dim = jcp.ic;
1160         jcp.bcast_block = jcp.ic_block;
1161
1162         if (jcp.ver == ver_avx512_core && jcp.reduce_block <= 19) {
1163             // if reduce_block is big then generated JIT code may be big
1164             // for small values of ur because reduce_loop_unroll = reduce_block
1165             jcp.ur = jcp.bcast_block / 2;
1166             jcp.expl_bcast = true;
1167         } else {
1168             jcp.ur = jcp.bcast_block;
1169             jcp.expl_bcast = false;
1170         }
1171
1172         jcp.reduce_loop_unroll = jcp.reduce_block;
1173         jcp.reduce_loop_bcast_step
1174             = jcp.reduce_loop_unroll * jcp.ic_block * jcp.typesize_in;
1175         jcp.reduce_loop_load_step
1176             = jcp.reduce_loop_unroll * jcp.oc_block * jcp.typesize_in;
1177
1178         jcp.bcast_loop_output_step =
1179                                 jcp.oc_block * jcp.ic_block * jcp.typesize_out;
1180         jcp.bcast_loop_output_substep =
1181             jcp.oc_block * jcp.ur * jcp.typesize_out;
1182         jcp.bcast_loop_bcast_step =
1183                 jcp.ic_block * jcp.reduce_dim * jcp.typesize_in;
1184         jcp.bcast_loop_bcast_substep = jcp.ur * jcp.typesize_in;
1185
1186         jcp.load_loop_load_step = jcp.oc_block * jcp.os * jcp.typesize_in;
1187         jcp.load_loop_iter_step = jcp.oc_block;
1188
1189         /* --- */
1190         balance(jcp, nthreads);
1191
1192         load_blocking = div_up(jcp.load_dim, jcp.load_block);
1193         load_blocking = best_divider(load_blocking, 16, load_blocking, false);
1194         load_blocking *= jcp.load_block;
1195
1196         load_blocking_max = load_blocking;
1197         assert(jcp.load_dim % load_blocking == 0);
1198
1199         int max_bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
1200         int min_bcast_blocking = 5;
1201
1202         bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
1203         bcast_blocking = best_divider(
1204                 bcast_blocking, min_bcast_blocking, max_bcast_blocking, false);
1205         bcast_blocking *= jcp.bcast_block;
1206         bcast_blocking_max = bcast_blocking;
1207         assert(jcp.bcast_dim % bcast_blocking == 0);
1208
1209         // for reduction balance
1210         if (jcp.ver == ver_avx512_core) {
1211             int max_reduce_blocking
1212                     = nstl::min(L1_capacity / jcp.ur, jcp.reduce_dim);
1213             int min_reduce_blocking = nstl::min(
1214                     L1_capacity / jcp.ur, nstl::max(jcp.iw, jcp.ih));
1215             reduce_blocking = best_divider(jcp.reduce_dim, min_reduce_blocking,
1216                     max_reduce_blocking, true);
1217             reduce_blocking
1218                     = nstl::max(rnd_dn(reduce_blocking, jcp.reduce_block),
1219                             jcp.reduce_block);
1220         } else {
1221             int max_reduce_blocking = L2_capacity
1222                     / ((bcast_blocking + load_blocking) * jcp.reduce_block);
1223             max_reduce_blocking = nstl::min(max_reduce_blocking,
1224                     (L1_capacity / (jcp.bcast_block)) / jcp.reduce_block);
1225
1226             int num_jobs = div_up(jcp.load_dim, load_blocking)
1227                     * div_up(jcp.bcast_dim, bcast_blocking);
1228             int threads_per_job = nstl::max(1, nthreads / num_jobs);
1229             reduce_blocking = div_up(jcp.mb * jcp.reduce_dim, jcp.reduce_block);
1230             reduce_blocking = div_up(reduce_blocking, threads_per_job);
1231
1232             reduce_blocking = best_divider(reduce_blocking,
1233                     max_reduce_blocking - 2, max_reduce_blocking, true);
1234             reduce_blocking *= jcp.reduce_block;
1235         }
1236
1237         reduce_blocking_max = rnd_dn(reduce_blocking * 3 / 2, jcp.reduce_block);
1238     } else
1239         return status::unimplemented;
1240
1241     assert(load_blocking);
1242     assert(load_blocking_max);
1243     assert(bcast_blocking);
1244     assert(bcast_blocking_max);
1245     assert(reduce_blocking);
1246     assert(reduce_blocking_max);
1247     assert(load_blocking % jcp.load_block == 0);
1248     assert(reduce_blocking % jcp.reduce_block == 0);
1249     assert(load_blocking_max % jcp.load_block == 0);
1250     assert(reduce_blocking_max % jcp.reduce_block == 0);
1251     if (jcp.ver == ver_4fma || jcp.ver == ver_4vnni) {
1252         if (jcp.ver == ver_4fma)
1253             assert(jcp.reduce_loop_unroll % jcp.fma_step == 0);
1254         if (jcp.ver == ver_4vnni)
1255             assert(jcp.reduce_loop_unroll % (2 * jcp.fma_step) == 0);
1256         assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0);
1257     }
1258
1259     assert(jcp.bcast_block % jcp.ur == 0);
1260     assert(jcp.reduce_dim % jcp.reduce_block == 0);
1261
1262     jcp.ur_tail = jcp.bcast_dim % jcp.ur;
1263
1264     jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block;
1265     jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block;
1266     jcp.nb_load_blocking = load_blocking / jcp.load_block;
1267     jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block;
1268     jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block;
1269     jcp.nb_reduce_blocking_max = reduce_blocking_max / jcp.reduce_block;
1270
1271     jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
1272     jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
1273     jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
1274
1275     return status::success;
1276 }
1277
1278 void jit_avx512_common_1x1_conv_kernel::init_scratchpad(
1279         memory_tracking::registrar_t &scratchpad,
1280         const jit_1x1_conv_conf_t &jcp) {
1281     using namespace mkldnn::impl::memory_tracking::names;
1282
1283     if (jcp.prop_kind != backward_data && jcp.with_bias
1284             && jcp.oc != jcp.oc_without_padding)
1285         scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc);
1286
1287     if (jcp.prop_kind == backward_weights) {
1288         const size_t wei_size = (size_t)jcp.ngroups * jcp.oc * jcp.ic;
1289         scratchpad.book(key_conv_wei_reduction,
1290                 jcp.typesize_out * wei_size * (jcp.nthr_mb - 1));
1291     }
1292
1293     if (jcp.transpose_src) {
1294         const size_t tr_src_size =
1295             (size_t)jcp.nthr_mb * jcp.ngroups * jcp.ic * jcp.tr_is;
1296         scratchpad.book(key_conv_tr_src, jcp.typesize_out * tr_src_size);
1297         scratchpad.book(key_conv_tr_src_bctx,
1298                 sizeof(simple_barrier::ctx_t) * jcp.nthr);
1299     }
1300 }
1301
1302 void jit_avx512_common_1x1_conv_kernel::balance(jit_1x1_conv_conf_t &jcp,
1303         int nthreads)
1304 {
1305     // initialize jcp reduction threading properties
1306     jcp.nthr = jcp.nthr_mb = jcp.nthr_g = jcp.nthr_oc_b = jcp.nthr_ic_b = 1;
1307     if (nthreads < jcp.ngroups) {
1308         /* simplification... fortunately it doesn't hurt much */
1309         return;
1310     }
1311     const int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
1312     const int nb_load = div_up(jcp.load_dim, jcp.load_block);
1313     const int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
1314
1315     jcp.nthr_g = jcp.ngroups;
1316     const int nthr = nthreads / jcp.nthr_g;
1317
1318     auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
1319         /* calculate per thread memory cost (read/write). high level
1320         * optimizer tries to minimize memory consumption. few notes: (n1)
1321         * unclear why, but that essentially helps first convolution...
1322         *  (n2) assuming the reduction over minibatch is always there:
1323         *    - instead of 8 it should be 5 here (write ~= 2 read):
1324         *      kernel: temporal workspace 1 write
1325         *      reduction: 1 read from workspace and 1 write to the diff_wei
1326         *    - but experiments showed 8 works better than 5 or 6... */
1327         int bcast_koeff = 1;
1328         int load_koeff = 1;
1329         int output_koeff = 12;
1330         if (jcp.transpose_src) {
1331             bcast_koeff = 5;
1332             load_koeff = 1;
1333             output_koeff = 8;
1334         }
1335         return 0
1336             + (size_t)bcast_koeff * div_up(jcp.mb * nb_reduce, nthr_mb)
1337             * div_up(jcp.ngroups, jcp.nthr_g)
1338             * div_up(nb_bcast, nthr_ic_b) * jcp.ic_block * jcp.reduce_block
1339             / jcp.stride_h / jcp.stride_w /* (n1) */
1340             + (size_t)load_koeff * div_up(jcp.mb * nb_reduce, nthr_mb)
1341             * div_up(jcp.ngroups, jcp.nthr_g)
1342             * div_up(nb_load, nthr_oc_b) * jcp.oc_block * jcp.reduce_block
1343             + (size_t)output_koeff /* (n2) */
1344             * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_load, nthr_oc_b)
1345             * div_up(nb_bcast, nthr_ic_b) * jcp.ic_block
1346             * jcp.oc_block;
1347     };
1348
1349     int nthr_mb = 1, nthr_oc_b = 1, nthr_ic_b = 1;
1350     auto best_mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
1351
1352     /* step 1: find the best thread distribution with lowest memory cost */
1353     const int nthr_mb_max = nstl::min(nthr, jcp.mb * nb_reduce);
1354     for (nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
1355         const int nthr_par = nthr / nthr_mb;
1356         const int nthr_oc_b_max = nstl::min(nthr_par, nb_load);
1357         for (nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
1358             nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, nb_bcast);
1359             auto mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
1360             if (mem_cost <= best_mem_cost) {
1361                 best_mem_cost = mem_cost;
1362                 jcp.nthr_mb = nthr_mb;
1363                 jcp.nthr_oc_b = nthr_oc_b;
1364                 jcp.nthr_ic_b = nthr_ic_b;
1365             }
1366         }
1367
1368         if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; }
1369     }
1370     if (jcp.nthr_mb > nthreads / 2 && jcp.nthr_mb < nthreads)
1371         jcp.nthr_mb = nstl::min(jcp.mb, nthreads);
1372
1373     jcp.nthr = jcp.nthr_mb * jcp.nthr_g * jcp.nthr_oc_b * jcp.nthr_ic_b;
1374     assert(jcp.nthr <= nthreads);
1375 }
1376
1377 }
1378 }
1379 }