Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_i8i8_pooling.cpp
index ccc1c34..8f93163 100644 (file)
@@ -1,5 +1,5 @@
 /*******************************************************************************
-* Copyright 2018 Intel Corporation
+* Copyright 2017-2018 Intel Corporation
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
@@ -14,6 +14,8 @@
 * limitations under the License.
 *******************************************************************************/
 
+#include "jit_uni_i8i8_pooling.hpp"
+
 #include <math.h>
 
 #include "mkldnn_types.h"
@@ -23,7 +25,6 @@
 
 #include "jit_generator.hpp"
 
-#include "jit_uni_i8i8_pooling.hpp"
 
 namespace mkldnn {
 namespace impl {
@@ -37,20 +38,34 @@ using namespace mkldnn::impl::utils;
 using namespace mkldnn::impl::types;
 using namespace alg_kind;
 
-struct call_params_t {
-    const char *src_i8;
-    const char *dst_i8;
-    size_t kw_range;
-    size_t kh_range;
-    float idivider;
-};
-
 template <cpu_isa_t isa>
-struct jit_uni_i8i8_pool_fwd_ker_t : public jit_generator {
-    DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_i8i8_pool_fwd_ker_t)
-
+struct jit_uni_i8i8_pooling_fwd_ker_t: public jit_generator {
+    DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_i8i8_pooling_fwd_ker_t)
+
+    struct call_params_t {
+        const char *src_i8;
+        const char *dst_i8;
+        size_t kw_range;
+        size_t kh_range;
+        float idivider;
+    };
+
+    using Vmm = typename cpu_isa_traits<isa>::Vmm;
+    Xmm xreg(int idx) const { return Xmm(idx); }
+    Ymm yreg(int idx) const { return Ymm(xreg(idx).getIdx()); }
+    Vmm vreg(int idx) const { return Vmm(xreg(idx).getIdx()); }
+
+    // Rounding modes for axv2
+    enum:uint8_t { rnd_op_nearest = 0x0 };
+
+    // In case of avx2 with data type i8 we need to use
+    // maskmovdqu instruction which has its destination hardcoded in rdi.
+    // Windows ABI: abi_param1 is rcx - nothing to do else
+    // Unix ABI: abi_param1 is rdi - copy it to rcx and use it as abi_param1
+    Reg64 reg_param      = rcx; // Our "unified abi_param1"
     Reg64 reg_ptr_src_i8 = r8;
     Reg64 reg_ptr_dst_i8 = r9;
+    Reg64 reg_ptr_maskmovdqu_dst = rdi; // store destination - must be rdi
 
     Reg64 ki = r10;
     Reg64 kj = r11;
@@ -62,73 +77,70 @@ struct jit_uni_i8i8_pool_fwd_ker_t : public jit_generator {
     Reg64 aux_reg_src_w = rbx;
 
     Reg64 reg_tmp = rdx;
-    Reg64 reg_src_64 = r15;
-    Reg32 reg_src_32 = r15d;
-    Reg8 reg_src_8 = r15b;
 
-    size_t sizeof_src_dt() const { return data_type_size(jpp.src_dt); }
-    size_t sizeof_dst_dt() const { return data_type_size(jpp.dst_dt); }
+    Reg64 reg_mask = r15;
 
-    using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
-            isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
+    Opmask k_cmp_mask = Opmask(7);
 
-    Xmm xmm_tmp = Xmm(0);
-    Vmm vreg_tmp = Vmm(14);
-    Vmm vreg_zeros = Vmm(15);
-
-    /* max pooling */
-    Vmm vmm_src(int jj, int ii) {
-        return Vmm(2*jj + ii);
+    Opmask mask(int idx) {
+        return Opmask(6 - idx);
     }
 
-    Xmm xmm_src(int jj) {
-        return Xmm(2*jj);
-    }
+    // ref to any of XYZ-regs via xreg/yreg/vreg functions
+    Xmm xmm_tmp = xreg(0);     // temp to init vreg_tmp
+    Vmm vreg_tmp = vreg(0);    // max pooling : holds minimum values for data_type
+    Vmm vreg_zeros = vreg(1);
 
-    Vmm vmm_dst(int jj, int ii) {
-        return Vmm(2*jj + ii + 2 * jpp.ur_c);
-    }
+    // only in case of <isa> == avx2
+    Vmm vreg_mask    = vreg(2); // full byte-mask
+    Xmm xreg_mask_lo = xreg(2); // low 128-bits part of byte-mask (alias for xmm part of vreg_mask)
+    Xmm xreg_mask_hi = xreg(3); // "max" - high 128-bits part of byte-mask (stored separately)
+    Xmm xreg_mask_q  = xreg(3); // "avg" - 1/4 part of the mask for s8/u8 operations
+    Vmm vreg_mask_q  = vreg(3); // "avg" - 1/4 part for non-zero tails
 
-    Xmm xmm_dst(int jj) {
-        return Xmm(2*jj + 2 * jpp.ur_c);
-    }
+    enum:int {vidx_base = isa == avx2 ? 4 : 2};
+    Vmm base_vr(int idx) const { return vreg(vidx_base + idx); }
 
-    /* avg pooling */
-    Vmm vmm_src_s32(int jj, int ii) {
-        return Vmm(2*jj + ii);
-    }
-
-    Xmm xmm_src_s32(int jj, int ii) {
-        return Xmm(2*jj + ii);
-    }
-
-    Vmm vmm_dst_s32(int jj, int ii) {
-        return Vmm(2*jj + ii + 2 * jpp.ur_c);
-    }
-
-    Ymm ymm_dst_s32(int jj, int ii) {
-        return Ymm(2*jj + ii + 2 * jpp.ur_c);
-    }
+    size_t sizeof_src_dt() const { return data_type_size(jpp.src_dt); }
+    size_t sizeof_dst_dt() const { return data_type_size(jpp.dst_dt); }
 
-    Xmm xmm_dst_s32(int jj, int ii) {
-        return Xmm(2*jj + ii + 2 * jpp.ur_c);
-    }
+    /* max pooling */
+    Vmm vreg_src(int idx) const { return base_vr(idx); }            // [0    .. ur_c-1]
+    Vmm vreg_dst(int idx) const { return base_vr(jpp.ur_c + idx); } // [ur_c .. 2*ur_c-1]
 
-    Vmm vmm_dst_f32(int jj, int ii) {
-        return Vmm(2*jj + ii + 4 * jpp.ur_c);
-    }
+    /* avg pooling */
+    // s32 used for processing of s8/u8 data
+    // thus we need to take into account ratio of sizes s32/i8 = 4
+    static constexpr data_type_t avg_proc_dt = data_type::s32;
+    enum:int {
+        s32_to_i8_ratio = sizeof(typename prec_traits<avg_proc_dt>::type)
+                / sizeof(typename prec_traits<data_type::u8>::type),
+        max_num_ll =  s32_to_i8_ratio
+    };
+    Vmm vreg_src_s32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 0*max_num_ll); }  // ll: 0..4 [0..3]
+    Vmm vreg_dst_s32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 1*max_num_ll); }  // ll: 0..4 [4..7]
+    Vmm vreg_dst_f32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 2*max_num_ll); }  // ll: 0..4 [8..11]
 
     void (*ker_)(const call_params_t *);
     jit_pool_conf_t jpp;
 
     void init_tmp_reg();
+    void init_mask();
+
+    void load_vreg_mask_q(int ll) {};
+
+    void load_src_max_op(int jj, int ll, size_t offset, bool masked, uint64_t msk);
+    void load_src_avg_op(int jj, int ll, size_t offset, bool masked, uint64_t msk);
+    void load_src(int jj, int ll, int c_tail);
 
-    void load_src(int jj, int c_step);
-    void store_dst(int jj, int c_step);
+    void store_dst_max_op(int jj, int ll, size_t offset, bool masked, uint64_t msk);
+    void store_dst_avg_op(int jj, int ll, size_t offset, bool masked, uint64_t msk);
+    void store_dst(int jj, int ll, int c_tail);
 
-    void compute_avg_step(int ur_c, int c_step);
-    void compute_max_step(int ur_c, int c_step);
-    void compute_step(int ur_c, int c_step);
+    void compute_avg_step(int ur_c, int c_tail);
+    void compute_max_op(const int jj);
+    void compute_max_step(int ur_c, int c_tail);
+    void compute_step(int ur_c, int c_tail);
 
     void compute_c_block();
     void generate();
@@ -137,7 +149,7 @@ struct jit_uni_i8i8_pool_fwd_ker_t : public jit_generator {
         const pooling_desc_t &pd, const memory_desc_wrapper &src_d,
         const memory_desc_wrapper &dst_d);
 
-    jit_uni_i8i8_pool_fwd_ker_t(const jit_pool_conf_t &jpp_)
+    jit_uni_i8i8_pooling_fwd_ker_t(const jit_pool_conf_t &jpp_)
            : jpp(jpp_) {
         generate();
         ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(
@@ -145,179 +157,376 @@ struct jit_uni_i8i8_pool_fwd_ker_t : public jit_generator {
     }
 };
 
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_vreg_mask_q(int ll) {
+
+    // extract ll-th part of mask (ll-th QWORD)
+    vpblendd(vreg_mask_q, vreg_zeros, vreg_mask, 0x3 << ll); // 0x3 - mask for 2 x DWORD
+
+    // Move mask from ll-th pos to 0-th pos
+    if (ll>0)
+        vpermq(vreg_mask_q, vreg_mask_q, ll);
+};
+
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_src_max_op(int jj, int ll,
+        size_t offset, bool masked, uint64_t msk) {
+    using namespace data_type;
+
+    if (masked) {
+        if (jpp.src_dt == s32) {
+            vpblendd(vreg_src(jj), vreg_tmp, ptr[aux_reg_src_w + offset], static_cast<uint8_t>(msk));
+        } else {
+            vpblendvb(vreg_src(jj), vreg_tmp, ptr[aux_reg_src_w + offset], vreg_mask);
+        }
+    } else
+        vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]);
+};
+
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::load_src_max_op(int jj, int ll,
+        size_t offset, bool masked, uint64_t msk) {
+    using namespace data_type;
+
+    if (masked) {
+        if (jpp.src_dt == s32)
+            vmovups(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]);
+        else
+            vmovdqu8(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]);
+    } else
+        vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]);
+};
+
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_src_avg_op(int jj, int ll,
+        size_t offset, bool masked, uint64_t msk) {
+    using namespace data_type;
+
+    // Don't generate useless code
+    if (masked && !msk)
+        return;
+
+    auto load_i8 = [&](bool is_signed, const Vmm& vr_src) {
+
+        // Need to use mask of tail?
+        if (masked) {
+
+            // load ll-th part of mask into vreg_mask_q
+            load_vreg_mask_q(ll);
+
+            // Load by mask from mem into register vr_src
+            vpblendvb(vr_src, vreg_zeros, ptr[aux_reg_src_w + offset], vreg_mask_q);
+
+            // Conversion s8/u8 -> s32
+            if (is_signed)
+                vpmovsxbd(vr_src, vr_src);
+            else
+                vpmovzxbd(vr_src, vr_src);
+        } else {
+
+            // Load from mem into vr_src with conversion
+            if (is_signed)
+                vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]);
+            else
+                vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]);
+        }
+    };
+
+    switch (jpp.src_dt) {
+        case s32:
+            if (masked)
+                vpblendd(vreg_src_s32(jj, ll), vreg_zeros, ptr[aux_reg_src_w + offset],
+                    static_cast<uint8_t>(msk));
+            else
+                vmovups(vreg_src_s32(jj, ll), ptr[aux_reg_src_w + offset]);
+            break;
+        case s8:
+                load_i8(true, vreg_src_s32(jj, ll));
+            break;
+        case u8:
+                load_i8(false, vreg_src_s32(jj, ll));
+            break;
+        default: assert(!"unsupported src data type");
+    }
+};
+
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::load_src_avg_op(int jj, int ll,
+        size_t offset, bool masked, uint64_t msk) {
+    using namespace data_type;
+
+    // Don't generate useless code
+    if (masked && !msk)
+        return;
+
+    const Vmm& vr_src = masked ?
+            vreg_src_s32(jj, ll) | mask(ll) :
+            vreg_src_s32(jj, ll);
+
+    switch (jpp.src_dt) {
+        case s32:
+            vmovups(vr_src, ptr[aux_reg_src_w + offset]);
+            break;
+        case s8:
+            vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]);
+            break;
+        case u8:
+            vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]);
+            break;
+        default: assert(!"unsupported src data type");
+    }
+};
+
 template <cpu_isa_t isa>
-void jit_uni_i8i8_pool_fwd_ker_t<isa>::load_src(int jj, int c_step) {
+void jit_uni_i8i8_pooling_fwd_ker_t<isa>::load_src(int jj, int ll, int c_tail) {
     using namespace data_type;
 
-    int repeats = isa == sse42 && c_step != 1 ? 2 : 1;
+    int c_block = jpp.c_block;
+    int ur_c = jpp.ur_c;
+
     switch (jpp.alg) {
         case pooling_max: {
-            auto offset = jj*c_step*sizeof_src_dt();
-            if (c_step == jpp.c_block) {
-                for (int ii = 0; ii < repeats; ii++)
-                    uni_vmovups(vmm_src(jj, ii), ptr[aux_reg_src_w + offset + (jpp.c_block / 2) * ii * sizeof_src_dt()]);
-            } else if (c_step == 1) {
-                if (jpp.src_dt == s32) {
-                    movsd(xmm_src(jj), ptr[aux_reg_src_w + offset]);
-                } else {
-                    mov(reg_src_8, ptr[aux_reg_src_w + offset]);
-                    movq(xmm_src(jj), reg_src_64);
-                }
-            }
+            auto offset = jj*c_block*sizeof_src_dt();
+            bool masked = jj == ur_c - 1 && c_tail;
+            load_src_max_op(jj, ll, offset, masked, jpp.tail[0]);
             break;
         }
         case pooling_avg_include_padding:
         case pooling_avg_exclude_padding: {
-            auto offset = jj*c_step*sizeof_src_dt();
-            switch (jpp.src_dt) {
-                case s32:
-                    if (c_step == jpp.c_block) {
-                        for (int ii = 0; ii < repeats; ii++)
-                            uni_vmovups(vmm_src_s32(jj, ii), ptr[aux_reg_src_w + offset + (jpp.c_block / 2) * ii * sizeof_src_dt()]);
-                    } else if (c_step == 1) {
-                        movsd(xmm_src_s32(jj, 0), ptr[aux_reg_src_w + offset]);
-                    }
-                    break;
-                case s8:
-                    if (c_step == jpp.c_block) {
-                        for (int ii = 0; ii < repeats; ii++) {
-                            if (isa == sse42)
-                                movd(xmm_src_s32(jj, ii), ptr[aux_reg_src_w + offset + (jpp.c_block / 2) * ii * sizeof_src_dt()]);
-                            else
-                                movq(xmm_src_s32(jj, ii), ptr[aux_reg_src_w + offset + (jpp.c_block / 2) * ii * sizeof_src_dt()]);
-
-                            uni_vpmovsxbd(vmm_src_s32(jj, ii), xmm_src_s32(jj, ii));
-                        }
-                    } else if (c_step == 1) {
-                        movsx(reg_src_32, ptr[aux_reg_src_w + offset]);
-                        movq(xmm_src_s32(jj, 0), reg_src_64);
-                    }
-                    break;
-                case u8:
-                    if (c_step == jpp.c_block) {
-                        for (int ii = 0; ii < repeats; ii++) {
-                            if (isa == sse42)
-                                movd(xmm_src_s32(jj, ii), ptr[aux_reg_src_w + offset + (jpp.c_block / 2) * ii * sizeof_src_dt()]);
-                            else
-                                movq(xmm_src_s32(jj, ii), ptr[aux_reg_src_w + offset + (jpp.c_block / 2) * ii * sizeof_src_dt()]);
-
-                            uni_vpmovzxbd(vmm_src_s32(jj, ii), xmm_src_s32(jj, ii));
-                        }
-                    } else if (c_step == 1) {
-                        movzx(reg_src_32, ptr[aux_reg_src_w + offset]);
-                        movq(xmm_src_s32(jj, 0), reg_src_64);
-                    }
-                    break;
-                default: assert(!"unsupported src data type");
-            }
+            auto offset = (ll*(c_block/max_num_ll) + jj*c_block)*sizeof_src_dt();
+            bool masked = jj == ur_c - 1 && c_tail;
+            load_src_avg_op(jj, ll, offset, masked, jpp.tail[ll]);
             break;
         }
         default: assert(!"unsupported algorithm");
     }
 }
 
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::store_dst_max_op(int jj, int ll,
+        size_t offset, bool masked, uint64_t msk) {
+    using namespace data_type;
+
+    int c_block = jpp.c_block;
+
+    if (masked) {
+        switch (jpp.src_dt) {
+            case s32:
+                vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask, vreg_dst(jj));
+                break;
+            case s8:
+            case u8: {
+                // Store low half by mask (bytes 0...15)
+                lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]);
+                maskmovdqu(vreg_dst(jj), xreg_mask_lo);
+
+                // Do we need to store high half (bytes 16...31) ?
+                const uint64_t low_mask = (1ULL << (c_block/2))-1;
+                if (msk & ~low_mask) {
+                    vextracti128(Xmm(vreg_dst(jj).getIdx()), vreg_dst(jj), 1);
+                    add(reg_ptr_maskmovdqu_dst, c_block / 2);
+                    maskmovdqu(vreg_dst(jj), xreg_mask_hi);
+                }
+            } break;
+            default: assert(!"unsupported src data type");
+        }
+    } else
+        vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj));
+}
+
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::store_dst_max_op(int jj, int ll,
+        size_t offset, bool masked, uint64_t msk) {
+    using namespace data_type;
+
+    if (masked) {
+        switch (jpp.src_dt) {
+            case s32:
+                vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0));
+                break;
+            case s8:
+            case u8:
+                vmovdqu8(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0));
+                break;
+            default: assert(!"unsupported src data type");
+        }
+    } else
+        vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj));
+}
+
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::store_dst_avg_op(int jj, int ll,
+        size_t offset, bool masked, uint64_t msk){
+    using namespace data_type;
+
+    // Don't generate useless code
+    if (masked && !msk)
+        return;
+
+    auto s32_to_i8 = [&](bool is_signed, const Vmm& vr_dst) {
+
+        // conversion: s32 -> s16/u16 : {8 x s32}{8 x 0} -> {16 x s16/u16}
+        // Result QWORDs (qw0, qw1) permuted: {qw0, 0, qw1, 0}
+        if (is_signed)
+            vpackssdw(vr_dst, vr_dst, vreg_zeros);
+        else
+            vpackusdw(vr_dst, vr_dst, vreg_zeros);
+
+        // Permute qwords to restore original order
+        // {qw0, 0, qw1, 0} -> {qw0, qw1, 0, 0}
+        vpermq(vr_dst, vr_dst, 0x58);
+
+        // conversion: s16/u16 -> s8/u8 : {16 x s16/u16}{16 x 0} -> {32 x s8/u8}
+        // Target QWORD qw = {8 x s8/u8} has proper position: {qw, xx, xx, xx}
+        if (is_signed)
+            vpacksswb(vr_dst, vr_dst, vreg_zeros);
+        else
+            vpackuswb(vr_dst, vr_dst, vreg_zeros);
+
+    };
+
+    auto store_i8 = [&](bool is_signed, bool is_masked, const Vmm& vr_dst) {
+
+        // Conversion s32 -> s8/u8
+        s32_to_i8(is_signed, vr_dst);
+
+        // Need to use mask of tail?
+        if (is_masked) {
+            // load ll-th part of mask into vreg_mask_q
+            load_vreg_mask_q(ll);
+        }
+
+        // store 8 bytes
+        lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]);
+        maskmovdqu(vr_dst, xreg_mask_q);
+    };
+
+    switch (jpp.dst_dt) {
+        case s32:
+            if (masked) {
+                vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask, vreg_dst_s32(jj, ll));
+            } else
+                vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst_s32(jj, ll));
+            break;
+        case s8:
+            store_i8(true, masked, vreg_dst_s32(jj, ll));
+            break;
+        case u8:
+            store_i8(false, masked, vreg_dst_s32(jj, ll));
+            break;
+        default: assert(!"unsuppotred dst data_type");
+    }
+}
+
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::store_dst_avg_op(int jj, int ll,
+        size_t offset, bool masked, uint64_t msk) {
+    using namespace data_type;
+
+    // Don't generate useless code
+    if (masked && !msk)
+        return;
+
+    const Vmm& vr_dst = masked ?
+            vreg_dst_s32(jj, ll) | mask(ll) :
+            vreg_dst_s32(jj, ll);
+
+    switch (jpp.dst_dt) {
+        case s32:
+            vmovups(ptr[reg_ptr_dst_i8 + offset], vr_dst);
+            break;
+        case s8:
+            vpmovdb(ptr[reg_ptr_dst_i8 + offset], vr_dst);
+            break;
+        case u8:
+            vpmovusdb(ptr[reg_ptr_dst_i8 + offset], vr_dst);
+            break;
+        default: assert(!"unsupported dst data_type");
+    }
+}
+
+
 template <cpu_isa_t isa>
-void jit_uni_i8i8_pool_fwd_ker_t<isa>::store_dst(int jj, int c_step) {
+void jit_uni_i8i8_pooling_fwd_ker_t<isa>::store_dst(int jj, int ll,
+        int c_tail) {
     using namespace data_type;
 
-    int repeats = isa == sse42 && c_step != 1 ? 2 : 1;
+    int c_block = jpp.c_block;
+    int ur_c = jpp.ur_c;
+
     switch(jpp.alg) {
         case pooling_max: {
-            auto offset = jj*c_step*sizeof_dst_dt();
-            if (c_step == jpp.c_block) {
-                for (int ii = 0; ii < repeats; ii++)
-                    uni_vmovups(ptr[reg_ptr_dst_i8 + offset + (jpp.c_block / 2) * ii * sizeof_dst_dt()], vmm_dst(jj, ii));
-            } else if (c_step == 1) {
-                if (jpp.src_dt == s32) {
-                    movq(reg_src_64, xmm_dst(jj));
-                    mov(ptr[reg_ptr_dst_i8 + offset], reg_src_32);
-                } else {
-                    movq(reg_src_64, xmm_dst(jj));
-                    mov(ptr[reg_ptr_dst_i8 + offset], reg_src_8);
-                }
-            }
+            auto offset = jj*c_block*sizeof_dst_dt();
+            bool masked = jj == ur_c - 1 && c_tail;
+            store_dst_max_op(jj, ll, offset, masked, jpp.tail[ll]);
             break;
         }
         case pooling_avg_include_padding:
         case pooling_avg_exclude_padding: {
-            auto offset = jj*c_step*sizeof_dst_dt();
-            switch (jpp.dst_dt) {
-                case s32:
-                    if (c_step == jpp.c_block) {
-                        for (int ii = 0; ii < repeats; ii++)
-                            uni_vmovups(ptr[reg_ptr_dst_i8 + offset + (jpp.c_block / 2) * ii * sizeof_dst_dt()], vmm_dst_s32(jj, ii));
-                    } else if (c_step == 1) {
-                        movq(reg_src_64, xmm_dst_s32(jj, 0));
-                        mov(ptr[reg_ptr_dst_i8 + offset], reg_src_32);
-                    }
-                    break;
-                case s8:
-                    if (c_step == jpp.c_block) {
-                        for (int ii = 0; ii < repeats; ii++) {
-                            uni_vpackssdw(vmm_dst_s32(jj, ii), vmm_dst_s32(jj, ii), vmm_dst_s32(jj, ii));
-
-                            if (isa != sse42)
-                                vpermq(ymm_dst_s32(jj, ii), ymm_dst_s32(jj, ii), 0x08);
-
-                            uni_vpacksswb(xmm_dst_s32(jj, ii), xmm_dst_s32(jj, ii), xmm_dst_s32(jj, ii));
-
-                            if (isa != sse42)
-                                movq(ptr[reg_ptr_dst_i8 + offset + (jpp.c_block / 2) * ii * sizeof_dst_dt()], xmm_dst_s32(jj, ii));
-                            else
-                                movd(ptr[reg_ptr_dst_i8 + offset + (jpp.c_block / 2) * ii * sizeof_dst_dt()], xmm_dst_s32(jj, ii));
-                        }
-                    } else if (c_step == 1) {
-                        vpackssdw(vmm_dst_s32(jj, 0), vmm_dst_s32(jj, 0), vmm_dst_s32(jj, 0));
-                        vpacksswb(xmm_dst_s32(jj, 0), xmm_dst_s32(jj, 0), xmm_dst_s32(jj, 0));
-                        movq(reg_src_64, xmm_dst_s32(jj, 0));
-                        mov(ptr[reg_ptr_dst_i8 + offset], reg_src_8);
-                    }
-                    break;
-                case u8:
-                    if (c_step == jpp.c_block) {
-                        for (int ii = 0; ii < repeats; ii++) {
-                            uni_vpackusdw(vmm_dst_s32(jj, ii), vmm_dst_s32(jj, ii), vmm_dst_s32(jj, ii));
-
-                            if (isa != sse42)
-                                vpermq(ymm_dst_s32(jj, ii), ymm_dst_s32(jj, ii), 0x08);
-
-                            uni_vpackuswb(xmm_dst_s32(jj, ii), xmm_dst_s32(jj, ii), xmm_dst_s32(jj, ii));
-
-                            if (isa != sse42)
-                                movq(ptr[reg_ptr_dst_i8 + offset + (jpp.c_block / 2) * ii * sizeof_dst_dt()], xmm_dst_s32(jj, ii));
-                            else
-                                movd(ptr[reg_ptr_dst_i8 + offset + (jpp.c_block / 2) * ii * sizeof_dst_dt()], xmm_dst_s32(jj, ii));
-                        }
-                    } else if (c_step == 1) {
-                        vpackusdw(vmm_dst_s32(jj, 0), vmm_dst_s32(jj, 0), vmm_dst_s32(jj, 0));
-                        vpackuswb(xmm_dst_s32(jj, 0), xmm_dst_s32(jj, 0), xmm_dst_s32(jj, 0));
-                        movq(reg_src_64, xmm_dst_s32(jj, 0));
-                        mov(ptr[reg_ptr_dst_i8 + offset], reg_src_8);
-                    }
-                    break;
-                default: assert(!"unsuppotred dst data_type");
-            }
+            auto offset = (ll*(c_block/max_num_ll) + jj*c_block)*sizeof_dst_dt();
+            bool masked = jj == ur_c - 1 && c_tail;
+            store_dst_avg_op(jj, ll, offset, masked, jpp.tail[ll]);
             break;
         }
         default: assert(!"unsupported pooling algorithm");
     }
 }
 
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::compute_max_op(const int jj)
+{
+    using namespace data_type;
+    switch (jpp.src_dt) {
+        case s32:
+            vpmaxsd(vreg_dst(jj), vreg_dst(jj), vreg_src(jj));
+            break;
+        case s8:
+            vpmaxsb(vreg_dst(jj), vreg_dst(jj), vreg_src(jj));
+            break;
+        case u8:
+            vpmaxub(vreg_dst(jj), vreg_dst(jj), vreg_src(jj));
+            break;
+        default: assert(!"unsupported src data type");
+    }
+}
+
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::compute_max_op(const int jj)
+{
+    using namespace data_type;
+
+    // Compare
+    switch (jpp.src_dt) {
+        case s32:
+            vpcmpd(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
+            break;
+        case s8:
+            vpcmpb(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
+            break;
+        case u8:
+            vpcmpub(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
+            break;
+        default: assert(!"unsupported src data type");
+    }
+
+    // move max values into vreg_dst
+    if (jpp.src_dt == s32)
+        vpblendmd(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj));
+    else
+        vpblendmb(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj));
+}
+
+
 template <cpu_isa_t isa>
-void jit_uni_i8i8_pool_fwd_ker_t<isa>::compute_max_step(int ur_c, int c_step)
+void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_max_step(int ur_c, int c_tail)
 {
     Label l_kw, l_kh;
 
     int iw = jpp.iw;
     int c = jpp.c;
 
-    int repeats = isa == sse42 && c_step != 1 ? 2 : 1;
-
-    for (int jj = 0; jj < ur_c; jj++) {
-        for (int ii = 0; ii < repeats; ii++) {
-            uni_vmovups(vmm_dst(jj, ii), vreg_tmp);
-        }
-    }
+    for (int jj = 0; jj < ur_c; jj++)
+        vmovups(vreg_dst(jj), vreg_tmp);
 
     mov(aux_reg_src_h, reg_ptr_src_i8);
 
@@ -329,18 +538,8 @@ void jit_uni_i8i8_pool_fwd_ker_t<isa>::compute_max_step(int ur_c, int c_step)
         L(l_kw);
         {
             for (int jj = 0; jj < ur_c; jj++) {
-                load_src(jj, c_step);
-
-                for (int ii = 0; ii < repeats; ii++) {
-                    if (jpp.src_dt == data_type::s32) {
-                        uni_vpmaxsd(vmm_dst(jj, ii), vmm_dst(jj, ii), vmm_src(jj, ii));
-                    } else {
-                        if (jpp.src_dt == data_type::s8)
-                            uni_vpmaxsb(vmm_dst(jj, ii), vmm_dst(jj, ii), vmm_src(jj, ii));
-                        else
-                            uni_vpmaxub(vmm_dst(jj, ii), vmm_dst(jj, ii), vmm_src(jj, ii));
-                    }
-                }
+                load_src(jj, 0, c_tail);
+                compute_max_op(jj);
             }
             add(aux_reg_src_w, c * sizeof_src_dt());
             inc(ki);
@@ -354,11 +553,11 @@ void jit_uni_i8i8_pool_fwd_ker_t<isa>::compute_max_step(int ur_c, int c_step)
     }
 
     for (int jj = 0; jj < ur_c; jj++)
-        store_dst(jj, c_step);
+        store_dst(jj, 0, c_tail);
 }
 
 template <cpu_isa_t isa>
-void jit_uni_i8i8_pool_fwd_ker_t<isa>::compute_avg_step(int ur_c, int c_step)
+void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_avg_step(int ur_c, int c_tail)
 {
     using namespace data_type;
 
@@ -367,12 +566,16 @@ void jit_uni_i8i8_pool_fwd_ker_t<isa>::compute_avg_step(int ur_c, int c_step)
     int iw = jpp.iw;
     int c = jpp.c;
 
-    int repeats = isa == sse42 && c_step != 1 ? 2 : 1;
+    const int num_ll = data_type_size(avg_proc_dt)/data_type_size(jpp.src_dt);
 
     for (int jj = 0; jj < ur_c; jj++) {
-        for (int ii = 0; ii < repeats; ii++) {
-            uni_vpxor(vmm_src_s32(jj, ii), vmm_src_s32(jj, ii), vmm_src_s32(jj, ii));
-            uni_vpxor(vmm_dst_s32(jj, ii), vmm_dst_s32(jj, ii), vmm_dst_s32(jj, ii));
+        for (int ll = 0; ll < num_ll; ll++) {
+            bool masked = jj == ur_c - 1 && c_tail;
+            size_t msk = jpp.tail[ll];
+            if (!(masked && !msk)) {
+                uni_vpxor(vreg_src_s32(jj, ll), vreg_src_s32(jj, ll), vreg_src_s32(jj, ll));
+                uni_vpxor(vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll));
+            }
         }
     }
 
@@ -386,10 +589,14 @@ void jit_uni_i8i8_pool_fwd_ker_t<isa>::compute_avg_step(int ur_c, int c_step)
         L(l_kw);
         {
             for (int jj = 0; jj < ur_c; jj++) {
-                load_src(jj, c_step);
-
-                for (int ii = 0; ii < repeats; ii++) {
-                    uni_vpaddd(vmm_dst_s32(jj, ii), vmm_dst_s32(jj, ii), vmm_src_s32(jj, ii));
+                for (int ll = 0; ll < num_ll; ll++) {
+                    bool masked = jj == ur_c - 1 && c_tail;
+                    size_t msk = jpp.tail[ll];
+                    if (!(masked && !msk)) {
+                        load_src(jj, ll, c_tail);
+                        vpaddd(vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll),
+                                vreg_src_s32(jj, ll));
+                    }
                 }
             }
             add(aux_reg_src_w, c * sizeof_src_dt());
@@ -404,82 +611,171 @@ void jit_uni_i8i8_pool_fwd_ker_t<isa>::compute_avg_step(int ur_c, int c_step)
     }
 
     for (int jj = 0; jj < ur_c; jj++) {
-        for (int ii = 0; ii < repeats; ii++) {
-            uni_vcvtdq2ps(vmm_dst_f32(jj, ii), vmm_dst_s32(jj, ii));
+        for (int ll = 0; ll < num_ll; ll++) {
+            bool masked = jj == ur_c - 1 && c_tail;
+            size_t msk = jpp.tail[ll];
+            if (!(masked && !msk)) {
 
-            if (isa == sse42)
-                mulps(vmm_dst_f32(jj, ii), vreg_tmp);
-            else
-                vfmadd132ps(vmm_dst_f32(jj, ii), vreg_zeros, vreg_tmp);
+                vcvtdq2ps(vreg_dst_f32(jj, ll), vreg_dst_s32(jj, ll));
+                vfmadd132ps(vreg_dst_f32(jj, ll), vreg_zeros, vreg_tmp);
 
-            uni_vcvtps2dq(vmm_dst_s32(jj, ii), vmm_dst_f32(jj, ii));
-        }
+                if (isa == avx2) {
+                    uni_vroundps(vreg_dst_f32(jj, ll), vreg_dst_f32(jj, ll), rnd_op_nearest);
+                    vcvtps2dq(vreg_dst_s32(jj, ll), vreg_dst_f32(jj, ll));
+                } else if (isa >= avx512_common) {
+                    // AVX512: use of EVEX-embedded static rounding override
+                    vcvtps2dq(vreg_dst_s32(jj, ll) | T_rn_sae, vreg_dst_f32(jj, ll));
+                }
 
-        store_dst(jj, c_step);
+                store_dst(jj, ll, c_tail);
+            }
+        }
     }
 }
 
 template <cpu_isa_t isa>
-void jit_uni_i8i8_pool_fwd_ker_t<isa>::compute_step(int ur_c, int c_step) {
+void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_step(int ur_c, int c_tail) {
     switch (jpp.alg) {
         case pooling_max:
-            compute_max_step(ur_c, c_step); break;
+            compute_max_step(ur_c, c_tail); break;
         case pooling_avg_include_padding:
         case pooling_avg_exclude_padding:
-            compute_avg_step(ur_c, c_step); break;
+            compute_avg_step(ur_c, c_tail); break;
         default: assert(!"unsupported pooling algorithm");
     }
 }
 
 template <cpu_isa_t isa>
-void jit_uni_i8i8_pool_fwd_ker_t<isa>::compute_c_block() {
+void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_c_block(){
     Label l_main_loop;
-    Label l_tail_loop;
-    Label exit;
 
+    int nb_c = jpp.nb_c;
+    int c_block = jpp.c_block;
     int ur_c = jpp.ur_c;
+    int ur_c_tail = jpp.ur_c_tail;
+    int c_steps = nb_c / ur_c;
+    int c_tail = jpp.c_tail;
 
     xor_(c_iter, c_iter);
+    if (c_steps > 0) {
+        L(l_main_loop); {
+            compute_step(ur_c, 0);
+            add(reg_ptr_src_i8, ur_c*c_block*sizeof_src_dt());
+            add(reg_ptr_dst_i8, ur_c*c_block*sizeof_dst_dt());
+            inc(c_iter);
+            cmp(c_iter, c_steps);
+            jl(l_main_loop, T_NEAR);
+        }
+    }
 
-    L(l_main_loop);
-    {
-        cmp(c_iter, jpp.c - ur_c * jpp.c_block);
-        jg(l_tail_loop, T_NEAR);
+    if (ur_c_tail != 0) {
+        compute_step(ur_c_tail, c_tail);
+    }
+}
 
-        compute_step(ur_c, jpp.c_block);
+template<>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::init_mask() {
+    using namespace data_type;
+    using cpu_isa = cpu_isa_traits<avx2>;
+
+    // AVX2 mask initialization: mask stored in Ymm-regs
+    auto init = [&](uint64_t bit_mask, bool init_mask_q) {
+        const size_t QW_PER_VREG = cpu_isa::vlen / sizeof(uint64_t);
+
+        uint64_t vmask[QW_PER_VREG];
+        for (size_t i = 0; i < QW_PER_VREG; i++){
+
+            uint64_t qw_vmask=0ULL;
+            const size_t DBITS = 8*sizeof_src_dt();
+            const uint64_t VMSK = 1ULL << (DBITS-1);
+            const size_t D_PER_QW = (8*sizeof(qw_vmask))/DBITS;
+            for (size_t j = 0; j < D_PER_QW; j++) {
+                if (bit_mask & 1)
+                    qw_vmask |= VMSK << DBITS * j;
+                bit_mask >>= 1;
+            }
+            vmask[i] = qw_vmask;
+        }
 
-        add(reg_ptr_src_i8, ur_c * jpp.c_block * sizeof_src_dt());
-        add(reg_ptr_dst_i8, ur_c * jpp.c_block * sizeof_dst_dt());
-        add(c_iter, ur_c * jpp.c_block);
-        jmp(l_main_loop);
-    }
+        // Put QWORDS with target mask into xmm regs
+        const int xdst_i[QW_PER_VREG] = {
+                xreg_mask_lo.getIdx(),
+                xreg_mask_lo.getIdx(),
+                xreg_mask_hi.getIdx(),
+                xreg_mask_hi.getIdx()
+        };
+        const int xsrc_i[QW_PER_VREG] = {
+                vreg_zeros.getIdx(),   // 0-th qword insert in zeros -> {qw0,  0}
+                xreg_mask_lo.getIdx(), // 1-st and 0-th merge        -> {qw0,qw1}
+                vreg_zeros.getIdx(),
+                xreg_mask_hi.getIdx()
+        };
+        const uint8 qw_dst_idx[QW_PER_VREG] = {0, 1, 0, 1}; // qword index in 128-bit xreg
+
+        for (size_t i = 0; i < QW_PER_VREG; i++) {
+            mov(reg_mask, vmask[i]);
+            vpinsrq(Xmm(xdst_i[i]), Xmm(xsrc_i[i]), reg_mask, qw_dst_idx[i]);
+        }
 
-    L(l_tail_loop);
-    {
-        cmp(c_iter, jpp.c - ur_c);
-        jg(exit, T_NEAR);
+        // Merge Low (xreg_mask_lo alias for vreg_mask.xreg)
+        // and High (xreg_mask_hi) into full vreg_mask
+        // vreg_mask -> {xreg_mask_hi, vreg_mask.xreg}
+        vinserti128(vreg_mask, vreg_mask, xreg_mask_hi, 1);
 
-        compute_step(ur_c, 1);
+        // Keep only low qword of mask in xreg_mask_q
+        if (init_mask_q) {
+            mov(reg_mask, vmask[0]);
+            vpinsrq(xreg_mask_q, Xmm(vreg_zeros.getIdx()), reg_mask, 0);
+        }
+    };
 
-        add(reg_ptr_src_i8, ur_c * sizeof_src_dt());
-        add(reg_ptr_dst_i8, ur_c * sizeof_dst_dt());
-        add(c_iter, ur_c);
-        jmp(l_tail_loop);
+    uint64_t tail_mask = (1ULL << jpp.c_tail) - 1;
+    switch (jpp.alg) {
+        case pooling_max:
+            // For "max" we need mask only in case of non-zero tail
+            if (tail_mask)
+                init(tail_mask, false);
+            break;
+        case pooling_avg_include_padding:
+        case pooling_avg_exclude_padding:
+            // For "avg" we need mask:
+            // - s32   - in case of the non-zero tail
+            // - s8/u8 - irrespective of the tail
+            switch (jpp.src_dt) {
+                case s32:
+                    if (tail_mask)
+                        init(tail_mask, false);
+                    break;
+                case s8:
+                case u8:
+                    init(tail_mask ? tail_mask : ~0ULL, tail_mask == 0);
+                    break;
+                default: assert(!"unsupported src data type");
+            }
+            break;
+        default: assert(!"unsupported pooling algorithm");
     }
+}
+
+template<>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::init_mask() {
 
-    L(exit);
+    for (int ll = 0; ll < max_num_ll; ll++) {
+        mov(reg_mask, jpp.tail[ll]);
+        kmovq(mask(ll), reg_mask);
+    }
 }
 
 template <cpu_isa_t isa>
-void jit_uni_i8i8_pool_fwd_ker_t<isa>::init_tmp_reg() {
+void jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_tmp_reg() {
     using namespace data_type;
 
     switch (jpp.alg) {
         case pooling_avg_include_padding:
         case pooling_avg_exclude_padding:
-            mov(reg_tmp, ptr[abi_param1 + offsetof(call_params_t, idivider)]);
+            mov(reg_tmp, ptr[reg_param + offsetof(call_params_t, idivider)]);
             movq(xmm_tmp, reg_tmp);
-            uni_vpbroadcastd(vreg_tmp, xmm_tmp);
+            vpbroadcastd(vreg_tmp, xmm_tmp);
             break;
         case pooling_max:
             switch (jpp.src_dt) {
@@ -496,17 +792,10 @@ void jit_uni_i8i8_pool_fwd_ker_t<isa>::init_tmp_reg() {
             }
 
             movq(xmm_tmp, reg_tmp);
-            if (jpp.src_dt == s32) {
-                uni_vpbroadcastd(vreg_tmp, xmm_tmp);
-            } else {
-                if (isa == avx2) {
-                    vpbroadcastb(vreg_tmp, xmm_tmp);
-                } else {
-                    movups(vreg_tmp, xmm_tmp);
-                    uni_vpxor(xmm_tmp, xmm_tmp, xmm_tmp);
-                    pshufb(vreg_tmp, xmm_tmp);
-                }
-            }
+            if (jpp.src_dt == s32)
+                vpbroadcastd(vreg_tmp, xmm_tmp);
+            else
+                vpbroadcastb(vreg_tmp, xmm_tmp);
             break;
         default: assert(!"unsupported pooling algorithm");
     }
@@ -514,11 +803,17 @@ void jit_uni_i8i8_pool_fwd_ker_t<isa>::init_tmp_reg() {
 }
 
 template <cpu_isa_t isa>
-void jit_uni_i8i8_pool_fwd_ker_t<isa>::generate() {
+void jit_uni_i8i8_pooling_fwd_ker_t<isa>::generate() {
     preamble();
 
+#if !defined(_WIN32)
+    // Always use rcx as abi_param1 -
+    // see the note about maskmovdqu near reg_param.
+    mov(rcx, rdi);
+#endif
+
 #   define READ_PARAM(reg, field) \
-        mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
+        mov(reg, ptr[reg_param + offsetof(call_params_t, field)])
     READ_PARAM(reg_ptr_src_i8, src_i8);
     READ_PARAM(reg_ptr_dst_i8, dst_i8);
     READ_PARAM(reg_kw, kw_range);
@@ -526,22 +821,23 @@ void jit_uni_i8i8_pool_fwd_ker_t<isa>::generate() {
 
 #   undef READ_PARAM
 
-    init_tmp_reg();
-
     uni_vpxor(vreg_zeros, vreg_zeros, vreg_zeros);
 
+    init_mask();
+
+    init_tmp_reg();
+
     compute_c_block();
 
     postamble();
 }
 
 template <cpu_isa_t isa>
-status_t jit_uni_i8i8_pool_fwd_ker_t<isa>::init_conf(jit_pool_conf_t &jpp,
+status_t jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_conf(jit_pool_conf_t &jpp,
         const pooling_desc_t &pd, const memory_desc_wrapper &src_d,
         const memory_desc_wrapper &dst_d) {
-    if (!mayiuse(isa)) {
+    if (!mayiuse(isa))
         return status::unimplemented;
-    }
 
     jpp.mb = src_d.dims()[0];
     jpp.c = src_d.dims()[1];
@@ -563,71 +859,106 @@ status_t jit_uni_i8i8_pool_fwd_ker_t<isa>::init_conf(jit_pool_conf_t &jpp,
     jpp.src_dt = pd.src_desc.data_type;
     jpp.dst_dt = pd.dst_desc.data_type;
 
-    jpp.c_block = jpp.alg == pooling_max ? 32 / (jpp.src_dt == data_type::s32 ? 4 : 1) : 8;
+    // data_type items per one vreg on the <isa>
+    //     isa == avx2    : 32 bytes -> 32 for s8/u8, 8 for s32
+    //     isa == avx512* : 64 bytes -> 64 for s8/u8, 16 for s32
+    int simd_w = cpu_isa_traits<isa>::vlen / data_type_size(jpp.src_dt);
+
+    jpp.c_block = simd_w;
     jpp.c_tail = jpp.c % jpp.c_block;
     jpp.nb_c = jpp.c / jpp.c_block;
     jpp.ur_c = 1;
-    jpp.ur_c_tail = jpp.nb_c - (jpp.nb_c / jpp.ur_c)*jpp.ur_c + (jpp.c_tail != 0);
+    jpp.ur_c_tail = jpp.nb_c - (jpp.nb_c / jpp.ur_c)*jpp.ur_c +
+            (jpp.c_tail != 0);
+
+    size_t tail_mask = (1ULL << jpp.c_tail) - 1;
+
+    switch (jpp.alg) {
+        case pooling_max:
+            jpp.tail[0] = tail_mask;
+            jpp.tail[1] = 0;
+            jpp.tail[2] = 0;
+            jpp.tail[3] = 0;
+            break;
+        case pooling_avg_include_padding:
+        case pooling_avg_exclude_padding: {
+            // avg_proc_dt (s32) defines granularity (because u8/s8 processed as s32)
+            // avx2 : 8, avx512 : 16
+            const size_t msk_gran = cpu_isa_traits<isa>::vlen / data_type_size(avg_proc_dt);
+            const size_t msk_msk = (1ULL << msk_gran) - 1;
+            size_t m = tail_mask;
+            for (size_t ll = 0; ll < max_num_ll; ll++) {
+                jpp.tail[ll] = m & msk_msk;
+                m = m >> msk_gran;
+            }
+            break;
+        }
+        default: return status::unimplemented;
+    }
 
     return status::success;
 }
 
 template <cpu_isa_t isa>
 status_t jit_uni_i8i8_pooling_fwd_t<isa>::pd_t::jit_conf() {
-    return jit_uni_i8i8_pool_fwd_ker_t<isa>::init_conf(jpp_,
+    return jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_conf(jpp_,
        desc_, src_pd_.desc(), dst_pd_.desc());
 }
 
 template <cpu_isa_t isa>
-jit_uni_i8i8_pooling_fwd_t<isa>::jit_uni_i8i8_pooling_fwd_t(const pd_t *pd,
+jit_uni_i8i8_pooling_fwd_t<isa>::
+jit_uni_i8i8_pooling_fwd_t(const pd_t *apd,
           const input_vector &inputs, const output_vector &outputs)
-    : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd), ker_(nullptr) {
-    ker_ = new jit_uni_i8i8_pool_fwd_ker_t<isa>(conf_.jpp_);
-}
+    : cpu_primitive_t(apd, inputs, outputs), ker_(nullptr)
+{ ker_ = new jit_uni_i8i8_pooling_fwd_ker_t<isa>(pd()->jpp_); }
 
 template <cpu_isa_t isa>
-jit_uni_i8i8_pooling_fwd_t<isa>::~jit_uni_i8i8_pooling_fwd_t() {
-    delete ker_;
-}
+jit_uni_i8i8_pooling_fwd_t<isa>::
+~jit_uni_i8i8_pooling_fwd_t() { delete ker_; }
 
 template <cpu_isa_t isa>
-void jit_uni_i8i8_pooling_fwd_t<isa>::execute_forward() {
+void jit_uni_i8i8_pooling_fwd_t<isa>::execute_forward() const {
     auto src_i8 = reinterpret_cast<const char *>(input_memory(0));
     auto dst_i8 = reinterpret_cast<char *>(memory());
 
-    const memory_desc_wrapper src_d(conf_.src_pd());
-    const memory_desc_wrapper dst_d(conf_.dst_pd());
+    const memory_desc_wrapper src_d(pd()->src_pd());
+    const memory_desc_wrapper dst_d(pd()->dst_pd());
 
-    const auto &jpp = conf_.jpp_;
+    const auto &jpp = pd()->jpp_;
 
     parallel_nd(jpp.mb, jpp.oh, jpp.ow,
-        [&](int n, int oh, int ow) {
-        const int ih = nstl::max(oh * jpp.stride_h - jpp.t_pad, 0);
-        const int iw = nstl::max(ow * jpp.stride_w - jpp.l_pad, 0);
+            [&](int n, int oh, int ow) {
+        const int ih = nstl::max(oh*jpp.stride_h - jpp.t_pad, 0);
+        const int iw = nstl::max(ow*jpp.stride_w - jpp.l_pad, 0);
 
         const int kh_start = nstl::max(0, jpp.t_pad - oh * jpp.stride_h);
         const int kh_end = nstl::min(jpp.kh,
-                                     jpp.ih + jpp.t_pad - oh * jpp.stride_h);
+                jpp.ih + jpp.t_pad - oh * jpp.stride_h);
         const int kw_start = nstl::max(0, jpp.l_pad - ow * jpp.stride_w);
         const int kw_end = nstl::min(jpp.kw,
-                                     jpp.iw + jpp.l_pad - ow * jpp.stride_w);
+                jpp.iw + jpp.l_pad - ow * jpp.stride_w);
 
-        auto p = call_params_t();
+        auto p = typename jit_uni_i8i8_pooling_fwd_ker_t<isa>::call_params_t();
         p.src_i8 = &src_i8[
-                src_d.blk_off(n, 0, ih, iw) * src_d.data_type_size()];
+            src_d.blk_off(n, 0, ih, iw) * src_d.data_type_size()];
         p.dst_i8 = &dst_i8[
-                dst_d.blk_off(n, 0, oh, ow) * dst_d.data_type_size()];
-        p.kw_range = (size_t) (kw_end - kw_start);
-        p.kh_range = (size_t) (kh_end - kh_start);
+            dst_d.blk_off(n, 0, oh, ow) * dst_d.data_type_size()];
+        p.kw_range = (size_t)(kw_end - kw_start);
+        p.kh_range = (size_t)(kh_end - kh_start);
         p.idivider = 1.0f / ((jpp.alg == pooling_avg_exclude_padding) ?
-                             p.kh_range * p.kw_range : jpp.kw * jpp.kh);
+            p.kh_range*p.kw_range : jpp.kw*jpp.kh);
 
         ker_->ker_(&p);
     });
 }
 
+// Explicit instantiation only for supported <isa> values.
+//
+template struct jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>;
+template struct jit_uni_i8i8_pooling_fwd_t<avx512_core>;
+
+template struct jit_uni_i8i8_pooling_fwd_ker_t<avx2>;
 template struct jit_uni_i8i8_pooling_fwd_t<avx2>;
-template struct jit_uni_i8i8_pooling_fwd_t<sse42>;
 
 }
 }