Add a section of how to link IE with CMake project (#99)
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_x8s8s32x_1x1_conv_kernel.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "nstl.hpp"
19 #include "type_helpers.hpp"
20 #include "utils.hpp"
21 #include "cpu_memory.hpp"
22
23 #include "jit_uni_x8s8s32x_1x1_conv_kernel.hpp"
24
25 #define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
26
27 #include <iostream>
28
29 namespace mkldnn {
30 namespace impl {
31 namespace cpu {
32
33 using namespace mkldnn::impl::prop_kind;
34 using namespace mkldnn::impl::memory_format;
35 using namespace mkldnn::impl::utils;
36 using namespace mkldnn::impl::types;
37
38 using namespace Xbyak;
39
40 template <cpu_isa_t isa>
41 void jit_uni_x8s8s32x_1x1_conv_fwd_kernel<isa>::cvt2ps(data_type_t type_in,
42         Vmm vmm_in, const Xbyak::Operand &op) {
43     switch (type_in) {
44     case data_type::f32:
45     case data_type::s32: vmovups(vmm_in, op); break;
46     case data_type::s8: vpmovsxbd(vmm_in, op); break;
47     case data_type::u8: vpmovzxbd(vmm_in, op); break;
48     default: assert(!"unsupported data type");
49     }
50     if (type_in != data_type::f32)
51         vcvtdq2ps(vmm_in, vmm_in);
52 }
53
54 template <cpu_isa_t isa>
55 void jit_uni_x8s8s32x_1x1_conv_fwd_kernel<isa>::loop_os(int oc_loop_blk)
56 {
57     mov(aux_reg_dst_data, reg_dst_data);
58
59     Label loop_os;
60     Label loop_ow_tail;
61
62     mov(reg_ow_loop_work, jcp.ow);
63
64     L(loop_os); {
65         assert(jcp.os_block == jcp.ur);
66         cmp(reg_ow_loop_work, jcp.ow_tail);
67         je(loop_ow_tail, T_NEAR);
68
69         ic_loop(oc_loop_blk, jcp.ur);
70
71         sub(reg_ow_loop_work, jcp.ur);
72
73         add(reg_src_data, jcp.os_loop_src_step);
74         add(aux_reg_dst_data, jcp.os_loop_dst_step);
75
76         sub(reg_loop_os_iter, jcp.os_block);
77         cmp(reg_loop_os_iter, jcp.os_block);
78         jge(loop_os, T_NEAR);
79
80         L(loop_ow_tail); {
81             if (jcp.ow_tail > 0) {
82                 ic_loop(oc_loop_blk, jcp.ow_tail);
83             }
84
85             add(reg_src_data, jcp.os_loop_src_tail_step);
86             add(aux_reg_dst_data, jcp.os_loop_dst_tail_step);
87
88             mov(reg_ow_loop_work, jcp.ow);
89
90             sub(reg_loop_os_iter, jcp.ow_tail);
91             cmp(reg_loop_os_iter, 0);
92             jg(loop_os, T_NEAR);
93         }
94     }
95 }
96
97 template <cpu_isa_t isa>
98 void jit_uni_x8s8s32x_1x1_conv_fwd_kernel<isa>::ic_loop(int oc_loop_blk, int ur)
99 {
100     auto vreg_wei = [=](int i) {
101         return Vmm(ur * oc_loop_blk + i);
102     };
103
104     auto vreg_accum_vmm = [=](int i, int j) {
105         return Vmm(j * oc_loop_blk + i);
106     };
107
108     auto vreg_accum_xmm = [=](int i, int j) {
109         return Xmm(j * oc_loop_blk + i);
110     };
111
112     auto src_ptr = [=](int u, int j) {
113         size_t offt = j * jcp.ic * jcp.stride_w + u*jcp.ic_block;
114         return ptr[aux_reg_src_data + jcp.typesize_in * offt];
115     };
116
117     auto wei_ptr = [=](int u, int i) {
118         size_t offt = i*jcp.nb_ic*jcp.oc_block*jcp.ic_block + u*jcp.ic_block * jcp.oc_block;
119         return ptr[aux_reg_weight_data + offt * jcp.typesize_in];
120     };
121
122     auto output_ptr = [=](int i, int j) {
123         return ptr[aux_reg_dst_data + (i * jcp.oc_block + j * jcp.oc) *
124                                               jcp.typesize_out];
125     };
126
127     auto init = [&]() {
128         for (int i = 0; i < oc_loop_blk; ++i) {
129             for (int j = 0; j < ur; ++j) {
130                 auto vmm_acc = vreg_accum_vmm(i, j);
131                 uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
132             }
133         }
134
135         for (int i = 0; i < oc_loop_blk; ++i)
136             uni_vmovdqu(vreg_wei(i), wei_ptr(0, i));
137
138         uni_vpbroadcastd(vreg_src, src_ptr(0, 0));
139     };
140
141     auto store = [=]() {
142         mov(reg_scales, ptr[this->param1 + GET_OFF(scales)]);
143         uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
144
145         for (int j = 0; j < ur; ++j)
146             for (int i = 0; i < oc_loop_blk; ++i) {
147                 int b_off = i*jcp.oc_block;
148
149                 if (jcp.with_bias) {
150                     switch (jcp.bia_dt) {
151                         case data_type::f32:
152                         case data_type::s32: vmovups(vmm_bias, ptr[reg_bias_data + b_off*jcp.typesize_bia]); break;
153                         case data_type::s8: vpmovsxbd(vmm_bias, ptr[reg_bias_data + b_off*jcp.typesize_bia]); break;
154                         case data_type::u8: vpmovzxbd(vmm_bias, ptr[reg_bias_data + b_off*jcp.typesize_bia]); break;
155                         default: assert(!"unsupported dst data type");
156                     }
157                 }
158                 if (jcp.bia_dt != data_type::f32)
159                     vcvtdq2ps(vmm_bias, vmm_bias);
160
161                 Vmm vmm_dst = vreg_accum_vmm(i, j);
162                 Xmm xmm_dst = vreg_accum_xmm(i, j);
163
164                 vcvtdq2ps(vmm_dst, vmm_dst);
165
166                 if (jcp.with_bias)
167                     vaddps(vmm_dst, vmm_dst, vmm_bias);
168
169                 int s_off = jcp.is_oc_scale * (sizeof(float) * (i*jcp.oc_block));
170                 vmulps(vmm_dst, vmm_dst, ptr[reg_scales + s_off]);
171
172                 if (jcp.with_sum) {
173                     Ymm vmm_prev_dst = Ymm(12);
174                     cvt2ps(jcp.dst_dt, vmm_prev_dst, output_ptr(i, j));
175                     vaddps(vmm_dst, vmm_prev_dst);
176                 }
177
178                 if (maybe_relu(0))
179                     vmaxps(vmm_dst, vmm_zero, vmm_dst);
180
181                 if (maybe_relu(1))
182                     vmaxps(vmm_dst, vmm_zero, vmm_dst);
183
184                 if (jcp.dst_dt != data_type::f32) {
185                     if (attr_.round_mode_ == round_mode::nearest)
186                         if (isa == avx512_common) {
187                             vcvtps2dq(vmm_dst | T_rn_sae, vmm_dst);
188                         } else {
189                             vcvtps2dq(vmm_dst, vmm_dst);
190                         }
191                     else if (attr_.round_mode_ == round_mode::down) {
192                         if (isa == avx512_common) {
193                             vcvtps2dq(vmm_dst | T_rd_sae, vmm_dst);
194                         } else {
195                             vroundps(vmm_dst, vmm_dst, 1);
196                             vcvtps2dq(vmm_dst, vmm_dst);
197                         }
198                     } else
199                         assert(!"unimplemented");
200                 }
201
202                 switch (jcp.dst_dt) {
203                     case data_type::f32:
204                     case data_type::s32: vmovups(output_ptr(i, j), vmm_dst); break;
205                     case data_type::s8:
206                         if (isa == avx512_common) {
207                             vpmovsdb(xmm_dst, vmm_dst);
208                             vmovups(output_ptr(i, j), xmm_dst);
209                         } else if (isa == avx2) {
210                             Ymm ymm_dst = Ymm(vmm_dst.getIdx());
211
212                             vpackssdw(ymm_dst, ymm_dst, ymm_dst);
213                             vpermq(ymm_dst, ymm_dst, 0x08);
214                             vpacksswb(xmm_dst, xmm_dst, xmm_dst);
215                             vmovq(output_ptr(i, j), xmm_dst);
216                         }
217                         break;
218                     case data_type::u8:
219                         if (isa == avx512_common) {
220                             vpmovusdb(xmm_dst, vmm_dst);
221                             vmovups(output_ptr(i, j), xmm_dst);
222                         } else if (isa == avx2) {
223                             Ymm ymm_dst = Ymm(vmm_dst.getIdx());
224
225                             vpackusdw(ymm_dst, ymm_dst, ymm_dst);
226                             vpermq(ymm_dst, ymm_dst, 0x08);
227                             vpackuswb(xmm_dst, xmm_dst, xmm_dst);
228                             vmovq(output_ptr(i, j), xmm_dst);
229                         }
230                         break;
231                     default: assert(!"unknown dst_dt");
232                 }
233             }
234     };
235
236     auto fma_block = [=]() {
237         for (int j = 0; j < ur; ++j) {
238             for (int i = 0; i < oc_loop_blk; i++) {
239                 vpmaddubsw(vreg_sum_0, vreg_src, vreg_wei(i));
240                 vpmaddwd(vreg_sum_0, vreg_sum_0, vmm_one);
241                 vpaddd(vreg_accum_vmm(i, j), vreg_accum_vmm(i, j), vreg_sum_0);
242
243                 if (j == ur - 1) {
244                     uni_vmovdqu(vreg_wei(i), wei_ptr(1, i));
245                 }
246             }
247
248             if (j < ur - 1)
249                 uni_vpbroadcastd(vreg_src, src_ptr(0, j + 1));
250         }
251
252         uni_vpbroadcastd(vreg_src, src_ptr(1, 0));
253     };
254
255     mov(aux_reg_weight_data, reg_weight_data);
256     mov(aux_reg_src_data, reg_src_data);
257
258     init();
259
260     Label ic_loop;
261     Label exit;
262
263     xor_(reg_loop_ic_iter, reg_loop_ic_iter);
264     L(ic_loop); {
265         cmp(reg_loop_ic_iter, jcp.nb_ic);
266         jge(exit, T_NEAR);
267
268         fma_block();
269
270         add(aux_reg_src_data, jcp.ic_block * jcp.typesize_in);
271         add(aux_reg_weight_data, jcp.ic_block * jcp.oc_block * jcp.typesize_in);
272         inc(reg_loop_ic_iter);
273         jmp(ic_loop, T_NEAR);
274     }
275
276     L(exit);
277
278     store();
279 }
280
281 template <cpu_isa_t isa>
282 void jit_uni_x8s8s32x_1x1_conv_fwd_kernel<isa>::generate()
283 {
284     preamble();
285
286     mov(reg_scratch, 0x1);
287     movq(xmm_one, reg_scratch);
288     vpbroadcastw(vmm_one, xmm_one);
289
290     mov(reg_weight_data, ptr[param1 + GET_OFF(oc_data)]);
291     mov(reg_dst_data,    ptr[param1 + GET_OFF(output_data)]);
292     if (jcp.with_bias) {
293         mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
294     }
295
296     mov(reg_oc_loop_work, ptr[param1 + GET_OFF(oc_dim)]);
297     mov(reg_src_data, ptr[param1 + GET_OFF(is_data)]);
298     mov(reg_loop_os_iter,  ptr[param1 + GET_OFF(os_dim)]);
299
300     Label oc_blocks_tail_label;
301     Label exit_label;
302
303     int oc_blocks_tail = jcp.nb_oc % jcp.nb_oc_blocking;
304
305     cmp(reg_oc_loop_work, jcp.nb_oc_blocking);
306     jne(oc_blocks_tail ? oc_blocks_tail_label : exit_label, T_NEAR);
307
308     loop_os(jcp.nb_oc_blocking); // channel main loop
309     jmp(exit_label, T_NEAR);
310
311     if (oc_blocks_tail) {
312         L(oc_blocks_tail_label);
313
314         cmp(reg_oc_loop_work, oc_blocks_tail);
315         jne(exit_label, T_NEAR);
316
317         loop_os(oc_blocks_tail); // channel tail loop
318     }
319
320     L(exit_label);
321
322     postamble();
323 }
324
325 template <cpu_isa_t isa>
326 bool jit_uni_x8s8s32x_1x1_conv_fwd_kernel<isa>::post_ops_ok(
327         jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
328     const auto &p = attr.post_ops_;
329
330     auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
331     auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
332
333     switch (p.len_) {
334         case 0: return true; // no post_ops
335         case 1: return !jcp.with_eltwise && (is_relu(0) || is_sum(0)); // sum OR relu
336         case 2: return !jcp.with_eltwise && (is_sum(0) && is_relu(1)); // sum->relu
337         default: return false;
338     }
339
340     return false;
341 }
342
343 template <cpu_isa_t isa>
344 bool jit_uni_x8s8s32x_1x1_conv_fwd_kernel<isa>::maybe_relu(int position) {
345     using namespace primitive_kind;
346     const auto &p = attr_.post_ops_;
347
348     if (position == 0) {
349         /* relu before sum */
350         return false
351                || jcp.with_eltwise
352                || p.contain(eltwise, 0)
353                || (jcp.dst_dt == data_type::u8 && !p.contain(sum, 0));
354     } else if (position == 1) {
355         /* relu after sum */
356         const int sum_idx = p.contain(sum, 0)
357                             ? 0 : (p.contain(sum, 1) ? 1 : -1);
358         if (sum_idx == -1)
359             return false;
360
361         return false
362                || p.contain(eltwise, sum_idx + 1)
363                || jcp.dst_dt == data_type::u8;
364     }
365
366     return false;
367 }
368
369 template <cpu_isa_t isa>
370 status_t jit_uni_x8s8s32x_1x1_conv_fwd_kernel<isa>::init_conf(jit_1x1_conv_conf_t &jcp,
371         const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
372         const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
373         const memory_desc_wrapper &bias_pd, const primitive_attr_t &attr,
374         bool with_relu, float relu_negative_slope)
375 {
376     if (!mayiuse(isa)) return status::unimplemented;
377
378     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
379
380     jcp.prop_kind = cd.prop_kind;
381
382     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
383     jcp.mb = src_d.dims()[0];
384
385     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
386     jcp.ic = src_d.dims()[1] / jcp.ngroups;
387
388     jcp.ih = src_d.dims()[2];
389     jcp.iw = src_d.dims()[3];
390     jcp.oh = dst_d.dims()[2];
391     jcp.ow = dst_d.dims()[3];
392
393     jcp.kh = weights_d.dims()[with_groups + 2];
394     jcp.kw = weights_d.dims()[with_groups + 3];
395
396     jcp.t_pad = cd.padding[0][0];
397     jcp.l_pad = cd.padding[0][1];
398
399     jcp.stride_h = cd.strides[0];
400     jcp.stride_w = cd.strides[1];
401
402     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
403     jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
404     jcp.dst_dt = cd.dst_desc.data_type;
405
406     jcp.src_fmt = src_d.format();
407     jcp.with_eltwise = with_relu;
408     jcp.eltwise_alpha = relu_negative_slope;
409
410     jcp.os = jcp.oh * jcp.ow;
411     jcp.is = jcp.ih * jcp.iw;
412
413     auto desired_wei_fmt = OhIw8o4i;
414     auto desired_gr_wei_fmt = gOhIw8o4i;
415
416     int simd_w = isa == avx512_common ? 16 : 8;
417
418     bool args_ok = true
419         && jcp.ngroups == 1
420         && src_d.format() == nhwc
421         && one_of(weights_d.format(), desired_wei_fmt, desired_gr_wei_fmt)
422         && one_of(cd.bias_desc.format, memory_format::undef, any, x)
423         && dst_d.format() == nhwc
424         && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0
425         && jcp.t_pad == 0 && jcp.l_pad == 0
426         && jcp.kh == 1 && jcp.kw == 1
427         && jcp.stride_h == 1 && jcp.stride_w == 1;
428
429     if (!args_ok) return status::unimplemented;
430
431     jcp.ic_block = 4;
432     jcp.oc_block = simd_w;
433
434     jcp.ur = 2;
435     jcp.ow_tail = jcp.ow % jcp.ur;
436
437     int oc_blocking{ 0 };
438     int oc_blocking_max{ 0 };
439     int os_blocking{ 0 };
440     int os_blocking_max{ 0 };
441     int ic_blocking{ 0 };
442
443     jcp.ic_dim = jcp.ic;
444     jcp.oc_dim = jcp.oc;
445     jcp.is_dim = jcp.is;
446     jcp.os_block = jcp.ur;
447
448     jcp.typesize_in = types::data_type_size(src_d.data_type());
449     jcp.typesize_out = types::data_type_size(dst_d.data_type());
450     jcp.typesize_acc = sizeof(int32_t);
451     jcp.typesize_bia = jcp.with_bias
452                        ? types::data_type_size(bias_pd.data_type())
453                        : 0;
454
455     const auto &oscales = attr.output_scales_;
456     jcp.is_oc_scale = oscales.mask_ == 1 << 1;
457
458     const auto &p = attr.post_ops_;
459     jcp.with_sum = p.find(primitive_kind::sum) != -1;
460
461     assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
462
463     jcp.ic_loop_src_step = jcp.ic_block * jcp.ic_loop_unroll * jcp.typesize_in;
464     jcp.ic_loop_wei_step = jcp.ic_block * jcp.ic_loop_unroll * jcp.oc_block * jcp.typesize_in;
465
466     jcp.os_loop_dst_step = jcp.ur * jcp.oc * jcp.typesize_out;
467     jcp.os_loop_acc_step = jcp.ur * jcp.oc_block * jcp.typesize_acc;
468     jcp.os_loop_src_step = jcp.stride_w * jcp.ur * jcp.ic * jcp.typesize_in;
469     jcp.os_loop_dst_tail_step = jcp.ow_tail * jcp.oc * jcp.typesize_out;
470     jcp.os_loop_acc_tail_step = jcp.ow_tail * jcp.oc_block * jcp.typesize_acc;
471     jcp.os_loop_src_tail_step = jcp.stride_w * jcp.ow_tail * jcp.ic * jcp.typesize_in
472              + ((jcp.stride_h-1)*jcp.iw*jcp.ic*jcp.typesize_in);
473
474     oc_blocking     = 4 * jcp.oc_block;
475     oc_blocking_max = 4 * jcp.oc_block;
476     os_blocking     = 48; // affects oc balancing across threads
477     os_blocking_max = 320;
478     ic_blocking     = 4*128; // affects L1$ utilization
479
480     assert(oc_blocking);
481     assert(oc_blocking_max);
482     assert(os_blocking);
483     assert(os_blocking_max);
484     assert(ic_blocking);
485
486     assert(jcp.os_block % jcp.ur == 0);
487     jcp.ur_tail = jcp.is_dim % jcp.ur;
488
489     jcp.nb_oh_blocking     = nstl::max(1, os_blocking     / jcp.ow);
490     jcp.nb_oh_blocking_max = nstl::max(1, os_blocking_max / jcp.ow);
491     jcp.nb_oc_blocking     = oc_blocking / jcp.oc_block;
492     jcp.nb_oc_blocking_max = oc_blocking_max / jcp.oc_block;
493     jcp.nb_ic_blocking     = ic_blocking / jcp.ic_block;
494
495     jcp.nb_oc = div_up(jcp.oc_dim, jcp.oc_block);
496
497     jcp.nb_ic = jcp.ic / jcp.ic_block;
498
499     return status::success;
500 }
501
502 template struct jit_uni_x8s8s32x_1x1_conv_fwd_kernel<avx2>;
503 template struct jit_uni_x8s8s32x_1x1_conv_fwd_kernel<sse42>;
504
505 }
506 }
507 }