updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_sse42_i8i8_pooling.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <math.h>
18
19 #include "mkldnn_types.h"
20
21 #include "mkldnn_thread.hpp"
22 #include "utils.hpp"
23
24 #include "jit_generator.hpp"
25
26 #include "jit_sse42_i8i8_pooling.hpp"
27
28 namespace mkldnn {
29 namespace impl {
30 namespace cpu {
31
32 using namespace Xbyak;
33
34 using namespace mkldnn::impl::utils;
35 using namespace mkldnn::impl::memory_format;
36 using namespace mkldnn::impl::utils;
37 using namespace mkldnn::impl::types;
38 using namespace alg_kind;
39
40 struct call_params_t {
41     const char *src_i8;
42     const char *dst_i8;
43     size_t kw_range;
44     size_t kh_range;
45     float idivider;
46 };
47
48 struct jit_sse42_i8i8_pool_fwd_ker_t : public jit_generator {
49     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse42_i8i8_pool_fwd_ker_t)
50
51     Reg64 reg_ptr_src_i8 = r8;
52     Reg64 reg_ptr_dst_i8 = r9;
53
54     Reg64 ki = r10;
55     Reg64 kj = r11;
56     Reg64 reg_kw = r12;
57     Reg64 reg_kh = r13;
58     Reg64 c_iter = r14;
59
60     Reg64 aux_reg_src_h = rax;
61     Reg64 aux_reg_src_w = rbx;
62
63     Reg64 reg_tmp = rdx;
64     Reg64 reg_src_64 = r15;
65     Reg32 reg_src_32 = r15d;
66     Reg8 reg_src_8 = r15b;
67
68     size_t sizeof_src_dt() const { return data_type_size(jpp.src_dt); }
69     size_t sizeof_dst_dt() const { return data_type_size(jpp.dst_dt); }
70
71     Xmm xmm_tmp = Xmm(0);
72     Xmm vreg_tmp = Xmm(14);
73     Xmm vreg_zeros = Xmm(15);
74
75     /* max pooling */
76     Xmm vmm_src(int jj, int ii) {
77         return Xmm(2*jj + ii);
78     }
79
80     Xmm xmm_src(int jj) {
81         return Xmm(2*jj);
82     }
83
84     Xmm vmm_dst(int jj, int ii) {
85         return Xmm(2*jj + ii + 2 * jpp.ur_c);
86     }
87
88     Xmm xmm_dst(int jj) {
89         return Xmm(2*jj + 2 * jpp.ur_c);
90     }
91
92     /* avg pooling */
93     Xmm vmm_src_s32(int jj, int ii) {
94         return Xmm(2*jj + ii);
95     }
96
97     Xmm xmm_src_s32(int jj, int ii) {
98         return Xmm(2*jj + ii);
99     }
100
101     Xmm vmm_dst_s32(int jj, int ii) {
102         return Xmm(2*jj + ii + 2 * jpp.ur_c);
103     }
104
105     Ymm ymm_dst_s32(int jj, int ii) {
106         return Ymm(2*jj + ii + 2 * jpp.ur_c);
107     }
108
109     Xmm xmm_dst_s32(int jj, int ii) {
110         return Xmm(2*jj + ii + 2 * jpp.ur_c);
111     }
112
113     Xmm vmm_dst_f32(int jj, int ii) {
114         return Xmm(2*jj + ii + 4 * jpp.ur_c);
115     }
116
117     void (*ker_)(const call_params_t *);
118     jit_pool_conf_t jpp;
119
120     void init_tmp_reg();
121
122     void load_src(int jj, int c_step);
123     void store_dst(int jj, int c_step);
124
125     void compute_avg_step(int ur_c, int c_step);
126     void compute_max_step(int ur_c, int c_step);
127     void compute_step(int ur_c, int c_step);
128
129     void compute_c_block();
130     void generate();
131
132     static status_t init_conf(jit_pool_conf_t &jpp,
133         const pooling_desc_t &pd, const memory_desc_wrapper &src_d,
134         const memory_desc_wrapper &dst_d);
135
136     jit_sse42_i8i8_pool_fwd_ker_t(const jit_pool_conf_t &jpp_)
137            : jpp(jpp_) {
138         generate();
139         ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(
140                        getCode()));
141     }
142 };
143
144 void jit_sse42_i8i8_pool_fwd_ker_t::load_src(int jj, int c_step) {
145     using namespace data_type;
146
147     int repeats = c_step != 1 ? 2 : 1;
148     switch (jpp.alg) {
149         case pooling_max: {
150             auto offset = jj*c_step*sizeof_src_dt();
151             if (c_step == jpp.c_block) {
152                 for (int ii = 0; ii < repeats; ii++)
153                     uni_vmovups(vmm_src(jj, ii), ptr[aux_reg_src_w + offset + (jpp.c_block / 2) * ii * sizeof_src_dt()]);
154             } else if (c_step == 1) {
155                 if (jpp.src_dt == s32) {
156                     movsd(xmm_src(jj), ptr[aux_reg_src_w + offset]);
157                 } else {
158                     mov(reg_src_8, ptr[aux_reg_src_w + offset]);
159                     movq(xmm_src(jj), reg_src_64);
160                 }
161             }
162             break;
163         }
164         case pooling_avg_include_padding:
165         case pooling_avg_exclude_padding: {
166             auto offset = jj*c_step*sizeof_src_dt();
167             switch (jpp.src_dt) {
168                 case s32:
169                     if (c_step == jpp.c_block) {
170                         for (int ii = 0; ii < repeats; ii++)
171                             uni_vmovups(vmm_src_s32(jj, ii), ptr[aux_reg_src_w + offset + (jpp.c_block / 2) * ii * sizeof_src_dt()]);
172                     } else if (c_step == 1) {
173                         movsd(xmm_src_s32(jj, 0), ptr[aux_reg_src_w + offset]);
174                     }
175                     break;
176                 case s8:
177                     if (c_step == jpp.c_block) {
178                         for (int ii = 0; ii < repeats; ii++) {
179                             movd(xmm_src_s32(jj, ii), ptr[aux_reg_src_w + offset + (jpp.c_block / 2) * ii * sizeof_src_dt()]);
180
181                             uni_vpmovsxbd(vmm_src_s32(jj, ii), xmm_src_s32(jj, ii));
182                         }
183                     } else if (c_step == 1) {
184                         movsx(reg_src_32, ptr[aux_reg_src_w + offset]);
185                         movq(xmm_src_s32(jj, 0), reg_src_64);
186                     }
187                     break;
188                 case u8:
189                     if (c_step == jpp.c_block) {
190                         for (int ii = 0; ii < repeats; ii++) {
191                             movd(xmm_src_s32(jj, ii), ptr[aux_reg_src_w + offset + (jpp.c_block / 2) * ii * sizeof_src_dt()]);
192
193                             uni_vpmovzxbd(vmm_src_s32(jj, ii), xmm_src_s32(jj, ii));
194                         }
195                     } else if (c_step == 1) {
196                         movzx(reg_src_32, ptr[aux_reg_src_w + offset]);
197                         movq(xmm_src_s32(jj, 0), reg_src_64);
198                     }
199                     break;
200                 default: assert(!"unsupported src data type");
201             }
202             break;
203         }
204         default: assert(!"unsupported algorithm");
205     }
206 }
207
208 void jit_sse42_i8i8_pool_fwd_ker_t::store_dst(int jj, int c_step) {
209     using namespace data_type;
210
211     int repeats = c_step != 1 ? 2 : 1;
212     switch(jpp.alg) {
213         case pooling_max: {
214             auto offset = jj*c_step*sizeof_dst_dt();
215             if (c_step == jpp.c_block) {
216                 for (int ii = 0; ii < repeats; ii++)
217                     uni_vmovups(ptr[reg_ptr_dst_i8 + offset + (jpp.c_block / 2) * ii * sizeof_dst_dt()], vmm_dst(jj, ii));
218             } else if (c_step == 1) {
219                 if (jpp.src_dt == s32) {
220                     movq(reg_src_64, xmm_dst(jj));
221                     mov(ptr[reg_ptr_dst_i8 + offset], reg_src_32);
222                 } else {
223                     movq(reg_src_64, xmm_dst(jj));
224                     mov(ptr[reg_ptr_dst_i8 + offset], reg_src_8);
225                 }
226             }
227             break;
228         }
229         case pooling_avg_include_padding:
230         case pooling_avg_exclude_padding: {
231             auto offset = jj*c_step*sizeof_dst_dt();
232             switch (jpp.dst_dt) {
233                 case s32:
234                     if (c_step == jpp.c_block) {
235                         for (int ii = 0; ii < repeats; ii++)
236                             uni_vmovups(ptr[reg_ptr_dst_i8 + offset + (jpp.c_block / 2) * ii * sizeof_dst_dt()], vmm_dst_s32(jj, ii));
237                     } else if (c_step == 1) {
238                         movq(reg_src_64, xmm_dst_s32(jj, 0));
239                         mov(ptr[reg_ptr_dst_i8 + offset], reg_src_32);
240                     }
241                     break;
242                 case s8:
243                     if (c_step == jpp.c_block) {
244                         for (int ii = 0; ii < repeats; ii++) {
245                             uni_vpackssdw(vmm_dst_s32(jj, ii), vmm_dst_s32(jj, ii), vmm_dst_s32(jj, ii));
246                             uni_vpacksswb(xmm_dst_s32(jj, ii), xmm_dst_s32(jj, ii), xmm_dst_s32(jj, ii));
247
248                             movd(ptr[reg_ptr_dst_i8 + offset + (jpp.c_block / 2) * ii * sizeof_dst_dt()], xmm_dst_s32(jj, ii));
249                         }
250                     } else if (c_step == 1) {
251                         vpackssdw(vmm_dst_s32(jj, 0), vmm_dst_s32(jj, 0), vmm_dst_s32(jj, 0));
252                         vpacksswb(xmm_dst_s32(jj, 0), xmm_dst_s32(jj, 0), xmm_dst_s32(jj, 0));
253                         movq(reg_src_64, xmm_dst_s32(jj, 0));
254                         mov(ptr[reg_ptr_dst_i8 + offset], reg_src_8);
255                     }
256                     break;
257                 case u8:
258                     if (c_step == jpp.c_block) {
259                         for (int ii = 0; ii < repeats; ii++) {
260                             uni_vpackusdw(vmm_dst_s32(jj, ii), vmm_dst_s32(jj, ii), vmm_dst_s32(jj, ii));
261                             uni_vpackuswb(xmm_dst_s32(jj, ii), xmm_dst_s32(jj, ii), xmm_dst_s32(jj, ii));
262
263                             movd(ptr[reg_ptr_dst_i8 + offset + (jpp.c_block / 2) * ii * sizeof_dst_dt()], xmm_dst_s32(jj, ii));
264                         }
265                     } else if (c_step == 1) {
266                         vpackusdw(vmm_dst_s32(jj, 0), vmm_dst_s32(jj, 0), vmm_dst_s32(jj, 0));
267                         vpackuswb(xmm_dst_s32(jj, 0), xmm_dst_s32(jj, 0), xmm_dst_s32(jj, 0));
268                         movq(reg_src_64, xmm_dst_s32(jj, 0));
269                         mov(ptr[reg_ptr_dst_i8 + offset], reg_src_8);
270                     }
271                     break;
272                 default: assert(!"unsuppotred dst data_type");
273             }
274             break;
275         }
276         default: assert(!"unsupported pooling algorithm");
277     }
278 }
279
280 void jit_sse42_i8i8_pool_fwd_ker_t::compute_max_step(int ur_c, int c_step)
281 {
282     Label l_kw, l_kh;
283
284     int iw = jpp.iw;
285     int c = jpp.c;
286
287     int repeats = c_step != 1 ? 2 : 1;
288
289     for (int jj = 0; jj < ur_c; jj++) {
290         for (int ii = 0; ii < repeats; ii++) {
291             uni_vmovups(vmm_dst(jj, ii), vreg_tmp);
292         }
293     }
294
295     mov(aux_reg_src_h, reg_ptr_src_i8);
296
297     xor_(kj, kj);
298     L(l_kh);
299     {
300         mov(aux_reg_src_w, aux_reg_src_h);
301         xor_(ki, ki);
302         L(l_kw);
303         {
304             for (int jj = 0; jj < ur_c; jj++) {
305                 load_src(jj, c_step);
306
307                 for (int ii = 0; ii < repeats; ii++) {
308                     if (jpp.src_dt == data_type::s32) {
309                         uni_vpmaxsd(vmm_dst(jj, ii), vmm_dst(jj, ii), vmm_src(jj, ii));
310                     } else {
311                         if (jpp.src_dt == data_type::s8)
312                             uni_vpmaxsb(vmm_dst(jj, ii), vmm_dst(jj, ii), vmm_src(jj, ii));
313                         else
314                             uni_vpmaxub(vmm_dst(jj, ii), vmm_dst(jj, ii), vmm_src(jj, ii));
315                     }
316                 }
317             }
318             add(aux_reg_src_w, c * sizeof_src_dt());
319             inc(ki);
320             cmp(ki, reg_kw);
321             jl(l_kw, T_NEAR);
322         }
323         add(aux_reg_src_h, iw * c * sizeof_src_dt());
324         inc(kj);
325         cmp(kj, reg_kh);
326         jl(l_kh, T_NEAR);
327     }
328
329     for (int jj = 0; jj < ur_c; jj++)
330         store_dst(jj, c_step);
331 }
332
333 void jit_sse42_i8i8_pool_fwd_ker_t::compute_avg_step(int ur_c, int c_step)
334 {
335     using namespace data_type;
336
337     Label l_kw, l_kh;
338
339     int iw = jpp.iw;
340     int c = jpp.c;
341
342     int repeats = c_step != 1 ? 2 : 1;
343
344     for (int jj = 0; jj < ur_c; jj++) {
345         for (int ii = 0; ii < repeats; ii++) {
346             uni_vpxor(vmm_src_s32(jj, ii), vmm_src_s32(jj, ii), vmm_src_s32(jj, ii));
347             uni_vpxor(vmm_dst_s32(jj, ii), vmm_dst_s32(jj, ii), vmm_dst_s32(jj, ii));
348         }
349     }
350
351     mov(aux_reg_src_h, reg_ptr_src_i8);
352
353     xor_(kj, kj);
354     L(l_kh);
355     {
356         mov(aux_reg_src_w, aux_reg_src_h);
357         xor_(ki, ki);
358         L(l_kw);
359         {
360             for (int jj = 0; jj < ur_c; jj++) {
361                 load_src(jj, c_step);
362
363                 for (int ii = 0; ii < repeats; ii++) {
364                     uni_vpaddd(vmm_dst_s32(jj, ii), vmm_dst_s32(jj, ii), vmm_src_s32(jj, ii));
365                 }
366             }
367             add(aux_reg_src_w, c * sizeof_src_dt());
368             inc(ki);
369             cmp(ki, reg_kw);
370             jl(l_kw, T_NEAR);
371         }
372         add(aux_reg_src_h, iw * c * sizeof_src_dt());
373         inc(kj);
374         cmp(kj, reg_kh);
375         jl(l_kh, T_NEAR);
376     }
377
378     for (int jj = 0; jj < ur_c; jj++) {
379         for (int ii = 0; ii < repeats; ii++) {
380             uni_vcvtdq2ps(vmm_dst_f32(jj, ii), vmm_dst_s32(jj, ii));
381
382             mulps(vmm_dst_f32(jj, ii), vreg_tmp);
383
384             uni_vcvtps2dq(vmm_dst_s32(jj, ii), vmm_dst_f32(jj, ii));
385         }
386
387         store_dst(jj, c_step);
388     }
389 }
390
391 void jit_sse42_i8i8_pool_fwd_ker_t::compute_step(int ur_c, int c_step) {
392     switch (jpp.alg) {
393         case pooling_max:
394             compute_max_step(ur_c, c_step); break;
395         case pooling_avg_include_padding:
396         case pooling_avg_exclude_padding:
397             compute_avg_step(ur_c, c_step); break;
398         default: assert(!"unsupported pooling algorithm");
399     }
400 }
401
402 void jit_sse42_i8i8_pool_fwd_ker_t::compute_c_block() {
403     Label l_main_loop;
404     Label l_tail_loop;
405     Label exit;
406
407     int ur_c = jpp.ur_c;
408
409     xor_(c_iter, c_iter);
410
411     L(l_main_loop);
412     {
413         cmp(c_iter, jpp.c - ur_c * jpp.c_block);
414         jg(l_tail_loop, T_NEAR);
415
416         compute_step(ur_c, jpp.c_block);
417
418         add(reg_ptr_src_i8, ur_c * jpp.c_block * sizeof_src_dt());
419         add(reg_ptr_dst_i8, ur_c * jpp.c_block * sizeof_dst_dt());
420         add(c_iter, ur_c * jpp.c_block);
421         jmp(l_main_loop);
422     }
423
424     L(l_tail_loop);
425     {
426         cmp(c_iter, jpp.c - ur_c);
427         jg(exit, T_NEAR);
428
429         compute_step(ur_c, 1);
430
431         add(reg_ptr_src_i8, ur_c * sizeof_src_dt());
432         add(reg_ptr_dst_i8, ur_c * sizeof_dst_dt());
433         add(c_iter, ur_c);
434         jmp(l_tail_loop);
435     }
436
437     L(exit);
438 }
439
440 void jit_sse42_i8i8_pool_fwd_ker_t::init_tmp_reg() {
441     using namespace data_type;
442
443     switch (jpp.alg) {
444         case pooling_avg_include_padding:
445         case pooling_avg_exclude_padding:
446             mov(reg_tmp, ptr[abi_param1 + offsetof(call_params_t, idivider)]);
447             movq(xmm_tmp, reg_tmp);
448             uni_vpbroadcastd(vreg_tmp, xmm_tmp);
449             break;
450         case pooling_max:
451             switch (jpp.src_dt) {
452                 case s32:
453                     mov(reg_tmp, nstl::numeric_limits<int32_t>::lowest());
454                     break;
455                 case s8:
456                     mov(reg_tmp, nstl::numeric_limits<int8_t>::lowest());
457                     break;
458                 case u8:
459                     mov(reg_tmp, nstl::numeric_limits<uint8_t>::lowest());
460                     break;
461                 default: assert(!"unsupported src data_type");
462             }
463
464             movq(xmm_tmp, reg_tmp);
465             if (jpp.src_dt == s32) {
466                 uni_vpbroadcastd(vreg_tmp, xmm_tmp);
467             } else {
468                 movups(vreg_tmp, xmm_tmp);
469                 uni_vpxor(xmm_tmp, xmm_tmp, xmm_tmp);
470                 pshufb(vreg_tmp, xmm_tmp);
471             }
472             break;
473         default: assert(!"unsupported pooling algorithm");
474     }
475
476 }
477
478 void jit_sse42_i8i8_pool_fwd_ker_t::generate() {
479     preamble();
480
481 #   define READ_PARAM(reg, field) \
482         mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
483     READ_PARAM(reg_ptr_src_i8, src_i8);
484     READ_PARAM(reg_ptr_dst_i8, dst_i8);
485     READ_PARAM(reg_kw, kw_range);
486     READ_PARAM(reg_kh, kh_range);
487
488 #   undef READ_PARAM
489
490     init_tmp_reg();
491
492     uni_vpxor(vreg_zeros, vreg_zeros, vreg_zeros);
493
494     compute_c_block();
495
496     postamble();
497 }
498
499 status_t jit_sse42_i8i8_pool_fwd_ker_t::init_conf(jit_pool_conf_t &jpp,
500         const pooling_desc_t &pd, const memory_desc_wrapper &src_d,
501         const memory_desc_wrapper &dst_d) {
502     if (!mayiuse(sse42)) {
503         return status::unimplemented;
504     }
505
506     jpp.mb = src_d.dims()[0];
507     jpp.c = src_d.dims()[1];
508     jpp.ih = src_d.dims()[2];
509     jpp.iw = src_d.dims()[3];
510     jpp.oh = dst_d.dims()[2];
511     jpp.ow = dst_d.dims()[3];
512
513     jpp.stride_h = pd.strides[0];
514     jpp.stride_w = pd.strides[1];
515     jpp.kh = pd.kernel[0];
516     jpp.kw = pd.kernel[1];
517
518     jpp.t_pad = pd.padding[0][0];
519     jpp.l_pad = pd.padding[0][1];
520
521     jpp.alg = pd.alg_kind;
522
523     jpp.src_dt = pd.src_desc.data_type;
524     jpp.dst_dt = pd.dst_desc.data_type;
525
526     jpp.c_block = jpp.alg == pooling_max ? 32 / (jpp.src_dt == data_type::s32 ? 4 : 1) : 8;
527     jpp.c_tail = jpp.c % jpp.c_block;
528     jpp.nb_c = jpp.c / jpp.c_block;
529     jpp.ur_c = 1;
530     jpp.ur_c_tail = jpp.nb_c - (jpp.nb_c / jpp.ur_c)*jpp.ur_c + (jpp.c_tail != 0);
531
532     return status::success;
533 }
534
535 status_t jit_sse42_i8i8_pooling_fwd_t::pd_t::jit_conf() {
536     return jit_sse42_i8i8_pool_fwd_ker_t::init_conf(jpp_,
537        desc_, src_pd_.desc(), dst_pd_.desc());
538 }
539
540 jit_sse42_i8i8_pooling_fwd_t::jit_sse42_i8i8_pooling_fwd_t(const pd_t *apd,
541           const input_vector &inputs, const output_vector &outputs)
542     : cpu_primitive_t(apd, inputs, outputs), ker_(nullptr)
543 { ker_ = new jit_sse42_i8i8_pool_fwd_ker_t(pd()->jpp_); }
544
545 jit_sse42_i8i8_pooling_fwd_t::~jit_sse42_i8i8_pooling_fwd_t() {
546     delete ker_;
547 }
548
549 void jit_sse42_i8i8_pooling_fwd_t::execute_forward() const {
550     auto src_i8 = reinterpret_cast<const char *>(input_memory(0));
551     auto dst_i8 = reinterpret_cast<char *>(memory());
552
553     const memory_desc_wrapper src_d(pd()->src_pd());
554     const memory_desc_wrapper dst_d(pd()->dst_pd());
555
556     const auto &jpp = pd()->jpp_;
557
558     parallel_nd(jpp.mb, jpp.oh, jpp.ow,
559         [&](int n, int oh, int ow) {
560         const int ih = nstl::max(oh * jpp.stride_h - jpp.t_pad, 0);
561         const int iw = nstl::max(ow * jpp.stride_w - jpp.l_pad, 0);
562
563         const int kh_start = nstl::max(0, jpp.t_pad - oh * jpp.stride_h);
564         const int kh_end = nstl::min(jpp.kh,
565                                      jpp.ih + jpp.t_pad - oh * jpp.stride_h);
566         const int kw_start = nstl::max(0, jpp.l_pad - ow * jpp.stride_w);
567         const int kw_end = nstl::min(jpp.kw,
568                                      jpp.iw + jpp.l_pad - ow * jpp.stride_w);
569
570         auto p = call_params_t();
571         p.src_i8 = &src_i8[
572                 src_d.blk_off(n, 0, ih, iw) * src_d.data_type_size()];
573         p.dst_i8 = &dst_i8[
574                 dst_d.blk_off(n, 0, oh, ow) * dst_d.data_type_size()];
575         p.kw_range = (size_t) (kw_end - kw_start);
576         p.kh_range = (size_t) (kh_end - kh_start);
577         p.idivider = 1.0f / ((jpp.alg == pooling_avg_exclude_padding) ?
578                              p.kh_range * p.kw_range : jpp.kw * jpp.kh);
579
580         ker_->ker_(&p);
581     });
582 }
583
584 }
585 }
586 }