b3c7ba44e7bec2544af2011f809de49472e21b10
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_sse42_1x1_conv_kernel_f32.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "nstl.hpp"
19 #include "type_helpers.hpp"
20 #include "utils.hpp"
21 #include "cpu_memory.hpp"
22
23 #include "jit_sse42_1x1_conv_kernel_f32.hpp"
24
25 #define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
26
27 namespace mkldnn {
28 namespace impl {
29 namespace cpu {
30
31 using namespace mkldnn::impl::prop_kind;
32 using namespace mkldnn::impl::memory_format;
33 using namespace mkldnn::impl::utils;
34
35 using namespace Xbyak;
36
37 void jit_sse42_1x1_conv_kernel_f32::bcast_loop(int load_loop_blk,
38         char load_loop_tag)
39 {
40     mov(aux1_reg_bcast_data, reg_bcast_data);
41     mov(aux_reg_output_data, reg_output_data);
42     mov(bcast_loop_iter, reg_bcast_loop_work);
43
44     jit_tagged_label bcast_loop("bcast_loop", load_loop_tag);
45     jit_tagged_label bcast_loop_tail("bcast_loop_tail", load_loop_tag);
46
47     cmp(bcast_loop_iter, jcp.ur);
48     jl(bcast_loop_tail, T_NEAR);
49
50     L(bcast_loop); {
51         assert(jcp.bcast_block % jcp.ur == 0);
52         int num_substeps = jcp.bcast_block / jcp.ur;
53         assert(num_substeps > 0 && num_substeps < 10);
54         for (int i = 0; i < num_substeps; i++) {
55             reduce_loop(load_loop_blk, jcp.ur, load_loop_tag, '0' + i);
56             if (i < num_substeps - 1) {
57                 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
58                 add(aux_reg_output_data, jcp.bcast_loop_output_substep);
59             } else {
60                 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step
61                         - (num_substeps - 1) * jcp.bcast_loop_bcast_substep);
62                 add(aux_reg_output_data, jcp.bcast_loop_output_step
63                         - (num_substeps - 1) * jcp.bcast_loop_output_substep);
64             }
65         }
66         sub(bcast_loop_iter, jcp.bcast_block);
67         cmp(bcast_loop_iter, jcp.bcast_block);
68         jge(bcast_loop, T_NEAR);
69     }
70
71     L(bcast_loop_tail);
72     if (jcp.ur_tail) {
73         jit_tagged_label bcast_loop_tail_out(
74                 "bcast_loop_tail_out", load_loop_tag);
75         cmp(bcast_loop_iter, 0);
76         jz(bcast_loop_tail_out, T_NEAR);
77         reduce_loop(load_loop_blk, jcp.ur_tail, load_loop_tag, '1');
78         L(bcast_loop_tail_out);
79     }
80 }
81
82 void jit_sse42_1x1_conv_kernel_f32::reduce_loop(int load_loop_blk, int ur,
83         char load_loop_tag, char bcast_loop_tag)
84 {
85     auto reg_load = [=](int i, int n) {
86         return Xmm(2*ur * load_loop_blk + 2*i + n + 1);
87     };
88
89     auto reg_accum = [=](int i, int j, int n) {
90         return Xmm(n*load_loop_blk*ur + i*ur + j + 1);
91     };
92
93     auto bias_ptr = [=](int i, int n) {
94         return ptr[reg_bias_data + sizeof(float) * jcp.oc_block * i + n*4*sizeof(float)];
95     };
96
97     auto bcast_ptr = [=](int u, int j) {
98         assert(j < jcp.ur);
99         assert(u <= jcp.reduce_loop_unroll);
100         size_t offt;
101         if (one_of(jcp.prop_kind,
102                     forward_training, forward_inference, backward_data)) {
103             assert(jcp.reduce_loop_unroll == (jcp.prop_kind == backward_data)
104                     ? jcp.oc_block : jcp.ic_block);
105             auto height = (jcp.prop_kind == backward_data) ? jcp.os : jcp.is;
106             offt = (u == jcp.reduce_loop_unroll)
107                 ? (height + j) * jcp.reduce_loop_unroll
108                 : j * jcp.reduce_loop_unroll + u;
109         } else
110             offt = u * jcp.ic_block + j;
111         return ptr[aux_reg_bcast_data + sizeof(float) * offt];
112     };
113
114     auto load_ptr = [=](int u, int i, int n) {
115         size_t offt;
116         size_t u0 = u % jcp.reduce_loop_unroll;
117         size_t u1 = u / jcp.reduce_loop_unroll;
118         switch (jcp.prop_kind) {
119         case backward_data:
120             offt = (i * jcp.oc_block + u0) * jcp.ic_block;
121             break;
122         case backward_weights:
123             offt = (i * jcp.os + u0) * jcp.oc_block;
124             break;
125         default:
126             offt = (i * jcp.ic + u0) * jcp.oc_block;
127         }
128         return ptr[aux_reg_load_data
129             + u1 * jcp.reduce_loop_load_step + sizeof(float) * offt + n * 4 * sizeof(float)];
130     };
131
132     auto output_ptr = [=](int i, int j, int n) {
133         switch (jcp.prop_kind) {
134         case backward_data:
135             return ptr[aux_reg_output_data +
136                 (i * jcp.is + j) * jcp.ic_block * sizeof(float) + n * 4 * sizeof(float)];
137         case backward_weights:
138             return ptr[aux_reg_output_data
139                 + (i ? reg_output_stride * i : 0) // TODO: Xbyak should allow 0 scale
140                 + sizeof(float) * jcp.oc_block * j + n * 4 * sizeof(float)];
141         default:
142             if (jcp.with_dw_conv)
143                 return ptr[aux_reg_output_data +
144                            (i * jcp.dw_conv_ker_h * jcp.ow + j) * jcp.oc_block * sizeof(float) + n*4*sizeof(float)];
145             else
146                 return ptr[aux_reg_output_data +
147                     (i * jcp.os + j) * jcp.oc_block * sizeof(float) + n*4*sizeof(float)];
148         }
149     };
150
151     auto init = [=]() {
152         jit_tagged_label init_done("init_done", load_loop_tag, bcast_loop_tag);
153         jit_tagged_label init_zero("init_zero", load_loop_tag, bcast_loop_tag);
154
155         if (jcp.with_bias && one_of(jcp.prop_kind, forward_training,
156                     forward_inference)) {
157             test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
158             jz(init_zero);
159
160             for (int i = 0; i < load_loop_blk; i++)
161                 for (int j = 0; j < ur; ++j) {
162                     movups(reg_accum(i, j, 0), bias_ptr(i, 0));
163                     movups(reg_accum(i, j, 1), bias_ptr(i, 1));
164                 }
165             jmp(init_done);
166         }
167
168         L(init_zero);
169         for (int i = 0; i < load_loop_blk; ++i)
170             for (int j = 0; j < ur; ++j) {
171                 auto r0 = reg_accum(i, j, 0);
172                 auto r1 = reg_accum(i, j, 1);
173                 xorps(r0, r0);
174                 xorps(r1, r1);
175             }
176
177         L(init_done);
178
179         // load weights
180         for (int i = 0; i < load_loop_blk; ++i) {
181             movups(reg_load(i, 0), load_ptr(0, i, 0));
182             movups(reg_load(i, 1), load_ptr(0, i, 1));
183         }
184
185         movss(reg_bcast, bcast_ptr(0, 0));
186         shufps(reg_bcast, reg_bcast, 0);
187     }; // init()
188
189     auto store = [=]() {
190         jit_tagged_label store_done(
191                 "store_done", load_loop_tag, bcast_loop_tag);
192         jit_tagged_label store_noadd(
193                 "store_noadd", load_loop_tag, bcast_loop_tag);
194
195         if (!jcp.with_sum) {
196             test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
197             jnz(store_noadd, T_NEAR);
198         }
199
200         for (int j = 0; j < ur; ++j)
201             for (int i = 0; i < load_loop_blk; ++i) {
202                 auto r0 = reg_accum(i, j, 0);
203                 auto r1 = reg_accum(i, j, 1);
204                 addps(r0, output_ptr(i, j, 0));
205                 addps(r1, output_ptr(i, j, 1));
206             }
207
208         L(store_noadd);
209
210         jit_tagged_label store_norelu(
211                 "store_norelu", load_loop_tag, bcast_loop_tag);
212         test(reg_reduce_pos_flag, FLAG_REDUCE_LAST);
213         jz(store_norelu, T_NEAR);
214
215         int eltwise_inj_idx = 0;
216         int depthwise_inj_idx = 0;
217         const auto &p = attr_.post_ops_;
218
219         if (p.len_ == 0 && eltwise_injectors.size() == 1) {
220             eltwise_injectors[0]->compute_vector_range(1, 2 * ur * load_loop_blk + 1);
221         }
222
223         int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
224         for (int i = 0; i < end_idx; i++) {
225             auto& post_op = p.entry_[i];
226             if (post_op.is_eltwise()) {
227                 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(1, 2 * ur * load_loop_blk + 1);
228                 eltwise_inj_idx++;
229             } else if (post_op.is_depthwise()) {
230                 mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
231                 mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
232
233                 add(reg_d_weights, reg_oc_off);
234                 add(reg_d_bias, reg_oc_off);
235
236                 for (int j = 0; j < load_loop_blk; ++j) {
237                     for (int k = 0; k < 2; k++) {
238                         int start_idx = reg_accum(j, 0, k).getIdx();
239                         int end_idx = reg_accum(j, ur, k).getIdx();
240
241                         depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
242                                 start_idx, end_idx, reg_d_weights, reg_d_bias);
243
244                         add(reg_d_weights, 4 * sizeof(float));
245                         add(reg_d_bias, 4 * sizeof(float));
246                     }
247                 }
248
249                 depthwise_inj_idx++;
250             }
251         }
252
253         L(store_norelu);
254
255         for (int j = 0; j < ur; ++j)
256             for (int i = 0; i < load_loop_blk; ++i) {
257                 movups(output_ptr(i, j, 0), reg_accum(i, j, 0));
258                 movups(output_ptr(i, j, 1), reg_accum(i, j, 1));
259             }
260
261         L(store_done);
262     };
263
264     auto fma_block = [=](bool last_block) {
265         for (int u = 0; u < jcp.reduce_loop_unroll; ++u) {
266             for (int j = 0; j < ur; ++j) {
267                 for (int i = 0; i < load_loop_blk; ++i) {
268                     mulps(reg_load(i, 0), reg_bcast);
269                     mulps(reg_load(i, 1), reg_bcast);
270                     addps(reg_accum(i, j, 0), reg_load(i, 0));
271                     addps(reg_accum(i, j, 1), reg_load(i, 1));
272
273                     if (j == ur - 1 && !(last_block && u == jcp.reduce_loop_unroll - 1)) {
274                         movups(reg_load(i, 0), load_ptr(u + 1, i, 0));
275                         movups(reg_load(i, 1), load_ptr(u + 1, i, 1));
276                     }
277                 }
278                 if (j < ur - 1) {
279                     movss(reg_bcast, bcast_ptr(u, j + 1));
280                     shufps(reg_bcast, reg_bcast, 0);
281                 }
282             } // for ur
283             if (!last_block || u < jcp.reduce_loop_unroll - 1) {
284                 movss(reg_bcast, bcast_ptr(u + 1, 0));
285                 shufps(reg_bcast, reg_bcast, 0);
286             }
287         } // for reduce_loop_unroll
288     };
289
290     jit_tagged_label reduce_loop("reduce_loop", load_loop_tag, bcast_loop_tag);
291     jit_tagged_label reduce_loop_tail(
292             "reduce_loop_tail", load_loop_tag, bcast_loop_tag);
293
294     mov(aux_reg_load_data, reg_load_data);
295     mov(aux_reg_bcast_data, aux1_reg_bcast_data);
296
297     init();
298
299     mov(reduce_loop_iter, reg_reduce_loop_work);
300     sub(reduce_loop_iter, jcp.reduce_loop_unroll);
301     jle(reduce_loop_tail, T_NEAR);
302
303     L(reduce_loop); {
304         fma_block(false);
305         add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step);
306         add(aux_reg_load_data, jcp.reduce_loop_load_step);
307         sub(reduce_loop_iter, jcp.reduce_loop_unroll);
308         jg(reduce_loop, T_NEAR);
309     }
310
311     L(reduce_loop_tail);
312     fma_block(true);
313
314     store();
315 } // reduce_loop()
316
317 void jit_sse42_1x1_conv_kernel_f32::diff_bias_loop(int load_loop_blk,
318         char load_loop_tag)
319 {
320     if (!jcp.with_bias || jcp.prop_kind != backward_weights)
321         return;
322
323     jit_tagged_label diff_bias_loop("diff_bias_loop", load_loop_tag);
324     jit_tagged_label diff_bias_loop_out("diff_bias_loop_out", load_loop_tag);
325     jit_tagged_label diff_bias_init_out("diff_bias_init_out", load_loop_tag);
326     jit_tagged_label diff_bias_load("diff_bias_load", load_loop_tag);
327
328     auto diff_bias_ptr = [=](int i, int n) {
329         return ptr[reg_diff_bias_data + i * jcp.oc_block * sizeof(float)+ 4*n*sizeof(float)];
330     };
331
332     auto load_ptr = [=](int u, int i, int n) {
333         return ptr[aux_reg_load_data
334             + (i * jcp.os + u) * jcp.oc_block * sizeof(float) + 4*n*sizeof(float)];
335     };
336
337     auto diff_bias_reg = [=](int i, int n) { return Xmm(2*i + n + 1); };
338
339     mov(reg_diff_bias_data, ptr[rsp + reg_diff_bias_data_stack_offt]);
340     cmp(reg_diff_bias_data, 0);
341     je(diff_bias_loop_out, T_NEAR);
342
343     test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
344     jz(diff_bias_load, T_NEAR);
345
346     for (int i = 0; i < load_loop_blk; ++i) {
347         auto r0 = diff_bias_reg(i, 0);
348         auto r1 = diff_bias_reg(i, 1);
349         xorps(r0, r0);
350         xorps(r1, r1);
351     }
352     jmp(diff_bias_init_out, T_NEAR);
353
354     L(diff_bias_load);
355     for (int i = 0; i < load_loop_blk; ++i) {
356         movups(diff_bias_reg(i, 0), diff_bias_ptr(i, 0));
357         movups(diff_bias_reg(i, 1), diff_bias_ptr(i, 1));
358     }
359
360     L(diff_bias_init_out);
361     mov(aux_reg_load_data, reg_load_data);
362     mov(reduce_loop_iter, reg_reduce_loop_work);
363     L(diff_bias_loop); {
364         for(int u = 0; u < jcp.reduce_loop_unroll; ++u)
365             for (int i = 0; i < load_loop_blk; ++i) {
366                 addps(diff_bias_reg(i, 0), load_ptr(u, i, 0));
367                 addps(diff_bias_reg(i, 1), load_ptr(u, i, 1));
368             }
369         assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0);
370         add(aux_reg_load_data, jcp.reduce_loop_load_step);
371         sub(reduce_loop_iter, jcp.reduce_loop_unroll);
372         jnz(diff_bias_loop, T_NEAR);
373     }
374
375     for (int i = 0; i < load_loop_blk; i++) {
376         movups(diff_bias_ptr(i, 0), diff_bias_reg(i, 0));
377         movups(diff_bias_ptr(i, 1), diff_bias_reg(i, 1));
378     }
379
380     add(reg_diff_bias_data, load_loop_blk * jcp.oc_block * sizeof(float));
381     mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data);
382
383     L(diff_bias_loop_out);
384 }
385
386 void jit_sse42_1x1_conv_kernel_f32::generate()
387 {
388     if (jcp.with_eltwise) {
389         eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<sse42>(
390                 this, jcp.eltwise_alg, jcp.eltwise_alpha, 0
391         ));
392     }
393
394     const auto &p = attr_.post_ops_;
395     int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
396     for (int i = 0; i < end_idx; i++) {
397         auto &post_op = p.entry_[i];
398         if (post_op.is_eltwise()) {
399             eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<sse42>(
400                     this,
401                     post_op.eltwise.alg,
402                     post_op.eltwise.alpha,
403                     post_op.eltwise.beta
404             ));
405         } else if (post_op.is_depthwise()) {
406             depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<sse42>(
407                     this,
408                     post_op.depthwise.alg
409             ));
410         }
411     }
412
413     preamble();
414
415     mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
416     mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]);
417     mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]);
418     if (jcp.with_bias) {
419         if (jcp.prop_kind == backward_weights) {
420             sub(rsp, stack_space_needed);
421             mov(reg_diff_bias_data, ptr[param1 + GET_OFF(bias_data)]);
422             mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data);
423         } else
424             mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
425     }
426
427     mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]);
428     mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]);
429     mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
430     mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
431     if (jcp.prop_kind == backward_weights)
432         mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]);
433     mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]);
434
435     auto load_loop_body = [=] (int load_loop_blk, char bcast_loop_tag) {
436         bcast_loop(load_loop_blk, bcast_loop_tag);
437         add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
438         switch (jcp.prop_kind) {
439         case forward_training:
440         case forward_inference:
441             add(reg_bias_data, load_loop_blk * jcp.oc_block * sizeof(float));
442             if (jcp.with_dw_conv)
443                 add(reg_output_data,
444                     load_loop_blk * jcp.ow * jcp.oc_block * sizeof(float));
445             else
446                 add(reg_output_data,
447                         load_loop_blk * jcp.os * jcp.oc_block * sizeof(float));
448             break;
449         case backward_data:
450             add(reg_output_data,
451                     load_loop_blk * jcp.is * jcp.ic_block * sizeof(float));
452             break;
453         case backward_weights:
454             for (int i = 0; i < load_loop_blk; i++)
455                 add(reg_output_data, reg_output_stride);
456             break;
457         default:
458             assert(!"invalid prop_kind");
459         }
460         sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
461         add(reg_oc_off, load_loop_blk * jcp.oc_block * sizeof(float));
462     };
463
464     const char *load_loop_blk_8 = "load_loop_blk_8";
465     const char *load_loop_blk_16 = "load_loop_blk_16";
466     const char *load_loop_blk_24 = "load_loop_blk_24";
467     const char *load_loop_blk_end = "load_loop_blk_end";
468
469     cmp(reg_load_loop_work, 8);
470     jle(load_loop_blk_8, T_NEAR);
471
472     cmp(reg_load_loop_work, 32);
473     je(load_loop_blk_16, T_NEAR);
474
475     cmp(reg_load_loop_work, 16);
476     jle(load_loop_blk_16, T_NEAR);
477
478     L(load_loop_blk_24); {
479         diff_bias_loop(3, '3');
480         load_loop_body(3, '3');
481         cmp(reg_load_loop_work, 32);
482         je(load_loop_blk_16);
483         cmp(reg_load_loop_work, 24);
484         jge(load_loop_blk_24);
485     }
486
487     cmp(reg_load_loop_work, 8);
488     jle(load_loop_blk_8, T_NEAR);
489
490     L(load_loop_blk_16); {
491         diff_bias_loop(2, '2');
492         load_loop_body(2, '2');
493         cmp(reg_load_loop_work, 16);
494         jge(load_loop_blk_16);
495     }
496
497     L(load_loop_blk_8); {
498         cmp(reg_load_loop_work, 0);
499         je(load_loop_blk_end, T_NEAR);
500         diff_bias_loop(1, '1');
501         load_loop_body(1, '1');
502     }
503
504     L(load_loop_blk_end);
505
506     if (jcp.with_bias && jcp.prop_kind == backward_weights)
507         add(rsp, stack_space_needed);
508
509     postamble();
510
511     for (auto& inj : eltwise_injectors)
512         inj->prepare_table();
513 }
514
515 bool jit_sse42_1x1_conv_kernel_f32::post_ops_ok(
516         jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
517     const auto &p = attr.post_ops_;
518
519     auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
520     auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
521     auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
522     auto is_dw_conv = [&](int idx) { return p.entry_[idx].is_dw_conv(); };
523     auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
524
525     switch (p.len_) {
526         case 0: return true; // no post_ops
527         case 1:
528             return true // sum OR eltwise OR dw_conv
529                    && !jcp.with_eltwise && (is_simple(0) || is_sum(0) || is_dw_conv(0));
530         case 2:
531             return true // sum->eltwise OR dw_conv->eltwise OR eltwise->dw_conv OR dw_conv->sum OR sum->depthwise OR
532                    // eltwise->depthwise OR depthwise->depthwise
533                    && !jcp.with_eltwise && ((is_sum(0) && is_simple(1)) || (is_dw_conv(0) && is_eltwise(1)) ||
534                                             (is_eltwise(0) && is_dw_conv(1)) || (is_dw_conv(0) && is_sum(1)) ||
535                                             (is_simple(0) && is_simple(1)));
536         case 3:
537             return true // eltwise->dw_conv->eltwise OR dw_conv->sum->eltwise OR sum->eltwise->depthwise OR
538                    // sum->depthwise->eltwise OR sum->depthwise->depthwise
539                    && !jcp.with_eltwise && ((is_eltwise(0) && is_dw_conv(1) && is_eltwise(2)) ||
540                                             (is_dw_conv(0) && is_sum(1) && is_eltwise(2)) ||
541                                             (is_sum(0) && is_simple(1) && is_simple(2)));
542         case 4: return true // eltwise->dw_conv->sum->eltwise
543                        && !jcp.with_eltwise && (is_eltwise(0) && is_dw_conv(1) && is_sum(2) && is_eltwise(3));
544         default: return false;
545     }
546
547     return false;
548 }
549
550 status_t jit_sse42_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp,
551         const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
552         const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
553         const primitive_attr_t &attr, bool with_relu, float relu_negative_slope)
554 {
555     if (!mayiuse(sse42))
556         return status::unimplemented;
557
558     // TODO (Roma): this code is duplicated from the generic kernel; maybe the
559     // configuration struct could do some stuff below
560     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
561
562     jcp.prop_kind = cd.prop_kind;
563
564     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
565     jcp.mb = src_d.dims()[0];
566
567     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
568     jcp.oc_without_padding = jcp.oc;
569     jcp.ic = src_d.dims()[1] / jcp.ngroups;
570
571     jcp.ih = src_d.dims()[2];
572     jcp.iw = src_d.dims()[3];
573     jcp.oh = dst_d.dims()[2];
574     jcp.ow = dst_d.dims()[3];
575
576     jcp.kh = weights_d.dims()[with_groups + 2];
577     jcp.kw = weights_d.dims()[with_groups + 3];
578
579     jcp.t_pad = cd.padding[0][0];
580     jcp.l_pad = cd.padding[0][1];
581
582     jcp.stride_h = cd.strides[0];
583     jcp.stride_w = cd.strides[1];
584
585     jcp.src_fmt = src_d.format();
586     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
587
588     jcp.with_eltwise = with_relu;
589     jcp.eltwise_alg = mkldnn_eltwise_relu;
590     jcp.eltwise_alpha = relu_negative_slope;
591
592     if (!post_ops_ok(jcp, attr))
593         return status::unimplemented;
594
595     const auto &p = attr.post_ops_;
596     jcp.with_dw_conv = false;
597     int dw_conv_ind = p.find(primitive_kind::convolution);
598     if (dw_conv_ind != -1) {
599         jcp.with_dw_conv = true;
600         jcp.dw_conv_in_h = p.entry_[dw_conv_ind].dw_conv.in_h;
601         jcp.dw_conv_in_w = p.entry_[dw_conv_ind].dw_conv.in_w;
602         jcp.dw_conv_ker_h = p.entry_[dw_conv_ind].dw_conv.ker_h;
603         jcp.dw_conv_ker_w = p.entry_[dw_conv_ind].dw_conv.ker_w;
604         jcp.dw_conv_str_h = p.entry_[dw_conv_ind].dw_conv.str_h;
605         jcp.dw_conv_str_w = p.entry_[dw_conv_ind].dw_conv.str_w;
606         jcp.dw_conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
607         jcp.dw_conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
608     }
609
610     if (jcp.with_dw_conv) {
611         int dw_conv_eltwise_ind = p.find(primitive_kind::eltwise, dw_conv_ind);
612         if (dw_conv_eltwise_ind != -1) {
613             jcp.dw_conv_with_eltwise = true;
614             jcp.dw_conv_eltwise_alg = p.entry_[dw_conv_eltwise_ind].eltwise.alg;
615             jcp.dw_conv_eltwise_alpha = p.entry_[dw_conv_eltwise_ind].eltwise.alpha;
616             jcp.dw_conv_eltwise_beta = p.entry_[dw_conv_eltwise_ind].eltwise.beta;
617         }
618     }
619
620     jcp.with_sum = p.find(primitive_kind::sum, 0, dw_conv_ind) != -1;
621     if (jcp.with_dw_conv) {
622         jcp.dw_conv_with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
623     }
624
625     if (jcp.with_dw_conv) {
626         jcp.oh = jcp.dw_conv_in_h;
627         jcp.ow = jcp.dw_conv_in_w;
628     }
629
630     jcp.os = jcp.oh * jcp.ow;
631     jcp.is = jcp.ih * jcp.iw;
632
633     constexpr memory_format_t weights_formats[2][2] = {
634         { OIhw8i8o, OIhw8o8i },
635         { gOIhw8i8o, gOIhw8o8i }
636     };
637     memory_format_t weights_format
638         = weights_formats[with_groups][jcp.prop_kind == backward_data];
639
640     bool args_ok = true
641         && jcp.ngroups == 1
642         && src_d.format() == nChw8c
643         && weights_d.format() == weights_format
644         && one_of(cd.bias_desc.format, memory_format::undef, any, x)
645         && dst_d.format() == nChw8c;
646     if (!args_ok) return status::unimplemented;
647
648     const int simd_w = 4;
649
650     jcp.oc = rnd_up(jcp.oc, simd_w*2);
651     jcp.ic = rnd_up(jcp.ic, simd_w*2);
652
653     jcp.ic_block = jcp.oc_block = simd_w*2;
654
655     args_ok = true
656         && jcp.oc % jcp.oc_block == 0
657         && jcp.ic % jcp.ic_block == 0
658         && jcp.t_pad == 0 && jcp.l_pad == 0
659         && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides
660         && jcp.kh == 1 && jcp.kw == 1;
661     if (!args_ok) return status::unimplemented;
662
663     jcp.ur = 1;
664
665     int load_blocking{ 0 };
666     int load_blocking_max{ 0 };
667     int bcast_blocking{ 0 };
668     int bcast_blocking_max{ 0 };
669     int reduce_blocking{ 0 };
670
671     if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
672         jcp.reduce_dim = jcp.ic;
673         jcp.reduce_block = jcp.ic_block;
674
675         jcp.load_dim = jcp.oc;
676         jcp.load_block = jcp.oc_block;
677
678         jcp.bcast_dim = jcp.with_dw_conv ? jcp.iw : jcp.is;
679         jcp.bcast_block = jcp.ur;
680
681         jcp.reduce_loop_unroll = jcp.reduce_block;
682         jcp.reduce_loop_bcast_step
683             = jcp.reduce_loop_unroll * jcp.is * sizeof(float);
684         jcp.reduce_loop_load_step
685             = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float);
686
687         jcp.bcast_loop_output_step = jcp.ur * jcp.oc_block * sizeof(float);
688         jcp.bcast_loop_output_substep = -1; // unused
689         jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_block * sizeof(float);
690         jcp.bcast_loop_bcast_substep = -1; // unused
691
692         jcp.load_loop_load_step = jcp.ic * jcp.oc_block * sizeof(float);
693         jcp.load_loop_iter_step = jcp.oc_block;
694
695         load_blocking = jcp.with_dw_conv ? nstl::min(3 * jcp.load_block, jcp.oc) : 120; // assumes the kernel is jcp.ur x 3
696         load_blocking_max = jcp.with_dw_conv ? nstl::min(3 * jcp.load_block, jcp.oc) : 144;
697         bcast_blocking = 128; // affects load balancing across threads
698         bcast_blocking_max = 192;
699         reduce_blocking = 128; // affects L1$ utilization
700     } else if (jcp.prop_kind == backward_data) {
701         jcp.reduce_dim = jcp.oc;
702         jcp.reduce_block = jcp.oc_block;
703
704         jcp.load_dim = jcp.ic;
705         jcp.load_block = jcp.oc_block;
706
707         jcp.bcast_dim = jcp.os;
708         jcp.bcast_block = jcp.ur;
709
710         jcp.reduce_loop_unroll = jcp.reduce_block;
711         jcp.reduce_loop_bcast_step
712             = jcp.reduce_loop_unroll * jcp.os * sizeof(float);
713         jcp.reduce_loop_load_step
714             = jcp.reduce_loop_unroll * jcp.ic * sizeof(float);
715
716         jcp.bcast_loop_output_step = jcp.ur * jcp.ic_block * sizeof(float);
717         jcp.bcast_loop_output_substep = -1; // unused
718         jcp.bcast_loop_bcast_step = jcp.ur * jcp.oc_block * sizeof(float);
719         jcp.bcast_loop_bcast_substep = -1; // unused
720
721         jcp.load_loop_load_step = jcp.oc_block * jcp.ic_block * sizeof(float);
722         jcp.load_loop_iter_step = jcp.ic_block;
723
724         load_blocking = 96; // assumes the kernel is jcp.ur x 3
725         load_blocking_max = 144;
726         bcast_blocking = 128; // affects load balancing across threads
727         bcast_blocking_max = 196;
728         reduce_blocking = 64; // affects L1$ utilization
729     } else if (jcp.prop_kind == backward_weights) {
730         jcp.reduce_dim = jcp.os;
731         jcp.reduce_block = 1;
732
733         jcp.load_dim = jcp.oc;
734         jcp.load_block = jcp.oc_block;
735
736         jcp.bcast_dim = jcp.ic;
737         jcp.bcast_block = jcp.ic_block;
738
739         jcp.reduce_loop_unroll = jcp.reduce_block;
740         jcp.reduce_loop_bcast_step
741             = jcp.reduce_loop_unroll * jcp.ic_block * sizeof(float);
742         jcp.reduce_loop_load_step
743             = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float);
744
745         jcp.bcast_loop_output_step = jcp.oc_block * jcp.ic_block * sizeof(float);
746         jcp.bcast_loop_output_substep = jcp.oc_block * jcp.ur * sizeof(float);
747         jcp.bcast_loop_bcast_step = jcp.ic_block * jcp.is * sizeof(float);
748         jcp.bcast_loop_bcast_substep = jcp.ur * sizeof(float);
749
750         jcp.load_loop_load_step = jcp.oc_block * jcp.os * sizeof(float);
751         jcp.load_loop_iter_step = jcp.oc_block;
752
753         /* --- */
754
755         load_blocking = div_up(jcp.load_dim, jcp.load_block);
756         while (true) {
757             if (load_blocking <= 32) break;
758             else if (load_blocking % 2 == 0) load_blocking /= 2;
759             else if (load_blocking % 3 == 0) load_blocking /= 3;
760             else break;
761         }
762         load_blocking *= jcp.load_block;
763         load_blocking_max = load_blocking;
764         assert(jcp.load_dim % load_blocking == 0);
765
766         bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
767         while (true) {
768             if (bcast_blocking <= 9) break;
769             else if (bcast_blocking % 2 == 0) bcast_blocking /= 2;
770             else if (bcast_blocking % 3 == 0) bcast_blocking /= 3;
771             else break;
772         }
773         bcast_blocking *= jcp.bcast_block;
774         bcast_blocking_max = bcast_blocking;
775         assert(jcp.bcast_dim % bcast_blocking == 0);
776
777         reduce_blocking = 128; // affects L1$ utilization
778     } else
779         return status::unimplemented;
780
781     assert(load_blocking);
782     assert(load_blocking_max);
783     assert(bcast_blocking);
784     assert(bcast_blocking_max);
785     assert(reduce_blocking);
786
787     assert(jcp.bcast_block % jcp.ur == 0);
788     jcp.ur_tail = jcp.bcast_dim % jcp.ur;
789
790     jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block;
791     jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block;
792     jcp.nb_load_blocking = load_blocking / jcp.load_block;
793     jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block;
794     jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block;
795
796     jcp.nb_bcast = jcp.with_dw_conv ? jcp.ih : div_up(jcp.bcast_dim, jcp.bcast_block);
797     jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
798     jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
799
800     return status::success;
801 }
802
803 }
804 }
805 }