Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_conv_winograd_kernel_f32.hpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef JIT_AVX512_CORE_CONV_WINOGRAD_KERNEL_F32_HPP
18 #define JIT_AVX512_CORE_CONV_WINOGRAD_KERNEL_F32_HPP
19
20 #include "c_types_map.hpp"
21 #include "cpu_memory.hpp"
22
23 #include "jit_generator.hpp"
24 #include "jit_primitive_conf.hpp"
25
26 #include "jit_avx512_common_conv_winograd_kernel_f32.hpp"
27
28 namespace mkldnn {
29 namespace impl {
30 namespace cpu {
31
32 struct _jit_avx512_core_conv_winograd_data_kernel_f32 : public jit_generator {
33     _jit_avx512_core_conv_winograd_data_kernel_f32(
34             jit_conv_winograd_conf_t ajcp)
35         : jcp(ajcp)
36     {
37         {
38             this->weights_transform_data_ker_generate();
39             weights_transform_data_ker
40                     = (decltype(weights_transform_data_ker)) this->getCode();
41         }
42         {
43             align();
44             const Xbyak::uint8 *addr = getCurr();
45             this->input_transform_data_ker_generate();
46             input_transform_data_ker = (decltype(input_transform_data_ker))addr;
47         }
48         {
49             align();
50             const Xbyak::uint8 *addr = getCurr();
51             this->output_transform_data_ker_generate();
52             output_transform_data_ker
53                 = (decltype(output_transform_data_ker))addr;
54         }
55         {
56             align();
57             const Xbyak::uint8 *addr = getCurr();
58             this->gemm_loop_generate();
59             gemm_loop_ker = (decltype(gemm_loop_ker))addr;
60         }
61     }
62
63     DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_conv_winograd_data_kernel_f32)
64
65     static status_t init_conf_common(jit_conv_winograd_conf_t &jcp,
66             const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
67             const memory_desc_wrapper &weights_d,
68             const memory_desc_wrapper &dst_d);
69
70     static status_t init_conf_kernel(
71             jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK);
72
73     jit_conv_winograd_conf_t jcp;
74     void (*gemm_loop_ker)(float *, const float *, const float *, const int);
75     void (*input_transform_data_ker)(jit_wino_transform_call_s *);
76     void (*output_transform_data_ker)(jit_wino_transform_call_s *);
77     void (*weights_transform_data_ker)(jit_wino_transform_call_s *);
78
79 protected:
80     using reg64_t = const Xbyak::Reg64;
81     using reg32_t = const Xbyak::Reg32;
82     enum { typesize = sizeof(float) };
83
84     void gemm_loop_generate();
85     void input_transform_data_ker_generate();
86     void output_transform_data_ker_generate();
87     void weights_transform_data_ker_generate();
88
89     /* registers used for GEMM */
90     reg64_t reg_dstC = abi_param1;
91     reg64_t reg_srcA = abi_param2;
92     reg64_t reg_srcB = abi_param3;
93     reg64_t reg_is_beta_zero = abi_param4;
94
95     reg64_t reg_dimM_block_loop_cnt = r10;
96     reg64_t reg_dimK_block_loop_cnt = r11;
97
98     /* registers used for transforms*/
99     reg64_t param = abi_param1;
100
101     /* registers used for output_transform_data_ker */
102     reg64_t oreg_temp = rcx;
103     reg64_t oreg_Ow = r9;
104     reg64_t oreg_src = r11;
105     reg64_t oreg_tile_block = r12;
106     reg64_t oreg_tile_block_ur = r13;
107     reg64_t oreg_nb_tile_block_ur = r14;
108     reg64_t oreg_O = r8;
109     reg64_t oreg_T = r10;
110     reg64_t oreg_dst = r11;
111     reg64_t oreg_ydim = r14;
112     reg64_t oreg_xdim = r15;
113     reg64_t oreg_out_j = r12;
114     reg64_t oreg_bias = rbx;
115     reg64_t imm_addr64 = rax;
116
117     /* registers used for input_transform_data_ker */
118     reg64_t ireg_temp = rcx;
119     reg64_t ireg_jtiles = rax;
120     reg64_t ireg_itiles = rbx;
121     reg64_t ireg_I = r8;
122     reg64_t ireg_src = r13;
123     reg64_t ireg_ydim = r14;
124     reg64_t ireg_xdim = r15;
125     reg64_t ireg_inp_j = r12;
126     reg64_t ireg_inp_i = rdx;
127     reg64_t ireg_mask_j = r11;
128     reg64_t ireg_mask = rsi;
129     reg32_t ireg_mask_32 = esi;
130     reg64_t ireg_zero = r9;
131     reg64_t ireg_Iw = r9;
132     reg64_t ireg_T = r10;
133     reg64_t ireg_tile_block = r12;
134     reg64_t ireg_tile_block_ur = r13;
135     reg64_t ireg_nb_tile_block_ur = r14;
136     reg64_t ireg_output = r15;
137
138     /* registers used for wei transform */
139     reg64_t wreg_temp = rcx;
140     reg64_t wreg_F = r8;
141     reg64_t wreg_src = r9;
142     reg64_t wreg_MT = r15;
143     reg64_t wreg_M = r14;
144     reg64_t wreg_dst = r10;
145     reg64_t wreg_dst_aux = r9;
146     reg64_t wreg_dst_idx = r8;
147     reg64_t wreg_Fw = r11;
148     reg64_t wreg_T = r12;
149     reg64_t wreg_cnt_j = rdx;
150     reg64_t wreg_F_aux = r14;
151     reg64_t wreg_Fw_aux = r15;
152 };
153
154 struct jit_avx512_core_conv_winograd_fwd_kernel_f32
155         : _jit_avx512_core_conv_winograd_data_kernel_f32 {
156     using _jit_avx512_core_conv_winograd_data_kernel_f32::
157             _jit_avx512_core_conv_winograd_data_kernel_f32;
158
159     static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr);
160
161     static status_t init_conf(jit_conv_winograd_conf_t &jcp,
162             const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
163             const memory_desc_wrapper &weights_d,
164             const memory_desc_wrapper &dst_d, const primitive_attr_t &attr,
165             bool with_relu = false, float relu_negative_slope = 0.);
166 };
167
168 struct jit_avx512_core_conv_winograd_bwd_data_kernel_f32
169         : public _jit_avx512_core_conv_winograd_data_kernel_f32 {
170     using _jit_avx512_core_conv_winograd_data_kernel_f32::
171             _jit_avx512_core_conv_winograd_data_kernel_f32;
172
173     static status_t init_conf(jit_conv_winograd_conf_t &jcp,
174             const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d,
175             const memory_desc_wrapper &weights_d,
176             const memory_desc_wrapper &diff_dst_d);
177 };
178
179 struct jit_avx512_core_conv_winograd_bwd_weights_kernel_f32
180         : public jit_generator {
181     DECLARE_CPU_JIT_AUX_FUNCTIONS(
182         _jit_avx512_core_conv_winograd_bwd_weights_kernel_f32)
183
184     jit_avx512_core_conv_winograd_bwd_weights_kernel_f32(
185             jit_conv_winograd_conf_t ajcp)
186         : jcp(ajcp)
187     {
188         //******************* First iter kernel ********************//
189         this->gemm_loop_generate(true);
190         gemm_loop_ker_first_iter = (decltype(gemm_loop_ker_first_iter))this->getCode();
191         
192         align();
193         const Xbyak::uint8 *addr = getCurr();
194         this->src_transform_generate();
195         src_transform = (decltype(src_transform))addr;
196
197         if (jcp.with_bias) {
198             align();
199             addr = getCurr();
200             this->diff_dst_transform_generate(true);
201             diff_dst_transform_wbias = (decltype(diff_dst_transform_wbias))addr;
202         }
203
204         align();
205         addr = getCurr();
206         this->diff_dst_transform_generate(false);
207         diff_dst_transform = (decltype(diff_dst_transform))addr;
208
209         if (jcp.sched_policy != WSCHED_WEI_SDGtWo && jcp.tile_block > 1) {
210             align();
211             addr = getCurr();
212             this->gemm_loop_generate(false);
213             gemm_loop_ker = (decltype(gemm_loop_ker))addr;
214         }
215
216         align();
217         addr = getCurr();
218         this->diff_weights_transform_generate(true);
219         diff_weights_transform = (decltype(diff_weights_transform))addr;
220
221         if (jcp.sched_policy == WSCHED_WEI_SDGtWo) {
222             align();
223             addr = getCurr();
224             this->diff_weights_transform_generate(false);
225             diff_weights_transform_accum =
226                 (decltype(diff_weights_transform_accum))addr;
227         };
228     }
229
230     static status_t init_conf(jit_conv_winograd_conf_t &jcp,
231             const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
232             const memory_desc_wrapper &diff_dst_d,
233             const memory_desc_wrapper &diff_weights_d);
234
235     jit_conv_winograd_conf_t jcp;
236     void (*gemm_loop_ker)(float *, const float *, const float *);
237     void (*gemm_loop_ker_first_iter)(float *, const float *, const float *);
238     void (*src_transform)(jit_wino_transform_call_s *);
239     void (*diff_dst_transform)(jit_wino_transform_call_s *);
240     void (*diff_dst_transform_wbias)(jit_wino_transform_call_s *);
241     void (*diff_weights_transform)(jit_wino_transform_call_s *);
242     void (*diff_weights_transform_accum)(jit_wino_transform_call_s *);
243
244 private:
245     using reg64_t = const Xbyak::Reg64;
246     using reg32_t = const Xbyak::Reg32;
247     enum { typesize = sizeof(float) };
248
249     void src_transform_generate();
250     void diff_dst_transform_generate(bool with_bias);
251     void diff_weights_transform_generate(bool first_tile);
252
253     /*registers common to transforms*/
254     reg64_t reg_transp = abi_param1;
255     reg64_t reg_ti = rbx;
256     reg64_t reg_tj = rcx;
257     reg64_t reg_src = r8;
258     reg64_t reg_dst = r9;
259     reg64_t reg_G = rsi; /*TODO: check if this is ok*/
260     reg64_t reg_temp = rsi;
261
262     /*registers common to src/diff_dst transform*/
263     reg64_t reg_I = r10;
264     reg64_t reg_ydim = r11;
265     reg64_t reg_xdim = r12;
266     reg64_t reg_src_offset = r13;
267     reg64_t reg_zero = r14;
268     reg64_t reg_tile_count = r15;
269     reg64_t reg_maski = rsi;
270     reg32_t reg_maski_32 = esi;
271     reg64_t reg_maskj = rdx;
272
273     reg64_t reg_T = rax;
274     reg64_t reg_oc_ur = rax;
275     reg64_t reg_ic_simd = r14;
276     reg64_t reg_bias = r10;
277
278     void gemm_loop_generate(bool is_first_tile);
279
280     reg64_t reg_dstC = abi_param1;
281     reg64_t reg_srcA = abi_param2;
282     reg64_t reg_srcB = abi_param3;
283
284     reg64_t reg_dimM_block_loop_cnt = r9;
285     reg64_t reg_dimN_block_loop_cnt = r10;
286     reg64_t reg_nb_dimN_bcast_ur = r11;
287     reg64_t reg_dimK_block_loop_cnt = r12;
288 };
289 }
290 }
291 }
292
293 #endif