updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_def_conv_kernel_f32.hpp
1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef JIT_UNI_DEF_CONV_KERNEL_F32_HPP
18 #define JIT_UNI_DEF_CONV_KERNEL_F32_HPP
19
20 #include "c_types_map.hpp"
21 #include "jit_generator.hpp"
22 #include "jit_primitive_conf.hpp"
23 #include "cpu_memory.hpp"
24
25 namespace mkldnn {
26 namespace impl {
27 namespace cpu {
28
29 template <cpu_isa_t isa>
30 struct jit_uni_def_conv_fwd_kernel_f32: public jit_generator {
31     jit_uni_def_conv_fwd_kernel_f32(jit_def_conv_conf_t ajcp,
32             const primitive_attr_t &attr): jcp(ajcp), attr_(attr)
33     {
34         this->generate();
35         jit_ker = (void (*)(jit_def_conv_call_s *))this->getCode();
36     }
37
38     ~jit_uni_def_conv_fwd_kernel_f32() {}
39
40     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_def_conv_fwd_kernel)
41
42     static bool post_ops_ok(jit_def_conv_conf_t &jcp,
43             const primitive_attr_t &attr);
44     static status_t init_conf(jit_def_conv_conf_t &jcp,
45             const deformable_convolution_desc_t &cd,
46             cpu_memory_t::pd_t &src_pd,
47             cpu_memory_t::pd_t &offsets_pd,
48             cpu_memory_t::pd_t &weights_pd,
49             cpu_memory_t::pd_t &dst_pd,
50             cpu_memory_t::pd_t &bias_pd,
51             const primitive_attr_t &attr);
52     static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
53             const jit_def_conv_conf_t &jcp, const primitive_attr_t &attr);
54
55     jit_def_conv_conf_t jcp;
56     const primitive_attr_t &attr_;
57     void (*jit_ker)(jit_def_conv_call_s *);
58
59 private:
60     using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
61             isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
62     const int vlen = cpu_isa_traits<isa>::vlen;
63     using Ymm = const Xbyak::Ymm;
64     using Xmm = const Xbyak::Xmm;
65     using reg64_t = const Xbyak::Reg64;
66     using reg32_t = const Xbyak::Reg32;
67     using reg8_t = const Xbyak::Reg8;
68
69     reg64_t reg_input = r8;
70     reg64_t reg_def_off = r9;
71     reg64_t reg_kernel = r10;
72     reg64_t reg_bias = r11;
73     reg64_t reg_output = r12;
74     reg64_t reg_oh_pos = r13;
75     reg64_t aux_reg_bias = rsi;
76     reg64_t reg_ow_pos = rdx;
77     reg64_t aux_reg_output = reg_ow_pos;
78     reg64_t reg_dg_iter = reg_output;
79     reg64_t aux_reg_input = rax;
80     reg64_t aux2_reg_input = reg_kernel;
81     reg64_t reg_ic_iter = rbx;
82     reg64_t reg_oc_work = reg_ic_iter;
83     reg64_t aux_reg_def_off = reg_bias;
84     reg64_t reg_input_buffer = abi_not_param1;
85     reg64_t aux_reg_input_buffer = r14;
86     reg32_t reg_tmp_32 = r15d;
87     reg64_t reg_tmp_64 = r15;
88     reg64_t reg_table = rbp;
89     reg64_t aux_reg_kernel = reg_table;
90     reg64_t aux2_reg_kernel = r15;
91     reg64_t aux2_reg_input_buffer = aux_reg_bias;
92     reg64_t aux3_reg_input_buffer = reg_input;
93
94     Xbyak::Opmask ktail_mask = Xbyak::Opmask(2);
95
96     inline Xbyak::Address table_val(int index)
97     { return ptr[reg_table + index * vlen]; }
98
99     inline Vmm get_vmm_ker(int idx) { return Vmm(idx + 0); }
100     inline Vmm get_vmm_src(int idx) { return Vmm(idx + 1); }
101     inline Vmm get_vmm_acc(int idx) { return Vmm(idx + jcp.ur_w + 1); }
102     inline Ymm get_ymm_acc(int idx) { return Ymm(idx + jcp.ur_w + 1); }
103     inline Xmm get_xmm_acc(int idx) { return Xmm(idx + jcp.ur_w + 1); }
104
105     inline void interpolate_input(int ow_step);
106     inline void ic_loop(int ow_step, int oc_blocks_step, int oc_step);
107     inline void store_output(int ow_step, int oc_blocks_step, int oc_step);
108     inline void oc_loop(int ow_step);
109     inline void ow_loop();
110     inline void apply_filter(int ow_step, int oc_blocks_step, int oc_step, int ic_step);
111     inline void init_accums(int ow_step, int oc_blocks_step, int oc_step);
112
113     Xbyak::Label l_table;
114
115     void generate();
116
117     void prepare_table();
118 };
119
120 }
121 }
122 }
123
124 #endif