Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx2_1x1_conv_kernel_f32.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 * Copyright 2018 YANDEX LLC
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
17
18 #include <assert.h>
19
20 #include "c_types_map.hpp"
21 #include "memory_tracking.hpp"
22 #include "nstl.hpp"
23 #include "type_helpers.hpp"
24 #include "utils.hpp"
25
26 #include "cpu_memory.hpp"
27
28 #include "jit_avx2_1x1_conv_kernel_f32.hpp"
29
30 #define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
31
32 namespace mkldnn {
33 namespace impl {
34 namespace cpu {
35
36 using namespace mkldnn::impl::prop_kind;
37 using namespace mkldnn::impl::memory_format;
38 using namespace mkldnn::impl::utils;
39
40 using namespace Xbyak;
41
42 void jit_avx2_1x1_conv_kernel_f32::generate_bcast_loop(int load_loop_blk)
43 {
44     mov(aux1_reg_bcast_data, reg_bcast_data);
45     mov(aux_reg_output_data, reg_output_data);
46     mov(bcast_loop_iter, reg_bcast_loop_work);
47
48     Label bcast_loop, bcast_loop_tail;
49
50     cmp(bcast_loop_iter, jcp.ur);
51     jl(bcast_loop_tail, T_NEAR);
52
53     L(bcast_loop); {
54         assert(jcp.bcast_block % jcp.ur == 0);
55         int num_substeps = jcp.bcast_block / jcp.ur;
56         assert(num_substeps > 0 && num_substeps < 10);
57         for (int i = 0; i < num_substeps; i++) {
58             generate_reduce_loop(load_loop_blk, jcp.ur);
59             if (i < num_substeps - 1) {
60                 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
61                 add(aux_reg_output_data, jcp.bcast_loop_output_substep);
62             } else {
63                 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step
64                         - (num_substeps - 1) * jcp.bcast_loop_bcast_substep);
65                 add(aux_reg_output_data, jcp.bcast_loop_output_step
66                         - (num_substeps - 1) * jcp.bcast_loop_output_substep);
67             }
68         }
69         sub(bcast_loop_iter, jcp.bcast_block);
70         cmp(bcast_loop_iter, jcp.bcast_block);
71         jge(bcast_loop, T_NEAR);
72     }
73
74     L(bcast_loop_tail);
75     if (jcp.ur_tail) {
76         Label bcast_loop_tail_out;
77         cmp(bcast_loop_iter, 0);
78         jz(bcast_loop_tail_out, T_NEAR);
79         generate_reduce_loop(load_loop_blk, jcp.ur_tail);
80         L(bcast_loop_tail_out);
81     }
82 }
83
84 void jit_avx2_1x1_conv_kernel_f32::generate_reduce_loop(
85         int load_loop_blk, int ur)
86 {
87     auto vreg_load = [=](int i) {
88         return Ymm(ur * load_loop_blk + i);
89     };
90
91     auto vreg_accum = [=](int i, int j) {
92         return Ymm(j + i*ur);
93     };
94
95     auto bias_ptr = [=](int i) {
96         return ptr[reg_bias_data + sizeof(float) * jcp.oc_block * i];
97     };
98
99     auto bcast_ptr = [=](int u, int j) {
100         assert(j < jcp.ur);
101         assert(u <= jcp.reduce_loop_unroll);
102         size_t offt;
103         if (one_of(jcp.prop_kind,
104                     forward_training, forward_inference, backward_data))
105         {
106             assert(jcp.reduce_loop_unroll == (jcp.prop_kind == backward_data)
107                     ? jcp.oc_block : jcp.ic_block);
108             auto height = (jcp.prop_kind == backward_data) ? jcp.os : jcp.is;
109             offt = (u == jcp.reduce_loop_unroll)
110                 ? (height + j) * jcp.reduce_loop_unroll
111                 : j * jcp.reduce_loop_unroll + u;
112         } else
113             offt = u * jcp.ic_block + j;
114         return ptr[aux_reg_bcast_data + sizeof(float) * offt];
115     };
116
117     auto load_ptr = [=](int u, int i) {
118         size_t offt;
119         size_t u0 = u % jcp.reduce_loop_unroll;
120         size_t u1 = u / jcp.reduce_loop_unroll;
121         switch (jcp.prop_kind) {
122         case backward_data:
123             offt = (i * jcp.oc_block + u0) * jcp.ic_block;
124             break;
125         case backward_weights:
126             offt = (i * jcp.os + u0) * jcp.oc_block;
127             break;
128         default:
129             offt = (i * jcp.ic + u0) * jcp.oc_block;
130         }
131         return ptr[aux_reg_load_data
132             + u1 * jcp.reduce_loop_load_step + sizeof(float) * offt];
133     };
134
135     auto output_ptr = [=](int i, int j) {
136         switch (jcp.prop_kind) {
137         case backward_data:
138             return ptr[aux_reg_output_data +
139                 (i * jcp.is + j) * jcp.ic_block * sizeof(float)];
140         case backward_weights:
141             return ptr[aux_reg_output_data
142                 + (i ? reg_output_stride * i : 0) // TODO: Xbyak should allow 0 scale
143                 + sizeof(float) * jcp.oc_block * j];
144         default:
145             if (jcp.with_dw_conv) {
146                 return ptr[aux_reg_output_data +
147                            (i * jcp_dw.kh * jcp.ow + j) * jcp.oc_block * sizeof(float)];
148             } else {
149                 return ptr[aux_reg_output_data +
150                            (i * jcp.os + j) * jcp.oc_block * sizeof(float)];
151             }
152         }
153     };
154
155     auto init = [=]() {
156         Label init_done, init_zero;
157
158         if (jcp.with_bias && one_of(jcp.prop_kind, forward_training,
159                     forward_inference)) {
160             test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
161             jz(init_zero);
162
163             for (int i = 0; i < load_loop_blk; i++)
164                 for (int j = 0; j < ur; ++j)
165                     vmovups(vreg_accum(i, j), bias_ptr(i));
166             jmp(init_done);
167         }
168
169         L(init_zero);
170         for (int i = 0; i < load_loop_blk; ++i)
171             for (int j = 0; j < ur; ++j) {
172                 auto r = vreg_accum(i, j);
173                 vxorps(r, r, r);
174             }
175
176         L(init_done);
177         for (int i = 0; i < load_loop_blk; ++i)
178             vmovups(vreg_load(i), load_ptr(0, i));
179         vbroadcastss(vreg_bcast, bcast_ptr(0, 0));
180     };
181
182     auto store = [=]() {
183         Label store_noadd;
184
185         if (!jcp.with_sum) {
186             test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
187             jnz(store_noadd, T_NEAR);
188         }
189
190         for (int j = 0; j < ur; ++j)
191             for (int i = 0; i < load_loop_blk; ++i) {
192                 auto r = vreg_accum(i, j);
193                 vaddps(r, r, output_ptr(i, j));
194             }
195
196         L(store_noadd);
197
198         Label store_norelu;
199         test(reg_reduce_pos_flag, FLAG_REDUCE_LAST);
200         jz(store_norelu, T_NEAR);
201
202         int eltwise_inj_idx = 0;
203         int depthwise_inj_idx = 0;
204         const auto &p = attr_.post_ops_;
205
206         int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
207         for (int i = 0; i < end_idx; i++) {
208             auto& post_op = p.entry_[i];
209             if (post_op.is_eltwise()) {
210                 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(0, ur * load_loop_blk);
211                 eltwise_inj_idx++;
212             } else if (post_op.is_depthwise()) {
213                 mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
214                 mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
215
216                 add(reg_d_weights, reg_oc_off);
217                 add(reg_d_bias, reg_oc_off);
218
219                 for (int j = 0; j < load_loop_blk; ++j) {
220                     int start_idx = vreg_accum(j, 0).getIdx();
221                     int end_idx = start_idx + ur;
222
223                     depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
224                             start_idx, end_idx, reg_d_weights, reg_d_bias);
225
226                     add(reg_d_weights, jcp.oc_block * sizeof(float));
227                     add(reg_d_bias, jcp.oc_block * sizeof(float));
228                 }
229
230                 depthwise_inj_idx++;
231             }
232         }
233
234         L(store_norelu);
235
236         for (int j = 0; j < ur; ++j)
237             for (int i = 0; i < load_loop_blk; ++i) {
238                 vmovups(output_ptr(i, j), vreg_accum(i, j));
239             }
240     };
241
242     auto fma_block = [=](bool last_block) {
243         for (int u = 0; u < jcp.reduce_loop_unroll; ++u) {
244             for (int j = 0; j < ur; ++j) {
245                 for (int i = 0; i < load_loop_blk; ++i) {
246                     if (mayiuse(avx2))
247                         vfmadd231ps(vreg_accum(i, j), vreg_load(i), vreg_bcast);
248                     else { // Intel(R) Advanced Vector Extensions (Intel(R) AVX) support
249                         vmulps(vtmp, vreg_bcast, vreg_load(i));
250                         vaddps(vreg_accum(i, j), vreg_accum(i, j), vtmp);
251                     }
252                     if (j == ur - 1 && !(last_block
253                                 && u == jcp.reduce_loop_unroll - 1))
254                         vmovups(vreg_load(i), load_ptr(u + 1, i));
255                 }
256                 if (j < ur - 1)
257                     vbroadcastss(vreg_bcast, bcast_ptr(u, j + 1));
258             }
259             if (!last_block || u < jcp.reduce_loop_unroll - 1)
260                 vbroadcastss(vreg_bcast, bcast_ptr(u + 1, 0));
261         }
262     };
263
264     Label reduce_loop, reduce_loop_tail;
265
266     mov(aux_reg_load_data, reg_load_data);
267     mov(aux_reg_bcast_data, aux1_reg_bcast_data);
268
269     init();
270
271     mov(reduce_loop_iter, reg_reduce_loop_work);
272     sub(reduce_loop_iter, jcp.reduce_loop_unroll);
273     jle(reduce_loop_tail, T_NEAR);
274
275     L(reduce_loop); {
276         fma_block(false);
277         add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step);
278         add(aux_reg_load_data, jcp.reduce_loop_load_step);
279         sub(reduce_loop_iter, jcp.reduce_loop_unroll);
280         jg(reduce_loop, T_NEAR);
281     }
282
283     L(reduce_loop_tail);
284     fma_block(true);
285
286     store();
287 }
288
289 void jit_avx2_1x1_conv_kernel_f32::generate_diff_bias_loop(int load_loop_blk)
290 {
291     if (!jcp.with_bias || jcp.prop_kind != backward_weights)
292         return;
293
294     Label diff_bias_loop, diff_bias_loop_out, diff_bias_init_out;
295     Label diff_bias_load;
296
297     auto diff_bias_ptr = [=](int i) {
298         return ptr[reg_diff_bias_data + i * jcp.oc_block * sizeof(float)];
299     };
300
301     auto load_ptr = [=](int u, int i) {
302         return ptr[aux_reg_load_data
303             + (i * jcp.os + u) * jcp.oc_block * sizeof(float)];
304     };
305
306     auto diff_bias_reg = [=](int i) { return Ymm(i); };
307
308     mov(reg_diff_bias_data, ptr[rsp + reg_diff_bias_data_stack_offt]);
309     cmp(reg_diff_bias_data, 0);
310     je(diff_bias_loop_out, T_NEAR);
311
312     test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
313     jz(diff_bias_load, T_NEAR);
314
315     for (int i = 0; i < load_loop_blk; ++i) {
316         auto r = diff_bias_reg(i);
317         vxorps(r, r, r);
318     }
319     jmp(diff_bias_init_out, T_NEAR);
320
321     L(diff_bias_load);
322     for (int i = 0; i < load_loop_blk; ++i)
323         vmovups(diff_bias_reg(i), diff_bias_ptr(i));
324
325     L(diff_bias_init_out);
326     mov(aux_reg_load_data, reg_load_data);
327     mov(reduce_loop_iter, reg_reduce_loop_work);
328     L(diff_bias_loop); {
329         for(int u = 0; u < jcp.reduce_loop_unroll; ++u)
330             for (int i = 0; i < load_loop_blk; ++i)
331                 vaddps(diff_bias_reg(i), diff_bias_reg(i), load_ptr(u, i));
332         assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0);
333         add(aux_reg_load_data, jcp.reduce_loop_load_step);
334         sub(reduce_loop_iter, jcp.reduce_loop_unroll);
335         jnz(diff_bias_loop, T_NEAR);
336     }
337
338     for (int i = 0; i < load_loop_blk; i++)
339         vmovups(diff_bias_ptr(i), diff_bias_reg(i));
340     add(reg_diff_bias_data, load_loop_blk * jcp.oc_block * sizeof(float));
341     mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data);
342
343     L(diff_bias_loop_out);
344 }
345
346 void jit_avx2_1x1_conv_kernel_f32::generate()
347 {
348     const auto &p = attr_.post_ops_;
349     int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
350     for (int i = 0; i < end_idx; i++) {
351         auto &post_op = p.entry_[i];
352         if (post_op.is_eltwise()) {
353             eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<avx2>(
354                     this,
355                     post_op.eltwise.alg,
356                     post_op.eltwise.alpha,
357                     post_op.eltwise.beta
358             ));
359         } else if (post_op.is_depthwise()) {
360         depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<avx2>(
361                 this,
362                 post_op.depthwise.alg
363         ));
364         }
365     }
366
367     preamble();
368
369     mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
370     mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]);
371     mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]);
372     if (jcp.with_bias) {
373         if (jcp.prop_kind == backward_weights) {
374             sub(rsp, stack_space_needed);
375             mov(reg_diff_bias_data, ptr[param1 + GET_OFF(bias_data)]);
376             mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data);
377         } else
378             mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
379     }
380
381     mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]);
382     mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]);
383     mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
384     mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
385     if (jcp.prop_kind == backward_weights)
386         mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]);
387     mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]);
388
389     auto generate_load_loop_body = [=] (int load_loop_blk) {
390         generate_bcast_loop(load_loop_blk);
391         add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
392         switch (jcp.prop_kind) {
393         case forward_training:
394         case forward_inference:
395             add(reg_bias_data, load_loop_blk * jcp.oc_block * sizeof(float));
396             if (jcp.with_dw_conv)
397                 add(reg_output_data,
398                     load_loop_blk * jcp.ow * jcp.oc_block * sizeof(float));
399             else
400                 add(reg_output_data,
401                     load_loop_blk * jcp.os * jcp.oc_block * sizeof(float));
402             break;
403         case backward_data:
404             add(reg_output_data,
405                     load_loop_blk * jcp.is * jcp.ic_block * sizeof(float));
406             break;
407         case backward_weights:
408             for (int i = 0; i < load_loop_blk; i++)
409                 add(reg_output_data, reg_output_stride);
410             break;
411         default:
412             assert(!"invalid prop_kind");
413         }
414         sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
415         add(reg_oc_off, load_loop_blk * jcp.oc_block * sizeof(float));
416     };
417
418     Label load_loop_blk_8;
419     Label load_loop_blk_16;
420     Label load_loop_blk_24;
421     Label load_loop_blk_end;
422
423     cmp(reg_load_loop_work, 8);
424     jle(load_loop_blk_8, T_NEAR);
425
426     cmp(reg_load_loop_work, 32);
427     je(load_loop_blk_16, T_NEAR);
428
429     cmp(reg_load_loop_work, 16);
430     jle(load_loop_blk_16, T_NEAR);
431
432     L(load_loop_blk_24); {
433         generate_diff_bias_loop(3);
434         generate_load_loop_body(3);
435         cmp(reg_load_loop_work, 32);
436         je(load_loop_blk_16);
437         cmp(reg_load_loop_work, 24);
438         jge(load_loop_blk_24);
439     }
440
441     cmp(reg_load_loop_work, 8);
442     jle(load_loop_blk_8, T_NEAR);
443
444     L(load_loop_blk_16); {
445         generate_diff_bias_loop(2);
446         generate_load_loop_body(2);
447         cmp(reg_load_loop_work, 16);
448         jge(load_loop_blk_16);
449     }
450
451     L(load_loop_blk_8); {
452         cmp(reg_load_loop_work, 0);
453         je(load_loop_blk_end, T_NEAR);
454         generate_diff_bias_loop(1);
455         generate_load_loop_body(1);
456     }
457
458     L(load_loop_blk_end);
459
460     if (jcp.with_bias && jcp.prop_kind == backward_weights)
461         add(rsp, 8);
462
463     postamble();
464
465     for (auto& inj : eltwise_injectors)
466         inj->prepare_table();
467 }
468
469 bool jit_avx2_1x1_conv_kernel_f32::post_ops_ok(
470         jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
471     const auto &p = attr.post_ops_;
472
473     auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
474     auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
475     auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
476     auto is_dw_conv = [&](int idx) { return p.entry_[idx].is_dw_conv(); };
477     auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
478
479     switch (p.len_) {
480         case 0: return true;
481         case 1: return is_simple(0) || is_sum(0) || is_dw_conv(0);
482         case 2: return (is_sum(0) && is_simple(1)) || (is_dw_conv(0) && is_eltwise(1)) ||
483                        (is_eltwise(0) && is_dw_conv(1)) || (is_dw_conv(0) && is_sum(1)) ||
484                        (is_simple(0) && is_simple(1));
485         case 3: return (is_eltwise(0) && is_dw_conv(1) && is_eltwise(2)) ||
486                        (is_dw_conv(0) && is_sum(1) && is_eltwise(2)) ||
487                        (is_sum(0) && is_simple(1) && is_simple(2));
488         case 4: return (is_eltwise(0) && is_dw_conv(1) && is_sum(2) && is_eltwise(3));
489         default: return false;
490     }
491
492     return false;
493 }
494
495 status_t jit_avx2_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp,
496         const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
497         const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
498         const primitive_attr_t &attr)
499 {
500     if (!mayiuse(avx)) return status::unimplemented;
501
502     // TODO (Roma): this code is duplicated from the generic kernel; maybe the
503     // configuration struct could do some stuff below
504     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
505     const int ndims = src_d.ndims();
506
507     jcp.prop_kind = cd.prop_kind;
508
509     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
510     jcp.mb = src_d.dims()[0];
511
512     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
513     jcp.oc_without_padding = jcp.oc;
514     jcp.ic = src_d.dims()[1] / jcp.ngroups;
515
516     jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2];
517     jcp.iw = src_d.dims()[ndims - 1];
518     jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2];
519     jcp.ow = dst_d.dims()[ndims - 1];
520
521     jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2];
522     jcp.kw = weights_d.dims()[with_groups + ndims - 1];
523
524     jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0];
525     jcp.l_pad = cd.padding[0][ndims - 3];
526
527     jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0];
528     jcp.stride_w = cd.strides[ndims - 3];
529
530     jcp.src_fmt = src_d.format();
531     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
532
533     if (!post_ops_ok(jcp, attr))
534         return status::unimplemented;
535
536     const auto &p = attr.post_ops_;
537
538     int dw_conv_ind = p.find(primitive_kind::convolution);
539     jcp.with_dw_conv = dw_conv_ind != -1;
540     jcp.with_dw_conv = dw_conv_ind != -1;
541     if (jcp.with_dw_conv) {
542         jcp.dw_conv_oh = jcp.oh;
543         jcp.dw_conv_ow = jcp.ow;
544         jcp.oh = p.entry_[dw_conv_ind].dw_conv.in_h;
545         jcp.ow = p.entry_[dw_conv_ind].dw_conv.in_w;
546     }
547
548     if (jcp.with_dw_conv && !mayiuse(avx2))
549         return status::unimplemented;
550
551     if (!mayiuse(avx2)) {
552         for (int i = 0; i < p.len_; i++) {
553             auto &post_op = p.entry_[i];
554             if (post_op.is_eltwise()) {
555                 if (post_op.eltwise.alg != alg_kind::eltwise_relu)
556                     return status::unimplemented;
557             } else if (post_op.is_depthwise()) {
558                 return status::unimplemented;
559             }
560         }
561     }
562
563     jcp.with_sum = p.find(primitive_kind::sum, 0, dw_conv_ind) != -1;
564
565     jcp.src_dt = cd.src_desc.data_type;
566     jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
567     jcp.dst_dt = cd.dst_desc.data_type;
568
569     jcp.os = jcp.oh * jcp.ow;
570     jcp.is = jcp.ih * jcp.iw;
571
572     const int is_bwd_d = jcp.prop_kind == backward_data;
573     memory_format_t weights_format = with_groups
574         ? utils::pick(2 * ndims - 6 + is_bwd_d, gOIw8i8o, gOIw8o8i, gOIhw8i8o,
575             gOIhw8o8i)
576         : utils::pick(2 * ndims - 6 + is_bwd_d, OIw8i8o, OIw8o8i, OIhw8i8o,
577             OIhw8o8i);
578
579     const int simd_w = 8;
580
581     jcp.oc = rnd_up(jcp.oc, simd_w);
582     jcp.ic = rnd_up(jcp.ic, simd_w);
583
584     bool args_ok = true
585         && jcp.ngroups == 1
586         && one_of(src_d.format(), nCw8c, nChw8c)
587         && weights_d.format() == weights_format
588         && one_of(cd.bias_desc.format, memory_format::undef, any, x)
589         && one_of(dst_d.format(), nCw8c, nChw8c);
590     if (!args_ok) return status::unimplemented;
591
592     args_ok = true
593         && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0
594         && jcp.t_pad == 0 && jcp.l_pad == 0
595         && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides
596         && jcp.kh == 1 && jcp.kw == 1;
597     if (!args_ok) return status::unimplemented;
598
599     // TODO: remove this restriction
600     // optimized 1x1 bwd_w does not support Intel AVX
601     if (jcp.prop_kind == backward_weights && !mayiuse(avx2))
602         return status::unimplemented;
603
604     jcp.ic_block = jcp.oc_block = simd_w;
605
606     jcp.ur = mayiuse(avx2) ? 4 : 3; // Intel AVX support
607
608     int load_blocking{ 0 };
609     int load_blocking_max{ 0 };
610     int bcast_blocking{ 0 };
611     int bcast_blocking_max{ 0 };
612     int reduce_blocking{ 0 };
613
614     if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
615         jcp.reduce_dim = jcp.ic;
616         jcp.reduce_block = jcp.ic_block;
617
618         jcp.load_dim = jcp.oc;
619         jcp.load_block = jcp.oc_block;
620
621         jcp.bcast_dim = jcp.with_dw_conv ? jcp.iw : jcp.is;
622         jcp.bcast_block = jcp.ur;
623
624         jcp.reduce_loop_unroll = jcp.reduce_block;
625         jcp.reduce_loop_bcast_step
626             = jcp.reduce_loop_unroll * jcp.is * sizeof(float);
627         jcp.reduce_loop_load_step
628             = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float);
629
630         jcp.bcast_loop_output_step = jcp.ur * jcp.oc_block * sizeof(float);
631         jcp.bcast_loop_output_substep = -1; // unused
632         jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_block * sizeof(float);
633         jcp.bcast_loop_bcast_substep = -1; // unused
634
635         jcp.load_loop_load_step = jcp.ic * jcp.oc_block * sizeof(float);
636         jcp.load_loop_iter_step = jcp.oc_block;
637
638         load_blocking = jcp.with_dw_conv ? nstl::min(3 * jcp.load_block, jcp.oc) : 120; // assumes the kernel is jcp.ur x 3
639         load_blocking_max = jcp.with_dw_conv ? nstl::min(3 * jcp.load_block, jcp.oc) : 144;
640         bcast_blocking = 128; // affects load balancing across threads
641         bcast_blocking_max = 192;
642         reduce_blocking = 128; // affects L1$ utilization
643     } else if (jcp.prop_kind == backward_data) {
644         jcp.reduce_dim = jcp.oc;
645         jcp.reduce_block = jcp.oc_block;
646
647         jcp.load_dim = jcp.ic;
648         jcp.load_block = jcp.oc_block;
649
650         jcp.bcast_dim = jcp.os;
651         jcp.bcast_block = jcp.ur;
652
653         jcp.reduce_loop_unroll = jcp.reduce_block;
654         jcp.reduce_loop_bcast_step
655             = jcp.reduce_loop_unroll * jcp.os * sizeof(float);
656         jcp.reduce_loop_load_step
657             = jcp.reduce_loop_unroll * jcp.ic * sizeof(float);
658
659         jcp.bcast_loop_output_step = jcp.ur * jcp.ic_block * sizeof(float);
660         jcp.bcast_loop_output_substep = -1; // unused
661         jcp.bcast_loop_bcast_step = jcp.ur * jcp.oc_block * sizeof(float);
662         jcp.bcast_loop_bcast_substep = -1; // unused
663
664         jcp.load_loop_load_step = jcp.oc_block * jcp.ic_block * sizeof(float);
665         jcp.load_loop_iter_step = jcp.ic_block;
666
667         load_blocking = 96; // assumes the kernel is jcp.ur x 3
668         load_blocking_max = 144;
669         bcast_blocking = 128; // affects load balancing across threads
670         bcast_blocking_max = 196;
671         reduce_blocking = 64; // affects L1$ utilization
672     } else if (jcp.prop_kind == backward_weights) {
673         jcp.reduce_dim = jcp.os;
674         jcp.reduce_block = 1;
675
676         jcp.load_dim = jcp.oc;
677         jcp.load_block = jcp.oc_block;
678
679         jcp.bcast_dim = jcp.ic;
680         jcp.bcast_block = jcp.ic_block;
681
682         jcp.reduce_loop_unroll = jcp.reduce_block;
683         jcp.reduce_loop_bcast_step
684             = jcp.reduce_loop_unroll * jcp.ic_block * sizeof(float);
685         jcp.reduce_loop_load_step
686             = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float);
687
688         jcp.bcast_loop_output_step = jcp.oc_block * jcp.ic_block * sizeof(float);
689         jcp.bcast_loop_output_substep = jcp.oc_block * jcp.ur * sizeof(float);
690         jcp.bcast_loop_bcast_step = jcp.ic_block * jcp.is * sizeof(float);
691         jcp.bcast_loop_bcast_substep = jcp.ur * sizeof(float);
692
693         jcp.load_loop_load_step = jcp.oc_block * jcp.os * sizeof(float);
694         jcp.load_loop_iter_step = jcp.oc_block;
695
696         /* --- */
697
698         load_blocking = div_up(jcp.load_dim, jcp.load_block);
699         while (true) {
700             if (load_blocking <= 32) break;
701             else if (load_blocking % 2 == 0) load_blocking /= 2;
702             else if (load_blocking % 3 == 0) load_blocking /= 3;
703             else break;
704         }
705         load_blocking *= jcp.load_block;
706         load_blocking_max = load_blocking;
707         assert(jcp.load_dim % load_blocking == 0);
708
709         bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
710         while (true) {
711             if (bcast_blocking <= 9) break;
712             else if (bcast_blocking % 2 == 0) bcast_blocking /= 2;
713             else if (bcast_blocking % 3 == 0) bcast_blocking /= 3;
714             else break;
715         }
716         bcast_blocking *= jcp.bcast_block;
717         bcast_blocking_max = bcast_blocking;
718         assert(jcp.bcast_dim % bcast_blocking == 0);
719
720         reduce_blocking = 128; // affects L1$ utilization
721     } else
722         return status::unimplemented;
723
724     assert(load_blocking);
725     assert(load_blocking_max);
726     assert(bcast_blocking);
727     assert(bcast_blocking_max);
728     assert(reduce_blocking);
729
730     assert(jcp.bcast_block % jcp.ur == 0);
731     jcp.ur_tail = jcp.bcast_dim % jcp.ur;
732
733     jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block;
734     jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block;
735     jcp.nb_load_blocking = load_blocking / jcp.load_block;
736     jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block;
737     jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block;
738
739     jcp.nb_bcast = jcp.with_dw_conv ? jcp.ih : div_up(jcp.bcast_dim, jcp.bcast_block);
740     jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
741     jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
742
743     return status::success;
744 }
745
746 void jit_avx2_1x1_conv_kernel_f32::init_scratchpad(
747         memory_tracking::registrar_t &scratchpad,
748         const jit_1x1_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw) {
749     using namespace mkldnn::impl::memory_tracking::names;
750
751     if (jcp.prop_kind != backward_data && jcp.oc != jcp.oc_without_padding)
752         scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
753
754     if (jcp.with_dw_conv) {
755         const int nthreads = mkldnn_get_max_threads();
756         size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * (jcp.oc / jcp.oc_block);
757         scratchpad.book(key_dw_conv_buffer, sizeof(float) * dw_conv_buffer_size_ * nthreads);
758
759         if (jcp.oc != jcp.oc_without_padding)
760             scratchpad.book(key_dw_conv_padded_bias, sizeof(float) * jcp.oc);
761     }
762 }
763
764 }
765 }
766 }