Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_x8s8s32x_conv_kernel.hpp
1 /*******************************************************************************
2 * Copyright 2018-2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef JIT_UNI_X8S8S32X_CONV_KERNEL_HPP
18 #define JIT_UNI_X8S8S32X_CONV_KERNEL_HPP
19
20 #include "c_types_map.hpp"
21 #include "jit_generator.hpp"
22 #include "jit_primitive_conf.hpp"
23 #include "cpu_memory.hpp"
24 #include "jit_uni_eltwise.hpp"
25 #include "jit_uni_depthwise.hpp"
26
27 namespace mkldnn {
28 namespace impl {
29 namespace cpu {
30
31 template <cpu_isa_t isa>
32 struct jit_uni_x8s8s32x_conv_fwd_kernel: public jit_generator {
33     jit_uni_x8s8s32x_conv_fwd_kernel(jit_conv_conf_t ajcp, jit_conv_conf_t ajcp_dw,
34             const primitive_attr_t &attr): jcp(ajcp), jcp_dw(ajcp_dw), attr_(attr)
35     {
36         this->generate();
37         jit_ker = (void (*)(jit_conv_call_s *))this->getCode();
38     }
39
40     ~jit_uni_x8s8s32x_conv_fwd_kernel() {
41         for (auto inj : eltwise_injectors)
42            delete inj;
43         eltwise_injectors.clear();
44
45         for (auto inj : depthwise_injectors)
46             delete inj;
47         depthwise_injectors.clear();
48     }
49
50     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_x8s8s32x_conv_fwd_kernel)
51
52     static bool post_ops_ok(jit_conv_conf_t &jcp,
53             const primitive_attr_t &attr);
54     static status_t init_conf(jit_conv_conf_t &jcp,
55             const convolution_desc_t &cd,
56             cpu_memory_t::pd_t &src_pd,
57             cpu_memory_t::pd_t &weights_pd,
58             cpu_memory_t::pd_t &dst_pd,
59             cpu_memory_t::pd_t &bias_pd,
60             const primitive_attr_t &attr);
61     static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
62             const jit_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw, const primitive_attr_t &attr);
63
64     jit_conv_conf_t jcp;
65     jit_conv_conf_t jcp_dw;
66     const primitive_attr_t &attr_;
67     void (*jit_ker)(jit_conv_call_s *);
68
69 private:
70     using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
71             isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
72     const int vlen = cpu_isa_traits<isa>::vlen;
73     using Ymm = const Xbyak::Ymm;
74     using reg64_t = const Xbyak::Reg64;
75     using reg32_t = const Xbyak::Reg32;
76     using reg8_t = const Xbyak::Reg8;
77
78     reg64_t reg_scales_base = r13;
79     reg64_t reg_bias_base = rbp;
80     reg64_t reg_input_base = r8;
81     reg64_t reg_output_base = r9;
82     reg64_t reg_kernel_base = rbx;
83
84     reg64_t reg_input = rax;
85     reg64_t aux_reg_input = r8;
86     reg64_t aux1_reg_input = r13;
87     reg64_t reg_kernel = rdx;
88     reg64_t aux_reg_kernel = r9;
89     reg64_t aux1_reg_kernel = rbx;
90     reg64_t reg_output = rsi;
91
92     reg64_t reg_kj = r10;
93     reg64_t reg_overflow = r10;
94     reg64_t reg_oi_iter = r11;
95     reg64_t reg_ic_iter = r15;
96     reg64_t reg_compensation_base = abi_not_param1;
97     reg64_t reg_oc_work = r12;
98     reg64_t imm_addr64 = rbx;
99
100     reg8_t reg_tmp_8 = r14b;
101     reg32_t reg_tmp_32 = r14d;
102     reg64_t reg_tmp_64 = r14;
103
104     reg64_t reg_oc_off = r10;
105     reg64_t reg_d_weights = aux_reg_kernel;
106     reg64_t reg_d_bias = aux_reg_input;
107
108     Vmm vmm_one = Vmm(15);
109     Vmm vmm_bias_alpha = Vmm(13);
110     Vmm vmm_shift = Vmm(14);
111     Vmm vmm_bias = Vmm(15);
112     Ymm ymm_tmp = Ymm(10);
113     Vmm vmm_scale = Vmm(12);
114     Vmm vmm_comp = Vmm(12);
115     Vmm vmm_prev_dst = Vmm(12);
116
117     inline Vmm get_src_reg(int idx) { return Vmm(idx + 9); }
118     inline Vmm get_ker_reg(int idx) { return Vmm(idx + 0); }
119     inline Vmm get_tmp_reg(int idx) { return Vmm(idx + 13); }
120     inline Vmm get_acc_reg(int idx) { return Vmm(idx + 1); }
121
122     inline void cvt2ps(data_type_t type_in, Vmm ymm_in, const Xbyak::Operand &op, bool scalar_load);
123     inline void store_dst(const Xbyak::Address &op, Vmm vmm_dst, bool scalar_store);
124
125     inline void apply_filter(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step,
126                              int tail_size, bool h_padded);
127     inline void oh_step_unroll_kw(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step, bool h_padded);
128     inline void kh_loop(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step);
129     inline void width_blk_step(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step);
130     inline void solve_common(int oc_blocks, int oc_step);
131
132     void generate();
133
134     void prepare_table();
135
136     nstl::vector<jit_uni_eltwise_injector_f32<isa>*> eltwise_injectors;
137     nstl::vector<jit_uni_depthwise_injector_f32<isa>*> depthwise_injectors;
138
139     Xbyak::Label l_table;
140 };
141
142 }
143 }
144 }
145
146 #endif