Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx2_1x1_conv_kernel_f32.hpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef JIT_AVX2_1x1_CONV_KERNEL_F32_HPP
18 #define JIT_AVX2_1x1_CONV_KERNEL_F32_HPP
19
20 #include "c_types_map.hpp"
21 #include "memory_tracking.hpp"
22
23 #include "cpu_memory.hpp"
24 #include "jit_generator.hpp"
25 #include "jit_primitive_conf.hpp"
26 #include "jit_uni_eltwise.hpp"
27 #include "jit_uni_depthwise.hpp"
28
29 namespace mkldnn {
30 namespace impl {
31 namespace cpu {
32
33 struct jit_avx2_1x1_conv_kernel_f32: public jit_generator {
34     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_1x1_conv_kernel_f32)
35
36     jit_avx2_1x1_conv_kernel_f32(jit_1x1_conv_conf_t ajcp, jit_conv_conf_t ajcp_dw,
37            const primitive_attr_t &attr)
38         : jcp(ajcp), jcp_dw(ajcp_dw), attr_(attr)
39     {
40         this->generate();
41         jit_ker = (void (*)(jit_1x1_conv_call_s *))this->getCode();
42     }
43
44     ~jit_avx2_1x1_conv_kernel_f32() {
45         for (auto inj : eltwise_injectors)
46             delete inj;
47         eltwise_injectors.clear();
48
49         for (auto inj : depthwise_injectors)
50             delete inj;
51         depthwise_injectors.clear();
52     }
53
54     static bool post_ops_ok(jit_1x1_conv_conf_t &jcp,
55             const primitive_attr_t &attr);
56
57     static status_t init_conf(jit_1x1_conv_conf_t &jcp,
58             const convolution_desc_t &cd,
59             const memory_desc_wrapper &src_d,
60             const memory_desc_wrapper &weights_d,
61             const memory_desc_wrapper &dst_d,
62             const primitive_attr_t &attr);
63
64     static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
65             const jit_1x1_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw = jit_conv_conf_t());
66
67     jit_1x1_conv_conf_t jcp;
68     jit_conv_conf_t jcp_dw;
69     const primitive_attr_t &attr_;
70     void (*jit_ker)(jit_1x1_conv_call_s *);
71
72 private:
73     using reg64_t = const Xbyak::Reg64;
74     using ymm_t = const Xbyak::Ymm;
75
76     reg64_t reg_bcast_data = rax;
77     reg64_t reg_load_data = rsi;
78     reg64_t reg_output_data = rbx;
79     reg64_t aux_reg_bcast_data = rdx;
80     reg64_t aux1_reg_bcast_data = abi_not_param1;
81     reg64_t aux_reg_output_data = rbp;
82     reg64_t reg_load_loop_work = r9;
83     reg64_t reg_bcast_loop_work = r10;
84     reg64_t reg_reduce_loop_work = r11;
85     reg64_t load_loop_iter = r13;
86     reg64_t aux_reg_load_data = load_loop_iter;
87     reg64_t bcast_loop_iter = r14;
88     reg64_t reduce_loop_iter = r15;
89     reg64_t imm_addr64 = reduce_loop_iter;
90     reg64_t reg_reduce_pos_flag = r8;
91     reg64_t reg_output_stride = r12;
92     reg64_t reg_bias_data = r12;
93     reg64_t reg_diff_bias_data = bcast_loop_iter;
94
95     reg64_t reg_oc_off = abi_param1;
96     reg64_t reg_d_weights = aux_reg_bcast_data;
97     reg64_t reg_d_bias = reduce_loop_iter;
98
99     int reg_diff_bias_data_stack_offt = 0;
100     int stack_space_needed = 8;
101
102     ymm_t vreg_bcast = ymm_t(15);
103     ymm_t vtmp = ymm_t(14);
104
105     void generate_bcast_loop(int load_loop_blk);
106     void generate_reduce_loop(int load_loop_blk, int ur);
107     void generate_diff_bias_loop(int load_loop_blk);
108
109     nstl::vector<jit_uni_eltwise_injector_f32<avx2>*> eltwise_injectors;
110     nstl::vector<jit_uni_depthwise_injector_f32<avx2>*> depthwise_injectors;
111
112     void generate();
113 };
114
115 }
116 }
117 }
118
119 #endif