updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_common_1x1_conv_kernel.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18 #include <float.h>
19
20 #include "c_types_map.hpp"
21 #include "memory_tracking.hpp"
22 #include "mkldnn_thread.hpp"
23 #include "nstl.hpp"
24 #include "type_helpers.hpp"
25 #include "utils.hpp"
26
27 #include "cpu_memory.hpp"
28 #include "cpu_barrier.hpp"
29
30 #include "jit_uni_1x1_conv_utils.hpp"
31 #include "jit_avx512_common_1x1_conv_kernel.hpp"
32
33 #define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
34
35 namespace mkldnn {
36 namespace impl {
37 namespace cpu {
38
39 using namespace mkldnn::impl::prop_kind;
40 using namespace mkldnn::impl::memory_format;
41 using namespace mkldnn::impl::utils;
42
43 using namespace Xbyak;
44
45 void jit_avx512_common_1x1_conv_kernel::bcast_loop(int load_loop_blk)
46 {
47     mov(aux1_reg_bcast_data, reg_bcast_data);
48     mov(aux_reg_bcast_data, reg_bcast_data);
49
50     mov(aux_reg_output_data, reg_output_data);
51     mov(bcast_loop_iter, EVEX_compress_addr(rsp, bcast_loop_work_offt));
52
53     if (jcp.ver == ver_4fma)
54     {
55         Label bcast_loop;
56         Label bcast_loop_wraparound;
57         Label bcast_loop_out;
58         Label bcast_loop_ur_full;
59
60         cmp(bcast_loop_iter, jcp.ur);
61         jle(bcast_loop_wraparound, T_NEAR);
62
63         L(bcast_loop); {
64             assert(jcp.bcast_block % jcp.ur == 0);
65             int num_substeps = jcp.bcast_block / jcp.ur;
66             assert(num_substeps > 0 && num_substeps < 10);
67             for (int i = 0; i < num_substeps; i++) {
68                 reduce_loop(load_loop_blk, jcp.ur, i, false);
69                 if (i < num_substeps - 1) {
70                     add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
71                     add(aux_reg_output_data, jcp.bcast_loop_output_substep);
72                 }
73                 else {
74                     add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step
75                         - (num_substeps - 1) * jcp.bcast_loop_bcast_substep);
76                     add(aux_reg_output_data, jcp.bcast_loop_output_step
77                         - (num_substeps - 1) * jcp.bcast_loop_output_substep);
78                 }
79             }
80             sub(bcast_loop_iter, jcp.bcast_block);
81             cmp(bcast_loop_iter, jcp.bcast_block);
82             jg(bcast_loop, T_NEAR);
83         }
84
85         L(bcast_loop_wraparound);
86         if (jcp.ur_tail) {
87             je(bcast_loop_ur_full, T_NEAR);
88             reduce_loop(load_loop_blk, jcp.ur_tail, 0, true);
89             jmp(bcast_loop_out, T_NEAR);
90         }
91         L(bcast_loop_ur_full);
92         reduce_loop(load_loop_blk, jcp.ur, 0, true);
93         L(bcast_loop_out);
94     }
95     else
96     {
97         Label bcast_loop;
98         Label bcast_loop_tail;
99
100         cmp(bcast_loop_iter, jcp.ur);
101         jl(bcast_loop_tail, T_NEAR);
102
103         L(bcast_loop); {
104             assert(jcp.bcast_block % jcp.ur == 0);
105             int num_substeps = jcp.bcast_block / jcp.ur;
106             assert(num_substeps > 0 && num_substeps < 10);
107             for (int i = 0; i < num_substeps; i++) {
108                 reduce_loop(load_loop_blk, jcp.ur, i, false);
109                 if (i < num_substeps - 1) {
110                     add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
111                     add(aux_reg_output_data, jcp.bcast_loop_output_substep);
112                 }
113                 else {
114                     add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step
115                         - (num_substeps - 1) * jcp.bcast_loop_bcast_substep);
116                     add(aux_reg_output_data, jcp.bcast_loop_output_step
117                         - (num_substeps - 1) * jcp.bcast_loop_output_substep);
118                 }
119             }
120             sub(bcast_loop_iter, jcp.bcast_block);
121             cmp(bcast_loop_iter, jcp.bcast_block);
122             jge(bcast_loop, T_NEAR);
123         }
124
125         L(bcast_loop_tail);
126         if (jcp.ur_tail) {
127             Label bcast_loop_tail_out;
128             cmp(bcast_loop_iter, 0);
129             jz(bcast_loop_tail_out, T_NEAR);
130             reduce_loop(load_loop_blk, jcp.ur_tail, 0, true);
131             L(bcast_loop_tail_out);
132         }
133     }
134 }
135
136 void jit_avx512_common_1x1_conv_kernel::reduce_loop(int load_loop_blk,
137          int ur, int substep, bool wraparound)
138 {
139     auto vreg_load = [=](int i_load, int i_fma) {
140         return Zmm(utils::rnd_up(ur * load_loop_blk, jcp.fma_step)
141                     + jcp.fma_step * i_load + i_fma);
142     };
143
144     auto vreg_accum = [=](int i_load, int i_ur) {
145         return Zmm(i_ur + i_load * ur);
146     };
147
148     auto bias_ptr = [=](int i_load) {
149         return EVEX_compress_addr(reg_bias_data,
150                                   jcp.typesize_out * jcp.oc_block * i_load);
151     };
152
153     auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast) {
154         assert(i_ur < jcp.ur);
155         assert(i_reduce <= jcp.reduce_loop_unroll);
156         int offt;
157         if (one_of(jcp.prop_kind, forward_training, forward_inference,
158                    backward_data)) {
159             assert(jcp.reduce_loop_unroll == jcp.reduce_block);
160             offt = (i_reduce == jcp.reduce_loop_unroll)
161                     ? (jcp.bcast_dim + i_ur) * jcp.reduce_loop_unroll
162                     : i_ur * jcp.reduce_loop_unroll + i_reduce;
163         } else {
164             if (jcp.transpose_src) {
165                 const int reduce_group = i_reduce / 4;
166                 const int reduce_shift = i_reduce % 4;
167                 offt = 4 * (reduce_group * jcp.ic_block + i_ur) + reduce_shift;
168             }
169             else
170                 offt = i_reduce * jcp.ic_block + i_ur;
171         }
172         return EVEX_compress_addr(aux_reg_bcast_data, jcp.typesize_in * offt,
173                                 bcast);
174     };
175
176     auto load_ptr = [=](int i_reduce, int i_load) {
177         int offt;
178         int u0 = i_reduce % jcp.reduce_loop_unroll;
179         int u1 = i_reduce / jcp.reduce_loop_unroll;
180         if (jcp.prop_kind == backward_data && jcp.ver == ver_4vnni)
181             offt = (i_load * jcp.reduce_block + u0) * jcp.load_block;
182         else
183             offt = (i_load * jcp.reduce_dim + u0) * jcp.load_block;
184         return EVEX_compress_addr(aux_reg_load_data,
185                                   u1 * jcp.reduce_loop_load_step
186                                   + jcp.typesize_in * offt);
187     };
188
189     auto output_ptr = [=](int i_load, int i_ur) {
190         if (one_of(jcp.prop_kind, forward_training, forward_inference,
191                    backward_data))
192             return EVEX_compress_addr(aux_reg_output_data,
193                     (i_load * jcp.bcast_dim + i_ur) * jcp.load_block
194                     * jcp.typesize_out);
195         else
196             return ptr[aux_reg_output_data +
197                        (i_load
198                             ? reg_output_stride * i_load
199                             : 0) // TODO: Xbyak should allow 0 scale
200                        + jcp.typesize_out * jcp.load_block * i_ur];
201     };
202
203     auto init = [=]() {
204         Label init_done;
205         Label init_zero;
206
207         if (jcp.with_sum) {
208             for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
209                 for (int i_ur = 0; i_ur < ur; ++i_ur) {
210                     mic_prefetcht1(output_ptr(i_load, i_ur));
211                 }
212             }
213         }
214
215         if (jcp.with_bias
216             && one_of(jcp.prop_kind, forward_training, forward_inference)) {
217             test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
218             jz(init_zero, T_NEAR);
219
220             for (int i_load = 0; i_load < load_loop_blk; i_load++)
221                 for (int i_ur = 0; i_ur < ur; ++i_ur)
222                     vmovups(vreg_accum(i_load, i_ur), bias_ptr(i_load));
223             jmp(init_done, T_NEAR);
224         }
225
226         L(init_zero);
227         for (int i_load = 0; i_load < load_loop_blk; ++i_load)
228             for (int i_ur = 0; i_ur < ur; ++i_ur) {
229                 auto r = vreg_accum(i_load, i_ur);
230                 vpxord(r, r, r);
231             }
232         L(init_done);
233     };
234
235     auto vadd = [=](const Xmm& x1, const Xmm& x2, const Operand& op) {
236         if (jcp.ver == ver_4vnni)
237             vpaddd(x1, x2, op);
238         else
239             vaddps(x1, x2, op);
240     };
241
242     auto store = [=]() {
243
244         Label store_noadd;
245         if (!jcp.with_sum) {
246             test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
247             jnz(store_noadd, T_NEAR);
248         }
249
250         for (int i_ur = 0; i_ur < ur; ++i_ur)
251             for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
252                 auto r = vreg_accum(i_load, i_ur);
253                 vadd(r, r, output_ptr(i_load, i_ur));
254             }
255
256         L(store_noadd);
257
258         Label store_nopostproc;
259         test(reg_reduce_pos_flag, FLAG_REDUCE_LAST);
260         jz(store_nopostproc, T_NEAR);
261
262         int eltwise_inj_idx = 0;
263         int depthwise_inj_idx = 0;
264         const auto &p = attr_.post_ops_;
265
266         for (int i = 0; i < p.len_; i++) {
267             auto& post_op = p.entry_[i];
268             if (post_op.is_eltwise()) {
269                 if (jcp.ver == ver_4vnni) {
270                     zmm_t zmm_zero = vreg_bcast;
271                     vpxord(zmm_zero, zmm_zero, zmm_zero);
272
273                     for (int i_ur = 0; i_ur < ur; ++i_ur) {
274                         for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
275                             Zmm zmm = vreg_accum(i_load, i_ur);
276                             vpcmpd(k1, zmm, zmm_zero, _cmp_lt_os);
277                             vpmulld(zmm | k1, zmm, zmm_zero);
278                         }
279                     }
280                 } else {
281                     eltwise_injectors[eltwise_inj_idx]->compute_vector_range(0, ur * load_loop_blk);
282                 }
283                 eltwise_inj_idx++;
284             } else if (post_op.is_depthwise()) {
285                 mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
286                 mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
287
288                 add(reg_d_weights, reg_oc_off);
289                 add(reg_d_bias, reg_oc_off);
290
291                 for (int j = 0; j < load_loop_blk; ++j) {
292                     int start_idx = vreg_accum(j, 0).getIdx();
293                     int end_idx = start_idx + ur;
294
295                     depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
296                             start_idx, end_idx, reg_d_weights, reg_d_bias);
297
298                     add(reg_d_weights, jcp.oc_block * sizeof(float));
299                     add(reg_d_bias, jcp.oc_block * sizeof(float));
300                 }
301
302                 depthwise_inj_idx++;
303             }
304         }
305
306         L(store_nopostproc);
307
308         auto store_output = [=](bool output_is_aligned) {
309             for (int i_ur = 0; i_ur < ur; ++i_ur)
310                 for (int i_load = 0; i_load < load_loop_blk; ++i_load)
311                     if (output_is_aligned && jcp.use_vmovntps)
312                         vmovntps(output_ptr(i_load, i_ur),
313                             vreg_accum(i_load, i_ur));
314                     else
315                         vmovups(output_ptr(i_load, i_ur),
316                             vreg_accum(i_load, i_ur));
317         };
318
319         Label unaligned_store, end_store;
320         test(aux_reg_output_data, cpu_isa_traits<avx512_common>::vlen - 1);
321         jnz(unaligned_store, T_NEAR);
322         store_output(true);
323         jmp(end_store, T_NEAR);
324         L(unaligned_store); {
325             store_output(false);
326         }
327         L(end_store);
328     };
329
330     auto prefetch_callback = [=](int ur, int i_reduce, int i_ur, int i_load,
331         bool last_block, bool wraparound, int reduce_step)
332     {
333         bool pf_ker_l1 = true;
334         bool pf_ker_l2 = wraparound;
335         int n_ops = (jcp.reduce_loop_unroll / reduce_step) * ur * load_loop_blk;
336         int i_op = (i_reduce / reduce_step) * ur * load_loop_blk +
337             i_ur * load_loop_blk + i_load;
338
339         int n_pf_ker_l1 = pf_ker_l1 ? jcp.reduce_block : 0;
340         int n_pf_ker_l2 = pf_ker_l2 && wraparound ? jcp.reduce_block : 0;
341         int n_pf_out_l1 = jcp.use_vmovntps ? 0 : ur;
342
343         int pf_inp_ops = n_ops / 2; // # of operations during which to pf input
344         int pf_inp_trigger;
345         if (jcp.prop_kind == backward_weights)
346             pf_inp_trigger = nstl::max(1, pf_inp_ops / jcp.reduce_block);
347         else
348             pf_inp_trigger = nstl::max(1, pf_inp_ops / ur);
349
350         int n_other_pf =
351             load_loop_blk * (n_pf_ker_l1 + n_pf_ker_l2 + n_pf_out_l1);
352         int n_other_pf_ops = n_ops - pf_inp_ops;
353         int other_pf_trigger
354                 = n_other_pf ? nstl::max(1, n_other_pf_ops / n_other_pf) : 0;
355
356         if (i_op < pf_inp_ops && i_op % pf_inp_trigger == 0) {
357             // input prefetches have the highest priority b/c the
358             // first iteration of the kernel block touches all the
359             // cache lines
360             int i_pf = i_op / pf_inp_trigger;
361             auto pf_reg = wraparound && last_block
362                                   ? reg_bcast_data
363                                   : (last_block ? aux1_reg_bcast_data
364                                                 : aux_reg_bcast_data);
365             int offt = i_pf;
366             if (jcp.prop_kind == backward_weights) {
367                 offt += wraparound && last_block
368                                     ? 0
369                                     : (last_block ? jcp.is : jcp.reduce_block);
370                 offt *= jcp.bcast_block;
371             } else {
372                 offt += wraparound && last_block
373                                     ? 0
374                                     : (last_block ? jcp.ur : jcp.bcast_dim);
375                 offt *= jcp.reduce_block;
376             }
377             mic_prefetcht0(ptr[pf_reg + offt * jcp.typesize_in]);
378         } else if (i_op >= pf_inp_ops && n_other_pf) {
379             // remaining prefetches are spread among the rest of the
380             // operations; prefetches for output take priority
381             // TODO: spread L2 prefetches among L1 prefetches
382             i_op -= pf_inp_ops;
383             if (i_op % other_pf_trigger == 0) {
384                 int i_pf = i_op / (load_loop_blk * other_pf_trigger);
385                 if (i_pf < n_pf_ker_l2) {
386                     int offt = (i_pf + (i_load + 1) * jcp.reduce_dim)
387                         * jcp.load_block;
388                     if (jcp.prop_kind == backward_data && jcp.ver == ver_4vnni)
389                         offt = (i_pf + (i_load + 1) * jcp.reduce_block)
390                                 * jcp.load_block;
391
392                     mic_prefetcht1(ptr[aux_reg_load_data
393                                     + offt * jcp.typesize_in]);
394                 } else if (i_pf < n_pf_ker_l2 + n_pf_ker_l1) {
395                     i_pf -= n_pf_ker_l2;
396                     auto pf_reg = last_block ? reg_load_data
397                                              : aux_reg_load_data;
398                     int offt = (i_pf + i_load * jcp.reduce_dim
399                         + (last_block
400                             ? (wraparound ? jcp.reduce_dim : 0)
401                             : jcp.reduce_block))
402                         * jcp.load_block;
403                     mic_prefetcht0(ptr[pf_reg + offt * jcp.typesize_in]);
404                 } else if (i_pf < n_pf_ker_l1 + n_pf_ker_l2 + n_pf_out_l1) {
405                     i_pf -= n_pf_ker_l1 + n_pf_ker_l2;
406                     int offt = i_pf * jcp.load_block;
407                     mic_prefetcht0(ptr[aux_reg_output_data
408                                     + offt * jcp.typesize_out]);
409                 }
410             }
411         }
412     };
413
414     auto fma_block = [=](bool last_block) {
415         assert(jcp.reduce_loop_unroll % jcp.fma_step == 0);
416
417         int reduce_step = jcp.fma_step;
418         if (jcp.ver == ver_4vnni)
419             reduce_step *= 2;
420
421         for (int i_reduce = 0; i_reduce < jcp.reduce_loop_unroll;
422                 i_reduce += reduce_step) {
423             int load_scale = (jcp.ver == ver_4vnni) ? 2 : 1;
424             for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
425                 // if transposed input data used and if spatial size is
426                 // not divided by transpose step (4) then for last reduce step
427                 // we should load only needed load_registers data
428                 // and clear remaining
429                 if (jcp.transpose_src && jcp.is % jcp.fma_step && last_block
430                         && i_reduce == jcp.reduce_loop_unroll - reduce_step) {
431                     Label load_all;
432                     Label load_finish;
433                     test(reg_reduce_pos_flag, FLAG_SP_LAST);
434                     jz(load_all, T_NEAR);
435
436                     const int n_loads = jcp.is % jcp.fma_step;
437                     for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) {
438                         if (i_fma < n_loads)
439                             vmovups(vreg_load(i_load, i_fma),
440                                     load_ptr(i_reduce + load_scale * i_fma,
441                                             i_load));
442                         else
443                             vpxord(vreg_load(i_load, i_fma),
444                                     vreg_load(i_load, i_fma),
445                                     vreg_load(i_load, i_fma));
446                     }
447                     jmp(load_finish);
448
449                     L(load_all);
450                     for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) {
451                         vmovups(vreg_load(i_load, i_fma),
452                             load_ptr(i_reduce + load_scale * i_fma, i_load));
453                     }
454                     L(load_finish);
455                 } else {
456                     for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) {
457                         vmovups(vreg_load(i_load, i_fma),
458                             load_ptr(i_reduce
459                                 + load_scale * i_fma,
460                                 i_load));
461                     }
462                 }
463             }
464
465             for (int i_ur = 0; i_ur < ur; ++i_ur) {
466                 if (jcp.ver == ver_avx512_core && jcp.expl_bcast
467                         && load_loop_blk > 1)
468                     vbroadcastss(vreg_bcast, bcast_ptr(i_reduce, i_ur, false));
469                 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
470                     if (jcp.ver == ver_4fma)
471                         v4fmaddps(vreg_accum(i_load, i_ur),
472                                     vreg_load(i_load, 0),
473                                     bcast_ptr(i_reduce, i_ur, false));
474                     else if (jcp.ver == ver_4vnni)
475                         vp4dpwssd(vreg_accum(i_load, i_ur),
476                                 vreg_load(i_load, 0),
477                                 bcast_ptr(i_reduce, i_ur, false));
478                     else if (jcp.ver == ver_avx512_core && jcp.expl_bcast
479                             && load_loop_blk > 1)
480                         vfmadd231ps(vreg_accum(i_load, i_ur),
481                                 vreg_load(i_load, 0), vreg_bcast);
482                     else
483                         vfmadd231ps(vreg_accum(i_load, i_ur),
484                                 vreg_load(i_load, 0),
485                                 bcast_ptr(i_reduce, i_ur, true));
486                     prefetch_callback(ur, i_reduce, i_ur, i_load,
487                                     last_block, wraparound, reduce_step);
488                 }
489             }
490         }
491     };
492     Label reduce_loop;
493     Label reduce_loop_tail;
494
495     mov(aux_reg_load_data, reg_load_data);
496
497     mov(aux_reg_bcast_data, aux1_reg_bcast_data);
498     init();
499
500     mov(reduce_loop_iter, reg_reduce_loop_work);
501     sub(reduce_loop_iter, jcp.reduce_loop_unroll);
502     jle(reduce_loop_tail, T_NEAR);
503
504     L(reduce_loop); {
505         fma_block(false);
506         add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step);
507         add(aux_reg_load_data, jcp.reduce_loop_load_step);
508         sub(reduce_loop_iter, jcp.reduce_loop_unroll);
509         jg(reduce_loop, T_NEAR);
510     }
511
512     L(reduce_loop_tail);
513     fma_block(true);
514
515     store();
516 }
517
518 void jit_avx512_common_1x1_conv_kernel::generate()
519 {
520     const auto &p = attr_.post_ops_;
521     for (int i = 0; i < p.len_; i++) {
522         auto &post_op = p.entry_[i];
523         if (post_op.is_eltwise()) {
524             eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<avx512_common>(
525                     this,
526                     post_op.eltwise.alg,
527                     post_op.eltwise.alpha,
528                     post_op.eltwise.beta
529             ));
530         } else if (post_op.is_depthwise()) {
531             depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<avx512_common>(
532                     this,
533                     post_op.depthwise.alg
534             ));
535         }
536     }
537
538     preamble();
539
540     mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
541     mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]);
542     mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]);
543
544     sub(rsp, stack_space_needed);
545
546     if (jcp.with_bias)
547         mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
548
549     mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]);
550     mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]);
551     mov(EVEX_compress_addr(rsp, bcast_loop_work_offt), reg_bcast_loop_work);
552     mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
553     mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
554     if (one_of(jcp.prop_kind, forward_training, forward_inference))
555         mov(reg_relu_ns, reinterpret_cast<size_t>(&jcp.eltwise.alpha));
556     if (jcp.prop_kind == backward_weights)
557         mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]);
558     mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]);
559
560     auto load_loop_body = [=](int load_loop_blk) {
561         bcast_loop(load_loop_blk);
562         add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
563         switch (jcp.prop_kind) {
564         case forward_training:
565         case forward_inference:
566             add(reg_bias_data,
567                 load_loop_blk * jcp.load_block * jcp.typesize_out);
568             add(reg_output_data,
569                 load_loop_blk * jcp.bcast_dim * jcp.load_block *
570                     jcp.typesize_out);
571             break;
572         case backward_data:
573             add(reg_output_data,
574                 load_loop_blk * jcp.bcast_dim * jcp.load_block *
575                     jcp.typesize_out);
576             break;
577         case backward_weights:
578             for (int i_load = 0; i_load < load_loop_blk; i_load++)
579                 add(reg_output_data, reg_output_stride);
580             break;
581         default:
582             assert(!"invalid prop_kind");
583         }
584         sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
585         add(reg_oc_off, load_loop_blk * jcp.oc_block * jcp.typesize_out);
586     };
587
588     const int simd_w = 16;
589
590     Label load_loop_blk[7];
591
592     static const int ur_cases_fma_embd_bcast[] = { 2, 4, 5, 8, 14, 32 };
593     static const int ur_cases_fma_expl_bcast[] = { 2, 5, 6, 9, 14, 32 };
594     static const int ur_cases_4fma[] = { 2, 4, 6, 12, 32 };
595
596     const int size_ur_cases_fma
597             = (jcp.ver == ver_avx512_core && jcp.expl_bcast) ?
598             sizeof(ur_cases_fma_expl_bcast) :
599             sizeof(ur_cases_fma_embd_bcast);
600     const int size_ur_cases_4fma = sizeof(ur_cases_4fma);
601
602     const int *ur_cases_fma = (jcp.ver == ver_avx512_core && jcp.expl_bcast) ?
603             ur_cases_fma_expl_bcast :
604             ur_cases_fma_embd_bcast;
605     const int *ur_cases = (jcp.ver == ver_4fma || jcp.ver == ver_4vnni)
606         ? ur_cases_4fma : ur_cases_fma;
607     const int num_ur_cases = (jcp.ver == ver_4fma || jcp.ver == ver_4vnni ?
608                                              size_ur_cases_4fma :
609                                              size_ur_cases_fma)
610             / sizeof(*ur_cases);
611
612     for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) {
613         int label_idx = num_ur_cases - ur_idx - 1;
614         if (jcp.ur <= ur_cases[ur_idx]) {
615             cmp(reg_load_loop_work, simd_w * (label_idx + 1));
616             jle(load_loop_blk[label_idx], T_NEAR);
617         }
618     }
619
620     for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) {
621         if (jcp.ur <= ur_cases[ur_idx]) {
622             int label_idx = num_ur_cases - ur_idx - 1;
623             L(load_loop_blk[label_idx]);
624             {
625                 if (label_idx == 0) {
626                     cmp(reg_load_loop_work, 0);
627                     je(load_loop_blk[num_ur_cases], T_NEAR);
628                 }
629                 load_loop_body(label_idx + 1);
630                 if (label_idx - 1 > 0) {
631                     cmp(reg_load_loop_work, 2 * label_idx * simd_w);
632                     je(load_loop_blk[label_idx - 1], T_NEAR);
633                 }
634                 cmp(reg_load_loop_work, (label_idx + 1) * simd_w);
635                 jge(load_loop_blk[label_idx]);
636             }
637             for (int idx = label_idx - 1; idx > 0; --idx) {
638                 cmp(reg_load_loop_work, simd_w * (idx + 1));
639                 je(load_loop_blk[idx], T_NEAR);
640             }
641             if (ur_idx < num_ur_cases - 2) {
642                 cmp(reg_load_loop_work, simd_w);
643                 jle(load_loop_blk[0], T_NEAR);
644             }
645         }
646     }
647     L(load_loop_blk[num_ur_cases]);
648
649     add(rsp, stack_space_needed);
650
651     postamble();
652
653     for (auto& inj : eltwise_injectors)
654         inj->prepare_table();
655 }
656
657 bool jit_avx512_common_1x1_conv_kernel::post_ops_ok(
658         jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
659     const auto &p = attr.post_ops_;
660
661     auto all_post_ops_supported = [&]() {
662         bool ok = true;
663
664         for (int i = 0; i < p.len_; i++) {
665             ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::sum, primitive_kind::eltwise, primitive_kind::depthwise);
666         }
667         return ok;
668     };
669     auto contain = [&](mkldnn::impl::primitive_kind_t kind) { return p.find(kind) != -1; };
670     auto position = [&](mkldnn::impl::primitive_kind_t kind) { return p.find(kind); };
671     auto count = [&](mkldnn::impl::primitive_kind_t kind) { return p.count(kind); };
672
673     return all_post_ops_supported() &&
674            count(primitive_kind::sum) <= 1 &&
675            IMPLICATION(contain(primitive_kind::sum), position(primitive_kind::sum) == 0);
676 }
677
678 status_t jit_avx512_common_1x1_conv_kernel::init_conf(jit_1x1_conv_conf_t &jcp,
679         const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
680         const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
681         const primitive_attr_t &attr, int nthreads, bool reduce_src) {
682     if (!mayiuse(avx512_common)) return status::unimplemented;
683
684     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
685     const int simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
686     const int ndims = src_d.ndims();
687
688     jcp.prop_kind = cd.prop_kind;
689
690     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
691     jcp.mb = src_d.dims()[0];
692
693     jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups;
694     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
695     jcp.ic = src_d.dims()[1] / jcp.ngroups;
696
697     bool ok_to_pad_channels = true
698         && jcp.ngroups == 1
699         && src_d.data_type() == data_type::f32;
700     if (ok_to_pad_channels) {
701         jcp.oc = rnd_up(jcp.oc, simd_w);
702         jcp.ic = rnd_up(jcp.ic, simd_w);
703     }
704
705     jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2];
706     jcp.iw = src_d.dims()[ndims - 1];
707     jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2];
708     jcp.ow = dst_d.dims()[ndims - 1];
709
710     jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2];
711     jcp.kw = weights_d.dims()[with_groups + ndims - 1];
712
713     jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0];
714     jcp.l_pad = cd.padding[0][ndims - 3];
715
716     jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0];
717     jcp.stride_w = cd.strides[ndims - 3];
718
719     jcp.src_fmt = src_d.format();
720     jcp.with_bias = pick_by_prop_kind(jcp.prop_kind, cd.bias_desc.format,
721             memory_format::undef, cd.diff_bias_desc.format)
722         != memory_format::undef;
723
724     jcp.os = jcp.oh * jcp.ow;
725     jcp.is = jcp.ih * jcp.iw;
726     jcp.tr_is = rnd_up(jcp.is, 4);
727
728     if (!post_ops_ok(jcp, attr))
729         return status::unimplemented;
730
731     const auto &p = attr.post_ops_;
732     jcp.with_sum = p.find(primitive_kind::sum) != -1;
733     const int eltwise_ind = p.find(primitive_kind::eltwise);
734     jcp.with_eltwise = eltwise_ind != -1;
735     if (jcp.with_eltwise) {
736         jcp.eltwise = p.entry_[eltwise_ind].eltwise;
737         if (dst_d.data_type() == data_type::s32) return status::unimplemented;
738     }
739
740     bool args_ok = true
741         && jcp.ngroups == 1
742         && everyone_is(pick(ndims - 3, nCw16c, nChw16c), src_d.format(),
743             dst_d.format())
744         && one_of(cd.bias_desc.format, memory_format::undef, any, x);
745     if (!args_ok) return status::unimplemented;
746
747     args_ok = true
748         && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0
749         && jcp.t_pad == 0 && jcp.l_pad == 0
750         && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides
751         && jcp.kh == 1 && jcp.kw == 1;
752     if (!args_ok) return status::unimplemented;
753
754     jcp.ic_block = jcp.oc_block = simd_w;
755     jcp.transpose_src = false;
756
757     if (mayiuse(avx512_mic_4ops)
758         && ((one_of(jcp.prop_kind, forward_training, forward_inference)
759             && src_d.data_type() == data_type::s16
760             && weights_d.data_type() == data_type::s16
761             && dst_d.data_type() == data_type::s32)
762         || (jcp.prop_kind == backward_data
763             && src_d.data_type() == data_type::s32
764             && weights_d.data_type() == data_type::s16
765             && dst_d.data_type() == data_type::s16)))
766     {
767         const int is_bwd_d = jcp.prop_kind == backward_data;
768         memory_format_t weights_format = with_groups
769             ? pick(2 * ndims - 6 + is_bwd_d, gOIw8i16o2i, gOIw8o16i2o,
770                 gOIhw8i16o2i, gOIhw8o16i2o)
771             : pick(2 * ndims - 6 + is_bwd_d, OIw8i16o2i, OIw8o16i2o,
772                 OIhw8i16o2i, OIhw8o16i2o);
773
774         if (weights_d.format() != weights_format)
775             return status::unimplemented;
776
777         jcp.ver = ver_4vnni;
778         jcp.fma_step = 4;
779         jcp.typesize_in = sizeof(prec_traits<data_type::s16>::type);
780         jcp.typesize_out = sizeof(prec_traits<data_type::s32>::type);
781     }
782     else if (everyone_is(data_type::f32, src_d.data_type(),
783                             weights_d.data_type(), dst_d.data_type()))
784     {
785         const int is_bwd_d = jcp.prop_kind == backward_data;
786         memory_format_t weights_format = with_groups
787             ? pick(2 * ndims - 6 + is_bwd_d, gOIw16i16o, gIOw16o16i,
788                 gOIhw16i16o, gIOhw16o16i)
789             : pick(2 * ndims - 6 + is_bwd_d, OIw16i16o, IOw16o16i,
790                 OIhw16i16o, IOhw16o16i);
791
792         if (weights_d.format() != weights_format)
793             return status::unimplemented;
794         if (jcp.prop_kind != backward_weights && mayiuse(avx512_mic_4ops) &&
795             ((jcp.prop_kind == backward_data) ? jcp.oc_block : jcp.ic_block) % 4
796             == 0) {
797             jcp.ver = ver_4fma;
798             jcp.fma_step = 4;
799         } else if (jcp.prop_kind == backward_weights && mayiuse(avx512_mic_4ops)
800                 && !reduce_src
801                 /* Heuristic condition for relation of src size to oc. Otherwise
802                    the src transposition overhead exceed the benefit from 4fma
803                 */
804                 && ((jcp.is * jcp.ic) / jcp.oc <= 2048)
805                 && mkldnn_thr_syncable()
806                 )
807         {
808             jcp.transpose_src = true;
809             jcp.ver = ver_4fma;
810             jcp.fma_step = 4;
811         } else {
812             jcp.ver = (mayiuse(avx512_core)) ? ver_avx512_core : ver_fma;
813             jcp.fma_step = 1;
814         }
815         jcp.typesize_in = sizeof(prec_traits<data_type::f32>::type);
816         jcp.typesize_out = sizeof(prec_traits<data_type::f32>::type);
817     } else {
818         return status::unimplemented;
819     }
820
821     /* once all the formats are set, check the padding consistency */
822     args_ok = true
823         && jcp.ic <= src_d.blocking_desc().padding_dims[1]
824         && jcp.oc <= dst_d.blocking_desc().padding_dims[1]
825         && jcp.ic <= weights_d.blocking_desc().padding_dims[with_groups + 1]
826         && jcp.oc <= weights_d.blocking_desc().padding_dims[with_groups + 0];
827     if (!args_ok) return status::unimplemented;
828
829     const int SMALL_SPATIAL = 10;
830     const int BIG_SPATIAL = 28;
831     const int BIG_REDUCE_DIM = 1024;
832     const int BIG_LOAD_DIM = 256;
833
834     int load_blocking{ 0 };
835     int load_blocking_max{ 0 };
836     int bcast_blocking{ 0 };
837     int bcast_blocking_max{ 0 };
838     int reduce_blocking{ 0 };
839     int reduce_blocking_max{ 0 };
840
841     jcp.load_grp_count = 1;
842
843     const int L1_capacity = get_cache_size(1, true) / sizeof(float);
844     const int L2_size = get_cache_size(2, true) / sizeof(float);
845     const int L2_capacity = (L2_size * 3) / 4;
846
847     if (one_of(jcp.prop_kind, forward_training, forward_inference,
848                 backward_data)) {
849         if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
850             jcp.reduce_dim = jcp.ic;
851             jcp.reduce_block = jcp.ic_block;
852
853             jcp.load_dim = jcp.oc;
854             jcp.load_block = jcp.oc_block;
855
856             jcp.bcast_dim = jcp.is;
857         } else {
858             jcp.reduce_dim = jcp.oc;
859             jcp.reduce_block = jcp.oc_block;
860
861             jcp.load_dim = jcp.ic;
862             jcp.load_block = jcp.ic_block;
863
864             jcp.bcast_dim = jcp.os;
865         }
866         jcp.reduce_loop_unroll = jcp.reduce_block;
867         jcp.reduce_loop_bcast_step
868                 = jcp.reduce_loop_unroll * jcp.bcast_dim * jcp.typesize_in;
869
870         if (jcp.prop_kind == backward_data && jcp.ver == ver_4vnni) {
871             jcp.reduce_loop_load_step
872                     = jcp.reduce_loop_unroll * jcp.ic * jcp.typesize_in;
873             jcp.load_loop_load_step
874                     = jcp.oc_block * jcp.ic_block * jcp.typesize_in;
875         } else {
876             jcp.reduce_loop_load_step
877                     = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in;
878             jcp.load_loop_load_step
879                     = jcp.reduce_dim * jcp.load_block * jcp.typesize_in;
880         }
881
882         // adjusting registry blocking
883         int max_regs, min_regs, size_treshold, ur_step;
884         const int spatial
885                 = (one_of(jcp.prop_kind, forward_training, forward_inference)) ?
886                 jcp.oh :
887                 jcp.ih;
888         if (jcp.ver == ver_avx512_core && (8 * jcp.mb) / nthreads >= 1) {
889             max_regs = 9;
890             min_regs = 6;
891             size_treshold = 14;
892             ur_step = 1;
893             jcp.expl_bcast = true;
894
895             if (jcp.load_dim > 128 && jcp.load_dim < BIG_LOAD_DIM
896                     && spatial > SMALL_SPATIAL && spatial < BIG_SPATIAL) {
897                 max_regs = 6;
898                 min_regs = 5;
899             }
900         } else {
901             bool is4ops = (jcp.ver == ver_4fma || jcp.ver == ver_4vnni);
902
903             max_regs = is4ops ? 28 : 30;
904             min_regs = 9;
905             size_treshold = is4ops ? 28 : 14;
906             ur_step = is4ops ? 4 : 1;
907             jcp.expl_bcast = false;
908             jcp.use_vmovntps = true;
909         }
910         jcp.ur = 1;
911         for (int ur_w = max_regs; ur_w >= min_regs; ur_w -= ur_step) {
912             if ((spatial >= size_treshold && spatial % ur_w == 0)
913                     || (spatial < size_treshold && jcp.os % ur_w == 0)) {
914                 jcp.ur = ur_w;
915                 break;
916             }
917         }
918         if (jcp.ur == 1) {
919             jcp.ur = nstl::min(max_regs, jcp.os);
920             int os_tail = jcp.os % max_regs;
921             for (int i = max_regs; i >= min_regs; i -= ur_step) {
922                 int i_tail = jcp.os % i;
923                 if (i_tail > os_tail || i_tail == 0) {
924                     jcp.ur = i;
925                     os_tail = i_tail;
926                     if (i_tail == 0)
927                         break;
928                 }
929             }
930         }
931
932         jcp.reduce_loop_unroll = jcp.reduce_block;
933         jcp.reduce_loop_bcast_step
934                 = jcp.reduce_loop_unroll * jcp.bcast_dim * jcp.typesize_in;
935
936         jcp.bcast_block = jcp.ur;
937
938         jcp.bcast_loop_output_step = jcp.ur * jcp.load_block * jcp.typesize_out;
939         jcp.bcast_loop_output_substep = -1; // unused
940         jcp.bcast_loop_bcast_step = jcp.ur * jcp.reduce_block * jcp.typesize_in;
941         jcp.bcast_loop_bcast_substep = -1; // unused
942
943         jcp.load_loop_iter_step = jcp.load_block;
944
945         if (jcp.prop_kind == backward_data)
946             jcp.loop_order = loop_lbr;
947         else
948             jcp.loop_order = reduce_src ? loop_blr : loop_lbr;
949
950         int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
951         int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
952         int nb_load = div_up(jcp.load_dim, jcp.load_block);
953
954         if (jcp.ver == ver_avx512_core && jcp.expl_bcast) {
955             if (jcp.load_dim <= BIG_LOAD_DIM && spatial > SMALL_SPATIAL
956                     && spatial < BIG_SPATIAL)
957                 reduce_blocking = nstl::min(jcp.reduce_dim, 80);
958             else if (spatial > SMALL_SPATIAL)
959                 reduce_blocking = nstl::min(jcp.reduce_dim, 512);
960             else
961                 reduce_blocking = nstl::min(jcp.reduce_dim, 256);
962
963             if ((jcp.mb > 28 && spatial >= 28)
964                     || (jcp.mb > 112 && spatial >= 17))
965                 jcp.use_vmovntps = true;
966             else
967                 jcp.use_vmovntps = false;
968         } else {
969
970             reduce_blocking = nb_reduce;
971             if (spatial <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM)
972                 reduce_blocking = 16;
973             else if (spatial > SMALL_SPATIAL
974                     && jcp.reduce_dim >= BIG_REDUCE_DIM)
975                 reduce_blocking = 8;
976             reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true);
977             reduce_blocking *= jcp.reduce_block;
978         }
979
980         // Check input data cache aliasing.
981         // For other ISA constants may be updated.
982         // 64 * 1024 is chosen due to 1MB L2 16-way cache.
983         // 7 is empirical value. It is about half of 16.
984         // So we leave about half of the set for other data - weights, dst
985         int way_size = (64 * 1024) / jcp.typesize_in;
986         int max_hits = 7;
987         if (jcp.bcast_dim * reduce_blocking > way_size * max_hits) {
988             int nrb = reduce_blocking / simd_w;
989             int sp = jcp.bcast_dim;
990             int wl = way_size / simd_w;
991             for (int start_off = 0; start_off < jcp.ur; start_off++) {
992                 for (int off = start_off, hits = 0; off < sp * nrb; off += wl) {
993                     if (off % sp >= jcp.ur || ++hits < max_hits)
994                         continue;
995                     int max_r_blocking = simd_w * nstl::max(1, (off + wl) / sp);
996                     reduce_blocking
997                             = nstl::min(reduce_blocking, max_r_blocking);
998                     break;
999                 }
1000             }
1001         }
1002
1003         if (reduce_blocking < jcp.reduce_dim) {
1004             jcp.use_vmovntps = false;
1005             if (jcp.prop_kind == backward_data)
1006                 jcp.loop_order = reduce_src ? loop_lbr : loop_rlb;
1007             else
1008                 jcp.loop_order = reduce_src ? loop_rbl : loop_rlb;
1009         }
1010         load_blocking = jcp.load_dim;
1011
1012         int load_size = jcp.load_dim * jcp.reduce_dim;
1013         int bcast_size = jcp.mb * jcp.ngroups * jcp.bcast_dim * jcp.reduce_dim;
1014
1015         if (jcp.ver == ver_avx512_core && nthreads <= 28 && jcp.mb < nthreads
1016                 && nb_load * nb_bcast > nthreads) {
1017             // Some heuristic here
1018             float calc_koef = 0.01, best_cost = FLT_MAX;
1019             int n_lgc = nthreads;
1020             float ratio = (float)load_size / (float)bcast_size;
1021             int best_lgc = ratio > 1 ? n_lgc : 1;
1022             auto calc_job_cost = [&](int lb, int tg, float mem_k) {
1023                 int bb_size = jcp.mb * div_up(nb_bcast, tg);
1024                 float calc_size = (float)(bb_size * jcp.ur)
1025                         * (lb * jcp.load_block) * jcp.reduce_dim;
1026                 float mem_size = (float)(bb_size * jcp.ur + lb * jcp.load_block)
1027                         * jcp.reduce_dim;
1028                 return calc_koef * calc_size + mem_k * mem_size;
1029             };
1030             for (int lgc, ilgc = 0; ilgc < n_lgc; ilgc++) {
1031                 lgc = ratio > 1 ? n_lgc - ilgc : ilgc + 1;
1032                 int min_lb = nb_load / lgc;
1033                 int max_lb = div_up(nb_load, lgc);
1034                 int min_tg = nthreads / lgc;
1035                 int max_tg = div_up(nthreads, lgc);
1036                 // Some heuristic here
1037                 float mem_koef = (max_tg == 1) ? 1.f : 1.3f;
1038                 float job_cost = 0.;
1039                 if (nthreads % lgc < nb_load % lgc) {
1040                     job_cost = calc_job_cost(max_lb, min_tg, mem_koef);
1041                 } else {
1042                     auto job_cost1 = calc_job_cost(max_lb, max_tg, mem_koef);
1043                     auto job_cost2 = calc_job_cost(min_lb, min_tg, mem_koef);
1044                     job_cost = nstl::max(job_cost1, job_cost2);
1045                 }
1046
1047                 if (job_cost < best_cost) {
1048                     best_lgc = lgc;
1049                     best_cost = job_cost;
1050                 }
1051             }
1052             jcp.load_grp_count = best_lgc;
1053             load_blocking = div_up(nb_load, jcp.load_grp_count) * jcp.load_block;
1054         } else {
1055             jcp.load_grp_count = div_up(nthreads, jcp.mb * jcp.ngroups * nb_bcast);
1056             jcp.load_grp_count = best_divider(
1057                 nthreads, jcp.load_grp_count, 2 * jcp.load_grp_count, false);
1058         }
1059
1060         if (jcp.ver == ver_avx512_core && jcp.expl_bcast && jcp.bcast_dim <= 64
1061                 && load_size >= L2_size) {
1062             jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4);
1063         } else if (jcp.bcast_dim <= 49 && jcp.mb <= nthreads
1064                 && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) {
1065             jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2);
1066             load_blocking = jcp.load_block;
1067         }
1068
1069         if (jcp.ver == ver_4fma && jcp.bcast_dim * jcp.mb < jcp.load_dim
1070                 && jcp.oh * jcp.ow > 64
1071                 && IMPLICATION(reduce_src, jcp.load_dim < 1024)) {
1072             /* Looking for best loading dimension blocking
1073             * to get the best thread and data read/write efficiency
1074             * by finding the optimal 'load_chunk' value
1075             * Example:
1076             * for 72 threads and convolution with mb=1, ih=iw=7, oc = 512
1077             * the 'best' load_chunk value should be 1
1078             * TODO: remove heuristic constants in above condition
1079             * TODO: check this blocking for other ISA
1080             */
1081             float best_eff = -1.f;
1082             int best_lgc = 1;
1083
1084             for (int load_chunk = 1; load_chunk <= nb_load; load_chunk++) {
1085                 int lgc = div_up(nb_load, load_chunk);
1086                 if (lgc > nthreads)
1087                     continue;
1088                 int thr_per_grp = div_up(nthreads, lgc);
1089                 int bcast_per_thr = div_up(jcp.mb * nb_bcast, thr_per_grp)
1090                         * jcp.bcast_block;
1091                 int load_per_thr = load_chunk * simd_w;
1092                 float data_norm = (bcast_per_thr + load_per_thr) / 2.f;
1093                 float data_eff = (bcast_per_thr * load_per_thr)
1094                         / (data_norm * data_norm);
1095                 float thr_eff_over_grp = (float)nstl::max(1, nthreads / lgc)
1096                         / div_up(nthreads, lgc);
1097                 float thr_eff_in_grp = ((float)jcp.mb * nb_bcast)
1098                         / rnd_up(jcp.mb * nb_bcast, thr_per_grp);
1099                 float thr_eff = thr_eff_over_grp * thr_eff_in_grp;
1100                 float load_eff = (float)nb_load / rnd_up(nb_load, lgc);
1101                 float overall_eff = data_eff + thr_eff + load_eff;
1102                 if (overall_eff > best_eff) {
1103                     best_eff = overall_eff;
1104                     best_lgc = lgc;
1105                 }
1106             }
1107             jcp.load_grp_count = best_lgc;
1108             load_blocking
1109                     = div_up(nb_load, jcp.load_grp_count) * jcp.load_block;
1110         }
1111         bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast,
1112                                  div_up(nthreads, jcp.load_grp_count))
1113                 * jcp.bcast_block;
1114         bcast_blocking = nstl::min(jcp.bcast_dim, bcast_blocking);
1115         bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block);
1116
1117         int space_for_bcast
1118                 = (L2_capacity - /* kernel_size - */
1119                     2 * jcp.load_block * reduce_blocking
1120                         - jcp.ur * reduce_blocking - 3 * 1024);
1121         if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity)
1122             space_for_bcast /= 2;
1123
1124         int bcast_in_cache
1125                 = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking);
1126         bcast_blocking = nstl::min(
1127                 bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block));
1128
1129         load_blocking_max = load_blocking;
1130         bcast_blocking_max = bcast_blocking * 3 / 2;
1131         reduce_blocking_max = reduce_blocking;
1132
1133     } else if (jcp.prop_kind == backward_weights) {
1134
1135         jcp.use_vmovntps = false;
1136         if (jcp.is > SMALL_SPATIAL * SMALL_SPATIAL && jcp.ver == ver_4fma)
1137             jcp.use_vmovntps = true;
1138
1139         if (jcp.transpose_src)
1140             jcp.reduce_dim = jcp.tr_is;
1141         else
1142             jcp.reduce_dim = jcp.is;
1143
1144         if (jcp.ver == ver_4fma) {
1145             // reduce_block should be divided by fma_step
1146             jcp.reduce_block = best_divider(jcp.reduce_dim, 4, 16, true, 4);
1147         } else {
1148             jcp.reduce_block = best_divider(jcp.reduce_dim, 7, 16, true);
1149             if (jcp.reduce_dim % jcp.reduce_block != 0)
1150                 jcp.reduce_block = best_divider(jcp.iw, 4, jcp.iw, false);
1151             if (jcp.reduce_block > 256) {
1152                 jcp.reduce_block = 1;
1153             }
1154
1155         }
1156
1157         jcp.load_dim = jcp.oc;
1158         jcp.load_block = jcp.oc_block;
1159
1160         jcp.bcast_dim = jcp.ic;
1161         jcp.bcast_block = jcp.ic_block;
1162
1163         if (jcp.ver == ver_avx512_core && jcp.reduce_block <= 19) {
1164             // if reduce_block is big then generated JIT code may be big
1165             // for small values of ur because reduce_loop_unroll = reduce_block
1166             jcp.ur = jcp.bcast_block / 2;
1167             jcp.expl_bcast = true;
1168         } else {
1169             jcp.ur = jcp.bcast_block;
1170             jcp.expl_bcast = false;
1171         }
1172
1173         jcp.reduce_loop_unroll = jcp.reduce_block;
1174         jcp.reduce_loop_bcast_step
1175             = jcp.reduce_loop_unroll * jcp.ic_block * jcp.typesize_in;
1176         jcp.reduce_loop_load_step
1177             = jcp.reduce_loop_unroll * jcp.oc_block * jcp.typesize_in;
1178
1179         jcp.bcast_loop_output_step =
1180                                 jcp.oc_block * jcp.ic_block * jcp.typesize_out;
1181         jcp.bcast_loop_output_substep =
1182             jcp.oc_block * jcp.ur * jcp.typesize_out;
1183         jcp.bcast_loop_bcast_step =
1184                 jcp.ic_block * jcp.reduce_dim * jcp.typesize_in;
1185         jcp.bcast_loop_bcast_substep = jcp.ur * jcp.typesize_in;
1186
1187         jcp.load_loop_load_step = jcp.oc_block * jcp.os * jcp.typesize_in;
1188         jcp.load_loop_iter_step = jcp.oc_block;
1189
1190         /* --- */
1191         balance(jcp, nthreads);
1192
1193         load_blocking = div_up(jcp.load_dim, jcp.load_block);
1194         load_blocking = best_divider(load_blocking, 16, load_blocking, false);
1195         load_blocking *= jcp.load_block;
1196
1197         load_blocking_max = load_blocking;
1198         assert(jcp.load_dim % load_blocking == 0);
1199
1200         int max_bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
1201         int min_bcast_blocking = 5;
1202
1203         bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
1204         bcast_blocking = best_divider(
1205                 bcast_blocking, min_bcast_blocking, max_bcast_blocking, false);
1206         bcast_blocking *= jcp.bcast_block;
1207         bcast_blocking_max = bcast_blocking;
1208         assert(jcp.bcast_dim % bcast_blocking == 0);
1209
1210         // for reduction balance
1211         if (jcp.ver == ver_avx512_core) {
1212             int max_reduce_blocking
1213                     = nstl::min(L1_capacity / jcp.ur, jcp.reduce_dim);
1214             int min_reduce_blocking = nstl::min(
1215                     L1_capacity / jcp.ur, nstl::max(jcp.iw, jcp.ih));
1216             reduce_blocking = best_divider(jcp.reduce_dim, min_reduce_blocking,
1217                     max_reduce_blocking, true);
1218             reduce_blocking
1219                     = nstl::max(rnd_dn(reduce_blocking, jcp.reduce_block),
1220                             jcp.reduce_block);
1221         } else {
1222             int max_reduce_blocking = L2_capacity
1223                     / ((bcast_blocking + load_blocking) * jcp.reduce_block);
1224             max_reduce_blocking = nstl::min(max_reduce_blocking,
1225                     (L1_capacity / (jcp.bcast_block)) / jcp.reduce_block);
1226
1227             int num_jobs = div_up(jcp.load_dim, load_blocking)
1228                     * div_up(jcp.bcast_dim, bcast_blocking);
1229             int threads_per_job = nstl::max(1, nthreads / num_jobs);
1230             reduce_blocking = div_up(jcp.mb * jcp.reduce_dim, jcp.reduce_block);
1231             reduce_blocking = div_up(reduce_blocking, threads_per_job);
1232
1233             reduce_blocking = best_divider(reduce_blocking,
1234                     max_reduce_blocking - 2, max_reduce_blocking, true);
1235             reduce_blocking *= jcp.reduce_block;
1236         }
1237
1238         reduce_blocking_max = rnd_dn(reduce_blocking * 3 / 2, jcp.reduce_block);
1239     } else
1240         return status::unimplemented;
1241
1242     assert(load_blocking);
1243     assert(load_blocking_max);
1244     assert(bcast_blocking);
1245     assert(bcast_blocking_max);
1246     assert(reduce_blocking);
1247     assert(reduce_blocking_max);
1248     assert(load_blocking % jcp.load_block == 0);
1249     assert(reduce_blocking % jcp.reduce_block == 0);
1250     assert(load_blocking_max % jcp.load_block == 0);
1251     assert(reduce_blocking_max % jcp.reduce_block == 0);
1252     if (jcp.ver == ver_4fma || jcp.ver == ver_4vnni) {
1253         if (jcp.ver == ver_4fma)
1254             assert(jcp.reduce_loop_unroll % jcp.fma_step == 0);
1255         if (jcp.ver == ver_4vnni)
1256             assert(jcp.reduce_loop_unroll % (2 * jcp.fma_step) == 0);
1257         assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0);
1258     }
1259
1260     assert(jcp.bcast_block % jcp.ur == 0);
1261     assert(jcp.reduce_dim % jcp.reduce_block == 0);
1262
1263     jcp.ur_tail = jcp.bcast_dim % jcp.ur;
1264
1265     jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block;
1266     jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block;
1267     jcp.nb_load_blocking = load_blocking / jcp.load_block;
1268     jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block;
1269     jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block;
1270     jcp.nb_reduce_blocking_max = reduce_blocking_max / jcp.reduce_block;
1271
1272     jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
1273     jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
1274     jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
1275
1276     return status::success;
1277 }
1278
1279 void jit_avx512_common_1x1_conv_kernel::init_scratchpad(
1280         memory_tracking::registrar_t &scratchpad,
1281         const jit_1x1_conv_conf_t &jcp) {
1282     using namespace mkldnn::impl::memory_tracking::names;
1283
1284     if (jcp.prop_kind != backward_data && jcp.with_bias
1285             && jcp.oc != jcp.oc_without_padding)
1286         scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc);
1287
1288     if (jcp.prop_kind == backward_weights) {
1289         const size_t wei_size = (size_t)jcp.ngroups * jcp.oc * jcp.ic;
1290         scratchpad.book(key_conv_wei_reduction,
1291                 jcp.typesize_out * wei_size * (jcp.nthr_mb - 1));
1292     }
1293
1294     if (jcp.transpose_src) {
1295         const size_t tr_src_size =
1296             (size_t)jcp.nthr_mb * jcp.ngroups * jcp.ic * jcp.tr_is;
1297         scratchpad.book(key_conv_tr_src, jcp.typesize_out * tr_src_size);
1298         scratchpad.book(key_conv_tr_src_bctx,
1299                 sizeof(simple_barrier::ctx_t) * jcp.nthr);
1300     }
1301 }
1302
1303 void jit_avx512_common_1x1_conv_kernel::balance(jit_1x1_conv_conf_t &jcp,
1304         int nthreads)
1305 {
1306     // initialize jcp reduction threading properties
1307     jcp.nthr = jcp.nthr_mb = jcp.nthr_g = jcp.nthr_oc_b = jcp.nthr_ic_b = 1;
1308     if (nthreads < jcp.ngroups) {
1309         /* simplification... fortunately it doesn't hurt much */
1310         return;
1311     }
1312     const int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
1313     const int nb_load = div_up(jcp.load_dim, jcp.load_block);
1314     const int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
1315
1316     jcp.nthr_g = jcp.ngroups;
1317     const int nthr = nthreads / jcp.nthr_g;
1318
1319     auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
1320         /* calculate per thread memory cost (read/write). high level
1321         * optimizer tries to minimize memory consumption. few notes: (n1)
1322         * unclear why, but that essentially helps first convolution...
1323         *  (n2) assuming the reduction over minibatch is always there:
1324         *    - instead of 8 it should be 5 here (write ~= 2 read):
1325         *      kernel: temporal workspace 1 write
1326         *      reduction: 1 read from workspace and 1 write to the diff_wei
1327         *    - but experiments showed 8 works better than 5 or 6... */
1328         int bcast_koeff = 1;
1329         int load_koeff = 1;
1330         int output_koeff = 12;
1331         if (jcp.transpose_src) {
1332             bcast_koeff = 5;
1333             load_koeff = 1;
1334             output_koeff = 8;
1335         }
1336         return 0
1337             + (size_t)bcast_koeff * div_up(jcp.mb * nb_reduce, nthr_mb)
1338             * div_up(jcp.ngroups, jcp.nthr_g)
1339             * div_up(nb_bcast, nthr_ic_b) * jcp.ic_block * jcp.reduce_block
1340             / jcp.stride_h / jcp.stride_w /* (n1) */
1341             + (size_t)load_koeff * div_up(jcp.mb * nb_reduce, nthr_mb)
1342             * div_up(jcp.ngroups, jcp.nthr_g)
1343             * div_up(nb_load, nthr_oc_b) * jcp.oc_block * jcp.reduce_block
1344             + (size_t)output_koeff /* (n2) */
1345             * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_load, nthr_oc_b)
1346             * div_up(nb_bcast, nthr_ic_b) * jcp.ic_block
1347             * jcp.oc_block;
1348     };
1349
1350     int nthr_mb = 1, nthr_oc_b = 1, nthr_ic_b = 1;
1351     auto best_mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
1352
1353     /* step 1: find the best thread distribution with lowest memory cost */
1354     const int nthr_mb_max = nstl::min(nthr, jcp.mb * nb_reduce);
1355     for (nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
1356         const int nthr_par = nthr / nthr_mb;
1357         const int nthr_oc_b_max = nstl::min(nthr_par, nb_load);
1358         for (nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
1359             nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, nb_bcast);
1360             auto mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
1361             if (mem_cost <= best_mem_cost) {
1362                 best_mem_cost = mem_cost;
1363                 jcp.nthr_mb = nthr_mb;
1364                 jcp.nthr_oc_b = nthr_oc_b;
1365                 jcp.nthr_ic_b = nthr_ic_b;
1366             }
1367         }
1368
1369         if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; }
1370     }
1371     if (jcp.nthr_mb > nthreads / 2 && jcp.nthr_mb < nthreads)
1372         jcp.nthr_mb = nstl::min(jcp.mb, nthreads);
1373
1374     jcp.nthr = jcp.nthr_mb * jcp.nthr_g * jcp.nthr_oc_b * jcp.nthr_ic_b;
1375     assert(jcp.nthr <= nthreads);
1376 }
1377
1378 }
1379 }
1380 }