Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_eltwise.hpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_JIT_UNI_ELTWISE_HPP
18 #define CPU_JIT_UNI_ELTWISE_HPP
19
20 #include <assert.h>
21 #include <mkldnn.hpp>
22
23 #include "c_types_map.hpp"
24 #include "cpu_eltwise_pd.hpp"
25 #include "cpu_engine.hpp"
26 #include "type_helpers.hpp"
27 #include "utils.hpp"
28
29 namespace mkldnn {
30 namespace impl {
31 namespace cpu {
32
33 struct jit_uni_eltwise_kernel_f32;
34
35 template <cpu_isa_t isa>
36 struct jit_uni_eltwise_vector_f32;
37
38 template <cpu_isa_t isa>
39 struct jit_uni_relu_vector_f32;
40
41 template <cpu_isa_t isa>
42 struct jit_uni_eltwise_vector_f32 : public c_compatible {
43     using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
44             isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
45
46     const int vlen = cpu_isa_traits<isa>::vlen;
47
48     Xbyak::Opmask k_mask;
49     Vmm vmm_mask;
50     Vmm vmm_ns;
51     Xbyak::Xmm xmm_ns;
52     Vmm vmm_zero;
53     Vmm vmm_aux0;
54     Xbyak::Xmm xmm_aux0;
55     Vmm vmm_aux1;
56     Vmm vmm_aux2;
57     Vmm vmm_src_rem;
58
59     Xbyak::Reg64 imm_addr64;
60
61     alg_kind_t elt_alg = alg_kind::undef;
62
63     jit_generator *generator = nullptr;
64
65     ~jit_uni_eltwise_vector_f32() { release(); }
66
67     void init(alg_kind_t elt_alg_, nstl::vector<int> &shared_vecs, nstl::vector<Xbyak::Reg64> &shared_regs);
68     void release() { if (generator != nullptr) { delete generator; generator = nullptr; } }
69
70     static int sharedVecsCount(alg_kind_t elt_alg) {
71         //TODO (dmitrygo): upper bound. Should be specified for each type.
72         return isa == avx512_common ? 5 : 4;
73     }
74
75     static int sharedRegsCount(alg_kind_t elt_alg) {
76         //TODO (dmitrygo): upper bound. Should be specified for each type.
77         return 1;
78     }
79
80     jit_code_injection prepareConstants(float alpha, float beta);
81     jit_code_injection computeVector(const Vmm &vmm_src, const Vmm &vmm_dst);
82     jit_code_injection prepareTable();
83 };
84
85 template <cpu_isa_t isa>
86 struct jit_uni_eltwise_fwd_t : public cpu_primitive_t {
87     struct pd_t : public cpu_eltwise_fwd_pd_t {
88         pd_t(engine_t *engine, const eltwise_desc_t *adesc,
89                 const primitive_attr_t *attr,
90                 const eltwise_fwd_pd_t *hint_fwd_pd)
91             : cpu_eltwise_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) {}
92
93         DECLARE_COMMON_PD_T(
94                 JIT_IMPL_NAME_HELPER("jit:", isa, ""),
95                 jit_uni_eltwise_fwd_t<isa>);
96
97         virtual status_t init() override;
98     };
99
100     jit_uni_eltwise_fwd_t(const pd_t *pd, const input_vector &inputs,
101                        const output_vector &outputs);
102     ~jit_uni_eltwise_fwd_t();
103
104     typedef typename prec_traits<data_type::f32>::type data_t;
105
106     virtual void execute(event_t *e)
107     {
108         execute_forward();
109         e->set_state(event_t::ready);
110     }
111
112 private:
113     void execute_forward();
114     pd_t conf_;
115     jit_uni_eltwise_kernel_f32 *kernel_;
116 };
117
118 template <cpu_isa_t isa>
119 struct jit_uni_eltwise_bwd_t : public cpu_primitive_t {
120     struct pd_t : public cpu_eltwise_bwd_pd_t {
121         pd_t(engine_t *engine, const eltwise_desc_t *adesc,
122                 const primitive_attr_t *attr,
123                 const eltwise_fwd_pd_t *hint_fwd_pd)
124             : cpu_eltwise_bwd_pd_t(engine, adesc, attr, hint_fwd_pd) {}
125
126         DECLARE_COMMON_PD_T(
127                 JIT_IMPL_NAME_HELPER("jit:", isa, ""),
128                 jit_uni_eltwise_bwd_t<isa>);
129
130         virtual status_t init() override;
131     };
132
133     jit_uni_eltwise_bwd_t(const pd_t *pd, const input_vector &inputs,
134                        const output_vector &outputs);
135     ~jit_uni_eltwise_bwd_t();
136
137     typedef typename prec_traits<data_type::f32>::type data_t;
138
139     virtual void execute(event_t *e)
140     {
141         execute_backward();
142         e->set_state(event_t::ready);
143     }
144
145 private:
146     void execute_backward();
147     pd_t conf_;
148     jit_uni_eltwise_kernel_f32 *kernel_;
149 };
150
151 }
152 }
153 }
154
155 #endif