Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_dw_conv_kernel_f32.hpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef JIT_UNI_DW_CONV_KERNEL_F32_HPP
18 #define JIT_UNI_DW_CONV_KERNEL_F32_HPP
19
20 #include "c_types_map.hpp"
21 #include "memory_tracking.hpp"
22
23 #include "jit_generator.hpp"
24 #include "jit_primitive_conf.hpp"
25 #include "jit_uni_eltwise.hpp"
26 #include "jit_uni_depthwise.hpp"
27
28 namespace mkldnn {
29 namespace impl {
30 namespace cpu {
31
32 template <cpu_isa_t isa>
33 struct jit_uni_dw_conv_fwd_kernel_f32: public jit_generator {
34     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_fwd_kernel_f32)
35
36     jit_uni_dw_conv_fwd_kernel_f32(jit_conv_conf_t ajcp,
37             const primitive_attr_t &attr): jcp(ajcp), attr_(attr) {
38         this->generate();
39         jit_ker = (void (*)(jit_conv_call_s *))this->getCode();
40     }
41
42     ~jit_uni_dw_conv_fwd_kernel_f32() {
43         for (auto inj : eltwise_injectors)
44             delete inj;
45         eltwise_injectors.clear();
46
47         for (auto inj : depthwise_injectors)
48             delete inj;
49         depthwise_injectors.clear();
50     }
51
52     static bool post_ops_ok(jit_conv_conf_t &jcp,
53             const primitive_attr_t &attr);
54     static status_t init_conf(jit_conv_conf_t &jcp,
55             const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
56             const memory_desc_wrapper &weights_d,
57             const memory_desc_wrapper &dst_d, const primitive_attr_t &attr);
58
59     static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
60             const jit_conv_conf_t &jcp);
61
62     jit_conv_conf_t jcp;
63     const primitive_attr_t &attr_;
64     void (*jit_ker)(jit_conv_call_s *);
65
66 private:
67     using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
68         isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
69     using reg64_t = const Xbyak::Reg64;
70     const Xbyak::AddressFrame &vmmword = (isa == sse42)
71         ? xword : (isa == avx2) ? yword : zword;
72     const int vlen = cpu_isa_traits<isa>::vlen;
73
74     // dw convolution
75     reg64_t reg_input = r8;
76     reg64_t aux_reg_input = r9;
77     reg64_t aux1_reg_input = r10;
78     reg64_t reg_kernel = r11;
79     reg64_t aux_reg_kernel = r12;
80     reg64_t aux1_reg_kernel = r13;
81     reg64_t reg_output = r14;
82     reg64_t reg_bias = r15;
83     reg64_t reg_kh = rax;
84     reg64_t reg_kw = rbx;
85     reg64_t iter_kh = rdx;
86     reg64_t iter_kw = rsi;
87     reg64_t reg_ur_w = rbp;
88     reg64_t reg_ch_blocks = aux1_reg_input;
89     reg64_t imm_addr64 = aux1_reg_input;
90
91     reg64_t reg_d_weights = imm_addr64;
92     reg64_t reg_d_bias = iter_kh;
93
94     inline Vmm get_ker_reg(int idx) { return Vmm(idx + 0); }
95     inline Vmm get_src_reg(int idx) { return Vmm(idx + 1); }
96     inline Vmm get_acc_reg(int idx) { return Vmm(idx + 4); }
97
98     inline void load_src(int ur_ch_blocks, int ur_w);
99     inline void apply_filter(int ur_ch_blocks, int ur_w);
100     inline void apply_filter_unrolled(int ur_ch_blocks, int ur_w);
101     inline void apply_postprocess(int ur_ch_blocks, int ur_w);
102     inline void store_dst(int ur_ch_blocks, int ur_w);
103     inline void loop_body(int ur_ch_blocks);
104
105     void generate();
106
107     nstl::vector<jit_uni_eltwise_injector_f32<isa>*> eltwise_injectors;
108     nstl::vector<jit_uni_depthwise_injector_f32<isa>*> depthwise_injectors;
109 };
110
111 template <cpu_isa_t isa>
112 struct jit_uni_dw_conv_bwd_data_kernel_f32: public jit_generator {
113     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_bwd_data_kernel_f32)
114
115     jit_uni_dw_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp) {
116         this->generate();
117         jit_ker = (void (*)(jit_conv_call_s *))this->getCode();
118     }
119
120     static status_t init_conf(jit_conv_conf_t &jcp,
121             const convolution_desc_t &cd,
122             const memory_desc_wrapper &diff_src_d,
123             const memory_desc_wrapper &weights_d,
124             const memory_desc_wrapper &diff_dst_d);
125
126     static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
127             const jit_conv_conf_t &jcp);
128
129     jit_conv_conf_t jcp;
130     void (*jit_ker)(jit_conv_call_s *);
131
132 private:
133     using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
134         isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
135     using reg64_t = const Xbyak::Reg64;
136
137     inline Vmm get_ker_reg(int idx) { return Vmm(idx + 0); }
138     inline Vmm get_src_reg(int idx) { return Vmm(idx + 1); }
139     inline Vmm get_acc_reg(int idx) { return Vmm(idx + 4); }
140
141     reg64_t reg_ddst       = rax;
142     reg64_t aux_reg_ddst   = r8;
143     reg64_t aux1_reg_ddst = abi_not_param1;
144     reg64_t reg_kernel     = rdx;
145     reg64_t aux_reg_kernel = r10;
146     reg64_t aux1_reg_kernel = rbp;
147     reg64_t reg_dsrc       = rsi;
148
149     reg64_t reg_ur_str_w = r9;
150     reg64_t reg_ch_blocks = rbx;
151
152     reg64_t iter_kh = r11;
153     reg64_t iter_kw = r12;
154     reg64_t reg_kh  = r13;
155     reg64_t reg_kw  = r14;
156
157     inline void loop_body(int ur_ch_blocks);
158     inline void load_ddst(int ur_ch_blocks, int ur_str_w);
159     inline void apply_filter(int ur_ch_blocks, int ur_str_w);
160     inline void store_dsrc(int ur_ch_blocks, int ur_str_w);
161
162     void generate();
163 };
164
165 template <cpu_isa_t isa>
166 struct jit_uni_dw_conv_bwd_weights_kernel_f32 : public jit_generator {
167
168     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_bwd_weights_kernel_f32)
169
170     jit_uni_dw_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp) : jcp(ajcp) {
171         this->generate();
172         jit_ker = (void (*)(jit_dw_conv_call_s *)) this->getCode();
173     }
174
175     static status_t init_conf(jit_conv_conf_t &jcp,
176             const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
177             const memory_desc_wrapper &diff_weights_d,
178             const memory_desc_wrapper &diff_dst_d, int nthreads);
179
180     static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
181             const jit_conv_conf_t &jcp);
182
183     static void balance(jit_conv_conf_t &jcp, int nthreads);
184
185     jit_conv_conf_t jcp;
186     void (*jit_ker)(jit_dw_conv_call_s *);
187
188 private:
189     using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
190             isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
191     using reg64_t = const Xbyak::Reg64;
192     const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
193     const int reg_repeats = (isa == sse42) ? 2 : 1;
194
195     const Xbyak::AddressFrame &vmmword
196             = (isa == sse42) ? xword : (isa == avx2) ? yword : zword;
197
198     /* XXX: offset between input and accummulators is 3, therefore, assume 'kw'
199      * is no larger than 3*/
200     inline Vmm get_bias_reg(int idx = 0) { return Vmm(idx); }
201     inline Vmm get_output_reg(int idx) { return Vmm(idx + 1); }
202     inline Vmm get_input_reg(int idx) { return Vmm(idx + 4 * reg_repeats + 1); }
203     inline Vmm get_acc_reg(int idx) { return Vmm(idx + 1 * reg_repeats + 1); }
204     inline Vmm get_aux_reg() { return Vmm(0); }
205
206     reg64_t reg_tmp_input = r9;
207     reg64_t reg_tmp_output = r10;
208     reg64_t reg_tmp_filter = r13;
209     reg64_t reg_kh_offset = rax;
210
211     /* parameter passed by driver into kernel */
212     Xbyak::Reg8 reg_exec_flags = bl;
213
214     reg64_t reg_oh_worksize = r14;
215     reg64_t reg_oh = rax;
216
217     reg64_t iter_ow_blk = r11;
218
219     reg64_t reg_kh = rsi;
220     reg64_t reg_kh_count = rdx;
221
222     /* Base addresses for convolution parameters. */
223     reg64_t reg_input_baddr = r15;
224     reg64_t reg_output_baddr = r12;
225     reg64_t reg_filter_baddr = abi_not_param1;
226     reg64_t reg_bias_baddr = r13;
227
228     /* Micro-kernel JIT'ing, fusing 'kw' and 'ow_block' loops into unrolled FMAs
229      */
230     inline void compute_ow_step_unroll(
231             int unroll_w, int l_pad, int pad_offset, int ow_block);
232
233     /* JIT'ing the outer loops for the micro-kernel -> {kh, oh_block} */
234     inline void compute_h_step(
235             int unroll_w, int l_pad, int pad_offset, int ow_block);
236     inline void compute_h_loop(
237             int unroll_w, int l_pad, int pad_offset, int ow_block);
238
239     /* Write 'width' micro-kernel JITs; depending on the padding and convolution
240      * size, write a micro-kernel for the left ow-block, middle ow-block(s), and
241      * right ow-block.*/
242     inline void compute_ow_block_unroll();
243
244     inline void compute_zero_filter();
245     inline void load_filter();
246     inline void zero_filter();
247     inline void load_bias();
248     inline void zero_bias();
249     inline void compute_bias_step_unroll(const int unroll_w);
250     inline void compute_bias_loop(const int block_size);
251     inline void store_filter();
252     inline void store_bias();
253
254     void generate();
255 };
256 }
257 }
258 }
259
260 #endif