Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_eltwise.hpp
index 063556d..1acc239 100644 (file)
@@ -18,7 +18,6 @@
 #define CPU_JIT_UNI_ELTWISE_HPP
 
 #include <assert.h>
-#include <mkldnn.hpp>
 
 #include "c_types_map.hpp"
 #include "cpu_eltwise_pd.hpp"
@@ -33,45 +32,57 @@ namespace cpu {
 
 template <cpu_isa_t isa>
 struct jit_uni_eltwise_injector_f32 {
-    jit_uni_eltwise_injector_f32(jit_generator* host, alg_kind_t elt_alg_,
-            float alpha_, float beta_, bool save_vecs_state_ = true,
-            int table_reg_idx_ = 0, int opmask_idx_ = 1) {
+    using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
+            isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
+
+    jit_uni_eltwise_injector_f32(jit_generator *host, alg_kind_t alg,
+            float alpha, float beta, bool save_state = true,
+            Xbyak::Reg64 p_table = Xbyak::util::rax,
+            Xbyak::Opmask k_mask = Xbyak::Opmask(1))
+        : alg_(alg), alpha_(alpha), beta_(beta), h(host)
+        , save_state_(save_state), p_table(p_table), k_mask(k_mask)
+    {
+        using namespace alg_kind;
         assert(utils::one_of(isa, sse42, avx2, avx512_common));
-        assert(utils::one_of(elt_alg_, alg_kind::eltwise_relu,
-                alg_kind::eltwise_tanh, alg_kind::eltwise_elu,
-                alg_kind::eltwise_square, alg_kind::eltwise_abs,
-                alg_kind::eltwise_sqrt, alg_kind::eltwise_linear,
-                alg_kind::eltwise_bounded_relu, alg_kind::eltwise_soft_relu,
-                alg_kind::eltwise_logistic, alg_kind::eltwise_clamp));
-
-        h = host;
-        elt_alg = elt_alg_;
-        alpha = alpha_;
-        beta = beta_;
-        save_vecs_state = save_vecs_state_;
-        table_reg_idx = table_reg_idx_;
-        opmask_idx = opmask_idx_;
+        assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu,
+                    eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
+                    eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic,
+                    eltwise_clamp, eltwise_exp));
     }
 
+    // note that eltwise.scale is ignored
+    jit_uni_eltwise_injector_f32(jit_generator *host,
+            const post_ops_t::entry_t::eltwise_t &eltwise,
+            bool save_state = true, Xbyak::Reg64 p_table = Xbyak::util::rax,
+            Xbyak::Opmask k_mask = Xbyak::Opmask(1))
+        : jit_uni_eltwise_injector_f32(host, eltwise.alg, eltwise.alpha,
+                eltwise.beta, save_state, p_table, k_mask) {}
+
     void compute_vector_range(size_t start_idx, size_t end_idx);
-    void compute_vector(size_t idx);
-    void prepare_table();
+    void compute_vector(size_t idx) { compute_vector_range(idx, idx + 1); }
+    void prepare_table(bool gen_table=true);
+    void load_table_addr() { h->mov(p_table, l_table); }
 
-private:
-    jit_generator* h;
+    const alg_kind_t alg_;
+    const float alpha_;
+    const float beta_;
 
-    using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
-            isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
+    jit_generator * const h;
 
-    size_t vlen = cpu_isa_traits<isa>::vlen;
+    const bool save_state_;
+    const Xbyak::Reg64 p_table;
+    const Xbyak::Opmask k_mask;
+    Xbyak::Label l_table;
 
-    alg_kind_t elt_alg;
-    float alpha;
-    float beta;
+private:
+    // if only the injector was inherited from jit_generator...
+    enum {
+        _cmp_le_os = jit_generator::_cmp_le_os,
+        _cmp_nle_us = jit_generator::_cmp_nle_us,
+        _op_floor = jit_generator::_op_floor,
+    };
 
-    bool save_vecs_state;
-    int table_reg_idx;
-    int opmask_idx;
+    size_t vlen = cpu_isa_traits<isa>::vlen;
 
     const static size_t preserved_vecs_max = 5;
 
@@ -81,20 +92,17 @@ private:
     size_t preserved_vec_idxs[preserved_vecs_max] = {0};
     size_t start_idx_tail = 0;
 
-    Vmm vmm_mask, vmm_aux0, vmm_aux1, vmm_aux2, vmm_aux3;
-
-    Xbyak::Reg64 p_table;
-    Xbyak::Opmask k_mask;
-    Xbyak::Label l_table;
+    Vmm vmm_mask, vmm_aux0, vmm_aux1, vmm_aux2, vmm_aux3, vmm_aux4;
 
-    int aux_vecs_count(alg_kind_t elt_alg);
+    Xbyak::Address table_val(int index)
+    { return h->ptr[p_table + index * vlen]; }
 
+    int aux_vecs_count(alg_kind_t alg);
     void compute_body(size_t start_idx, size_t end_idx);
     void injector_preamble(size_t start_idx, size_t end_idx);
     void injector_preamble_tail(size_t start_idx);
     void injector_postamble();
     void assign_regs();
-    bool is_free_vec(size_t idx);
 
     void exp_compute_vector(const Vmm &vmm_src);
     void relu_compute_vector(const Vmm &vmm_src);
@@ -137,21 +145,21 @@ struct jit_uni_eltwise_fwd_t : public cpu_primitive_t {
         virtual status_t init() override;
     };
 
-    jit_uni_eltwise_fwd_t(const pd_t *pd, const input_vector &inputs,
+    jit_uni_eltwise_fwd_t(const pd_t *apd, const input_vector &inputs,
                        const output_vector &outputs);
     ~jit_uni_eltwise_fwd_t();
 
     typedef typename prec_traits<data_type::f32>::type data_t;
 
-    virtual void execute(event_t *e)
+    virtual void execute(event_t *e) const
     {
         execute_forward();
         e->set_state(event_t::ready);
     }
 
 private:
-    void execute_forward();
-    pd_t conf_;
+    void execute_forward() const;
+    const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
     jit_uni_eltwise_kernel_f32 *kernel_;
 };
 
@@ -170,21 +178,21 @@ struct jit_uni_eltwise_bwd_t : public cpu_primitive_t {
         virtual status_t init() override;
     };
 
-    jit_uni_eltwise_bwd_t(const pd_t *pd, const input_vector &inputs,
+    jit_uni_eltwise_bwd_t(const pd_t *apd, const input_vector &inputs,
                        const output_vector &outputs);
     ~jit_uni_eltwise_bwd_t();
 
     typedef typename prec_traits<data_type::f32>::type data_t;
 
-    virtual void execute(event_t *e)
+    virtual void execute(event_t *e) const
     {
         execute_backward();
         e->set_state(event_t::ready);
     }
 
 private:
-    void execute_backward();
-    pd_t conf_;
+    void execute_backward() const;
+    const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
     jit_uni_eltwise_kernel_f32 *kernel_;
 };