Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_eltwise.cpp
index 2896b1b..f659fdc 100644 (file)
@@ -32,21 +32,10 @@ namespace cpu {
 using namespace Xbyak;
 
 template <cpu_isa_t isa>
-bool jit_uni_eltwise_injector_f32<isa>::is_free_vec(size_t idx) {
-    for (size_t i = 0; i < preserved_vecs_count; i++) {
-        if (preserved_vec_idxs[i] == idx) {
-            return false;
-        }
-    }
-    return true;
-}
-
-template <cpu_isa_t isa>
 void jit_uni_eltwise_injector_f32<isa>::injector_preamble(size_t start_idx,
         size_t end_idx) {
     preserved_vecs_count = 0;
-    vecs_to_preserve = (size_t)jit_uni_eltwise_injector_f32<isa>::
-            aux_vecs_count(elt_alg);
+    vecs_to_preserve = (size_t)aux_vecs_count(alg_);
     start_idx_tail = start_idx;
 
     // For sse42 mask register has to be Xmm(0)
@@ -56,78 +45,80 @@ void jit_uni_eltwise_injector_f32<isa>::injector_preamble(size_t start_idx,
         preserved_vec_idxs[preserved_vecs_count++] = idx;
     }
 
-    for (size_t i = 0; i < vecs_count; i++) {
-        if (preserved_vecs_count >= vecs_to_preserve)
-            break;
+    for (size_t idx = preserved_vecs_count; idx < vecs_count; idx++) {
+        if (preserved_vecs_count >= vecs_to_preserve) break;
+        if (start_idx <= idx && idx < end_idx) continue;
 
-        size_t idx = i;
-        if (is_free_vec(idx) && (idx < start_idx || idx >= end_idx)) {
-            preserved_vec_idxs[preserved_vecs_count++] = idx;
-        }
+        preserved_vec_idxs[preserved_vecs_count++] = idx;
     }
 
     size_t preserved_vecs_count_tail = vecs_to_preserve - preserved_vecs_count;
     for (size_t i = 0; i < preserved_vecs_count_tail; i++) {
-        size_t idx = start_idx_tail;
-        if (is_free_vec(idx)) {
-            preserved_vec_idxs[preserved_vecs_count++] = idx;
-            start_idx_tail++;
-        }
+        preserved_vec_idxs[preserved_vecs_count++] = start_idx_tail++;
     }
 
     assert(preserved_vecs_count == vecs_to_preserve);
 
-    if (save_vecs_state) {
+    if (save_state_) {
         h->push(p_table);
 
-        h->sub(h->rsp, preserved_vecs_count * vlen);
+        if (preserved_vecs_count)
+            h->sub(h->rsp, preserved_vecs_count * vlen);
+
         for (size_t i = 0; i < preserved_vecs_count; ++i)
             h->uni_vmovups(h->ptr[h->rsp + i * vlen],
                     Vmm(preserved_vec_idxs[i]));
+
+        load_table_addr();
     }
 
     assign_regs();
 }
 
 template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::injector_preamble_tail(
-        size_t start_idx) {
+void jit_uni_eltwise_injector_f32<isa>::injector_preamble_tail(size_t start_idx)
+{
     size_t tail_vecs_to_preserve = start_idx_tail - start_idx;
-    int idx_off = (vecs_to_preserve - tail_vecs_to_preserve);
+    if (tail_vecs_to_preserve == 0) return;
+
+    const int idx_off = vecs_to_preserve - tail_vecs_to_preserve;
 
-    if (tail_vecs_to_preserve > 0) {
-        if (save_vecs_state) {
+    if (save_state_) {
+        if (idx_off)
             h->add(h->rsp, idx_off * vlen);
-            for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
-                h->uni_vmovups(Vmm(preserved_vec_idxs[idx_off + i]),
-                        h->ptr[h->rsp + i * vlen]);
-        }
 
-        for (size_t i = 0; i < tail_vecs_to_preserve; ++i) {
-            preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve;
-        }
+        for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
+            h->uni_vmovups(Vmm(preserved_vec_idxs[idx_off + i]),
+                    h->ptr[h->rsp + i * vlen]);
+    }
 
-        if (save_vecs_state) {
-            for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
-                h->uni_vmovups(h->ptr[h->rsp + i * vlen],
-                        Vmm(preserved_vec_idxs[idx_off + i]));
-            h->sub(h->rsp, idx_off * vlen);
-        }
+    for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
+        preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve;
+
+    if (save_state_) {
+        for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
+            h->uni_vmovups(h->ptr[h->rsp + i * vlen],
+                    Vmm(preserved_vec_idxs[idx_off + i]));
 
-        assign_regs();
+        if (idx_off)
+            h->sub(h->rsp, idx_off * vlen);
     }
+
+    assign_regs();
 }
 
 template <cpu_isa_t isa>
 void jit_uni_eltwise_injector_f32<isa>::injector_postamble() {
-    if (save_vecs_state) {
-        for (size_t i = 0; i < preserved_vecs_count; ++i)
-            h->uni_vmovups(Vmm(preserved_vec_idxs[i]),
-                    h->ptr[h->rsp + i * vlen]);
+    if (!save_state_) return;
+
+    for (size_t i = 0; i < preserved_vecs_count; ++i)
+        h->uni_vmovups(Vmm(preserved_vec_idxs[i]),
+                h->ptr[h->rsp + i * vlen]);
+
+    if (preserved_vecs_count)
         h->add(h->rsp, preserved_vecs_count * vlen);
 
-        h->pop(p_table);
-    }
+    h->pop(p_table);
 }
 
 template <cpu_isa_t isa>
@@ -137,33 +128,26 @@ void jit_uni_eltwise_injector_f32<isa>::assign_regs() {
     vmm_aux1 = Vmm(preserved_vec_idxs[1]);
     vmm_aux2 = Vmm(preserved_vec_idxs[2]);
     vmm_aux3 = Vmm(preserved_vec_idxs[3]);
-
-    p_table = Xbyak::Reg64(table_reg_idx);
-    k_mask = Xbyak::Opmask(opmask_idx);
+    vmm_aux4 = Vmm(preserved_vec_idxs[4]);
 }
 
 template <cpu_isa_t isa>
 void jit_uni_eltwise_injector_f32<isa>::exp_compute_vector(const Vmm &vmm_src) {
-    const unsigned char _op_floor = 1;
-
-    h->uni_vminps(vmm_src, vmm_src, h->ptr[p_table + 10 * vlen]);
-    h->uni_vmaxps(vmm_src, vmm_src, h->ptr[p_table + 11 * vlen]);
+    h->uni_vminps(vmm_src, vmm_src, table_val(10));
+    h->uni_vmaxps(vmm_src, vmm_src, table_val(11));
     h->uni_vmovups(vmm_aux0, vmm_src);
     //calculate exp(x)
     // fx = x * log2ef + 0.5
-    h->uni_vmulps(vmm_src, vmm_src, h->ptr[p_table + 2 * vlen]);
-    h->uni_vaddps(vmm_src, vmm_src, h->ptr[p_table + 1 * vlen]);
+    h->uni_vmulps(vmm_src, vmm_src, table_val(2));
+    h->uni_vaddps(vmm_src, vmm_src, table_val(1));
 
     // tmp = floorf(fx)
     if (isa == avx512_common) {
         h->vcvtps2dq(vmm_aux1 | h->T_rd_sae, vmm_src);
         h->vcvtdq2ps(vmm_aux1, vmm_aux1);
 
-        unsigned char _cmp_gt_os = 14;
-        Xbyak::Opmask k_mask_tmp = Xbyak::Opmask(2);
-        h->vcmpps(k_mask_tmp, vmm_aux1, vmm_src, _cmp_gt_os);
-        h->vmovups(vmm_aux3 | k_mask_tmp | h->T_z,
-                h->zword[p_table + 0 * vlen]);
+        h->vcmpps(k_mask, vmm_aux1, vmm_src, _cmp_nle_us);
+        h->vmovups(vmm_aux3 | k_mask | h->T_z, table_val(0));
 
         h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux3);
     } else {
@@ -174,105 +158,213 @@ void jit_uni_eltwise_injector_f32<isa>::exp_compute_vector(const Vmm &vmm_src) {
     h->uni_vmovups(vmm_src, vmm_aux1); //vmm_src = fx
 
     //x = x - fx * ln2
-    h->uni_vfnmadd231ps(vmm_aux0, vmm_aux1, h->ptr[p_table + 3 * vlen]);
+    h->uni_vfnmadd231ps(vmm_aux0, vmm_aux1, table_val(3));
 
     // compute 2^n
     h->uni_vcvtps2dq(vmm_aux1, vmm_src);
-    h->uni_vpaddd(vmm_aux1, vmm_aux1, h->ptr[p_table + 4 * vlen]);
+    h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4));
     h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //Vmm(6) = 2^-fx
 
     // y = p5
-    h->uni_vmovups(vmm_src, h->ptr[p_table + 9 * vlen]);
+    h->uni_vmovups(vmm_src, table_val(9));
     // y = y * x + p4
-    h->uni_vfmadd213ps(vmm_src, vmm_aux0, h->ptr[p_table + 8 * vlen]);
+    h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(8));
     // y = y * x + p3
-    h->uni_vfmadd213ps(vmm_src, vmm_aux0, h->ptr[p_table + 7 * vlen]);
+    h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(7));
     // y = y * x + p2
-    h->uni_vfmadd213ps(vmm_src, vmm_aux0, h->ptr[p_table + 6 * vlen]);
+    h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(6));
     // y = y * x + p1
-    h->uni_vfmadd213ps(vmm_src, vmm_aux0, h->ptr[p_table + 0 * vlen]);
+    h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(0));
     // y = y * x + p0
-    h->uni_vfmadd213ps(vmm_src, vmm_aux0, h->ptr[p_table + 5 * vlen]);  //exp(q)
+    h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(5));  //exp(q)
     // y = y * 2^n
     h->uni_vmulps(vmm_src, vmm_src, vmm_aux1);
 }
 
 template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::relu_compute_vector(
-        const Vmm &vmm_src) {
-    unsigned char _cmp_gt_os = isa == avx512_common ? 14 : 6;
-
-    int alpha_off = 0 * vlen;
-    int zero_off = 1 * vlen;
+void jit_uni_eltwise_injector_f32<isa>::relu_compute_vector(const Vmm &vmm_src)
+{
+    const int alpha_off = 0, zero_off = 1;
 
     h->uni_vmovups(vmm_aux1, vmm_src);
     if (isa == sse42) {
         h->movups(vmm_mask, vmm_src);
-        h->mulps(vmm_src, h->ptr[p_table + alpha_off]);
-        h->cmpps(vmm_mask, h->ptr[p_table + zero_off], _cmp_gt_os);
+        h->mulps(vmm_src, table_val(alpha_off));
+        h->cmpps(vmm_mask, table_val(zero_off), _cmp_nle_us);
         h->blendvps(vmm_src, vmm_aux1);
     } else if (isa == avx2) {
-        h->vmulps(vmm_src, vmm_src, h->ptr[p_table + alpha_off]);
-        h->vcmpgtps(vmm_mask, vmm_aux1, h->ptr[p_table + zero_off]);
+        h->vmulps(vmm_src, vmm_src, table_val(alpha_off));
+        h->vcmpgtps(vmm_mask, vmm_aux1, table_val(zero_off));
         h->vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask);
     } else if (isa == avx512_common) {
-        h->vmulps(vmm_src, vmm_src, h->ptr[p_table + alpha_off]);
-        h->vcmpps(k_mask, vmm_aux1, h->ptr[p_table + zero_off], _cmp_gt_os);
-        h->vblendmps(vmm_src | k_mask, vmm_src,
-                     vmm_aux1);
+        h->vmulps(vmm_src, vmm_src, table_val(alpha_off));
+        h->vcmpps(k_mask, vmm_aux1, table_val(zero_off), _cmp_nle_us);
+        h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1);
     }
 }
 
 template <cpu_isa_t isa>
 void jit_uni_eltwise_injector_f32<isa>::relu_zero_ns_compute_vector(
         const Vmm &vmm_src) {
-    int zero_off = 1 * vlen;
-    h->uni_vmaxps(vmm_src, vmm_src, h->ptr[p_table + zero_off]);
+    const int zero_off = 1;
+    h->uni_vmaxps(vmm_src, vmm_src, table_val(zero_off));
 }
 
 template <cpu_isa_t isa>
 void jit_uni_eltwise_injector_f32<isa>::elu_compute_vector(const Vmm &vmm_src) {
-    const unsigned char _cmp_gt_os = 6;
-    const unsigned char _cmp_let_os = 2;
-    int alpha_off = 12 * vlen;
-    int zero_off = 13 * vlen;
+    const int alpha_off = 23, zero_off = 24;
 
     // compute exponent
     h->uni_vmovups(vmm_aux2, vmm_src);
     exp_compute_vector(vmm_src);
 
     // alpha * (exp(x) - 1)
-    h->uni_vsubps(vmm_src, vmm_src, h->ptr[p_table + 0 * 32]);
-    h->uni_vmulps(vmm_src, vmm_src, h->ptr[p_table + alpha_off]);
+    h->uni_vsubps(vmm_src, vmm_src, table_val(0));
+    h->uni_vmulps(vmm_src, vmm_src, table_val(alpha_off));
 
     // combine with mask
     if (isa == sse42) {
         h->pxor(vmm_mask, vmm_mask);
-        h->cmpps(vmm_mask,  vmm_aux2, _cmp_let_os);
+        h->cmpps(vmm_mask,  vmm_aux2, _cmp_le_os);
         h->blendvps(vmm_src, vmm_aux2);
     } else if (isa == avx2) {
-        h->uni_vcmpgtps(vmm_mask, vmm_aux2, h->ptr[p_table + zero_off]);
+        h->uni_vcmpgtps(vmm_mask, vmm_aux2, table_val(zero_off));
         h->uni_vblendvps(vmm_src, vmm_src, vmm_aux2, vmm_mask);
     } else if (isa == avx512_common) {
-        h->vcmpps(k_mask, vmm_aux2, h->ptr[p_table + zero_off], _cmp_gt_os);
+        h->vcmpps(k_mask, vmm_aux2, table_val(zero_off), _cmp_nle_us);
         h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux2);
     }
 }
 
 template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::tanh_compute_vector(
-        const Vmm &vmm_src) {
-    // compute exp(2x)
-    h->uni_vaddps(vmm_src, vmm_src, vmm_src);
-    exp_compute_vector(vmm_src);
-    // dup exp(2x)
-    h->uni_vmovups(vmm_aux0, vmm_src);
-    // (exp(2x) - 1)
-    h->uni_vsubps(vmm_src, vmm_src, h->ptr[p_table + 0 * vlen]);
-    // (exp(2x) + 1)
-    h->uni_vaddps(vmm_aux0, vmm_aux0, h->ptr[p_table + 0 * vlen]);
-    // y = (exp(2x) - 1) / (exp(2x) + 1)
-    h->uni_vdivps(vmm_src, vmm_src, vmm_aux0);
+void jit_uni_eltwise_injector_f32<isa>::tanh_compute_vector(const Vmm &vmm_src)
+{
+    // # comes from Taylor expansion error bound
+    //  > linear_sat_point = single(sqrt(3) * 1b-12);
+    // # comes from the exp formula cancellation
+    //  > exp_bound_point = (single(log(3)/2));
+    // # comes from rounding accuracy in float
+    //  > one_sat_point = round(atanh(1 - 1b-25), single, RU);
+    //  > P = fpminimax(f, [|1, 3, 5, 7, 9|], [|24... |],
+    //            [linear_sat_point, exp_bound_point], relative, floating);
+    //  > err_bound = D(sup(supnorm(P, tanh(x),
+    //          [linear_sat_point, exp_bound_point], relative, theta)));
+    //    0x1.fffd6f00b9539p-25
+    //  > P;
+    //    x * (0x1.fffffep-1 + x^0x1p1 * (-0x1.55539ep-2 + x^0x1p1 *
+    //        (0x1.10be3ep-3 + x^0x1p1 * (-0x1.ae57b4p-5
+    //        + x^0x1p1 * 0x1.09fa1p-6))))
+
+    // register mapping
+    // vmm_src contains input
+    // vmm_aux0 contains mask of currently valid results.
+    //     1 is need computation, 0 is already computed
+    // vmm_aux1 contains current output
+    // vmm_aux2, vmm_aux3 contains auxiliary values
+    // vmm_aux4 contains the original sign of inputs
+
+    Label end_tanh_label;
+
+    auto test_exit =[&](Xbyak::Address threshold){
+        // is not necessary for >AVX, but should not matter on perf
+        h->uni_vmovups(vmm_aux0, vmm_src);
+        if (isa == avx512_common){
+            h->vcmpps(k_mask, vmm_aux0, threshold, 0x5);
+            h->kortestw(k_mask, k_mask);
+        } else {
+            h->uni_vcmpgeps(vmm_aux0, vmm_aux0, threshold);
+            h->uni_vtestps(vmm_aux0, vmm_aux0);
+        }
+        h->jz(end_tanh_label, Xbyak::CodeGenerator::T_NEAR);
+    };
+
+    auto blend_results=[&](Vmm vmm_partial_res){
+        if (isa == avx512_common)
+            h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_partial_res);
+        else
+            h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_partial_res, vmm_aux0);
+    };
+
+    // because tanh(x) = -tanh(-x), we extract sign to make x postive
+    // and reapply sign at the end
+    // mov is not necessary for >AVX, but should not matter for performance
+    h->uni_vmovups(vmm_aux4, vmm_src);
+    h->uni_vandps(vmm_aux4, vmm_aux4, table_val(12));
+    h->uni_vandps(vmm_src, vmm_src, table_val(17));
+
+    // if x < linear_sat_point for all inputs, we just return the input
+    h->uni_vmovups(vmm_aux1, vmm_src);
+    test_exit(table_val(13));
+
+    // if one of the mask is one, we have to compute an better approx
+    h->uni_vmovups(vmm_aux2, vmm_src);
+    h->uni_vmulps(vmm_aux2, vmm_aux2, vmm_aux2);
+    h->uni_vmovups(vmm_aux3, table_val(22));
+    h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(21));
+    h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(20));
+    h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(19));
+    h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(18));
+    h->uni_vmulps(vmm_aux3, vmm_aux3, vmm_src);
+
+    // we blend only the result that need update
+    blend_results(vmm_aux3);
+
+    // if x < exp_bound_point, we go to return point
+    test_exit(table_val(14));
+
+    // if not we use a better approx 1 - 2 / (1 + exp(2x))
+    // compute 2x
+    h->uni_vmovups(vmm_aux3, vmm_src);
+    h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux3);
+
+    // Compute exp(2x)
+    // We need to save kmask, vmm_aux0, vmm_aux1 and vmm_src as exp can use them
+    // vmm_src is not more read afterwards, so we do not have to save it
+    auto stack_size = 3 * vlen + (isa == avx512_common) * 4;
+    h->sub(h->rsp, stack_size);
+    h->uni_vmovups(h->ptr[h->rsp + 0 * vlen], vmm_aux0);
+    h->uni_vmovups(h->ptr[h->rsp + 1 * vlen], vmm_aux1);
+    h->uni_vmovups(h->ptr[h->rsp + 2 * vlen], vmm_src);
+    if (isa == avx512_common)
+        h->kmovw(h->ptr[h->rsp + 3 * vlen], k_mask);
+
+    exp_compute_vector(vmm_aux3);
+
+    h->uni_vmovups(vmm_aux0, h->ptr[h->rsp + 0 * vlen]);
+    h->uni_vmovups(vmm_aux1, h->ptr[h->rsp + 1 * vlen]);
+    h->uni_vmovups(vmm_src, h->ptr[h->rsp + 2 * vlen]);
+    if (isa == avx512_common)
+        h->kmovw(k_mask, h->ptr[h->rsp + 3 * vlen]);
+    h->add(h->rsp, stack_size);
+
+    // 1 + exp(2x)
+    h->uni_vaddps(vmm_aux3, vmm_aux3, table_val(0));
+
+    // 1 - 2 / (1 + exp(2x))
+    h->uni_vmovups(vmm_aux2, table_val(16));
+    h->uni_vdivps(vmm_aux2, vmm_aux2, vmm_aux3);
+    h->uni_vaddps(vmm_aux2, vmm_aux2, table_val(0));
+
+    // we blend only the result that need update
+    blend_results(vmm_aux2);
+
+    // finally, we saturate to 1 if needed
+    // TODO: maybe move that up if most inputs saturate in practice
+    if (isa == avx512_common)
+        h->vcmpps(k_mask, vmm_aux0, table_val(15), 0x5);
+    else {
+        h->uni_vmovups(vmm_aux0, vmm_src);
+        h->uni_vcmpgeps(vmm_aux0, vmm_aux0, table_val(15));
+    }
+    h->uni_vmovups(vmm_aux2, table_val(0));
+    blend_results(vmm_aux2);
+
+    h->L(end_tanh_label);
+    {
+        // we apply the sign of x to the result and we are done
+        h->uni_vmovups(vmm_src, vmm_aux1);
+        h->uni_vpxor(vmm_src, vmm_src, vmm_aux4);
+    }
 }
 
 template <cpu_isa_t isa>
@@ -284,24 +376,22 @@ void jit_uni_eltwise_injector_f32<isa>::square_compute_vector(
 template <cpu_isa_t isa>
 void jit_uni_eltwise_injector_f32<isa>::abs_compute_vector(const Vmm &vmm_src) {
     // compute abs(x) = _mm_and_ps(x, 01111..111));
-    h->uni_vandps(vmm_src, vmm_src, h->ptr[p_table + 0*vlen]);
+    h->uni_vandps(vmm_src, vmm_src, table_val(0));
 }
 
 template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::sqrt_compute_vector(
-        const Vmm &vmm_src) {
+void jit_uni_eltwise_injector_f32<isa>::sqrt_compute_vector(const Vmm &vmm_src)
+{
     if (isa == avx512_common) {
-        unsigned char _cmp_gt_os = 6;
-
-        h->vcmpps(k_mask, vmm_src, h->ptr[p_table + 0 * vlen], _cmp_gt_os);
+        h->vcmpps(k_mask, vmm_src, table_val(0), _cmp_nle_us);
         h->uni_vsqrtps(vmm_aux1, vmm_src);
-        h->uni_vmovups(vmm_src, h->ptr[p_table + 0*vlen]);
+        h->uni_vmovups(vmm_src, table_val(0));
         h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1);
     } else {
         h->uni_vmovups(vmm_mask, vmm_src);
-        h->uni_vcmpgtps(vmm_mask, vmm_mask, h->ptr[p_table + 0*vlen]);
+        h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(0));
         h->uni_vsqrtps(vmm_aux1, vmm_src);
-        h->uni_vmovups(vmm_src, h->ptr[p_table + 0*vlen]);
+        h->uni_vmovups(vmm_src, table_val(0));
         h->uni_vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask);
     }
 }
@@ -310,48 +400,39 @@ template <cpu_isa_t isa>
 void jit_uni_eltwise_injector_f32<isa>::linear_compute_vector(
         const Vmm &vmm_src) {
     // compute x = alpha * x + beta;
-    h->uni_vmovups(vmm_aux0, h->ptr[p_table + 0*vlen]);
-    h->uni_vfmadd213ps(vmm_src, vmm_aux0, h->ptr[p_table + 1*vlen]);
+    h->uni_vmovups(vmm_aux0, table_val(0));
+    h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(1));
 }
 
 template <cpu_isa_t isa>
 void jit_uni_eltwise_injector_f32<isa>::bounded_relu_compute_vector(
         const Vmm &vmm_src) {
     // compute bounded relu */
-    h->uni_vmaxps(vmm_src, vmm_src, h->ptr[p_table + 1*vlen]);
-    h->uni_vminps(vmm_src, vmm_src, h->ptr[p_table + 0*vlen]);
-}
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::clamp_compute_vector(
-        const Vmm &vmm_src) {
-    h->uni_vmaxps(vmm_src, vmm_src, h->ptr[p_table + 1*vlen]);
-    h->uni_vminps(vmm_src, vmm_src, h->ptr[p_table + 0*vlen]);
+    h->uni_vmaxps(vmm_src, vmm_src, table_val(1));
+    h->uni_vminps(vmm_src, vmm_src, table_val(0));
 }
 
 template <cpu_isa_t isa>
 void jit_uni_eltwise_injector_f32<isa>::soft_relu_compute_vector(
         const Vmm &vmm_src) {
-    const unsigned char _op_floor = 1;
     // duplicate src
     h->uni_vmovups(vmm_aux2, vmm_src);
 
-    h->uni_vminps(vmm_src, vmm_src, h->ptr[p_table + 24 * vlen]);
-    h->uni_vmaxps(vmm_src, vmm_src, h->ptr[p_table + 25 * vlen]);
+    h->uni_vminps(vmm_src, vmm_src, table_val(24));
+    h->uni_vmaxps(vmm_src, vmm_src, table_val(25));
     h->uni_vmovups(vmm_aux1, vmm_src);
     // calculate exp(x)
     // fx = x * log2ef + 0.5
-    h->uni_vmulps(vmm_src, vmm_src, h->ptr[p_table + 2 * vlen]);
-    h->uni_vaddps(vmm_src, vmm_src, h->ptr[p_table + 1 * vlen]);
+    h->uni_vmulps(vmm_src, vmm_src, table_val(2));
+    h->uni_vaddps(vmm_src, vmm_src, table_val(1));
 
     // tmp = floorf(fx)
     if (isa == avx512_common) {
         h->vcvtps2dq(vmm_aux0 | h->T_rd_sae, vmm_src);
         h->vcvtdq2ps(vmm_aux0, vmm_aux0);
 
-        unsigned char _cmp_gt_os = 14;
-        h->vcmpps(k_mask, vmm_aux0, vmm_src, _cmp_gt_os);
-        h->vmovups(vmm_aux3 | k_mask | h->T_z, h->ptr[p_table + 0 * vlen]);
+        h->vcmpps(k_mask, vmm_aux0, vmm_src, _cmp_nle_us);
+        h->vmovups(vmm_aux3 | k_mask | h->T_z, table_val(0));
 
         h->vsubps(vmm_aux0, vmm_aux0, vmm_aux3);
     } else {
@@ -361,32 +442,32 @@ void jit_uni_eltwise_injector_f32<isa>::soft_relu_compute_vector(
     // keep fx for further computations
     h->uni_vmovups(vmm_src, vmm_aux0); //vmm_src = fx
     // calculation fx * ln2
-    h->uni_vmulps(vmm_aux0, vmm_aux0, h->ptr[p_table + 3 * vlen]);
+    h->uni_vmulps(vmm_aux0, vmm_aux0, table_val(3));
     // x = x - fx * ln2
     h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux0);
     // y = p5
-    h->uni_vmovups(vmm_aux3, h->ptr[p_table + 22 * vlen]);
+    h->uni_vmovups(vmm_aux3, table_val(22));
     // y = y * x + p4
-    h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, h->ptr[p_table + 21 * vlen]);
+    h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(21));
     // y = y * x + p3
-    h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, h->ptr[p_table + 20 * vlen]);
+    h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(20));
     // y = y * x + p2
-    h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, h->ptr[p_table + 19 * vlen]);
+    h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(19));
     // y = y * x + p1
-    h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, h->ptr[p_table + 0 * vlen]);
+    h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(0));
     // y = y * x + p0
-    h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, h->ptr[p_table + 17 * vlen]);
+    h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(17));
 
     // compute 2^(-n)
     if (isa == avx512_common) {
-        h->vmulps(vmm_aux1, vmm_src, h->ptr[p_table + 23 * vlen]);
+        h->vmulps(vmm_aux1, vmm_src, table_val(23));
         h->vcvtps2dq(vmm_aux1, vmm_aux1);
     } else {
         h->uni_vcvtps2dq(vmm_aux1, vmm_src);
-        h->uni_vpsignd(vmm_aux1, vmm_aux1, h->ptr[p_table + 23 * vlen]);
+        h->uni_vpsignd(vmm_aux1, vmm_aux1, table_val(23));
     }
 
-    h->uni_vpaddd(vmm_aux1, vmm_aux1, h->ptr[p_table + 4 * vlen]);
+    h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4));
     h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //vmm_aux1 = 2^-fx
     // calculate ln(1 + y)
     h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux1);
@@ -396,46 +477,45 @@ void jit_uni_eltwise_injector_f32<isa>::soft_relu_compute_vector(
     h->uni_vpsrld(vmm_src, vmm_src, 23);
     h->uni_vcvtdq2ps(vmm_src, vmm_src);
     // got n. where n is x = 2^n * y. y = 0.5 .. 1
-    h->uni_vsubps(vmm_src, vmm_src, h->ptr[p_table + 5 * vlen]);
+    h->uni_vsubps(vmm_src, vmm_src, table_val(5));
 
-    h->uni_vandps(vmm_aux3, vmm_aux3, h->ptr[p_table + 6 * vlen]);
+    h->uni_vandps(vmm_aux3, vmm_aux3, table_val(6));
     // got y. (mantisa)  0.5 < y < 1
-    h->uni_vorps(vmm_aux3, vmm_aux3, h->ptr[p_table + 7 * vlen]);
+    h->uni_vorps(vmm_aux3, vmm_aux3, table_val(7));
     // y  = y - 1
-    h->uni_vsubps(vmm_aux3, vmm_aux3, h->ptr[p_table + 0 * vlen]);
+    h->uni_vsubps(vmm_aux3, vmm_aux3, table_val(0));
     // y = p8
-    h->uni_vmovups(vmm_aux1, h->ptr[p_table + 16 * vlen]);
+    h->uni_vmovups(vmm_aux1, table_val(16));
     // y = y * x + p7
-    h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, h->ptr[p_table + 15 * vlen]);
+    h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(15));
     // y = y * x + p6
-    h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, h->ptr[p_table + 14 * vlen]);
+    h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(14));
     // y = y * x + p5
-    h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, h->ptr[p_table + 13 * vlen]);
+    h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(13));
     // y = y * x + p4
-    h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, h->ptr[p_table + 12 * vlen]);
+    h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(12));
     // y = y * x + p3
-    h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, h->ptr[p_table + 11 * vlen]);
+    h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(11));
     // y = y * x + p2
-    h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, h->ptr[p_table + 10 * vlen]);
+    h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(10));
     // y = y * x + p1
-    h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, h->ptr[p_table + 9 * vlen]);
+    h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(9));
     // y = y * x + p0 ; p0 = 0
-    h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, h->ptr[p_table + 8 * vlen]);
+    h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(8));
     //calculate ln(2) * n
-    h->uni_vmulps(vmm_src, vmm_src, h->ptr[p_table + 3 * vlen]);
+    h->uni_vmulps(vmm_src, vmm_src, table_val(3));
     h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_src);
     h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_aux0);
 
     // get vmm_mask = src > max logf
     h->uni_vmovups(vmm_mask, vmm_aux2);
     if (isa == avx512_common) {
-        unsigned char _cmp_gt_os = 6;
         // y = (x < max log f) ? soft_relu(x) : x
-        h->vcmpps(k_mask, vmm_mask, h->ptr[p_table + 24 * vlen], _cmp_gt_os);
+        h->vcmpps(k_mask, vmm_mask, table_val(24), _cmp_nle_us);
         h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_aux2);
     } else {
         // y = (x < max log f) ? soft_relu(x) : x
-        h->uni_vcmpgtps(vmm_mask, vmm_mask, h->ptr[p_table + 24 * vlen]);
+        h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(24));
         h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_aux2, vmm_mask);
     }
 
@@ -445,23 +525,46 @@ void jit_uni_eltwise_injector_f32<isa>::soft_relu_compute_vector(
 template <cpu_isa_t isa>
 void jit_uni_eltwise_injector_f32<isa>::logistic_compute_vector(
         const Vmm &vmm_src) {
+    // we store the original sign and make x negative
+    // IMPORTANT: we assume vmm_aux0 to be xmm0, as for sse4.2 path it is required
+    // IMPORTANT: we use vmm_aux2 for the mask as exp_compute does not use it.
+    h->uni_vmovups(vmm_aux2, vmm_src);
+    h->uni_vandps(vmm_aux2, vmm_aux2, table_val(12));
+    h->uni_vorps(vmm_src, vmm_src, table_val(12));
+
     exp_compute_vector(vmm_src);
     // dup exp(x)
-    h->uni_vmovups(vmm_aux0, vmm_src);
+    h->uni_vmovups(vmm_aux1, vmm_src);
     // (exp(x) + 1)
-    h->uni_vaddps(vmm_aux0, vmm_aux0, h->ptr[p_table + 0 * vlen]);
+    h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(0));
     // y = exp(x) / (exp(x) + 1)
-    h->uni_vdivps(vmm_src, vmm_src, vmm_aux0);
+    h->uni_vdivps(vmm_src, vmm_src, vmm_aux1);
+
+    // Now we have to apply the "symmetry" based on original sign
+    h->uni_vmovups(vmm_aux3, table_val(0));
+    h->uni_vsubps(vmm_aux3, vmm_aux3, vmm_src);
+    if (isa == avx512_common) {
+        h->vptestmd(k_mask, vmm_aux2, vmm_aux2);
+        h->vblendmps(vmm_aux3 | k_mask, vmm_aux3, vmm_src);
+    } else {
+        h->uni_vmovups(vmm_aux0, vmm_aux2);// The mask should be xmm0 for sse4.2
+        h->uni_vblendvps(vmm_aux3, vmm_aux3, vmm_src, vmm_aux0);
+    }
+    h->uni_vmovups(vmm_src, vmm_aux3);
+}
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_injector_f32<isa>::clamp_compute_vector(
+        const Vmm &vmm_src) {
+    // compute clamp */
+    h->uni_vmaxps(vmm_src, vmm_src, table_val(1));
+    h->uni_vminps(vmm_src, vmm_src, table_val(0));
 }
 
 template <cpu_isa_t isa>
 void jit_uni_eltwise_injector_f32<isa>::relu_prepare_table() {
-    for (size_t d = 0; d < vlen / sizeof(float); ++d) {
-        h->dd(float2int(alpha));
-    }
-    for (size_t d = 0; d < vlen / sizeof(float); ++d) {
-        h->dd(0);
-    }
+    for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
+    for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
 }
 
 template <cpu_isa_t isa>
@@ -479,20 +582,28 @@ void jit_uni_eltwise_injector_f32<isa>::elu_prepare_table() {
             0x3d2bb1b1, // [8] p4 = 0.041917507f
             0x3c091ec1, // [9] p5 = 0.008369149f
             0x42b0c0a5, //[10] max logf = 88.3762589f
-            0xc1766666  //[11] min logf = -14.5f
+            0xc1766666, //[11] min logf = -14.5f
+            // tanh(x) constants,
+            0x80000000, //[12] mask to extract sign
+            0x39ddb3d7, //[13] arg below which tanh(x) = x
+            0x3f0c9f54, //[14] arg below which pol approx is valid
+            0x41102cb4, //[15] arg after which tanh(x) = 1
+            0xc0000000, //[16] -2.0f
+            0x7fffffff, //[17] mask to make positive
+            // tanh pol approx
+            0x3f7fffff, //[18] p0
+            0xbeaaa9cf, //[19] p1
+            0x3e085f1f, //[20] p2
+            0xbd572bda, //[21] p3
+            0x3c84fd08, //[22] p4
     };
 
     for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) {
-        for (size_t d = 0; d < vlen / sizeof(float); ++d) {
-            h->dd(cvals[i]);
-        }
-    }
-    for (size_t d = 0; d < vlen / sizeof(float); ++d) {
-        h->dd(float2int(alpha));
-    }
-    for (size_t d = 0; d < vlen / sizeof(float); ++d) {
-        h->dd(0);
+        for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(cvals[i]);
     }
+
+    for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
+    for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
 }
 
 template <cpu_isa_t isa>
@@ -537,63 +648,48 @@ void jit_uni_eltwise_injector_f32<isa>::soft_relu_prepare_table() {
 
 template <cpu_isa_t isa>
 void jit_uni_eltwise_injector_f32<isa>::abs_prepare_table() {
-    for (size_t d = 0; d < vlen / sizeof(float); ++d) {
-        h->dd(0x7fffffff);
-    }
+    for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0x7fffffff);
 }
 
 template <cpu_isa_t isa>
 void jit_uni_eltwise_injector_f32<isa>::sqrt_prepare_table() {
-    for (size_t d = 0; d < vlen / sizeof(float); ++d) {
-        h->dd(0);
-    }
+    for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
 }
 
 template <cpu_isa_t isa>
 void jit_uni_eltwise_injector_f32<isa>::linear_prepare_table() {
-    for (size_t d = 0; d < vlen / sizeof(float); ++d) {
-        h->dd(float2int(alpha));
-    }
-    for (size_t d = 0; d < vlen / sizeof(float); ++d) {
-        h->dd(float2int(beta));
-    }
+    for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
+    for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(beta_));
 }
 
 template <cpu_isa_t isa>
 void jit_uni_eltwise_injector_f32<isa>::bounded_relu_prepare_table() {
-    for (size_t d = 0; d < vlen / sizeof(float); ++d) {
-        h->dd(float2int(alpha));
-    }
-    for (size_t d = 0; d < vlen / sizeof(float); ++d) {
-        h->dd(0);
-    }
+    for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
+    for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
 }
 
 template <cpu_isa_t isa>
 void jit_uni_eltwise_injector_f32<isa>::clamp_prepare_table() {
-    for (size_t d = 0; d < vlen / sizeof(float); ++d) {
-        h->dd(float2int(alpha));
-    }
-    for (size_t d = 0; d < vlen / sizeof(float); ++d) {
-        h->dd(float2int(beta));
-    }
+    for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
+    for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(beta_));
 }
 
 template <cpu_isa_t isa>
-int jit_uni_eltwise_injector_f32<isa>::aux_vecs_count(alg_kind_t elt_alg) {
-    switch (elt_alg) {
-        case alg_kind::eltwise_relu: return (alpha == 0.f) ? 0 : 2;
-        case alg_kind::eltwise_elu: return 4;
-        case alg_kind::eltwise_tanh: return 4;
-        case alg_kind::eltwise_square: return 0;
-        case alg_kind::eltwise_abs: return 0;
-        case alg_kind::eltwise_sqrt: return 2;
-        case alg_kind::eltwise_linear: return 1;
-        case alg_kind::eltwise_bounded_relu: return 0;
-        case alg_kind::eltwise_soft_relu: return 4;
-        case alg_kind::eltwise_logistic: return 4;
-        case alg_kind::eltwise_clamp: return 0;
-        default: assert(!"unsupported eltwise algorithm");
+int jit_uni_eltwise_injector_f32<isa>::aux_vecs_count(alg_kind_t alg_) {
+    switch (alg_) {
+    case alg_kind::eltwise_relu: return (alpha_ == 0.f) ? 0 : 2;
+    case alg_kind::eltwise_elu: return 4;
+    case alg_kind::eltwise_tanh: return 5;
+    case alg_kind::eltwise_square: return 0;
+    case alg_kind::eltwise_abs: return 0;
+    case alg_kind::eltwise_sqrt: return 2;
+    case alg_kind::eltwise_linear: return 1;
+    case alg_kind::eltwise_bounded_relu: return 0;
+    case alg_kind::eltwise_soft_relu: return 4;
+    case alg_kind::eltwise_logistic: return 4;
+    case alg_kind::eltwise_clamp: return 0;
+    case alg_kind::eltwise_exp: return 4;
+    default: assert(!"unsupported eltwise algorithm");
     }
 
     return 0;
@@ -602,37 +698,25 @@ int jit_uni_eltwise_injector_f32<isa>::aux_vecs_count(alg_kind_t elt_alg) {
 template <cpu_isa_t isa>
 void jit_uni_eltwise_injector_f32<isa>::compute_body(size_t start_idx,
         size_t end_idx) {
-    h->mov(p_table, l_table);
-
+    using namespace alg_kind;
     for (size_t idx = start_idx; idx < end_idx; idx++) {
-        switch (elt_alg) {
-            case alg_kind::eltwise_relu:
-                if (alpha == 0.f)
-                    relu_zero_ns_compute_vector(Vmm(idx));
-                else
-                    relu_compute_vector(Vmm(idx));
-                break;
-            case alg_kind::eltwise_elu:
-                elu_compute_vector(Vmm(idx)); break;
-            case alg_kind::eltwise_tanh:
-                tanh_compute_vector(Vmm(idx)); break;
-            case alg_kind::eltwise_square:
-                square_compute_vector(Vmm(idx)); break;
-            case alg_kind::eltwise_abs:
-                abs_compute_vector(Vmm(idx)); break;
-            case alg_kind::eltwise_sqrt:
-                sqrt_compute_vector(Vmm(idx)); break;
-            case alg_kind::eltwise_linear:
-                linear_compute_vector(Vmm(idx)); break;
-            case alg_kind::eltwise_bounded_relu:
-                bounded_relu_compute_vector(Vmm(idx)); break;
-            case alg_kind::eltwise_soft_relu:
-                soft_relu_compute_vector(Vmm(idx)); break;
-            case alg_kind::eltwise_logistic:
-                logistic_compute_vector(Vmm(idx)); break;
-            case alg_kind::eltwise_clamp:
-                clamp_compute_vector(Vmm(idx)); break;
-            default: assert(!"unsupported eltwise algorithm");
+        switch (alg_) {
+        case eltwise_relu:
+            if (alpha_ == 0.f) relu_zero_ns_compute_vector(Vmm(idx));
+            else relu_compute_vector(Vmm(idx));
+            break;
+        case eltwise_elu: elu_compute_vector(Vmm(idx)); break;
+        case eltwise_tanh: tanh_compute_vector(Vmm(idx)); break;
+        case eltwise_square: square_compute_vector(Vmm(idx)); break;
+        case eltwise_abs: abs_compute_vector(Vmm(idx)); break;
+        case eltwise_sqrt: sqrt_compute_vector(Vmm(idx)); break;
+        case eltwise_linear: linear_compute_vector(Vmm(idx)); break;
+        case eltwise_bounded_relu: bounded_relu_compute_vector(Vmm(idx)); break;
+        case eltwise_soft_relu: soft_relu_compute_vector(Vmm(idx)); break;
+        case eltwise_logistic: logistic_compute_vector(Vmm(idx)); break;
+        case eltwise_clamp: clamp_compute_vector(Vmm(idx)); break;
+        case eltwise_exp: exp_compute_vector(Vmm(idx)); break;
+        default: assert(!"unsupported eltwise algorithm");
         }
     }
 }
@@ -640,9 +724,7 @@ void jit_uni_eltwise_injector_f32<isa>::compute_body(size_t start_idx,
 template <cpu_isa_t isa>
 void jit_uni_eltwise_injector_f32<isa>::compute_vector_range(size_t start_idx,
         size_t end_idx) {
-    assert(start_idx < vecs_count);
-    assert(end_idx <= vecs_count);
-    assert(start_idx < end_idx);
+    assert(start_idx < end_idx && end_idx <= vecs_count);
 
     injector_preamble(start_idx, end_idx);
     compute_body(start_idx_tail, end_idx);
@@ -652,38 +734,30 @@ void jit_uni_eltwise_injector_f32<isa>::compute_vector_range(size_t start_idx,
 }
 
 template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::compute_vector(size_t idx) {
-    compute_vector_range(idx, idx + 1);
-}
+void jit_uni_eltwise_injector_f32<isa>::prepare_table(bool gen_table) {
+    using namespace alg_kind;
 
-template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::prepare_table() {
     h->align(64);
     h->L(l_table);
 
-    switch (elt_alg) {
-        case alg_kind::eltwise_relu:
-            relu_prepare_table(); break;
-        case alg_kind::eltwise_elu:
-        case alg_kind::eltwise_tanh:
-        case alg_kind::eltwise_logistic:
+    if (gen_table) {
+        switch (alg_) {
+        case eltwise_relu: relu_prepare_table(); break;
+        case eltwise_elu:
+        case eltwise_tanh:
+        case eltwise_logistic:
+        case eltwise_exp:
             elu_prepare_table(); break;
-        case alg_kind::eltwise_soft_relu:
-            soft_relu_prepare_table(); break;
-        case alg_kind::eltwise_abs:
-            abs_prepare_table(); break;
-        case alg_kind::eltwise_sqrt:
-            sqrt_prepare_table(); break;
-        case alg_kind::eltwise_linear:
-            linear_prepare_table(); break;
-        case alg_kind::eltwise_bounded_relu:
-            bounded_relu_prepare_table(); break;
-        case alg_kind::eltwise_square:
-            break;
-        case alg_kind::eltwise_clamp:
-            clamp_prepare_table(); break;
+        case eltwise_soft_relu: soft_relu_prepare_table(); break;
+        case eltwise_abs: abs_prepare_table(); break;
+        case eltwise_sqrt: sqrt_prepare_table(); break;
+        case eltwise_linear: linear_prepare_table(); break;
+        case eltwise_bounded_relu: bounded_relu_prepare_table(); break;
+        case eltwise_square: break;
+        case eltwise_clamp: clamp_prepare_table(); break;
         default: assert(!"unsupported eltwise algorithm");
     }
+    }
 }
 
 template struct jit_uni_eltwise_injector_f32<avx512_common>;
@@ -861,27 +935,27 @@ struct jit_uni_kernel_fwd_f32: public jit_uni_eltwise_kernel_f32,
     jit_uni_kernel_fwd_f32(const eltwise_desc_t &desc)
         : jit_uni_eltwise_kernel_f32(desc), jit_generator() {
 
-        eltwise_injector = new jit_uni_eltwise_injector_f32<isa>(this,
-                desc.alg_kind, desc.alpha, desc.beta, false, 9, 1);
+        eltwise_injector_ = new jit_uni_eltwise_injector_f32<isa>(this,
+                desc.alg_kind, desc.alpha, desc.beta, false, r9, Opmask(1));
 
         using namespace alg_kind;
 
         assert(is_bwd() == false);
         assert(utils::one_of(desc.alg_kind, eltwise_tanh, eltwise_elu,
                     eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
-                    eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic));
+                    eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic,
+                    eltwise_clamp, eltwise_exp));
 
         preamble();
 
-        Label vectorized_loop_start;
-        Label reminder_loop_start;
-        Label vectorized_loop_end;
-        Label reminder_loop_end;
-
         Reg64 param = abi_param1;
         mov(reg_from, ptr[param + GET_OFF(from)]);
         mov(reg_to, ptr[param + GET_OFF(to)]);
         mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
+        eltwise_injector_->load_table_addr();
+
+        Label reminder_loop_start, reminder_loop_end;
+        Label vectorized_loop_start, vectorized_loop_end;
 
         cmp(reg_work_amount, simd_w);
         jl(reminder_loop_start, T_NEAR);
@@ -889,7 +963,7 @@ struct jit_uni_kernel_fwd_f32: public jit_uni_eltwise_kernel_f32,
         L(vectorized_loop_start);
 
         uni_vmovups(vmm_src, ptr[reg_from]);
-        eltwise_injector->compute_vector(vmm_src.getIdx());
+        eltwise_injector_->compute_vector(vmm_src.getIdx());
         uni_vmovups(ptr[reg_to], vmm_src);
 
         add(reg_from, vlen);
@@ -907,7 +981,7 @@ struct jit_uni_kernel_fwd_f32: public jit_uni_eltwise_kernel_f32,
         jle(reminder_loop_end, T_NEAR);
 
         movss(xmm_src, ptr[reg_from]);
-        eltwise_injector->compute_vector(xmm_src.getIdx());
+        eltwise_injector_->compute_vector(xmm_src.getIdx());
         movss(ptr[reg_to], xmm_src);
 
         add(reg_from, sizeof(float));
@@ -920,14 +994,12 @@ struct jit_uni_kernel_fwd_f32: public jit_uni_eltwise_kernel_f32,
 
         postamble();
 
-        eltwise_injector->prepare_table();
+        eltwise_injector_->prepare_table();
 
         ker_ = (decltype(ker_))this->getCode();
     }
 
-    ~jit_uni_kernel_fwd_f32() {
-        delete eltwise_injector;
-    }
+    ~jit_uni_kernel_fwd_f32() { delete eltwise_injector_; }
 
 private:
     using Vmm = typename utils::conditional3<isa == sse42, Xmm,
@@ -944,7 +1016,7 @@ private:
     Xmm xmm_src = Xmm(1);
     Vmm vmm_src = Vmm(1);
 
-    jit_uni_eltwise_injector_f32<isa>* eltwise_injector;
+    jit_uni_eltwise_injector_f32<isa> *eltwise_injector_;
 };
 
 } /* namespace */
@@ -959,23 +1031,23 @@ status_t jit_uni_eltwise_fwd_t<isa>::pd_t::init() {
                 prop_kind::forward_inference)
         && utils::everyone_is(data_type::f32, desc()->data_desc.data_type)
         && !has_zero_dim_memory()
-        && IMPLICATION(isa > avx2, utils::one_of(desc()->alg_kind,
-                eltwise_relu, eltwise_elu))
-        && IMPLICATION(isa == sse42 || isa == avx2, utils::one_of(
-                    desc()->alg_kind, eltwise_relu, eltwise_tanh, eltwise_elu,
-                    eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
-                    eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic))
-        && memory_desc_wrapper(src_pd()).is_dense()
+        && utils::one_of(desc()->alg_kind, eltwise_relu, eltwise_tanh,
+                eltwise_elu, eltwise_square, eltwise_abs, eltwise_sqrt,
+                eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu,
+                eltwise_logistic, eltwise_clamp, eltwise_exp)
+        && memory_desc_wrapper(src_pd()).is_dense(true)
+        && IMPLICATION(!memory_desc_wrapper(src_pd()).is_dense(false),
+                math::eltwise_fwd_preserves_zero(desc()->alg_kind, true))
         && attr()->has_default_values();
 
     return ok ? status::success : status::unimplemented;
 }
 
 template <cpu_isa_t isa>
-jit_uni_eltwise_fwd_t<isa>::jit_uni_eltwise_fwd_t(const pd_t *pd,
+jit_uni_eltwise_fwd_t<isa>::jit_uni_eltwise_fwd_t(const pd_t *apd,
         const input_vector &inputs, const output_vector &outputs)
-    : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd), kernel_(nullptr) {
-    const auto &desc = *conf_.desc();
+    : cpu_primitive_t(apd, inputs, outputs), kernel_(nullptr) {
+    const auto &desc = *pd()->desc();
     switch (desc.alg_kind) {
     case alg_kind::eltwise_relu:
         kernel_ = new jit_uni_relu_kernel_f32<isa>(desc); break;
@@ -989,13 +1061,13 @@ jit_uni_eltwise_fwd_t<isa>::~jit_uni_eltwise_fwd_t()
 { delete kernel_; }
 
 template <cpu_isa_t isa>
-void jit_uni_eltwise_fwd_t<isa>::execute_forward() {
+void jit_uni_eltwise_fwd_t<isa>::execute_forward() const {
     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto dst = reinterpret_cast<data_t *>(this->memory(0));
 
-    const memory_desc_wrapper data_d(conf_.src_pd());
+    const memory_desc_wrapper data_d(pd()->src_pd());
 
-    const size_t nelems = data_d.nelems();
+    const size_t nelems = data_d.nelems(true);
 
     src += data_d.blocking_desc().offset_padding;
     dst += data_d.blocking_desc().offset_padding;
@@ -1037,10 +1109,10 @@ status_t jit_uni_eltwise_bwd_t<isa>::pd_t::init() {
 }
 
 template <cpu_isa_t isa>
-jit_uni_eltwise_bwd_t<isa>::jit_uni_eltwise_bwd_t(const pd_t *pd,
+jit_uni_eltwise_bwd_t<isa>::jit_uni_eltwise_bwd_t(const pd_t *apd,
         const input_vector &inputs, const output_vector &outputs)
-    : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd), kernel_(nullptr) {
-    const auto &desc = *conf_.desc();
+    : cpu_primitive_t(apd, inputs, outputs), kernel_(nullptr) {
+    const auto &desc = *pd()->desc();
     switch (desc.alg_kind) {
     case alg_kind::eltwise_relu:
         kernel_ = new jit_uni_relu_kernel_f32<isa>(desc); break;
@@ -1053,13 +1125,13 @@ jit_uni_eltwise_bwd_t<isa>::~jit_uni_eltwise_bwd_t()
 { delete kernel_; }
 
 template <cpu_isa_t isa>
-void jit_uni_eltwise_bwd_t<isa>::execute_backward() {
+void jit_uni_eltwise_bwd_t<isa>::execute_backward() const {
     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
     auto diff_src = reinterpret_cast<data_t *>(this->memory(0));
 
-    const memory_desc_wrapper data_d(conf_.src_pd());
-    const memory_desc_wrapper diff_data_d(conf_.diff_src_pd());
+    const memory_desc_wrapper data_d(pd()->src_pd());
+    const memory_desc_wrapper diff_data_d(pd()->diff_src_pd());
 
     const size_t nelems = data_d.nelems();