Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef JIT_AVX512_CORE_X8S8S32X_1X1_CONV_KERNEL_HPP
18 #define JIT_AVX512_CORE_X8S8S32X_1X1_CONV_KERNEL_HPP
19
20 #include "c_types_map.hpp"
21 #include "memory_tracking.hpp"
22
23 #include "jit_generator.hpp"
24 #include "jit_primitive_conf.hpp"
25 #include "jit_uni_eltwise.hpp"
26 #include "jit_uni_depthwise.hpp"
27
28 namespace mkldnn {
29 namespace impl {
30 namespace cpu {
31
32 struct jit_avx512_core_x8s8s32x_1x1_conv_kernel: public jit_generator {
33     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_x8s8s32x_1x1_conv_fwd_ker_t)
34     jit_avx512_core_x8s8s32x_1x1_conv_kernel(jit_1x1_conv_conf_t ajcp,
35             const primitive_attr_t &attr) : jcp(ajcp), attr_(attr)
36     {
37         this->generate();
38         jit_ker = (void (*)(jit_1x1_conv_call_s *)) this->getCode();
39     }
40
41     ~jit_avx512_core_x8s8s32x_1x1_conv_kernel() {
42         for (auto inj : eltwise_injectors)
43             delete inj;
44         eltwise_injectors.clear();
45
46         for (auto inj : depthwise_injectors)
47             delete inj;
48         depthwise_injectors.clear();
49     }
50
51     static bool post_ops_ok(jit_1x1_conv_conf_t &jcp,
52                                 const primitive_attr_t &attr);
53
54     static status_t init_conf(jit_1x1_conv_conf_t &jcp,
55             const convolution_desc_t &cd,
56             const memory_desc_wrapper &src_d,
57             const memory_desc_wrapper &weights_d,
58             const memory_desc_wrapper &dst_d,
59             const memory_desc_wrapper &bias_d,
60             const primitive_attr_t &attr,
61             int nthreads, bool reduce_src);
62
63     static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
64             const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr);
65
66     jit_1x1_conv_conf_t jcp;
67     const primitive_attr_t &attr_;
68     void (*jit_ker)(jit_1x1_conv_call_s *);
69
70   private:
71     nstl::vector<jit_uni_eltwise_injector_f32<avx512_common>*> eltwise_injectors;
72     nstl::vector<jit_uni_depthwise_injector_f32<avx512_common>*> depthwise_injectors;
73
74     using reg64_t = const Xbyak::Reg64;
75     using zmm_t = const Xbyak::Zmm;
76     using mask_t = const Xbyak::Opmask;
77
78     reg64_t reg_bcast_data = r8;
79     reg64_t reg_ptr_scales = r8;
80     reg64_t reg_output_data = r9;
81     reg64_t reg_load_data = r10;
82     reg64_t reg_ptr_sum_scale = r10;
83     reg64_t reg_reduce_loop_work = r11;
84     reg64_t reg_bias_data = r12;
85     reg64_t reg_comp_data = r12;
86     reg64_t reg_scratch = r13;
87     reg64_t aux_reg_bcast_data = r14;
88     reg64_t aux_reg_load_data = r15;
89     reg64_t imm_addr64 = r15;
90     reg64_t reg_reduce_pos_flag = rax;
91     reg64_t aux1_reg_bcast_data = rbx;
92     reg64_t reg_bcast_loop_work = rbx;
93     reg64_t bcast_loop_iter = rdx; // Note: Fix me
94     reg64_t reg_load_loop_work = rsi;
95     reg64_t aux_reg_output_data = abi_not_param1;
96     reg64_t reduce_loop_iter = abi_param1;
97
98     const Xbyak::Reg64 reg_d_weights = aux_reg_bcast_data;
99     const Xbyak::Reg64 reg_d_bias = reduce_loop_iter;
100     const Xbyak::Reg64 reg_oc_off = aux_reg_load_data;
101
102     reg64_t reg_last_load = r8;
103     mask_t ktail_mask = k6;
104
105     mask_t vmask = k7;
106
107     Xbyak::Zmm zmm_tmp = Xbyak::Zmm(28);
108     Xbyak::Zmm zmm_one = Xbyak::Zmm(29);
109     Xbyak::Zmm zmm_zero = Xbyak::Zmm(30);
110     Xbyak::Zmm zmm_bcast = Xbyak::Zmm(31);
111     Xbyak::Zmm zmm_shift = Xbyak::Zmm(30);
112
113     Xbyak::Zmm zmm_bias_alpha = Xbyak::Zmm(31);
114     Xbyak::Xmm xmm_bias_alpha = Xbyak::Xmm(31);
115
116     int bcast_loop_work_off = 0;
117     int reg_bias_data_off = 8;
118     int reg_bcast_data_off = 16;
119     int reg_load_data_off = 24;
120     int reg_ptr_sum_scale_off = 32;
121     int reg_comp_data_off = 40;
122     int stack_space_needed = 48;
123
124     void bcast_loop(int load_loop_blk);
125     void reduce_loop(int load_loop_blk, int ur, int substep, bool wraparound);
126
127     void generate();
128     void cvt2ps(data_type_t type_in, zmm_t zmm_in, const Xbyak::Operand &op,
129         bool mask_flag);
130 };
131
132 }
133 }
134 }
135
136 #endif