Add a section of how to link IE with CMake project (#99)
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_i8i8_pooling.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <math.h>
18
19 #include "mkldnn_types.h"
20
21 #include "mkldnn_thread.hpp"
22 #include "utils.hpp"
23
24 #include "jit_generator.hpp"
25
26 #include "jit_avx512_core_i8i8_pooling.hpp"
27
28 namespace mkldnn {
29 namespace impl {
30 namespace cpu {
31
32 using namespace Xbyak;
33
34 using namespace mkldnn::impl::utils;
35 using namespace mkldnn::impl::memory_format;
36 using namespace mkldnn::impl::utils;
37 using namespace mkldnn::impl::types;
38 using namespace alg_kind;
39
40 struct jit_avx512_core_i8i8_pool_fwd_ker_t: public jit_generator {
41     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_i8i8_pool_fwd_ker_t)
42
43     struct call_params_t {
44         const char *src_i8;
45         const char *dst_i8;
46         size_t kw_range;
47         size_t kh_range;
48         float idivider;
49     };
50
51     Reg64 reg_ptr_src_i8 = r8;
52     Reg64 reg_ptr_dst_i8 = r9;
53
54     Reg64 ki = r10;
55     Reg64 kj = r11;
56     Reg64 reg_kw = r12;
57     Reg64 reg_kh = r13;
58     Reg64 c_iter = r14;
59
60     Reg64 aux_reg_src_h = rax;
61     Reg64 aux_reg_src_w = rbx;
62
63     Reg64 reg_tmp = rdx;
64
65     Reg64 reg_mask = r15;
66
67     Opmask k_cmp_mask = Opmask(7);
68
69     Opmask mask(int idx) {
70         return Opmask(6 - idx);
71     }
72
73     Xmm xmm_tmp = Xmm(0);
74     Zmm vreg_tmp = Zmm(30);
75     Zmm vreg_zeros = Zmm(31);
76
77     size_t sizeof_src_dt() const { return data_type_size(jpp.src_dt); }
78     size_t sizeof_dst_dt() const { return data_type_size(jpp.dst_dt); }
79
80     /* max pooling */
81     Zmm vreg_src(int idx) {
82         return Zmm(idx);
83     }
84
85     Zmm vreg_dst(int idx) {
86         return Zmm(jpp.ur_c + idx);
87     }
88
89     /* avg pooling */
90     Zmm vreg_src_s32(int jj, int ll) {
91         return Zmm(12*jj + ll);
92     }
93
94     Zmm vreg_dst_s32(int jj, int ll) {
95         return Zmm(12*jj + ll + 4);
96     }
97
98     Zmm vreg_dst_f32(int jj, int ll) {
99         return Zmm(12*jj + ll + 8);
100     }
101
102     void (*ker_)(const call_params_t *);
103     jit_pool_conf_t jpp;
104
105     void init_tmp_reg();
106     void init_mask();
107
108     void load_src(int jj, int ll, int c_tail);
109     void store_dst(int jj, int ll, int c_tail);
110
111     void compute_avg_step(int ur_c, int c_tail);
112     void compute_max_step(int ur_c, int c_tail);
113     void compute_step(int ur_c, int c_tail);
114
115     void compute_c_block();
116     void generate();
117
118     static status_t init_conf(jit_pool_conf_t &jpp,
119         const pooling_desc_t &pd, const memory_desc_wrapper &src_d,
120         const memory_desc_wrapper &dst_d);
121
122     jit_avx512_core_i8i8_pool_fwd_ker_t(const jit_pool_conf_t &jpp_)
123            : jpp(jpp_) {
124         generate();
125         ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(
126                        getCode()));
127     }
128 };
129
130 void jit_avx512_core_i8i8_pool_fwd_ker_t::load_src(int jj, int ll, int c_tail) {
131     using namespace data_type;
132
133     int c_block = jpp.c_block;
134     int ur_c = jpp.ur_c;
135
136     switch (jpp.alg) {
137         case pooling_max: {
138             auto offset = jj*c_block*sizeof_src_dt();
139             if (jj == ur_c - 1 && c_tail) {
140                 if (jpp.src_dt == data_type::s32) {
141                     vmovups(vreg_src(jj) | mask(0),
142                             ptr[aux_reg_src_w + offset]);
143                 } else {
144                     vmovdqu8(vreg_src(jj) | mask(0),
145                             ptr[aux_reg_src_w + offset]);
146                 }
147             } else {
148                 vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]);
149             }
150             break;
151         }
152         case pooling_avg_include_padding:
153         case pooling_avg_exclude_padding: {
154             auto offset = (ll*(c_block/4) + jj*c_block)*sizeof_src_dt();
155             if (jj == jpp.ur_c - 1 && c_tail) {
156                 if (jpp.tail[ll]) {
157                     switch (jpp.src_dt) {
158                         case s32:
159                             vmovups(vreg_src_s32(jj, ll) | mask(ll),
160                                     ptr[aux_reg_src_w + offset]);
161                             break;
162                         case s8:
163                             vpmovsxbd(vreg_src_s32(jj, ll) | mask(ll),
164                                     ptr[aux_reg_src_w + offset]);
165                             break;
166                         case u8:
167                             vpmovzxbd(vreg_src_s32(jj, ll) | mask(ll),
168                                     ptr[aux_reg_src_w + offset]);
169                             break;
170                         default: assert(!"unsupported src data type");
171                     }
172                 }
173             } else {
174                 switch (jpp.src_dt) {
175                     case s32:
176                         vmovups(vreg_src_s32(jj, ll),
177                                 ptr[aux_reg_src_w + offset]);
178                         break;
179                     case s8:
180                         vpmovsxbd(vreg_src_s32(jj, ll),
181                                 ptr[aux_reg_src_w + offset]);
182                         break;
183                     case u8:
184                         vpmovzxbd(vreg_src_s32(jj, ll),
185                                 ptr[aux_reg_src_w + offset]);
186                         break;
187                     default: assert(!"unsupported src data type");
188                 }
189             }
190             break;
191         }
192         default: assert(!"unsupported algorithm");
193     }
194 }
195
196 void jit_avx512_core_i8i8_pool_fwd_ker_t::store_dst(int jj, int ll,
197         int c_tail) {
198     using namespace data_type;
199
200     int c_block = jpp.c_block;
201     int ur_c = jpp.ur_c;
202
203     switch(jpp.alg) {
204         case pooling_max: {
205             auto offset = jj*c_block*sizeof_dst_dt();
206             if (jj == ur_c - 1 && c_tail) {
207                 if (jpp.src_dt == data_type::s32) {
208                     vmovups(ptr[reg_ptr_dst_i8 + offset],
209                            vreg_dst(jj) | mask(0));
210                 } else {
211                     vmovdqu8(ptr[reg_ptr_dst_i8 + offset],
212                             vreg_dst(jj) | mask(0));
213                 }
214             } else {
215                 vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj));
216             }
217             break;
218         }
219         case pooling_avg_include_padding:
220         case pooling_avg_exclude_padding: {
221             auto offset = (ll*(c_block/4) + jj*c_block)*sizeof_dst_dt();
222             if (jj == ur_c - 1 && c_tail) {
223                 if (jpp.tail[ll]) {
224                     switch (jpp.dst_dt) {
225                         case s32:
226                             vmovups(ptr[reg_ptr_dst_i8 + offset],
227                                 vreg_dst_s32(jj, ll) | mask(ll));
228                             break;
229                         case s8:
230                             vpmovdb(ptr[reg_ptr_dst_i8 + offset],
231                                 vreg_dst_s32(jj, ll) | mask(ll));
232                             break;
233                         case u8:
234                             vpmovusdb(ptr[reg_ptr_dst_i8 + offset],
235                                 vreg_dst_s32(jj, ll) | mask(ll));
236                             break;
237                         default: assert(!"unsupported dst data_type");
238                     }
239                 }
240             } else {
241                 switch (jpp.dst_dt) {
242                     case s32:
243                         vmovups(ptr[reg_ptr_dst_i8 + offset],
244                             vreg_dst_s32(jj, ll));
245                         break;
246                     case s8:
247                         vpmovdb(ptr[reg_ptr_dst_i8 + offset],
248                             vreg_dst_s32(jj, ll));
249                         break;
250                     case u8:
251                         vpmovusdb(ptr[reg_ptr_dst_i8 + offset],
252                             vreg_dst_s32(jj, ll));
253                         break;
254                     default: assert(!"unsuppotred dst data_type");
255                 }
256             }
257             break;
258         }
259         default: assert(!"unsupported pooling algorithm");
260     }
261 }
262
263 void jit_avx512_core_i8i8_pool_fwd_ker_t::compute_max_step(int ur_c, int c_tail)
264 {
265     Label l_kw, l_kh;
266
267     int iw = jpp.iw;
268     int c = jpp.c;
269
270     for (int jj = 0; jj < ur_c; jj++)
271         vmovups(vreg_dst(jj), vreg_tmp);
272
273     mov(aux_reg_src_h, reg_ptr_src_i8);
274
275     xor_(kj, kj);
276     L(l_kh);
277     {
278         mov(aux_reg_src_w, aux_reg_src_h);
279         xor_(ki, ki);
280         L(l_kw);
281         {
282             for (int jj = 0; jj < ur_c; jj++) {
283                 load_src(jj, 0, c_tail);
284                 if (jpp.src_dt == data_type::s32) {
285                     vpcmpd(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
286                     vpblendmd(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj),
287                             vreg_src(jj));
288                 } else {
289                     if (jpp.src_dt == data_type::s8)
290                         vpcmpb(k_cmp_mask, vreg_dst(jj), vreg_src(jj),
291                                 _cmp_lt_os);
292                     else
293                         vpcmpub(k_cmp_mask, vreg_dst(jj), vreg_src(jj),
294                                 _cmp_lt_os);
295                     vpblendmb(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj),
296                             vreg_src(jj));
297                 }
298             }
299             add(aux_reg_src_w, c * sizeof_src_dt());
300             inc(ki);
301             cmp(ki, reg_kw);
302             jl(l_kw, T_NEAR);
303         }
304         add(aux_reg_src_h, iw * c * sizeof_src_dt());
305         inc(kj);
306         cmp(kj, reg_kh);
307         jl(l_kh, T_NEAR);
308     }
309
310     for (int jj = 0; jj < ur_c; jj++)
311         store_dst(jj, 0, c_tail);
312 }
313
314 void jit_avx512_core_i8i8_pool_fwd_ker_t::compute_avg_step(int ur_c, int c_tail)
315 {
316     using namespace data_type;
317
318     Label l_kw, l_kh;
319
320     int iw = jpp.iw;
321     int c = jpp.c;
322
323     int num_ll = jpp.src_dt == data_type::s32 ? 1 : 4;
324
325     for (int jj = 0; jj < ur_c; jj++) {
326         for (int ll = 0; ll < 4; ll++) {
327             uni_vpxor(vreg_src_s32(jj, ll),
328                     vreg_src_s32(jj, ll), vreg_src_s32(jj, ll));
329             uni_vpxor(vreg_dst_s32(jj, ll),
330                     vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll));
331         }
332     }
333
334     mov(aux_reg_src_h, reg_ptr_src_i8);
335
336     xor_(kj, kj);
337     L(l_kh);
338     {
339         mov(aux_reg_src_w, aux_reg_src_h);
340         xor_(ki, ki);
341         L(l_kw);
342         {
343             for (int jj = 0; jj < ur_c; jj++) {
344                 for (int ll = 0; ll < num_ll; ll++) {
345                     load_src(jj, ll, c_tail);
346                     vpaddd(vreg_dst_s32(jj, ll),
347                             vreg_dst_s32(jj, ll), vreg_src_s32(jj, ll));
348                 }
349             }
350             add(aux_reg_src_w, c * sizeof_src_dt());
351             inc(ki);
352             cmp(ki, reg_kw);
353             jl(l_kw, T_NEAR);
354         }
355         add(aux_reg_src_h, iw * c * sizeof_src_dt());
356         inc(kj);
357         cmp(kj, reg_kh);
358         jl(l_kh, T_NEAR);
359     }
360
361     for (int jj = 0; jj < ur_c; jj++) {
362         for (int ll = 0; ll < num_ll; ll++) {
363             vcvtdq2ps(vreg_dst_f32(jj, ll), vreg_dst_s32(jj, ll));
364             vfmadd132ps(vreg_dst_f32(jj, ll), vreg_zeros, vreg_tmp);
365             vcvtps2dq(vreg_dst_s32(jj, ll) | T_rn_sae, vreg_dst_f32(jj, ll));
366
367             store_dst(jj, ll, c_tail);
368         }
369     }
370 }
371
372 void jit_avx512_core_i8i8_pool_fwd_ker_t::compute_step(int ur_c, int c_tail) {
373     switch (jpp.alg) {
374         case pooling_max:
375             compute_max_step(ur_c, c_tail); break;
376         case pooling_avg_include_padding:
377         case pooling_avg_exclude_padding:
378             compute_avg_step(ur_c, c_tail); break;
379         default: assert(!"unsupported pooling algorithm");
380     }
381 }
382
383 void jit_avx512_core_i8i8_pool_fwd_ker_t::compute_c_block(){
384     Label l_main_loop;
385
386     int nb_c = jpp.nb_c;
387     int c_block = jpp.c_block;
388     int ur_c = jpp.ur_c;
389     int ur_c_tail = jpp.ur_c_tail;
390     int c_steps = nb_c / ur_c;
391     int c_tail = jpp.c_tail;
392
393     xor_(c_iter, c_iter);
394     if (c_steps > 0) {
395         L(l_main_loop); {
396             compute_step(ur_c, 0);
397             add(reg_ptr_src_i8, ur_c*c_block*sizeof_src_dt());
398             add(reg_ptr_dst_i8, ur_c*c_block*sizeof_dst_dt());
399             inc(c_iter);
400             cmp(c_iter, c_steps);
401             jl(l_main_loop, T_NEAR);
402         }
403     }
404
405     if (ur_c_tail != 0) {
406         compute_step(ur_c_tail, c_tail);
407     }
408 }
409
410 void jit_avx512_core_i8i8_pool_fwd_ker_t::init_mask() {
411     for (int i = 0; i < 4; i++) {
412         mov(reg_mask, jpp.tail[i]);
413         kmovq(mask(i), reg_mask);
414     }
415 }
416
417 void jit_avx512_core_i8i8_pool_fwd_ker_t::init_tmp_reg() {
418     using namespace data_type;
419
420     switch (jpp.alg) {
421         case pooling_avg_include_padding:
422         case pooling_avg_exclude_padding:
423             mov(reg_tmp, ptr[abi_param1 + offsetof(call_params_t, idivider)]);
424             movq(xmm_tmp, reg_tmp);
425             vpbroadcastd(vreg_tmp, xmm_tmp);
426             break;
427         case pooling_max:
428             switch (jpp.src_dt) {
429                 case s32:
430                     mov(reg_tmp, nstl::numeric_limits<int32_t>::lowest());
431                     break;
432                 case s8:
433                     mov(reg_tmp, nstl::numeric_limits<int8_t>::lowest());
434                     break;
435                 case u8:
436                     mov(reg_tmp, nstl::numeric_limits<uint8_t>::lowest());
437                     break;
438                 default: assert(!"unsupported src data_type");
439             }
440
441             movq(xmm_tmp, reg_tmp);
442             if (jpp.src_dt == s32)
443                 vpbroadcastd(vreg_tmp, xmm_tmp);
444             else
445                 vpbroadcastb(vreg_tmp, xmm_tmp);
446             break;
447         default: assert(!"unsupported pooling algorithm");
448     }
449
450 }
451
452 void jit_avx512_core_i8i8_pool_fwd_ker_t::generate() {
453     preamble();
454
455 #   define READ_PARAM(reg, field) \
456         mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
457     READ_PARAM(reg_ptr_src_i8, src_i8);
458     READ_PARAM(reg_ptr_dst_i8, dst_i8);
459     READ_PARAM(reg_kw, kw_range);
460     READ_PARAM(reg_kh, kh_range);
461
462 #   undef READ_PARAM
463
464     init_tmp_reg();
465     init_mask();
466
467     uni_vpxor(vreg_zeros, vreg_zeros, vreg_zeros);
468
469     compute_c_block();
470
471     postamble();
472 }
473
474 status_t jit_avx512_core_i8i8_pool_fwd_ker_t::init_conf(jit_pool_conf_t &jpp,
475         const pooling_desc_t &pd, const memory_desc_wrapper &src_d,
476         const memory_desc_wrapper &dst_d) {
477     if (!mayiuse(avx512_core)) {
478         return status::unimplemented;
479     }
480
481     jpp.mb = src_d.dims()[0];
482     jpp.c = src_d.dims()[1];
483     jpp.ih = src_d.dims()[2];
484     jpp.iw = src_d.dims()[3];
485     jpp.oh = dst_d.dims()[2];
486     jpp.ow = dst_d.dims()[3];
487
488     jpp.stride_h = pd.strides[0];
489     jpp.stride_w = pd.strides[1];
490     jpp.kh = pd.kernel[0];
491     jpp.kw = pd.kernel[1];
492
493     jpp.t_pad = pd.padding[0][0];
494     jpp.l_pad = pd.padding[0][1];
495
496     jpp.alg = pd.alg_kind;
497
498     jpp.src_dt = pd.src_desc.data_type;
499     jpp.dst_dt = pd.dst_desc.data_type;
500
501     jpp.c_block = 64 / (jpp.src_dt == data_type::s32 ? 4 : 1);
502     jpp.c_tail = jpp.c % jpp.c_block;
503     jpp.nb_c = jpp.c / jpp.c_block;
504     jpp.ur_c = 1;
505     jpp.ur_c_tail = jpp.nb_c - (jpp.nb_c / jpp.ur_c)*jpp.ur_c +
506             (jpp.c_tail != 0);
507
508     size_t tail_mask = (1ULL << jpp.c_tail) - 1;
509
510     switch(jpp.alg) {
511         case pooling_max:
512             jpp.tail[0] = tail_mask;
513             jpp.tail[1] = 0;
514             jpp.tail[2] = 0;
515             jpp.tail[3] = 0;
516             break;
517         case pooling_avg_include_padding:
518         case pooling_avg_exclude_padding:
519             jpp.tail[0] = tail_mask & 0xffff;
520             for (size_t i = 1, m = tail_mask; i < 4; i++) {
521                 m = m >> 16;
522                 jpp.tail[i] = m & 0xffff;
523             }
524             break;
525         default: return status::unimplemented;
526     }
527
528     return status::success;
529 }
530
531 status_t jit_avx512_core_i8i8_pooling_fwd_t::pd_t::jit_conf() {
532     return jit_avx512_core_i8i8_pool_fwd_ker_t::init_conf(jpp_,
533        desc_, src_pd_.desc(), dst_pd_.desc());
534 }
535
536 jit_avx512_core_i8i8_pooling_fwd_t::
537 jit_avx512_core_i8i8_pooling_fwd_t(const pd_t *pd,
538           const input_vector &inputs, const output_vector &outputs)
539     : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd), ker_(nullptr)
540 { ker_ = new jit_avx512_core_i8i8_pool_fwd_ker_t(conf_.jpp_); }
541
542 jit_avx512_core_i8i8_pooling_fwd_t::
543 ~jit_avx512_core_i8i8_pooling_fwd_t() { delete ker_; }
544
545 void jit_avx512_core_i8i8_pooling_fwd_t::execute_forward() {
546     auto src_i8 = reinterpret_cast<const char *>(input_memory(0));
547     auto dst_i8 = reinterpret_cast<char *>(memory());
548
549     const memory_desc_wrapper src_d(conf_.src_pd());
550     const memory_desc_wrapper dst_d(conf_.dst_pd());
551
552     const auto &jpp = conf_.jpp_;
553
554     parallel_nd(jpp.mb, jpp.oh, jpp.ow,
555             [&](int n, int oh, int ow) {
556         const int ih = nstl::max(oh*jpp.stride_h - jpp.t_pad, 0);
557         const int iw = nstl::max(ow*jpp.stride_w - jpp.l_pad, 0);
558
559         const int kh_start = nstl::max(0, jpp.t_pad - oh * jpp.stride_h);
560         const int kh_end = nstl::min(jpp.kh,
561                 jpp.ih + jpp.t_pad - oh * jpp.stride_h);
562         const int kw_start = nstl::max(0, jpp.l_pad - ow * jpp.stride_w);
563         const int kw_end = nstl::min(jpp.kw,
564                 jpp.iw + jpp.l_pad - ow * jpp.stride_w);
565
566         auto p = jit_avx512_core_i8i8_pool_fwd_ker_t::call_params_t();
567         p.src_i8 = &src_i8[
568             src_d.blk_off(n, 0, ih, iw) * src_d.data_type_size()];
569         p.dst_i8 = &dst_i8[
570             dst_d.blk_off(n, 0, oh, ow) * dst_d.data_type_size()];
571         p.kw_range = (size_t)(kw_end - kw_start);
572         p.kh_range = (size_t)(kh_end - kh_start);
573         p.idivider = 1.0f / ((jpp.alg == pooling_avg_exclude_padding) ?
574             p.kh_range*p.kw_range : jpp.kw*jpp.kh);
575
576         ker_->ker_(&p);
577     });
578 }
579
580 }
581 }
582 }