updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_x8s8s32x_conv_kernel.cpp
1 /*******************************************************************************
2 * Copyright 2018-2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <common/memory_tracking.hpp>
18 #include <common/primitive_attr.hpp>
19 #include "c_types_map.hpp"
20 #include "nstl.hpp"
21 #include "type_helpers.hpp"
22 #include "utils.hpp"
23 #include "cpu_memory.hpp"
24
25 #include "jit_uni_x8s8s32x_conv_kernel.hpp"
26
27 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
28
29 namespace mkldnn {
30 namespace impl {
31 namespace cpu {
32
33 using namespace mkldnn::impl::prop_kind;
34 using namespace mkldnn::impl::memory_format;
35 using namespace mkldnn::impl::memory_tracking::names;
36 using namespace mkldnn::impl::utils;
37
38 using namespace Xbyak;
39
40 template <cpu_isa_t isa>
41 void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::cvt2ps(data_type_t type_in, Vmm vmm_in,
42         const Xbyak::Operand &op, bool scalar_load) {
43     Xmm xmm_in = Xmm(vmm_in.getIdx());
44
45     switch (type_in) {
46         case data_type::f32:
47         case data_type::s32:
48             if (scalar_load) {
49                 movsd(xmm_in, op);
50             } else {
51                 uni_vmovups(vmm_in, op);
52             }
53             break;
54         case data_type::s8:
55             if (scalar_load) {
56                 movsx(reg_tmp_32, op);
57                 movq(xmm_in, reg_tmp_64);
58             } else {
59                 uni_vpmovsxbd(vmm_in, op);
60             }
61             break;
62         case data_type::u8:
63             if (scalar_load) {
64                 movzx(reg_tmp_32, op);
65                 movq(xmm_in, reg_tmp_64);
66             } else {
67                 uni_vpmovzxbd(vmm_in, op);
68             }
69             break;
70         default: assert(!"unsupported data type");
71     }
72
73     if (type_in != data_type::f32)
74         uni_vcvtdq2ps(vmm_in, vmm_in);
75 }
76
77 template <cpu_isa_t isa>
78 void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::store_dst(const Xbyak::Address &op, Vmm vmm_dst, bool scalar_store, bool need_pack) {
79     Ymm ymm_dst = Ymm(vmm_dst.getIdx());
80     Xmm xmm_dst = Xmm(vmm_dst.getIdx());
81
82     switch (jcp.dst_dt) {
83         case data_type::f32:
84         case data_type::s32:
85             if (scalar_store) {
86                 movq(reg_tmp_64, xmm_dst);
87                 mov(op, reg_tmp_32);
88             } else {
89                 uni_vmovups(op, vmm_dst);
90             }
91             break;
92         case data_type::s8:
93             if (need_pack)
94                 uni_vpackssdw(vmm_dst, vmm_dst, vmm_dst);
95
96             if (isa != sse42 && !scalar_store && need_pack)
97                 vpermq(ymm_dst, ymm_dst, 0x08);
98
99             if (need_pack)
100                 uni_vpacksswb(vmm_dst, vmm_dst, vmm_dst);
101
102             if (scalar_store) {
103                 movq(reg_tmp_64, xmm_dst);
104                 mov(op, reg_tmp_8);
105             } else {
106                 if (isa != sse42)
107                     vmovq(op, xmm_dst);
108                 else
109                     movd(op, xmm_dst);
110             }
111             break;
112         case data_type::u8:
113             if (need_pack)
114                 uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst);
115
116             if (isa != sse42 && !scalar_store && need_pack)
117                 vpermq(ymm_dst, ymm_dst, 0x08);
118
119             if (need_pack)
120                 uni_vpackuswb(vmm_dst, vmm_dst, vmm_dst);
121
122             if (scalar_store) {
123                 movq(reg_tmp_64, xmm_dst);
124                 mov(op, reg_tmp_8);
125             } else {
126                 if (isa != sse42)
127                     vmovq(op, xmm_dst);
128                 else
129                     movd(op, xmm_dst);
130             }
131
132             break;
133         default:
134             assert(!"unknown dst_dt");
135     }
136 }
137
138 template <cpu_isa_t isa>
139 void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::apply_filter(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step,
140         int tail_size, bool h_padded) {
141     int kw = jcp.kw;
142     int kh = jcp.kh;
143     int kd = jcp.kd;
144     int nb_ic = jcp.nb_ic;
145     int stride_w = jcp.stride_w;
146     int dilate_w = jcp.dilate_w + 1;
147     int ic_blk = jcp.ic_block;
148     int oc_blk = jcp.oc_block;
149
150     int repeats = isa == sse42 && oc_step > (oc_blk / 2) ? 2 : 1;
151
152     for (int ki = 0; ki < kw; ki++) {
153         int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w));
154         int jj_end = ur_w - nstl::max(0, div_up(ki*dilate_w+pad_r-(kw-1)*dilate_w, stride_w));
155
156         int _start = (jcp.signed_input) ? 0 : jj_start;
157         int _end = (jcp.signed_input) ? ur_w : jj_end;
158
159         for (int r = 0; r < repeats; r++) {
160             for (int jj = _start; jj < _end; jj++) {
161                 int inp_off = (ki * dilate_w + jj * stride_w - pad_l) * jcp.ic * jcp.ngroups;
162                 if (tail_size > 0) {
163                     if (h_padded || jj < jj_start || jj >= jj_end) {
164                         uni_vpxor(get_src_reg(jj), get_src_reg(jj), get_src_reg(jj));
165                         uni_vpsubb(get_src_reg(jj), get_src_reg(jj), vmm_shift);
166                     } else {
167                         uni_vpbroadcastd(get_src_reg(jj), ptr[aux1_reg_input + jcp.typesize_in * inp_off]);
168
169                         if (jcp.signed_input) {
170                             uni_vpsubb(get_src_reg(jj), get_src_reg(jj), vmm_shift);
171                         }
172                     }
173                 } else {
174                     if (h_padded || jj < jj_start || jj >= jj_end) {
175                         uni_vpxor(get_src_reg(jj), get_src_reg(jj), get_src_reg(jj));
176                     } else {
177                         uni_vpbroadcastd(get_src_reg(jj), ptr[aux1_reg_input + jcp.typesize_in * inp_off]);
178                     }
179
180                     if (jcp.signed_input)
181                         uni_vpsubb(get_src_reg(jj), get_src_reg(jj), vmm_shift);
182                 }
183             }
184
185             for (int ii = 0; ii < oc_blocks; ii++) {
186                 int ker_off = ii * nb_ic * kd * kh * kw * ic_blk * oc_blk + ki * ic_blk * oc_blk + r * ic_blk * (oc_blk / 2);
187                 uni_vmovups(get_ker_reg(0), ptr[aux1_reg_kernel + jcp.typesize_in * ker_off]);
188
189                 for (int jj = _start; jj < _end; jj++) {
190                     Vmm vmm_src = get_src_reg(jj);
191                     if (isa == sse42) {
192                         uni_vmovups(get_tmp_reg(0), vmm_src);
193                         uni_vpmaddubsw(get_tmp_reg(0), get_tmp_reg(0), get_ker_reg(0));
194                     } else {
195                         uni_vpmaddubsw(get_tmp_reg(0), vmm_src, get_ker_reg(0));
196                     }
197                     uni_vpmaddwd(get_tmp_reg(0), get_tmp_reg(0), vmm_one);
198                     uni_vpaddd(get_acc_reg(r*jcp.ur_w*jcp.nb_oc_blocking + ur_w * ii + jj),
199                                get_acc_reg(r*jcp.ur_w*jcp.nb_oc_blocking + ur_w * ii + jj), get_tmp_reg(0));
200                 }
201             }
202         }
203     }
204 }
205
206 template <cpu_isa_t isa>
207 void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::oh_step_unroll_kw(int ur_w,
208         int pad_l, int pad_r, int oc_blocks, int oc_step, bool h_padded) {
209     int kw = jcp.kw;
210     int ic_blk = jcp.ic_block;
211     int oc_blk = jcp.oc_block;
212
213     Label ic_main_loop;
214     Label ic_tail;
215     Label exit;
216
217     mov(aux1_reg_input, aux_reg_input);
218     mov(aux1_reg_kernel, aux_reg_kernel);
219
220     mov(reg_ic_iter, jcp.ic);
221
222     L(ic_main_loop); {
223         cmp(reg_ic_iter, ic_blk);
224         jl(ic_tail, T_NEAR);
225
226         apply_filter(ur_w, pad_l, pad_r, oc_blocks, oc_step, 0, h_padded);
227
228         add(aux1_reg_input, ic_blk * jcp.typesize_in);
229         add(aux1_reg_kernel, kw * ic_blk * oc_blk * jcp.typesize_in);
230         sub(reg_ic_iter, ic_blk);
231         jmp(ic_main_loop, T_NEAR);
232     }
233
234     L(ic_tail);
235     int ic_tail_size = jcp.ic % jcp.ic_block;
236
237     if (ic_tail_size > 0)
238         apply_filter(ur_w, pad_l, pad_r, oc_blocks, oc_step, ic_tail_size, h_padded);
239
240     L(exit);
241 }
242
243 template <cpu_isa_t isa>
244 void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::kh_loop(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step) {
245     int iw = jcp.iw;
246     int kw = jcp.kw;
247     int dilate_h = jcp.dilate_h + 1;
248     const int inp_mult = jcp.ic * dilate_h * jcp.ngroups;
249
250     Label t_overflow_label, no_t_overflow_label,
251           b_overflow_label, no_b_overflow_label;
252
253     auto h_overflow_func = [&] () {
254         Label h_overflow_label, no_h_overflow_label;
255         cmp(reg_overflow, 0);
256         je(no_h_overflow_label, T_NEAR);
257         L(h_overflow_label); {
258             oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks, oc_step, true);
259
260             add(aux_reg_kernel, jcp.typesize_in * kw * jcp.oc_block * rnd_up(jcp.ic, jcp.ic_block));
261             dec(reg_overflow);
262             cmp(reg_overflow, 0);
263             jg(h_overflow_label, T_NEAR);
264         }
265         L(no_h_overflow_label);
266     };
267
268     if (jcp.signed_input) {
269         mov(reg_overflow,  ptr[param1 + GET_OFF(t_overflow)]);
270         h_overflow_func();
271     }
272
273     Label skip_kh_loop;
274     mov(reg_kj, ptr[this->param1 + GET_OFF(kh_padding)]);
275     if ((jcp.signed_input) || (!jcp.signed_input &&
276                                (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad))) {
277         cmp(reg_kj, 0);
278         je(skip_kh_loop, T_NEAR);
279     }
280
281     Label kh_label;
282     L(kh_label);
283     {
284         oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks, oc_step, false);
285
286         add(aux_reg_kernel, jcp.typesize_in * kw * jcp.oc_block * rnd_up(jcp.ic, jcp.ic_block));
287         add(aux_reg_input, jcp.typesize_in * iw * inp_mult);
288
289         dec(reg_kj);
290         cmp(reg_kj, 0);
291         jg(kh_label, T_NEAR);
292     }
293
294     L(skip_kh_loop);
295
296     if (jcp.signed_input) {
297         mov(reg_overflow,  ptr[param1 + GET_OFF(b_overflow)]);
298         h_overflow_func();
299     }
300 }
301
302 template <cpu_isa_t isa>
303 void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::kd_loop(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step) {
304     int iw = jcp.iw;
305     int ih = jcp.ih;
306     int kw = jcp.kw;
307     int kh = jcp.kh;
308     int dilate_d = jcp.dilate_d + 1;
309
310     auto d_overflow_func = [&] () {
311         Label d_overflow_label, no_d_overflow_label, aux_kh_label;
312         cmp(reg_overflow, 0);
313         je(no_d_overflow_label, T_NEAR);
314         L(d_overflow_label);
315         {
316             push(reg_overflow);
317
318             mov(reg_kj, jcp.kh);
319             L(aux_kh_label);
320             {
321                 oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks, oc_step, true);
322
323                 add(aux_reg_kernel, jcp.typesize_in * kw * jcp.oc_block * rnd_up(jcp.ic, jcp.ic_block));
324                 dec(reg_kj);
325                 cmp(reg_kj, 0);
326                 jg(aux_kh_label, T_NEAR);
327             }
328
329             pop(reg_overflow);
330
331             add(aux_reg_ker_d, jcp.typesize_in * kh * kw * jcp.oc_block * rnd_up(jcp.ic, jcp.ic_block));
332             mov(aux_reg_kernel, aux_reg_ker_d);
333             dec(reg_overflow);
334             cmp(reg_overflow, 0);
335             jg(d_overflow_label);
336         }
337         L(no_d_overflow_label);
338     };
339
340     push(reg_output);
341     push(reg_oi_iter);
342     push(reg_compensation_base);
343
344     mov(aux_reg_inp_d, reg_input);
345     mov(aux_reg_ker_d, reg_kernel);
346
347     if (jcp.signed_input) {
348         mov(reg_overflow, ptr[param1 + GET_OFF(front_overflow)]);
349         d_overflow_func();
350     }
351
352     Label skip_kd_loop;
353     mov(reg_kd, ptr[param1 + GET_OFF(kd_padding)]);
354     cmp(reg_kd, 0);
355     je(skip_kd_loop, T_NEAR);
356
357     Label kd_label;
358     L(kd_label);
359     {
360         kh_loop(ur_w, pad_l, pad_r, oc_blocks, oc_step);
361
362         add(aux_reg_inp_d, jcp.typesize_in * dilate_d * ih * iw * jcp.ic * jcp.ngroups);
363         add(aux_reg_ker_d, jcp.typesize_in * kh * kw * jcp.oc_block * rnd_up(jcp.ic, jcp.ic_block));
364         mov(aux_reg_input, aux_reg_inp_d);
365         mov(aux_reg_kernel, aux_reg_ker_d);
366
367         dec(reg_kd);
368         cmp(reg_kd, 0);
369         jg(kd_label, T_NEAR);
370     }
371
372     L(skip_kd_loop);
373
374     if (jcp.signed_input) {
375         mov(reg_overflow, ptr[param1 + GET_OFF(back_overflow)]);
376         d_overflow_func();
377     }
378
379     pop(reg_compensation_base);
380     pop(reg_oi_iter);
381     pop(reg_output);
382 }
383
384 template <cpu_isa_t isa>
385 void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::width_blk_step(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step)
386 {
387     int repeats = isa == sse42 && oc_step > (jcp.oc_block / 2) ? 2 : 1;
388
389     for (int r = 0; r < repeats; r++)
390         for (int ii = 0; ii < oc_blocks; ii++)
391             for (int jj = 0; jj < ur_w; jj++)
392                 uni_vpxor(get_acc_reg(r*jcp.ur_w*jcp.nb_oc_blocking + ur_w * ii + jj),
393                           get_acc_reg(r*jcp.ur_w*jcp.nb_oc_blocking + ur_w * ii + jj),
394                           get_acc_reg(r*jcp.ur_w*jcp.nb_oc_blocking + ur_w * ii + jj));
395
396     mov(imm_addr64, l_table);
397     uni_vmovups(vmm_one,   ptr[imm_addr64 + 0 * vlen]);
398     uni_vmovups(vmm_shift, ptr[imm_addr64 + 1 * vlen]);
399
400     mov(aux_reg_input, reg_input);
401     mov(aux_reg_kernel, reg_kernel);
402
403     if (jcp.ndims == 5) {
404         kd_loop(ur_w, pad_l, pad_r, oc_blocks, oc_step);
405     } else {
406         kh_loop(ur_w, pad_l, pad_r, oc_blocks, oc_step);
407     }
408
409     pop(reg_oc_off);
410     pop(reg_scales_base);
411
412     mov(imm_addr64, l_table);
413     uni_vmovups(vmm_bias_alpha, ptr[imm_addr64 + 2 * vlen]);
414
415     const auto &p = attr_.post_ops_;
416     const int sum_idx = p.find(primitive_kind::sum);
417     const float p_sum_scale = (sum_idx != -1) ? p.entry_[sum_idx].sum.scale : 1.f;
418
419     for (int r = 0; r < repeats; r++) {
420         auto get_dst_off = [=](int ii, int jj) {
421             if (jcp.with_dw_conv)
422                 return (ii * jcp_dw.kh * jcp.ow + jj) * jcp.oc_block + r * (jcp.oc_block / 2);
423             else
424                 return ii * jcp.oc_block + jj * jcp.oc * jcp.ngroups + r * (jcp.oc_block / 2);
425         };
426
427         int tail_size = isa == avx2 ? oc_step : nstl::min(jcp.oc_block / 2, oc_step - r * jcp.oc_block / 2);
428         bool is_scalar_store = isa == avx2 ? tail_size < jcp.oc_block : tail_size < jcp.oc_block / 2;
429
430         for (int ii = 0; ii < oc_blocks; ii++) {
431             if (jcp.with_bias) {
432                 int b_off = ii * jcp.oc_block + r * (jcp.oc_block / 2);
433                 cvt2ps(jcp.bia_dt, vmm_bias, ptr[reg_bias_base + b_off * jcp.typesize_bia], false);
434
435                 if (jcp.signed_input)
436                     uni_vmulps(vmm_bias, vmm_bias, vmm_bias_alpha);
437             }
438
439             for (int jj = 0; jj < ur_w; jj++) {
440                 Vmm vmm_dst = get_acc_reg(r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj);
441                 uni_vcvtdq2ps(vmm_dst, vmm_dst);
442
443                 if (jcp.signed_input) {
444                     int c_off = ii * jcp.oc_block + r * (jcp.oc_block / 2);
445                     cvt2ps(data_type::s32, vmm_comp, ptr[reg_compensation_base + c_off * sizeof(int32_t)], false);
446                 }
447
448                 if (jcp.signed_input)
449                     uni_vaddps(vmm_dst, vmm_dst, vmm_comp);
450                 if (jcp.with_bias)
451                     uni_vaddps(vmm_dst, vmm_dst, vmm_bias);
452
453                 int s_off = jcp.is_oc_scale * (ii * jcp.oc_block + r * (jcp.oc_block / 2));
454                 cvt2ps(mkldnn_f32, vmm_scale, ptr[reg_scales_base + s_off * sizeof(float)], false);
455                 uni_vmulps(vmm_dst, vmm_dst, vmm_scale);
456             }
457         }
458
459         int eltwise_inj_idx = 0;
460         int depthwise_inj_idx = 0;
461         int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
462         for (int i = 0; i < end_idx; i++) {
463             int start_idx = 1 + r * jcp.ur_w * jcp.nb_oc_blocking;
464
465             auto& post_op = p.entry_[i];
466             if (post_op.is_eltwise()) {
467                 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(start_idx, start_idx + oc_blocks * ur_w);
468                 eltwise_inj_idx++;
469             } else if (post_op.is_depthwise()) {
470                 mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
471                 mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
472
473                 add(reg_d_weights, reg_oc_off);
474                 add(reg_d_bias, reg_oc_off);
475
476                 if (r == 1) {
477                     add(reg_d_weights, (jcp.oc_block / 2) * sizeof(float));
478                     add(reg_d_bias, (jcp.oc_block / 2) * sizeof(float));
479                 }
480
481                 for (int ii = 0; ii < oc_blocks; ii++) {
482                     depthwise_injectors[depthwise_inj_idx]->compute_vector_range(start_idx + ur_w * ii,
483                             start_idx + ur_w * ii + ur_w, reg_d_weights, reg_d_bias);
484
485                     add(reg_d_weights, jcp.oc_block * sizeof(float));
486                     add(reg_d_bias, jcp.oc_block * sizeof(float));
487                 }
488
489                 depthwise_inj_idx++;
490             } else if (post_op.is_sum(false)) {
491                 for (int ii = 0; ii < oc_blocks; ii++) {
492                     for (int jj = 0; jj < ur_w; jj++) {
493                         Vmm vmm_dst = get_acc_reg(r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj);
494                         int o_off = get_dst_off(ii, jj);
495
496                         if (is_scalar_store) {
497                             for (int oc = 0; oc < tail_size; oc++) {
498                                 uni_vpxor(vmm_prev_dst, vmm_prev_dst, vmm_prev_dst);
499                                 cvt2ps(jcp.dst_dt, vmm_prev_dst, ptr[reg_output + (o_off + oc) * jcp.typesize_out], true);
500
501                                 if (oc < jcp.oc_block / 2) {
502                                     uni_vpslldq(vmm_prev_dst, vmm_prev_dst, oc * sizeof(float));
503                                 } else {
504                                     Ymm ymm_prev_dst = Ymm(vmm_prev_dst.getIdx());
505                                     vperm2i128(ymm_prev_dst, ymm_prev_dst, ymm_prev_dst, 0x01);
506                                     vpslldq(vmm_prev_dst, vmm_prev_dst, (oc - jcp.oc_block / 2) * sizeof(float));
507                                 }
508
509                                 if (p_sum_scale == 1.f) {
510                                     uni_vaddps(vmm_dst, vmm_dst, vmm_prev_dst);
511                                 } else {
512                                     uni_vfmadd231ps(vmm_dst, vmm_prev_dst, ptr[imm_addr64 + 3 * vlen]);
513                                 }
514                             }
515                         } else {
516                             cvt2ps(jcp.dst_dt, vmm_prev_dst, ptr[reg_output + o_off * jcp.typesize_out], false);
517
518                             if (p_sum_scale == 1.f) {
519                                 uni_vaddps(vmm_dst, vmm_dst, vmm_prev_dst);
520                             } else {
521                                 uni_vfmadd231ps(vmm_dst, vmm_prev_dst, ptr[imm_addr64 + 3 * vlen]);
522                             }
523                         }
524                     }
525                 }
526             }
527         }
528
529         for (int ii = 0; ii < oc_blocks; ii++) {
530             for (int jj = 0; jj < ur_w; jj++) {
531                 Vmm vmm_dst = get_acc_reg(r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj);
532                 int o_off = get_dst_off(ii, jj);
533
534                 if (jcp.dst_dt != data_type::f32) {
535                     if (attr_.round_mode_ == round_mode::nearest)
536                         uni_vcvtps2dq(vmm_dst, vmm_dst);
537                     else if (attr_.round_mode_ == round_mode::down) {
538                         uni_vroundps(vmm_dst, vmm_dst, 1);
539                         uni_vcvtps2dq(vmm_dst, vmm_dst);
540                     } else
541                         assert(!"unimplemented");
542                 }
543
544                 if (is_scalar_store) {
545                     for (int oc = 0; oc < tail_size; oc++) {
546                         store_dst(ptr[reg_output + (o_off + oc) * jcp.typesize_out], vmm_dst, true, oc == 0);
547
548                         if (isa == avx2) {
549                             Ymm ymm_dst = Ymm(vmm_dst.getIdx());
550                             vperm2i128(ymm_tmp, ymm_dst, ymm_dst, 0x01);
551                             vpalignr(ymm_dst, ymm_tmp, ymm_dst, jcp.typesize_out);
552                         } else {
553                             psrldq(vmm_dst, jcp.typesize_out);
554                         }
555                     }
556                 } else {
557                     store_dst(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, false);
558                 }
559             }
560         }
561     }
562
563     push(reg_scales_base);
564     push(reg_oc_off);
565 }
566
567 template <cpu_isa_t isa>
568 inline void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::solve_common(int oc_blocks, int oc_step)
569 {
570     int ur_w = jcp.ur_w;
571     int ur_w_tail = jcp.ur_w_tail;
572     int n_oi = jcp.ow / ur_w;
573     int iw = jcp.iw;
574     int kw = jcp.kw;
575     int dilate_w = jcp.dilate_w + 1;
576     int str_w = jcp.stride_w;
577     const int inp_mult = jcp.ic * jcp.ngroups;
578     const int out_mult = jcp.with_dw_conv ? jcp.oc_block : jcp.oc * jcp.ngroups;
579
580     int l_pad = jcp.l_pad;
581     int r_pad = nstl::max(0, (jcp.ow - 1) * str_w + (kw - 1) * dilate_w
582             - (iw + l_pad - 1));
583     int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w
584             - (iw + l_pad - 1);
585     if (r_pad1 > 0) n_oi--;
586
587     mov(reg_input, reg_input_base);
588     mov(reg_output, reg_output_base);
589     mov(reg_kernel, reg_kernel_base);
590
591     push(reg_input_base);
592     push(reg_output_base);
593     push(reg_kernel_base);
594     push(reg_scales_base);
595     push(reg_oc_off);
596
597     if (l_pad > 0) {
598         n_oi--;
599         if (n_oi < 0 && r_pad1 > 0)
600             width_blk_step(ur_w, l_pad, r_pad1, oc_blocks, oc_step); // "lrpad"
601         else
602             width_blk_step(ur_w, l_pad, 0, oc_blocks, oc_step); // "lpad"
603         add(reg_input, jcp.typesize_in * (ur_w * str_w - l_pad) * inp_mult);
604         add(reg_output, jcp.typesize_out * ur_w * out_mult);
605     }
606
607     Label ow_loop_label;
608     xor_(reg_oi_iter, reg_oi_iter);
609
610     if (n_oi > 0) {
611         L(ow_loop_label);
612
613         width_blk_step(ur_w, 0, 0, oc_blocks, oc_step); // "middle"
614         add(reg_input, jcp.typesize_in * ur_w * str_w * inp_mult);
615         add(reg_output, jcp.typesize_out * ur_w * out_mult);
616
617         inc(reg_oi_iter);
618         cmp(reg_oi_iter, n_oi);
619         jl(ow_loop_label, T_NEAR);
620     }
621
622     if (r_pad1 > 0 && n_oi >=0) {
623         width_blk_step(ur_w, 0, r_pad1, oc_blocks, oc_step); // "rpad"
624         add(reg_input, jcp.typesize_in * ur_w * str_w * inp_mult);
625         add(reg_output, jcp.typesize_out * ur_w * out_mult);
626     }
627
628     if (ur_w_tail != 0)
629         width_blk_step(ur_w_tail, 0, r_pad, oc_blocks, oc_step); // "tail"
630
631     pop(reg_oc_off);
632     pop(reg_scales_base);
633     pop(reg_kernel_base);
634     pop(reg_output_base);
635     pop(reg_input_base);
636 }
637
638 template <cpu_isa_t isa>
639 void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::generate()
640 {
641     const auto &p = attr_.post_ops_;
642     int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
643     for (int i = 0; i < end_idx; i++) {
644         auto &post_op = p.entry_[i];
645         if (post_op.is_eltwise()) {
646             eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>(
647                     this,
648                     post_op.eltwise.alg,
649                     post_op.eltwise.alpha,
650                     post_op.eltwise.beta
651             ));
652         } else if (post_op.is_depthwise()) {
653             depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<isa>(
654                     this,
655                     post_op.depthwise.alg
656             ));
657         }
658     }
659
660     this->preamble();
661
662     mov(reg_kernel_base, ptr[this->param1 + GET_OFF(filt)]);
663     mov(reg_input_base, ptr[this->param1 + GET_OFF(src)]);
664     mov(reg_output_base, ptr[this->param1 + GET_OFF(dst)]);
665     mov(reg_oc_work, ptr[this->param1 + GET_OFF(oc_work)]);
666     if (jcp.with_bias)
667         mov(reg_bias_base, ptr[this->param1 + GET_OFF(bias)]);
668     mov(reg_scales_base, ptr[this->param1 + GET_OFF(scales)]);
669     if (jcp.signed_input)
670         mov(reg_compensation_base, ptr[param1 + GET_OFF(compensation)]);
671     mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]);
672
673     Label main_loop_label;
674     Label tail_label;
675     Label exit_label;
676
677     cmp(reg_oc_work, jcp.nb_oc_blocking * jcp.oc_block);
678     jne(main_loop_label, T_NEAR);
679
680     solve_common(jcp.nb_oc_blocking, jcp.oc_block);
681
682     sub(reg_oc_work, jcp.nb_oc_blocking * jcp.oc_block);
683
684     jmp(exit_label, T_NEAR);
685
686     L(main_loop_label); {
687         cmp(reg_oc_work, jcp.oc_block);
688         jl(tail_label, T_NEAR);
689
690         solve_common(1, jcp.oc_block);
691
692         sub(reg_oc_work, jcp.oc_block);
693         add(reg_kernel_base, jcp.oc_block * jcp.nb_ic * jcp.kd * jcp.kh * jcp.kw * jcp.ic_block * jcp.typesize_in);
694         if (jcp.with_dw_conv)
695             add(reg_output_base, jcp.oc_block * jcp_dw.kh * jcp.ow * jcp.typesize_out);
696         else
697             add(reg_output_base, jcp.oc_block * jcp.typesize_out);
698         add(reg_bias_base, jcp.oc_block * jcp.typesize_bia);
699         add(reg_scales_base, jcp.is_oc_scale * jcp.oc_block * sizeof(float));
700         add(reg_compensation_base, jcp.oc_block * sizeof(int32_t));
701         add(reg_oc_off, jcp.oc_block * sizeof(float));
702
703         jmp(main_loop_label, T_NEAR);
704     }
705
706     L(tail_label);
707
708     if (jcp.oc % jcp.oc_block != 0)
709         solve_common(1, jcp.oc % jcp.oc_block);
710
711     L(exit_label);
712
713     this->postamble();
714
715     prepare_table();
716
717     for (auto& inj : eltwise_injectors)
718         inj->prepare_table();
719 }
720
721 template <cpu_isa_t isa>
722 void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::prepare_table() {
723     const auto &p = attr_.post_ops_;
724     const int sum_idx = p.find(primitive_kind::sum);
725     const float p_sum_scale = (sum_idx != -1) ? p.entry_[sum_idx].sum.scale : 1.f;
726
727     const uint16_t cvals_one[] = {
728         0x0001,
729     };
730
731     const int8_t cvals_shift[] = {
732         -128,
733     };
734
735     const int32_t cvals_scale[] = {
736         float2int(jcp.wei_adj_scale)
737     };
738
739     const int32_t cvals_sum_scale[] = {
740         float2int(p_sum_scale)
741     };
742
743     align(64);
744     L(l_table);
745     for (size_t i = 0; i < sizeof(cvals_one) / sizeof(cvals_one[0]); ++i) {
746         for (size_t d = 0; d < vlen / sizeof(uint16_t); ++d) {
747             dw(cvals_one[i]);
748         }
749     }
750
751     for (size_t i = 0; i < sizeof(cvals_shift) / sizeof(cvals_shift[0]); ++i) {
752         for (size_t d = 0; d < vlen / sizeof(int8_t); ++d) {
753             db(cvals_shift[i]);
754         }
755     }
756
757     for (size_t i = 0; i < sizeof(cvals_scale) / sizeof(cvals_scale[0]); ++i) {
758         for (size_t d = 0; d < vlen / sizeof(int32_t); ++d) {
759             dd(cvals_scale[i]);
760         }
761     }
762
763     for (size_t i = 0; i < sizeof(cvals_sum_scale) / sizeof(cvals_sum_scale[0]); ++i) {
764         for (size_t d = 0; d < vlen / sizeof(int32_t); ++d) {
765             dd(cvals_sum_scale[i]);
766         }
767     }
768 }
769
770 template <cpu_isa_t isa>
771 bool jit_uni_x8s8s32x_conv_fwd_kernel<isa>::post_ops_ok(
772         jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
773     const auto &p = attr.post_ops_;
774
775     int dw_conv_idx = p.find(primitive_kind::convolution);
776     bool with_dw_conv = dw_conv_idx != -1;
777
778     auto all_post_ops_supported = [&]() {
779         bool ok = true;
780
781         int end_idx = with_dw_conv ? dw_conv_idx : p.len_;
782         for (int i = 0; i < end_idx; i++) {
783             ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::sum, primitive_kind::eltwise, primitive_kind::depthwise);
784         }
785         return ok;
786     };
787     auto contain = [&](mkldnn::impl::primitive_kind_t kind) { return p.find(kind, 0, dw_conv_idx) != -1; };
788     auto count = [&](mkldnn::impl::primitive_kind_t kind) { return p.count(kind, 0, dw_conv_idx); };
789
790     return all_post_ops_supported() &&
791            count(primitive_kind::sum) <= 1 &&
792            IMPLICATION(with_dw_conv, !contain(primitive_kind::sum));
793 }
794
795 template <cpu_isa_t isa>
796 status_t jit_uni_x8s8s32x_conv_fwd_kernel<isa>::init_conf(jit_conv_conf_t &jcp,
797         const convolution_desc_t &cd, cpu_memory_t::pd_t &src_pd,
798         cpu_memory_t::pd_t &weights_pd, cpu_memory_t::pd_t &dst_pd,
799         cpu_memory_t::pd_t &bias_pd,
800         const primitive_attr_t &attr)
801 {
802     if (!mayiuse(isa)) return status::unimplemented;
803
804     const memory_desc_wrapper src_d(&src_pd);
805     const memory_desc_wrapper weights_d(&weights_pd);
806     const memory_desc_wrapper dst_d(&dst_pd);
807     const memory_desc_wrapper bias_d(&bias_pd);
808
809     jcp.prop_kind = cd.prop_kind;
810
811     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
812     int ndims = src_d.ndims();
813     jcp.ndims = ndims;
814
815     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
816     jcp.mb = src_d.dims()[0];
817
818     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
819     jcp.ic = src_d.dims()[1] / jcp.ngroups;
820
821     jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
822     jcp.ih = src_d.dims()[ndims - 2];
823     jcp.iw = src_d.dims()[ndims - 1];
824     jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
825     jcp.oh = dst_d.dims()[ndims - 2];
826     jcp.ow = dst_d.dims()[ndims - 1];
827
828     jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
829     jcp.kh = weights_d.dims()[with_groups + ndims - 2];
830     jcp.kw = weights_d.dims()[with_groups + ndims - 1];
831
832     jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
833     jcp.t_pad = cd.padding[0][ndims - 4];
834     jcp.l_pad = cd.padding[0][ndims - 3];
835
836     jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
837     jcp.stride_h = cd.strides[ndims - 4];
838     jcp.stride_w = cd.strides[ndims - 3];
839
840     jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
841     jcp.dilate_h = cd.dilates[ndims - 4];
842     jcp.dilate_w = cd.dilates[ndims - 3];
843
844     jcp.src_fmt = src_d.format();
845     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
846
847     jcp.signed_input = src_d.data_type() == data_type::s8;
848
849     const int simd_w = 8;
850
851     jcp.ic_block = 4;
852     jcp.nb_ic = div_up(jcp.ic, jcp.ic_block);
853
854     jcp.oc_block = simd_w;
855     jcp.oc_padded = rnd_up(jcp.oc, jcp.oc_block);
856     jcp.nb_oc = div_up(jcp.oc, jcp.oc_block);
857
858     if (jcp.ngroups != 1) {
859         if (jcp.ic % jcp.ic_block != 0 || jcp.oc % jcp.oc_block != 0)
860             return status::unimplemented;
861     }
862
863     jcp.src_dt = cd.src_desc.data_type;
864     jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
865     jcp.dst_dt = cd.dst_desc.data_type;
866
867     if (!post_ops_ok(jcp, attr))
868         return status::unimplemented;
869
870     const auto &p = attr.post_ops_;
871
872     int dw_conv_ind = p.find(primitive_kind::convolution);
873     jcp.with_dw_conv = dw_conv_ind != -1;
874     if (jcp.with_dw_conv) {
875         if (ndims == 5) return status::unimplemented;
876
877         jcp.dw_conv_oh = jcp.oh;
878         jcp.dw_conv_ow = jcp.ow;
879         jcp.oh = p.entry_[dw_conv_ind].dw_conv.in_h;
880         jcp.ow = p.entry_[dw_conv_ind].dw_conv.in_w;
881
882         jcp.dw_conv_dst_dt = jcp.dst_dt;
883         jcp.dst_dt = p.entry_[dw_conv_ind].dw_conv.in_dt;
884     }
885
886     auto desired_act_fmt = (ndims == 5) ? ndhwc : nhwc;
887     auto desired_wei_fmt = (ndims == 5) ? with_groups ? jcp.signed_input ? gOdhIw8o4i_s8s8 : gOdhIw8o4i
888                                                       : jcp.signed_input ? OdhIw8o4i_s8s8 : OdhIw8o4i
889                                         : with_groups ? jcp.signed_input ? gOhIw8o4i_s8s8 : gOhIw8o4i
890                                                       : jcp.signed_input ? OhIw8o4i_s8s8 : OhIw8o4i;
891
892     if (src_d.format() == any)
893         CHECK(src_pd.set_format(desired_act_fmt));
894     if (src_d.format() != desired_act_fmt)
895         return status::unimplemented;
896
897     if (dst_d.format() == any)
898         CHECK(dst_pd.set_format(desired_act_fmt));
899     if (dst_d.format() != desired_act_fmt)
900         return status::unimplemented;
901
902     if (weights_d.format() == any)
903         CHECK(weights_pd.set_format(desired_wei_fmt));
904     if (weights_d.format() != desired_wei_fmt)
905         return status::unimplemented;
906
907     if (jcp.with_bias) {
908         if (bias_d.format() == any)
909             CHECK(bias_pd.set_format(x));
910         if (bias_d.format() != x)
911             return status::unimplemented;
912     }
913
914     jcp.typesize_in = types::data_type_size(src_d.data_type());
915     jcp.typesize_out = types::data_type_size(dst_d.data_type());
916     jcp.typesize_acc = sizeof(int32_t);
917     jcp.typesize_bia = jcp.with_bias
918                        ? types::data_type_size(bias_d.data_type())
919                        : 0;
920
921     const auto &oscales = attr.output_scales_;
922     jcp.is_oc_scale = oscales.mask_ == 1 << 1;
923
924     assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
925
926     jcp.ur_h = 1; /* no code-unrolling by h so far */
927     jcp.ur_w = isa == avx2 ? 4 : 2;
928     jcp.nb_oc_blocking = nstl::min(2, jcp.nb_oc);
929     jcp.max_regs_ur = 12;
930
931     // WA to prevent fallback on gemm implementation
932     if (isa == sse42 && jcp.ic == 3) {
933         jcp.ur_w = 4;
934         jcp.nb_oc_blocking = 1;
935     }
936
937     if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
938     jcp.ur_w_tail = jcp.ow % jcp.ur_w;
939
940     bool args_ok = true
941         && jcp.l_pad <= jcp.ur_w
942         && IMPLICATION(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0)
943                 || (jcp.stride_w == 1 && jcp.stride_h == 1));
944     if (!args_ok) return status::unimplemented;
945
946     int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
947         + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
948     if (r_pad_no_tail > jcp.ur_w)
949         return status::unimplemented;
950
951     if (jcp.l_pad > jcp.ur_w)
952         return status::unimplemented;
953
954     jcp.wei_adj_scale = (jcp.signed_input) ? (1.0f / 2.0f) : 1.0f;
955
956     return status::success;
957 }
958
959 template <cpu_isa_t isa>
960 void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::init_scratchpad(
961         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw,
962         const primitive_attr_t &attr) {
963     if (jcp.oc != jcp.oc_padded)
964         scratchpad.book(key_conv_padded_bias, (size_t)jcp.typesize_bia * jcp.oc_padded);
965
966     if (jcp.signed_input) {
967         size_t count = nstl::max(attr.output_scales_.count_, 8);
968         scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count);
969
970         if (jcp.oc != jcp.oc_padded)
971             scratchpad.book(key_conv_padded_compensation, sizeof(int32_t) * jcp.oc_padded);
972     }
973
974     if (jcp.with_dw_conv) {
975         const int nthreads = mkldnn_get_max_threads();
976         size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * jcp.nb_oc_blocking;
977         scratchpad.book(key_dw_conv_buffer, jcp_dw.typesize_in * dw_conv_buffer_size_ * nthreads);
978
979         if (jcp.oc != jcp.oc_padded)
980             scratchpad.book(key_dw_conv_padded_bias, (size_t)jcp_dw.typesize_bia * jcp.oc_padded);
981     }
982 }
983
984 template struct jit_uni_x8s8s32x_conv_fwd_kernel<avx2>;
985 template struct jit_uni_x8s8s32x_conv_fwd_kernel<sse42>;
986
987 }
988 }
989 }