Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_planar_conv_kernel_f32.cpp
1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <common/primitive_attr.hpp>
18 #include "c_types_map.hpp"
19 #include "nstl.hpp"
20 #include "type_helpers.hpp"
21 #include "utils.hpp"
22 #include "cpu_memory.hpp"
23
24 #include "jit_uni_planar_conv_kernel_f32.hpp"
25 #include "cpu_isa_traits.hpp"
26
27 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
28
29 namespace mkldnn {
30 namespace impl {
31 namespace cpu {
32
33 using namespace mkldnn::impl::prop_kind;
34 using namespace mkldnn::impl::memory_format;
35 using namespace mkldnn::impl::utils;
36
37 using namespace Xbyak;
38
39 template <cpu_isa_t isa>
40 void jit_uni_planar_conv_fwd_kernel_f32<isa>::load_src_scalar(int ur_h) {
41     Label init_done_label;
42     Label init_first_label;
43
44     mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]);
45     if (jcp.with_bias)
46         mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]);
47
48     if (!jcp.with_sum) {
49         test(reg_ci_flag, FLAG_IC_FIRST);
50         jne(init_first_label, T_NEAR);
51     }
52
53     for (int kk = 0; kk < ur_h; kk++) {
54         size_t offt = sizeof(float) * (kk * jcp.ow * jcp.oh_block_step);
55         movss(Xmm(kk), make_safe_addr(reg_output, offt, reg_long_offt));
56     }
57
58     if (jcp.with_sum && jcp.with_bias) {
59         test(reg_ci_flag, FLAG_IC_FIRST);
60         je(init_done_label, T_NEAR);
61
62         movss(xmm_tmp, make_safe_addr(reg_bias, 0, reg_long_offt));
63         for (int kk = 0; kk < ur_h; kk++) {
64             uni_vaddps(Vmm(kk), Vmm(kk), vmm_tmp);
65         }
66     }
67
68     jmp(init_done_label, T_NEAR);
69
70     L(init_first_label);
71     if (this->jcp.with_bias) {
72         movss(xmm_tmp, make_safe_addr(reg_bias, 0, reg_long_offt));
73         for (int kk = 0; kk < ur_h; kk++) {
74             uni_vmovups(Vmm(kk), vmm_tmp);
75         }
76     } else {
77         for (int kk = 0; kk < ur_h; kk++) {
78             uni_vpxor(Vmm(kk), Vmm(kk), Vmm(kk));
79         }
80     }
81
82     L(init_done_label);
83 }
84
85 template <cpu_isa_t isa>
86 void jit_uni_planar_conv_fwd_kernel_f32<isa>::filter_scalar(int ur_h) {
87     Label iter_exit_label;
88
89     int iw = jcp.iw;
90     int ih = jcp.ih;
91     int id = jcp.id;
92     int dilate_w = jcp.dilate_w + 1;
93     int ic_blk = jcp.ic_block;
94     int kw = jcp.kw;
95     int kh = jcp.kh;
96     int kd = jcp.kd;
97
98     cmp(reg_kw, 0);
99     je(iter_exit_label, T_NEAR);
100
101     mov(aux_reg_input_w, aux_reg_input_h);
102     mov(aux_reg_kernel_w, aux_reg_kernel_h);
103     mov(kw_iter, reg_kw);
104
105     Label kw_label;
106     L(kw_label);
107     {
108         for (size_t ifm2 = 0; ifm2 < (size_t)ic_blk; ifm2++) {
109             for (int kk = 0; kk < ur_h; kk++) {
110                 size_t inp_off = sizeof(float) * (ifm2 * id * ih * iw + kk * jcp.iw * jcp.oh_block_step);
111                 movss(xmm_src, make_safe_addr(aux_reg_input_w, inp_off, reg_long_offt));
112
113                 size_t ker_off = sizeof(float) * (ifm2 * kd * kh * kw);
114                 movss(xmm_ker, ptr[aux_reg_kernel_w + ker_off]);
115
116                 uni_vfmadd231ps(Vmm(kk), vmm_src, vmm_ker);
117             }
118         }
119
120         add(aux_reg_kernel_w, sizeof(float));
121         add(aux_reg_input_w, dilate_w * sizeof(float));
122
123         dec(kw_iter);
124         cmp(kw_iter, 0);
125         jg(kw_label, T_NEAR);
126     }
127
128     L(iter_exit_label);
129 }
130
131 template <cpu_isa_t isa>
132 void jit_uni_planar_conv_fwd_kernel_f32<isa>::apply_filter_scalar(int ur_h) {
133     int iw = jcp.iw;
134     int kw = jcp.kw;
135     int dilate_h = jcp.dilate_h + 1;
136     int dilate_d = jcp.dilate_h + 1;
137     const int inp_mult_h = dilate_h;
138     const int inp_mult_d = dilate_d;
139
140     Label skip_kh_loop, skip_kd_loop, kd_label;
141     if (jcp.ndims == 5) {
142         push(reg_kernel);
143         push(reg_output);
144
145         mov(reg_kd, ptr[param1 + GET_OFF(kd_padding)]);
146         mov(aux_reg_ker_d, aux_reg_kernel_h);
147         mov(aux_reg_inp_d, aux_reg_input_h);
148
149         cmp(reg_kd, 0);
150         je(skip_kd_loop, T_NEAR);
151
152         L(kd_label);
153         mov(kh_iter, ptr[param1 + GET_OFF(kh_padding)]);
154     } else {
155         mov(kh_iter, reg_kh);
156     }
157
158     if (jcp.ndims == 5) {
159         mov(aux_reg_input_h, aux_reg_inp_d);
160         mov(aux_reg_kernel_h, aux_reg_ker_d);
161     }
162
163     cmp(kh_iter, 0);
164     je(skip_kh_loop, T_NEAR);
165
166     Label kh_label;
167     L(kh_label);
168     {
169         filter_scalar(ur_h);
170
171         add(aux_reg_kernel_h, sizeof(float) * kw);
172         add(aux_reg_input_h, sizeof(float) * iw * inp_mult_h);
173
174         dec(kh_iter);
175         cmp(kh_iter, 0);
176         jg(kh_label, T_NEAR);
177     }
178
179     L(skip_kh_loop);
180
181     if (jcp.ndims == 5) {
182         add(aux_reg_ker_d, sizeof(float) * jcp.kw * jcp.kh);
183         add(aux_reg_inp_d, sizeof(float) * jcp.ih * jcp.iw * inp_mult_d);
184
185         dec(reg_kd);
186         cmp(reg_kd, 0);
187         jg(kd_label, T_NEAR);
188         L(skip_kd_loop);
189
190         pop(reg_output);
191         pop(reg_kernel);
192     }
193 }
194
195 template <cpu_isa_t isa>
196 void jit_uni_planar_conv_fwd_kernel_f32<isa>::apply_postprocess_scalar(int ur_h) {
197     Label regular_store_label;
198
199     mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]);
200     test(reg_ci_flag, FLAG_IC_LAST);
201     je(regular_store_label, T_NEAR);
202
203     int eltwise_inj_idx = 0;
204     const auto &p = attr_.post_ops_;
205
206
207     for (int i = 0; i < p.len_; i++) {
208         auto& post_op = p.entry_[i];
209         if (post_op.is_eltwise()) {
210             eltwise_injectors[eltwise_inj_idx]->compute_vector_range(0, ur_h);
211             eltwise_inj_idx++;
212         }
213     }
214
215     L(regular_store_label);
216 }
217
218 template <cpu_isa_t isa>
219 void jit_uni_planar_conv_fwd_kernel_f32<isa>::store_dst_scalar(int ur_h) {
220     for (int kk = 0; kk < ur_h; kk++) {
221         size_t o_off = sizeof(float) * (kk * jcp.ow * jcp.oh_block_step);
222         movss(make_safe_addr(reg_output, o_off, reg_long_offt), Xmm(kk));
223     }
224 }
225
226 template <cpu_isa_t isa>
227 void jit_uni_planar_conv_fwd_kernel_f32<isa>::load_src(int ur_h, int ur_w) {
228     Label init_done_label;
229     Label init_first_label;
230
231     mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]);
232     if (jcp.with_bias)
233         mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]);
234
235     if (!jcp.with_sum) {
236         test(reg_ci_flag, FLAG_IC_FIRST);
237         jne(init_first_label, T_NEAR);
238     }
239
240     for (int kk = 0; kk < ur_h; kk++) {
241         for (int jj = 0; jj < ur_w; jj++) {
242             size_t offt = sizeof(float) * (jj * jcp.ow_block + kk * jcp.ow * jcp.oh_block_step);
243             uni_vmovups(Vmm(kk * ur_w + jj), make_safe_addr(reg_output, offt, reg_long_offt));
244         }
245     }
246
247     if (jcp.with_sum && jcp.with_bias) {
248         test(reg_ci_flag, FLAG_IC_FIRST);
249         je(init_done_label, T_NEAR);
250
251         uni_vbroadcastss(vmm_tmp, make_safe_addr(reg_bias, 0, reg_long_offt));
252         for (int kk = 0; kk < ur_h; kk++) {
253             for (int jj = 0; jj < ur_w; jj++) {
254                 uni_vaddps(Vmm(kk * ur_w + jj), Vmm(kk * ur_w + jj), vmm_tmp);
255             }
256         }
257     }
258
259     jmp(init_done_label, T_NEAR);
260
261     L(init_first_label);
262     if (this->jcp.with_bias) {
263         uni_vbroadcastss(vmm_tmp, make_safe_addr(reg_bias, 0, reg_long_offt));
264         for (int kk = 0; kk < ur_h; kk++) {
265             for (int jj = 0; jj < ur_w; jj++) {
266                 uni_vmovups(Vmm(kk * ur_w + jj), vmm_tmp);
267             }
268         }
269     } else {
270         for (int kk = 0; kk < ur_h; kk++) {
271             for (int jj = 0; jj < ur_w; jj++) {
272                 uni_vpxor(Vmm(kk * ur_w + jj), Vmm(kk * ur_w + jj), Vmm(kk * ur_w + jj));
273             }
274         }
275     }
276
277     L(init_done_label);
278 }
279
280 template <cpu_isa_t isa>
281 void jit_uni_planar_conv_fwd_kernel_f32<isa>::filter_unrolled(int ur_h, int ur_w) {
282     int iw = jcp.iw;
283     int ih = jcp.ih;
284     int id = jcp.id;
285     int stride_w = jcp.stride_w;
286     int dilate_w = jcp.dilate_w + 1;
287     int ic_blk = jcp.ic_block;
288     int kw = jcp.kw;
289     int kh = jcp.kh;
290     int kd = jcp.kd;
291     int ow_blk = jcp.ow_block;
292
293     for (int ki = 0; ki < kw; ki++) {
294         for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
295             for (int kk = 0; kk < ur_h; kk++) {
296                 for (int jj = 0; jj < ur_w; jj++) {
297                     size_t inp_off = sizeof(float) * ((size_t) ifm2 * id * ih * iw + ki * dilate_w +
298                             jj * stride_w * ow_blk + kk * jcp.ow * jcp.oh_block_step);
299                     uni_vmovups(vmm_src, make_safe_addr(aux_reg_input_h, inp_off, reg_long_offt));
300
301                     int ker_off = sizeof(float) * ((size_t) ifm2 * kd * kh * kw + ki);
302                     uni_vbroadcastss(vmm_ker, ptr[aux_reg_kernel_h + ker_off]);
303
304                     uni_vfmadd231ps(Vmm(kk * ur_w + jj), vmm_src, vmm_ker);
305                 }
306             }
307         }
308     }
309 }
310
311 template <cpu_isa_t isa>
312 void jit_uni_planar_conv_fwd_kernel_f32<isa>::filter(int ur_h) {
313     Label iter_exit_label;
314
315     int iw = jcp.iw;
316     int ih = jcp.ih;
317     int id = jcp.id;
318     int dilate_w = jcp.dilate_w + 1;
319     int ic_blk = jcp.ic_block;
320     int kw = jcp.kw;
321     int kh = jcp.kh;
322     int kd = jcp.kd;
323
324     cmp(reg_kw, 0);
325     je(iter_exit_label, T_NEAR);
326
327     mov(aux_reg_input_w, aux_reg_input_h);
328     mov(aux_reg_kernel_w, aux_reg_kernel_h);
329     mov(kw_iter, reg_kw);
330
331     Label kw_label;
332     L(kw_label);
333     {
334         for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
335             for (int kk = 0; kk < ur_h; kk++) {
336                 size_t inp_off = sizeof(float) * ((size_t) ifm2 * id * ih * iw + kk * jcp.ow * jcp.oh_block_step);
337                 uni_vmovups(vmm_src, make_safe_addr(aux_reg_input_w, inp_off, reg_long_offt));
338
339                 size_t ker_off = sizeof(float) * ((size_t) ifm2 * kd * kh * kw);
340                 uni_vbroadcastss(vmm_ker, ptr[aux_reg_kernel_w + ker_off]);
341
342                 uni_vfmadd231ps(Vmm(kk), vmm_src, vmm_ker);
343             }
344         }
345
346         add(aux_reg_kernel_w, sizeof(float));
347         add(aux_reg_input_w, dilate_w * sizeof(float));
348
349         dec(kw_iter);
350         cmp(kw_iter, 0);
351         jg(kw_label, T_NEAR);
352     }
353
354     L(iter_exit_label);
355 }
356
357 template <cpu_isa_t isa>
358 void jit_uni_planar_conv_fwd_kernel_f32<isa>::apply_filter(int ur_h, int ur_w) {
359     int iw = jcp.iw;
360     int kw = jcp.kw;
361     int dilate_h = jcp.dilate_h + 1;
362     int dilate_d = jcp.dilate_h + 1;
363     const int inp_mult_h = dilate_h;
364     const int inp_mult_d = dilate_d;
365
366     Label skip_kh_loop, skip_kd_loop, kd_label;
367     if (jcp.ndims == 5) {
368         push(reg_kernel);
369         push(reg_output);
370
371         mov(reg_kd, ptr[param1 + GET_OFF(kd_padding)]);
372         mov(aux_reg_ker_d, aux_reg_kernel_h);
373         mov(aux_reg_inp_d, aux_reg_input_h);
374
375         cmp(reg_kd, 0);
376         je(skip_kd_loop, T_NEAR);
377
378         L(kd_label);
379         mov(kh_iter, ptr[param1 + GET_OFF(kh_padding)]);
380     } else {
381         mov(kh_iter, reg_kh);
382     }
383
384     if (jcp.ndims == 5) {
385         mov(aux_reg_input_h, aux_reg_inp_d);
386         mov(aux_reg_kernel_h, aux_reg_ker_d);
387     }
388
389     cmp(kh_iter, 0);
390     je(skip_kh_loop, T_NEAR);
391
392     Label kh_label;
393     L(kh_label);
394     {
395         if (ur_w == jcp.nb_ow_blocking)
396             filter_unrolled(ur_h, ur_w);
397         else
398             filter(ur_h);
399
400         add(aux_reg_kernel_h, sizeof(float) * kw);
401         add(aux_reg_input_h, sizeof(float) * iw * inp_mult_h);
402
403         dec(kh_iter);
404         cmp(kh_iter, 0);
405         jg(kh_label, T_NEAR);
406     }
407
408     L(skip_kh_loop);
409
410     if (jcp.ndims == 5) {
411         add(aux_reg_ker_d, sizeof(float) * jcp.kw * jcp.kh);
412         add(aux_reg_inp_d, sizeof(float) * jcp.ih * jcp.iw * inp_mult_d);
413
414         dec(reg_kd);
415         cmp(reg_kd, 0);
416         jg(kd_label, T_NEAR);
417         L(skip_kd_loop);
418
419         pop(reg_output);
420         pop(reg_kernel);
421     }
422 }
423
424 template <cpu_isa_t isa>
425 void jit_uni_planar_conv_fwd_kernel_f32<isa>::apply_postprocess(int ur_h, int ur_w) {
426     Label regular_store_label;
427
428     mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]);
429     test(reg_ci_flag, FLAG_IC_LAST);
430     je(regular_store_label, T_NEAR);
431
432     int eltwise_inj_idx = 0;
433     const auto &p = attr_.post_ops_;
434
435     for (int i = 0; i < p.len_; i++) {
436         auto& post_op = p.entry_[i];
437         if (post_op.is_eltwise()) {
438             eltwise_injectors[eltwise_inj_idx]->compute_vector_range(0, ur_w * ur_h);
439             eltwise_inj_idx++;
440         }
441     }
442
443     L(regular_store_label);
444 }
445
446 template <cpu_isa_t isa>
447 void jit_uni_planar_conv_fwd_kernel_f32<isa>::store_dst(int ur_h, int ur_w) {
448     for (int kk = 0; kk < ur_h; kk++) {
449         for (int jj = 0; jj < ur_w; jj++) {
450             size_t o_off = sizeof(float) * (jj * jcp.ow_block + kk * jcp.ow * jcp.oh_block_step);
451             uni_vmovups(make_safe_addr(reg_output, o_off, reg_long_offt), Vmm(kk * ur_w + jj));
452         }
453     }
454 }
455
456 template <cpu_isa_t isa>
457 void jit_uni_planar_conv_fwd_kernel_f32<isa>::solve_common(int ur_h) {
458     auto solve_loop = [&](int ur_w, int step_w) {
459         Label loop_label;
460         Label exit_label;
461
462         L(loop_label);
463         {
464             if (step_w == 1) {
465                 load_src_scalar(ur_h);
466                 apply_filter_scalar(ur_h);
467                 apply_postprocess_scalar(ur_h);
468                 store_dst_scalar(ur_h);
469             } else {
470                 load_src(ur_h, ur_w);
471                 apply_filter(ur_h, ur_w);
472                 apply_postprocess(ur_h, ur_w);
473                 store_dst(ur_h, ur_w);
474             }
475
476             add(reg_input, sizeof(float) * step_w * jcp.stride_w);
477             add(reg_output, sizeof(float) * step_w);
478         }
479
480         L(exit_label);
481     };
482
483     Label left_border_label;
484     Label main_loop_unrolled_label;
485     Label main_loop_label;
486     Label right_border_label;
487     Label exit_label;
488
489     xor_(reg_ow, reg_ow);
490     sub(reg_input, sizeof(float) * jcp.l_pad);
491
492     auto adjust_indexes_left = [&]() {
493         Label border_indexes_label;
494         Label border_indexes_exit_label;
495
496         mov(reg_wj, jcp.l_pad);
497         sub(reg_wj, reg_ow);
498         L(border_indexes_label);
499         {
500             cmp(reg_wj, 0);
501             jle(border_indexes_exit_label, T_NEAR);
502
503             add(aux_reg_kernel_h, sizeof(float));
504             add(aux_reg_input_h, sizeof(float) * (jcp.dilate_w + 1));
505             dec(reg_kw);
506             sub(reg_wj, jcp.dilate_w + 1);
507
508             jmp(border_indexes_label);
509
510             L(border_indexes_exit_label);
511         }
512     };
513
514     auto adjust_indexes_right = [&]() {
515         Label border_indexes_right_label;
516         Label border_indexes_right_exit_label;
517
518         imul(reg_wj, reg_ow, jcp.stride_w);
519         add(reg_wj, (jcp.kw-1) * (jcp.dilate_w+1) - jcp.l_pad+1 - jcp.iw);
520
521         L(border_indexes_right_label);
522         {
523             cmp(reg_wj, 0);
524             jle(border_indexes_right_exit_label, T_NEAR);
525
526             dec(reg_kw);
527             sub(reg_wj, jcp.dilate_w + 1);
528
529             jmp(border_indexes_right_label);
530
531             L(border_indexes_right_exit_label);
532         }
533     };
534
535     int left_border_end = nstl::min(div_up(jcp.l_pad, jcp.stride_w), jcp.ow);
536     L(left_border_label); {
537         cmp(reg_ow, left_border_end);
538         jge(main_loop_unrolled_label, T_NEAR);
539
540         mov(aux_reg_input_h, reg_input);
541         mov(aux_reg_kernel_h, reg_kernel);
542         mov(reg_kw, jcp.kw);
543
544         adjust_indexes_left();
545         adjust_indexes_right();
546
547         solve_loop(1, 1); // scalar
548
549         inc(reg_ow);
550         jmp(left_border_label, T_NEAR);
551     }
552
553     int main_loop_end = (jcp.iw - (jcp.kw - 1)*(jcp.dilate_w + 1) + jcp.l_pad - 1) / jcp.stride_w + 1;
554     L(main_loop_unrolled_label); {
555         cmp(reg_ow, main_loop_end - jcp.nb_ow_blocking * jcp.ow_block);
556         jg(main_loop_label, T_NEAR);
557
558         mov(aux_reg_input_h, reg_input);
559         mov(aux_reg_kernel_h, reg_kernel);
560         mov(reg_kw, jcp.kw);
561
562         solve_loop(jcp.nb_ow_blocking, jcp.nb_ow_blocking * jcp.ow_block);
563
564         add(reg_ow, jcp.nb_ow_blocking * jcp.ow_block);
565         jmp(main_loop_unrolled_label, T_NEAR);
566     }
567
568     L(main_loop_label); {
569         cmp(reg_ow, main_loop_end - jcp.ow_block);
570         jg(right_border_label, T_NEAR);
571
572         mov(aux_reg_input_h, reg_input);
573         mov(aux_reg_kernel_h, reg_kernel);
574         mov(reg_kw, jcp.kw);
575
576         solve_loop(1, jcp.ow_block); // vectorized
577
578         add(reg_ow, jcp.ow_block);
579         jmp(main_loop_label, T_NEAR);
580     }
581
582     int right_border_end = jcp.ow;
583     L(right_border_label); {
584         cmp(reg_ow, right_border_end);
585         jge(exit_label, T_NEAR);
586
587         mov(aux_reg_input_h, reg_input);
588         mov(aux_reg_kernel_h, reg_kernel);
589         mov(reg_kw, jcp.kw);
590
591         adjust_indexes_left();
592         adjust_indexes_right();
593
594         solve_loop(1, 1); // scalar
595
596         inc(reg_ow);
597         jmp(right_border_label, T_NEAR);
598     }
599
600     L(exit_label);
601 }
602
603 template <cpu_isa_t isa>
604 void jit_uni_planar_conv_fwd_kernel_f32<isa>::generate() {
605     const auto &p = attr_.post_ops_;
606     for (int i = 0; i < p.len_; i++) {
607         auto &post_op = p.entry_[i];
608         if (post_op.is_eltwise()) {
609             eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>(
610                     this,
611                     post_op.eltwise.alg,
612                     post_op.eltwise.alpha,
613                     post_op.eltwise.beta
614             ));
615         }
616     }
617
618     this->preamble();
619
620     mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
621     mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
622     mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
623     mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
624     mov(reg_oh_blocks, ptr[this->param1 + GET_OFF(oh_blocks)]);
625
626     Label tail_label;
627     Label exit_label;
628
629     solve_common(1);
630
631     this->postamble();
632
633     for (auto& inj : eltwise_injectors)
634         inj->prepare_table();
635 }
636
637 template <cpu_isa_t isa>
638 bool jit_uni_planar_conv_fwd_kernel_f32<isa>::post_ops_ok(
639         jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
640     const auto &p = attr.post_ops_;
641
642     auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
643     auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
644     auto is_simple = [&](int idx) { return is_eltwise(idx); };
645
646     switch (p.len_) {
647     case 0: return true; // no post_ops
648     case 1:
649         return true // sum OR eltwise OR depthwise
650                 && !jcp.with_eltwise && (is_simple(0) || is_sum(0));
651     case 2:
652         return true // sum->relu
653                 && !jcp.with_eltwise && ((is_sum(0) && is_simple(1)) ||
654                                          (is_simple(0) && is_simple(1)));
655     case 3:
656         return true // sum->relu
657                 && !jcp.with_eltwise && (is_sum(0) && is_simple(1) && is_simple(2));
658     default: return false;
659     }
660
661     return false;
662 }
663
664 template <cpu_isa_t isa>
665 status_t jit_uni_planar_conv_fwd_kernel_f32<isa>::init_conf(jit_conv_conf_t &jcp,
666         const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
667         const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
668         const primitive_attr_t &attr) {
669     if (!mayiuse(isa)) return status::unimplemented;
670
671     jcp.prop_kind = cd.prop_kind;
672
673     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
674     int ndims = src_d.ndims();
675     jcp.ndims = ndims;
676
677     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
678     jcp.mb = src_d.dims()[0];
679
680     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
681     jcp.oc_without_padding = jcp.oc;
682     jcp.ic = src_d.dims()[1] / jcp.ngroups;
683
684     jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
685     jcp.ih = src_d.dims()[ndims-2];
686     jcp.iw = src_d.dims()[ndims-1];
687     jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
688     jcp.oh = dst_d.dims()[ndims-2];
689     jcp.ow = dst_d.dims()[ndims-1];
690     jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
691     jcp.kh = weights_d.dims()[with_groups + ndims-2];
692     jcp.kw = weights_d.dims()[with_groups + ndims-1];
693
694     jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
695     jcp.t_pad = cd.padding[0][ndims-4];
696     jcp.l_pad = cd.padding[0][ndims-3];
697     jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
698     jcp.stride_h = cd.strides[ndims-4];
699     jcp.stride_w = cd.strides[ndims-3];
700
701     jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
702     jcp.dilate_h = cd.dilates[ndims-4];
703     jcp.dilate_w = cd.dilates[ndims-3];
704
705     jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
706             - (jcp.ih + jcp.t_pad - 1);
707
708     jcp.src_fmt = src_d.format();
709     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
710     jcp.with_eltwise = false;
711
712     if (!post_ops_ok(jcp, attr))
713         return status::unimplemented;
714
715     const auto &p = attr.post_ops_;
716     jcp.with_sum = p.find(primitive_kind::sum) != -1;
717
718     const int simd_w = isa == avx512_common ? 16 : 8;
719
720     bool args_ok = true
721         && one_of(src_d.format(), nchw, ncdhw)
722         && one_of(weights_d.format(), oihw, oidhw)
723         && one_of(cd.bias_desc.format, memory_format::undef, any, x)
724         && one_of(dst_d.format(), nchw, ncdhw);
725     if (!args_ok) return status::unimplemented;
726
727     // This convolution implementation was introduced as workaround to provide competitive performance on MSD topology.
728     // The conditions below are needed to bound applicability scope.
729     args_ok = jcp.ngroups == 1 &&
730               jcp.oc == 1 &&
731               jcp.stride_d == 1 && jcp.stride_h == 1 && jcp.stride_w == 1;
732
733     if (!args_ok) return status::unimplemented;
734
735     jcp.ur_w = 1;
736
737     jcp.ow_block = simd_w;
738     jcp.nb_ow_blocking = isa == avx512_common ? 3 : 3;
739
740     jcp.oh_block = 1;
741     jcp.nb_oh_blocking = 1;
742     jcp.oh_block_step = 1; // (jcp.dilate_h + 1);
743
744     jcp.oc_block = 1;
745     jcp.nb_oc = jcp.oc / jcp.oc_block;
746     jcp.nb_oc_blocking = 1;
747
748     jcp.ic_block = 1;
749     jcp.nb_ic = jcp.ic / jcp.ic_block;
750     jcp.nb_ic_blocking = 1;
751
752     return status::success;
753 }
754
755 template struct jit_uni_planar_conv_fwd_kernel_f32<avx512_common>;
756 template struct jit_uni_planar_conv_fwd_kernel_f32<avx2>;
757
758 }
759 }
760 }