Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_bin_conv_kernel.cpp
1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <common/primitive_attr.hpp>
18 #include "c_types_map.hpp"
19 #include "nstl.hpp"
20 #include "type_helpers.hpp"
21 #include "utils.hpp"
22 #include "cpu_memory.hpp"
23
24 #include "jit_uni_bin_conv_kernel.hpp"
25
26 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
27
28 namespace mkldnn {
29 namespace impl {
30 namespace cpu {
31
32 using namespace mkldnn::impl::prop_kind;
33 using namespace mkldnn::impl::memory_format;
34 using namespace mkldnn::impl::memory_tracking::names;
35 using namespace mkldnn::impl::utils;
36
37 using namespace Xbyak;
38
39 template <cpu_isa_t isa>
40 void jit_uni_bin_conv_fwd_kernel<isa>::cvt2ps(data_type_t type_in, Vmm vmm_in, const Operand &op, bool scalar_load) {
41     Xmm xmm_in = Xmm(vmm_in.getIdx());
42
43     switch (type_in) {
44         case data_type::f32:
45         case data_type::s32:
46             if (scalar_load) {
47                 mov(reg_tmp_32, op);
48                 movq(xmm_in, reg_tmp_64);
49             } else {
50                 uni_vmovups(vmm_in, op);
51             }
52             break;
53         case data_type::s8:
54             if (scalar_load) {
55                 movsx(reg_tmp_32, op);
56                 movq(xmm_in, reg_tmp_64);
57             } else {
58                 uni_vpmovsxbd(vmm_in, op);
59             }
60             break;
61         case data_type::u8:
62             if (scalar_load) {
63                 movzx(reg_tmp_32, op);
64                 movq(xmm_in, reg_tmp_64);
65             } else {
66                 uni_vpmovzxbd(vmm_in, op);
67             }
68             break;
69         default: assert(!"unsupported data type");
70     }
71
72     if (type_in != data_type::f32)
73         uni_vcvtdq2ps(vmm_in, vmm_in);
74 }
75
76 template <cpu_isa_t isa>
77 void jit_uni_bin_conv_fwd_kernel<isa>::store_dst(const Xbyak::Address &op, Vmm vmm_dst, bool scalar_store) {
78     Ymm ymm_dst = Ymm(vmm_dst.getIdx());
79     Xmm xmm_dst = Xmm(vmm_dst.getIdx());
80
81     switch (jcp.dst_dt) {
82         case data_type::f32:
83         case data_type::s32:
84             if (scalar_store) {
85                 movq(reg_tmp_64, xmm_dst);
86                 mov(op, reg_tmp_32);
87             } else {
88                 uni_vmovups(op, vmm_dst);
89             }
90             break;
91         case data_type::s8:
92             uni_vpackssdw(vmm_dst, vmm_dst, vmm_dst);
93
94             if (isa != sse42 && !scalar_store)
95                 vpermq(ymm_dst, ymm_dst, 0x08);
96
97             uni_vpacksswb(xmm_dst, xmm_dst, xmm_dst);
98
99             if (scalar_store) {
100                 movq(reg_tmp_64, xmm_dst);
101                 mov(op, reg_tmp_8);
102             } else {
103                 if (isa != sse42)
104                     vmovq(op, xmm_dst);
105                 else
106                     movd(op, xmm_dst);
107             }
108             break;
109         case data_type::u8:
110         case data_type::bin:
111             uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst);
112
113             if (isa != sse42 && !scalar_store)
114                 vpermq(ymm_dst, ymm_dst, 0x08);
115
116             uni_vpackuswb(xmm_dst, xmm_dst, xmm_dst);
117
118             if (scalar_store) {
119                 movq(reg_tmp_64, xmm_dst);
120                 mov(op, reg_tmp_8);
121             } else {
122                 if (isa != sse42)
123                     vmovq(op, xmm_dst);
124                 else
125                     movd(op, xmm_dst);
126             }
127
128             break;
129         default:
130             assert(!"unknown dst_dt");
131     }
132 }
133
134 template <cpu_isa_t isa>
135 void jit_uni_bin_conv_fwd_kernel<isa>::apply_filter(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step,
136         int ic_blocks, bool last_icb, bool h_padded)
137 {
138     int kw = jcp.kw;
139     int kh = jcp.kh;
140     int stride_w = jcp.stride_w;
141     int dilate_w = jcp.dilate_w + 1;
142     int ic_blk = jcp.ic_block;
143     int oc_blk = jcp.oc_block;
144
145     int repeats = isa == sse42 && oc_step > (oc_blk / 2) ? 2 : 1;
146     int nbits = 8;
147
148     for (int ki = 0; ki < kw; ki++) {
149         int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w));
150         int jj_end = ur_w  - nstl::max(0, div_up(ki*dilate_w+pad_r-(kw-1)*dilate_w, stride_w));
151
152         int _start = (!jcp.exclude_pad) ? 0 : jj_start;
153         int _end = (!jcp.exclude_pad) ? ur_w : jj_end;
154
155         for (int ifm2 = 0; ifm2 < ic_blocks; ifm2++) {
156             for (int jj = _start; jj < _end; jj++) {
157                 int inp_off = ((ki*dilate_w + jj*stride_w - pad_l)*div_up(jcp.ic, nbits) + ifm2 * div_up(ic_blk, nbits)) * jcp.typesize_in;
158
159                 if (h_padded || jj < jj_start || jj >= jj_end) {
160                     uni_vmovups(vmm_src, ptr[reg_table + 256]);
161                 } else {
162                     uni_vpbroadcastd(vmm_src, ptr[aux1_reg_input + inp_off]);
163                 }
164
165                 for (int r = 0; r < repeats; r++) {
166                     for (int ii = 0; ii < oc_blocks; ii++) {
167                         int ker_off = (ifm2 * kw * div_up(ic_blk, nbits) * oc_blk
168                                        + ii * jcp.nb_ic * div_up(ic_blk, nbits) * kh * kw * oc_blk
169                                        + ki * div_up(ic_blk, nbits) * oc_blk + r * div_up(ic_blk, nbits) * (oc_blk / 2)) * jcp.typesize_in;
170
171                         uni_vmovups(vmm_tmp, ptr[aux1_reg_kernel + ker_off]);
172
173                         uni_vpxor(vmm_tmp, vmm_tmp, vmm_src);
174                         if (jcp.ic_padded != jcp.ic && last_icb && ifm2 == (ic_blocks - 1))
175                             uni_vandps(vmm_tmp, vmm_tmp, ptr[reg_table + 224]);
176
177                         if (isa == sse42) {
178                             movups(vmm_tmp1, vmm_tmp);
179                             pand(vmm_tmp1, vmm_mask);
180                         } else {
181                             uni_vandps(vmm_tmp1, vmm_mask, vmm_tmp);
182                         }
183
184                         uni_vpsrld(vmm_tmp, vmm_tmp, 4);
185                         uni_vandps(vmm_tmp, vmm_tmp, vmm_mask);
186
187                         if (isa == sse42) {
188                             movups(vmm_tmp2, vmm_lookup);
189                             pshufb(vmm_tmp2, vmm_tmp);
190                             movups(vmm_tmp, vmm_lookup);
191                             pshufb(vmm_tmp, vmm_tmp1);
192                             paddb(vmm_tmp, vmm_tmp2);
193                         } else {
194                             uni_vpshufb(vmm_tmp, vmm_lookup, vmm_tmp);
195                             uni_vpshufb(vmm_tmp1, vmm_lookup, vmm_tmp1);
196                             uni_vpaddb(vmm_tmp, vmm_tmp, vmm_tmp1);
197                         }
198
199                         uni_vpmaddubsw(vmm_tmp, vmm_tmp, vmm_one_u8);
200                         uni_vpmaddwd(vmm_tmp, vmm_tmp, vmm_one_s16);
201                         uni_vpaddd(Vmm(1 + r*jcp.ur_w*jcp.nb_oc_blocking + ur_w * ii + jj),
202                                    Vmm(1 + r*jcp.ur_w*jcp.nb_oc_blocking + ur_w * ii + jj), vmm_tmp);
203                     }
204                 }
205             }
206         }
207     }
208 }
209
210 template <cpu_isa_t isa>
211 void jit_uni_bin_conv_fwd_kernel<isa>::oh_step_unroll_kw(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step, bool h_padded) {
212     int kw = jcp.kw;
213
214     int nbits = 8;
215     int inp_mult = div_up(jcp.ic_block, nbits);
216     int out_mult = jcp.oc_block;
217
218     Label icb_main_loop;
219     Label icb_tail;
220
221     mov(aux1_reg_input, aux_reg_input);
222     mov(aux1_reg_kernel, aux_reg_kernel);
223
224     mov(reg_icb_iter, jcp.nb_ic);
225     L(icb_main_loop);
226     {
227         cmp(reg_icb_iter, 1);
228         jle(icb_tail, T_NEAR);
229
230         apply_filter(ur_w, pad_l, pad_r, oc_blocks, oc_step, 1, false, h_padded);
231
232         add(aux1_reg_input, inp_mult * jcp.typesize_in);
233         add(aux1_reg_kernel, kw * inp_mult * out_mult * jcp.typesize_in);
234         sub(reg_icb_iter, 1);
235         jmp(icb_main_loop, T_NEAR);
236     }
237
238     L(icb_tail);
239
240     apply_filter(ur_w, pad_l, pad_r, oc_blocks, oc_step, 1, true, h_padded);
241 }
242
243 template <cpu_isa_t isa>
244 void jit_uni_bin_conv_fwd_kernel<isa>::kh_loop(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step) {
245     int iw = jcp.iw;
246     int kw = jcp.kw;
247     int dilate_h = jcp.dilate_h + 1;
248
249     int nbits = 8;
250     const int inp_mult = dilate_h * div_up(jcp.ic, nbits);
251
252     Label t_overflow_label, no_t_overflow_label,
253           b_overflow_label, no_b_overflow_label;
254
255     mov(aux_reg_input, reg_input);
256     mov(aux_reg_kernel, reg_kernel_base);
257
258     uni_vmovups(vmm_lookup,  ptr[reg_table]);
259     uni_vmovups(vmm_mask,    ptr[reg_table + 32]);
260     uni_vmovups(vmm_one_u8,  ptr[reg_table + 160]);
261     uni_vmovups(vmm_one_s16, ptr[reg_table + 192]);
262
263     if (!jcp.exclude_pad) {
264         mov(reg_overflow,  ptr[param1 + GET_OFF(t_overflow)]);
265         cmp(reg_overflow, 0);
266         je(no_t_overflow_label, T_NEAR);
267         L(t_overflow_label); {
268             oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks, oc_step, true);
269
270             add(aux_reg_kernel, jcp.typesize_in * kw * jcp.oc_block * jcp.nb_ic * div_up(jcp.ic_block, nbits));
271             dec(reg_overflow);
272             cmp(reg_overflow, 0);
273             jg(t_overflow_label, T_NEAR);
274         }
275         L(no_t_overflow_label);
276     }
277
278     Label skip_kh_loop;
279     mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
280     if (!jcp.exclude_pad || (jcp.exclude_pad &&
281                                (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad))) {
282         cmp(reg_kh, 0);
283         je(skip_kh_loop, T_NEAR);
284     }
285
286     Label kh_label;
287     L(kh_label);
288     {
289         oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks, oc_step, false);
290
291         add(aux_reg_kernel, jcp.typesize_in * kw * jcp.oc_block * jcp.nb_ic * div_up(jcp.ic_block, nbits));
292         add(aux_reg_input, jcp.typesize_in * iw * inp_mult);
293
294         dec(reg_kh);
295         cmp(reg_kh, 0);
296         jg(kh_label, T_NEAR);
297     }
298
299     L(skip_kh_loop);
300
301     if (!jcp.exclude_pad) {
302         mov(reg_overflow,  ptr[param1 + GET_OFF(b_overflow)]);
303         cmp(reg_overflow, 0);
304         je(no_b_overflow_label, T_NEAR);
305         L(b_overflow_label); {
306             oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks, oc_step, true);
307
308             add(aux_reg_kernel, jcp.typesize_in * kw * jcp.oc_block * jcp.nb_ic * div_up(jcp.ic_block, nbits));
309             dec(reg_overflow);
310             cmp(reg_overflow, 0);
311             jg(b_overflow_label, T_NEAR);
312         }
313         L(no_b_overflow_label);
314     }
315 }
316
317 template <cpu_isa_t isa>
318 void jit_uni_bin_conv_fwd_kernel<isa>::width_blk_step(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step)
319 {
320     int nbits = 8;
321     int repeats = isa == sse42 && oc_step > (jcp.oc_block / 2) ? 2 : 1;
322
323     for (int r = 0; r < repeats; r++)
324         for (int ii = 0; ii < oc_blocks; ii++)
325             for (int jj = 0; jj < ur_w; jj++)
326                 uni_vpxor(Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj),
327                           Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj),
328                           Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj));
329
330     kh_loop(ur_w, pad_l, pad_r, oc_blocks, oc_step);
331
332     const auto &p = attr_.post_ops_;
333     for (int r = 0; r < repeats; r++) {
334         int tail_size = isa == sse42 ? nstl::min(jcp.oc_block / 2, oc_step - r * jcp.oc_block / 2) : oc_step;
335         bool is_scalar_store = isa == sse42 ? tail_size < jcp.oc_block / 2 : tail_size < jcp.oc_block;
336
337         int kw_padding[ur_w];
338         if (jcp.exclude_pad) {
339             mov(reg_tmp_32, jcp.ic);
340             imul(reg_tmp_32,  ptr[param1 + GET_OFF(kh_padding)]);
341
342             for (int jj = 0; jj < ur_w; jj++)
343                 kw_padding[jj] = 0;
344
345             for (int ki = 0; ki < jcp.kw; ki++) {
346                 int jj_start = nstl::max(0, div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w));
347                 int jj_end = ur_w - nstl::max(0, div_up(ki * (jcp.dilate_w + 1) + pad_r -
348                                                         (jcp.kw - 1) * (jcp.dilate_w + 1), jcp.stride_w));
349                 for (int jj = jj_start; jj < jj_end; jj++) {
350                     kw_padding[jj]++;
351                 }
352             }
353         } else {
354             uni_vmovups(vmm_shift, ptr[reg_table + 128]);
355         }
356         uni_vmovups(vmm_scale, ptr[reg_table + 96]);
357
358         for (int jj = 0; jj < ur_w; jj++) {
359             if (jcp.exclude_pad) {
360                 mov(reg_shift, kw_padding[jj]);
361                 imul(reg_shift, reg_tmp_32);
362                 movq(Xmm(vmm_shift.getIdx()), reg_shift);
363                 uni_vbroadcastss(vmm_shift, Xmm(vmm_shift.getIdx()));
364                 uni_vcvtdq2ps(vmm_shift, vmm_shift);
365             }
366
367             for (int ii = 0; ii < oc_blocks; ii++) {
368                 uni_vcvtdq2ps(Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj), Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj));
369                 uni_vfmadd213ps(Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj), vmm_scale, vmm_shift);
370             }
371         }
372
373         int eltwise_inj_idx = 0;
374         int depthwise_inj_idx = 0;
375         int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
376         for (int i = 0; i < end_idx; i++) {
377             int start_idx = 1 + r * jcp.ur_w * jcp.nb_oc_blocking;
378
379             auto& post_op = p.entry_[i];
380             if (post_op.is_eltwise()) {
381                 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(start_idx, start_idx + oc_blocks * ur_w);
382                 eltwise_inj_idx++;
383             } else if (post_op.is_depthwise()) {
384                 pop(reg_oc_off);
385
386                 mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
387                 mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
388
389                 add(reg_d_weights, reg_oc_off);
390                 add(reg_d_bias, reg_oc_off);
391
392                 if (r == 1) {
393                     add(reg_d_weights, (jcp.oc_block / 2) * sizeof(float));
394                     add(reg_d_bias, (jcp.oc_block / 2) * sizeof(float));
395                 }
396
397                 for (int ii = 0; ii < oc_blocks; ii++) {
398                     depthwise_injectors[depthwise_inj_idx]->compute_vector_range(start_idx + ur_w * ii,
399                             start_idx + ur_w * ii + ur_w, reg_d_weights, reg_d_bias);
400
401                     add(reg_d_weights, jcp.oc_block * sizeof(float));
402                     add(reg_d_bias, jcp.oc_block * sizeof(float));
403                 }
404
405                 depthwise_inj_idx++;
406
407                 push(reg_oc_off);
408             } else if (post_op.is_sum(false)) {
409                 for (int ii = 0; ii < oc_blocks; ii++) {
410                     for (int jj = 0; jj < ur_w; jj++) {
411                         Vmm vmm_dst = Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj);
412
413                         if (is_scalar_store) {
414                             for (int oc = 0; oc < tail_size; oc++) {
415                                 int o_off =  jj * jcp.oc * jcp.ngroups + r * (jcp.oc_block / 2) + oc;
416
417                                 uni_vpxor(vmm_sum, vmm_sum, vmm_sum);
418                                 cvt2ps(jcp.dst_dt, vmm_sum, ptr[reg_output + o_off * jcp.typesize_out], true);
419
420                                 if (oc < jcp.oc_block / 2) {
421                                     uni_vpslldq(vmm_sum, vmm_sum, oc * sizeof(float));
422                                 } else {
423                                     Ymm ymm_prev_dst = Ymm(vmm_sum.getIdx());
424                                     vperm2i128(ymm_prev_dst, ymm_prev_dst, ymm_prev_dst, 0x01);
425                                     vpslldq(vmm_sum, vmm_sum, (oc - jcp.oc_block / 2) * sizeof(float));
426                                 }
427
428                                 uni_vaddps(vmm_dst, vmm_dst, vmm_sum);
429                             }
430                         } else {
431                             size_t o_off = ii * jcp.oc_block + jj * jcp.oc * jcp.ngroups + r * (jcp.oc_block / 2);
432
433                             cvt2ps(jcp.dst_dt, vmm_sum, ptr[reg_output + o_off * jcp.typesize_out], false);
434                             uni_vaddps(vmm_dst, vmm_dst, vmm_sum);
435                         }
436                     }
437                 }
438             }
439         }
440     }
441
442     if (jcp.with_binarization) {
443         int binarization_idx = p.find(primitive_kind::binarization);
444
445         pop(reg_oc_off);
446
447         mov(reg_b_weights, reinterpret_cast<size_t>(p.entry_[binarization_idx].binarization.weights_data));
448         add(reg_b_weights, reg_oc_off);
449
450         push(reg_oc_off);
451
452         for (int ii = 0; ii < oc_blocks; ii++) {
453             for (int jj = 0; jj < ur_w; jj++) {
454                 for (int r = 0; r < repeats; r++) {
455                     int tail_size = isa == sse42 ? nstl::min(jcp.oc_block / 2, oc_step - r * jcp.oc_block / 2) : oc_step;
456                     mov(reg_b_mask, (1 << tail_size) - 1);
457                     uni_vmovups(vmm_thr, ptr[reg_b_weights + (ii * jcp.oc_block + r * (jcp.oc_block / 2)) * sizeof(float)]);
458
459                     Vmm vmm_dst = Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj);
460
461                     uni_vcmpgtps(vmm_dst, vmm_dst, vmm_thr);
462
463                     if (r == 0) {
464                         uni_vmovmskps(reg_tmp_32, vmm_dst);
465                         and_(reg_tmp_64, reg_b_mask);
466                     } else {
467                         uni_vmovmskps(reg_tmp2_32, vmm_dst);
468                         and_(reg_tmp2_64, reg_b_mask);
469                         shl(reg_tmp2_32, 4);
470                         or_(reg_tmp_32, reg_tmp2_32);
471                     }
472
473                     if (r == repeats - 1) {
474                         const size_t o_off = (ii + jj * div_up(jcp.oc, nbits));
475                         mov(ptr[reg_output + o_off * jcp.typesize_out], reg_tmp_8);
476                     }
477                 }
478             }
479         }
480     } else {
481         for (int r = 0; r < repeats; r++) {
482             int tail_size = isa == sse42 ? nstl::min(jcp.oc_block / 2, oc_step - r * jcp.oc_block / 2) : oc_step;
483             bool is_scalar_store = isa == sse42 ? tail_size < jcp.oc_block / 2 : tail_size < jcp.oc_block;
484             if (is_scalar_store) {
485                 for (int jj = 0; jj < ur_w; jj++) {
486                     Vmm vmm_dst = Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + jj);
487                     Ymm ymm_dst = Ymm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + jj);
488
489                     for (int oc = 0; oc < tail_size; oc++) {
490                         size_t o_off;
491                         if (jcp.with_dw_conv)
492                             o_off = jj * jcp.oc_block + oc + r * (jcp.oc_block / 2);
493                         else
494                             o_off = jj * jcp.oc * jcp.ngroups + r * (jcp.oc_block / 2) + oc;
495
496                         store_dst(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, true);
497
498                         if (isa == sse42) {
499                             psrldq(vmm_dst, jcp.typesize_out);
500                         } else {
501                             vperm2i128(ymm_tmp, ymm_dst, ymm_dst, 0x01);
502                             vpalignr(ymm_dst, vmm_tmp, ymm_dst, jcp.typesize_out);
503                         }
504                     }
505                 }
506             } else {
507                 for (int ii = 0; ii < oc_blocks; ii++) {
508                     for (int jj = 0; jj < ur_w; jj++) {
509                         Vmm vmm_dst = Vmm(1 + r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj);
510
511                         size_t o_off;
512                         if (jcp.with_dw_conv)
513                             o_off = ((size_t) ii * jcp_dw_conv.kh * jcp.ow + jj) * jcp.oc_block +
514                                     r * (jcp.oc_block / 2);
515                         else
516                             o_off = ii * jcp.oc_block + jj * jcp.oc * jcp.ngroups + r * (jcp.oc_block / 2);
517
518                         store_dst(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, false);
519                     }
520                 }
521             }
522         }
523     }
524 }
525
526 template <cpu_isa_t isa>
527 inline void jit_uni_bin_conv_fwd_kernel<isa>::solve_common(int oc_blocks, int oc_step)
528 {
529     int ur_w = jcp.ur_w;
530     int ur_w_tail = jcp.ur_w_tail;
531     int n_oi = jcp.ow / ur_w;
532     int iw = jcp.iw;
533     int kw = jcp.kw;
534     int dilate_w = jcp.dilate_w + 1;
535     int str_w = jcp.stride_w;
536
537     int nbits = 8;
538     const int inp_mult = div_up(jcp.ic, nbits);
539     const int out_mult = jcp.with_dw_conv ? jcp.oc_block : jcp.with_binarization ? div_up(jcp.oc, nbits) : jcp.oc;
540
541     int l_pad = jcp.l_pad;
542     int r_pad = nstl::max(0, (jcp.ow - 1) * str_w + (kw - 1) * dilate_w
543             - (iw + l_pad - 1));
544     int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w
545             - (iw + l_pad - 1);
546     if (r_pad1 > 0) n_oi--;
547
548     mov(reg_input, reg_input_base);
549     mov(reg_output, reg_output_base);
550
551     push(reg_input_base);
552     push(reg_output_base);
553     push(reg_oc_work);
554     push(reg_oc_off);
555
556     if (l_pad > 0) {
557         n_oi--;
558         if (n_oi < 0 && r_pad1 > 0)
559             width_blk_step(ur_w, l_pad, r_pad1, oc_blocks, oc_step); // "lrpad"
560         else
561             width_blk_step(ur_w, l_pad, 0, oc_blocks, oc_step); // "lpad"
562         add(reg_input, jcp.typesize_in * (ur_w * str_w - l_pad) * inp_mult);
563         add(reg_output, jcp.typesize_out * ur_w * out_mult);
564     }
565
566     Label ow_loop_label;
567     xor_(oi_iter, oi_iter);
568
569     if (n_oi > 0) {
570         L(ow_loop_label);
571
572         width_blk_step(ur_w, 0, 0, oc_blocks, oc_step); // "middle"
573         add(reg_input, jcp.typesize_in * ur_w * str_w * inp_mult);
574         add(reg_output, jcp.typesize_out * ur_w * out_mult);
575
576         inc(oi_iter);
577         cmp(oi_iter, n_oi);
578         jl(ow_loop_label, T_NEAR);
579     }
580
581     if (r_pad1 > 0 && n_oi >=0) {
582         width_blk_step(ur_w, 0, r_pad1, oc_blocks, oc_step); // "rpad"
583         add(reg_input, jcp.typesize_in * ur_w * str_w * inp_mult);
584         add(reg_output, jcp.typesize_out * ur_w * out_mult);
585     }
586
587     if (ur_w_tail != 0)
588         width_blk_step(ur_w_tail, 0, r_pad, oc_blocks, oc_step); // "tail"
589
590     pop(reg_oc_off);
591     pop(reg_oc_work);
592     pop(reg_output_base);
593     pop(reg_input_base);
594 }
595
596 template <cpu_isa_t isa>
597 void jit_uni_bin_conv_fwd_kernel<isa>::generate()
598 {
599     const auto &p = attr_.post_ops_;
600     int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
601     for (int i = 0; i < end_idx; i++) {
602         auto &post_op = p.entry_[i];
603         if (post_op.is_eltwise()) {
604             eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>(
605                     this,
606                     post_op.eltwise.alg,
607                     post_op.eltwise.alpha,
608                     post_op.eltwise.beta
609             ));
610         } else if (post_op.is_depthwise()) {
611             depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<isa>(
612                     this,
613                     post_op.depthwise.alg
614             ));
615         }
616     }
617
618     this->preamble();
619
620     mov(reg_input_base, ptr[this->param1 + GET_OFF(src)]);
621     mov(reg_output_base, ptr[this->param1 + GET_OFF(dst)]);
622     mov(reg_kernel_base, ptr[this->param1 + GET_OFF(filt)]);
623
624     mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
625     mov(reg_oc_work, ptr[this->param1 + GET_OFF(oc_work)]);
626
627     mov(reg_oc_off,  ptr[param1 + GET_OFF(oc_off)]);
628     mov(reg_table, l_table);
629
630     Label main_loop_label;
631     Label tail_label;
632     Label exit_label;
633
634     cmp(reg_oc_work, jcp.nb_oc_blocking * jcp.oc_block);
635     jne(main_loop_label, T_NEAR);
636
637     solve_common(jcp.nb_oc_blocking, jcp.oc_block);
638
639     sub(reg_oc_work, jcp.nb_oc_blocking * jcp.oc_block);
640
641     jmp(exit_label, T_NEAR);
642
643     int nbits = 8;
644
645     L(main_loop_label); {
646         cmp(reg_oc_work, jcp.oc_block);
647         jl(tail_label, T_NEAR);
648
649         solve_common(1, jcp.oc_block);
650
651         sub(reg_oc_work, jcp.oc_block);
652         add(reg_kernel_base, jcp.oc_block * jcp.nb_ic * jcp.kh * jcp.kw * div_up(jcp.ic_block, nbits) * jcp.typesize_in);
653
654         if (jcp.with_dw_conv) {
655             add(reg_output_base, jcp.oc_block * jcp_dw_conv.kh * jcp.ow * jcp.typesize_out);
656         } else {
657             if (jcp.with_binarization)
658                 add(reg_output_base, jcp.typesize_out);
659             else
660                 add(reg_output_base, jcp.oc_block * jcp.typesize_out);
661         }
662
663         add(reg_oc_off, jcp.oc_block * sizeof(float));
664
665         jmp(main_loop_label, T_NEAR);
666     }
667
668     L(tail_label);
669
670     if (jcp.oc % jcp.oc_block != 0)
671         solve_common(1, jcp.oc % jcp.oc_block);
672
673     L(exit_label);
674
675     this->postamble();
676
677     prepare_table();
678
679     for (auto& inj : eltwise_injectors)
680         inj->prepare_table();
681 }
682
683 template <cpu_isa_t isa>
684 void jit_uni_bin_conv_fwd_kernel<isa>::prepare_table() {
685     const unsigned int cvals[] = {
686             0x02010100, // 0 1 1 2
687             0x03020201, // 1 2 2 3
688             0x03020201, // 1 2 2 3
689             0x04030302,  // 2 3 3 4
690             0x02010100, // 0 1 1 2
691             0x03020201, // 1 2 2 3
692             0x03020201, // 1 2 2 3
693             0x04030302,  // 2 3 3 4
694             0x0f0f0f0f,
695             0x000000ff,
696             0xc0000000, // -2.0f
697             0x01010101,
698             0x00010001
699     };
700
701     align(64);
702     L(l_table);
703     // offset = 0
704     for (size_t d = 0; d < 8; ++d) {
705         dd(cvals[d % 8]);
706     }
707     // offset = 32
708     for (size_t d = 0; d < 8; ++d) {
709         dd(cvals[8]);
710     }
711     // offset = 64
712     for (size_t d = 0; d < 8; ++d) {
713         dd(cvals[9]);
714     }
715     // offset = 96
716     for (size_t d = 0; d < 8; ++d) {
717         dd(cvals[10]);
718     }
719
720     // offset = 128
721     for (size_t d = 0; d < 8; ++d) {
722         dd(float2int(jcp.ic * jcp.kw * jcp.kh));
723     }
724
725     // offset = 160
726     for (size_t d = 0; d < 8; ++d) {
727         dd(cvals[11]);
728     }
729     // offset = 192
730     for (size_t d = 0; d < 8; ++d) {
731         dd(cvals[12]);
732     }
733     // offset = 224
734     for (size_t d = 0; d < 8; ++d) {
735         uint32_t mask = 0xffffffff >> (jcp.ic_padded - jcp.ic);
736         dd(mask);
737     }
738     // offset = 256
739     for (size_t d = 0; d < 8; ++d) {
740         uint32_t val = jcp.pad_value == 1.0f ? 0xffffffff : 0x00000000;
741         dd(val);
742     }
743 }
744
745 template <cpu_isa_t isa>
746 bool jit_uni_bin_conv_fwd_kernel<isa>::post_ops_ok(jit_bin_conv_conf_t &jcp, const primitive_attr_t &attr) {
747     const auto &p = attr.post_ops_;
748
749     auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
750     auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
751     auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
752     auto is_dw_conv = [&](int idx) { return p.entry_[idx].is_dw_conv(); };
753     auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
754     auto is_binarization = [&](int idx) { return p.entry_[idx].is_binarization(); };
755
756     switch (p.len_) {
757     case 0: return true; // no post_ops
758     case 1:
759         return (is_simple(0) || is_sum(0) || is_dw_conv(0) || is_binarization(0));
760     case 2:
761         return ((is_sum(0) && is_simple(1)) || (is_dw_conv(0) && is_simple(1)) ||
762                 (is_simple(0) && is_dw_conv(1)) || (is_dw_conv(0) && is_sum(1)) ||
763                 (is_simple(0) && is_simple(1)) || (is_simple(0) && is_binarization(1)) ||
764                 (is_dw_conv(0) && is_binarization(1)) || (is_simple(0) && is_sum(1)));
765     case 3:
766         return ((is_simple(0) && is_dw_conv(1) && is_simple(2)) ||
767                 (is_dw_conv(0) && is_sum(1) && is_simple(2)) ||
768                 (is_sum(0) && is_simple(1) && is_simple(2)) ||
769                 (is_simple(0) && is_sum(1) && is_simple(2)) ||
770                 (is_simple(0) && is_dw_conv(1) && is_binarization(2)) ||
771                 (is_simple(0) && is_simple(1) && is_dw_conv(2)));
772     case 4: return ((is_simple(0) && is_dw_conv(1) && is_sum(2) && is_simple(3)) ||
773                     (is_simple(0) && is_dw_conv(1) && is_simple(2) && is_binarization(3)) ||
774                     (is_simple(0) && is_simple(1) && is_dw_conv(2) && is_binarization(3)) ||
775                     (is_simple(0) && is_simple(1) && is_simple(2) && is_binarization(3)) ||
776                     (is_simple(0) && is_simple(1) && is_dw_conv(2) && is_simple(3)));
777     default: return false;
778     }
779
780     return false;
781 }
782
783 template <cpu_isa_t isa>
784 status_t jit_uni_bin_conv_fwd_kernel<isa>::init_conf(jit_bin_conv_conf_t &jcp,
785         const binary_convolution_desc_t &cd, const memory_desc_wrapper &src_d,
786         const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, const primitive_attr_t &attr)
787 {
788     if (!mayiuse(isa)) return status::unimplemented;
789
790     jcp.prop_kind = cd.prop_kind;
791
792     jcp.dst_dt = cd.dst_desc.data_type;
793
794     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
795
796     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
797
798     if (jcp.ngroups != 1)
799         return status::unimplemented;
800
801     jcp.mb = src_d.dims()[0];
802
803     int simd_w = isa == avx512_common ? 16 : 8;
804
805     jcp.ic = src_d.dims()[1] / jcp.ngroups;
806     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
807
808     jcp.oc_padded = rnd_up(jcp.oc, simd_w);
809
810     jcp.ih = src_d.dims()[2];
811     jcp.iw = src_d.dims()[3];
812     jcp.oh = dst_d.dims()[2];
813     jcp.ow = dst_d.dims()[3];
814
815     jcp.kh = weights_d.dims()[with_groups + 2];
816     jcp.kw = weights_d.dims()[with_groups + 3];
817
818     jcp.t_pad = cd.padding[0][0];
819     jcp.l_pad = cd.padding[0][1];
820
821     jcp.stride_h = cd.strides[0];
822     jcp.stride_w = cd.strides[1];
823
824     jcp.dilate_h = cd.dilates[0];
825     jcp.dilate_w = cd.dilates[1];
826
827     jcp.src_fmt = src_d.format();
828
829     if (!post_ops_ok(jcp, attr))
830         return status::unimplemented;
831
832     jcp.pad_value = cd.pad_value;
833     jcp.exclude_pad = jcp.pad_value == 0.0f;
834
835     const auto &p = attr.post_ops_;
836     int dw_conv_ind = p.find(primitive_kind::convolution);
837     jcp.with_dw_conv = dw_conv_ind != -1;
838     if (jcp.with_dw_conv) {
839         jcp.dw_conv_oh = jcp.oh;
840         jcp.dw_conv_ow = jcp.ow;
841         jcp.oh = p.entry_[dw_conv_ind].dw_conv.in_h;
842         jcp.ow = p.entry_[dw_conv_ind].dw_conv.in_w;
843     }
844     jcp.with_sum = p.find(primitive_kind::sum, 0, dw_conv_ind) != -1;
845     jcp.with_binarization = p.find(primitive_kind::binarization, 0, dw_conv_ind) != -1;
846
847     if (with_groups)
848         return status::unimplemented;
849
850     auto desired_weights_format = isa == avx512_common ? OhIw16o32i : OhIw8o32i;
851     bool args_ok = true
852         && src_d.format() == nhwc
853         && weights_d.format() == desired_weights_format
854         && dst_d.format() == nhwc;
855     if (!args_ok) return status::unimplemented;
856
857     jcp.ur_h = 1; /* no code-unrolling by h so far */
858     jcp.ur_w = 2;
859     if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
860     jcp.ur_w_tail = jcp.ow % jcp.ur_w;
861
862     jcp.nb_oc_blocking = isa == sse42 ? 2 : 4; /* the optimal value for the kernel */
863
864     args_ok = true
865         && jcp.l_pad <= jcp.ur_w
866         && IMPLICATION(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0)
867                 || (jcp.stride_w == 1 && jcp.stride_h == 1));
868     if (!args_ok) return status::unimplemented;
869
870     int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
871         + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
872
873     if (r_pad_no_tail > jcp.ur_w) {
874         /* recalculate ur_w, nb_oc_blocking and ur_w_tail */
875         jcp.ur_w = r_pad_no_tail + 1;
876         jcp.nb_oc_blocking = ((16 - 1)-jcp.ur_w)/jcp.ur_w;
877         jcp.ur_w_tail = jcp.ow % jcp.ur_w;
878         /* check again ... */
879         r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
880             + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
881         if ((r_pad_no_tail > jcp.ur_w) || (jcp.ow < jcp.ur_w))
882             return status::unimplemented;
883     }
884     if (jcp.l_pad > jcp.ur_w) return status::unimplemented;
885
886     jcp.ic_block = 32;
887     jcp.nb_ic = div_up(jcp.ic, jcp.ic_block);
888     jcp.ic_padded = rnd_up(jcp.ic, jcp.ic_block);
889
890     jcp.oc_block = simd_w;
891     jcp.nb_oc = div_up(jcp.oc, jcp.oc_block);
892
893     jcp.nb_ic_blocking = 1;
894
895     jcp.src_dt = cd.src_desc.data_type;
896     jcp.bia_dt = mkldnn_f32;
897     jcp.dst_dt = jcp.with_binarization ? mkldnn_bin : mkldnn_f32;
898
899     jcp.typesize_in = types::data_type_size(jcp.src_dt);
900     jcp.typesize_out = types::data_type_size(jcp.dst_dt);
901     jcp.typesize_acc = sizeof(int32_t);
902
903     return status::success;
904 }
905
906 template <cpu_isa_t isa>
907 void jit_uni_bin_conv_fwd_kernel<isa>::init_scratchpad(
908         memory_tracking::registrar_t &scratchpad, const jit_bin_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw_conv) {
909     if (jcp.with_dw_conv) {
910         const int nthreads = mkldnn_get_max_threads();
911         size_t dw_conv_buffer_size_ = (size_t)jcp_dw_conv.kh * jcp_dw_conv.iw * jcp_dw_conv.ch_block * jcp.nb_oc_blocking;
912         scratchpad.book(key_dw_conv_buffer, sizeof(float) * dw_conv_buffer_size_ * nthreads);
913
914         if (jcp.oc != jcp.oc_padded)
915             scratchpad.book(key_dw_conv_padded_bias, sizeof(float) * jcp.oc_padded);
916     }
917 }
918
919 template struct jit_uni_bin_conv_fwd_kernel<sse42>;
920 template struct jit_uni_bin_conv_fwd_kernel<avx2>;
921 template struct jit_uni_bin_conv_fwd_kernel<avx512_common>;
922
923 }
924 }
925 }