Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_x8s8s32x_dw_conv_kernel.hpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef JIT_UNI_X8S8S32X_DW_CONV_KERNEL_F32_HPP
18 #define JIT_UNI_X8S8S32X_DW_CONV_KERNEL_F32_HPP
19
20 #include "c_types_map.hpp"
21 #include "jit_generator.hpp"
22 #include "jit_primitive_conf.hpp"
23 #include "type_helpers.hpp"
24 #include "jit_uni_eltwise.hpp"
25 #include "jit_uni_depthwise.hpp"
26
27 namespace mkldnn {
28 namespace impl {
29 namespace cpu {
30
31 template <cpu_isa_t isa>
32 struct jit_uni_x8s8s32x_dw_conv_fwd_kernel: public jit_generator {
33     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_fwd_kernel_f32)
34
35     jit_uni_x8s8s32x_dw_conv_fwd_kernel(jit_conv_conf_t ajcp,
36             const primitive_attr_t &attr): jcp(ajcp), attr_(attr) {
37         this->generate();
38         jit_ker = (void (*)(jit_conv_call_s *))this->getCode();
39     }
40
41     ~jit_uni_x8s8s32x_dw_conv_fwd_kernel() {
42         for (auto inj : eltwise_injectors)
43            delete inj;
44         eltwise_injectors.clear();
45
46         for (auto inj : depthwise_injectors)
47             delete inj;
48         depthwise_injectors.clear();
49     }
50
51     static bool post_ops_ok(jit_conv_conf_t &jcp,
52             const primitive_attr_t &attr);
53     static status_t init_conf(jit_conv_conf_t &jcp,
54             const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
55             const memory_desc_wrapper &weights_d,
56             const memory_desc_wrapper &dst_d,
57             const memory_desc_wrapper &bias_pd,
58             const primitive_attr_t &attr);
59
60     jit_conv_conf_t jcp;
61     const primitive_attr_t &attr_;
62     void (*jit_ker)(jit_conv_call_s *);
63
64 private:
65     using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
66         isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
67     using Ymm = const Xbyak::Ymm;
68     using reg64_t = const Xbyak::Reg64;
69     using reg32_t = const Xbyak::Reg32;
70     using reg8_t = const Xbyak::Reg8;
71     const int vlen = cpu_isa_traits<isa>::vlen;
72
73     reg64_t reg_input_base = r10;
74     reg64_t reg_output_base = r9;
75     reg64_t reg_kernel_base = r11;
76     reg64_t reg_ch_work = r13;
77     reg64_t reg_bias_base = abi_not_param1;
78     reg64_t reg_scales_base = rdx;
79
80     reg64_t reg_input = r8;
81     reg64_t reg_kernel = r12;
82     reg64_t aux_reg_input = r9;
83     reg64_t aux1_reg_input = r10;
84     reg64_t aux_reg_kernel = r13;
85     reg64_t aux1_reg_kernel = r11;
86     reg64_t reg_output = r14;
87
88     reg64_t reg_kh = rax;
89     reg64_t reg_kw = rbx;
90     reg64_t iter_kh = rdx;
91     reg64_t iter_kw = rsi;
92     reg64_t reg_ur_w = rbp;
93
94     reg32_t reg_tmp_32 = r15d;
95     reg64_t reg_tmp_64 = r15;
96     reg8_t reg_tmp_8 = r15b;
97
98     reg64_t imm_addr64 = r10;
99
100     reg64_t reg_oc_off = iter_kw;
101     reg64_t reg_d_weights = aux1_reg_kernel;
102     reg64_t reg_d_bias = aux_reg_input;
103
104     Vmm vmm_zero = Vmm(0);
105     Vmm vmm_bias = Vmm(3);
106     Vmm vmm_scale = Vmm(2);
107     Vmm vmm_prev_dst = Vmm(2);
108
109     inline Vmm get_ker_reg(int idx) { return Vmm(idx + 0); }
110     inline Vmm get_src_reg(int idx) { return Vmm(idx + 1); }
111     inline Vmm get_acc_reg(int idx) { return Vmm(idx + 4); }
112
113     inline void cvt2ps(data_type_t type_in, Vmm vmm_in, const Xbyak::Operand &op, bool scalar_load);
114     inline void store_dst(const Xbyak::Address &op, Vmm vmm_dst, bool scalar_store);
115
116     inline void load_src(int ur_ch_blocks, int ch_step, int ur_w);
117     inline void apply_filter(int ur_ch_blocks, int ch_step, int ur_w);
118     inline void apply_filter_unrolled(int ur_ch_blocks, int ch_step, int ur_w);
119     inline void store_dst(int ur_ch_blocks, int ch_step, int ur_w);
120     inline void loop_body(int ur_ch_blocks, int ch_step);
121
122     inline void prepare_table();
123     void generate();
124
125     nstl::vector<jit_uni_eltwise_injector_f32<isa>*> eltwise_injectors;
126     nstl::vector<jit_uni_depthwise_injector_f32<isa>*> depthwise_injectors;
127
128     Xbyak::Label l_table;
129 };
130
131 }
132 }
133 }
134
135 #endif