updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_bf16_sum.cpp
1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 #include <float.h>
17 #include "c_types_map.hpp"
18 #include "nstl.hpp"
19 #include "type_helpers.hpp"
20 #include "mkldnn_thread.hpp"
21 #include "utils.hpp"
22 #include "cpu_memory.hpp"
23
24 #include "bfloat16_utils.hpp"
25 #include "jit_avx512_core_bf16_sum.hpp"
26
27 #define GET_OFF(field) offsetof(jit_sum_call_s, field)
28
29 namespace mkldnn {
30 namespace impl {
31 namespace cpu {
32
33 using namespace mkldnn::impl::prop_kind;
34 using namespace mkldnn::impl::memory_format;
35 using namespace mkldnn::impl::utils;
36 using namespace mkldnn::impl::cpu::bf16_cvt_utils;
37
38 using namespace Xbyak;
39 void jit_avx512_core_bf16_sum_kernel::loop_iteration(int current_unroll)
40 {
41     Label loop_label, exit_label;
42     const int num_compute_elements = 2 * f32_simd_w_ * current_unroll;
43     size_t src_shift = 2 * f32_simd_w_ * jsp.typesize_in;
44     size_t dst_shift = f32_simd_w_* jsp.typesize_out;
45
46     L(loop_label);
47     cmp(reg_sz, num_compute_elements);
48     jl(exit_label, T_NEAR);
49     for (int u_idx = 0; u_idx < current_unroll; u_idx++) {
50         zmm_t vacc0 = Zmm(acc_vreg_idx(u_idx, 0));
51         zmm_t vacc1 = Zmm(acc_vreg_idx(u_idx, 1));
52         vpxord(vacc0, vacc0, vacc0);
53         vpxord(vacc1, vacc1, vacc1);
54
55         int num_acc_iters = utils::div_up(jsp.num_srcs, 2);
56         for (int acc_iter = 0; acc_iter < num_acc_iters; acc_iter++) {
57             int isrc0 = 2 * acc_iter;
58             int isrc1 = 2 * acc_iter + 1;
59             zmm_t vscale = Zmm(scale_vreg_idx(acc_iter));
60             zmm_t vsrc0 = Zmm(src_vreg_idx(u_idx, isrc0));
61             zmm_t vsrc1 = Zmm(src_vreg_idx(u_idx, isrc1));
62             zmm_t vtmp = Zmm(tmp_vreg_idx(u_idx, acc_iter));
63             vmovups(vsrc0, zword[reg_src[isrc0] + u_idx * src_shift]);
64             if (num_acc_iters * 2 > jsp.num_srcs
65                 && acc_iter == num_acc_iters - 1)
66                 vpxord(vtmp, vtmp, vtmp); /* imitate additional zero input
67                                              if number of srcs is odd */
68             else
69                 vmovups(vtmp, zword[reg_src[isrc1] + u_idx * src_shift]);
70             vshuff64x2(vsrc1, vsrc0, vtmp, 0xEE);
71             vpermw(vsrc1, zmm_idx, vsrc1);
72             vshuff64x2(vsrc0, vsrc0, vtmp, 0x44);
73             vpermw(vsrc0, zmm_idx, vsrc0);
74
75                 if (!jsp.is_cpx) {
76                     bf16_emu_->r_vdpbf16ps(vacc0, vsrc0, vscale);
77                     vpbroadcastd(vscale,
78                         ptr[reg_scales + 2 * acc_iter * jsp.typesize_in]);
79                     bf16_emu_->r_vdpbf16ps(vacc1, vsrc1, vscale);
80                 } else {
81                     vdpbf16ps(vacc0, vsrc0, vscale);
82                     vdpbf16ps(vacc1, vsrc1, vscale);
83                 }
84         }
85
86         if (!jsp.is_bf16_dst) {
87             vmovups(zword[reg_dst + 2 * u_idx * dst_shift], vacc0);
88             vmovups(zword[reg_dst + (2 * u_idx + 1) * dst_shift], vacc1);
89         } else {
90             if (jsp.is_cpx) {
91                 zmm_t zmm_str = Zmm(tmp_vreg_idx(u_idx, 0));
92                 vcvtne2ps2bf16(zmm_str, vacc1, vacc0);
93                 vmovups(zword[reg_dst + 2 * u_idx * dst_shift], zmm_str);
94             } else {
95                 auto ymm_str = Ymm(tmp_vreg_idx(u_idx, 0));
96                 bf16_emu_->r_vcvtneps2bf16(ymm_str, vacc0);
97                 vmovups(yword[reg_dst + 2 * u_idx * dst_shift], ymm_str);
98                 bf16_emu_->r_vcvtneps2bf16(ymm_str, vacc1);
99                 vmovups(yword[reg_dst + (2 * u_idx + 1) * dst_shift], ymm_str);
100             }
101         }
102     }
103     sub(reg_sz, num_compute_elements);
104     for (int s = 0; s < jsp.num_srcs; s++)
105         add(reg_src[s], current_unroll * src_shift);
106     add(reg_dst, 2 * current_unroll * dst_shift);
107     jge(loop_label, T_NEAR);
108
109     L(exit_label);
110 }
111
112 void jit_avx512_core_bf16_sum_kernel::generate()
113 {
114     preamble();
115
116     mov(reg_dst, ptr[param + GET_OFF(dst)]);
117     mov(reg_srcs, ptr[param + GET_OFF(srcs)]);
118
119     for (int s = 0; s < jsp.num_srcs; s++)
120          mov(reg_src[s], ptr[reg_srcs + sizeof(void*) * s]);
121
122     mov(reg_scales, ptr[param + GET_OFF(scales)]);
123     mov(reg_sz, ptr[param + GET_OFF(size)]);
124
125     Label tail_label, exit_label, mask_label;
126
127     mov(reg_idx_table, idx_table);
128     vmovups(zmm_idx, ptr[reg_idx_table]);
129
130     int num_acc_iters = utils::div_up(jsp.num_srcs, 2);
131     for (int acc_iter = 0; acc_iter < num_acc_iters; acc_iter++) {
132         zmm_t vscale = Zmm(scale_vreg_idx(acc_iter));
133         vpbroadcastd(vscale, ptr[reg_scales + 2 * acc_iter * jsp.typesize_in]);
134     }
135
136     if (!jsp.is_cpx) bf16_emu_->init_vcvtneps2bf16();
137     if (jsp.loop_unroll > 1)
138         loop_iteration(jsp.loop_unroll);
139
140     loop_iteration(1);
141
142     // tail processing
143     L(tail_label);
144     cmp(reg_sz, 0);
145     jle(exit_label, T_NEAR);
146
147     const int bf16_half_reg = f32_simd_w_;
148     mov(reg32_mask, 0xffff);
149     cmp(reg_sz, bf16_half_reg);
150     jge(mask_label, T_NEAR);
151
152     mov(reg32_mask, 1);
153     mov(rcx, reg_sz);
154     shl(reg32_mask, cl);
155     sub(reg32_mask, 1);
156
157     L(mask_label);
158     kmovd(k_mask, reg32_mask);
159     zmm_t vacc = Zmm(acc_vreg_idx(0, 0));
160     vpxord(vacc, vacc, vacc);
161
162     for (int acc_iter = 0; acc_iter < num_acc_iters; acc_iter++) {
163         int isrc0 = 2 * acc_iter;
164         int isrc1 = 2 * acc_iter + 1;
165         zmm_t vscale = Zmm(scale_vreg_idx(acc_iter));
166         zmm_t vsrc = Zmm(src_vreg_idx(0, isrc0));
167         ymm_t vysrc0 = Ymm(src_vreg_idx(0, isrc0));
168         ymm_t vysrc1 = Ymm(src_vreg_idx(0, isrc1));
169         vpxord(vysrc0, vysrc0, vysrc0);
170         vpxord(vysrc1, vysrc1, vysrc1);
171
172         vmovdqu16(vysrc0 | k_mask | T_z, yword[reg_src[isrc0]]);
173         if (!(num_acc_iters * 2 > jsp.num_srcs
174                 && acc_iter == num_acc_iters - 1))
175             vmovdqu16(vysrc1 | k_mask | T_z, yword[reg_src[isrc1]]);
176         vinserti64x4(vsrc, vsrc, vysrc1, 0x1);
177         vpermw(vsrc, zmm_idx, vsrc);
178
179         if (!jsp.is_cpx) {
180             bf16_emu_->r_vdpbf16ps(vacc, vsrc, vscale);
181         } else {
182             vdpbf16ps(vacc, vsrc, vscale);
183         }
184     }
185     if (!jsp.is_bf16_dst) {
186         vmovups(zword[reg_dst] | k_mask, vacc);
187     } else {
188         if (jsp.is_cpx) {
189             auto ymm_str = Ymm(tmp_vreg_idx(0, 0));
190             vcvtneps2bf16(ymm_str, vacc);
191             vmovdqu16(yword[reg_dst] | k_mask, ymm_str);
192         } else {
193             auto ymm_str = Ymm(tmp_vreg_idx(0, 0));
194             bf16_emu_->r_vcvtneps2bf16(ymm_str, vacc);
195             vmovdqu16(yword[reg_dst] | k_mask, ymm_str);
196         }
197     }
198
199     sub(reg_sz, bf16_half_reg);
200     cmp(reg_sz, 0);
201     jle(exit_label, T_NEAR);
202
203     for (int s = 0; s < jsp.num_srcs; s++)
204         add(reg_src[s], bf16_half_reg * jsp.typesize_in);
205     add(reg_dst, f32_simd_w_ * jsp.typesize_out);
206
207     jmp(tail_label, T_NEAR);
208
209     L(exit_label);
210     postamble();
211
212     align(64);
213     L(idx_table);
214     const uint16_t _idx[] = { 0,16,1,17,2,18,3,19,4,20,5,21,6,22,7,23,8,24,
215                               9,25,10,26,11,27,12,28,13,29,14,30,15,31 };
216     for (size_t i = 0; i < sizeof(_idx) / sizeof(_idx[0]); ++i)
217         dw(_idx[i]);
218 }
219
220 status_t jit_avx512_core_bf16_sum_kernel::init_conf(
221                                 jit_sum_conf_t &jsp,
222                                 const int num_srcs,
223                                 const cpu_memory_pd_t &dst_d)
224 {
225     jsp.is_cpx = mayiuse(avx512_core_bf16);
226
227     jsp.num_srcs = num_srcs;
228     jsp.loop_unroll = 0;
229
230     const int max_unroll = 6; // maximum possible value of unroll is 6
231     for (/*continue*/; jsp.loop_unroll < max_unroll; jsp.loop_unroll++)
232     {
233         int num_regs = num_vregs_required(jsp.loop_unroll + 1, jsp.num_srcs);
234         if (num_regs > max_vregs_available(jsp.is_cpx))
235             break;
236     }
237     const int bf16_simd_w = 32;
238     if (jsp.loop_unroll == 0) return status::unimplemented;
239     jsp.size_blocking = bf16_simd_w * jsp.loop_unroll;
240
241     const memory_desc_wrapper o_d(&dst_d);
242     jsp.is_bf16_dst = data_type::bf16 == o_d.data_type();
243
244     jsp.typesize_in = sizeof(mkldnn_bfloat16_t);
245     jsp.typesize_out = jsp.is_bf16_dst
246                      ? sizeof(mkldnn_bfloat16_t)
247                      : sizeof(float);
248
249     return status::success;
250 }
251
252 template <data_type_t src_data_type, data_type_t dst_data_type>
253 void jit_bf16_sum_t<src_data_type, dst_data_type>::execute() const
254 {
255     auto output = reinterpret_cast<dst_data_t *>(this->memory());
256     const int num_arrs = pd()->n_inputs();
257     const memory_desc_wrapper o_d(pd()->dst_pd());
258     output += o_d.blk_off(0);
259     const size_t nelems = o_d.nelems();
260     const src_data_t *input_ptrs[max_num_arrs];
261     /* Number of scales needs to be multiple of 2 in order
262     to use VNNI instructions */
263     src_data_t scales[max_num_arrs];
264     for (int a = 0; a < num_arrs; ++a) {
265         const memory_desc_wrapper i_d(pd()->src_pd(a));
266
267         input_ptrs[a] = reinterpret_cast<const src_data_t *>(
268                 this->input_memory(a)) + i_d.blk_off(0);
269     }
270     cvt_float_to_bfloat16(scales, &pd()->scales_[0], num_arrs);
271     if (num_arrs % 2 != 0)
272         scales[num_arrs] = cvt_float_to_bfloat16(0.0f);
273
274     const size_t half_L1 = 16 * 1024; // bytes
275     const size_t num_elems_in_block = utils::rnd_up(
276                  utils::div_up(half_L1,
277                      num_arrs * sizeof(src_data_t) + sizeof(dst_data_t)),
278                  pd()->jsp_.size_blocking);
279     const size_t num_blocks = nelems / num_elems_in_block;
280     const size_t tail = nelems % num_elems_in_block;
281
282     parallel(0, [&](const int ithr, const int nthr) {
283         size_t start{0}, end{0};
284         balance211(num_blocks, nthr, ithr, start, end);
285         auto arg = jit_sum_call_s();
286         const src_data_t *local_input_ptrs[max_num_arrs];
287         dst_data_t *local_output;
288
289         for (size_t nb = start; nb < end; ++nb) {
290             size_t start_e = nb * num_elems_in_block;
291             for (int a = 0; a < num_arrs; ++a) {
292                 local_input_ptrs[a] = &input_ptrs[a][start_e];
293             }
294             local_output = &output[start_e];
295             arg.srcs = (const void **)local_input_ptrs;
296             arg.dst = (const void *)local_output;
297             arg.scales = (const void *)scales;
298             arg.size = num_elems_in_block;
299             kernel_->jit_ker(&arg);
300         }
301
302         if (tail != 0 && ithr == nthr - 1) {
303             size_t start_e = nelems - tail;
304             for (int a = 0; a < num_arrs; ++a) {
305                 local_input_ptrs[a] = &input_ptrs[a][start_e];
306             }
307             local_output = &output[start_e];
308             arg.srcs = (const void **)local_input_ptrs;
309             arg.dst = (const void *)local_output;
310             arg.scales = (const void *)scales;
311             arg.size = tail;
312             kernel_->jit_ker(&arg);
313         }
314     });
315 }
316
317 template struct jit_bf16_sum_t<data_type::bf16, data_type::f32>;
318 template struct jit_bf16_sum_t<data_type::bf16, data_type::bf16>;
319
320 }
321 }
322 }