Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_eltwise.hpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_JIT_UNI_ELTWISE_HPP
18 #define CPU_JIT_UNI_ELTWISE_HPP
19
20 #include <assert.h>
21
22 #include "c_types_map.hpp"
23 #include "cpu_eltwise_pd.hpp"
24 #include "cpu_engine.hpp"
25 #include "type_helpers.hpp"
26 #include "utils.hpp"
27 #include "jit_generator.hpp"
28
29 namespace mkldnn {
30 namespace impl {
31 namespace cpu {
32
33 template <cpu_isa_t isa>
34 struct jit_uni_eltwise_injector_f32 {
35     using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
36             isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
37
38     jit_uni_eltwise_injector_f32(jit_generator *host, alg_kind_t alg,
39             float alpha, float beta, bool save_state = true,
40             Xbyak::Reg64 p_table = Xbyak::util::rax,
41             Xbyak::Opmask k_mask = Xbyak::Opmask(1))
42         : alg_(alg), alpha_(alpha), beta_(beta), h(host)
43         , save_state_(save_state), p_table(p_table), k_mask(k_mask)
44     {
45         using namespace alg_kind;
46         assert(utils::one_of(isa, sse42, avx2, avx512_common));
47         assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu,
48                     eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
49                     eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic,
50                     eltwise_clamp, eltwise_exp));
51     }
52
53     // note that eltwise.scale is ignored
54     jit_uni_eltwise_injector_f32(jit_generator *host,
55             const post_ops_t::entry_t::eltwise_t &eltwise,
56             bool save_state = true, Xbyak::Reg64 p_table = Xbyak::util::rax,
57             Xbyak::Opmask k_mask = Xbyak::Opmask(1))
58         : jit_uni_eltwise_injector_f32(host, eltwise.alg, eltwise.alpha,
59                 eltwise.beta, save_state, p_table, k_mask) {}
60
61     void compute_vector_range(size_t start_idx, size_t end_idx);
62     void compute_vector(size_t idx) { compute_vector_range(idx, idx + 1); }
63     void prepare_table(bool gen_table=true);
64     void load_table_addr() { h->mov(p_table, l_table); }
65
66     const alg_kind_t alg_;
67     const float alpha_;
68     const float beta_;
69
70     jit_generator * const h;
71
72     const bool save_state_;
73     const Xbyak::Reg64 p_table;
74     const Xbyak::Opmask k_mask;
75     Xbyak::Label l_table;
76
77 private:
78     // if only the injector was inherited from jit_generator...
79     enum {
80         _cmp_le_os = jit_generator::_cmp_le_os,
81         _cmp_nle_us = jit_generator::_cmp_nle_us,
82         _op_floor = jit_generator::_op_floor,
83     };
84
85     size_t vlen = cpu_isa_traits<isa>::vlen;
86
87     const static size_t preserved_vecs_max = 5;
88
89     size_t vecs_to_preserve = 0;
90     size_t vecs_count = isa == avx512_common ? 32 : 16;
91     size_t preserved_vecs_count = 0;
92     size_t preserved_vec_idxs[preserved_vecs_max] = {0};
93     size_t start_idx_tail = 0;
94
95     Vmm vmm_mask, vmm_aux0, vmm_aux1, vmm_aux2, vmm_aux3, vmm_aux4;
96
97     Xbyak::Address table_val(int index)
98     { return h->ptr[p_table + index * vlen]; }
99
100     int aux_vecs_count(alg_kind_t alg);
101     void compute_body(size_t start_idx, size_t end_idx);
102     void injector_preamble(size_t start_idx, size_t end_idx);
103     void injector_preamble_tail(size_t start_idx);
104     void injector_postamble();
105     void assign_regs();
106
107     void exp_compute_vector(const Vmm &vmm_src);
108     void relu_compute_vector(const Vmm &vmm_src);
109     void relu_zero_ns_compute_vector(const Vmm &vmm_src);
110     void elu_compute_vector(const Vmm &vmm_src);
111     void tanh_compute_vector(const Vmm &vmm_src);
112     void square_compute_vector(const Vmm &vmm_src);
113     void abs_compute_vector(const Vmm &vmm_src);
114     void sqrt_compute_vector(const Vmm &vmm_src);
115     void linear_compute_vector(const Vmm &vmm_src);
116     void bounded_relu_compute_vector(const Vmm &vmm_src);
117     void soft_relu_compute_vector(const Vmm &vmm_src);
118     void logistic_compute_vector(const Vmm &vmm_src);
119     void clamp_compute_vector(const Vmm &vmm_src);
120
121     void relu_prepare_table();
122     void elu_prepare_table();
123     void soft_relu_prepare_table();
124     void abs_prepare_table();
125     void sqrt_prepare_table();
126     void linear_prepare_table();
127     void bounded_relu_prepare_table();
128     void clamp_prepare_table();
129 };
130
131 struct jit_uni_eltwise_kernel_f32;
132
133 template <cpu_isa_t isa>
134 struct jit_uni_eltwise_fwd_t : public cpu_primitive_t {
135     struct pd_t : public cpu_eltwise_fwd_pd_t {
136         pd_t(engine_t *engine, const eltwise_desc_t *adesc,
137                 const primitive_attr_t *attr,
138                 const eltwise_fwd_pd_t *hint_fwd_pd)
139             : cpu_eltwise_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) {}
140
141         DECLARE_COMMON_PD_T(
142                 JIT_IMPL_NAME_HELPER("jit:", isa, ""),
143                 jit_uni_eltwise_fwd_t<isa>);
144
145         virtual status_t init() override;
146     };
147
148     jit_uni_eltwise_fwd_t(const pd_t *apd, const input_vector &inputs,
149                        const output_vector &outputs);
150     ~jit_uni_eltwise_fwd_t();
151
152     typedef typename prec_traits<data_type::f32>::type data_t;
153
154     virtual void execute(event_t *e) const
155     {
156         execute_forward();
157         e->set_state(event_t::ready);
158     }
159
160 private:
161     void execute_forward() const;
162     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
163     jit_uni_eltwise_kernel_f32 *kernel_;
164 };
165
166 template <cpu_isa_t isa>
167 struct jit_uni_eltwise_bwd_t : public cpu_primitive_t {
168     struct pd_t : public cpu_eltwise_bwd_pd_t {
169         pd_t(engine_t *engine, const eltwise_desc_t *adesc,
170                 const primitive_attr_t *attr,
171                 const eltwise_fwd_pd_t *hint_fwd_pd)
172             : cpu_eltwise_bwd_pd_t(engine, adesc, attr, hint_fwd_pd) {}
173
174         DECLARE_COMMON_PD_T(
175                 JIT_IMPL_NAME_HELPER("jit:", isa, ""),
176                 jit_uni_eltwise_bwd_t<isa>);
177
178         virtual status_t init() override;
179     };
180
181     jit_uni_eltwise_bwd_t(const pd_t *apd, const input_vector &inputs,
182                        const output_vector &outputs);
183     ~jit_uni_eltwise_bwd_t();
184
185     typedef typename prec_traits<data_type::f32>::type data_t;
186
187     virtual void execute(event_t *e) const
188     {
189         execute_backward();
190         e->set_state(event_t::ready);
191     }
192
193 private:
194     void execute_backward() const;
195     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
196     jit_uni_eltwise_kernel_f32 *kernel_;
197 };
198
199 }
200 }
201 }
202
203 #endif