Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_x8s8s32x_conv_kernel.hpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_JIT_AVX512_CORE_X8S8S32X_CONV_KERNEL_HPP
18 #define CPU_JIT_AVX512_CORE_X8S8S32X_CONV_KERNEL_HPP
19
20 #include "c_types_map.hpp"
21 #include "memory_tracking.hpp"
22
23 #include "cpu_memory.hpp"
24
25 #include "jit_generator.hpp"
26 #include "jit_primitive_conf.hpp"
27 #include "jit_uni_eltwise.hpp"
28 #include "jit_uni_depthwise.hpp"
29
30 namespace mkldnn {
31 namespace impl {
32 namespace cpu {
33
34 template<typename Vmm>
35 struct _jit_avx512_core_x8s8s32x_fwd_kernel : public jit_generator {
36     DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_x8s8s32x_conv_fwd_ker_t)
37
38     enum { STATE_FIRST_DST_LOAD = 0x1U };
39
40     _jit_avx512_core_x8s8s32x_fwd_kernel(jit_conv_conf_t ajcp,
41             const primitive_attr_t &attr) : jcp(ajcp), attr_(attr)
42     {
43         generate();
44         jit_ker_ = (void (*)(jit_conv_call_s *))getCode();
45     }
46
47     ~_jit_avx512_core_x8s8s32x_fwd_kernel() {
48         for (auto inj : eltwise_injectors)
49             delete inj;
50         eltwise_injectors.clear();
51
52         for (auto inj : depthwise_injectors)
53             delete inj;
54         depthwise_injectors.clear();
55     }
56
57     jit_conv_conf_t jcp;
58     const primitive_attr_t &attr_;
59     void (*jit_ker_)(jit_conv_call_s *);
60
61 private:
62     nstl::vector<jit_uni_eltwise_injector_f32<avx512_common>*> eltwise_injectors;
63     nstl::vector<jit_uni_depthwise_injector_f32<avx512_common>*> depthwise_injectors;
64
65     enum {
66         typesize = sizeof(float),
67         ker_reg_base_idx = 28,
68         ker_dw_reg_base_idx = 30,
69     };
70     typedef enum {
71         no_last_block,
72         last_ic_block,
73         last_sp_block,
74     } ic_block_t;
75
76     /* data regs */
77     const Xbyak::Reg64 reg_ptr_scales = rax;
78     const Xbyak::Reg64 reg_inp = r8;
79     const Xbyak::Reg64 reg_ker = r9;
80     const Xbyak::Reg64 reg_out = r10;
81     const Xbyak::Reg64 aux_reg_inp = r11;
82     const Xbyak::Reg64 reg_ptr_sum_scale = r11;
83     const Xbyak::Reg64 aux_reg_ker = r12;
84     const Xbyak::Reg64 reg_compensation = r14;
85     /* counter regs */
86     const Xbyak::Reg64 reg_bias_alpha = abi_not_param1;
87     const Xbyak::Reg64 reg_oi = rbx;
88     const Xbyak::Reg64 reg_bias = rdx;
89     const Xbyak::Reg64 reg_oc_blocks = rsi;
90     const Xbyak::Reg64 reg_owb = aux_reg_ker;
91     const Xbyak::Reg64 reg_scratch = reg_compensation;
92     const Xbyak::Reg64 reg_kj = reg_ptr_scales;
93     const Xbyak::Reg64 reg_overflow = reg_ptr_scales;
94     const Xbyak::Reg64 reg_icb = reg_bias;
95
96     const Xbyak::Reg64 reg_d_weights = r15;
97     const Xbyak::Reg64 reg_d_bias = r13;
98
99     const Xbyak::Opmask ktail_mask = Xbyak::Opmask(2);
100     const Xbyak::Opmask kblend_mask = Xbyak::Opmask(3);
101
102     const Vmm vmm_wei = Vmm(31);
103     /* used during bias section of store_output */
104     const Vmm vmm_comp = Vmm(30); // only for signed input
105     const Vmm vmm_bias = Vmm(31);
106     /* used during post_op sum section of store_output */
107     const Vmm vmm_prev_dst = Vmm(31);
108     /* used during write-out section of store_output */
109     const Vmm vmm_zero = Vmm(31);
110
111     /* used in compute_ker (but set during prepare_output) */
112     const Vmm vmm_shift = vmm_comp; // only for signed input
113     /* used in compute_ker (but only for pre-VNNI machines) */
114     const Vmm vmm_tmp = Vmm(28); // not used for depthwise
115     const Vmm vmm_one = Vmm(29); // set at start of kernel, not used for depthwise.
116
117     /* registers use only for depthwise
118        groups are always blocked by 16(padded if needed),
119        hence use only Zmm registers */
120     const Xbyak::Zmm zmm_wei = Xbyak::Zmm(31);
121     Xbyak::Zmm zmm_src;
122     Xbyak::Zmm zmm_permute;
123     Xbyak::Zmm zmm_zero_blend; // used only for fast depthwise
124
125     Vmm vmm_out(int i_ur, int i_oc) {
126         int idx = i_ur + i_oc * jcp.ur_w;
127         assert(idx < (jcp.is_depthwise
128                     ? ker_dw_reg_base_idx : ker_reg_base_idx));
129         return Vmm(idx);
130     }
131     Xbyak::Zmm zmm_out(int i_ur, int i_oc) {
132         int idx = i_ur + i_oc * jcp.ur_w;
133         assert(idx < (jcp.is_depthwise
134                     ? ker_dw_reg_base_idx : ker_reg_base_idx));
135         return Xbyak::Zmm(idx);
136     }
137     Vmm vmm_inp(int i_ic, int nb_x_blocking) {
138         int idx = i_ic + nb_x_blocking * jcp.ur_w;
139         assert(idx < 31);
140         return Vmm(idx);
141     }
142     Vmm vmm_bias_alpha() {
143         int nb_c_block = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
144         return Vmm(nb_c_block * jcp.ur_w);
145     }
146     Xbyak::Xmm xmm_bias_alpha() {
147         int nb_c_block = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
148         return Xbyak::Xmm(nb_c_block * jcp.ur_w);
149     }
150     int get_ow_start(int ki, int pad_l) {
151         return nstl::max(0,
152                 utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w));
153     }
154     int get_ow_end(int ur_w, int ki, int pad_r) {
155         return ur_w - nstl::max(0, utils::div_up(pad_r
156                                                    - (jcp.kw - 1 - ki)
157                                                            * (jcp.dilate_w + 1),
158                                            jcp.stride_w));
159     }
160
161     void prepare_output(int ur_w);
162     void store_output(int ur_w, bool last_oc_block_flag);
163     void compute_ker_dw(
164             int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded);
165     void compute_ker(int ur_w, int pad_l, int pad_r,
166             ic_block_t last_ic_block_flag, bool h_padded = false);
167     void kh_loop(int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag);
168     void icb_loop(
169             int ur_w, int pad_l, int pad_r, bool is_last_spatial_block);
170     void generate();
171     void cvt2ps(data_type_t type_in, Vmm ymm_in, const Xbyak::Operand &op,
172         bool mask_flag);
173     const Vmm vmm_mask(const Vmm vmm_in, bool mask_flag, bool store = false);
174 };
175
176 struct jit_avx512_core_x8s8s32x_fwd_kernel {
177
178     jit_avx512_core_x8s8s32x_fwd_kernel(jit_conv_conf_t ajcp,
179             const primitive_attr_t &attr) :
180         jit_ker(nullptr),
181         zmm_kernel_(nullptr),
182         ymm_kernel_(nullptr),
183         xmm_kernel_(nullptr) {
184             int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.ic_block;
185             switch (ch_block) {
186                 case 16:
187                     zmm_kernel_ =
188                         new _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Zmm>(
189                                 ajcp, attr);
190                     jit_ker = zmm_kernel_->jit_ker_;
191                     return;
192                 case 8:
193                     ymm_kernel_ =
194                         new _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Ymm>(
195                                 ajcp, attr);
196                     jit_ker = ymm_kernel_->jit_ker_;
197                     return;
198                 case 4:
199                     xmm_kernel_ =
200                         new _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Xmm>(
201                                 ajcp, attr);
202                     jit_ker = xmm_kernel_->jit_ker_;
203                     return;
204                 default:
205                     assert(!"invalid channel blocking");
206             }
207     }
208
209     ~jit_avx512_core_x8s8s32x_fwd_kernel() {
210         delete xmm_kernel_;
211         delete ymm_kernel_;
212         delete zmm_kernel_;
213     }
214
215     static bool post_ops_ok(jit_conv_conf_t &jcp,
216             const primitive_attr_t &attr);
217
218     static status_t init_conf(jit_conv_conf_t &jcp,
219             const convolution_desc_t &cd,
220             cpu_memory_t::pd_t &src_pd,
221             cpu_memory_t::pd_t &weights_pd,
222             cpu_memory_t::pd_t &dst_pd,
223             cpu_memory_t::pd_t &bias_pd,
224             const primitive_attr_t &attr,
225             int nthreads);
226     static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
227             const jit_conv_conf_t &jcp, const primitive_attr_t &attr);
228
229     void (*jit_ker)(jit_conv_call_s *);
230     _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Zmm> *zmm_kernel_;
231     _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Ymm> *ymm_kernel_;
232     _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Xmm> *xmm_kernel_;
233 };
234
235 }
236 }
237 }
238
239 #endif