updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_planar_conv_kernel_f32.hpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef JIT_UNI_PLANAR_CONV_KERNEL_F32_HPP
18 #define JIT_UNI_PLANAR_CONV_KERNEL_F32_HPP
19
20 #include "c_types_map.hpp"
21 #include "jit_generator.hpp"
22 #include "jit_primitive_conf.hpp"
23 #include "jit_uni_eltwise.hpp"
24 #include "jit_uni_depthwise.hpp"
25
26 namespace mkldnn {
27 namespace impl {
28 namespace cpu {
29
30 template <cpu_isa_t isa>
31 struct jit_uni_planar_conv_fwd_kernel_f32: public jit_generator {
32     jit_uni_planar_conv_fwd_kernel_f32(jit_conv_conf_t ajcp,
33             const primitive_attr_t &attr): jcp(ajcp), attr_(attr)
34     {
35         this->generate();
36         jit_ker = (void (*)(jit_conv_call_s *))this->getCode();
37     }
38
39     ~jit_uni_planar_conv_fwd_kernel_f32() {
40         for (auto inj : eltwise_injectors)
41            delete inj;
42         eltwise_injectors.clear();
43
44         for (auto inj : depthwise_injectors)
45             delete inj;
46         depthwise_injectors.clear();
47     }
48
49     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_planar_conv_fwd_kernel_f32)
50
51     static bool post_ops_ok(jit_conv_conf_t &jcp,
52             const primitive_attr_t &attr);
53     static status_t init_conf(jit_conv_conf_t &jcp,
54             const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
55             const memory_desc_wrapper &weights_d,
56             const memory_desc_wrapper &dst_d,
57             const primitive_attr_t &attr);
58
59     jit_conv_conf_t jcp;
60     const primitive_attr_t &attr_;
61     void (*jit_ker)(jit_conv_call_s *);
62
63 private:
64     using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
65             isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
66     using reg64_t = const Xbyak::Reg64;
67     using reg32_t = const Xbyak::Reg32;
68     const Xbyak::AddressFrame &vmmword = (isa == sse42)
69         ? xword : (isa == avx2) ? yword : zword;
70
71     reg64_t reg_input = r8;
72     reg64_t reg_kernel = r9;
73     reg64_t reg_output = r10;
74
75     reg64_t aux_reg_input_h = r11;
76     reg64_t aux_reg_kernel_h = r12;
77
78     reg64_t aux_reg_input_w = r13;
79     reg64_t aux_reg_kernel_w = r14;
80
81     reg64_t aux_reg_inp_d = r9;
82     reg64_t aux_reg_ker_d = r10;
83
84     reg64_t reg_kd = rbx;
85     reg64_t reg_kh = rdx;
86     reg64_t reg_kw = rsi;
87
88     reg64_t kh_iter = rax;
89     reg64_t kw_iter = abi_not_param1;
90
91     reg64_t reg_bias = r13;
92     reg64_t reg_long_offt = r15;
93     reg32_t reg_ci_flag = r15d;
94
95     reg64_t reg_d_weights = r15;
96     reg64_t reg_d_bias = kh_iter;
97
98     reg64_t reg_ow = rbp;
99
100     reg64_t reg_oh_blocks = aux_reg_kernel_w;
101
102     reg64_t reg_wj = aux_reg_input_w;
103
104     Vmm vmm_ker = Vmm(15);
105     Vmm vmm_tmp = Vmm(15);
106     Vmm vmm_src = Vmm(14);
107     Xbyak::Xmm xmm_ker = Xbyak::Xmm(15);
108     Xbyak::Xmm xmm_tmp = Xbyak::Xmm(15);
109     Xbyak::Xmm xmm_src = Xbyak::Xmm(14);
110
111     nstl::vector<jit_uni_eltwise_injector_f32<isa>*> eltwise_injectors;
112     nstl::vector<jit_uni_depthwise_injector_f32<isa>*> depthwise_injectors;
113
114     inline void load_src(int ur_h, int ur_w);
115     inline void filter(int ur_h);
116     inline void filter_unrolled(int ur_h, int ur_w);
117     inline void apply_filter(int ur_h, int ur_w);
118     inline void apply_postprocess(int ur_h, int ur_w);
119     inline void store_dst(int ur_h, int ur_w);
120     inline void solve_common(int ur_h);
121
122     inline void filter_scalar(int ur_h);
123     inline void load_src_scalar(int ur_h);
124     inline void apply_filter_scalar(int ur_h);
125     inline void apply_postprocess_scalar(int ur_h);
126     inline void store_dst_scalar(int ur_h);
127
128     void generate();
129 };
130
131 }
132 }
133 }
134
135 #endif