1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef JIT_UNI_DW_CONV_KERNEL_F32_HPP
18 #define JIT_UNI_DW_CONV_KERNEL_F32_HPP
20 #include "c_types_map.hpp"
21 #include "memory_tracking.hpp"
23 #include "jit_generator.hpp"
24 #include "jit_primitive_conf.hpp"
25 #include "jit_uni_eltwise.hpp"
26 #include "jit_uni_depthwise.hpp"
32 template <cpu_isa_t isa>
33 struct jit_uni_dw_conv_fwd_kernel_f32: public jit_generator {
34 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_fwd_kernel_f32)
36 jit_uni_dw_conv_fwd_kernel_f32(jit_conv_conf_t ajcp,
37 const primitive_attr_t &attr): jcp(ajcp), attr_(attr) {
39 jit_ker = (void (*)(jit_conv_call_s *))this->getCode();
42 ~jit_uni_dw_conv_fwd_kernel_f32() {
43 for (auto inj : eltwise_injectors)
45 eltwise_injectors.clear();
47 for (auto inj : depthwise_injectors)
49 depthwise_injectors.clear();
52 static bool post_ops_ok(jit_conv_conf_t &jcp,
53 const primitive_attr_t &attr);
54 static status_t init_conf(jit_conv_conf_t &jcp,
55 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
56 const memory_desc_wrapper &weights_d,
57 const memory_desc_wrapper &dst_d, const primitive_attr_t &attr);
59 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
60 const jit_conv_conf_t &jcp);
63 const primitive_attr_t &attr_;
64 void (*jit_ker)(jit_conv_call_s *);
67 using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
68 isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
69 using reg64_t = const Xbyak::Reg64;
70 const Xbyak::AddressFrame &vmmword = (isa == sse42)
71 ? xword : (isa == avx2) ? yword : zword;
72 const int vlen = cpu_isa_traits<isa>::vlen;
75 reg64_t reg_input = r8;
76 reg64_t aux_reg_input = r9;
77 reg64_t aux1_reg_input = r10;
78 reg64_t reg_kernel = r11;
79 reg64_t aux_reg_kernel = r12;
80 reg64_t aux1_reg_kernel = r13;
81 reg64_t reg_output = r14;
82 reg64_t reg_bias = r15;
85 reg64_t iter_kh = rdx;
86 reg64_t iter_kw = rsi;
87 reg64_t reg_ur_w = rbp;
88 reg64_t reg_ch_blocks = aux1_reg_input;
89 reg64_t imm_addr64 = aux1_reg_input;
91 reg64_t reg_d_weights = imm_addr64;
92 reg64_t reg_d_bias = iter_kh;
94 inline Vmm get_ker_reg(int idx) { return Vmm(idx + 0); }
95 inline Vmm get_src_reg(int idx) { return Vmm(idx + 1); }
96 inline Vmm get_acc_reg(int idx) { return Vmm(idx + 4); }
98 inline void load_src(int ur_ch_blocks, int ur_w);
99 inline void apply_filter(int ur_ch_blocks, int ur_w);
100 inline void apply_filter_unrolled(int ur_ch_blocks, int ur_w);
101 inline void apply_postprocess(int ur_ch_blocks, int ur_w);
102 inline void store_dst(int ur_ch_blocks, int ur_w);
103 inline void loop_body(int ur_ch_blocks);
107 nstl::vector<jit_uni_eltwise_injector_f32<isa>*> eltwise_injectors;
108 nstl::vector<jit_uni_depthwise_injector_f32<isa>*> depthwise_injectors;
111 template <cpu_isa_t isa>
112 struct jit_uni_dw_conv_bwd_data_kernel_f32: public jit_generator {
113 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_bwd_data_kernel_f32)
115 jit_uni_dw_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp) {
117 jit_ker = (void (*)(jit_conv_call_s *))this->getCode();
120 static status_t init_conf(jit_conv_conf_t &jcp,
121 const convolution_desc_t &cd,
122 const memory_desc_wrapper &diff_src_d,
123 const memory_desc_wrapper &weights_d,
124 const memory_desc_wrapper &diff_dst_d);
126 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
127 const jit_conv_conf_t &jcp);
130 void (*jit_ker)(jit_conv_call_s *);
133 using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
134 isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
135 using reg64_t = const Xbyak::Reg64;
137 inline Vmm get_ker_reg(int idx) { return Vmm(idx + 0); }
138 inline Vmm get_src_reg(int idx) { return Vmm(idx + 1); }
139 inline Vmm get_acc_reg(int idx) { return Vmm(idx + 4); }
141 reg64_t reg_ddst = rax;
142 reg64_t aux_reg_ddst = r8;
143 reg64_t aux1_reg_ddst = abi_not_param1;
144 reg64_t reg_kernel = rdx;
145 reg64_t aux_reg_kernel = r10;
146 reg64_t aux1_reg_kernel = rbp;
147 reg64_t reg_dsrc = rsi;
149 reg64_t reg_ur_str_w = r9;
150 reg64_t reg_ch_blocks = rbx;
152 reg64_t iter_kh = r11;
153 reg64_t iter_kw = r12;
154 reg64_t reg_kh = r13;
155 reg64_t reg_kw = r14;
157 inline void loop_body(int ur_ch_blocks);
158 inline void load_ddst(int ur_ch_blocks, int ur_str_w);
159 inline void apply_filter(int ur_ch_blocks, int ur_str_w);
160 inline void store_dsrc(int ur_ch_blocks, int ur_str_w);
165 template <cpu_isa_t isa>
166 struct jit_uni_dw_conv_bwd_weights_kernel_f32 : public jit_generator {
168 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_bwd_weights_kernel_f32)
170 jit_uni_dw_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp) : jcp(ajcp) {
172 jit_ker = (void (*)(jit_dw_conv_call_s *)) this->getCode();
175 static status_t init_conf(jit_conv_conf_t &jcp,
176 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
177 const memory_desc_wrapper &diff_weights_d,
178 const memory_desc_wrapper &diff_dst_d, int nthreads);
180 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
181 const jit_conv_conf_t &jcp);
183 static void balance(jit_conv_conf_t &jcp, int nthreads);
186 void (*jit_ker)(jit_dw_conv_call_s *);
189 using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
190 isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
191 using reg64_t = const Xbyak::Reg64;
192 const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
193 const int reg_repeats = (isa == sse42) ? 2 : 1;
195 const Xbyak::AddressFrame &vmmword
196 = (isa == sse42) ? xword : (isa == avx2) ? yword : zword;
198 /* XXX: offset between input and accummulators is 3, therefore, assume 'kw'
199 * is no larger than 3*/
200 inline Vmm get_bias_reg(int idx = 0) { return Vmm(idx); }
201 inline Vmm get_output_reg(int idx) { return Vmm(idx + 1); }
202 inline Vmm get_input_reg(int idx) { return Vmm(idx + 4 * reg_repeats + 1); }
203 inline Vmm get_acc_reg(int idx) { return Vmm(idx + 1 * reg_repeats + 1); }
204 inline Vmm get_aux_reg() { return Vmm(0); }
206 reg64_t reg_tmp_input = r9;
207 reg64_t reg_tmp_output = r10;
208 reg64_t reg_tmp_filter = r13;
209 reg64_t reg_kh_offset = rax;
211 /* parameter passed by driver into kernel */
212 Xbyak::Reg8 reg_exec_flags = bl;
214 reg64_t reg_oh_worksize = r14;
215 reg64_t reg_oh = rax;
217 reg64_t iter_ow_blk = r11;
219 reg64_t reg_kh = rsi;
220 reg64_t reg_kh_count = rdx;
222 /* Base addresses for convolution parameters. */
223 reg64_t reg_input_baddr = r15;
224 reg64_t reg_output_baddr = r12;
225 reg64_t reg_filter_baddr = abi_not_param1;
226 reg64_t reg_bias_baddr = r13;
228 /* Micro-kernel JIT'ing, fusing 'kw' and 'ow_block' loops into unrolled FMAs
230 inline void compute_ow_step_unroll(
231 int unroll_w, int l_pad, int pad_offset, int ow_block);
233 /* JIT'ing the outer loops for the micro-kernel -> {kh, oh_block} */
234 inline void compute_h_step(
235 int unroll_w, int l_pad, int pad_offset, int ow_block);
236 inline void compute_h_loop(
237 int unroll_w, int l_pad, int pad_offset, int ow_block);
239 /* Write 'width' micro-kernel JITs; depending on the padding and convolution
240 * size, write a micro-kernel for the left ow-block, middle ow-block(s), and
242 inline void compute_ow_block_unroll();
244 inline void compute_zero_filter();
245 inline void load_filter();
246 inline void zero_filter();
247 inline void load_bias();
248 inline void zero_bias();
249 inline void compute_bias_step_unroll(const int unroll_w);
250 inline void compute_bias_loop(const int block_size);
251 inline void store_filter();
252 inline void store_bias();