Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_sse42_conv_kernel_f32.hpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef JIT_SSE42_CONV_KERNEL_F32_HPP
18 #define JIT_SSE42_CONV_KERNEL_F32_HPP
19
20 #include "c_types_map.hpp"
21 #include "cpu_memory.hpp"
22 #include "jit_generator.hpp"
23 #include "jit_primitive_conf.hpp"
24 #include "jit_uni_eltwise.hpp"
25 #include "jit_uni_depthwise.hpp"
26
27 namespace mkldnn {
28 namespace impl {
29 namespace cpu {
30
31 struct jit_sse42_conv_fwd_kernel_f32: public jit_generator {
32     jit_sse42_conv_fwd_kernel_f32(jit_conv_conf_t ajcp, jit_conv_conf_t ajcp_dw,
33             const primitive_attr_t &attr)
34         : jcp(ajcp), jcp_dw(ajcp_dw), attr_(attr)
35     {
36         this->generate();
37         jit_ker = (void (*)(jit_conv_call_s *))this->getCode();
38     }
39
40     ~jit_sse42_conv_fwd_kernel_f32() {
41         for (auto inj : eltwise_injectors)
42             delete inj;
43         eltwise_injectors.clear();
44
45         for (auto inj : depthwise_injectors)
46             delete inj;
47         depthwise_injectors.clear();
48     }
49
50     static bool post_ops_ok(jit_conv_conf_t &jcp,
51             const primitive_attr_t &attr);
52
53     static status_t init_conf(jit_conv_conf_t &jcp,
54             const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
55             const memory_desc_wrapper &weights_d,
56             const memory_desc_wrapper &dst_d, const primitive_attr_t &attr);
57     static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
58             const jit_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw = jit_conv_conf_t());
59
60     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse42_conv_fwd_kernel_f32)
61     jit_conv_conf_t jcp;
62     jit_conv_conf_t jcp_dw;
63     const primitive_attr_t &attr_;
64     void (*jit_ker)(jit_conv_call_s *);
65
66 private:
67     using reg64_t = const Xbyak::Reg64;
68     reg64_t reg_input = rax;
69     reg64_t aux_reg_input = r8;
70     reg64_t reg_kernel = rdx;
71     reg64_t aux_reg_kernel = r9;
72     reg64_t reg_output = rsi;
73     reg64_t reg_bias = rbx;
74
75     reg64_t kj = r10;
76     reg64_t oi_iter = r11;
77     reg64_t ki_iter = r12;
78     reg64_t reg_kh = abi_not_param1;
79     reg64_t simd_iter = r15;
80     reg64_t reg_oc_blocks = r14;
81     reg64_t imm_addr64 = reg_oc_blocks;
82     Xbyak::Reg32 reg_ci_flag = r13d;
83
84     reg64_t reg_d_weights = imm_addr64;
85     reg64_t reg_d_bias = ki_iter;
86     reg64_t reg_oc_off = abi_param1;
87
88     nstl::vector<jit_uni_eltwise_injector_f32<sse42>*> eltwise_injectors;
89     nstl::vector<jit_uni_depthwise_injector_f32<sse42>*> depthwise_injectors;
90
91     inline void oh_step_unroll_kw(int ur_w, int pad_l, int pad_r,
92             int oc_blocks);
93     inline void oh_step_nopad(int ur_w, int pad_l, int pad_r, int oc_blocks);
94     inline void width_blk_step(int ur_w, int pad_l, int pad_r, int oc_blocks);
95     inline void solve_common(int oc_blocks);
96
97     void generate();
98 };
99
100 }
101 }
102 }
103
104 #endif