updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_bf16_1x1_conv_kernel.hpp
1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef JIT_AVX512_CORE_BF16_1x1_CONV_KERNEL_HPP
18 #define JIT_AVX512_CORE_BF16_1x1_CONV_KERNEL_HPP
19
20 #include "c_types_map.hpp"
21 #include "jit_generator.hpp"
22 #include "jit_primitive_conf.hpp"
23 #include "jit_uni_eltwise.hpp"
24 #include "jit_avx512_core_bf16cvt.hpp"
25
26 //#define BF16_CONV_1x1_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
27
28 namespace mkldnn {
29 namespace impl {
30 namespace cpu {
31
32 namespace { const size_t code_size_bf16_bwd_w = 1024 * 1024; }
33
34 struct jit_avx512_core_bf16_1x1_conv_kernel : public jit_generator {
35     jit_avx512_core_bf16_1x1_conv_kernel(jit_1x1_conv_conf_t ajcp,
36             const primitive_attr_t &attr) :
37     jit_generator(nullptr, ker_code_size),
38     jcp(ajcp), attr_(attr)
39     , eltwise_injector_(nullptr)
40     , bf16_emu_(nullptr)
41     {
42         if (jcp.with_eltwise)
43             eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx512_common>(
44                     this, jcp.eltwise);
45
46         if (!mayiuse(avx512_core_bf16))
47             bf16_emu_ = new bf16_emulation_t(this,
48                     bf16_emu_reserv_1, bf16_emu_reserv_2,
49                     bf16_emu_reserv_3, bf16_emu_reserv_4,
50                     bf16_emu_reserv_5, bf16_emu_reserv_6);
51
52         this->generate();
53         jit_ker = (void (*)(jit_1x1_conv_call_s *)) this->getCode();
54     }
55
56     ~jit_avx512_core_bf16_1x1_conv_kernel() {
57         delete eltwise_injector_;
58         delete bf16_emu_;
59     }
60
61     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_bf16_1x1_conv_kernel)
62
63     static bool post_ops_ok(jit_1x1_conv_conf_t &jcp,
64                                 const primitive_attr_t &attr);
65
66     static status_t init_conf(jit_1x1_conv_conf_t &jcp,
67             const convolution_desc_t &cd,
68             const memory_desc_wrapper &src_d,
69             const memory_desc_wrapper &weights_d,
70             const memory_desc_wrapper &dst_d,
71             const primitive_attr_t &attr,
72             int nthreads, bool reduce_src);
73
74     static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
75             const jit_1x1_conv_conf_t &jcp);
76
77     jit_1x1_conv_conf_t jcp;
78     const primitive_attr_t &attr_;
79     void (*jit_ker)(jit_1x1_conv_call_s *);
80
81   private:
82     using reg64_t = const Xbyak::Reg64;
83     using zmm_t = const Xbyak::Zmm;
84     using mask_t = const Xbyak::Opmask;
85     enum {
86         ker_code_size = 1024 * 1024,
87     };
88
89     reg64_t reg_bcast_data = r8;
90     reg64_t reg_load_data = r10;
91     reg64_t reg_output_data = r9;
92     reg64_t aux_reg_bcast_data = r14;
93     reg64_t aux1_reg_bcast_data = rbx;
94     reg64_t aux_reg_load_data = r15;
95     reg64_t imm_addr64 = aux_reg_load_data;
96     reg64_t aux_reg_output_data = abi_not_param1;
97     reg64_t reg_load_loop_work = rsi;
98     reg64_t reg_reduce_loop_work = r11;
99     reg64_t bcast_loop_iter = rdx;
100     reg64_t reduce_loop_iter = abi_param1;
101     reg64_t reg_reduce_pos_flag = rax;
102     reg64_t reg_output_stride = r13;
103     reg64_t reg_bias_data = r12;
104     reg64_t reg_bcast_loop_work = aux1_reg_bcast_data;
105     reg64_t reg_trans_tmp = rax;
106
107     mask_t vmask = k7;
108
109     Xbyak::Xmm xmm_relu_ns = Xbyak::Xmm(30);
110     Xbyak::Zmm zmm_relu_ns = Xbyak::Zmm(30);
111     Xbyak::Zmm zmm_zero = Xbyak::Zmm(31);
112     Xbyak::Zmm vreg_bcast = Xbyak::Zmm(31);
113
114     Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(25);
115     Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(26);
116     Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(27);
117     reg64_t bf16_emu_reserv_4 = imm_addr64;
118     Xbyak::Zmm bf16_emu_reserv_5 = Xbyak::Zmm(28);
119     Xbyak::Zmm bf16_emu_reserv_6 = Xbyak::Zmm(29);
120
121     Xbyak::Zmm zmm_tmp2 = Xbyak::Zmm(30);
122
123     Xbyak::Label dst_prm_table;
124
125     jit_uni_eltwise_injector_f32<avx512_common> *eltwise_injector_;
126
127     int bcast_loop_work_offt = 0;
128 #ifdef BF16_CONV_1x1_BWD_W_JIT_KER_USES_PERMW_TRANSPOSITION
129     int perm_reg_offset = 8;
130     int broadcast_space = 24;
131 #endif
132     int stack_space_needed = 96;
133
134     void bcast_loop(int load_loop_blk);
135     void reduce_loop(int load_loop_blk, int ur, int substep, bool wraparound);
136
137     void generate();
138     static void balance(jit_1x1_conv_conf_t &jcp, int nthreads);
139
140     bf16_emulation_t *bf16_emu_;
141 };
142 }
143 }
144 }
145
146 #endif