1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef CPU_JIT_UNI_ELTWISE_HPP
18 #define CPU_JIT_UNI_ELTWISE_HPP
22 #include "c_types_map.hpp"
23 #include "cpu_eltwise_pd.hpp"
24 #include "cpu_engine.hpp"
25 #include "type_helpers.hpp"
27 #include "jit_generator.hpp"
33 template <cpu_isa_t isa>
34 struct jit_uni_eltwise_injector_f32 {
35 using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
36 isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
38 jit_uni_eltwise_injector_f32(jit_generator *host, alg_kind_t alg,
39 float alpha, float beta, bool save_state = true,
40 Xbyak::Reg64 p_table = Xbyak::util::rax,
41 Xbyak::Opmask k_mask = Xbyak::Opmask(1))
42 : alg_(alg), alpha_(alpha), beta_(beta), h(host)
43 , save_state_(save_state), p_table(p_table), k_mask(k_mask)
45 using namespace alg_kind;
46 assert(utils::one_of(isa, sse42, avx2, avx512_common));
47 assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu,
48 eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
49 eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic,
50 eltwise_clamp, eltwise_exp));
53 // note that eltwise.scale is ignored
54 jit_uni_eltwise_injector_f32(jit_generator *host,
55 const post_ops_t::entry_t::eltwise_t &eltwise,
56 bool save_state = true, Xbyak::Reg64 p_table = Xbyak::util::rax,
57 Xbyak::Opmask k_mask = Xbyak::Opmask(1))
58 : jit_uni_eltwise_injector_f32(host, eltwise.alg, eltwise.alpha,
59 eltwise.beta, save_state, p_table, k_mask) {}
61 void compute_vector_range(size_t start_idx, size_t end_idx);
62 void compute_vector(size_t idx) { compute_vector_range(idx, idx + 1); }
63 void prepare_table(bool gen_table=true);
64 void load_table_addr() { h->mov(p_table, l_table); }
66 const alg_kind_t alg_;
70 jit_generator * const h;
72 const bool save_state_;
73 const Xbyak::Reg64 p_table;
74 const Xbyak::Opmask k_mask;
78 // if only the injector was inherited from jit_generator...
80 _cmp_le_os = jit_generator::_cmp_le_os,
81 _cmp_nle_us = jit_generator::_cmp_nle_us,
82 _op_floor = jit_generator::_op_floor,
85 size_t vlen = cpu_isa_traits<isa>::vlen;
87 const static size_t preserved_vecs_max = 5;
89 size_t vecs_to_preserve = 0;
90 size_t vecs_count = isa == avx512_common ? 32 : 16;
91 size_t preserved_vecs_count = 0;
92 size_t preserved_vec_idxs[preserved_vecs_max] = {0};
93 size_t start_idx_tail = 0;
95 Vmm vmm_mask, vmm_aux0, vmm_aux1, vmm_aux2, vmm_aux3, vmm_aux4;
97 Xbyak::Address table_val(int index)
98 { return h->ptr[p_table + index * vlen]; }
100 int aux_vecs_count(alg_kind_t alg);
101 void compute_body(size_t start_idx, size_t end_idx);
102 void injector_preamble(size_t start_idx, size_t end_idx);
103 void injector_preamble_tail(size_t start_idx);
104 void injector_postamble();
107 void exp_compute_vector(const Vmm &vmm_src);
108 void relu_compute_vector(const Vmm &vmm_src);
109 void relu_zero_ns_compute_vector(const Vmm &vmm_src);
110 void elu_compute_vector(const Vmm &vmm_src);
111 void tanh_compute_vector(const Vmm &vmm_src);
112 void square_compute_vector(const Vmm &vmm_src);
113 void abs_compute_vector(const Vmm &vmm_src);
114 void sqrt_compute_vector(const Vmm &vmm_src);
115 void linear_compute_vector(const Vmm &vmm_src);
116 void bounded_relu_compute_vector(const Vmm &vmm_src);
117 void soft_relu_compute_vector(const Vmm &vmm_src);
118 void logistic_compute_vector(const Vmm &vmm_src);
119 void clamp_compute_vector(const Vmm &vmm_src);
121 void relu_prepare_table();
122 void elu_prepare_table();
123 void soft_relu_prepare_table();
124 void abs_prepare_table();
125 void sqrt_prepare_table();
126 void linear_prepare_table();
127 void bounded_relu_prepare_table();
128 void clamp_prepare_table();
131 struct jit_uni_eltwise_kernel_f32;
133 template <cpu_isa_t isa>
134 struct jit_uni_eltwise_fwd_t : public cpu_primitive_t {
135 struct pd_t : public cpu_eltwise_fwd_pd_t {
136 pd_t(engine_t *engine, const eltwise_desc_t *adesc,
137 const primitive_attr_t *attr,
138 const eltwise_fwd_pd_t *hint_fwd_pd)
139 : cpu_eltwise_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) {}
142 JIT_IMPL_NAME_HELPER("jit:", isa, ""),
143 jit_uni_eltwise_fwd_t<isa>);
145 virtual status_t init() override;
148 jit_uni_eltwise_fwd_t(const pd_t *apd, const input_vector &inputs,
149 const output_vector &outputs);
150 ~jit_uni_eltwise_fwd_t();
152 typedef typename prec_traits<data_type::f32>::type data_t;
154 virtual void execute(event_t *e) const
157 e->set_state(event_t::ready);
161 void execute_forward() const;
162 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
163 jit_uni_eltwise_kernel_f32 *kernel_;
166 template <cpu_isa_t isa>
167 struct jit_uni_eltwise_bwd_t : public cpu_primitive_t {
168 struct pd_t : public cpu_eltwise_bwd_pd_t {
169 pd_t(engine_t *engine, const eltwise_desc_t *adesc,
170 const primitive_attr_t *attr,
171 const eltwise_fwd_pd_t *hint_fwd_pd)
172 : cpu_eltwise_bwd_pd_t(engine, adesc, attr, hint_fwd_pd) {}
175 JIT_IMPL_NAME_HELPER("jit:", isa, ""),
176 jit_uni_eltwise_bwd_t<isa>);
178 virtual status_t init() override;
181 jit_uni_eltwise_bwd_t(const pd_t *apd, const input_vector &inputs,
182 const output_vector &outputs);
183 ~jit_uni_eltwise_bwd_t();
185 typedef typename prec_traits<data_type::f32>::type data_t;
187 virtual void execute(event_t *e) const
190 e->set_state(event_t::ready);
194 void execute_backward() const;
195 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
196 jit_uni_eltwise_kernel_f32 *kernel_;