Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_u8s8s32x_1x1_conv_kernel.hpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef JIT_AVX512_CORE_U8S8S32X_1X1_CONV_KERNEL_HPP
18 #define JIT_AVX512_CORE_U8S8S32X_1X1_CONV_KERNEL_HPP
19
20 #include "c_types_map.hpp"
21 #include "jit_generator.hpp"
22 #include "jit_primitive_conf.hpp"
23
24 namespace mkldnn {
25 namespace impl {
26 namespace cpu {
27
28 struct jit_avx512_core_u8s8s32x_1x1_conv_kernel: public jit_generator {
29     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8s8s32x_1x1_conv_fwd_ker_t)
30     jit_avx512_core_u8s8s32x_1x1_conv_kernel(jit_1x1_conv_conf_t ajcp,
31             const primitive_attr_t &attr) : jcp(ajcp), attr_(attr)
32     {
33         this->generate();
34         jit_ker = (void (*)(jit_1x1_conv_call_s *)) this->getCode();
35     }
36
37     static bool post_ops_ok(jit_1x1_conv_conf_t &jcp,
38                                 const primitive_attr_t &attr);
39
40     static status_t init_conf(jit_1x1_conv_conf_t &jcp,
41                                 const convolution_desc_t &cd,
42                                 const memory_desc_wrapper &src_d,
43                                 const memory_desc_wrapper &weights_d,
44                                 const memory_desc_wrapper &dst_d,
45                                 const memory_desc_wrapper &bias_d,
46                                 const primitive_attr_t &attr,
47                                 bool with_relu, float relu_negative_slope,
48                                 int nthreads, bool reduce_src);
49
50     static status_t init_conf(jit_1x1_conv_conf_t &jcp,
51                               const convolution_desc_t &cd,
52                               const memory_desc_wrapper &src_d,
53                               const memory_desc_wrapper &weights_d,
54                               const memory_desc_wrapper &dst_d,
55                               const memory_desc_wrapper &bias_d,
56                               const primitive_attr_t &attr,
57                               int nthreads, bool reduce_src)
58     {
59         return init_conf(jcp, cd, src_d, weights_d, dst_d, bias_d, attr, false,
60             0.0, nthreads, reduce_src);
61     }
62     bool maybe_relu(int position);
63
64     jit_1x1_conv_conf_t jcp;
65     const primitive_attr_t &attr_;
66     void (*jit_ker)(jit_1x1_conv_call_s *);
67
68   private:
69     using reg64_t = const Xbyak::Reg64;
70     using zmm_t = const Xbyak::Zmm;
71     using mask_t = const Xbyak::Opmask;
72
73     reg64_t reg_bcast_data = r8;
74     reg64_t reg_ptr_scales = r8;
75     reg64_t reg_output_data = r9;
76     reg64_t reg_load_data = r10;
77     reg64_t reg_ptr_sum_scale = r10;
78     reg64_t reg_reduce_loop_work = r11;
79     reg64_t reg_bias_data = r12;
80     reg64_t aux_reg_acc_s32 = r12;
81     reg64_t reg_acc_s32 = r13;
82     reg64_t reg_scratch = r13;
83     reg64_t aux_reg_bcast_data = r14;
84     reg64_t aux_reg_load_data = r15;
85     reg64_t imm_addr64 = r15;
86     reg64_t reg_reduce_pos_flag = rax;
87     reg64_t aux1_reg_bcast_data = rbx;
88     reg64_t reg_bcast_loop_work = rbx;
89     reg64_t bcast_loop_iter = rdx; // Note: Fix me
90     reg64_t reg_load_loop_work = rsi;
91     reg64_t aux_reg_output_data = abi_not_param1;
92     reg64_t reduce_loop_iter = abi_param1;
93
94     mask_t vmask = k7;
95
96     Xbyak::Zmm zmm_tmp = Xbyak::Zmm(28);
97     Xbyak::Zmm zmm_one = Xbyak::Zmm(29);
98     Xbyak::Zmm zmm_zero = Xbyak::Zmm(30);
99     Xbyak::Zmm zmm_bcast = Xbyak::Zmm(31);
100
101     int bcast_loop_work_offt = 0;
102     int reg_bias_data_offt = 8;
103     int aux_reg_acc_s32_offt = 16;
104     int reg_bcast_data_off = 24;
105     int reg_load_data_off = 32;
106     int reg_ptr_sum_scale_off = 40;
107     int stack_space_needed = 48;
108
109     void bcast_loop(int load_loop_blk);
110     void reduce_loop(int load_loop_blk, int ur, int substep, bool wraparound);
111
112     void generate();
113     static void balance(jit_1x1_conv_conf_t &jcp, int nthreads);
114 };
115 }
116 }
117 }
118
119 #endif