Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_depthwise.hpp
1 /*******************************************************************************
2 * Copyright 2018-2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_JIT_UNI_DEPTHWISE_HPP
18 #define CPU_JIT_UNI_DEPTHWISE_HPP
19
20 #include <assert.h>
21
22 #include "c_types_map.hpp"
23 #include "cpu_depthwise_pd.hpp"
24 #include "cpu_engine.hpp"
25 #include "type_helpers.hpp"
26 #include "utils.hpp"
27 #include "jit_primitive_conf.hpp"
28 #include "jit_generator.hpp"
29 #include "jit_uni_eltwise.hpp"
30
31 namespace mkldnn {
32 namespace impl {
33 namespace cpu {
34
35 template <cpu_isa_t isa>
36 struct jit_uni_depthwise_injector_f32 {
37     jit_uni_depthwise_injector_f32(jit_generator* host, alg_kind_t depthwise_alg_)
38         : h(host), depthwise_alg(depthwise_alg_) {
39         assert(utils::one_of(isa, sse42, avx2, avx512_common));
40         assert(utils::one_of(depthwise_alg, alg_kind::depthwise_scale_shift, alg_kind::depthwise_prelu));
41     }
42
43     void compute_vector_range(int start_idx, int end_idx, const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias);
44
45 private:
46     jit_generator* h;
47
48     using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
49             isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
50
51     size_t vlen = cpu_isa_traits<isa>::vlen;
52
53     alg_kind_t depthwise_alg;
54
55     Vmm vmm_mask;
56     Vmm vmm_aux0;
57
58     Xbyak::Opmask k_mask = Xbyak::Opmask(1);
59
60     const static size_t preserved_vecs_max = 5;
61     size_t vecs_to_preserve = 0;
62     size_t vecs_count = isa == avx512_common ? 32 : 16;
63     size_t preserved_vecs_count = 0;
64     size_t preserved_vec_idxs[preserved_vecs_max] = {0};
65     size_t start_idx_tail = 0;
66
67     int aux_vecs_count(alg_kind_t elt_alg);
68
69     void compute_body(size_t start_idx, size_t end_idx, const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias);
70     void injector_preamble(size_t start_idx, size_t end_idx);
71     void injector_preamble_tail(size_t start_idx, size_t end_idx);
72     void injector_postamble();
73     void assign_regs();
74
75     void scale_shift_compute_vector(const Vmm &vmm_src, const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias);
76     void prelu_compute_vector(const Vmm &vmm_src, const Xbyak::Reg64& p_weights, const Xbyak::Reg64& p_bias);
77 };
78
79 struct jit_uni_depthwise_kernel_f32;
80
81 template <cpu_isa_t isa>
82 struct jit_uni_depthwise_fwd_t : public cpu_primitive_t {
83     struct pd_t : public cpu_depthwise_fwd_pd_t {
84         pd_t(engine_t *engine, const depthwise_desc_t *adesc,
85                 const primitive_attr_t *attr,
86                 const depthwise_fwd_pd_t *hint_fwd_pd)
87             : cpu_depthwise_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) {}
88
89         DECLARE_COMMON_PD_T(
90                 JIT_IMPL_NAME_HELPER("jit:", isa, ""),
91                 jit_uni_depthwise_fwd_t<isa>);
92
93         virtual status_t init() override;
94     };
95
96     jit_uni_depthwise_fwd_t(const pd_t *apd, const input_vector &inputs,
97                        const output_vector &outputs);
98     ~jit_uni_depthwise_fwd_t();
99
100     typedef typename prec_traits<data_type::f32>::type data_t;
101
102     virtual void execute(event_t *e) const
103     {
104         execute_forward();
105         e->set_state(event_t::ready);
106     }
107
108 private:
109     void execute_forward() const;
110     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
111     jit_uni_depthwise_kernel_f32 *kernel_;
112     data_t *padded_weights_;
113     data_t *padded_bias_;
114 };
115
116
117 template <cpu_isa_t isa>
118 struct jit_uni_dw_conv_row_f32: public jit_generator {
119     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_ds_dw_conv_kernel_f32)
120
121     jit_uni_dw_conv_row_f32(jit_conv_conf_t ajcp, const primitive_attr_t &attr, int ow_stride)
122         : jcp(ajcp), attr_(attr), ow_stride_(ow_stride) {
123         this->generate();
124         jit_ker = (void (*)(jit_conv_call_s *))this->getCode();
125     }
126
127     ~jit_uni_dw_conv_row_f32() {
128         for (auto inj : eltwise_injectors)
129             delete inj;
130         eltwise_injectors.clear();
131
132         for (auto inj : depthwise_injectors)
133             delete inj;
134         depthwise_injectors.clear();
135     }
136
137     static bool post_ops_ok(jit_conv_conf_t &jcp,
138             const primitive_attr_t &attr);
139     static status_t init_conf(jit_1x1_conv_conf_t &jcp, jit_conv_conf_t &jcp_dw, const primitive_attr_t &attr);
140     static status_t init_conf(jit_conv_conf_t &jcp, jit_conv_conf_t &jcp_dw, const primitive_attr_t &attr);
141     static status_t init_conf(jit_bin_conv_conf_t &jcp, jit_conv_conf_t &jcp_dw, const primitive_attr_t &attr);
142
143     jit_conv_conf_t jcp;
144     const primitive_attr_t &attr_;
145     void (*jit_ker)(jit_conv_call_s *);
146     int ow_stride_;
147
148 private:
149     using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
150         isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
151     using reg64_t = const Xbyak::Reg64;
152     using reg32_t = const Xbyak::Reg32;
153     using reg8_t = const Xbyak::Reg8;
154     const Xbyak::AddressFrame &vmmword = (isa == sse42)
155         ? xword : (isa == avx2) ? yword : zword;
156     const int vlen = cpu_isa_traits<isa>::vlen;
157
158     // dw convolution
159     reg64_t reg_input0 = r8;
160     reg64_t reg_input1 = r9;
161     reg64_t reg_input2 = r10;
162     reg64_t aux_reg_input0 = r11;
163     reg64_t aux_reg_input1 = r12;
164     reg64_t aux_reg_input2 = r13;
165
166     reg64_t reg_kernel = r14;
167     reg64_t aux_reg_kernel = r15;
168     reg64_t reg_output = rdx;
169     reg64_t reg_bias = rbx;
170     reg64_t reg_kh = rax;
171     reg64_t reg_ur_w = rbp;
172     reg64_t reg_oc_work = abi_not_param1;
173
174     reg64_t reg_oc_off = rsi;
175     reg64_t reg_d_weights = aux_reg_input0;
176     reg64_t reg_d_bias = aux_reg_input1;
177
178     reg64_t reg_b_weights = r15;
179     reg64_t reg_b_mask = reg_d_bias;
180
181     reg32_t reg_tmp_32 = r11d;
182     reg64_t reg_tmp_64 = r11;
183     reg8_t reg_tmp_8 = r11b;
184
185     reg32_t reg_tmp2_32 = r13d;
186     reg64_t reg_tmp2_64 = r13;
187
188     inline Vmm get_ker_reg(int idx) { return Vmm(idx + 0); }
189     inline Vmm get_src_reg(int idx) { return Vmm(idx + 1); }
190     inline Vmm get_acc_reg(int idx) { return Vmm(idx + 4); }
191
192     Xbyak::Ymm ymm_tmp = Xbyak::Ymm(0);
193     Vmm vmm_tmp = Vmm(0);
194     Vmm vmm_sum = Vmm(0);
195     Vmm vmm_bias = Vmm(0);
196     Vmm vmm_thr = Vmm(0);
197
198     inline void load_src(int ur_w);
199     inline void apply_filter(int ur_w, int kw_size);
200     inline void cvt2ps(data_type_t type_in, Vmm vmm_in, const Xbyak::Operand &op, bool scalar_load);
201     inline void apply_postprocessing(int ur_w, int oc_step);
202     inline void store_dst_typed(const Xbyak::Address &op, Vmm vmm_dst, bool scalar_store);
203     inline void store_dst(int ur_w, int oc_step);
204     inline void loop_body(int oc_step);
205
206     void generate();
207
208     nstl::vector<jit_uni_eltwise_injector_f32<isa>*> eltwise_injectors;
209     nstl::vector<jit_uni_depthwise_injector_f32<isa>*> depthwise_injectors;
210 };
211
212 }
213 }
214 }
215
216 #endif