1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef JIT_AVX512_CORE_CONV_WINOGRAD_KERNEL_F32_HPP
18 #define JIT_AVX512_CORE_CONV_WINOGRAD_KERNEL_F32_HPP
20 #include "c_types_map.hpp"
21 #include "cpu_memory.hpp"
23 #include "jit_generator.hpp"
24 #include "jit_primitive_conf.hpp"
26 #include "jit_avx512_common_conv_winograd_kernel_f32.hpp"
32 struct _jit_avx512_core_conv_winograd_data_kernel_f32 : public jit_generator {
33 _jit_avx512_core_conv_winograd_data_kernel_f32(
34 jit_conv_winograd_conf_t ajcp)
38 this->weights_transform_data_ker_generate();
39 weights_transform_data_ker
40 = (decltype(weights_transform_data_ker)) this->getCode();
44 const Xbyak::uint8 *addr = getCurr();
45 this->input_transform_data_ker_generate();
46 input_transform_data_ker = (decltype(input_transform_data_ker))addr;
50 const Xbyak::uint8 *addr = getCurr();
51 this->output_transform_data_ker_generate();
52 output_transform_data_ker
53 = (decltype(output_transform_data_ker))addr;
57 const Xbyak::uint8 *addr = getCurr();
58 this->gemm_loop_generate();
59 gemm_loop_ker = (decltype(gemm_loop_ker))addr;
63 DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_conv_winograd_data_kernel_f32)
65 static status_t init_conf_common(jit_conv_winograd_conf_t &jcp,
66 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
67 const memory_desc_wrapper &weights_d,
68 const memory_desc_wrapper &dst_d);
70 static status_t init_conf_kernel(
71 jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK);
73 jit_conv_winograd_conf_t jcp;
74 void (*gemm_loop_ker)(float *, const float *, const float *, const int);
75 void (*input_transform_data_ker)(jit_wino_transform_call_s *);
76 void (*output_transform_data_ker)(jit_wino_transform_call_s *);
77 void (*weights_transform_data_ker)(jit_wino_transform_call_s *);
80 using reg64_t = const Xbyak::Reg64;
81 using reg32_t = const Xbyak::Reg32;
82 enum { typesize = sizeof(float) };
84 void gemm_loop_generate();
85 void input_transform_data_ker_generate();
86 void output_transform_data_ker_generate();
87 void weights_transform_data_ker_generate();
89 /* registers used for GEMM */
90 reg64_t reg_dstC = abi_param1;
91 reg64_t reg_srcA = abi_param2;
92 reg64_t reg_srcB = abi_param3;
93 reg64_t reg_is_beta_zero = abi_param4;
95 reg64_t reg_dimM_block_loop_cnt = r10;
96 reg64_t reg_dimK_block_loop_cnt = r11;
98 /* registers used for transforms*/
99 reg64_t param = abi_param1;
101 /* registers used for output_transform_data_ker */
102 reg64_t oreg_temp = rcx;
103 reg64_t oreg_Ow = r9;
104 reg64_t oreg_src = r11;
105 reg64_t oreg_tile_block = r12;
106 reg64_t oreg_tile_block_ur = r13;
107 reg64_t oreg_nb_tile_block_ur = r14;
109 reg64_t oreg_T = r10;
110 reg64_t oreg_dst = r11;
111 reg64_t oreg_ydim = r14;
112 reg64_t oreg_xdim = r15;
113 reg64_t oreg_out_j = r12;
114 reg64_t oreg_bias = rbx;
115 reg64_t imm_addr64 = rax;
117 /* registers used for input_transform_data_ker */
118 reg64_t ireg_temp = rcx;
119 reg64_t ireg_jtiles = rax;
120 reg64_t ireg_itiles = rbx;
122 reg64_t ireg_src = r13;
123 reg64_t ireg_ydim = r14;
124 reg64_t ireg_xdim = r15;
125 reg64_t ireg_inp_j = r12;
126 reg64_t ireg_inp_i = rdx;
127 reg64_t ireg_mask_j = r11;
128 reg64_t ireg_mask = rsi;
129 reg32_t ireg_mask_32 = esi;
130 reg64_t ireg_zero = r9;
131 reg64_t ireg_Iw = r9;
132 reg64_t ireg_T = r10;
133 reg64_t ireg_tile_block = r12;
134 reg64_t ireg_tile_block_ur = r13;
135 reg64_t ireg_nb_tile_block_ur = r14;
136 reg64_t ireg_output = r15;
138 /* registers used for wei transform */
139 reg64_t wreg_temp = rcx;
141 reg64_t wreg_src = r9;
142 reg64_t wreg_MT = r15;
143 reg64_t wreg_M = r14;
144 reg64_t wreg_dst = r10;
145 reg64_t wreg_dst_aux = r9;
146 reg64_t wreg_dst_idx = r8;
147 reg64_t wreg_Fw = r11;
148 reg64_t wreg_T = r12;
149 reg64_t wreg_cnt_j = rdx;
150 reg64_t wreg_F_aux = r14;
151 reg64_t wreg_Fw_aux = r15;
154 struct jit_avx512_core_conv_winograd_fwd_kernel_f32
155 : _jit_avx512_core_conv_winograd_data_kernel_f32 {
156 using _jit_avx512_core_conv_winograd_data_kernel_f32::
157 _jit_avx512_core_conv_winograd_data_kernel_f32;
159 static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr);
161 static status_t init_conf(jit_conv_winograd_conf_t &jcp,
162 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
163 const memory_desc_wrapper &weights_d,
164 const memory_desc_wrapper &dst_d, const primitive_attr_t &attr,
165 bool with_relu = false, float relu_negative_slope = 0.);
168 struct jit_avx512_core_conv_winograd_bwd_data_kernel_f32
169 : public _jit_avx512_core_conv_winograd_data_kernel_f32 {
170 using _jit_avx512_core_conv_winograd_data_kernel_f32::
171 _jit_avx512_core_conv_winograd_data_kernel_f32;
173 static status_t init_conf(jit_conv_winograd_conf_t &jcp,
174 const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d,
175 const memory_desc_wrapper &weights_d,
176 const memory_desc_wrapper &diff_dst_d);
179 struct jit_avx512_core_conv_winograd_bwd_weights_kernel_f32
180 : public jit_generator {
181 DECLARE_CPU_JIT_AUX_FUNCTIONS(
182 _jit_avx512_core_conv_winograd_bwd_weights_kernel_f32)
184 jit_avx512_core_conv_winograd_bwd_weights_kernel_f32(
185 jit_conv_winograd_conf_t ajcp)
188 //******************* First iter kernel ********************//
189 this->gemm_loop_generate(true);
190 gemm_loop_ker_first_iter = (decltype(gemm_loop_ker_first_iter))this->getCode();
193 const Xbyak::uint8 *addr = getCurr();
194 this->src_transform_generate();
195 src_transform = (decltype(src_transform))addr;
200 this->diff_dst_transform_generate(true);
201 diff_dst_transform_wbias = (decltype(diff_dst_transform_wbias))addr;
206 this->diff_dst_transform_generate(false);
207 diff_dst_transform = (decltype(diff_dst_transform))addr;
209 if (jcp.sched_policy != WSCHED_WEI_SDGtWo && jcp.tile_block > 1) {
212 this->gemm_loop_generate(false);
213 gemm_loop_ker = (decltype(gemm_loop_ker))addr;
218 this->diff_weights_transform_generate(true);
219 diff_weights_transform = (decltype(diff_weights_transform))addr;
221 if (jcp.sched_policy == WSCHED_WEI_SDGtWo) {
224 this->diff_weights_transform_generate(false);
225 diff_weights_transform_accum =
226 (decltype(diff_weights_transform_accum))addr;
230 static status_t init_conf(jit_conv_winograd_conf_t &jcp,
231 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
232 const memory_desc_wrapper &diff_dst_d,
233 const memory_desc_wrapper &diff_weights_d);
235 jit_conv_winograd_conf_t jcp;
236 void (*gemm_loop_ker)(float *, const float *, const float *);
237 void (*gemm_loop_ker_first_iter)(float *, const float *, const float *);
238 void (*src_transform)(jit_wino_transform_call_s *);
239 void (*diff_dst_transform)(jit_wino_transform_call_s *);
240 void (*diff_dst_transform_wbias)(jit_wino_transform_call_s *);
241 void (*diff_weights_transform)(jit_wino_transform_call_s *);
242 void (*diff_weights_transform_accum)(jit_wino_transform_call_s *);
245 using reg64_t = const Xbyak::Reg64;
246 using reg32_t = const Xbyak::Reg32;
247 enum { typesize = sizeof(float) };
249 void src_transform_generate();
250 void diff_dst_transform_generate(bool with_bias);
251 void diff_weights_transform_generate(bool first_tile);
253 /*registers common to transforms*/
254 reg64_t reg_transp = abi_param1;
255 reg64_t reg_ti = rbx;
256 reg64_t reg_tj = rcx;
257 reg64_t reg_src = r8;
258 reg64_t reg_dst = r9;
259 reg64_t reg_G = rsi; /*TODO: check if this is ok*/
260 reg64_t reg_temp = rsi;
262 /*registers common to src/diff_dst transform*/
264 reg64_t reg_ydim = r11;
265 reg64_t reg_xdim = r12;
266 reg64_t reg_src_offset = r13;
267 reg64_t reg_zero = r14;
268 reg64_t reg_tile_count = r15;
269 reg64_t reg_maski = rsi;
270 reg32_t reg_maski_32 = esi;
271 reg64_t reg_maskj = rdx;
274 reg64_t reg_oc_ur = rax;
275 reg64_t reg_ic_simd = r14;
276 reg64_t reg_bias = r10;
278 void gemm_loop_generate(bool is_first_tile);
280 reg64_t reg_dstC = abi_param1;
281 reg64_t reg_srcA = abi_param2;
282 reg64_t reg_srcB = abi_param3;
284 reg64_t reg_dimM_block_loop_cnt = r9;
285 reg64_t reg_dimN_block_loop_cnt = r10;
286 reg64_t reg_nb_dimN_bcast_ur = r11;
287 reg64_t reg_dimK_block_loop_cnt = r12;