updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_bf16_sum.hpp
1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef JIT_AVX512_CORE_BF16_SUM_HPP
18 #define JIT_AVX512_CORE_BF16_SUM_HPP
19
20 #include "c_types_map.hpp"
21
22 #include "bfloat16_utils.hpp"
23 #include "cpu_memory.hpp"
24 #include "cpu_sum.hpp"
25 #include "jit_generator.hpp"
26 #include "jit_primitive_conf.hpp"
27 #include "jit_avx512_core_bf16cvt.hpp"
28
29 namespace mkldnn {
30 namespace impl {
31 namespace cpu {
32
33 struct jit_sum_conf_t {
34     int num_srcs;
35     int is_cpx;
36     int is_bf16_dst;
37     int typesize_in;
38     int typesize_out;
39     int loop_unroll;
40     int size_blocking; /* minimum recommended data blocking size as this
41                           number of elements computes main unrolled loop
42                           in jit kernel per iteration */
43 };
44
45 struct jit_sum_call_s {
46     const void **srcs;
47     const void *dst;
48     const void *scales;
49     size_t size;
50 };
51
52 namespace {
53 /* this needs to be multiple of 2 for vnni instruction
54  * to work with scales */
55 constexpr int max_num_arrs = 8 ;
56 }
57
58
59
60 struct jit_avx512_core_bf16_sum_kernel : public jit_generator {
61     using cpu_memory_pd_t = cpu_memory_t::pd_t;
62     jit_avx512_core_bf16_sum_kernel(jit_sum_conf_t ajsp)
63     : jsp(ajsp)
64     , bf16_emu_(nullptr)
65     {
66         if (!mayiuse(avx512_core_bf16))
67             bf16_emu_ = new bf16_emulation_t(this, one, even,
68                                          selector, abi_not_param1,
69                                          zmm_tmp0, zmm_tmp1);
70
71         this->generate();
72         jit_ker = (void (*)(jit_sum_call_s *)) this->getCode();
73     }
74
75     ~jit_avx512_core_bf16_sum_kernel() {
76         delete bf16_emu_;
77     }
78
79     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_bf16_sum_kernel)
80
81     static status_t init_conf(jit_sum_conf_t &jsp,
82                                 const int num_srcs,
83                                 const cpu_memory_pd_t &dst_d);
84
85     jit_sum_conf_t jsp;
86     void (*jit_ker)(jit_sum_call_s *);
87
88   private:
89     using reg64_t = const Xbyak::Reg64;
90     using reg32_t = const Xbyak::Reg32;
91     using reg8_t = const Xbyak::Reg8;
92     using zmm_t = const Xbyak::Zmm;
93     using ymm_t = const Xbyak::Ymm;
94     using mask_t = const Xbyak::Opmask;
95
96     enum { f32_simd_w_ = 16 };
97
98     reg64_t param = abi_param1; /* may be rcx, note that cl is required
99                                     for mask computation */
100
101     reg64_t reg_srcs = abi_not_param1; /* may be rcx, note that cl is required
102                                           for mask computation */
103     reg64_t reg_idx_table = abi_not_param1; /* may be rcx, note that cl is
104                                                required for mask computation */
105     reg64_t reg_mask = rsi;
106     reg32_t reg32_mask = esi;
107
108     reg64_t reg_dst = rax;
109     reg64_t reg_scales = rbx;
110     reg64_t reg_sz = rdx;
111
112     reg64_t reg_src[max_num_arrs] = {r8, r9, r10, r11, r12, r13, r14, r15};
113
114     static int max_vregs_available(bool is_cpx)
115     {
116         // one vector registers are reserved for vperm index and zero values
117         // additional 5 registers are reserved for bf16 emulation on non-cpx
118         return is_cpx ? 31 : 26;
119     }
120
121     int acc_vreg_idx(int i_unroll, int i_acc)
122     {
123         // 2 accumulation registers per unroll iteration
124         int idx = 2 * i_unroll + i_acc;
125         assert(idx < max_vregs_available(jsp.is_cpx));
126         return idx;
127     }
128
129     int scale_vreg_idx(int i_acc_iter)
130     {
131         int scale_idx_start = 2 * jsp.loop_unroll; // reserved for acc registers
132         int idx = scale_idx_start + i_acc_iter;
133         assert(idx < max_vregs_available(jsp.is_cpx));
134         return idx;
135     }
136
137     int src_vreg_idx(int i_unroll, int i_inp)
138     {
139         // reserved for acc and scale registers
140         int inp_idx_start = 2 * jsp.loop_unroll + utils::div_up(jsp.num_srcs, 2);
141         int idx = inp_idx_start
142                 + utils::rnd_up(jsp.num_srcs, 2) * i_unroll + i_inp;
143         assert(idx < max_vregs_available(jsp.is_cpx));
144         return idx;
145     }
146
147     int tmp_vreg_idx(int i_unroll, int i_acc_iter)
148     {
149         int num_acc_iters = utils::div_up(jsp.num_srcs, 2);
150         // reserved for acc, scale and src registers
151         int tmp_idx_start = utils::div_up(jsp.num_srcs, 2)
152                 + (2 + utils::rnd_up(jsp.num_srcs, 2)) * jsp.loop_unroll;
153         int idx = tmp_idx_start
154                 + num_acc_iters * i_unroll + i_acc_iter;
155         assert(idx < max_vregs_available(jsp.is_cpx));
156         return idx;
157     }
158
159     static int num_vregs_required(int unroll, int num_srcs)
160     {
161         int num_acc_iters = utils::div_up(num_srcs, 2);
162         // reserved for acc, scale and src registers
163         int num_regs = utils::div_up(num_srcs, 2)
164                 + (2 + utils::rnd_up(num_srcs, 2)) * unroll;
165         // tmp registers
166         num_regs += num_acc_iters * unroll;
167         return num_regs;
168     }
169
170     Xbyak::Zmm one = Xbyak::Zmm(26);
171     Xbyak::Zmm even = Xbyak::Zmm(27);
172     Xbyak::Zmm selector = Xbyak::Zmm(28);
173     Xbyak::Zmm zmm_tmp0 = Xbyak::Zmm(29);
174     Xbyak::Zmm zmm_tmp1 = Xbyak::Zmm(30);
175
176     Xbyak::Zmm zmm_idx = Xbyak::Zmm(31);
177
178     Xbyak::Label idx_table;
179     const Xbyak::Opmask k_mask = k1;
180
181     void generate();
182     void loop_iteration(int current_unroll);
183     bf16_emulation_t *bf16_emu_;
184 };
185
186 template <data_type_t src_data_type, data_type_t dst_data_type>
187 struct jit_bf16_sum_t: public cpu_primitive_t {
188     using cpu_memory_pd_t = cpu_memory_t::pd_t;
189
190     struct pd_t: public cpu_sum_pd_t {
191         pd_t(const memory_desc_t *output_d, int n, const float *scales,
192              const cpu_memory_pd_t **input_pds, const primitive_attr_t *attr)
193             : cpu_sum_pd_t(output_d, n, scales, input_pds, attr), jsp_() {}
194
195         DECLARE_CPU_SUM_PD_T(
196                 JIT_IMPL_NAME_HELPER("jit_bf16_", avx512_core, ""),
197                 jit_bf16_sum_t);
198
199         virtual status_t init() override {
200             bool ok = true
201                 && mayiuse(avx512_core)
202                 && cpu_sum_pd_t::init() == success
203                 && src_pds_.size() <= max_num_arrs;
204             if (!ok) return unimplemented;
205
206             const memory_desc_wrapper o_d(&dst_pd_);
207             ok = true
208                 && o_d.data_type() == dst_data_type
209                 && o_d.is_dense();
210             if (!ok) return unimplemented;
211
212             const auto n = src_pds_.size();
213
214             if (n > max_num_arrs)
215                 return status::unimplemented;
216
217             for (size_t i = 0; i < n; ++i) {
218                 const memory_desc_wrapper i_d(&src_pds_[i]);
219                 ok = true
220                     && src_data_type == i_d.data_type()
221                     && i_d.format() == o_d.format()
222                     && i_d.is_dense()
223                     && bf16_cvt_utils::
224                         is_float_representable_in_bfloat16(scales_[i]);
225                 if (!ok) return unimplemented;
226             }
227
228             return jit_avx512_core_bf16_sum_kernel::init_conf(jsp_,
229                                 src_pds_.size(),
230                                 dst_pd_);
231         }
232         jit_sum_conf_t jsp_;
233     };
234
235     jit_bf16_sum_t(const pd_t *apd, const input_vector &inputs,
236             const output_vector &outputs)
237         : cpu_primitive_t(apd, inputs, outputs) {
238         kernel_ = new jit_avx512_core_bf16_sum_kernel(pd()->jsp_);
239     }
240
241     ~jit_bf16_sum_t() {
242         delete kernel_;
243     }
244
245     virtual void execute(event_t *e) const{
246         execute();
247         e->set_state(event_t::ready);
248     }
249
250     typedef typename prec_traits<src_data_type>::type src_data_t;
251     typedef typename prec_traits<dst_data_type>::type dst_data_t;
252     typedef typename prec_traits<data_type::f32>::type acc_data_t;
253
254 private:
255     void execute() const;
256     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
257     jit_avx512_core_bf16_sum_kernel *kernel_;
258 };
259
260 }
261 }
262 }
263
264 #endif