Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_i8i8_pooling.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "jit_uni_i8i8_pooling.hpp"
18
19 #include <math.h>
20
21 #include "mkldnn_types.h"
22
23 #include "mkldnn_thread.hpp"
24 #include "utils.hpp"
25
26 #include "jit_generator.hpp"
27
28
29 namespace mkldnn {
30 namespace impl {
31 namespace cpu {
32
33 using namespace Xbyak;
34
35 using namespace mkldnn::impl::utils;
36 using namespace mkldnn::impl::memory_format;
37 using namespace mkldnn::impl::utils;
38 using namespace mkldnn::impl::types;
39 using namespace alg_kind;
40
41 template <cpu_isa_t isa>
42 struct jit_uni_i8i8_pooling_fwd_ker_t: public jit_generator {
43     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_i8i8_pooling_fwd_ker_t)
44
45     struct call_params_t {
46         const char *src_i8;
47         const char *dst_i8;
48         size_t kw_range;
49         size_t kh_range;
50         float idivider;
51     };
52
53     using Vmm = typename cpu_isa_traits<isa>::Vmm;
54     Xmm xreg(int idx) const { return Xmm(idx); }
55     Ymm yreg(int idx) const { return Ymm(xreg(idx).getIdx()); }
56     Vmm vreg(int idx) const { return Vmm(xreg(idx).getIdx()); }
57
58     // Rounding modes for axv2
59     enum:uint8_t { rnd_op_nearest = 0x0 };
60
61     // In case of avx2 with data type i8 we need to use
62     // maskmovdqu instruction which has its destination hardcoded in rdi.
63     // Windows ABI: abi_param1 is rcx - nothing to do else
64     // Unix ABI: abi_param1 is rdi - copy it to rcx and use it as abi_param1
65     Reg64 reg_param      = rcx; // Our "unified abi_param1"
66     Reg64 reg_ptr_src_i8 = r8;
67     Reg64 reg_ptr_dst_i8 = r9;
68     Reg64 reg_ptr_maskmovdqu_dst = rdi; // store destination - must be rdi
69
70     Reg64 ki = r10;
71     Reg64 kj = r11;
72     Reg64 reg_kw = r12;
73     Reg64 reg_kh = r13;
74     Reg64 c_iter = r14;
75
76     Reg64 aux_reg_src_h = rax;
77     Reg64 aux_reg_src_w = rbx;
78
79     Reg64 reg_tmp = rdx;
80
81     Reg64 reg_mask = r15;
82
83     Opmask k_cmp_mask = Opmask(7);
84
85     Opmask mask(int idx) {
86         return Opmask(6 - idx);
87     }
88
89     // ref to any of XYZ-regs via xreg/yreg/vreg functions
90     Xmm xmm_tmp = xreg(0);     // temp to init vreg_tmp
91     Vmm vreg_tmp = vreg(0);    // max pooling : holds minimum values for data_type
92     Vmm vreg_zeros = vreg(1);
93
94     // only in case of <isa> == avx2
95     Vmm vreg_mask    = vreg(2); // full byte-mask
96     Xmm xreg_mask_lo = xreg(2); // low 128-bits part of byte-mask (alias for xmm part of vreg_mask)
97     Xmm xreg_mask_hi = xreg(3); // "max" - high 128-bits part of byte-mask (stored separately)
98     Xmm xreg_mask_q  = xreg(3); // "avg" - 1/4 part of the mask for s8/u8 operations
99     Vmm vreg_mask_q  = vreg(3); // "avg" - 1/4 part for non-zero tails
100
101     enum:int {vidx_base = isa == avx2 ? 4 : 2};
102     Vmm base_vr(int idx) const { return vreg(vidx_base + idx); }
103
104     size_t sizeof_src_dt() const { return data_type_size(jpp.src_dt); }
105     size_t sizeof_dst_dt() const { return data_type_size(jpp.dst_dt); }
106
107     /* max pooling */
108     Vmm vreg_src(int idx) const { return base_vr(idx); }            // [0    .. ur_c-1]
109     Vmm vreg_dst(int idx) const { return base_vr(jpp.ur_c + idx); } // [ur_c .. 2*ur_c-1]
110
111     /* avg pooling */
112     // s32 used for processing of s8/u8 data
113     // thus we need to take into account ratio of sizes s32/i8 = 4
114     static constexpr data_type_t avg_proc_dt = data_type::s32;
115     enum:int {
116         s32_to_i8_ratio = sizeof(typename prec_traits<avg_proc_dt>::type)
117                 / sizeof(typename prec_traits<data_type::u8>::type),
118         max_num_ll =  s32_to_i8_ratio
119     };
120     Vmm vreg_src_s32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 0*max_num_ll); }  // ll: 0..4 [0..3]
121     Vmm vreg_dst_s32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 1*max_num_ll); }  // ll: 0..4 [4..7]
122     Vmm vreg_dst_f32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 2*max_num_ll); }  // ll: 0..4 [8..11]
123
124     void (*ker_)(const call_params_t *);
125     jit_pool_conf_t jpp;
126
127     void init_tmp_reg();
128     void init_mask();
129
130     void load_vreg_mask_q(int ll) {};
131
132     void load_src_max_op(int jj, int ll, size_t offset, bool masked, uint64_t msk);
133     void load_src_avg_op(int jj, int ll, size_t offset, bool masked, uint64_t msk);
134     void load_src(int jj, int ll, int c_tail);
135
136     void store_dst_max_op(int jj, int ll, size_t offset, bool masked, uint64_t msk);
137     void store_dst_avg_op(int jj, int ll, size_t offset, bool masked, uint64_t msk);
138     void store_dst(int jj, int ll, int c_tail);
139
140     void compute_avg_step(int ur_c, int c_tail);
141     void compute_max_op(const int jj);
142     void compute_max_step(int ur_c, int c_tail);
143     void compute_step(int ur_c, int c_tail);
144
145     void compute_c_block();
146     void generate();
147
148     static status_t init_conf(jit_pool_conf_t &jpp,
149         const pooling_desc_t &pd, const memory_desc_wrapper &src_d,
150         const memory_desc_wrapper &dst_d);
151
152     jit_uni_i8i8_pooling_fwd_ker_t(const jit_pool_conf_t &jpp_)
153            : jpp(jpp_) {
154         generate();
155         ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(
156                        getCode()));
157     }
158 };
159
160 template <>
161 void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_vreg_mask_q(int ll) {
162
163     // extract ll-th part of mask (ll-th QWORD)
164     vpblendd(vreg_mask_q, vreg_zeros, vreg_mask, 0x3 << ll); // 0x3 - mask for 2 x DWORD
165
166     // Move mask from ll-th pos to 0-th pos
167     if (ll>0)
168         vpermq(vreg_mask_q, vreg_mask_q, ll);
169 };
170
171 template <>
172 void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_src_max_op(int jj, int ll,
173         size_t offset, bool masked, uint64_t msk) {
174     using namespace data_type;
175
176     if (masked) {
177         if (jpp.src_dt == s32) {
178             vpblendd(vreg_src(jj), vreg_tmp, ptr[aux_reg_src_w + offset], static_cast<uint8_t>(msk));
179         } else {
180             vpblendvb(vreg_src(jj), vreg_tmp, ptr[aux_reg_src_w + offset], vreg_mask);
181         }
182     } else
183         vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]);
184 };
185
186 template <>
187 void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::load_src_max_op(int jj, int ll,
188         size_t offset, bool masked, uint64_t msk) {
189     using namespace data_type;
190
191     if (masked) {
192         if (jpp.src_dt == s32)
193             vmovups(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]);
194         else
195             vmovdqu8(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]);
196     } else
197         vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]);
198 };
199
200 template <>
201 void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_src_avg_op(int jj, int ll,
202         size_t offset, bool masked, uint64_t msk) {
203     using namespace data_type;
204
205     // Don't generate useless code
206     if (masked && !msk)
207         return;
208
209     auto load_i8 = [&](bool is_signed, const Vmm& vr_src) {
210
211         // Need to use mask of tail?
212         if (masked) {
213
214             // load ll-th part of mask into vreg_mask_q
215             load_vreg_mask_q(ll);
216
217             // Load by mask from mem into register vr_src
218             vpblendvb(vr_src, vreg_zeros, ptr[aux_reg_src_w + offset], vreg_mask_q);
219
220             // Conversion s8/u8 -> s32
221             if (is_signed)
222                 vpmovsxbd(vr_src, vr_src);
223             else
224                 vpmovzxbd(vr_src, vr_src);
225         } else {
226
227             // Load from mem into vr_src with conversion
228             if (is_signed)
229                 vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]);
230             else
231                 vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]);
232         }
233     };
234
235     switch (jpp.src_dt) {
236         case s32:
237             if (masked)
238                 vpblendd(vreg_src_s32(jj, ll), vreg_zeros, ptr[aux_reg_src_w + offset],
239                     static_cast<uint8_t>(msk));
240             else
241                 vmovups(vreg_src_s32(jj, ll), ptr[aux_reg_src_w + offset]);
242             break;
243         case s8:
244                 load_i8(true, vreg_src_s32(jj, ll));
245             break;
246         case u8:
247                 load_i8(false, vreg_src_s32(jj, ll));
248             break;
249         default: assert(!"unsupported src data type");
250     }
251 };
252
253 template <>
254 void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::load_src_avg_op(int jj, int ll,
255         size_t offset, bool masked, uint64_t msk) {
256     using namespace data_type;
257
258     // Don't generate useless code
259     if (masked && !msk)
260         return;
261
262     const Vmm& vr_src = masked ?
263             vreg_src_s32(jj, ll) | mask(ll) :
264             vreg_src_s32(jj, ll);
265
266     switch (jpp.src_dt) {
267         case s32:
268             vmovups(vr_src, ptr[aux_reg_src_w + offset]);
269             break;
270         case s8:
271             vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]);
272             break;
273         case u8:
274             vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]);
275             break;
276         default: assert(!"unsupported src data type");
277     }
278 };
279
280 template <cpu_isa_t isa>
281 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::load_src(int jj, int ll, int c_tail) {
282     using namespace data_type;
283
284     int c_block = jpp.c_block;
285     int ur_c = jpp.ur_c;
286
287     switch (jpp.alg) {
288         case pooling_max: {
289             auto offset = jj*c_block*sizeof_src_dt();
290             bool masked = jj == ur_c - 1 && c_tail;
291             load_src_max_op(jj, ll, offset, masked, jpp.tail[0]);
292             break;
293         }
294         case pooling_avg_include_padding:
295         case pooling_avg_exclude_padding: {
296             auto offset = (ll*(c_block/max_num_ll) + jj*c_block)*sizeof_src_dt();
297             bool masked = jj == ur_c - 1 && c_tail;
298             load_src_avg_op(jj, ll, offset, masked, jpp.tail[ll]);
299             break;
300         }
301         default: assert(!"unsupported algorithm");
302     }
303 }
304
305 template <>
306 void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::store_dst_max_op(int jj, int ll,
307         size_t offset, bool masked, uint64_t msk) {
308     using namespace data_type;
309
310     int c_block = jpp.c_block;
311
312     if (masked) {
313         switch (jpp.src_dt) {
314             case s32:
315                 vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask, vreg_dst(jj));
316                 break;
317             case s8:
318             case u8: {
319                 // Store low half by mask (bytes 0...15)
320                 lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]);
321                 maskmovdqu(vreg_dst(jj), xreg_mask_lo);
322
323                 // Do we need to store high half (bytes 16...31) ?
324                 const uint64_t low_mask = (1ULL << (c_block/2))-1;
325                 if (msk & ~low_mask) {
326                     vextracti128(Xmm(vreg_dst(jj).getIdx()), vreg_dst(jj), 1);
327                     add(reg_ptr_maskmovdqu_dst, c_block / 2);
328                     maskmovdqu(vreg_dst(jj), xreg_mask_hi);
329                 }
330             } break;
331             default: assert(!"unsupported src data type");
332         }
333     } else
334         vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj));
335 }
336
337 template <>
338 void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::store_dst_max_op(int jj, int ll,
339         size_t offset, bool masked, uint64_t msk) {
340     using namespace data_type;
341
342     if (masked) {
343         switch (jpp.src_dt) {
344             case s32:
345                 vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0));
346                 break;
347             case s8:
348             case u8:
349                 vmovdqu8(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0));
350                 break;
351             default: assert(!"unsupported src data type");
352         }
353     } else
354         vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj));
355 }
356
357 template <>
358 void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::store_dst_avg_op(int jj, int ll,
359         size_t offset, bool masked, uint64_t msk){
360     using namespace data_type;
361
362     // Don't generate useless code
363     if (masked && !msk)
364         return;
365
366     auto s32_to_i8 = [&](bool is_signed, const Vmm& vr_dst) {
367
368         // conversion: s32 -> s16/u16 : {8 x s32}{8 x 0} -> {16 x s16/u16}
369         // Result QWORDs (qw0, qw1) permuted: {qw0, 0, qw1, 0}
370         if (is_signed)
371             vpackssdw(vr_dst, vr_dst, vreg_zeros);
372         else
373             vpackusdw(vr_dst, vr_dst, vreg_zeros);
374
375         // Permute qwords to restore original order
376         // {qw0, 0, qw1, 0} -> {qw0, qw1, 0, 0}
377         vpermq(vr_dst, vr_dst, 0x58);
378
379         // conversion: s16/u16 -> s8/u8 : {16 x s16/u16}{16 x 0} -> {32 x s8/u8}
380         // Target QWORD qw = {8 x s8/u8} has proper position: {qw, xx, xx, xx}
381         if (is_signed)
382             vpacksswb(vr_dst, vr_dst, vreg_zeros);
383         else
384             vpackuswb(vr_dst, vr_dst, vreg_zeros);
385
386     };
387
388     auto store_i8 = [&](bool is_signed, bool is_masked, const Vmm& vr_dst) {
389
390         // Conversion s32 -> s8/u8
391         s32_to_i8(is_signed, vr_dst);
392
393         // Need to use mask of tail?
394         if (is_masked) {
395             // load ll-th part of mask into vreg_mask_q
396             load_vreg_mask_q(ll);
397         }
398
399         // store 8 bytes
400         lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]);
401         maskmovdqu(vr_dst, xreg_mask_q);
402     };
403
404     switch (jpp.dst_dt) {
405         case s32:
406             if (masked) {
407                 vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask, vreg_dst_s32(jj, ll));
408             } else
409                 vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst_s32(jj, ll));
410             break;
411         case s8:
412             store_i8(true, masked, vreg_dst_s32(jj, ll));
413             break;
414         case u8:
415             store_i8(false, masked, vreg_dst_s32(jj, ll));
416             break;
417         default: assert(!"unsuppotred dst data_type");
418     }
419 }
420
421 template <>
422 void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::store_dst_avg_op(int jj, int ll,
423         size_t offset, bool masked, uint64_t msk) {
424     using namespace data_type;
425
426     // Don't generate useless code
427     if (masked && !msk)
428         return;
429
430     const Vmm& vr_dst = masked ?
431             vreg_dst_s32(jj, ll) | mask(ll) :
432             vreg_dst_s32(jj, ll);
433
434     switch (jpp.dst_dt) {
435         case s32:
436             vmovups(ptr[reg_ptr_dst_i8 + offset], vr_dst);
437             break;
438         case s8:
439             vpmovdb(ptr[reg_ptr_dst_i8 + offset], vr_dst);
440             break;
441         case u8:
442             vpmovusdb(ptr[reg_ptr_dst_i8 + offset], vr_dst);
443             break;
444         default: assert(!"unsupported dst data_type");
445     }
446 }
447
448
449 template <cpu_isa_t isa>
450 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::store_dst(int jj, int ll,
451         int c_tail) {
452     using namespace data_type;
453
454     int c_block = jpp.c_block;
455     int ur_c = jpp.ur_c;
456
457     switch(jpp.alg) {
458         case pooling_max: {
459             auto offset = jj*c_block*sizeof_dst_dt();
460             bool masked = jj == ur_c - 1 && c_tail;
461             store_dst_max_op(jj, ll, offset, masked, jpp.tail[ll]);
462             break;
463         }
464         case pooling_avg_include_padding:
465         case pooling_avg_exclude_padding: {
466             auto offset = (ll*(c_block/max_num_ll) + jj*c_block)*sizeof_dst_dt();
467             bool masked = jj == ur_c - 1 && c_tail;
468             store_dst_avg_op(jj, ll, offset, masked, jpp.tail[ll]);
469             break;
470         }
471         default: assert(!"unsupported pooling algorithm");
472     }
473 }
474
475 template <>
476 void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::compute_max_op(const int jj)
477 {
478     using namespace data_type;
479     switch (jpp.src_dt) {
480         case s32:
481             vpmaxsd(vreg_dst(jj), vreg_dst(jj), vreg_src(jj));
482             break;
483         case s8:
484             vpmaxsb(vreg_dst(jj), vreg_dst(jj), vreg_src(jj));
485             break;
486         case u8:
487             vpmaxub(vreg_dst(jj), vreg_dst(jj), vreg_src(jj));
488             break;
489         default: assert(!"unsupported src data type");
490     }
491 }
492
493 template <>
494 void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::compute_max_op(const int jj)
495 {
496     using namespace data_type;
497
498     // Compare
499     switch (jpp.src_dt) {
500         case s32:
501             vpcmpd(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
502             break;
503         case s8:
504             vpcmpb(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
505             break;
506         case u8:
507             vpcmpub(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
508             break;
509         default: assert(!"unsupported src data type");
510     }
511
512     // move max values into vreg_dst
513     if (jpp.src_dt == s32)
514         vpblendmd(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj));
515     else
516         vpblendmb(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj));
517 }
518
519
520 template <cpu_isa_t isa>
521 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_max_step(int ur_c, int c_tail)
522 {
523     Label l_kw, l_kh;
524
525     int iw = jpp.iw;
526     int c = jpp.c;
527
528     for (int jj = 0; jj < ur_c; jj++)
529         vmovups(vreg_dst(jj), vreg_tmp);
530
531     mov(aux_reg_src_h, reg_ptr_src_i8);
532
533     xor_(kj, kj);
534     L(l_kh);
535     {
536         mov(aux_reg_src_w, aux_reg_src_h);
537         xor_(ki, ki);
538         L(l_kw);
539         {
540             for (int jj = 0; jj < ur_c; jj++) {
541                 load_src(jj, 0, c_tail);
542                 compute_max_op(jj);
543             }
544             add(aux_reg_src_w, c * sizeof_src_dt());
545             inc(ki);
546             cmp(ki, reg_kw);
547             jl(l_kw, T_NEAR);
548         }
549         add(aux_reg_src_h, iw * c * sizeof_src_dt());
550         inc(kj);
551         cmp(kj, reg_kh);
552         jl(l_kh, T_NEAR);
553     }
554
555     for (int jj = 0; jj < ur_c; jj++)
556         store_dst(jj, 0, c_tail);
557 }
558
559 template <cpu_isa_t isa>
560 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_avg_step(int ur_c, int c_tail)
561 {
562     using namespace data_type;
563
564     Label l_kw, l_kh;
565
566     int iw = jpp.iw;
567     int c = jpp.c;
568
569     const int num_ll = data_type_size(avg_proc_dt)/data_type_size(jpp.src_dt);
570
571     for (int jj = 0; jj < ur_c; jj++) {
572         for (int ll = 0; ll < num_ll; ll++) {
573             bool masked = jj == ur_c - 1 && c_tail;
574             size_t msk = jpp.tail[ll];
575             if (!(masked && !msk)) {
576                 uni_vpxor(vreg_src_s32(jj, ll), vreg_src_s32(jj, ll), vreg_src_s32(jj, ll));
577                 uni_vpxor(vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll));
578             }
579         }
580     }
581
582     mov(aux_reg_src_h, reg_ptr_src_i8);
583
584     xor_(kj, kj);
585     L(l_kh);
586     {
587         mov(aux_reg_src_w, aux_reg_src_h);
588         xor_(ki, ki);
589         L(l_kw);
590         {
591             for (int jj = 0; jj < ur_c; jj++) {
592                 for (int ll = 0; ll < num_ll; ll++) {
593                     bool masked = jj == ur_c - 1 && c_tail;
594                     size_t msk = jpp.tail[ll];
595                     if (!(masked && !msk)) {
596                         load_src(jj, ll, c_tail);
597                         vpaddd(vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll),
598                                 vreg_src_s32(jj, ll));
599                     }
600                 }
601             }
602             add(aux_reg_src_w, c * sizeof_src_dt());
603             inc(ki);
604             cmp(ki, reg_kw);
605             jl(l_kw, T_NEAR);
606         }
607         add(aux_reg_src_h, iw * c * sizeof_src_dt());
608         inc(kj);
609         cmp(kj, reg_kh);
610         jl(l_kh, T_NEAR);
611     }
612
613     for (int jj = 0; jj < ur_c; jj++) {
614         for (int ll = 0; ll < num_ll; ll++) {
615             bool masked = jj == ur_c - 1 && c_tail;
616             size_t msk = jpp.tail[ll];
617             if (!(masked && !msk)) {
618
619                 vcvtdq2ps(vreg_dst_f32(jj, ll), vreg_dst_s32(jj, ll));
620                 vfmadd132ps(vreg_dst_f32(jj, ll), vreg_zeros, vreg_tmp);
621
622                 if (isa == avx2) {
623                     uni_vroundps(vreg_dst_f32(jj, ll), vreg_dst_f32(jj, ll), rnd_op_nearest);
624                     vcvtps2dq(vreg_dst_s32(jj, ll), vreg_dst_f32(jj, ll));
625                 } else if (isa >= avx512_common) {
626                     // AVX512: use of EVEX-embedded static rounding override
627                     vcvtps2dq(vreg_dst_s32(jj, ll) | T_rn_sae, vreg_dst_f32(jj, ll));
628                 }
629
630                 store_dst(jj, ll, c_tail);
631             }
632         }
633     }
634 }
635
636 template <cpu_isa_t isa>
637 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_step(int ur_c, int c_tail) {
638     switch (jpp.alg) {
639         case pooling_max:
640             compute_max_step(ur_c, c_tail); break;
641         case pooling_avg_include_padding:
642         case pooling_avg_exclude_padding:
643             compute_avg_step(ur_c, c_tail); break;
644         default: assert(!"unsupported pooling algorithm");
645     }
646 }
647
648 template <cpu_isa_t isa>
649 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_c_block(){
650     Label l_main_loop;
651
652     int nb_c = jpp.nb_c;
653     int c_block = jpp.c_block;
654     int ur_c = jpp.ur_c;
655     int ur_c_tail = jpp.ur_c_tail;
656     int c_steps = nb_c / ur_c;
657     int c_tail = jpp.c_tail;
658
659     xor_(c_iter, c_iter);
660     if (c_steps > 0) {
661         L(l_main_loop); {
662             compute_step(ur_c, 0);
663             add(reg_ptr_src_i8, ur_c*c_block*sizeof_src_dt());
664             add(reg_ptr_dst_i8, ur_c*c_block*sizeof_dst_dt());
665             inc(c_iter);
666             cmp(c_iter, c_steps);
667             jl(l_main_loop, T_NEAR);
668         }
669     }
670
671     if (ur_c_tail != 0) {
672         compute_step(ur_c_tail, c_tail);
673     }
674 }
675
676 template<>
677 void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::init_mask() {
678     using namespace data_type;
679     using cpu_isa = cpu_isa_traits<avx2>;
680
681     // AVX2 mask initialization: mask stored in Ymm-regs
682     auto init = [&](uint64_t bit_mask, bool init_mask_q) {
683         const size_t QW_PER_VREG = cpu_isa::vlen / sizeof(uint64_t);
684
685         uint64_t vmask[QW_PER_VREG];
686         for (size_t i = 0; i < QW_PER_VREG; i++){
687
688             uint64_t qw_vmask=0ULL;
689             const size_t DBITS = 8*sizeof_src_dt();
690             const uint64_t VMSK = 1ULL << (DBITS-1);
691             const size_t D_PER_QW = (8*sizeof(qw_vmask))/DBITS;
692             for (size_t j = 0; j < D_PER_QW; j++) {
693                 if (bit_mask & 1)
694                     qw_vmask |= VMSK << DBITS * j;
695                 bit_mask >>= 1;
696             }
697             vmask[i] = qw_vmask;
698         }
699
700         // Put QWORDS with target mask into xmm regs
701         const int xdst_i[QW_PER_VREG] = {
702                 xreg_mask_lo.getIdx(),
703                 xreg_mask_lo.getIdx(),
704                 xreg_mask_hi.getIdx(),
705                 xreg_mask_hi.getIdx()
706         };
707         const int xsrc_i[QW_PER_VREG] = {
708                 vreg_zeros.getIdx(),   // 0-th qword insert in zeros -> {qw0,  0}
709                 xreg_mask_lo.getIdx(), // 1-st and 0-th merge        -> {qw0,qw1}
710                 vreg_zeros.getIdx(),
711                 xreg_mask_hi.getIdx()
712         };
713         const uint8 qw_dst_idx[QW_PER_VREG] = {0, 1, 0, 1}; // qword index in 128-bit xreg
714
715         for (size_t i = 0; i < QW_PER_VREG; i++) {
716             mov(reg_mask, vmask[i]);
717             vpinsrq(Xmm(xdst_i[i]), Xmm(xsrc_i[i]), reg_mask, qw_dst_idx[i]);
718         }
719
720         // Merge Low (xreg_mask_lo alias for vreg_mask.xreg)
721         // and High (xreg_mask_hi) into full vreg_mask
722         // vreg_mask -> {xreg_mask_hi, vreg_mask.xreg}
723         vinserti128(vreg_mask, vreg_mask, xreg_mask_hi, 1);
724
725         // Keep only low qword of mask in xreg_mask_q
726         if (init_mask_q) {
727             mov(reg_mask, vmask[0]);
728             vpinsrq(xreg_mask_q, Xmm(vreg_zeros.getIdx()), reg_mask, 0);
729         }
730     };
731
732     uint64_t tail_mask = (1ULL << jpp.c_tail) - 1;
733     switch (jpp.alg) {
734         case pooling_max:
735             // For "max" we need mask only in case of non-zero tail
736             if (tail_mask)
737                 init(tail_mask, false);
738             break;
739         case pooling_avg_include_padding:
740         case pooling_avg_exclude_padding:
741             // For "avg" we need mask:
742             // - s32   - in case of the non-zero tail
743             // - s8/u8 - irrespective of the tail
744             switch (jpp.src_dt) {
745                 case s32:
746                     if (tail_mask)
747                         init(tail_mask, false);
748                     break;
749                 case s8:
750                 case u8:
751                     init(tail_mask ? tail_mask : ~0ULL, tail_mask == 0);
752                     break;
753                 default: assert(!"unsupported src data type");
754             }
755             break;
756         default: assert(!"unsupported pooling algorithm");
757     }
758 }
759
760 template<>
761 void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::init_mask() {
762
763     for (int ll = 0; ll < max_num_ll; ll++) {
764         mov(reg_mask, jpp.tail[ll]);
765         kmovq(mask(ll), reg_mask);
766     }
767 }
768
769 template <cpu_isa_t isa>
770 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_tmp_reg() {
771     using namespace data_type;
772
773     switch (jpp.alg) {
774         case pooling_avg_include_padding:
775         case pooling_avg_exclude_padding:
776             mov(reg_tmp, ptr[reg_param + offsetof(call_params_t, idivider)]);
777             movq(xmm_tmp, reg_tmp);
778             vpbroadcastd(vreg_tmp, xmm_tmp);
779             break;
780         case pooling_max:
781             switch (jpp.src_dt) {
782                 case s32:
783                     mov(reg_tmp, nstl::numeric_limits<int32_t>::lowest());
784                     break;
785                 case s8:
786                     mov(reg_tmp, nstl::numeric_limits<int8_t>::lowest());
787                     break;
788                 case u8:
789                     mov(reg_tmp, nstl::numeric_limits<uint8_t>::lowest());
790                     break;
791                 default: assert(!"unsupported src data_type");
792             }
793
794             movq(xmm_tmp, reg_tmp);
795             if (jpp.src_dt == s32)
796                 vpbroadcastd(vreg_tmp, xmm_tmp);
797             else
798                 vpbroadcastb(vreg_tmp, xmm_tmp);
799             break;
800         default: assert(!"unsupported pooling algorithm");
801     }
802
803 }
804
805 template <cpu_isa_t isa>
806 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::generate() {
807     preamble();
808
809 #if !defined(_WIN32)
810     // Always use rcx as abi_param1 -
811     // see the note about maskmovdqu near reg_param.
812     mov(rcx, rdi);
813 #endif
814
815 #   define READ_PARAM(reg, field) \
816         mov(reg, ptr[reg_param + offsetof(call_params_t, field)])
817     READ_PARAM(reg_ptr_src_i8, src_i8);
818     READ_PARAM(reg_ptr_dst_i8, dst_i8);
819     READ_PARAM(reg_kw, kw_range);
820     READ_PARAM(reg_kh, kh_range);
821
822 #   undef READ_PARAM
823
824     uni_vpxor(vreg_zeros, vreg_zeros, vreg_zeros);
825
826     init_mask();
827
828     init_tmp_reg();
829
830     compute_c_block();
831
832     postamble();
833 }
834
835 template <cpu_isa_t isa>
836 status_t jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_conf(jit_pool_conf_t &jpp,
837         const pooling_desc_t &pd, const memory_desc_wrapper &src_d,
838         const memory_desc_wrapper &dst_d) {
839     if (!mayiuse(isa))
840         return status::unimplemented;
841
842     jpp.mb = src_d.dims()[0];
843     jpp.c = src_d.dims()[1];
844     jpp.ih = src_d.dims()[2];
845     jpp.iw = src_d.dims()[3];
846     jpp.oh = dst_d.dims()[2];
847     jpp.ow = dst_d.dims()[3];
848
849     jpp.stride_h = pd.strides[0];
850     jpp.stride_w = pd.strides[1];
851     jpp.kh = pd.kernel[0];
852     jpp.kw = pd.kernel[1];
853
854     jpp.t_pad = pd.padding[0][0];
855     jpp.l_pad = pd.padding[0][1];
856
857     jpp.alg = pd.alg_kind;
858
859     jpp.src_dt = pd.src_desc.data_type;
860     jpp.dst_dt = pd.dst_desc.data_type;
861
862     // data_type items per one vreg on the <isa>
863     //     isa == avx2    : 32 bytes -> 32 for s8/u8, 8 for s32
864     //     isa == avx512* : 64 bytes -> 64 for s8/u8, 16 for s32
865     int simd_w = cpu_isa_traits<isa>::vlen / data_type_size(jpp.src_dt);
866
867     jpp.c_block = simd_w;
868     jpp.c_tail = jpp.c % jpp.c_block;
869     jpp.nb_c = jpp.c / jpp.c_block;
870     jpp.ur_c = 1;
871     jpp.ur_c_tail = jpp.nb_c - (jpp.nb_c / jpp.ur_c)*jpp.ur_c +
872             (jpp.c_tail != 0);
873
874     size_t tail_mask = (1ULL << jpp.c_tail) - 1;
875
876     switch (jpp.alg) {
877         case pooling_max:
878             jpp.tail[0] = tail_mask;
879             jpp.tail[1] = 0;
880             jpp.tail[2] = 0;
881             jpp.tail[3] = 0;
882             break;
883         case pooling_avg_include_padding:
884         case pooling_avg_exclude_padding: {
885             // avg_proc_dt (s32) defines granularity (because u8/s8 processed as s32)
886             // avx2 : 8, avx512 : 16
887             const size_t msk_gran = cpu_isa_traits<isa>::vlen / data_type_size(avg_proc_dt);
888             const size_t msk_msk = (1ULL << msk_gran) - 1;
889             size_t m = tail_mask;
890             for (size_t ll = 0; ll < max_num_ll; ll++) {
891                 jpp.tail[ll] = m & msk_msk;
892                 m = m >> msk_gran;
893             }
894             break;
895         }
896         default: return status::unimplemented;
897     }
898
899     return status::success;
900 }
901
902 template <cpu_isa_t isa>
903 status_t jit_uni_i8i8_pooling_fwd_t<isa>::pd_t::jit_conf() {
904     return jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_conf(jpp_,
905        desc_, src_pd_.desc(), dst_pd_.desc());
906 }
907
908 template <cpu_isa_t isa>
909 jit_uni_i8i8_pooling_fwd_t<isa>::
910 jit_uni_i8i8_pooling_fwd_t(const pd_t *apd,
911           const input_vector &inputs, const output_vector &outputs)
912     : cpu_primitive_t(apd, inputs, outputs), ker_(nullptr)
913 { ker_ = new jit_uni_i8i8_pooling_fwd_ker_t<isa>(pd()->jpp_); }
914
915 template <cpu_isa_t isa>
916 jit_uni_i8i8_pooling_fwd_t<isa>::
917 ~jit_uni_i8i8_pooling_fwd_t() { delete ker_; }
918
919 template <cpu_isa_t isa>
920 void jit_uni_i8i8_pooling_fwd_t<isa>::execute_forward() const {
921     auto src_i8 = reinterpret_cast<const char *>(input_memory(0));
922     auto dst_i8 = reinterpret_cast<char *>(memory());
923
924     const memory_desc_wrapper src_d(pd()->src_pd());
925     const memory_desc_wrapper dst_d(pd()->dst_pd());
926
927     const auto &jpp = pd()->jpp_;
928
929     parallel_nd(jpp.mb, jpp.oh, jpp.ow,
930             [&](int n, int oh, int ow) {
931         const int ih = nstl::max(oh*jpp.stride_h - jpp.t_pad, 0);
932         const int iw = nstl::max(ow*jpp.stride_w - jpp.l_pad, 0);
933
934         const int kh_start = nstl::max(0, jpp.t_pad - oh * jpp.stride_h);
935         const int kh_end = nstl::min(jpp.kh,
936                 jpp.ih + jpp.t_pad - oh * jpp.stride_h);
937         const int kw_start = nstl::max(0, jpp.l_pad - ow * jpp.stride_w);
938         const int kw_end = nstl::min(jpp.kw,
939                 jpp.iw + jpp.l_pad - ow * jpp.stride_w);
940
941         auto p = typename jit_uni_i8i8_pooling_fwd_ker_t<isa>::call_params_t();
942         p.src_i8 = &src_i8[
943             src_d.blk_off(n, 0, ih, iw) * src_d.data_type_size()];
944         p.dst_i8 = &dst_i8[
945             dst_d.blk_off(n, 0, oh, ow) * dst_d.data_type_size()];
946         p.kw_range = (size_t)(kw_end - kw_start);
947         p.kh_range = (size_t)(kh_end - kh_start);
948         p.idivider = 1.0f / ((jpp.alg == pooling_avg_exclude_padding) ?
949             p.kh_range*p.kw_range : jpp.kw*jpp.kh);
950
951         ker_->ker_(&p);
952     });
953 }
954
955 // Explicit instantiation only for supported <isa> values.
956 //
957 template struct jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>;
958 template struct jit_uni_i8i8_pooling_fwd_t<avx512_core>;
959
960 template struct jit_uni_i8i8_pooling_fwd_ker_t<avx2>;
961 template struct jit_uni_i8i8_pooling_fwd_t<avx2>;
962
963 }
964 }
965 }