updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef JIT_AVX512_CORE_FP32_WINO_CONV_4x3_KERNEL_HPP
18 #define JIT_AVX512_CORE_FP32_WINO_CONV_4x3_KERNEL_HPP
19
20 #include "c_types_map.hpp"
21 #include "cpu_memory.hpp"
22
23 #include "jit_generator.hpp"
24 #include "jit_primitive_conf.hpp"
25
26 #include "jit_avx512_common_conv_winograd_kernel_f32.hpp"
27
28 namespace mkldnn {
29 namespace impl {
30 namespace cpu {
31
32 struct _jit_avx512_core_fp32_wino_conv_4x3_data_kernel
33     : public jit_generator {
34     _jit_avx512_core_fp32_wino_conv_4x3_data_kernel(
35             jit_conv_winograd_conf_t ajcp)
36         : jcp(ajcp) {
37         {
38             this->weights_transform_data_ker_generate();
39             weights_transform_data_ker
40                     = (decltype(weights_transform_data_ker)) this->getCode();
41         }
42         {
43             align();
44             const Xbyak::uint8 *addr = getCurr();
45             this->input_transform_data_ker_generate();
46             input_transform_data_ker = (decltype(input_transform_data_ker))addr;
47         }
48         {
49             align();
50             const Xbyak::uint8 *addr = getCurr();
51             this->output_transform_data_ker_generate();
52             output_transform_data_ker
53                 = (decltype(output_transform_data_ker))addr;
54         }
55         {
56             align();
57             const Xbyak::uint8 *addr = getCurr();
58             this->gemm_loop_generate();
59             gemm_loop_ker = (decltype(gemm_loop_ker))addr;
60         }
61     }
62
63     DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_fp32_wino_conv_4x3_data_kernel)
64
65     static status_t init_conf_common(jit_conv_winograd_conf_t &jcp,
66             const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
67             const memory_desc_wrapper &weights_d,
68             const memory_desc_wrapper &dst_d);
69
70     static status_t init_conf_kernel(
71             jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK);
72
73     jit_conv_winograd_conf_t jcp;
74     void (*gemm_loop_ker)(float *, const float *, const float *, const int);
75     void (*input_transform_data_ker)(jit_wino_transform_call_s *);
76     void (*output_transform_data_ker)(jit_wino_transform_call_s *);
77     void (*weights_transform_data_ker)(jit_wino_transform_call_s *);
78
79 protected:
80     using reg64_t = const Xbyak::Reg64;
81     using reg32_t = const Xbyak::Reg32;
82     enum { typesize = sizeof(float) };
83
84     void gemm_loop_generate();
85     void input_transform_data_ker_generate();
86     void output_transform_data_ker_generate();
87     void weights_transform_data_ker_generate();
88
89     /* registers used for GEMM */
90     reg64_t reg_dstC = abi_param1;
91     reg64_t reg_srcA = abi_param2;
92     reg64_t reg_srcB = abi_param3;
93     reg64_t reg_is_beta_zero = abi_param4;
94
95     reg64_t reg_dimM_block_loop_cnt = r10;
96     reg64_t reg_dimK_block_loop_cnt = r11;
97
98     /* registers used for transforms*/
99     reg64_t param = abi_param1;
100
101     /* registers used for output_transform_data_ker */
102     reg64_t oreg_temp = abi_not_param1;
103     reg64_t oreg_Ow = r9;
104     reg64_t oreg_src = r11;
105     reg64_t oreg_tile_block = r12;
106     reg64_t oreg_tile_block_ur = r13;
107     reg64_t oreg_nb_tile_block_ur = r14;
108     reg64_t oreg_O = r8;
109     reg64_t oreg_T = r10;
110     reg64_t oreg_dst = r11;
111     reg64_t oreg_ydim = r14;
112     reg64_t oreg_xdim = r15;
113     reg64_t oreg_out_j = r12;
114     reg64_t oreg_bias = rbx;
115     reg64_t imm_addr64 = rax;
116
117     /* registers used for input_transform_data_ker */
118     reg64_t ireg_temp = abi_not_param1;
119     reg64_t ireg_jtiles = rax;
120     reg64_t ireg_itiles = rbx;
121     reg64_t ireg_I = r8;
122     reg64_t ireg_src = r13;
123     reg64_t ireg_ydim = r14;
124     reg64_t ireg_xdim = r15;
125     reg64_t ireg_inp_j = r12;
126     reg64_t ireg_inp_i = rdx;
127     reg64_t ireg_mask_j = r11;
128     reg64_t ireg_mask = rsi;
129     reg32_t ireg_mask_32 = esi;
130     reg64_t ireg_zero = r9;
131     reg64_t ireg_Iw = r9;
132     reg64_t ireg_T = r10;
133     reg64_t ireg_tile_block = r12;
134     reg64_t ireg_tile_block_ur = r13;
135     reg64_t ireg_nb_tile_block_ur = r14;
136     reg64_t ireg_output = r15;
137
138     /* registers used for wei transform */
139     reg64_t wreg_temp = abi_not_param1;
140     reg64_t wreg_F = r8;
141     reg64_t wreg_src = r9;
142     reg64_t wreg_MT = r15;
143     reg64_t wreg_M = r14;
144     reg64_t wreg_dst = r10;
145     reg64_t wreg_dst_aux = r9;
146     reg64_t wreg_dst_idx = r8;
147     reg64_t wreg_Fw = r11;
148     reg64_t wreg_T = r12;
149     reg64_t wreg_cnt_j = rdx;
150     reg64_t wreg_F_aux = r14;
151     reg64_t wreg_Fw_aux = r15;
152 };
153
154 struct jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel
155         : _jit_avx512_core_fp32_wino_conv_4x3_data_kernel {
156     using _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::
157             _jit_avx512_core_fp32_wino_conv_4x3_data_kernel;
158
159     static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr);
160
161     static status_t init_conf(jit_conv_winograd_conf_t &jcp,
162             const convolution_desc_t &cd, const cpu_memory_t::pd_t &src_pd,
163             cpu_memory_t::pd_t &weights_pd, const cpu_memory_t::pd_t &dst_pd,
164             const primitive_attr_t &attr);
165 };
166
167 struct jit_avx512_core_fp32_wino_conv_4x3_bwd_data_kernel
168         : public _jit_avx512_core_fp32_wino_conv_4x3_data_kernel {
169     using _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::
170             _jit_avx512_core_fp32_wino_conv_4x3_data_kernel;
171
172     static status_t init_conf(jit_conv_winograd_conf_t &jcp,
173             const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d,
174             const memory_desc_wrapper &weights_d,
175             const memory_desc_wrapper &diff_dst_d);
176 };
177
178 struct jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel
179         : public jit_generator {
180     DECLARE_CPU_JIT_AUX_FUNCTIONS(
181         _jit_avx512_core_conv_winograd_bwd_weights_kernel_f32)
182
183     jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel(
184             jit_conv_winograd_conf_t ajcp)
185         : jcp(ajcp)
186     {
187         //******************* First iter kernel ********************//
188         this->gemm_loop_generate(true);
189         gemm_loop_ker_first_iter = (decltype(gemm_loop_ker_first_iter))this->getCode();
190
191         align();
192         const Xbyak::uint8 *addr = getCurr();
193         this->src_transform_generate();
194         src_transform = (decltype(src_transform))addr;
195
196         if (jcp.with_bias) {
197             align();
198             addr = getCurr();
199             this->diff_dst_transform_generate(true);
200             diff_dst_transform_wbias = (decltype(diff_dst_transform_wbias))addr;
201         }
202
203         align();
204         addr = getCurr();
205         this->diff_dst_transform_generate(false);
206         diff_dst_transform = (decltype(diff_dst_transform))addr;
207
208         if (jcp.sched_policy != WSCHED_WEI_SDGtWo && jcp.tile_block > 1) {
209             align();
210             addr = getCurr();
211             this->gemm_loop_generate(false);
212             gemm_loop_ker = (decltype(gemm_loop_ker))addr;
213         }
214
215         align();
216         addr = getCurr();
217         this->diff_weights_transform_generate(true);
218         diff_weights_transform = (decltype(diff_weights_transform))addr;
219
220         if (jcp.sched_policy == WSCHED_WEI_SDGtWo) {
221             align();
222             addr = getCurr();
223             this->diff_weights_transform_generate(false);
224             diff_weights_transform_accum =
225                 (decltype(diff_weights_transform_accum))addr;
226         };
227     }
228
229     static status_t init_conf(jit_conv_winograd_conf_t &jcp,
230             const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
231             const memory_desc_wrapper &diff_dst_d,
232             const memory_desc_wrapper &diff_weights_d);
233
234     jit_conv_winograd_conf_t jcp;
235     void (*gemm_loop_ker)(float *, const float *, const float *);
236     void (*gemm_loop_ker_first_iter)(float *, const float *, const float *);
237     void (*src_transform)(jit_wino_transform_call_s *);
238     void (*diff_dst_transform)(jit_wino_transform_call_s *);
239     void (*diff_dst_transform_wbias)(jit_wino_transform_call_s *);
240     void (*diff_weights_transform)(jit_wino_transform_call_s *);
241     void (*diff_weights_transform_accum)(jit_wino_transform_call_s *);
242
243 private:
244     using reg64_t = const Xbyak::Reg64;
245     using reg32_t = const Xbyak::Reg32;
246     enum { typesize = sizeof(float) };
247
248     void src_transform_generate();
249     void diff_dst_transform_generate(bool with_bias);
250     void diff_weights_transform_generate(bool first_tile);
251
252     /*registers common to transforms*/
253     reg64_t reg_transp = abi_param1;
254     reg64_t reg_ti = rbx;
255     reg64_t reg_tj = abi_not_param1;
256     reg64_t reg_src = r8;
257     reg64_t reg_dst = r9;
258     reg64_t reg_G = rsi; /*TODO: check if this is ok*/
259     reg64_t reg_temp = rsi;
260
261     /*registers common to src/diff_dst transform*/
262     reg64_t reg_I = r10;
263     reg64_t reg_ydim = r11;
264     reg64_t reg_xdim = r12;
265     reg64_t reg_src_offset = r13;
266     reg64_t reg_zero = r14;
267     reg64_t reg_tile_count = r15;
268     reg64_t reg_maski = rsi;
269     reg32_t reg_maski_32 = esi;
270     reg64_t reg_maskj = rdx;
271
272     reg64_t reg_T = rax;
273     reg64_t reg_oc_ur = rax;
274     reg64_t reg_ic_simd = r14;
275     reg64_t reg_bias = r10;
276
277     void gemm_loop_generate(bool is_first_tile);
278
279     reg64_t reg_dstC = abi_param1;
280     reg64_t reg_srcA = abi_param2;
281     reg64_t reg_srcB = abi_param3;
282
283     reg64_t reg_dimM_block_loop_cnt = r9;
284     reg64_t reg_dimN_block_loop_cnt = r10;
285     reg64_t reg_nb_dimN_bcast_ur = r11;
286     reg64_t reg_dimK_block_loop_cnt = r12;
287 };
288 }
289 }
290 }
291
292 #endif