Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_common_conv_winograd_kernel_f32.hpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef JIT_AVX512_COMMON_CONV_WINOGRAD_KERNEL_F32_HPP
18 #define JIT_AVX512_COMMON_CONV_WINOGRAD_KERNEL_F32_HPP
19
20 #include "c_types_map.hpp"
21 #include "cpu_memory.hpp"
22
23 #include "jit_generator.hpp"
24 #include "jit_primitive_conf.hpp"
25
26 namespace mkldnn {
27 namespace impl {
28 namespace cpu {
29
30 //alpha determines the output tile_size
31 constexpr int alpha = 6;
32 constexpr int tile_size = 4;
33 //simd length used for vectorization
34 constexpr int simd_w = 16;
35
36 struct _jit_avx512_common_conv_winograd_data_kernel_f32 : public jit_generator {
37     _jit_avx512_common_conv_winograd_data_kernel_f32(
38             jit_conv_winograd_conf_t ajcp)
39         : jcp(ajcp)
40     {
41         //******************* First iter kernel ********************//
42         this->gemm_loop_generate(true);
43         gemm_loop_ker_first_iter
44                 = (decltype(gemm_loop_ker_first_iter)) this->getCode();
45
46         //************** Subsequent iterations kernel **************//
47         if (jcp.dimK_nb_block > 1) {
48             align();
49             const Xbyak::uint8 *addr = getCurr();
50             this->gemm_loop_generate(false);
51             gemm_loop_ker = (decltype(gemm_loop_ker))addr;
52         }
53     }
54
55     DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_winograd_data_kernel_f32)
56
57     static status_t init_conf_common(jit_conv_winograd_conf_t &jcp,
58             const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
59             const memory_desc_wrapper &weights_d,
60             const memory_desc_wrapper &dst_d);
61
62     static status_t init_conf_kernel(
63             jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK);
64
65     jit_conv_winograd_conf_t jcp;
66     void (*gemm_loop_ker)(float *, const float *, const float *);
67     void (*gemm_loop_ker_first_iter)(float *, const float *, const float *);
68
69 protected:
70     using reg64_t = const Xbyak::Reg64;
71     enum { typesize = sizeof(float) };
72
73     void gemm_loop_generate(bool is_beta_zero);
74
75     /* registers used for GEMM */
76     reg64_t reg_dstC = abi_param1;
77     reg64_t reg_srcA = abi_param2;
78     reg64_t reg_srcB = abi_param3;
79
80     reg64_t reg_dimM_block_loop_cnt = r10;
81     reg64_t reg_dimK_block_loop_cnt = r11;
82 };
83
84 struct jit_avx512_common_conv_winograd_fwd_kernel_f32
85         : _jit_avx512_common_conv_winograd_data_kernel_f32 {
86     using _jit_avx512_common_conv_winograd_data_kernel_f32::
87             _jit_avx512_common_conv_winograd_data_kernel_f32;
88
89     static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr);
90
91     static status_t init_conf(jit_conv_winograd_conf_t &jcp,
92             const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
93             const memory_desc_wrapper &weights_d,
94             const memory_desc_wrapper &dst_d, const primitive_attr_t &attr);
95 };
96
97 struct jit_avx512_common_conv_winograd_bwd_data_kernel_f32
98         : public _jit_avx512_common_conv_winograd_data_kernel_f32 {
99     using _jit_avx512_common_conv_winograd_data_kernel_f32::
100             _jit_avx512_common_conv_winograd_data_kernel_f32;
101
102     static status_t init_conf(jit_conv_winograd_conf_t &jcp,
103             const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d,
104             const memory_desc_wrapper &weights_d,
105             const memory_desc_wrapper &diff_dst_d);
106 };
107
108 struct jit_avx512_common_conv_winograd_bwd_weights_kernel_f32
109         : public jit_generator {
110     DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_winograd_bwd_weights_kernel_f32)
111
112     jit_avx512_common_conv_winograd_bwd_weights_kernel_f32(
113             jit_conv_winograd_conf_t ajcp)
114         : jcp(ajcp)
115     {
116
117         //******************* First iter kernel ********************//
118         {
119             align();
120             const Xbyak::uint8 *addr = getCurr();
121             this->gemm_loop_generate(true);
122             gemm_loop_ker_first_iter = (decltype(gemm_loop_ker_first_iter))addr;
123         }
124
125         if (jcp.tile_block > 1) {
126             align();
127             const Xbyak::uint8 *addr = getCurr();
128             this->gemm_loop_generate(false);
129             gemm_loop_ker = (decltype(gemm_loop_ker))addr;
130         }
131
132         if (jcp.ver == ver_4fma) {
133             align();
134             const Xbyak::uint8 *addr = getCurr();
135             this->transpose_ker_generate();
136             transpose_4fma_ker = (decltype(transpose_4fma_ker))addr;
137         }
138     }
139
140     static status_t init_conf(jit_conv_winograd_conf_t &jcp,
141             const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
142             const memory_desc_wrapper &diff_dst_d,
143             const memory_desc_wrapper &diff_weights_d);
144
145     jit_conv_winograd_conf_t jcp;
146     void (*gemm_loop_ker)(float *, const float *, const float *);
147     void (*gemm_loop_ker_first_iter)(float *, const float *, const float *);
148     void (*transpose_4fma_ker)(float *, float *);
149
150 private:
151     using reg64_t = const Xbyak::Reg64;
152     enum { typesize = sizeof(float) };
153
154     void gemm_loop_generate(bool is_first_tile);
155     void transpose_ker_generate();
156
157     reg64_t reg_origB = abi_param2;
158     reg64_t reg_transB = abi_param1;
159
160     reg64_t reg_dstC = abi_param1;
161     reg64_t reg_srcA_const = abi_param2;
162     reg64_t reg_srcB = abi_param3;
163
164     reg64_t reg_sp = rsp;
165     reg64_t reg_srcA = r9;
166     reg64_t reg_nb_ic = r10;
167     reg64_t reg_loop_cpt = r11;
168     reg64_t reg_transB_idx = r13;
169
170     /* Registers used by new kernel */
171     reg64_t reg_dimM_block_loop_cnt = r10;
172     reg64_t reg_dimK_block_loop_cnt = r12;
173     reg64_t reg_dimN_block_loop_cnt = r11;
174 };
175 }
176 }
177 }
178
179 #endif