updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_x8s8s32x_dw_conv_kernel.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "nstl.hpp"
19 #include "type_helpers.hpp"
20 #include "utils.hpp"
21 #include "cpu_memory.hpp"
22
23 #include "jit_uni_x8s8s32x_dw_conv_kernel.hpp"
24
25 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
26
27 namespace mkldnn {
28 namespace impl {
29 namespace cpu {
30
31 using namespace mkldnn::impl::prop_kind;
32 using namespace mkldnn::impl::memory_format;
33 using namespace mkldnn::impl::utils;
34
35 using namespace Xbyak;
36
37 template <cpu_isa_t isa>
38 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::load_src(int ur_ch_blocks, int ch_step, int ur_w) {
39     int repeats = isa == sse42 && ch_step > (jcp.ch_block / 2) ? 2 : 1;
40     for (int i = 0; i < repeats; i++) {
41         for (int ch = 0; ch < ur_ch_blocks; ch++) {
42             for (int ow = 0; ow < ur_w; ow++) {
43                 Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow);
44
45                 uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
46             }
47         }
48     }
49 }
50
51 template <cpu_isa_t isa>
52 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::apply_filter(int ur_ch_blocks, int ch_step, int ur_w) {
53     int ch_blk = jcp.ch_block;
54     int dilate_d = jcp.dilate_d + 1;
55     int dilate_h = jcp.dilate_h + 1;
56     int dilate_w = jcp.dilate_w + 1;
57     int stride_w = jcp.stride_w;
58
59     Label iter_exit_label;
60     Label kd_label, iter_d_exit_label;
61
62     if (jcp.ndims == 5) {
63         push(reg_input);
64         push(reg_kernel);
65         push(reg_bias_base);
66
67         mov(reg_kd, ptr[this->param1 + GET_OFF(kd_padding)]);
68         cmp(reg_kd, 0);
69         je(iter_d_exit_label, T_NEAR);
70
71         mov(aux_reg_inp_d, aux_reg_input);
72         mov(aux_reg_ker_d, aux_reg_kernel);
73
74         L(kd_label);
75
76         mov(aux_reg_input, aux_reg_inp_d);
77         mov(aux_reg_kernel, aux_reg_ker_d);
78     }
79
80     mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
81     cmp(reg_kh, 0);
82     je(iter_exit_label, T_NEAR);
83     cmp(reg_kw, 0);
84     je(iter_exit_label, T_NEAR);
85
86     mov(iter_kh, reg_kh);
87     Label kh_label;
88     L(kh_label); {
89         mov(iter_kw, reg_kw);
90         mov(aux1_reg_input, aux_reg_input);
91         mov(aux1_reg_kernel, aux_reg_kernel);
92
93         Label kw_label;
94         L(kw_label); {
95             int repeats = isa == sse42 && ch_step > (jcp.ch_block / 2) ? 2 : 1;
96             for (int i = 0; i < repeats; i++) {
97                 for (int ch = 0; ch < ur_ch_blocks; ch++) {
98                     int ker_off = ch*jcp.kd*jcp.kh*jcp.kw*ch_blk + i*(ch_blk / 2);
99                     Vmm vmm_ker = get_ker_reg(0);
100                     Xmm xmm_ker = Xmm(vmm_ker.getIdx());
101
102                     if (ch_step == 1) {
103                         movsx(reg_tmp_32, ptr[aux1_reg_kernel + ker_off*jcp.typesize_in]);
104                         movq(xmm_ker, reg_tmp_64);
105                     } else {
106                         uni_vpmovsxbd(vmm_ker, ptr[aux1_reg_kernel + ker_off*jcp.typesize_in]);
107                     }
108
109                     for (int ow = 0; ow < ur_w; ow++) {
110                         int inp_off = ch*ch_blk + ow*stride_w*jcp.oc + i*(ch_blk / 2);
111                         Vmm vmm_src = get_src_reg(0);
112                         Xmm xmm_src = Xmm(vmm_src.getIdx());
113
114                         if (ch_step == 1) {
115                             movzx(reg_tmp_32, ptr[aux1_reg_input + inp_off * jcp.typesize_in]);
116                             movq(xmm_src, reg_tmp_64);
117                         } else {
118                             uni_vpmovzxbd(vmm_src, ptr[aux1_reg_input + inp_off * jcp.typesize_in]);
119                         }
120
121                         Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow);
122                         uni_vpmulld(vmm_src, vmm_src, vmm_ker);
123                         uni_vpaddd(vmm_acc, vmm_acc, vmm_src);
124                     }
125                 }
126             }
127             add(aux1_reg_kernel, ch_blk*jcp.typesize_in);
128             add(aux1_reg_input, jcp.oc*dilate_w*jcp.typesize_in);
129
130             dec(iter_kw);
131             cmp(iter_kw, 0);
132             jg(kw_label, T_NEAR);
133         }
134         add(aux_reg_kernel, jcp.kw*ch_blk*jcp.typesize_in);
135         add(aux_reg_input, jcp.iw*jcp.oc*dilate_h*jcp.typesize_in);
136
137         dec(iter_kh);
138         cmp(iter_kh, 0);
139         jg(kh_label, T_NEAR);
140     }
141
142     L(iter_exit_label);
143
144     if (jcp.ndims == 5) {
145         add(aux_reg_inp_d, dilate_d * jcp.ih * jcp.iw * jcp.ic * jcp.typesize_in);
146         add(aux_reg_ker_d, jcp.kh * jcp.kw * ch_blk * jcp.typesize_in);
147         mov(aux_reg_input, aux_reg_inp_d);
148         mov(aux_reg_kernel, aux_reg_ker_d);
149
150         dec(reg_kd);
151         cmp(reg_kd, 0);
152         jg(kd_label, T_NEAR);
153
154         L(iter_d_exit_label);
155
156         pop(reg_bias_base);
157         pop(reg_kernel);
158         pop(reg_input);
159     }
160 }
161
162 template <cpu_isa_t isa>
163 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::apply_filter_unrolled(int ur_ch_blocks, int ch_step, int ur_w) {
164     int ch_blk = jcp.ch_block;
165     int dilate_d = jcp.dilate_d + 1;
166     int dilate_h = jcp.dilate_h + 1;
167     int dilate_w = jcp.dilate_w + 1;
168     int stride_w = jcp.stride_w;
169
170     Label iter_exit_label;
171     Label kd_label, iter_d_exit_label;
172
173     if (jcp.ndims == 5) {
174         push(reg_input);
175         push(reg_kernel);
176         push(reg_bias_base);
177
178         mov(reg_kd, ptr[this->param1 + GET_OFF(kd_padding)]);
179         cmp(reg_kd, 0);
180         je(iter_d_exit_label, T_NEAR);
181
182         mov(aux_reg_inp_d, aux_reg_input);
183         mov(aux_reg_ker_d, aux_reg_kernel);
184
185         L(kd_label);
186
187         mov(aux_reg_input, aux_reg_inp_d);
188         mov(aux_reg_kernel, aux_reg_ker_d);
189     }
190
191     mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
192     cmp(reg_kh, 0);
193     je(iter_exit_label, T_NEAR);
194
195     mov(iter_kh, reg_kh);
196     Label kh_label;
197     L(kh_label); {
198         int repeats = isa == sse42 && ch_step > (jcp.ch_block / 2) ? 2 : 1;
199         for (int i = 0; i < repeats; i++) {
200             for (int ch = 0; ch < ur_ch_blocks; ch++) {
201                 for (int kw = 0; kw < jcp.kw; kw++) {
202                     int ker_off = ch*jcp.kd*jcp.kh*jcp.kw*ch_blk + kw*ch_blk + i*(ch_blk / 2);
203                     Vmm vmm_ker = get_ker_reg(0);
204                     Xmm xmm_ker = Xmm(vmm_ker.getIdx());
205
206                     if (ch_step == 1) {
207                         movsx(reg_tmp_32, ptr[aux_reg_kernel + ker_off*jcp.typesize_in]);
208                         movq(xmm_ker, reg_tmp_64);
209                     } else {
210                         uni_vpmovsxbd(vmm_ker, ptr[aux_reg_kernel + ker_off*jcp.typesize_in]);
211                     }
212
213                     for (int ow = 0; ow < ur_w; ow++) {
214                         int inp_off = ch*ch_blk + ow*stride_w*jcp.oc + kw*jcp.oc*dilate_w + i*(ch_blk / 2);
215                         Vmm vmm_src = get_src_reg(0);
216                         Xmm xmm_src = Xmm(vmm_src.getIdx());
217
218                         if (ch_step == 1) {
219                             movzx(reg_tmp_32, ptr[aux_reg_input + inp_off * jcp.typesize_in]);
220                             movq(xmm_src, reg_tmp_64);
221                         } else {
222                             uni_vpmovzxbd(vmm_src, ptr[aux_reg_input + inp_off * jcp.typesize_in]);
223                         }
224
225                         Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow);
226                         uni_vpmulld(vmm_src, vmm_src, vmm_ker);
227                         uni_vpaddd(vmm_acc, vmm_acc, vmm_src);
228                     }
229                 }
230             }
231         }
232
233         add(aux_reg_kernel, jcp.kw*ch_blk*jcp.typesize_in);
234         add(aux_reg_input, jcp.iw*jcp.oc*dilate_h*jcp.typesize_in);
235
236         dec(iter_kh);
237         cmp(iter_kh, 0);
238         jg(kh_label, T_NEAR);
239     }
240
241     L(iter_exit_label);
242
243     if (jcp.ndims == 5) {
244         add(aux_reg_inp_d, dilate_d * jcp.ih * jcp.iw * jcp.ic * jcp.typesize_in);
245         add(aux_reg_ker_d, jcp.kh * jcp.kw * ch_blk * jcp.typesize_in);
246         mov(aux_reg_input, aux_reg_inp_d);
247         mov(aux_reg_kernel, aux_reg_ker_d);
248
249         dec(reg_kd);
250         cmp(reg_kd, 0);
251         jg(kd_label, T_NEAR);
252
253         L(iter_d_exit_label);
254
255         pop(reg_bias_base);
256         pop(reg_kernel);
257         pop(reg_input);
258     }
259 }
260
261 template <cpu_isa_t isa>
262 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::store_dst(const Xbyak::Address &op, Vmm vmm_dst, bool scalar_store) {
263     Ymm ymm_dst = Ymm(vmm_dst.getIdx());
264     Xmm xmm_dst = Xmm(vmm_dst.getIdx());
265
266     switch (jcp.dst_dt) {
267         case data_type::f32:
268         case data_type::s32:
269             if (scalar_store) {
270                 movq(reg_tmp_64, xmm_dst);
271                 mov(op, reg_tmp_32);
272             } else {
273                 uni_vmovups(op, vmm_dst);
274             }
275             break;
276         case data_type::s8:
277             uni_vpackssdw(vmm_dst, vmm_dst, vmm_dst);
278
279             if (isa != sse42 && !scalar_store)
280                 vpermq(ymm_dst, ymm_dst, 0x08);
281
282             uni_vpacksswb(vmm_dst, vmm_dst, vmm_dst);
283
284             if (scalar_store) {
285                 movq(reg_tmp_64, xmm_dst);
286                 mov(op, reg_tmp_8);
287             } else {
288                 if (isa != sse42)
289                     vmovq(op, xmm_dst);
290                 else
291                     movd(op, xmm_dst);
292             }
293             break;
294         case data_type::u8:
295             uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst);
296
297             if (isa != sse42 && !scalar_store)
298                 vpermq(ymm_dst, ymm_dst, 0x08);
299
300             uni_vpackuswb(vmm_dst, vmm_dst, vmm_dst);
301
302             if (scalar_store) {
303                 movq(reg_tmp_64, xmm_dst);
304                 mov(op, reg_tmp_8);
305             } else {
306                 if (isa != sse42)
307                     vmovq(op, xmm_dst);
308                 else
309                     movd(op, xmm_dst);
310             }
311
312             break;
313         default:
314             assert(!"unknown dst_dt");
315     }
316 }
317
318 template <cpu_isa_t isa>
319 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::cvt2ps(data_type_t type_in, Vmm vmm_in,
320         const Xbyak::Operand &op, bool scalar_load) {
321     Xmm xmm_in = Xmm(vmm_in.getIdx());
322
323     switch (type_in) {
324         case data_type::f32:
325         case data_type::s32:
326             if (scalar_load) {
327                 movsd(xmm_in, op);
328             } else {
329                 uni_vmovups(vmm_in, op);
330             }
331             break;
332         case data_type::s8:
333             if (scalar_load) {
334                 movsx(reg_tmp_32, op);
335                 movq(xmm_in, reg_tmp_64);
336             } else {
337                 uni_vpmovsxbd(vmm_in, op);
338             }
339             break;
340         case data_type::u8:
341             if (scalar_load) {
342                 movzx(reg_tmp_32, op);
343                 movq(xmm_in, reg_tmp_64);
344             } else {
345                 uni_vpmovzxbd(vmm_in, op);
346             }
347             break;
348         default: assert(!"unsupported data type");
349     }
350
351     if (type_in != data_type::f32)
352         uni_vcvtdq2ps(vmm_in, vmm_in);
353 }
354
355 template <cpu_isa_t isa>
356 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::store_dst(int ur_ch_blocks, int ch_step, int ur_w) {
357     int repeats = isa == sse42 && ch_step > (jcp.ch_block / 2) ? 2 : 1;
358
359     pop(reg_oc_off);
360     pop(reg_scales_base);
361
362     mov(imm_addr64, l_table);
363
364     const auto &p = attr_.post_ops_;
365     const int sum_idx = p.find(primitive_kind::sum);
366     const float p_sum_scale = (sum_idx != -1) ? p.entry_[sum_idx].sum.scale : 1.f;
367
368     bool is_scalar_store = ch_step < jcp.ch_block;
369
370     for (int r = 0; r < repeats; r++) {
371         for (int ii = 0; ii < ur_ch_blocks; ii++) {
372             if (jcp.with_bias) {
373                 int b_off = ii * jcp.ch_block + r * (jcp.ch_block / 2);
374                 cvt2ps(jcp.bia_dt, vmm_bias, ptr[reg_bias_base + b_off * jcp.typesize_bia], is_scalar_store);
375             }
376
377             for (int jj = 0; jj < ur_w; jj++) {
378                 Vmm vmm_dst = get_acc_reg(r * ur_ch_blocks * ur_w + ur_w * ii + jj);
379                 uni_vcvtdq2ps(vmm_dst, vmm_dst);
380
381                 if (jcp.with_bias)
382                     uni_vaddps(vmm_dst, vmm_dst, vmm_bias);
383
384                 int s_off = jcp.is_oc_scale * (ii * jcp.ch_block + r * (jcp.ch_block / 2));
385                 cvt2ps(mkldnn_f32, vmm_scale, ptr[reg_scales_base + s_off * sizeof(float)], is_scalar_store);
386                 uni_vmulps(vmm_dst, vmm_dst, vmm_scale);
387             }
388         }
389
390         int eltwise_inj_idx = 0;
391         int depthwise_inj_idx = 0;
392         for (int i = 0; i < p.len_; i++) {
393             int start_idx = 4 + r * ur_ch_blocks*ur_w;
394
395             auto& post_op = p.entry_[i];
396             if (post_op.is_eltwise()) {
397                 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(start_idx, start_idx + ur_ch_blocks * ur_w);
398                 eltwise_inj_idx++;
399             } else if (post_op.is_depthwise()) {
400                 mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
401                 mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
402
403                 add(reg_d_weights, reg_oc_off);
404                 add(reg_d_bias, reg_oc_off);
405
406                 if (r == 1) {
407                     add(reg_d_weights, (jcp.ch_block / 2) * sizeof(float));
408                     add(reg_d_bias, (jcp.ch_block / 2) * sizeof(float));
409                 }
410
411                 for (int ii = 0; ii < ur_ch_blocks; ii++) {
412                     depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
413                             start_idx + ur_w * ii, start_idx + ur_w * ii + ur_w, reg_d_weights, reg_d_bias);
414
415                     add(reg_d_weights, jcp.ch_block * sizeof(float));
416                     add(reg_d_bias, jcp.ch_block * sizeof(float));
417                 }
418
419                 depthwise_inj_idx++;
420             } else if (post_op.is_sum(false)) {
421                 for (int ii = 0; ii < ur_ch_blocks; ii++) {
422                     for (int jj = 0; jj < ur_w; jj++) {
423                         Vmm vmm_dst = get_acc_reg(r * ur_ch_blocks*ur_w + ur_w * ii + jj);
424                         int o_off = ii * jcp.ch_block + jj * jcp.oc + r * (jcp.ch_block / 2);
425
426                         cvt2ps(jcp.dst_dt, vmm_prev_dst, ptr[reg_output + o_off * jcp.typesize_out], is_scalar_store);
427
428                         if (p_sum_scale == 1.f) {
429                             uni_vaddps(vmm_dst, vmm_dst, vmm_prev_dst);
430                         } else {
431                             uni_vfmadd231ps(vmm_dst, vmm_prev_dst, ptr[imm_addr64 + 0 * vlen]);
432                         }
433                     }
434                 }
435             }
436         }
437
438         for (int ii = 0; ii < ur_ch_blocks; ii++) {
439             for (int jj = 0; jj < ur_w; jj++) {
440                 Vmm vmm_dst = get_acc_reg(r * ur_ch_blocks * ur_w + ur_w * ii + jj);
441                 int o_off = ii * jcp.ch_block + jj * jcp.oc + r * (jcp.ch_block / 2);
442
443                 if (jcp.dst_dt != data_type::f32) {
444                     if (attr_.round_mode_ == round_mode::nearest)
445                         uni_vcvtps2dq(vmm_dst, vmm_dst);
446                     else if (attr_.round_mode_ == round_mode::down) {
447                         uni_vroundps(vmm_dst, vmm_dst, 1);
448                         uni_vcvtps2dq(vmm_dst, vmm_dst);
449                     } else
450                         assert(!"unimplemented");
451                 }
452
453                 store_dst(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, is_scalar_store);
454             }
455         }
456     }
457
458     push(reg_scales_base);
459     push(reg_oc_off);
460 }
461
462 template <cpu_isa_t isa>
463 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::loop_body(int ur_ch_blocks, int ch_step) {
464     Label unrolled_w_label;
465     Label tail_w_label;
466     Label exit_label;
467
468     mov(reg_ur_w, ptr[this->param1 + GET_OFF(ur_w)]);
469     mov(reg_input, reg_input_base);
470     mov(reg_output, reg_output_base);
471     mov(reg_kernel, reg_kernel_base);
472
473     push(reg_input_base);
474     push(reg_output_base);
475     push(reg_kernel_base);
476     push(reg_ch_work);
477     push(reg_scales_base);
478     push(reg_oc_off);
479
480     L(unrolled_w_label); {
481         int ur_w = jcp.ur_w;
482
483         cmp(reg_ur_w, ur_w);
484         jl(tail_w_label, T_NEAR);
485
486         mov(aux_reg_input, reg_input);
487         mov(aux_reg_kernel, reg_kernel);
488
489         load_src(ur_ch_blocks, ch_step, ur_w);
490         apply_filter_unrolled(ur_ch_blocks, ch_step, ur_w);
491         store_dst(ur_ch_blocks, ch_step, ur_w);
492
493         add(reg_input, jcp.typesize_in * ur_w * jcp.ic * jcp.stride_w);
494         add(reg_output, jcp.typesize_out * ur_w * jcp.oc);
495
496         sub(reg_ur_w, ur_w);
497         jmp(unrolled_w_label);
498     }
499
500     L(tail_w_label); {
501         int ur_w = 1;
502
503         cmp(reg_ur_w, ur_w);
504         jl(exit_label, T_NEAR);
505
506         mov(aux_reg_input, reg_input);
507         mov(aux_reg_kernel, reg_kernel);
508
509         load_src(ur_ch_blocks, ch_step, ur_w);
510         apply_filter(ur_ch_blocks, ch_step, ur_w);
511         store_dst(ur_ch_blocks, ch_step, ur_w);
512
513         add(reg_input, jcp.typesize_in * ur_w * jcp.ic * jcp.stride_w);
514         add(reg_output, jcp.typesize_out * ur_w * jcp.oc);
515
516         sub(reg_ur_w, ur_w);
517         jmp(tail_w_label);
518     }
519
520     L(exit_label);
521
522     pop(reg_oc_off);
523     pop(reg_scales_base);
524     pop(reg_ch_work);
525     pop(reg_kernel_base);
526     pop(reg_output_base);
527     pop(reg_input_base);
528 }
529
530 template <cpu_isa_t isa>
531 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::generate() {
532     const auto &p = attr_.post_ops_;
533     for (int i = 0; i < p.len_; i++) {
534         auto &post_op = p.entry_[i];
535         if (post_op.is_eltwise()) {
536             eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>(
537                     this,
538                     post_op.eltwise.alg,
539                     post_op.eltwise.alpha,
540                     post_op.eltwise.beta
541             ));
542         } else if (post_op.is_depthwise()) {
543             depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<isa>(
544                     this,
545                     post_op.depthwise.alg
546             ));
547         }
548     }
549
550     this->preamble();
551
552     mov(reg_input_base, ptr[this->param1 + GET_OFF(src)]);
553     mov(reg_output_base, ptr[this->param1 + GET_OFF(dst)]);
554     mov(reg_kernel_base, ptr[this->param1 + GET_OFF(filt)]);
555     if (jcp.with_bias)
556         mov(reg_bias_base, ptr[this->param1 + GET_OFF(bias)]);
557     mov(reg_scales_base, ptr[this->param1 + GET_OFF(scales)]);
558     mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]);
559     mov(reg_ch_work, ptr[this->param1 + GET_OFF(ch_work)]);
560     mov(reg_oc_off, ptr[this->param1 + GET_OFF(oc_off)]);
561
562     Label main_loop_label;
563     Label tail_loop_label;
564     Label exit_label;
565
566     cmp(reg_ch_work, jcp.nb_ch_blocking * jcp.ch_block);
567     jne(main_loop_label, T_NEAR);
568
569     loop_body(jcp.nb_ch_blocking, jcp.nb_ch_blocking * jcp.ch_block);
570
571     sub(reg_ch_work, jcp.nb_ch_blocking * jcp.ch_block);
572
573     jmp(exit_label, T_NEAR);
574
575     L(main_loop_label); {
576         cmp(reg_ch_work, jcp.ch_block);
577         jl(tail_loop_label, T_NEAR);
578
579         loop_body(1, jcp.ch_block);
580
581         sub(reg_ch_work, jcp.ch_block);
582         add(reg_input_base, jcp.ch_block * jcp.typesize_in);
583         add(reg_output_base, jcp.ch_block * jcp.typesize_out);
584         add(reg_kernel_base, jcp.ch_block * jcp.kd * jcp.kh * jcp.kw * jcp.typesize_in);
585         add(reg_bias_base, jcp.ch_block * jcp.typesize_bia);
586         add(reg_scales_base, jcp.is_oc_scale * jcp.ch_block * sizeof(float));
587         add(reg_oc_off, jcp.ch_block * sizeof(float));
588
589         jmp(main_loop_label, T_NEAR);
590     }
591
592     L(tail_loop_label); {
593         cmp(reg_ch_work, 1);
594         jl(exit_label, T_NEAR);
595
596         loop_body(1, 1);
597
598         sub(reg_ch_work, 1);
599         add(reg_input_base, 1 * jcp.typesize_in);
600         add(reg_output_base, 1 * jcp.typesize_out);
601         add(reg_kernel_base, 1 * jcp.typesize_in);
602         add(reg_bias_base, 1 * jcp.typesize_bia);
603         add(reg_scales_base, jcp.is_oc_scale * 1 * sizeof(float));
604         add(reg_oc_off, 1 * sizeof(float));
605
606         jmp(tail_loop_label, T_NEAR);
607     }
608
609     L(exit_label);
610
611     this->postamble();
612
613     prepare_table();
614
615     for (auto& inj : eltwise_injectors)
616         inj->prepare_table();
617 }
618
619 template <cpu_isa_t isa>
620 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::prepare_table() {
621     const auto &p = attr_.post_ops_;
622     const int sum_idx = p.find(primitive_kind::sum);
623     const float p_sum_scale = (sum_idx != -1) ? p.entry_[sum_idx].sum.scale : 1.f;
624
625     const int32_t cvals_sum_scale[] = {
626         float2int(p_sum_scale)
627     };
628
629     align(64);
630     L(l_table);
631     for (size_t i = 0; i < sizeof(cvals_sum_scale) / sizeof(cvals_sum_scale[0]); ++i) {
632         for (size_t d = 0; d < vlen / sizeof(int32_t); ++d) {
633             dd(cvals_sum_scale[i]);
634         }
635     }
636 }
637
638 template <cpu_isa_t isa>
639 bool jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::post_ops_ok(
640         jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
641     const auto &p = attr.post_ops_;
642
643     auto all_post_ops_supported = [&]() {
644         bool ok = true;
645
646         for (int i = 0; i < p.len_; i++) {
647             ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::sum, primitive_kind::eltwise, primitive_kind::depthwise);
648         }
649         return ok;
650     };
651     auto count = [&](mkldnn::impl::primitive_kind_t kind) { return p.count(kind); };
652
653     return all_post_ops_supported() &&
654            count(primitive_kind::sum) <= 1;
655 }
656
657 template <cpu_isa_t isa>
658 status_t jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::init_conf(jit_conv_conf_t &jcp,
659         const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
660         const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
661         const memory_desc_wrapper &bias_pd, const primitive_attr_t &attr)
662 {
663     if (!mayiuse(isa)) return status::unimplemented;
664
665     if (!(src_d.data_type() == data_type::u8 &&
666           weights_d.data_type() == data_type::s8 &&
667           one_of(dst_d.data_type(), data_type::f32, data_type::s32, data_type::s8, data_type::u8)))
668         return status::unimplemented;
669
670     jcp.prop_kind = cd.prop_kind;
671
672     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
673     if (!with_groups) return status::unimplemented;
674
675     int ndims = src_d.ndims();
676     jcp.ndims = ndims;
677
678     jcp.ngroups = weights_d.dims()[0];
679     jcp.mb = src_d.dims()[0];
680
681     jcp.oc = dst_d.dims()[1];
682     jcp.ic = src_d.dims()[1];
683
684     jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
685     jcp.ih = src_d.dims()[ndims - 2];
686     jcp.iw = src_d.dims()[ndims - 1];
687     jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
688     jcp.oh = dst_d.dims()[ndims - 2];
689     jcp.ow = dst_d.dims()[ndims - 1];
690
691     jcp.kd = (ndims == 5) ? weights_d.dims()[3] : 1;
692     jcp.kh = weights_d.dims()[ndims - 1];
693     jcp.kw = weights_d.dims()[ndims];
694
695     jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
696     jcp.t_pad = cd.padding[0][ndims - 4];
697     jcp.l_pad = cd.padding[0][ndims - 3];
698     jcp.back_pad = (ndims == 5) ? cd.padding[1][0] : 0;
699     jcp.b_pad = cd.padding[1][ndims - 4];
700     jcp.r_pad = cd.padding[1][ndims - 3];
701
702     jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
703     jcp.stride_h = cd.strides[ndims - 4];
704     jcp.stride_w = cd.strides[ndims - 3];
705
706     jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
707     jcp.dilate_h = cd.dilates[ndims - 4];
708     jcp.dilate_w = cd.dilates[ndims - 3];
709
710     jcp.src_fmt = src_d.format();
711     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
712
713     jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false;
714
715     if (jcp.signed_input)
716         return status::unimplemented;
717
718     const int simd_w = isa == avx512_common ? 16 : 8;
719     jcp.ch_block = simd_w;
720     jcp.nb_ch = div_up(jcp.oc, jcp.ch_block);
721
722     if (!post_ops_ok(jcp, attr))
723         return status::unimplemented;
724
725     const auto &p = attr.post_ops_;
726     jcp.with_sum = p.find(primitive_kind::sum) != -1;
727     const int eltwise_ind = p.find(primitive_kind::eltwise);
728     jcp.with_eltwise = eltwise_ind != -1;
729     if (jcp.with_eltwise)
730         jcp.eltwise = p.entry_[eltwise_ind].eltwise;
731
732     auto desired_act_fmt = (ndims == 5) ? ndhwc : nhwc;
733     auto desired_wei_fmt = (ndims == 5) ? isa == avx512_common ? Goidhw16g : Goidhw8g
734                                         : isa == avx512_common ? Goihw16g : Goihw8g;
735
736     bool args_ok = true
737         && jcp.oc == jcp.ngroups
738         && jcp.ic == jcp.ngroups
739         && src_d.format() == desired_act_fmt
740         && weights_d.format() == desired_wei_fmt
741         && one_of(cd.bias_desc.format, memory_format::undef, any, x)
742         && dst_d.format() == desired_act_fmt;
743     if (!args_ok) return status::unimplemented;
744
745     jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
746     jcp.dst_dt = cd.dst_desc.data_type;
747
748     jcp.typesize_in = types::data_type_size(src_d.data_type());
749     jcp.typesize_out = types::data_type_size(dst_d.data_type());
750     jcp.typesize_acc = sizeof(int32_t);
751     jcp.typesize_bia = jcp.with_bias
752                        ? types::data_type_size(bias_pd.data_type())
753                        : 0;
754
755     const auto &oscales = attr.output_scales_;
756     jcp.is_oc_scale = oscales.mask_ == 1 << 1;
757
758     assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
759
760     jcp.ur_w = isa == avx512_common ? 6 : isa == avx2 ? 4 : 3;
761
762     jcp.nb_ch_blocking = isa == avx512_common ? 4 : isa == avx2 ? 3 : 2;
763     if (jcp.nb_ch < jcp.nb_ch_blocking)
764         jcp.nb_ch_blocking = jcp.nb_ch;
765
766     return status::success;
767 }
768
769 template struct jit_uni_x8s8s32x_dw_conv_fwd_kernel<avx2>;
770 template struct jit_uni_x8s8s32x_dw_conv_fwd_kernel<sse42>;
771
772 }
773 }
774 }