#define CPU_JIT_UNI_ELTWISE_HPP
#include <assert.h>
-#include <mkldnn.hpp>
#include "c_types_map.hpp"
#include "cpu_eltwise_pd.hpp"
template <cpu_isa_t isa>
struct jit_uni_eltwise_injector_f32 {
- jit_uni_eltwise_injector_f32(jit_generator* host, alg_kind_t elt_alg_,
- float alpha_, float beta_, bool save_vecs_state_ = true,
- int table_reg_idx_ = 0, int opmask_idx_ = 1) {
+ using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
+ isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
+
+ jit_uni_eltwise_injector_f32(jit_generator *host, alg_kind_t alg,
+ float alpha, float beta, bool save_state = true,
+ Xbyak::Reg64 p_table = Xbyak::util::rax,
+ Xbyak::Opmask k_mask = Xbyak::Opmask(1))
+ : alg_(alg), alpha_(alpha), beta_(beta), h(host)
+ , save_state_(save_state), p_table(p_table), k_mask(k_mask)
+ {
+ using namespace alg_kind;
assert(utils::one_of(isa, sse42, avx2, avx512_common));
- assert(utils::one_of(elt_alg_, alg_kind::eltwise_relu,
- alg_kind::eltwise_tanh, alg_kind::eltwise_elu,
- alg_kind::eltwise_square, alg_kind::eltwise_abs,
- alg_kind::eltwise_sqrt, alg_kind::eltwise_linear,
- alg_kind::eltwise_bounded_relu, alg_kind::eltwise_soft_relu,
- alg_kind::eltwise_logistic, alg_kind::eltwise_clamp));
-
- h = host;
- elt_alg = elt_alg_;
- alpha = alpha_;
- beta = beta_;
- save_vecs_state = save_vecs_state_;
- table_reg_idx = table_reg_idx_;
- opmask_idx = opmask_idx_;
+ assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu,
+ eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
+ eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic,
+ eltwise_clamp, eltwise_exp));
}
+ // note that eltwise.scale is ignored
+ jit_uni_eltwise_injector_f32(jit_generator *host,
+ const post_ops_t::entry_t::eltwise_t &eltwise,
+ bool save_state = true, Xbyak::Reg64 p_table = Xbyak::util::rax,
+ Xbyak::Opmask k_mask = Xbyak::Opmask(1))
+ : jit_uni_eltwise_injector_f32(host, eltwise.alg, eltwise.alpha,
+ eltwise.beta, save_state, p_table, k_mask) {}
+
void compute_vector_range(size_t start_idx, size_t end_idx);
- void compute_vector(size_t idx);
- void prepare_table();
+ void compute_vector(size_t idx) { compute_vector_range(idx, idx + 1); }
+ void prepare_table(bool gen_table=true);
+ void load_table_addr() { h->mov(p_table, l_table); }
-private:
- jit_generator* h;
+ const alg_kind_t alg_;
+ const float alpha_;
+ const float beta_;
- using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
- isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
+ jit_generator * const h;
- size_t vlen = cpu_isa_traits<isa>::vlen;
+ const bool save_state_;
+ const Xbyak::Reg64 p_table;
+ const Xbyak::Opmask k_mask;
+ Xbyak::Label l_table;
- alg_kind_t elt_alg;
- float alpha;
- float beta;
+private:
+ // if only the injector was inherited from jit_generator...
+ enum {
+ _cmp_le_os = jit_generator::_cmp_le_os,
+ _cmp_nle_us = jit_generator::_cmp_nle_us,
+ _op_floor = jit_generator::_op_floor,
+ };
- bool save_vecs_state;
- int table_reg_idx;
- int opmask_idx;
+ size_t vlen = cpu_isa_traits<isa>::vlen;
const static size_t preserved_vecs_max = 5;
size_t preserved_vec_idxs[preserved_vecs_max] = {0};
size_t start_idx_tail = 0;
- Vmm vmm_mask, vmm_aux0, vmm_aux1, vmm_aux2, vmm_aux3;
-
- Xbyak::Reg64 p_table;
- Xbyak::Opmask k_mask;
- Xbyak::Label l_table;
+ Vmm vmm_mask, vmm_aux0, vmm_aux1, vmm_aux2, vmm_aux3, vmm_aux4;
- int aux_vecs_count(alg_kind_t elt_alg);
+ Xbyak::Address table_val(int index)
+ { return h->ptr[p_table + index * vlen]; }
+ int aux_vecs_count(alg_kind_t alg);
void compute_body(size_t start_idx, size_t end_idx);
void injector_preamble(size_t start_idx, size_t end_idx);
void injector_preamble_tail(size_t start_idx);
void injector_postamble();
void assign_regs();
- bool is_free_vec(size_t idx);
void exp_compute_vector(const Vmm &vmm_src);
void relu_compute_vector(const Vmm &vmm_src);
virtual status_t init() override;
};
- jit_uni_eltwise_fwd_t(const pd_t *pd, const input_vector &inputs,
+ jit_uni_eltwise_fwd_t(const pd_t *apd, const input_vector &inputs,
const output_vector &outputs);
~jit_uni_eltwise_fwd_t();
typedef typename prec_traits<data_type::f32>::type data_t;
- virtual void execute(event_t *e)
+ virtual void execute(event_t *e) const
{
execute_forward();
e->set_state(event_t::ready);
}
private:
- void execute_forward();
- pd_t conf_;
+ void execute_forward() const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
jit_uni_eltwise_kernel_f32 *kernel_;
};
virtual status_t init() override;
};
- jit_uni_eltwise_bwd_t(const pd_t *pd, const input_vector &inputs,
+ jit_uni_eltwise_bwd_t(const pd_t *apd, const input_vector &inputs,
const output_vector &outputs);
~jit_uni_eltwise_bwd_t();
typedef typename prec_traits<data_type::f32>::type data_t;
- virtual void execute(event_t *e)
+ virtual void execute(event_t *e) const
{
execute_backward();
e->set_state(event_t::ready);
}
private:
- void execute_backward();
- pd_t conf_;
+ void execute_backward() const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
jit_uni_eltwise_kernel_f32 *kernel_;
};