Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_sse42_1x1_conv_kernel_f32.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "nstl.hpp"
19 #include "type_helpers.hpp"
20 #include "utils.hpp"
21 #include "cpu_memory.hpp"
22
23 #include "jit_sse42_1x1_conv_kernel_f32.hpp"
24
25 #define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
26
27 namespace mkldnn {
28 namespace impl {
29 namespace cpu {
30
31 using namespace mkldnn::impl::prop_kind;
32 using namespace mkldnn::impl::memory_format;
33 using namespace mkldnn::impl::utils;
34
35 using namespace Xbyak;
36
37 void jit_sse42_1x1_conv_kernel_f32::generate_bcast_loop(int load_loop_blk)
38 {
39     mov(aux1_reg_bcast_data, reg_bcast_data);
40     mov(aux_reg_output_data, reg_output_data);
41     mov(bcast_loop_iter, reg_bcast_loop_work);
42
43     Label bcast_loop;
44     Label bcast_loop_tail;
45
46     cmp(bcast_loop_iter, jcp.ur);
47     jl(bcast_loop_tail, T_NEAR);
48
49     L(bcast_loop); {
50         assert(jcp.bcast_block % jcp.ur == 0);
51         int num_substeps = jcp.bcast_block / jcp.ur;
52         assert(num_substeps > 0 && num_substeps < 10);
53         for (int i = 0; i < num_substeps; i++) {
54             generate_reduce_loop(load_loop_blk, jcp.ur);
55             if (i < num_substeps - 1) {
56                 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
57                 add(aux_reg_output_data, jcp.bcast_loop_output_substep);
58             } else {
59                 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step
60                         - (num_substeps - 1) * jcp.bcast_loop_bcast_substep);
61                 add(aux_reg_output_data, jcp.bcast_loop_output_step
62                         - (num_substeps - 1) * jcp.bcast_loop_output_substep);
63             }
64         }
65         sub(bcast_loop_iter, jcp.bcast_block);
66         cmp(bcast_loop_iter, jcp.bcast_block);
67         jge(bcast_loop, T_NEAR);
68     }
69
70     L(bcast_loop_tail);
71     if (jcp.ur_tail) {
72         Label bcast_loop_tail_out;
73         cmp(bcast_loop_iter, 0);
74         jz(bcast_loop_tail_out, T_NEAR);
75         generate_reduce_loop(load_loop_blk, jcp.ur_tail);
76         L(bcast_loop_tail_out);
77     }
78 }
79
80 void jit_sse42_1x1_conv_kernel_f32::generate_reduce_loop(
81         int load_loop_blk, int ur)
82 {
83     auto reg_load = [=](int i, int n) {
84         return Xmm(2*ur * load_loop_blk + 2*i + n + 1);
85     };
86
87     auto reg_accum = [=](int i, int j, int n) {
88         return Xmm(n*load_loop_blk*ur + i*ur + j + 1);
89     };
90
91     auto bias_ptr = [=](int i, int n) {
92         return ptr[reg_bias_data + sizeof(float) * jcp.oc_block * i + n*4*sizeof(float)];
93     };
94
95     auto bcast_ptr = [=](int u, int j) {
96         assert(j < jcp.ur);
97         assert(u <= jcp.reduce_loop_unroll);
98         size_t offt;
99         if (one_of(jcp.prop_kind,
100                     forward_training, forward_inference, backward_data)) {
101             assert(jcp.reduce_loop_unroll == (jcp.prop_kind == backward_data)
102                     ? jcp.oc_block : jcp.ic_block);
103             auto height = (jcp.prop_kind == backward_data) ? jcp.os : jcp.is;
104             offt = (u == jcp.reduce_loop_unroll)
105                 ? (height + j) * jcp.reduce_loop_unroll
106                 : j * jcp.reduce_loop_unroll + u;
107         } else
108             offt = u * jcp.ic_block + j;
109         return ptr[aux_reg_bcast_data + sizeof(float) * offt];
110     };
111
112     auto load_ptr = [=](int u, int i, int n) {
113         size_t offt;
114         size_t u0 = u % jcp.reduce_loop_unroll;
115         size_t u1 = u / jcp.reduce_loop_unroll;
116         switch (jcp.prop_kind) {
117         case backward_data:
118             offt = (i * jcp.oc_block + u0) * jcp.ic_block;
119             break;
120         case backward_weights:
121             offt = (i * jcp.os + u0) * jcp.oc_block;
122             break;
123         default:
124             offt = (i * jcp.ic + u0) * jcp.oc_block;
125         }
126         return ptr[aux_reg_load_data
127             + u1 * jcp.reduce_loop_load_step + sizeof(float) * offt + n * 4 * sizeof(float)];
128     };
129
130     auto output_ptr = [=](int i, int j, int n) {
131         switch (jcp.prop_kind) {
132         case backward_data:
133             return ptr[aux_reg_output_data +
134                 (i * jcp.is + j) * jcp.ic_block * sizeof(float) + n * 4 * sizeof(float)];
135         case backward_weights:
136             return ptr[aux_reg_output_data
137                 + (i ? reg_output_stride * i : 0) // TODO: Xbyak should allow 0 scale
138                 + sizeof(float) * jcp.oc_block * j + n * 4 * sizeof(float)];
139         default:
140             if (jcp.with_dw_conv)
141                 return ptr[aux_reg_output_data +
142                            (i * jcp_dw.kh * jcp.ow + j) * jcp.oc_block * sizeof(float) + n*4*sizeof(float)];
143             else
144                 return ptr[aux_reg_output_data +
145                     (i * jcp.os + j) * jcp.oc_block * sizeof(float) + n*4*sizeof(float)];
146         }
147     };
148
149     auto init = [=]() {
150         Label init_done;
151         Label init_zero;
152
153         if (jcp.with_bias && one_of(jcp.prop_kind, forward_training,
154                     forward_inference)) {
155             test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
156             jz(init_zero);
157
158             for (int i = 0; i < load_loop_blk; i++)
159                 for (int j = 0; j < ur; ++j) {
160                     movups(reg_accum(i, j, 0), bias_ptr(i, 0));
161                     movups(reg_accum(i, j, 1), bias_ptr(i, 1));
162                 }
163             jmp(init_done);
164         }
165
166         L(init_zero);
167         for (int i = 0; i < load_loop_blk; ++i)
168             for (int j = 0; j < ur; ++j) {
169                 auto r0 = reg_accum(i, j, 0);
170                 auto r1 = reg_accum(i, j, 1);
171                 xorps(r0, r0);
172                 xorps(r1, r1);
173             }
174
175         L(init_done);
176
177         // load weights
178         for (int i = 0; i < load_loop_blk; ++i) {
179             movups(reg_load(i, 0), load_ptr(0, i, 0));
180             movups(reg_load(i, 1), load_ptr(0, i, 1));
181         }
182
183         movss(reg_bcast, bcast_ptr(0, 0));
184         shufps(reg_bcast, reg_bcast, 0);
185     }; // init()
186
187     auto store = [=]() {
188         Label store_noadd;
189
190         if (!jcp.with_sum) {
191             test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
192             jnz(store_noadd, T_NEAR);
193         }
194
195         for (int j = 0; j < ur; ++j)
196             for (int i = 0; i < load_loop_blk; ++i) {
197                 auto r0 = reg_accum(i, j, 0);
198                 auto r1 = reg_accum(i, j, 1);
199                 addps(r0, output_ptr(i, j, 0));
200                 addps(r1, output_ptr(i, j, 1));
201             }
202
203         L(store_noadd);
204
205         Label store_no_postops;
206         test(reg_reduce_pos_flag, FLAG_REDUCE_LAST);
207         jz(store_no_postops, T_NEAR);
208
209         int eltwise_inj_idx = 0;
210         int depthwise_inj_idx = 0;
211         const auto &p = attr_.post_ops_;
212
213         int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
214         for (int i = 0; i < end_idx; i++) {
215             auto& post_op = p.entry_[i];
216             if (post_op.is_eltwise()) {
217                 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(1, 2 * ur * load_loop_blk + 1);
218                 eltwise_inj_idx++;
219             } else if (post_op.is_depthwise()) {
220                 mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
221                 mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
222
223                 add(reg_d_weights, reg_oc_off);
224                 add(reg_d_bias, reg_oc_off);
225
226                 for (int j = 0; j < load_loop_blk; ++j) {
227                     for (int k = 0; k < 2; k++) {
228                         int start_idx = reg_accum(j, 0, k).getIdx();
229                         int end_idx = reg_accum(j, ur, k).getIdx();
230
231                         depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
232                                 start_idx, end_idx, reg_d_weights, reg_d_bias);
233
234                         add(reg_d_weights, 4 * sizeof(float));
235                         add(reg_d_bias, 4 * sizeof(float));
236                     }
237                 }
238
239                 depthwise_inj_idx++;
240             }
241         }
242
243         L(store_no_postops);
244
245         for (int j = 0; j < ur; ++j)
246             for (int i = 0; i < load_loop_blk; ++i) {
247                 movups(output_ptr(i, j, 0), reg_accum(i, j, 0));
248                 movups(output_ptr(i, j, 1), reg_accum(i, j, 1));
249             }
250     };
251
252     auto fma_block = [=](bool last_block) {
253         for (int u = 0; u < jcp.reduce_loop_unroll; ++u) {
254             for (int j = 0; j < ur; ++j) {
255                 for (int i = 0; i < load_loop_blk; ++i) {
256                     mulps(reg_load(i, 0), reg_bcast);
257                     mulps(reg_load(i, 1), reg_bcast);
258                     addps(reg_accum(i, j, 0), reg_load(i, 0));
259                     addps(reg_accum(i, j, 1), reg_load(i, 1));
260
261                     if (j == ur - 1 && !(last_block && u == jcp.reduce_loop_unroll - 1)) {
262                         movups(reg_load(i, 0), load_ptr(u + 1, i, 0));
263                         movups(reg_load(i, 1), load_ptr(u + 1, i, 1));
264                     }
265                 }
266                 if (j < ur - 1) {
267                     movss(reg_bcast, bcast_ptr(u, j + 1));
268                     shufps(reg_bcast, reg_bcast, 0);
269                 }
270             } // for ur
271             if (!last_block || u < jcp.reduce_loop_unroll - 1) {
272                 movss(reg_bcast, bcast_ptr(u + 1, 0));
273                 shufps(reg_bcast, reg_bcast, 0);
274             }
275         } // for reduce_loop_unroll
276     };
277
278     Label reduce_loop;
279     Label reduce_loop_tail;
280
281     mov(aux_reg_load_data, reg_load_data);
282     mov(aux_reg_bcast_data, aux1_reg_bcast_data);
283
284     init();
285
286     mov(reduce_loop_iter, reg_reduce_loop_work);
287     sub(reduce_loop_iter, jcp.reduce_loop_unroll);
288     jle(reduce_loop_tail, T_NEAR);
289
290     L(reduce_loop); {
291         fma_block(false);
292         add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step);
293         add(aux_reg_load_data, jcp.reduce_loop_load_step);
294         sub(reduce_loop_iter, jcp.reduce_loop_unroll);
295         jg(reduce_loop, T_NEAR);
296     }
297
298     L(reduce_loop_tail);
299     fma_block(true);
300
301     store();
302 } // reduce_loop()
303
304 void jit_sse42_1x1_conv_kernel_f32::generate_diff_bias_loop(int load_loop_blk)
305 {
306     if (!jcp.with_bias || jcp.prop_kind != backward_weights)
307         return;
308
309     Label diff_bias_loop, diff_bias_loop_out, diff_bias_init_out;
310     Label diff_bias_load;
311
312     auto diff_bias_ptr = [=](int i, int n) {
313         return ptr[reg_diff_bias_data + i * jcp.oc_block * sizeof(float)+ 4*n*sizeof(float)];
314     };
315
316     auto load_ptr = [=](int u, int i, int n) {
317         return ptr[aux_reg_load_data
318             + (i * jcp.os + u) * jcp.oc_block * sizeof(float) + 4*n*sizeof(float)];
319     };
320
321     auto diff_bias_reg = [=](int i, int n) { return Xmm(2*i + n + 1); };
322
323     mov(reg_diff_bias_data, ptr[rsp + reg_diff_bias_data_stack_offt]);
324     cmp(reg_diff_bias_data, 0);
325     je(diff_bias_loop_out, T_NEAR);
326
327     test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
328     jz(diff_bias_load, T_NEAR);
329
330     for (int i = 0; i < load_loop_blk; ++i) {
331         auto r0 = diff_bias_reg(i, 0);
332         auto r1 = diff_bias_reg(i, 1);
333         xorps(r0, r0);
334         xorps(r1, r1);
335     }
336     jmp(diff_bias_init_out, T_NEAR);
337
338     L(diff_bias_load);
339     for (int i = 0; i < load_loop_blk; ++i) {
340         movups(diff_bias_reg(i, 0), diff_bias_ptr(i, 0));
341         movups(diff_bias_reg(i, 1), diff_bias_ptr(i, 1));
342     }
343
344     L(diff_bias_init_out);
345     mov(aux_reg_load_data, reg_load_data);
346     mov(reduce_loop_iter, reg_reduce_loop_work);
347     L(diff_bias_loop); {
348         for(int u = 0; u < jcp.reduce_loop_unroll; ++u)
349             for (int i = 0; i < load_loop_blk; ++i) {
350                 addps(diff_bias_reg(i, 0), load_ptr(u, i, 0));
351                 addps(diff_bias_reg(i, 1), load_ptr(u, i, 1));
352             }
353         assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0);
354         add(aux_reg_load_data, jcp.reduce_loop_load_step);
355         sub(reduce_loop_iter, jcp.reduce_loop_unroll);
356         jnz(diff_bias_loop, T_NEAR);
357     }
358
359     for (int i = 0; i < load_loop_blk; i++) {
360         movups(diff_bias_ptr(i, 0), diff_bias_reg(i, 0));
361         movups(diff_bias_ptr(i, 1), diff_bias_reg(i, 1));
362     }
363
364     add(reg_diff_bias_data, load_loop_blk * jcp.oc_block * sizeof(float));
365     mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data);
366
367     L(diff_bias_loop_out);
368 }
369
370 void jit_sse42_1x1_conv_kernel_f32::generate()
371 {
372     const auto &p = attr_.post_ops_;
373     int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
374     for (int i = 0; i < end_idx; i++) {
375         auto &post_op = p.entry_[i];
376         if (post_op.is_eltwise()) {
377             eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<sse42>(
378                     this,
379                     post_op.eltwise.alg,
380                     post_op.eltwise.alpha,
381                     post_op.eltwise.beta
382             ));
383         } else if (post_op.is_depthwise()) {
384             depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<sse42>(
385                     this,
386                     post_op.depthwise.alg
387             ));
388         }
389     }
390
391     preamble();
392
393     mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
394     mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]);
395     mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]);
396     if (jcp.with_bias) {
397         if (jcp.prop_kind == backward_weights) {
398             sub(rsp, stack_space_needed);
399             mov(reg_diff_bias_data, ptr[param1 + GET_OFF(bias_data)]);
400             mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data);
401         } else
402             mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
403     }
404
405     mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]);
406     mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]);
407     mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
408     mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
409     if (jcp.prop_kind == backward_weights)
410         mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]);
411     mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]);
412
413     auto generate_load_loop_body = [=] (int load_loop_blk) {
414         generate_bcast_loop(load_loop_blk);
415         add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
416         switch (jcp.prop_kind) {
417         case forward_training:
418         case forward_inference:
419             add(reg_bias_data, load_loop_blk * jcp.oc_block * sizeof(float));
420             if (jcp.with_dw_conv)
421                 add(reg_output_data,
422                     load_loop_blk * jcp.ow * jcp.oc_block * sizeof(float));
423             else
424                 add(reg_output_data,
425                         load_loop_blk * jcp.os * jcp.oc_block * sizeof(float));
426             break;
427         case backward_data:
428             add(reg_output_data,
429                     load_loop_blk * jcp.is * jcp.ic_block * sizeof(float));
430             break;
431         case backward_weights:
432             for (int i = 0; i < load_loop_blk; i++)
433                 add(reg_output_data, reg_output_stride);
434             break;
435         default:
436             assert(!"invalid prop_kind");
437         }
438         sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
439         add(reg_oc_off, load_loop_blk * jcp.oc_block * sizeof(float));
440     };
441
442     Label load_loop_blk_8;
443     Label load_loop_blk_16;
444     Label load_loop_blk_24;
445     Label load_loop_blk_end;
446
447     cmp(reg_load_loop_work, 8);
448     jle(load_loop_blk_8, T_NEAR);
449
450     cmp(reg_load_loop_work, 32);
451     je(load_loop_blk_16, T_NEAR);
452
453     cmp(reg_load_loop_work, 16);
454     jle(load_loop_blk_16, T_NEAR);
455
456     L(load_loop_blk_24); {
457         generate_diff_bias_loop(3);
458         generate_load_loop_body(3);
459         cmp(reg_load_loop_work, 32);
460         je(load_loop_blk_16);
461         cmp(reg_load_loop_work, 24);
462         jge(load_loop_blk_24);
463     }
464
465     cmp(reg_load_loop_work, 8);
466     jle(load_loop_blk_8, T_NEAR);
467
468     L(load_loop_blk_16); {
469         generate_diff_bias_loop(2);
470         generate_load_loop_body(2);
471         cmp(reg_load_loop_work, 16);
472         jge(load_loop_blk_16);
473     }
474
475     L(load_loop_blk_8); {
476         cmp(reg_load_loop_work, 0);
477         je(load_loop_blk_end, T_NEAR);
478         generate_diff_bias_loop(1);
479         generate_load_loop_body(1);
480     }
481
482     L(load_loop_blk_end);
483
484     if (jcp.with_bias && jcp.prop_kind == backward_weights)
485         add(rsp, stack_space_needed);
486
487     postamble();
488
489     for (auto& inj : eltwise_injectors)
490         inj->prepare_table();
491 }
492
493 bool jit_sse42_1x1_conv_kernel_f32::post_ops_ok(
494         jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
495     const auto &p = attr.post_ops_;
496
497     auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
498     auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
499     auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
500     auto is_dw_conv = [&](int idx) { return p.entry_[idx].is_dw_conv(); };
501     auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
502
503     switch (p.len_) {
504         case 0: return true;
505         case 1: return is_simple(0) || is_sum(0) || is_dw_conv(0);
506         case 2: return (is_sum(0) && is_simple(1)) || (is_dw_conv(0) && is_eltwise(1)) ||
507                        (is_eltwise(0) && is_dw_conv(1)) || (is_dw_conv(0) && is_sum(1)) ||
508                        (is_simple(0) && is_simple(1));
509         case 3: return (is_eltwise(0) && is_dw_conv(1) && is_eltwise(2)) ||
510                        (is_dw_conv(0) && is_sum(1) && is_eltwise(2)) ||
511                        (is_sum(0) && is_simple(1) && is_simple(2));
512         case 4: return (is_eltwise(0) && is_dw_conv(1) && is_sum(2) && is_eltwise(3));
513         default: return false;
514     }
515
516     return false;
517 }
518
519 status_t jit_sse42_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp,
520         const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
521         const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
522         const primitive_attr_t &attr)
523 {
524     if (!mayiuse(sse42))
525         return status::unimplemented;
526
527     // TODO (Roma): this code is duplicated from the generic kernel; maybe the
528     // configuration struct could do some stuff below
529     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
530     const int ndims = src_d.ndims();
531
532     jcp.prop_kind = cd.prop_kind;
533
534     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
535     jcp.mb = src_d.dims()[0];
536
537     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
538     jcp.oc_without_padding = jcp.oc;
539     jcp.ic = src_d.dims()[1] / jcp.ngroups;
540
541     jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2];
542     jcp.iw = src_d.dims()[ndims - 1];
543     jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2];
544     jcp.ow = dst_d.dims()[ndims - 1];
545
546     jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2];
547     jcp.kw = weights_d.dims()[with_groups + ndims - 1];
548
549     jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0];
550     jcp.l_pad = cd.padding[0][ndims - 3];
551
552     jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0];
553     jcp.stride_w = cd.strides[ndims - 3];
554
555     jcp.src_fmt = src_d.format();
556     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
557
558     if (!post_ops_ok(jcp, attr))
559         return status::unimplemented;
560
561     const auto &p = attr.post_ops_;
562
563     int dw_conv_ind = p.find(primitive_kind::convolution);
564     jcp.with_dw_conv = dw_conv_ind != -1;
565     if (jcp.with_dw_conv) {
566         jcp.dw_conv_oh = jcp.oh;
567         jcp.dw_conv_ow = jcp.ow;
568         jcp.oh = p.entry_[dw_conv_ind].dw_conv.in_h;
569         jcp.ow = p.entry_[dw_conv_ind].dw_conv.in_w;
570     }
571
572     jcp.with_sum = p.find(primitive_kind::sum, 0, dw_conv_ind) != -1;
573
574     jcp.src_dt = cd.src_desc.data_type;
575     jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
576     jcp.dst_dt = cd.dst_desc.data_type;
577
578     jcp.os = jcp.oh * jcp.ow;
579     jcp.is = jcp.ih * jcp.iw;
580
581     const int is_bwd_d = jcp.prop_kind == backward_data;
582     memory_format_t weights_format = with_groups
583         ? utils::pick(2 * ndims - 6 + is_bwd_d, gOIw8i8o, gOIw8o8i, gOIhw8i8o,
584             gOIhw8o8i)
585         : utils::pick(2 * ndims - 6 + is_bwd_d, OIw8i8o, OIw8o8i, OIhw8i8o,
586             OIhw8o8i);
587
588     bool args_ok = true
589         && jcp.ngroups == 1
590         && one_of(src_d.format(), nCw8c, nChw8c)
591         && weights_d.format() == weights_format
592         && one_of(cd.bias_desc.format, memory_format::undef, any, x)
593         && one_of(dst_d.format(), nCw8c, nChw8c);
594     if (!args_ok) return status::unimplemented;
595
596     const int simd_w = 4;
597
598     jcp.oc = rnd_up(jcp.oc, simd_w*2);
599     jcp.ic = rnd_up(jcp.ic, simd_w*2);
600
601     jcp.ic_block = jcp.oc_block = simd_w*2;
602
603     args_ok = true
604         && jcp.oc % jcp.oc_block == 0
605         && jcp.ic % jcp.ic_block == 0
606         && jcp.t_pad == 0 && jcp.l_pad == 0
607         && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides
608         && jcp.kh == 1 && jcp.kw == 1;
609     if (!args_ok) return status::unimplemented;
610
611     jcp.ur = 1;
612
613     int load_blocking{ 0 };
614     int load_blocking_max{ 0 };
615     int bcast_blocking{ 0 };
616     int bcast_blocking_max{ 0 };
617     int reduce_blocking{ 0 };
618
619     if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
620         jcp.reduce_dim = jcp.ic;
621         jcp.reduce_block = jcp.ic_block;
622
623         jcp.load_dim = jcp.oc;
624         jcp.load_block = jcp.oc_block;
625
626         jcp.bcast_dim = jcp.with_dw_conv ? jcp.iw : jcp.is;
627         jcp.bcast_block = jcp.ur;
628
629         jcp.reduce_loop_unroll = jcp.reduce_block;
630         jcp.reduce_loop_bcast_step
631             = jcp.reduce_loop_unroll * jcp.is * sizeof(float);
632         jcp.reduce_loop_load_step
633             = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float);
634
635         jcp.bcast_loop_output_step = jcp.ur * jcp.oc_block * sizeof(float);
636         jcp.bcast_loop_output_substep = -1; // unused
637         jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_block * sizeof(float);
638         jcp.bcast_loop_bcast_substep = -1; // unused
639
640         jcp.load_loop_load_step = jcp.ic * jcp.oc_block * sizeof(float);
641         jcp.load_loop_iter_step = jcp.oc_block;
642
643         load_blocking = jcp.with_dw_conv ? nstl::min(3 * jcp.load_block, jcp.oc) : 120; // assumes the kernel is jcp.ur x 3
644         load_blocking_max = jcp.with_dw_conv ? nstl::min(3 * jcp.load_block, jcp.oc) : 144;
645         bcast_blocking = 128; // affects load balancing across threads
646         bcast_blocking_max = 192;
647         reduce_blocking = 128; // affects L1$ utilization
648     } else if (jcp.prop_kind == backward_data) {
649         jcp.reduce_dim = jcp.oc;
650         jcp.reduce_block = jcp.oc_block;
651
652         jcp.load_dim = jcp.ic;
653         jcp.load_block = jcp.oc_block;
654
655         jcp.bcast_dim = jcp.os;
656         jcp.bcast_block = jcp.ur;
657
658         jcp.reduce_loop_unroll = jcp.reduce_block;
659         jcp.reduce_loop_bcast_step
660             = jcp.reduce_loop_unroll * jcp.os * sizeof(float);
661         jcp.reduce_loop_load_step
662             = jcp.reduce_loop_unroll * jcp.ic * sizeof(float);
663
664         jcp.bcast_loop_output_step = jcp.ur * jcp.ic_block * sizeof(float);
665         jcp.bcast_loop_output_substep = -1; // unused
666         jcp.bcast_loop_bcast_step = jcp.ur * jcp.oc_block * sizeof(float);
667         jcp.bcast_loop_bcast_substep = -1; // unused
668
669         jcp.load_loop_load_step = jcp.oc_block * jcp.ic_block * sizeof(float);
670         jcp.load_loop_iter_step = jcp.ic_block;
671
672         load_blocking = 96; // assumes the kernel is jcp.ur x 3
673         load_blocking_max = 144;
674         bcast_blocking = 128; // affects load balancing across threads
675         bcast_blocking_max = 196;
676         reduce_blocking = 64; // affects L1$ utilization
677     } else if (jcp.prop_kind == backward_weights) {
678         jcp.reduce_dim = jcp.os;
679         jcp.reduce_block = 1;
680
681         jcp.load_dim = jcp.oc;
682         jcp.load_block = jcp.oc_block;
683
684         jcp.bcast_dim = jcp.ic;
685         jcp.bcast_block = jcp.ic_block;
686
687         jcp.reduce_loop_unroll = jcp.reduce_block;
688         jcp.reduce_loop_bcast_step
689             = jcp.reduce_loop_unroll * jcp.ic_block * sizeof(float);
690         jcp.reduce_loop_load_step
691             = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float);
692
693         jcp.bcast_loop_output_step = jcp.oc_block * jcp.ic_block * sizeof(float);
694         jcp.bcast_loop_output_substep = jcp.oc_block * jcp.ur * sizeof(float);
695         jcp.bcast_loop_bcast_step = jcp.ic_block * jcp.is * sizeof(float);
696         jcp.bcast_loop_bcast_substep = jcp.ur * sizeof(float);
697
698         jcp.load_loop_load_step = jcp.oc_block * jcp.os * sizeof(float);
699         jcp.load_loop_iter_step = jcp.oc_block;
700
701         /* --- */
702
703         load_blocking = div_up(jcp.load_dim, jcp.load_block);
704         while (true) {
705             if (load_blocking <= 32) break;
706             else if (load_blocking % 2 == 0) load_blocking /= 2;
707             else if (load_blocking % 3 == 0) load_blocking /= 3;
708             else break;
709         }
710         load_blocking *= jcp.load_block;
711         load_blocking_max = load_blocking;
712         assert(jcp.load_dim % load_blocking == 0);
713
714         bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
715         while (true) {
716             if (bcast_blocking <= 9) break;
717             else if (bcast_blocking % 2 == 0) bcast_blocking /= 2;
718             else if (bcast_blocking % 3 == 0) bcast_blocking /= 3;
719             else break;
720         }
721         bcast_blocking *= jcp.bcast_block;
722         bcast_blocking_max = bcast_blocking;
723         assert(jcp.bcast_dim % bcast_blocking == 0);
724
725         reduce_blocking = 128; // affects L1$ utilization
726     } else
727         return status::unimplemented;
728
729     assert(load_blocking);
730     assert(load_blocking_max);
731     assert(bcast_blocking);
732     assert(bcast_blocking_max);
733     assert(reduce_blocking);
734
735     assert(jcp.bcast_block % jcp.ur == 0);
736     jcp.ur_tail = jcp.bcast_dim % jcp.ur;
737
738     jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block;
739     jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block;
740     jcp.nb_load_blocking = load_blocking / jcp.load_block;
741     jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block;
742     jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block;
743
744     jcp.nb_bcast = jcp.with_dw_conv ? jcp.ih : div_up(jcp.bcast_dim, jcp.bcast_block);
745     jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
746     jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
747
748     return status::success;
749 }
750
751 void jit_sse42_1x1_conv_kernel_f32::init_scratchpad(
752         memory_tracking::registrar_t &scratchpad,
753         const jit_1x1_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw) {
754     using namespace mkldnn::impl::memory_tracking::names;
755
756     if (jcp.prop_kind != backward_data && jcp.oc != jcp.oc_without_padding)
757         scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
758
759     if (jcp.with_dw_conv) {
760         const int nthreads = mkldnn_get_max_threads();
761         size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * (jcp.oc / jcp.oc_block);
762         scratchpad.book(key_dw_conv_buffer, sizeof(float) * dw_conv_buffer_size_ * nthreads);
763
764         if (jcp.oc != jcp.oc_without_padding)
765             scratchpad.book(key_dw_conv_padded_bias, sizeof(float) * jcp.oc);
766     }
767 }
768
769 }
770 }
771 }