updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_bin_conv_kernel.cpp
1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <common/primitive_attr.hpp>
18 #include "c_types_map.hpp"
19 #include "nstl.hpp"
20 #include "type_helpers.hpp"
21 #include "utils.hpp"
22 #include "cpu_memory.hpp"
23
24 #include "jit_uni_bin_conv_kernel.hpp"
25
26 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
27
28 namespace mkldnn {
29 namespace impl {
30 namespace cpu {
31
32 using namespace mkldnn::impl::prop_kind;
33 using namespace mkldnn::impl::memory_format;
34 using namespace mkldnn::impl::memory_tracking::names;
35 using namespace mkldnn::impl::utils;
36
37 using namespace Xbyak;
38
39 template <cpu_isa_t isa>
40 void jit_uni_bin_conv_fwd_kernel<isa>::cvt2ps(data_type_t type_in, Vmm vmm_in, const Operand &op, bool scalar_load) {
41     Xmm xmm_in = Xmm(vmm_in.getIdx());
42
43     switch (type_in) {
44         case data_type::f32:
45         case data_type::s32:
46             if (scalar_load) {
47                 mov(reg_tmp_32, op);
48                 movq(xmm_in, reg_tmp_64);
49             } else {
50                 uni_vmovups(vmm_in, op);
51             }
52             break;
53         case data_type::s8:
54             if (scalar_load) {
55                 movsx(reg_tmp_32, op);
56                 movq(xmm_in, reg_tmp_64);
57             } else {
58                 uni_vpmovsxbd(vmm_in, op);
59             }
60             break;
61         case data_type::u8:
62             if (scalar_load) {
63                 movzx(reg_tmp_32, op);
64                 movq(xmm_in, reg_tmp_64);
65             } else {
66                 uni_vpmovzxbd(vmm_in, op);
67             }
68             break;
69         default: assert(!"unsupported data type");
70     }
71
72     if (type_in != data_type::f32)
73         uni_vcvtdq2ps(vmm_in, vmm_in);
74 }
75
76 template <cpu_isa_t isa>
77 void jit_uni_bin_conv_fwd_kernel<isa>::store_dst(const Xbyak::Address &op, Vmm vmm_dst, bool scalar_store) {
78     Ymm ymm_dst = Ymm(vmm_dst.getIdx());
79     Xmm xmm_dst = Xmm(vmm_dst.getIdx());
80
81     switch (jcp.dst_dt) {
82         case data_type::f32:
83         case data_type::s32:
84             if (scalar_store) {
85                 movq(reg_tmp_64, xmm_dst);
86                 mov(op, reg_tmp_32);
87             } else {
88                 uni_vmovups(op, vmm_dst);
89             }
90             break;
91         case data_type::s8:
92             uni_vpackssdw(vmm_dst, vmm_dst, vmm_dst);
93
94             if (isa != sse42 && !scalar_store)
95                 vpermq(ymm_dst, ymm_dst, 0x08);
96
97             uni_vpacksswb(xmm_dst, xmm_dst, xmm_dst);
98
99             if (scalar_store) {
100                 movq(reg_tmp_64, xmm_dst);
101                 mov(op, reg_tmp_8);
102             } else {
103                 if (isa != sse42)
104                     vmovq(op, xmm_dst);
105                 else
106                     movd(op, xmm_dst);
107             }
108             break;
109         case data_type::u8:
110         case data_type::bin:
111             uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst);
112
113             if (isa != sse42 && !scalar_store)
114                 vpermq(ymm_dst, ymm_dst, 0x08);
115
116             uni_vpackuswb(xmm_dst, xmm_dst, xmm_dst);
117
118             if (scalar_store) {
119                 movq(reg_tmp_64, xmm_dst);
120                 mov(op, reg_tmp_8);
121             } else {
122                 if (isa != sse42)
123                     vmovq(op, xmm_dst);
124                 else
125                     movd(op, xmm_dst);
126             }
127
128             break;
129         default:
130             assert(!"unknown dst_dt");
131     }
132 }
133
134 template <cpu_isa_t isa>
135 void jit_uni_bin_conv_fwd_kernel<isa>::apply_filter(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step,
136         int ic_blocks, bool last_icb, bool h_padded)
137 {
138     int kw = jcp.kw;
139     int kh = jcp.kh;
140     int stride_w = jcp.stride_w;
141     int dilate_w = jcp.dilate_w + 1;
142     int ic_blk = jcp.ic_block;
143     int oc_blk = jcp.oc_block;
144
145     int repeats = isa == sse42 && oc_step > (oc_blk / 2) ? 2 : 1;
146     int nbits = 8;
147
148     for (int ki = 0; ki < kw; ki++) {
149         int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w));
150         int jj_end = ur_w  - nstl::max(0, div_up(ki*dilate_w+pad_r-(kw-1)*dilate_w, stride_w));
151
152         int _start = (!jcp.exclude_pad) ? 0 : jj_start;
153         int _end = (!jcp.exclude_pad) ? ur_w : jj_end;
154
155         for (int ifm2 = 0; ifm2 < ic_blocks; ifm2++) {
156             for (int jj = _start; jj < _end; jj++) {
157                 int inp_off = ((ki*dilate_w + jj*stride_w - pad_l)*div_up(jcp.ic, nbits) + ifm2 * div_up(ic_blk, nbits)) * jcp.typesize_in;
158
159                 if (h_padded || jj < jj_start || jj >= jj_end) {
160                     uni_vmovups(vmm_src, ptr[reg_table + 8 * vlen]);
161                 } else {
162                     uni_vpbroadcastd(vmm_src, ptr[aux1_reg_input + inp_off]);
163                 }
164
165                 for (int r = 0; r < repeats; r++) {
166                     for (int ii = 0; ii < oc_blocks; ii++) {
167                         int ker_off = (ifm2 * kw * div_up(ic_blk, nbits) * oc_blk
168                                        + ii * jcp.nb_ic * div_up(ic_blk, nbits) * kh * kw * oc_blk
169                                        + ki * div_up(ic_blk, nbits) * oc_blk + r * div_up(ic_blk, nbits) * (oc_blk / 2)) * jcp.typesize_in;
170
171                         uni_vmovups(vmm_tmp, ptr[aux1_reg_kernel + ker_off]);
172
173                         uni_vpxor(vmm_tmp, vmm_tmp, vmm_src);
174                         if (jcp.ic_padded != jcp.ic && last_icb && ifm2 == (ic_blocks - 1))
175                             uni_vandps(vmm_tmp, vmm_tmp, ptr[reg_table + 7 * vlen]);
176
177                         if (mayiuse(avx512_vpopcnt)) {
178                             vpopcntd(vmm_tmp, vmm_tmp);
179                             uni_vpaddd(Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj),
180                                        Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj), vmm_tmp);
181                         } else {
182                             if (isa == sse42) {
183                                 movups(vmm_tmp1, vmm_tmp);
184                                 pand(vmm_tmp1, vmm_mask);
185                             } else {
186                                 uni_vandps(vmm_tmp1, vmm_mask, vmm_tmp);
187                             }
188
189                             uni_vpsrld(vmm_tmp, vmm_tmp, 4);
190                             uni_vandps(vmm_tmp, vmm_tmp, vmm_mask);
191
192                             if (isa == sse42) {
193                                 movups(vmm_tmp2, vmm_lookup);
194                                 pshufb(vmm_tmp2, vmm_tmp);
195                                 movups(vmm_tmp, vmm_lookup);
196                                 pshufb(vmm_tmp, vmm_tmp1);
197                                 paddb(vmm_tmp, vmm_tmp2);
198                             } else {
199                                 uni_vpshufb(vmm_tmp, vmm_lookup, vmm_tmp);
200                                 uni_vpshufb(vmm_tmp1, vmm_lookup, vmm_tmp1);
201                                 uni_vpaddb(vmm_tmp, vmm_tmp, vmm_tmp1);
202                             }
203
204                             if (mayiuse(avx512_core_vnni)) {
205                                 vpdpbusd(Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj), vmm_tmp, vmm_one_u8);
206                             } else {
207                                 uni_vpmaddubsw(vmm_tmp, vmm_tmp, vmm_one_u8);
208                                 uni_vpmaddwd(vmm_tmp, vmm_tmp, vmm_one_s16);
209                                 uni_vpaddd(Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj),
210                                            Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj), vmm_tmp);
211                             }
212                         }
213                     }
214                 }
215             }
216         }
217     }
218 }
219
220 template <cpu_isa_t isa>
221 void jit_uni_bin_conv_fwd_kernel<isa>::oh_step_unroll_kw(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step, bool h_padded) {
222     int kw = jcp.kw;
223
224     int nbits = 8;
225     int inp_mult = div_up(jcp.ic_block, nbits);
226     int out_mult = jcp.oc_block;
227
228     Label icb_main_loop;
229     Label icb_tail;
230
231     mov(aux1_reg_input, aux_reg_input);
232     mov(aux1_reg_kernel, aux_reg_kernel);
233
234     mov(reg_icb_iter, jcp.nb_ic);
235     L(icb_main_loop);
236     {
237         cmp(reg_icb_iter, 1);
238         jle(icb_tail, T_NEAR);
239
240         apply_filter(ur_w, pad_l, pad_r, oc_blocks, oc_step, 1, false, h_padded);
241
242         add(aux1_reg_input, inp_mult * jcp.typesize_in);
243         add(aux1_reg_kernel, kw * inp_mult * out_mult * jcp.typesize_in);
244         sub(reg_icb_iter, 1);
245         jmp(icb_main_loop, T_NEAR);
246     }
247
248     L(icb_tail);
249
250     apply_filter(ur_w, pad_l, pad_r, oc_blocks, oc_step, 1, true, h_padded);
251 }
252
253 template <cpu_isa_t isa>
254 void jit_uni_bin_conv_fwd_kernel<isa>::kh_loop(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step) {
255     int iw = jcp.iw;
256     int kw = jcp.kw;
257     int dilate_h = jcp.dilate_h + 1;
258
259     int nbits = 8;
260     const int inp_mult = dilate_h * div_up(jcp.ic, nbits);
261
262     Label t_overflow_label, no_t_overflow_label,
263           b_overflow_label, no_b_overflow_label;
264
265     mov(aux_reg_input, reg_input);
266     mov(aux_reg_kernel, reg_kernel_base);
267
268     uni_vmovups(vmm_lookup,  ptr[reg_table + 0 * vlen]);
269     uni_vmovups(vmm_mask,    ptr[reg_table + 1 * vlen]);
270     uni_vmovups(vmm_one_u8,  ptr[reg_table + 5 * vlen]);
271     uni_vmovups(vmm_one_s16, ptr[reg_table + 6 * vlen]);
272
273     if (!jcp.exclude_pad) {
274         mov(reg_overflow,  ptr[param1 + GET_OFF(t_overflow)]);
275         cmp(reg_overflow, 0);
276         je(no_t_overflow_label, T_NEAR);
277         L(t_overflow_label); {
278             oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks, oc_step, true);
279
280             add(aux_reg_kernel, jcp.typesize_in * kw * jcp.oc_block * jcp.nb_ic * div_up(jcp.ic_block, nbits));
281             dec(reg_overflow);
282             cmp(reg_overflow, 0);
283             jg(t_overflow_label, T_NEAR);
284         }
285         L(no_t_overflow_label);
286     }
287
288     Label skip_kh_loop;
289     mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
290     if (!jcp.exclude_pad || (jcp.exclude_pad &&
291                                (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad))) {
292         cmp(reg_kh, 0);
293         je(skip_kh_loop, T_NEAR);
294     }
295
296     Label kh_label;
297     L(kh_label);
298     {
299         oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks, oc_step, false);
300
301         add(aux_reg_kernel, jcp.typesize_in * kw * jcp.oc_block * jcp.nb_ic * div_up(jcp.ic_block, nbits));
302         add(aux_reg_input, jcp.typesize_in * iw * inp_mult);
303
304         dec(reg_kh);
305         cmp(reg_kh, 0);
306         jg(kh_label, T_NEAR);
307     }
308
309     L(skip_kh_loop);
310
311     if (!jcp.exclude_pad) {
312         mov(reg_overflow,  ptr[param1 + GET_OFF(b_overflow)]);
313         cmp(reg_overflow, 0);
314         je(no_b_overflow_label, T_NEAR);
315         L(b_overflow_label); {
316             oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks, oc_step, true);
317
318             add(aux_reg_kernel, jcp.typesize_in * kw * jcp.oc_block * jcp.nb_ic * div_up(jcp.ic_block, nbits));
319             dec(reg_overflow);
320             cmp(reg_overflow, 0);
321             jg(b_overflow_label, T_NEAR);
322         }
323         L(no_b_overflow_label);
324     }
325 }
326
327 template <cpu_isa_t isa>
328 void jit_uni_bin_conv_fwd_kernel<isa>::width_blk_step(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step)
329 {
330     int nbits = 8;
331     int repeats = isa == sse42 && oc_step > (jcp.oc_block / 2) ? 2 : 1;
332
333     for (int r = 0; r < repeats; r++)
334         for (int ii = 0; ii < oc_blocks; ii++)
335             for (int jj = 0; jj < ur_w; jj++)
336                 uni_vpxor(Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj),
337                           Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj),
338                           Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj));
339
340     kh_loop(ur_w, pad_l, pad_r, oc_blocks, oc_step);
341
342     if (isa == avx512_common && oc_step != jcp.oc_block) {
343         int mask = (1 << oc_step) - 1;
344         mov(reg_tmp_32, mask);
345         kmovw(ktail_mask, reg_tmp_32);
346     }
347
348     const auto &p = attr_.post_ops_;
349     for (int r = 0; r < repeats; r++) {
350         int tail_size = isa == sse42 ? nstl::min(jcp.oc_block / 2, oc_step - r * jcp.oc_block / 2) : oc_step;
351         bool is_scalar_store = isa == sse42 ? tail_size < jcp.oc_block / 2 : tail_size < jcp.oc_block;
352
353 #ifdef _MSC_BUILD
354         auto kw_padding = make_vla<int>(ur_w);
355 #else
356         int kw_padding[ur_w];
357 #endif  // _MSC_BUILD
358
359         if (jcp.exclude_pad) {
360             mov(reg_tmp_32, jcp.ic);
361             imul(reg_tmp_32,  ptr[param1 + GET_OFF(kh_padding)]);
362
363             for (int jj = 0; jj < ur_w; jj++)
364                 kw_padding[jj] = 0;
365
366             for (int ki = 0; ki < jcp.kw; ki++) {
367                 int jj_start = nstl::max(0, div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w));
368                 int jj_end = ur_w - nstl::max(0, div_up(ki * (jcp.dilate_w + 1) + pad_r -
369                                                         (jcp.kw - 1) * (jcp.dilate_w + 1), jcp.stride_w));
370                 for (int jj = jj_start; jj < jj_end; jj++) {
371                     kw_padding[jj]++;
372                 }
373             }
374         } else {
375             uni_vmovups(vmm_shift, ptr[reg_table + 4 * vlen]);
376         }
377         uni_vmovups(vmm_scale, ptr[reg_table + 3 * vlen]);
378
379         for (int jj = 0; jj < ur_w; jj++) {
380             if (jcp.exclude_pad) {
381                 mov(reg_shift, kw_padding[jj]);
382                 imul(reg_shift, reg_tmp_32);
383                 movq(Xmm(vmm_shift.getIdx()), reg_shift);
384                 uni_vbroadcastss(vmm_shift, Xmm(vmm_shift.getIdx()));
385                 uni_vcvtdq2ps(vmm_shift, vmm_shift);
386             }
387
388             for (int ii = 0; ii < oc_blocks; ii++) {
389                 uni_vcvtdq2ps(Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj), Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj));
390                 uni_vfmadd213ps(Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj), vmm_scale, vmm_shift);
391             }
392         }
393
394         int eltwise_inj_idx = 0;
395         int depthwise_inj_idx = 0;
396         int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
397         for (int i = 0; i < end_idx; i++) {
398             int start_idx = 1 + r * jcp.ur_w * jcp.nb_oc_blocking;
399
400             auto& post_op = p.entry_[i];
401             if (post_op.is_eltwise()) {
402                 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(start_idx, start_idx + oc_blocks * ur_w);
403                 eltwise_inj_idx++;
404             } else if (post_op.is_depthwise()) {
405                 pop(reg_oc_off);
406
407                 mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
408                 mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
409
410                 add(reg_d_weights, reg_oc_off);
411                 add(reg_d_bias, reg_oc_off);
412
413                 if (r == 1) {
414                     add(reg_d_weights, (jcp.oc_block / 2) * sizeof(float));
415                     add(reg_d_bias, (jcp.oc_block / 2) * sizeof(float));
416                 }
417
418                 for (int ii = 0; ii < oc_blocks; ii++) {
419                     depthwise_injectors[depthwise_inj_idx]->compute_vector_range(start_idx + ur_w * ii,
420                             start_idx + ur_w * ii + ur_w, reg_d_weights, reg_d_bias);
421
422                     add(reg_d_weights, jcp.oc_block * sizeof(float));
423                     add(reg_d_bias, jcp.oc_block * sizeof(float));
424                 }
425
426                 depthwise_inj_idx++;
427
428                 push(reg_oc_off);
429             } else if (post_op.is_sum(false)) {
430                 for (int ii = 0; ii < oc_blocks; ii++) {
431                     for (int jj = 0; jj < ur_w; jj++) {
432                         Vmm vmm_dst = Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj);
433
434                         if (is_scalar_store) {
435                             if (isa == avx512_common) {
436                                 int o_off =  jj * jcp.oc * jcp.ngroups;
437
438                                 Vmm vmm_in = vmm_sum | ktail_mask | T_z;
439
440                                 vmovups(vmm_in, ptr[reg_output + o_off * jcp.typesize_out]);
441                                 uni_vaddps(vmm_dst, vmm_dst, vmm_sum);
442                             } else {
443                                 for (int oc = 0; oc < tail_size; oc++) {
444                                     int o_off =  jj * jcp.oc * jcp.ngroups + r * (jcp.oc_block / 2) + oc;
445
446                                     uni_vpxor(vmm_sum, vmm_sum, vmm_sum);
447                                     cvt2ps(jcp.dst_dt, vmm_sum, ptr[reg_output + o_off * jcp.typesize_out], true);
448
449                                     if (oc < jcp.oc_block / 2) {
450                                         uni_vpslldq(vmm_sum, vmm_sum, oc * sizeof(float));
451                                     } else {
452                                         Ymm ymm_prev_dst = Ymm(vmm_sum.getIdx());
453                                         vperm2i128(ymm_prev_dst, ymm_prev_dst, ymm_prev_dst, 0x01);
454                                         vpslldq(vmm_sum, vmm_sum, (oc - jcp.oc_block / 2) * sizeof(float));
455                                     }
456
457                                     uni_vaddps(vmm_dst, vmm_dst, vmm_sum);
458                                 }
459                             }
460                         } else {
461                             size_t o_off = ii * jcp.oc_block + jj * jcp.oc * jcp.ngroups + r * (jcp.oc_block / 2);
462
463                             cvt2ps(jcp.dst_dt, vmm_sum, ptr[reg_output + o_off * jcp.typesize_out], false);
464                             uni_vaddps(vmm_dst, vmm_dst, vmm_sum);
465                         }
466                     }
467                 }
468             }
469         }
470     }
471
472     if (jcp.with_binarization) {
473         int binarization_idx = p.find(primitive_kind::binarization);
474
475         pop(reg_oc_off);
476
477         mov(reg_b_weights, reinterpret_cast<size_t>(p.entry_[binarization_idx].binarization.weights_data));
478         mov(reg_b_out_mask, reinterpret_cast<size_t>(p.entry_[binarization_idx].binarization.output_mask_data));
479         add(reg_b_weights, reg_oc_off);
480         add(reg_b_out_mask, reg_oc_off);
481
482         push(reg_oc_off);
483
484         for (int ii = 0; ii < oc_blocks; ii++) {
485             for (int jj = 0; jj < ur_w; jj++) {
486                 for (int r = 0; r < repeats; r++) {
487                     int tail_size = isa == sse42 ? nstl::min(jcp.oc_block / 2, oc_step - r * jcp.oc_block / 2) : oc_step;
488                     mov(reg_b_mask, (1 << tail_size) - 1);
489                     uni_vmovups(vmm_thr, ptr[reg_b_weights + (ii * jcp.oc_block + r * (jcp.oc_block / 2)) * sizeof(float)]);
490                     uni_vmovups(vmm_out_mask, ptr[reg_b_out_mask + (ii * jcp.oc_block + r * (jcp.oc_block / 2)) * sizeof(float)]);
491
492                     Vmm vmm_dst = Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj);
493
494                     if (isa == avx512_common) {
495                         vcmpps(bin_mask0, vmm_dst, vmm_thr, _cmp_gt_os);
496                         vptestmd(bin_mask1, vmm_out_mask, vmm_out_mask);
497                         kxnorw(bin_mask0, bin_mask0, bin_mask1);
498                     } else {
499                         uni_vcmpgtps(vmm_dst, vmm_dst, vmm_thr);
500                         uni_vpcmpeqd(vmm_dst, vmm_dst, vmm_out_mask);
501                     }
502
503                     if (r == 0) {
504                         if (isa == avx512_common) {
505                             kmovw(reg_tmp_32, bin_mask0);
506                         } else {
507                             uni_vmovmskps(reg_tmp_32, vmm_dst);
508                         }
509                         and_(reg_tmp_64, reg_b_mask);
510                     } else {
511                         uni_vmovmskps(reg_tmp2_32, vmm_dst);
512                         and_(reg_tmp2_64, reg_b_mask);
513                         shl(reg_tmp2_32, 4);
514                         or_(reg_tmp_32, reg_tmp2_32);
515                     }
516
517                     if (r == repeats - 1) {
518                         if (isa == avx512_common && oc_step > nbits) {
519                             const size_t o_off = (2 * ii + jj * div_up(jcp.oc, nbits));
520                             mov(ptr[reg_output + o_off * jcp.typesize_out], reg_tmp_16);
521                         } else {
522                             const size_t o_off = (ii + jj * div_up(jcp.oc, nbits));
523                             mov(ptr[reg_output + o_off * jcp.typesize_out], reg_tmp_8);
524                         }
525                     }
526                 }
527             }
528         }
529     } else {
530         for (int r = 0; r < repeats; r++) {
531             int tail_size = isa == sse42 ? nstl::min(jcp.oc_block / 2, oc_step - r * jcp.oc_block / 2) : oc_step;
532             bool is_scalar_store = isa == sse42 ? tail_size < jcp.oc_block / 2 : tail_size < jcp.oc_block;
533             if (is_scalar_store) {
534                 for (int jj = 0; jj < ur_w; jj++) {
535                     Vmm vmm_dst = Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + jj);
536
537                     if (isa == avx512_common) {
538                         size_t o_off;
539                         if (jcp.with_dw_conv)
540                             o_off = jj * jcp.oc_block;
541                         else
542                             o_off = jj * jcp.oc * jcp.ngroups;
543
544                         uni_vmovups(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst | ktail_mask);
545                     } else {
546                         for (int oc = 0; oc < tail_size; oc++) {
547                             size_t o_off;
548                             if (jcp.with_dw_conv)
549                                 o_off = jj * jcp.oc_block + oc + r * (jcp.oc_block / 2);
550                             else
551                                 o_off = jj * jcp.oc * jcp.ngroups + r * (jcp.oc_block / 2) + oc;
552
553                             store_dst(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, true);
554
555                             if (isa == sse42) {
556                                 psrldq(vmm_dst, jcp.typesize_out);
557                             } else {
558                                 Ymm ymm_dst = Ymm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + jj);
559
560                                 vperm2i128(ymm_tmp, ymm_dst, ymm_dst, 0x01);
561                                 vpalignr(ymm_dst, vmm_tmp, ymm_dst, jcp.typesize_out);
562                             }
563                         }
564                     }
565                 }
566             } else {
567                 for (int ii = 0; ii < oc_blocks; ii++) {
568                     for (int jj = 0; jj < ur_w; jj++) {
569                         Vmm vmm_dst = Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj);
570
571                         size_t o_off;
572                         if (jcp.with_dw_conv)
573                             o_off = ((size_t) ii * jcp_dw_conv.kh * jcp.ow + jj) * jcp.oc_block +
574                                     r * (jcp.oc_block / 2);
575                         else
576                             o_off = ii * jcp.oc_block + jj * jcp.oc * jcp.ngroups + r * (jcp.oc_block / 2);
577
578                         store_dst(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, false);
579                     }
580                 }
581             }
582         }
583     }
584 }
585
586 template <cpu_isa_t isa>
587 inline void jit_uni_bin_conv_fwd_kernel<isa>::solve_common(int oc_blocks, int oc_step)
588 {
589     int ur_w = jcp.ur_w;
590     int ur_w_tail = jcp.ur_w_tail;
591     int n_oi = jcp.ow / ur_w;
592     int iw = jcp.iw;
593     int kw = jcp.kw;
594     int dilate_w = jcp.dilate_w + 1;
595     int str_w = jcp.stride_w;
596
597     int nbits = 8;
598     const int inp_mult = div_up(jcp.ic, nbits);
599     const int out_mult = jcp.with_dw_conv ? jcp.oc_block : jcp.with_binarization ? div_up(jcp.oc, nbits) : jcp.oc;
600
601     int l_pad = jcp.l_pad;
602     int r_pad = nstl::max(0, (jcp.ow - 1) * str_w + (kw - 1) * dilate_w
603             - (iw + l_pad - 1));
604     int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w
605             - (iw + l_pad - 1);
606     if (r_pad1 > 0) n_oi--;
607
608     mov(reg_input, reg_input_base);
609     mov(reg_output, reg_output_base);
610
611     push(reg_input_base);
612     push(reg_output_base);
613     push(reg_oc_work);
614     push(reg_oc_off);
615
616     if (l_pad > 0) {
617         n_oi--;
618         if (n_oi < 0 && r_pad1 > 0)
619             width_blk_step(ur_w, l_pad, r_pad1, oc_blocks, oc_step); // "lrpad"
620         else
621             width_blk_step(ur_w, l_pad, 0, oc_blocks, oc_step); // "lpad"
622         add(reg_input, jcp.typesize_in * (ur_w * str_w - l_pad) * inp_mult);
623         add(reg_output, jcp.typesize_out * ur_w * out_mult);
624     }
625
626     Label ow_loop_label;
627     xor_(oi_iter, oi_iter);
628
629     if (n_oi > 0) {
630         L(ow_loop_label);
631
632         width_blk_step(ur_w, 0, 0, oc_blocks, oc_step); // "middle"
633         add(reg_input, jcp.typesize_in * ur_w * str_w * inp_mult);
634         add(reg_output, jcp.typesize_out * ur_w * out_mult);
635
636         inc(oi_iter);
637         cmp(oi_iter, n_oi);
638         jl(ow_loop_label, T_NEAR);
639     }
640
641     if (r_pad1 > 0 && n_oi >=0) {
642         width_blk_step(ur_w, 0, r_pad1, oc_blocks, oc_step); // "rpad"
643         add(reg_input, jcp.typesize_in * ur_w * str_w * inp_mult);
644         add(reg_output, jcp.typesize_out * ur_w * out_mult);
645     }
646
647     if (ur_w_tail != 0)
648         width_blk_step(ur_w_tail, 0, r_pad, oc_blocks, oc_step); // "tail"
649
650     pop(reg_oc_off);
651     pop(reg_oc_work);
652     pop(reg_output_base);
653     pop(reg_input_base);
654 }
655
656 template <cpu_isa_t isa>
657 void jit_uni_bin_conv_fwd_kernel<isa>::generate()
658 {
659     const auto &p = attr_.post_ops_;
660     int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
661     for (int i = 0; i < end_idx; i++) {
662         auto &post_op = p.entry_[i];
663         if (post_op.is_eltwise()) {
664             eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>(
665                     this,
666                     post_op.eltwise.alg,
667                     post_op.eltwise.alpha,
668                     post_op.eltwise.beta
669             ));
670         } else if (post_op.is_depthwise()) {
671             depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<isa>(
672                     this,
673                     post_op.depthwise.alg
674             ));
675         }
676     }
677
678     this->preamble();
679
680     mov(reg_input_base, ptr[this->param1 + GET_OFF(src)]);
681     mov(reg_output_base, ptr[this->param1 + GET_OFF(dst)]);
682     mov(reg_kernel_base, ptr[this->param1 + GET_OFF(filt)]);
683
684     mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
685     mov(reg_oc_work, ptr[this->param1 + GET_OFF(oc_work)]);
686
687     mov(reg_oc_off,  ptr[param1 + GET_OFF(oc_off)]);
688     mov(reg_table, l_table);
689
690     Label main_loop_label;
691     Label tail_label;
692     Label exit_label;
693
694     cmp(reg_oc_work, jcp.nb_oc_blocking * jcp.oc_block);
695     jne(main_loop_label, T_NEAR);
696
697     solve_common(jcp.nb_oc_blocking, jcp.oc_block);
698
699     sub(reg_oc_work, jcp.nb_oc_blocking * jcp.oc_block);
700
701     jmp(exit_label, T_NEAR);
702
703     int nbits = 8;
704
705     L(main_loop_label); {
706         cmp(reg_oc_work, jcp.oc_block);
707         jl(tail_label, T_NEAR);
708
709         solve_common(1, jcp.oc_block);
710
711         sub(reg_oc_work, jcp.oc_block);
712         add(reg_kernel_base, jcp.oc_block * jcp.nb_ic * jcp.kh * jcp.kw * div_up(jcp.ic_block, nbits) * jcp.typesize_in);
713
714         if (jcp.with_dw_conv) {
715             add(reg_output_base, jcp.oc_block * jcp_dw_conv.kh * jcp.ow * jcp.typesize_out);
716         } else {
717             if (jcp.with_binarization)
718                 add(reg_output_base, div_up(jcp.oc_block, nbits) * jcp.typesize_out);
719             else
720                 add(reg_output_base, jcp.oc_block * jcp.typesize_out);
721         }
722
723         add(reg_oc_off, jcp.oc_block * sizeof(float));
724
725         jmp(main_loop_label, T_NEAR);
726     }
727
728     L(tail_label);
729
730     if (jcp.oc % jcp.oc_block != 0)
731         solve_common(1, jcp.oc % jcp.oc_block);
732
733     L(exit_label);
734
735     this->postamble();
736
737     prepare_table();
738
739     for (auto& inj : eltwise_injectors)
740         inj->prepare_table();
741 }
742
743 template <cpu_isa_t isa>
744 void jit_uni_bin_conv_fwd_kernel<isa>::prepare_table() {
745     const unsigned int cvals[] = {
746             0x02010100, // 0 1 1 2
747             0x03020201, // 1 2 2 3
748             0x03020201, // 1 2 2 3
749             0x04030302,  // 2 3 3 4
750             0x0f0f0f0f,
751             0x000000ff,
752             0xc0000000, // -2.0f
753             0x01010101,
754             0x00010001
755     };
756
757     size_t simd_w = vlen / sizeof(int32_t);
758
759     align(64);
760     L(l_table);
761     // offset = 0
762     for (size_t d = 0; d < simd_w; ++d) {
763         dd(cvals[d % 4]);
764     }
765     // offset = 1
766     for (size_t d = 0; d < simd_w; ++d) {
767         dd(cvals[4]);
768     }
769     // offset = 2
770     for (size_t d = 0; d < simd_w; ++d) {
771         dd(cvals[5]);
772     }
773     // offset = 3
774     for (size_t d = 0; d < simd_w; ++d) {
775         dd(cvals[6]);
776     }
777
778     // offset = 4
779     for (size_t d = 0; d < simd_w; ++d) {
780         dd(float2int(jcp.ic * jcp.kw * jcp.kh));
781     }
782
783     // offset = 5
784     for (size_t d = 0; d < simd_w; ++d) {
785         dd(cvals[7]);
786     }
787     // offset = 6
788     for (size_t d = 0; d < simd_w; ++d) {
789         dd(cvals[8]);
790     }
791     // offset = 7
792     for (size_t d = 0; d < simd_w; ++d) {
793         uint32_t mask = 0xffffffff >> (jcp.ic_padded - jcp.ic);
794         dd(mask);
795     }
796     // offset = 8
797     for (size_t d = 0; d < simd_w; ++d) {
798         uint32_t val = jcp.pad_value == 1.0f ? 0xffffffff : 0x00000000;
799         dd(val);
800     }
801 }
802
803 template <cpu_isa_t isa>
804 bool jit_uni_bin_conv_fwd_kernel<isa>::post_ops_ok(jit_bin_conv_conf_t &jcp, const primitive_attr_t &attr) {
805     const auto &p = attr.post_ops_;
806
807     int dw_conv_idx = p.find(primitive_kind::convolution);
808     bool with_dw_conv = dw_conv_idx != -1;
809
810     auto all_post_ops_supported = [&]() {
811         bool ok = true;
812
813         int end_idx = with_dw_conv ? dw_conv_idx : p.len_;
814         for (int i = 0; i < end_idx; i++) {
815             ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::sum, primitive_kind::eltwise, primitive_kind::depthwise,
816                                                        primitive_kind::binarization);
817         }
818         return ok;
819     };
820     auto contain = [&](mkldnn::impl::primitive_kind_t kind) { return p.find(kind, 0, dw_conv_idx) != -1; };
821     auto position = [&](mkldnn::impl::primitive_kind_t kind) { return p.find(kind, 0, dw_conv_idx); };
822     auto count = [&](mkldnn::impl::primitive_kind_t kind) { return p.count(kind, 0, dw_conv_idx); };
823
824     return all_post_ops_supported() &&
825            count(primitive_kind::sum) <= 1 &&
826            count(primitive_kind::binarization) <= 1 &&
827            IMPLICATION(contain(primitive_kind::binarization), position(primitive_kind::binarization) == p.len_-1 &&
828                                                               !contain(primitive_kind::sum)) &&
829            IMPLICATION(with_dw_conv, !contain(primitive_kind::sum) && !contain(primitive_kind::binarization));
830 }
831
832 template <cpu_isa_t isa>
833 status_t jit_uni_bin_conv_fwd_kernel<isa>::init_conf(jit_bin_conv_conf_t &jcp,
834         const binary_convolution_desc_t &cd, const memory_desc_wrapper &src_d,
835         const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, const primitive_attr_t &attr)
836 {
837     if (!mayiuse(isa)) return status::unimplemented;
838
839     jcp.prop_kind = cd.prop_kind;
840
841     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
842
843     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
844
845     if (jcp.ngroups != 1)
846         return status::unimplemented;
847
848     jcp.mb = src_d.dims()[0];
849
850     int simd_w = isa == avx512_common ? 16 : 8;
851
852     jcp.ic = src_d.dims()[1] / jcp.ngroups;
853     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
854
855     jcp.oc_padded = rnd_up(jcp.oc, simd_w);
856
857     jcp.ih = src_d.dims()[2];
858     jcp.iw = src_d.dims()[3];
859     jcp.oh = dst_d.dims()[2];
860     jcp.ow = dst_d.dims()[3];
861
862     jcp.kh = weights_d.dims()[with_groups + 2];
863     jcp.kw = weights_d.dims()[with_groups + 3];
864
865     jcp.t_pad = cd.padding[0][0];
866     jcp.l_pad = cd.padding[0][1];
867
868     jcp.stride_h = cd.strides[0];
869     jcp.stride_w = cd.strides[1];
870
871     jcp.dilate_h = cd.dilates[0];
872     jcp.dilate_w = cd.dilates[1];
873
874     jcp.src_fmt = src_d.format();
875
876     if (!post_ops_ok(jcp, attr))
877         return status::unimplemented;
878
879     jcp.pad_value = cd.pad_value;
880     jcp.exclude_pad = jcp.pad_value == 0.0f;
881
882     jcp.src_dt = cd.src_desc.data_type;
883     jcp.bia_dt = mkldnn_f32;
884     jcp.dst_dt = cd.dst_desc.data_type;
885
886     const auto &p = attr.post_ops_;
887     int dw_conv_ind = p.find(primitive_kind::convolution);
888     jcp.with_dw_conv = dw_conv_ind != -1;
889     if (jcp.with_dw_conv) {
890         jcp.dw_conv_oh = jcp.oh;
891         jcp.dw_conv_ow = jcp.ow;
892         jcp.oh = p.entry_[dw_conv_ind].dw_conv.in_h;
893         jcp.ow = p.entry_[dw_conv_ind].dw_conv.in_w;
894
895         jcp.dw_conv_dst_dt = jcp.dst_dt;
896         jcp.dst_dt = p.entry_[dw_conv_ind].dw_conv.in_dt;
897     }
898     jcp.with_sum = p.find(primitive_kind::sum, 0, dw_conv_ind) != -1;
899     jcp.with_binarization = p.find(primitive_kind::binarization, 0, dw_conv_ind) != -1;
900
901     if (with_groups)
902         return status::unimplemented;
903
904     auto desired_weights_format = isa == avx512_common ? OhIw16o32i : OhIw8o32i;
905     bool args_ok = true
906         && src_d.format() == nhwc
907         && weights_d.format() == desired_weights_format
908         && dst_d.format() == nhwc;
909     if (!args_ok) return status::unimplemented;
910
911     jcp.ur_h = 1;
912     jcp.ur_w = isa == avx512_common ? 4 : 2;
913     if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
914     jcp.ur_w_tail = jcp.ow % jcp.ur_w;
915
916     jcp.ic_block = 32;
917     jcp.nb_ic = div_up(jcp.ic, jcp.ic_block);
918     jcp.ic_padded = rnd_up(jcp.ic, jcp.ic_block);
919
920     jcp.oc_block = simd_w;
921     jcp.nb_oc = div_up(jcp.oc, jcp.oc_block);
922
923     jcp.nb_ic_blocking = 1;
924     jcp.nb_oc_blocking = nstl::min(isa == sse42 ? 2 : isa == avx2 ? 4 : 6, jcp.nb_oc);
925
926     jcp.typesize_in = types::data_type_size(jcp.src_dt);
927     jcp.typesize_out = types::data_type_size(jcp.dst_dt);
928     jcp.typesize_acc = sizeof(int32_t);
929
930     args_ok = true
931         && jcp.l_pad <= jcp.ur_w
932         && IMPLICATION(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0)
933                 || (jcp.stride_w == 1 && jcp.stride_h == 1));
934     if (!args_ok) return status::unimplemented;
935
936     int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
937         + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
938     if (r_pad_no_tail > jcp.ur_w)
939         return status::unimplemented;
940
941     if (jcp.l_pad > jcp.ur_w) return status::unimplemented;
942
943     return status::success;
944 }
945
946 template <cpu_isa_t isa>
947 void jit_uni_bin_conv_fwd_kernel<isa>::init_scratchpad(
948         memory_tracking::registrar_t &scratchpad, const jit_bin_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw_conv) {
949     if (jcp.with_dw_conv) {
950         const int nthreads = mkldnn_get_max_threads();
951         size_t dw_conv_buffer_size_ = (size_t)jcp_dw_conv.kh * jcp_dw_conv.iw * jcp_dw_conv.ch_block * jcp.nb_oc_blocking;
952         scratchpad.book(key_dw_conv_buffer, sizeof(float) * dw_conv_buffer_size_ * nthreads);
953
954         if (jcp.oc != jcp.oc_padded)
955             scratchpad.book(key_dw_conv_padded_bias, sizeof(float) * jcp.oc_padded);
956     }
957 }
958
959 template struct jit_uni_bin_conv_fwd_kernel<sse42>;
960 template struct jit_uni_bin_conv_fwd_kernel<avx2>;
961 template struct jit_uni_bin_conv_fwd_kernel<avx512_common>;
962
963 }
964 }
965 }