1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef _JIT_AVX512_CORE_BF16CVT_HPP
18 #define _JIT_AVX512_CORE_BF16CVT_HPP
22 #include "c_types_map.hpp"
23 #include "mkldnn_debug.h"
25 #include "type_helpers.hpp"
27 #include "jit_generator.hpp"
33 namespace bf16_cvt_utils {
42 #define GET_OFF(field) offsetof(bf16_cvt_utils::jit_call_t, field)
44 struct bf16_emulation_t {
45 using opmask_t = const Xbyak::Opmask;
46 using Zmm_t = const Xbyak::Zmm;
47 using Ymm_t = const Xbyak::Ymm;
48 using reg64_t = const Xbyak::Reg64;
50 bf16_emulation_t(jit_generator *host, Zmm_t one, Zmm_t even,
51 Zmm_t selector, reg64_t scratch, Zmm_t tr0, Zmm_t tr1)
60 void r_vdpbf16ps(Zmm_t &acc, Zmm_t wei, Zmm_t inp) {
61 host_->vpsrad(tr0_, wei, 16);
62 host_->vpslld(tr0_, tr0_, 16);
64 host_->vpsrad(tr1_, inp, 16);
65 host_->vpslld(tr1_, tr1_, 16);
67 host_->vfmadd231ps(acc, tr1_, tr0_);
69 host_->vpslld(tr0_, wei, 16);
70 host_->vpslld(tr1_, inp, 16);
72 host_->vfmadd231ps(acc, tr1_, tr0_);
75 void r_vcvtneps2bf16(Ymm_t &out, Zmm_t in) {
76 host_->vpsrld(tr0_, in, 16);
77 host_->vpandd(tr0_, tr0_, one_);
79 host_->vpaddd(tr0_, even_, tr0_);
81 host_->vpaddd(tr0_, in, tr0_);
82 host_->vfixupimmps(tr0_, in, selector_, 0);
84 host_->vpsrad(tr0_, tr0_, 16);
85 host_->vpmovdw(out, tr0_);
88 void init_vcvtneps2bf16() {
89 const int selector_int32 =
90 /* qnan input to qnan output (presenrving input bits 0..21) */
91 encode_fixup_selector(fixup_input_code_snan_, fixup_output_code_qnan_input_) |
92 /* snan input to qnan output (presenrving input bits 0..21) */
93 encode_fixup_selector(fixup_input_code_qnan_, fixup_output_code_qnan_input_) |
94 /* neg inf input copied to output */
95 encode_fixup_selector(fixup_input_code_ninf_, fixup_output_code_copy_input_) |
96 /* pos inf input copied to output */
97 encode_fixup_selector(fixup_input_code_pinf_, fixup_output_code_copy_input_);
99 host_->xor_(scratch_, scratch_);
100 host_->mov(scratch_.cvt32(), 0x1);
101 host_->vpbroadcastd(one_, scratch_.cvt32());
103 host_->xor_(scratch_, scratch_);
104 host_->mov(scratch_.cvt32(), 0x7fff);
105 host_->vpbroadcastd(even_, scratch_.cvt32());
107 host_->xor_(scratch_, scratch_);
108 host_->mov(scratch_.cvt32(), selector_int32);
109 host_->vpbroadcastd(selector_, scratch_.cvt32());
119 jit_generator *const host_;
121 inline int encode_fixup_selector(int input, int output) {
122 return ((output) << (4 * (input)));
126 fixup_input_code_qnan_ = 0,
127 fixup_input_code_snan_ = 1,
128 fixup_input_code_ninf_ = 4,
129 fixup_input_code_pinf_ = 5,
130 fixup_output_code_copy_input_ = 1,
131 fixup_output_code_qnan_input_ = 2,
136 struct jit_avx512_core_cvt_ps_to_bf16_t : public jit_generator {
137 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_cvt_ps_to_bf16)
139 jit_avx512_core_cvt_ps_to_bf16_t(void) : simd_w_(16), is_dynamic_size_(true) {
140 is_cpx_ = mayiuse(avx512_core_bf16);
141 bf16_emu_ = new bf16_emulation_t(this, one, even,
142 selector, scratch, fp32_tmp, fp32_tmp);
145 jit_ker = (void (*)(bf16_cvt_utils::jit_call_t *))getCode();
148 jit_avx512_core_cvt_ps_to_bf16_t(size_t size)
149 : size_(size), simd_w_(16), is_dynamic_size_(false) {
150 tail_mask_ = (1 << (size % simd_w_)) - 1;
152 is_cpx_ = (mayiuse(avx512_core_bf16)) ? true : false;
153 bf16_emu_ = new bf16_emulation_t(this, one, even,
154 selector, scratch, fp32_tmp, fp32_tmp);
157 jit_ker = (void (*)(bf16_cvt_utils::jit_call_t *))getCode();
160 ~jit_avx512_core_cvt_ps_to_bf16_t() { delete bf16_emu_; }
165 auto cvt = [&](size_t idx, Xbyak::Opmask ktail_mask) {
166 vmovups(fp32_inp | ktail_mask | T_z,
167 ptr[reg_inp + sizeof(float) * (idx)]);
169 bf16_emu_->r_vcvtneps2bf16(bf16_out, fp32_inp);
171 vcvtneps2bf16(bf16_out, fp32_inp);
173 yword[reg_out + sizeof(mkldnn_bfloat16_t) * (idx)] | ktail_mask,
177 mov(reg_inp, ptr[abi_param1 + GET_OFF(inp)]);
178 mov(reg_out, ptr[abi_param1 + GET_OFF(out)]);
179 if (is_dynamic_size_)
180 mov(reg_size, ptr[abi_param1 + GET_OFF(size)]);
183 bf16_emu_->init_vcvtneps2bf16();
185 mov(reg32_tail, 0xffff);
186 kmovw(ktail_mask, reg32_tail);
188 if (is_dynamic_size_) { // determine size after JIT is called
189 constexpr int n_unroll = 2; // unroll by powers of 2 from 2^n to 2^0
190 Xbyak::Label l_simd_loop[n_unroll + 2], l_simd_notail;
191 for (int i = n_unroll; i >= 0; i--) {
192 const int unroll = 1 << i; // 4, 2, 1
193 L(l_simd_loop[i + 1]); {
194 cmp(reg_size, simd_w_ * unroll);
195 jl(l_simd_loop[i], T_NEAR);
196 for (int j = 0; j < simd_w_ * unroll; j += simd_w_) {
199 add(reg_inp, simd_w_ * unroll * sizeof(float));
200 add(reg_out, simd_w_ * unroll * sizeof(mkldnn_bfloat16_t));
201 sub(reg_size, simd_w_ * unroll);
202 jmp(l_simd_loop[i + 1], T_NEAR);
206 test(reg_size, reg_size);
208 // JIT of `tail_mask_ = (1 << (size_ % simd_w_)) - 1;`
210 mov(reg64_tail, reg_size);
211 shl(reg32_mask, reg8_mask_shift);
213 kmovd(ktail_mask, reg32_mask);
219 size_t blocked_size = (size_ / simd_w_) * simd_w_;
220 const size_t loop_length = 1024;
221 const size_t number_of_loops = blocked_size / loop_length;
222 const size_t tail_of_loops = blocked_size % loop_length;
224 if (number_of_loops > 0) {
225 Xbyak::Label l_number_of_loops;
226 mov(reg_size, number_of_loops);
227 L(l_number_of_loops);
228 for (size_t i = 0; i < loop_length; i += simd_w_)
230 add(reg_inp, sizeof(float) * loop_length);
231 add(reg_out, sizeof(mkldnn_bfloat16_t) * loop_length);
235 jg(l_number_of_loops, T_NEAR);
237 if (tail_of_loops > 0) {
238 for (size_t i = 0; i < tail_of_loops; i += simd_w_)
240 add(reg_inp, sizeof(float) * tail_of_loops);
241 add(reg_out, sizeof(mkldnn_bfloat16_t) * tail_of_loops);
243 if (tail_mask_ != 0) {
244 mov(reg32_tail, tail_mask_);
245 kmovw(ktail_mask, reg32_tail);
252 void (*jit_ker)(bf16_cvt_utils::jit_call_t *);
259 bf16_emulation_t *bf16_emu_;
261 bool is_dynamic_size_;
263 Xbyak::Opmask ktail_mask = k2;
264 Xbyak::Zmm fp32_inp = Xbyak::Zmm(0);
265 Xbyak::Zmm fp32_tmp = Xbyak::Zmm(1);
267 Xbyak::Zmm one = Xbyak::Zmm(2);
268 Xbyak::Zmm even = Xbyak::Zmm(3);
269 Xbyak::Zmm selector = Xbyak::Zmm(4);
271 Xbyak::Ymm bf16_out = Xbyak::Ymm(5);
273 Xbyak::Reg64 scratch = r15;
274 Xbyak::Reg64 reg_inp = rax;
275 Xbyak::Reg64 reg_out = rbx;
276 Xbyak::Reg64 reg_size = rdx;
278 Xbyak::Reg64 reg64_tail = rcx;
279 Xbyak::Reg32 reg32_tail = ecx;
280 Xbyak::Reg8 reg8_mask_shift = cl;
281 Xbyak::Reg32 reg32_mask = r8d;
285 struct jit_avx512_core_cvt_bf16_to_ps_t : public jit_generator {
286 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_cvt_bf16_to_ps_t)
288 jit_avx512_core_cvt_bf16_to_ps_t(void) : simd_w_(16), is_dynamic_size_(true) {
290 jit_ker = (void (*)(bf16_cvt_utils::jit_call_t *))getCode();
293 jit_avx512_core_cvt_bf16_to_ps_t(size_t size)
294 : size_(size), simd_w_(16), is_dynamic_size_(false) {
295 tail_mask_ = (1 << (size_ % simd_w_)) - 1;
298 jit_ker = (void (*)(bf16_cvt_utils::jit_call_t *))getCode();
304 mov(reg_inp, ptr[abi_param1 + GET_OFF(inp)]);
305 mov(reg_out, ptr[abi_param1 + GET_OFF(out)]);
307 if (is_dynamic_size_) { // determine size after JIT is called
308 mov(reg_size, ptr[abi_param1 + GET_OFF(size)]);
309 constexpr int n_unroll = 2; // unroll by powers of 2 from 2^n to 2^0
310 Xbyak::Label l_simd_loop[n_unroll + 2], l_simd_notail;
311 for (int i = n_unroll; i >= 0; i--) {
312 const int unroll = 1 << i; // 4, 2, 1
313 L(l_simd_loop[i + 1]); {
314 cmp(reg_size, simd_w_ * unroll);
315 jl(l_simd_loop[i], T_NEAR);
316 for (int j = 0; j < simd_w_ * unroll; j += simd_w_) {
318 ptr[reg_inp + sizeof(mkldnn_bfloat16_t) * j]);
319 vpslld(zmm_cvt, zmm_cvt, 0x10);
320 vmovdqu32(zword[reg_out + sizeof(float) * j], zmm_cvt);
322 add(reg_inp, simd_w_ * unroll * sizeof(mkldnn_bfloat16_t));
323 add(reg_out, simd_w_ * unroll * sizeof(float));
324 sub(reg_size, simd_w_ * unroll);
325 jmp(l_simd_loop[i + 1], T_NEAR);
329 test(reg_size, reg_size);
331 // JIT of `tail_mask_ = (1 << (size_ % simd_w_)) - 1;`
333 mov(reg64_tail, reg_size);
334 shl(reg32_mask, reg8_mask_shift);
336 kmovd(ktail_mask, reg32_mask);
337 vpmovzxwd(zmm_cvt | ktail_mask | T_z, ptr[reg_inp]);
338 vpslld(zmm_cvt, zmm_cvt, 0x10);
339 vmovdqu32(zword[reg_out] | ktail_mask, zmm_cvt);
344 size_t blocked_size = (size_ / simd_w_) * simd_w_;
345 const size_t loop_length = 1024;
346 const size_t number_of_loops = blocked_size / loop_length;
347 const size_t tail_of_loops = blocked_size % loop_length;
349 if (number_of_loops > 0) {
350 Xbyak::Label l_number_of_loops;
351 mov(reg_size, number_of_loops);
352 L(l_number_of_loops);
353 for (size_t i = 0; i < loop_length; i += simd_w_) {
354 vpmovzxwd(zmm_cvt, ptr[reg_inp + sizeof(mkldnn_bfloat16_t) * i]);
355 vpslld(zmm_cvt, zmm_cvt, 0x10);
356 vmovups(zword[reg_out + sizeof(float) * i], zmm_cvt);
358 add(reg_inp, sizeof(mkldnn_bfloat16_t) * loop_length);
359 add(reg_out, sizeof(float) * loop_length);
363 jg(l_number_of_loops, T_NEAR);
366 if (tail_of_loops > 0) {
367 for (size_t i = 0; i < tail_of_loops; i += simd_w_) {
368 vpmovzxwd(zmm_cvt, ptr[reg_inp + sizeof(mkldnn_bfloat16_t) * i]);
369 vpslld(zmm_cvt, zmm_cvt, 0x10);
370 vmovups(zword[reg_out + sizeof(float) * (i)], zmm_cvt);
372 add(reg_inp, sizeof(mkldnn_bfloat16_t) * tail_of_loops);
373 add(reg_out, sizeof(float) * tail_of_loops);
375 if (tail_mask_ != 0) {
376 mov(reg32_mask, tail_mask_);
377 kmovw(ktail_mask, reg32_mask);
379 vpmovzxwd(zmm_cvt | ktail_mask | T_z, ptr[reg_inp]);
380 vpslld(zmm_cvt, zmm_cvt, 0x10);
381 vmovups(zword[reg_out] | ktail_mask, zmm_cvt);
388 void (*jit_ker)(bf16_cvt_utils::jit_call_t *);
394 bool is_dynamic_size_;
396 Xbyak::Opmask ktail_mask = k1;
397 Xbyak::Zmm zmm_cvt = Xbyak::Zmm(0);
399 Xbyak::Reg64 reg_inp = rax;
400 Xbyak::Reg64 reg_out = rbx;
401 Xbyak::Reg64 reg_size = rdx;
403 Xbyak::Reg64 reg64_tail = rcx;
404 Xbyak::Reg32 reg32_tail = ecx;
405 Xbyak::Reg8 reg8_mask_shift = cl;
406 Xbyak::Reg32 reg32_mask = r8d;
409 // performs element-by-element sum of inp and add float arrays and stores
410 // result to bfloat16 out array with downconversion
411 struct jit_avx512_core_add_cvt_ps_to_bf16_t : public jit_generator {
412 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_add_cvt_ps_to_bf16)
414 jit_avx512_core_add_cvt_ps_to_bf16_t() : simd_w_(16) {
415 is_cpx_ = mayiuse(avx512_core_bf16);
416 bf16_emu_ = new bf16_emulation_t(this, one, even,
417 selector, scratch, fp32_tmp, fp32_tmp);
420 jit_ker = (void (*)(bf16_cvt_utils::jit_call_t *))getCode();
423 ~jit_avx512_core_add_cvt_ps_to_bf16_t() { delete bf16_emu_; }
428 auto add_cvt = [&](size_t idx, Xbyak::Opmask ktail_mask) {
429 vmovups(fp32_inp | ktail_mask | T_z, ptr[reg_inp + sizeof(float) * (idx)]);
430 vaddps(fp32_inp | ktail_mask | T_z, fp32_inp, ptr[reg_add + sizeof(float) * (idx)]);
432 bf16_emu_->r_vcvtneps2bf16(bf16_out, fp32_inp);
434 vcvtneps2bf16(bf16_out, fp32_inp);
437 yword[reg_out + sizeof(mkldnn_bfloat16_t) * (idx)] | ktail_mask,
441 mov(reg_inp, ptr[abi_param1 + GET_OFF(inp)]);
442 mov(reg_add, ptr[abi_param1 + GET_OFF(add)]);
443 mov(reg_out, ptr[abi_param1 + GET_OFF(out)]);
444 mov(reg_size, ptr[abi_param1 + GET_OFF(size)]);
447 bf16_emu_->init_vcvtneps2bf16();
449 mov(reg32_tail, 0xffff);
450 kmovw(ktail_mask, reg32_tail);
452 constexpr int n_unroll = 2; // unroll by powers of 2 from 2^n to 2^0
453 Xbyak::Label l_simd_loop[n_unroll + 2], l_simd_notail;
454 for (int i = n_unroll; i >= 0; i--) {
455 const int unroll = 1 << i; // 4, 2, 1
456 L(l_simd_loop[i + 1]); {
457 cmp(reg_size, simd_w_ * unroll);
458 jl(l_simd_loop[i], T_NEAR);
459 for (int j = 0; j < simd_w_ * unroll; j += simd_w_) {
460 add_cvt(j, ktail_mask);
462 add(reg_inp, simd_w_ * unroll * sizeof(float));
463 add(reg_add, simd_w_ * unroll * sizeof(float));
464 add(reg_out, simd_w_ * unroll * sizeof(mkldnn_bfloat16_t));
466 sub(reg_size, simd_w_ * unroll);
467 jmp(l_simd_loop[i + 1], T_NEAR);
471 test(reg_size, reg_size);
473 // JIT of `tail_mask_ = (1 << (size_ % simd_w_)) - 1;`
475 mov(reg64_tail, reg_size);
476 shl(reg32_mask, reg8_mask_shift);
478 kmovd(ktail_mask, reg32_mask);
479 add_cvt(0, ktail_mask);
485 void (*jit_ker)(bf16_cvt_utils::jit_call_t *);
490 bf16_emulation_t *bf16_emu_;
493 Xbyak::Opmask ktail_mask = k2;
494 Xbyak::Zmm fp32_inp = Xbyak::Zmm(0);
495 Xbyak::Zmm fp32_tmp = Xbyak::Zmm(1);
497 Xbyak::Zmm one = Xbyak::Zmm(2);
498 Xbyak::Zmm even = Xbyak::Zmm(3);
499 Xbyak::Zmm selector = Xbyak::Zmm(4);
500 Xbyak::Reg64 scratch = r15;
502 Xbyak::Ymm bf16_out = Xbyak::Ymm(5);
504 Xbyak::Reg64 reg_inp = rax;
505 Xbyak::Reg64 reg_out = rbx;
506 Xbyak::Reg64 reg_add = r11;
507 Xbyak::Reg64 reg_size = rdx;
509 Xbyak::Reg64 reg64_tail = rcx;
510 Xbyak::Reg32 reg32_tail = ecx;
511 Xbyak::Reg8 reg8_mask_shift = cl;
512 Xbyak::Reg32 reg32_mask = r8d;
515 // implementation of reorder of part of tensor [s][16c] -> [S][16c][2s]
516 // it is required for quick implementation of 1x1 bf16 bwd_w jit kernel
517 // w/o using permw instruction inside
518 // TODO: consider modification/replacement for outer transformation jit kernel
519 struct jit_avx512_core_bf16_reorder_s16c_to_S16c2s_t : public jit_generator {
520 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_bf16_reorder_s16c_to_S16c2s)
522 jit_avx512_core_bf16_reorder_s16c_to_S16c2s_t() : simd_w_(16) {
524 jit_ker = (void (*)(bf16_cvt_utils::jit_call_t *))getCode();
527 ~jit_avx512_core_bf16_reorder_s16c_to_S16c2s_t() {}
532 mov(reg_inp, ptr[abi_param1 + GET_OFF(inp)]);
533 mov(reg_out, ptr[abi_param1 + GET_OFF(out)]);
534 mov(reg_size, ptr[abi_param1 + GET_OFF(size)]);
536 auto zmm_reg = [=](int idx) {
538 return Xbyak::Zmm(idx);
541 Xbyak::Label dst_prm_table;
542 mov(reg_prm, dst_prm_table);
543 vmovups(zmm_prm, ptr[reg_prm]);
545 constexpr int n_unroll = 2; // unroll by powers of 2 from 2^n to 2^0
546 int sizeofcacheline = 2 * simd_w_ * sizeof(mkldnn_bfloat16_t);
547 Xbyak::Label l_simd_loop[n_unroll + 2], l_simd_notail;
548 for (int i = n_unroll; i >= 0; i--) {
549 const int unroll = 1 << i; // 4, 2, 1
550 L(l_simd_loop[i + 1]); {
551 cmp(reg_size, 2 * unroll);
552 jl(l_simd_loop[i], T_NEAR);
553 for (int j = 0; j < unroll; j++) {
554 auto zmm_inp = zmm_reg(j);
555 vmovups(zmm_inp, zword[reg_inp + j * sizeofcacheline]);
556 vpermw(zmm_inp, zmm_prm, zmm_inp);
557 vmovups(zword[reg_out + j * sizeofcacheline], zmm_inp);
559 add(reg_inp, unroll * sizeofcacheline);
560 add(reg_out, unroll * sizeofcacheline);
562 sub(reg_size, 2 * unroll);
563 jmp(l_simd_loop[i + 1], T_NEAR);
568 test(reg_size, reg_size);
571 mov(reg32_tail, 0x00ff);
572 kmovw(ktail_mask, reg32_tail);
574 auto zmm_inp = zmm_reg(0);
575 vpxord(zmm_inp, zmm_inp, zmm_inp);
576 vmovups(zmm_inp | ktail_mask | T_z, ptr[reg_inp]);
577 vpermw(zmm_inp, zmm_prm, zmm_inp);
578 vmovups(zword[reg_out], zmm_inp);
584 const uint16_t dst_prm_array[32] =
585 {0,16, 1,17, 2,18, 3,19, 4,20, 5,21, 6,22, 7,23, 8,24,
586 9,25, 10,26, 11,27, 12,28, 13,29, 14,30, 15,31 };
590 for (size_t i = 0; i < 32; ++i)
591 dw(dst_prm_array[i]);
594 void (*jit_ker)(bf16_cvt_utils::jit_call_t *);
599 Xbyak::Opmask ktail_mask = k2;
600 Xbyak::Zmm zmm_prm = Xbyak::Zmm(31);
602 Xbyak::Reg64 reg_inp = rax;
603 Xbyak::Reg64 reg_out = rbx;
604 Xbyak::Reg64 reg_prm = r11;
605 Xbyak::Reg64 reg_size = rdx;
607 Xbyak::Reg32 reg32_tail = abi_not_param1.cvt32();