1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
19 #include "type_helpers.hpp"
20 #include "mkldnn_thread.hpp"
22 #include "cpu_memory.hpp"
24 #include "bfloat16_utils.hpp"
25 #include "jit_avx512_core_bf16_sum.hpp"
27 #define GET_OFF(field) offsetof(jit_sum_call_s, field)
33 using namespace mkldnn::impl::prop_kind;
34 using namespace mkldnn::impl::memory_format;
35 using namespace mkldnn::impl::utils;
36 using namespace mkldnn::impl::cpu::bf16_cvt_utils;
38 using namespace Xbyak;
39 void jit_avx512_core_bf16_sum_kernel::loop_iteration(int current_unroll)
41 Label loop_label, exit_label;
42 const int num_compute_elements = 2 * f32_simd_w_ * current_unroll;
43 size_t src_shift = 2 * f32_simd_w_ * jsp.typesize_in;
44 size_t dst_shift = f32_simd_w_* jsp.typesize_out;
47 cmp(reg_sz, num_compute_elements);
48 jl(exit_label, T_NEAR);
49 for (int u_idx = 0; u_idx < current_unroll; u_idx++) {
50 zmm_t vacc0 = Zmm(acc_vreg_idx(u_idx, 0));
51 zmm_t vacc1 = Zmm(acc_vreg_idx(u_idx, 1));
52 vpxord(vacc0, vacc0, vacc0);
53 vpxord(vacc1, vacc1, vacc1);
55 int num_acc_iters = utils::div_up(jsp.num_srcs, 2);
56 for (int acc_iter = 0; acc_iter < num_acc_iters; acc_iter++) {
57 int isrc0 = 2 * acc_iter;
58 int isrc1 = 2 * acc_iter + 1;
59 zmm_t vscale = Zmm(scale_vreg_idx(acc_iter));
60 zmm_t vsrc0 = Zmm(src_vreg_idx(u_idx, isrc0));
61 zmm_t vsrc1 = Zmm(src_vreg_idx(u_idx, isrc1));
62 zmm_t vtmp = Zmm(tmp_vreg_idx(u_idx, acc_iter));
63 vmovups(vsrc0, zword[reg_src[isrc0] + u_idx * src_shift]);
64 if (num_acc_iters * 2 > jsp.num_srcs
65 && acc_iter == num_acc_iters - 1)
66 vpxord(vtmp, vtmp, vtmp); /* imitate additional zero input
67 if number of srcs is odd */
69 vmovups(vtmp, zword[reg_src[isrc1] + u_idx * src_shift]);
70 vshuff64x2(vsrc1, vsrc0, vtmp, 0xEE);
71 vpermw(vsrc1, zmm_idx, vsrc1);
72 vshuff64x2(vsrc0, vsrc0, vtmp, 0x44);
73 vpermw(vsrc0, zmm_idx, vsrc0);
76 bf16_emu_->r_vdpbf16ps(vacc0, vsrc0, vscale);
78 ptr[reg_scales + 2 * acc_iter * jsp.typesize_in]);
79 bf16_emu_->r_vdpbf16ps(vacc1, vsrc1, vscale);
81 vdpbf16ps(vacc0, vsrc0, vscale);
82 vdpbf16ps(vacc1, vsrc1, vscale);
86 if (!jsp.is_bf16_dst) {
87 vmovups(zword[reg_dst + 2 * u_idx * dst_shift], vacc0);
88 vmovups(zword[reg_dst + (2 * u_idx + 1) * dst_shift], vacc1);
91 zmm_t zmm_str = Zmm(tmp_vreg_idx(u_idx, 0));
92 vcvtne2ps2bf16(zmm_str, vacc1, vacc0);
93 vmovups(zword[reg_dst + 2 * u_idx * dst_shift], zmm_str);
95 auto ymm_str = Ymm(tmp_vreg_idx(u_idx, 0));
96 bf16_emu_->r_vcvtneps2bf16(ymm_str, vacc0);
97 vmovups(yword[reg_dst + 2 * u_idx * dst_shift], ymm_str);
98 bf16_emu_->r_vcvtneps2bf16(ymm_str, vacc1);
99 vmovups(yword[reg_dst + (2 * u_idx + 1) * dst_shift], ymm_str);
103 sub(reg_sz, num_compute_elements);
104 for (int s = 0; s < jsp.num_srcs; s++)
105 add(reg_src[s], current_unroll * src_shift);
106 add(reg_dst, 2 * current_unroll * dst_shift);
107 jge(loop_label, T_NEAR);
112 void jit_avx512_core_bf16_sum_kernel::generate()
116 mov(reg_dst, ptr[param + GET_OFF(dst)]);
117 mov(reg_srcs, ptr[param + GET_OFF(srcs)]);
119 for (int s = 0; s < jsp.num_srcs; s++)
120 mov(reg_src[s], ptr[reg_srcs + sizeof(void*) * s]);
122 mov(reg_scales, ptr[param + GET_OFF(scales)]);
123 mov(reg_sz, ptr[param + GET_OFF(size)]);
125 Label tail_label, exit_label, mask_label;
127 mov(reg_idx_table, idx_table);
128 vmovups(zmm_idx, ptr[reg_idx_table]);
130 int num_acc_iters = utils::div_up(jsp.num_srcs, 2);
131 for (int acc_iter = 0; acc_iter < num_acc_iters; acc_iter++) {
132 zmm_t vscale = Zmm(scale_vreg_idx(acc_iter));
133 vpbroadcastd(vscale, ptr[reg_scales + 2 * acc_iter * jsp.typesize_in]);
136 if (!jsp.is_cpx) bf16_emu_->init_vcvtneps2bf16();
137 if (jsp.loop_unroll > 1)
138 loop_iteration(jsp.loop_unroll);
145 jle(exit_label, T_NEAR);
147 const int bf16_half_reg = f32_simd_w_;
148 mov(reg32_mask, 0xffff);
149 cmp(reg_sz, bf16_half_reg);
150 jge(mask_label, T_NEAR);
158 kmovd(k_mask, reg32_mask);
159 zmm_t vacc = Zmm(acc_vreg_idx(0, 0));
160 vpxord(vacc, vacc, vacc);
162 for (int acc_iter = 0; acc_iter < num_acc_iters; acc_iter++) {
163 int isrc0 = 2 * acc_iter;
164 int isrc1 = 2 * acc_iter + 1;
165 zmm_t vscale = Zmm(scale_vreg_idx(acc_iter));
166 zmm_t vsrc = Zmm(src_vreg_idx(0, isrc0));
167 ymm_t vysrc0 = Ymm(src_vreg_idx(0, isrc0));
168 ymm_t vysrc1 = Ymm(src_vreg_idx(0, isrc1));
169 vpxord(vysrc0, vysrc0, vysrc0);
170 vpxord(vysrc1, vysrc1, vysrc1);
172 vmovdqu16(vysrc0 | k_mask | T_z, yword[reg_src[isrc0]]);
173 if (!(num_acc_iters * 2 > jsp.num_srcs
174 && acc_iter == num_acc_iters - 1))
175 vmovdqu16(vysrc1 | k_mask | T_z, yword[reg_src[isrc1]]);
176 vinserti64x4(vsrc, vsrc, vysrc1, 0x1);
177 vpermw(vsrc, zmm_idx, vsrc);
180 bf16_emu_->r_vdpbf16ps(vacc, vsrc, vscale);
182 vdpbf16ps(vacc, vsrc, vscale);
185 if (!jsp.is_bf16_dst) {
186 vmovups(zword[reg_dst] | k_mask, vacc);
189 auto ymm_str = Ymm(tmp_vreg_idx(0, 0));
190 vcvtneps2bf16(ymm_str, vacc);
191 vmovdqu16(yword[reg_dst] | k_mask, ymm_str);
193 auto ymm_str = Ymm(tmp_vreg_idx(0, 0));
194 bf16_emu_->r_vcvtneps2bf16(ymm_str, vacc);
195 vmovdqu16(yword[reg_dst] | k_mask, ymm_str);
199 sub(reg_sz, bf16_half_reg);
201 jle(exit_label, T_NEAR);
203 for (int s = 0; s < jsp.num_srcs; s++)
204 add(reg_src[s], bf16_half_reg * jsp.typesize_in);
205 add(reg_dst, f32_simd_w_ * jsp.typesize_out);
207 jmp(tail_label, T_NEAR);
214 const uint16_t _idx[] = { 0,16,1,17,2,18,3,19,4,20,5,21,6,22,7,23,8,24,
215 9,25,10,26,11,27,12,28,13,29,14,30,15,31 };
216 for (size_t i = 0; i < sizeof(_idx) / sizeof(_idx[0]); ++i)
220 status_t jit_avx512_core_bf16_sum_kernel::init_conf(
223 const cpu_memory_pd_t &dst_d)
225 jsp.is_cpx = mayiuse(avx512_core_bf16);
227 jsp.num_srcs = num_srcs;
230 const int max_unroll = 6; // maximum possible value of unroll is 6
231 for (/*continue*/; jsp.loop_unroll < max_unroll; jsp.loop_unroll++)
233 int num_regs = num_vregs_required(jsp.loop_unroll + 1, jsp.num_srcs);
234 if (num_regs > max_vregs_available(jsp.is_cpx))
237 const int bf16_simd_w = 32;
238 if (jsp.loop_unroll == 0) return status::unimplemented;
239 jsp.size_blocking = bf16_simd_w * jsp.loop_unroll;
241 const memory_desc_wrapper o_d(&dst_d);
242 jsp.is_bf16_dst = data_type::bf16 == o_d.data_type();
244 jsp.typesize_in = sizeof(mkldnn_bfloat16_t);
245 jsp.typesize_out = jsp.is_bf16_dst
246 ? sizeof(mkldnn_bfloat16_t)
249 return status::success;
252 template <data_type_t src_data_type, data_type_t dst_data_type>
253 void jit_bf16_sum_t<src_data_type, dst_data_type>::execute() const
255 auto output = reinterpret_cast<dst_data_t *>(this->memory());
256 const int num_arrs = pd()->n_inputs();
257 const memory_desc_wrapper o_d(pd()->dst_pd());
258 output += o_d.blk_off(0);
259 const size_t nelems = o_d.nelems();
260 const src_data_t *input_ptrs[max_num_arrs];
261 /* Number of scales needs to be multiple of 2 in order
262 to use VNNI instructions */
263 src_data_t scales[max_num_arrs];
264 for (int a = 0; a < num_arrs; ++a) {
265 const memory_desc_wrapper i_d(pd()->src_pd(a));
267 input_ptrs[a] = reinterpret_cast<const src_data_t *>(
268 this->input_memory(a)) + i_d.blk_off(0);
270 cvt_float_to_bfloat16(scales, &pd()->scales_[0], num_arrs);
271 if (num_arrs % 2 != 0)
272 scales[num_arrs] = cvt_float_to_bfloat16(0.0f);
274 const size_t half_L1 = 16 * 1024; // bytes
275 const size_t num_elems_in_block = utils::rnd_up(
276 utils::div_up(half_L1,
277 num_arrs * sizeof(src_data_t) + sizeof(dst_data_t)),
278 pd()->jsp_.size_blocking);
279 const size_t num_blocks = nelems / num_elems_in_block;
280 const size_t tail = nelems % num_elems_in_block;
282 parallel(0, [&](const int ithr, const int nthr) {
283 size_t start{0}, end{0};
284 balance211(num_blocks, nthr, ithr, start, end);
285 auto arg = jit_sum_call_s();
286 const src_data_t *local_input_ptrs[max_num_arrs];
287 dst_data_t *local_output;
289 for (size_t nb = start; nb < end; ++nb) {
290 size_t start_e = nb * num_elems_in_block;
291 for (int a = 0; a < num_arrs; ++a) {
292 local_input_ptrs[a] = &input_ptrs[a][start_e];
294 local_output = &output[start_e];
295 arg.srcs = (const void **)local_input_ptrs;
296 arg.dst = (const void *)local_output;
297 arg.scales = (const void *)scales;
298 arg.size = num_elems_in_block;
299 kernel_->jit_ker(&arg);
302 if (tail != 0 && ithr == nthr - 1) {
303 size_t start_e = nelems - tail;
304 for (int a = 0; a < num_arrs; ++a) {
305 local_input_ptrs[a] = &input_ptrs[a][start_e];
307 local_output = &output[start_e];
308 arg.srcs = (const void **)local_input_ptrs;
309 arg.dst = (const void *)local_output;
310 arg.scales = (const void *)scales;
312 kernel_->jit_ker(&arg);
317 template struct jit_bf16_sum_t<data_type::bf16, data_type::f32>;
318 template struct jit_bf16_sum_t<data_type::bf16, data_type::bf16>;