Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_x8s8s32x_dw_conv_kernel.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "nstl.hpp"
19 #include "type_helpers.hpp"
20 #include "utils.hpp"
21 #include "cpu_memory.hpp"
22
23 #include "jit_uni_x8s8s32x_dw_conv_kernel.hpp"
24
25 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
26
27 namespace mkldnn {
28 namespace impl {
29 namespace cpu {
30
31 using namespace mkldnn::impl::prop_kind;
32 using namespace mkldnn::impl::memory_format;
33 using namespace mkldnn::impl::utils;
34
35 using namespace Xbyak;
36
37 template <cpu_isa_t isa>
38 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::load_src(int ur_ch_blocks, int ch_step, int ur_w) {
39     int repeats = isa == sse42 && ch_step > (jcp.ch_block / 2) ? 2 : 1;
40     for (int i = 0; i < repeats; i++) {
41         for (int ch = 0; ch < ur_ch_blocks; ch++) {
42             for (int ow = 0; ow < ur_w; ow++) {
43                 Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow);
44
45                 uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
46             }
47         }
48     }
49 }
50
51 template <cpu_isa_t isa>
52 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::apply_filter(int ur_ch_blocks, int ch_step, int ur_w) {
53     int ch_blk = jcp.ch_block;
54     int dilate_h = jcp.dilate_h + 1;
55     int dilate_w = jcp.dilate_w + 1;
56     int stride_w = jcp.stride_w;
57
58     Label iter_exit_label;
59
60     cmp(reg_kh, 0);
61     je(iter_exit_label, T_NEAR);
62     cmp(reg_kw, 0);
63     je(iter_exit_label, T_NEAR);
64
65     mov(iter_kh, reg_kh);
66     Label kh_label;
67     L(kh_label); {
68         mov(iter_kw, reg_kw);
69         mov(aux1_reg_input, aux_reg_input);
70         mov(aux1_reg_kernel, aux_reg_kernel);
71
72         Label kw_label;
73         L(kw_label); {
74             int repeats = isa == sse42 && ch_step > (jcp.ch_block / 2) ? 2 : 1;
75             for (int i = 0; i < repeats; i++) {
76                 for (int ch = 0; ch < ur_ch_blocks; ch++) {
77                     int ker_off = ch*jcp.kh*jcp.kw*ch_blk + i*(ch_blk / 2);
78                     Vmm vmm_ker = get_ker_reg(0);
79                     Xmm xmm_ker = Xmm(vmm_ker.getIdx());
80
81                     if (ch_step == 1) {
82                         movsx(reg_tmp_32, ptr[aux1_reg_kernel + ker_off*jcp.typesize_in]);
83                         movq(xmm_ker, reg_tmp_64);
84                     } else {
85                         uni_vpmovsxbd(vmm_ker, ptr[aux1_reg_kernel + ker_off*jcp.typesize_in]);
86                     }
87
88                     for (int ow = 0; ow < ur_w; ow++) {
89                         int inp_off = ch*ch_blk + ow*stride_w*jcp.oc + i*(ch_blk / 2);
90                         Vmm vmm_src = get_src_reg(0);
91                         Xmm xmm_src = Xmm(vmm_src.getIdx());
92
93                         if (ch_step == 1) {
94                             movzx(reg_tmp_32, ptr[aux1_reg_input + inp_off * jcp.typesize_in]);
95                             movq(xmm_src, reg_tmp_64);
96                         } else {
97                             uni_vpmovzxbd(vmm_src, ptr[aux1_reg_input + inp_off * jcp.typesize_in]);
98                         }
99
100                         Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow);
101                         uni_vpmulld(vmm_src, vmm_src, vmm_ker);
102                         uni_vpaddd(vmm_acc, vmm_acc, vmm_src);
103                     }
104                 }
105             }
106             add(aux1_reg_kernel, ch_blk*jcp.typesize_in);
107             add(aux1_reg_input, jcp.oc*dilate_w*jcp.typesize_in);
108
109             dec(iter_kw);
110             cmp(iter_kw, 0);
111             jg(kw_label, T_NEAR);
112         }
113         add(aux_reg_kernel, jcp.kw*ch_blk*jcp.typesize_in);
114         add(aux_reg_input, jcp.iw*jcp.oc*dilate_h*jcp.typesize_in);
115
116         dec(iter_kh);
117         cmp(iter_kh, 0);
118         jg(kh_label, T_NEAR);
119     }
120
121     L(iter_exit_label);
122 }
123
124 template <cpu_isa_t isa>
125 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::apply_filter_unrolled(int ur_ch_blocks, int ch_step, int ur_w) {
126     int ch_blk = jcp.ch_block;
127     int dilate_h = jcp.dilate_h + 1;
128     int dilate_w = jcp.dilate_w + 1;
129     int stride_w = jcp.stride_w;
130
131     Label iter_exit_label;
132
133     cmp(reg_kh, 0);
134     je(iter_exit_label, T_NEAR);
135
136     mov(iter_kh, reg_kh);
137     Label kh_label;
138     L(kh_label); {
139         int repeats = isa == sse42 && ch_step > (jcp.ch_block / 2) ? 2 : 1;
140         for (int i = 0; i < repeats; i++) {
141             for (int ch = 0; ch < ur_ch_blocks; ch++) {
142                 for (int kw = 0; kw < jcp.kw; kw++) {
143                     int ker_off = ch*jcp.kh*jcp.kw*ch_blk + kw*ch_blk + i*(ch_blk / 2);
144                     Vmm vmm_ker = get_ker_reg(0);
145                     Xmm xmm_ker = Xmm(vmm_ker.getIdx());
146
147                     if (ch_step == 1) {
148                         movsx(reg_tmp_32, ptr[aux_reg_kernel + ker_off*jcp.typesize_in]);
149                         movq(xmm_ker, reg_tmp_64);
150                     } else {
151                         uni_vpmovsxbd(vmm_ker, ptr[aux_reg_kernel + ker_off*jcp.typesize_in]);
152                     }
153
154                     for (int ow = 0; ow < ur_w; ow++) {
155                         int inp_off = ch*ch_blk + ow*stride_w*jcp.oc + kw*jcp.oc*dilate_w + i*(ch_blk / 2);
156                         Vmm vmm_src = get_src_reg(0);
157                         Xmm xmm_src = Xmm(vmm_src.getIdx());
158
159                         if (ch_step == 1) {
160                             movzx(reg_tmp_32, ptr[aux_reg_input + inp_off * jcp.typesize_in]);
161                             movq(xmm_src, reg_tmp_64);
162                         } else {
163                             uni_vpmovzxbd(vmm_src, ptr[aux_reg_input + inp_off * jcp.typesize_in]);
164                         }
165
166                         Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow);
167                         uni_vpmulld(vmm_src, vmm_src, vmm_ker);
168                         uni_vpaddd(vmm_acc, vmm_acc, vmm_src);
169                     }
170                 }
171             }
172         }
173
174         add(aux_reg_kernel, jcp.kw*ch_blk*jcp.typesize_in);
175         add(aux_reg_input, jcp.iw*jcp.oc*dilate_h*jcp.typesize_in);
176
177         dec(iter_kh);
178         cmp(iter_kh, 0);
179         jg(kh_label, T_NEAR);
180     }
181
182     L(iter_exit_label);
183 }
184
185 template <cpu_isa_t isa>
186 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::store_dst(const Xbyak::Address &op, Vmm vmm_dst, bool scalar_store) {
187     Ymm ymm_dst = Ymm(vmm_dst.getIdx());
188     Xmm xmm_dst = Xmm(vmm_dst.getIdx());
189
190     switch (jcp.dst_dt) {
191         case data_type::f32:
192         case data_type::s32:
193             if (scalar_store) {
194                 movq(reg_tmp_64, xmm_dst);
195                 mov(op, reg_tmp_32);
196             } else {
197                 uni_vmovups(op, vmm_dst);
198             }
199             break;
200         case data_type::s8:
201             uni_vpackssdw(vmm_dst, vmm_dst, vmm_dst);
202
203             if (isa != sse42 && !scalar_store)
204                 vpermq(ymm_dst, ymm_dst, 0x08);
205
206             uni_vpacksswb(vmm_dst, vmm_dst, vmm_dst);
207
208             if (scalar_store) {
209                 movq(reg_tmp_64, xmm_dst);
210                 mov(op, reg_tmp_8);
211             } else {
212                 if (isa != sse42)
213                     vmovq(op, xmm_dst);
214                 else
215                     movd(op, xmm_dst);
216             }
217             break;
218         case data_type::u8:
219             uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst);
220
221             if (isa != sse42 && !scalar_store)
222                 vpermq(ymm_dst, ymm_dst, 0x08);
223
224             uni_vpackuswb(vmm_dst, vmm_dst, vmm_dst);
225
226             if (scalar_store) {
227                 movq(reg_tmp_64, xmm_dst);
228                 mov(op, reg_tmp_8);
229             } else {
230                 if (isa != sse42)
231                     vmovq(op, xmm_dst);
232                 else
233                     movd(op, xmm_dst);
234             }
235
236             break;
237         default:
238             assert(!"unknown dst_dt");
239     }
240 }
241
242 template <cpu_isa_t isa>
243 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::cvt2ps(data_type_t type_in, Vmm vmm_in,
244         const Xbyak::Operand &op, bool scalar_load) {
245     Xmm xmm_in = Xmm(vmm_in.getIdx());
246
247     switch (type_in) {
248         case data_type::f32:
249         case data_type::s32:
250             if (scalar_load) {
251                 movsd(xmm_in, op);
252             } else {
253                 uni_vmovups(vmm_in, op);
254             }
255             break;
256         case data_type::s8:
257             if (scalar_load) {
258                 movsx(reg_tmp_32, op);
259                 movq(xmm_in, reg_tmp_64);
260             } else {
261                 uni_vpmovsxbd(vmm_in, op);
262             }
263             break;
264         case data_type::u8:
265             if (scalar_load) {
266                 movzx(reg_tmp_32, op);
267                 movq(xmm_in, reg_tmp_64);
268             } else {
269                 uni_vpmovzxbd(vmm_in, op);
270             }
271             break;
272         default: assert(!"unsupported data type");
273     }
274
275     if (type_in != data_type::f32)
276         uni_vcvtdq2ps(vmm_in, vmm_in);
277 }
278
279 template <cpu_isa_t isa>
280 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::store_dst(int ur_ch_blocks, int ch_step, int ur_w) {
281     int repeats = isa == sse42 && ch_step > (jcp.ch_block / 2) ? 2 : 1;
282
283     pop(reg_oc_off);
284     pop(reg_scales_base);
285
286     mov(imm_addr64, l_table);
287
288     const auto &p = attr_.post_ops_;
289     const int sum_idx = p.find(primitive_kind::sum);
290     const float p_sum_scale = (sum_idx != -1) ? p.entry_[sum_idx].sum.scale : 1.f;
291
292     bool is_scalar_store = ch_step < jcp.ch_block;
293
294     for (int r = 0; r < repeats; r++) {
295         for (int ii = 0; ii < ur_ch_blocks; ii++) {
296             if (jcp.with_bias) {
297                 int b_off = ii * jcp.ch_block + r * (jcp.ch_block / 2);
298                 cvt2ps(jcp.bia_dt, vmm_bias, ptr[reg_bias_base + b_off * jcp.typesize_bia], is_scalar_store);
299             }
300
301             for (int jj = 0; jj < ur_w; jj++) {
302                 Vmm vmm_dst = get_acc_reg(r * ur_ch_blocks * ur_w + ur_w * ii + jj);
303                 uni_vcvtdq2ps(vmm_dst, vmm_dst);
304
305                 if (jcp.with_bias)
306                     uni_vaddps(vmm_dst, vmm_dst, vmm_bias);
307
308                 int s_off = jcp.is_oc_scale * (ii * jcp.ch_block + r * (jcp.ch_block / 2));
309                 cvt2ps(mkldnn_f32, vmm_scale, ptr[reg_scales_base + s_off * sizeof(float)], is_scalar_store);
310                 uni_vmulps(vmm_dst, vmm_dst, vmm_scale);
311             }
312         }
313
314         int eltwise_inj_idx = 0;
315         int depthwise_inj_idx = 0;
316         for (int i = 0; i < p.len_; i++) {
317             int start_idx = 4 + r * ur_ch_blocks*ur_w;
318
319             auto& post_op = p.entry_[i];
320             if (post_op.is_eltwise()) {
321                 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(start_idx, start_idx + ur_ch_blocks * ur_w);
322                 eltwise_inj_idx++;
323             } else if (post_op.is_depthwise()) {
324                 mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
325                 mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
326
327                 add(reg_d_weights, reg_oc_off);
328                 add(reg_d_bias, reg_oc_off);
329
330                 if (r == 1) {
331                     add(reg_d_weights, (jcp.ch_block / 2) * sizeof(float));
332                     add(reg_d_bias, (jcp.ch_block / 2) * sizeof(float));
333                 }
334
335                 for (int ii = 0; ii < ur_ch_blocks; ii++) {
336                     depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
337                             start_idx + ur_w * ii, start_idx + ur_w * ii + ur_w, reg_d_weights, reg_d_bias);
338
339                     add(reg_d_weights, jcp.ch_block * sizeof(float));
340                     add(reg_d_bias, jcp.ch_block * sizeof(float));
341                 }
342
343                 depthwise_inj_idx++;
344             } else if (post_op.is_sum(false)) {
345                 for (int ii = 0; ii < ur_ch_blocks; ii++) {
346                     for (int jj = 0; jj < ur_w; jj++) {
347                         Vmm vmm_dst = get_acc_reg(r * ur_ch_blocks*ur_w + ur_w * ii + jj);
348                         int o_off = ii * jcp.ch_block + jj * jcp.oc + r * (jcp.ch_block / 2);
349
350                         cvt2ps(jcp.dst_dt, vmm_prev_dst, ptr[reg_output + o_off * jcp.typesize_out], is_scalar_store);
351
352                         if (p_sum_scale == 1.f) {
353                             uni_vaddps(vmm_dst, vmm_dst, vmm_prev_dst);
354                         } else {
355                             uni_vfmadd231ps(vmm_dst, vmm_prev_dst, ptr[imm_addr64 + 0 * vlen]);
356                         }
357                     }
358                 }
359             }
360         }
361
362         for (int ii = 0; ii < ur_ch_blocks; ii++) {
363             for (int jj = 0; jj < ur_w; jj++) {
364                 Vmm vmm_dst = get_acc_reg(r * ur_ch_blocks * ur_w + ur_w * ii + jj);
365                 int o_off = ii * jcp.ch_block + jj * jcp.oc + r * (jcp.ch_block / 2);
366
367                 if (jcp.dst_dt != data_type::f32) {
368                     if (attr_.round_mode_ == round_mode::nearest)
369                         uni_vcvtps2dq(vmm_dst, vmm_dst);
370                     else if (attr_.round_mode_ == round_mode::down) {
371                         uni_vroundps(vmm_dst, vmm_dst, 1);
372                         uni_vcvtps2dq(vmm_dst, vmm_dst);
373                     } else
374                         assert(!"unimplemented");
375                 }
376
377                 store_dst(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, is_scalar_store);
378             }
379         }
380     }
381
382     push(reg_scales_base);
383     push(reg_oc_off);
384 }
385
386 template <cpu_isa_t isa>
387 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::loop_body(int ur_ch_blocks, int ch_step) {
388     Label unrolled_w_label;
389     Label tail_w_label;
390     Label exit_label;
391
392     mov(reg_ur_w, ptr[this->param1 + GET_OFF(ur_w)]);
393     mov(reg_input, reg_input_base);
394     mov(reg_output, reg_output_base);
395     mov(reg_kernel, reg_kernel_base);
396
397     push(reg_input_base);
398     push(reg_output_base);
399     push(reg_kernel_base);
400     push(reg_ch_work);
401     push(reg_scales_base);
402     push(reg_oc_off);
403
404     L(unrolled_w_label); {
405         int ur_w = jcp.ur_w;
406
407         cmp(reg_ur_w, ur_w);
408         jl(tail_w_label, T_NEAR);
409
410         mov(aux_reg_input, reg_input);
411         mov(aux_reg_kernel, reg_kernel);
412
413         load_src(ur_ch_blocks, ch_step, ur_w);
414         apply_filter_unrolled(ur_ch_blocks, ch_step, ur_w);
415         store_dst(ur_ch_blocks, ch_step, ur_w);
416
417         add(reg_input, jcp.typesize_in * ur_w * jcp.ic * jcp.stride_w);
418         add(reg_output, jcp.typesize_out * ur_w * jcp.oc);
419
420         sub(reg_ur_w, ur_w);
421         jmp(unrolled_w_label);
422     }
423
424     L(tail_w_label); {
425         int ur_w = 1;
426
427         cmp(reg_ur_w, ur_w);
428         jl(exit_label, T_NEAR);
429
430         mov(aux_reg_input, reg_input);
431         mov(aux_reg_kernel, reg_kernel);
432
433         load_src(ur_ch_blocks, ch_step, ur_w);
434         apply_filter(ur_ch_blocks, ch_step, ur_w);
435         store_dst(ur_ch_blocks, ch_step, ur_w);
436
437         add(reg_input, jcp.typesize_in * ur_w * jcp.ic * jcp.stride_w);
438         add(reg_output, jcp.typesize_out * ur_w * jcp.oc);
439
440         sub(reg_ur_w, ur_w);
441         jmp(tail_w_label);
442     }
443
444     L(exit_label);
445
446     pop(reg_oc_off);
447     pop(reg_scales_base);
448     pop(reg_ch_work);
449     pop(reg_kernel_base);
450     pop(reg_output_base);
451     pop(reg_input_base);
452 }
453
454 template <cpu_isa_t isa>
455 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::generate() {
456     const auto &p = attr_.post_ops_;
457     for (int i = 0; i < p.len_; i++) {
458         auto &post_op = p.entry_[i];
459         if (post_op.is_eltwise()) {
460             eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>(
461                     this,
462                     post_op.eltwise.alg,
463                     post_op.eltwise.alpha,
464                     post_op.eltwise.beta
465             ));
466         } else if (post_op.is_depthwise()) {
467             depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<isa>(
468                     this,
469                     post_op.depthwise.alg
470             ));
471         }
472     }
473
474     this->preamble();
475
476     mov(reg_input_base, ptr[this->param1 + GET_OFF(src)]);
477     mov(reg_output_base, ptr[this->param1 + GET_OFF(dst)]);
478     mov(reg_kernel_base, ptr[this->param1 + GET_OFF(filt)]);
479     if (jcp.with_bias)
480         mov(reg_bias_base, ptr[this->param1 + GET_OFF(bias)]);
481     mov(reg_scales_base, ptr[this->param1 + GET_OFF(scales)]);
482     mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
483     mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]);
484     mov(reg_ch_work, ptr[this->param1 + GET_OFF(ch_work)]);
485     mov(reg_oc_off, ptr[this->param1 + GET_OFF(oc_off)]);
486
487     Label main_loop_label;
488     Label tail_loop_label;
489     Label exit_label;
490
491     cmp(reg_ch_work, jcp.nb_ch_blocking * jcp.ch_block);
492     jne(main_loop_label, T_NEAR);
493
494     loop_body(jcp.nb_ch_blocking, jcp.nb_ch_blocking * jcp.ch_block);
495
496     sub(reg_ch_work, jcp.nb_ch_blocking * jcp.ch_block);
497
498     jmp(exit_label, T_NEAR);
499
500     L(main_loop_label); {
501         cmp(reg_ch_work, jcp.ch_block);
502         jl(tail_loop_label, T_NEAR);
503
504         loop_body(1, jcp.ch_block);
505
506         sub(reg_ch_work, jcp.ch_block);
507         add(reg_input_base, jcp.ch_block * jcp.typesize_in);
508         add(reg_output_base, jcp.ch_block * jcp.typesize_out);
509         add(reg_kernel_base, jcp.ch_block * jcp.kh * jcp.kw * jcp.typesize_in);
510         add(reg_bias_base, jcp.ch_block * jcp.typesize_bia);
511         add(reg_scales_base, jcp.is_oc_scale * jcp.ch_block * sizeof(float));
512         add(reg_oc_off, jcp.ch_block * sizeof(float));
513
514         jmp(main_loop_label, T_NEAR);
515     }
516
517     L(tail_loop_label); {
518         cmp(reg_ch_work, 1);
519         jl(exit_label, T_NEAR);
520
521         loop_body(1, 1);
522
523         sub(reg_ch_work, 1);
524         add(reg_input_base, 1 * jcp.typesize_in);
525         add(reg_output_base, 1 * jcp.typesize_out);
526         add(reg_kernel_base, 1 * jcp.typesize_in);
527         add(reg_bias_base, 1 * jcp.typesize_bia);
528         add(reg_scales_base, jcp.is_oc_scale * 1 * sizeof(float));
529         add(reg_oc_off, 1 * sizeof(float));
530
531         jmp(tail_loop_label, T_NEAR);
532     }
533
534     L(exit_label);
535
536     this->postamble();
537
538     prepare_table();
539
540     for (auto& inj : eltwise_injectors)
541         inj->prepare_table();
542 }
543
544 template <cpu_isa_t isa>
545 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::prepare_table() {
546     const auto &p = attr_.post_ops_;
547     const int sum_idx = p.find(primitive_kind::sum);
548     const float p_sum_scale = (sum_idx != -1) ? p.entry_[sum_idx].sum.scale : 1.f;
549
550     const int32_t cvals_sum_scale[] = {
551         float2int(p_sum_scale)
552     };
553
554     align(64);
555     L(l_table);
556     for (size_t i = 0; i < sizeof(cvals_sum_scale) / sizeof(cvals_sum_scale[0]); ++i) {
557         for (size_t d = 0; d < vlen / sizeof(int32_t); ++d) {
558             dd(cvals_sum_scale[i]);
559         }
560     }
561 }
562
563 template <cpu_isa_t isa>
564 bool jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::post_ops_ok(
565         jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
566     const auto &p = attr.post_ops_;
567
568     auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
569     auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
570     auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(false); };
571     auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
572
573     switch (p.len_) {
574         case 0: return true;
575         case 1: return is_simple(0) || is_sum(0);
576         case 2: return (is_sum(0) && is_simple(1)) || (is_simple(0) && is_sum(1)) ||
577                        (is_simple(0) && is_simple(1));
578         case 3: return (is_simple(0) && is_sum(1) && is_simple(2));
579         default: return false;
580     }
581
582     return false;
583 }
584
585 template <cpu_isa_t isa>
586 status_t jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::init_conf(jit_conv_conf_t &jcp,
587         const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
588         const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
589         const memory_desc_wrapper &bias_pd, const primitive_attr_t &attr)
590 {
591     if (!mayiuse(isa)) return status::unimplemented;
592
593     if (!(src_d.data_type() == data_type::u8 &&
594           weights_d.data_type() == data_type::s8 &&
595           one_of(dst_d.data_type(), data_type::f32, data_type::s32, data_type::s8, data_type::u8)))
596         return status::unimplemented;
597
598     jcp.prop_kind = cd.prop_kind;
599
600     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
601     if (!with_groups) return status::unimplemented;
602
603     jcp.ngroups = weights_d.dims()[0];
604     jcp.mb = src_d.dims()[0];
605
606     jcp.oc = dst_d.dims()[1];
607     jcp.ic = src_d.dims()[1];
608
609     jcp.ih = src_d.dims()[2];
610     jcp.iw = src_d.dims()[3];
611     jcp.oh = dst_d.dims()[2];
612     jcp.ow = dst_d.dims()[3];
613
614     jcp.kh = weights_d.dims()[3];
615     jcp.kw = weights_d.dims()[4];
616
617     jcp.t_pad = cd.padding[0][0];
618     jcp.l_pad = cd.padding[0][1];
619     jcp.b_pad = cd.padding[1][0];
620     jcp.r_pad = cd.padding[1][1];
621
622     jcp.stride_h = cd.strides[0];
623     jcp.stride_w = cd.strides[1];
624
625     jcp.dilate_h = cd.dilates[0];
626     jcp.dilate_w = cd.dilates[1];
627
628     jcp.src_fmt = src_d.format();
629     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
630
631     jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false;
632
633     if (jcp.signed_input)
634         return status::unimplemented;
635
636     const int simd_w = isa == avx512_common ? 16 : 8;
637     jcp.ch_block = simd_w;
638     jcp.nb_ch = div_up(jcp.oc, jcp.ch_block);
639
640     if (!post_ops_ok(jcp, attr))
641         return status::unimplemented;
642
643     const auto &p = attr.post_ops_;
644     jcp.with_sum = p.find(primitive_kind::sum) != -1;
645     const int eltwise_ind = p.find(primitive_kind::eltwise);
646     jcp.with_eltwise = eltwise_ind != -1;
647     if (jcp.with_eltwise)
648         jcp.eltwise = p.entry_[eltwise_ind].eltwise;
649
650     auto desired_act_fmt = nhwc;
651     auto desired_wei_fmt = isa == avx512_common ? Goihw16g : Goihw8g;
652
653     bool args_ok = true
654         && jcp.oc == jcp.ngroups
655         && jcp.ic == jcp.ngroups
656         && src_d.format() == desired_act_fmt
657         && weights_d.format() == desired_wei_fmt
658         && one_of(cd.bias_desc.format, memory_format::undef, any, x)
659         && dst_d.format() == desired_act_fmt;
660     if (!args_ok) return status::unimplemented;
661
662     jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
663     jcp.dst_dt = cd.dst_desc.data_type;
664
665     jcp.typesize_in = types::data_type_size(src_d.data_type());
666     jcp.typesize_out = types::data_type_size(dst_d.data_type());
667     jcp.typesize_acc = sizeof(int32_t);
668     jcp.typesize_bia = jcp.with_bias
669                        ? types::data_type_size(bias_pd.data_type())
670                        : 0;
671
672     const auto &oscales = attr.output_scales_;
673     jcp.is_oc_scale = oscales.mask_ == 1 << 1;
674
675     assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
676
677     jcp.ur_w = isa == avx512_common ? 6 : isa == avx2 ? 4 : 3;
678
679     jcp.nb_ch_blocking = isa == avx512_common ? 4 : isa == avx2 ? 3 : 2;
680     if (jcp.nb_ch < jcp.nb_ch_blocking)
681         jcp.nb_ch_blocking = jcp.nb_ch;
682
683     return status::success;
684 }
685
686 template struct jit_uni_x8s8s32x_dw_conv_fwd_kernel<avx2>;
687 template struct jit_uni_x8s8s32x_dw_conv_fwd_kernel<sse42>;
688
689 }
690 }
691 }