updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_bf16cvt.hpp
1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef _JIT_AVX512_CORE_BF16CVT_HPP
18 #define _JIT_AVX512_CORE_BF16CVT_HPP
19
20 #include <assert.h>
21
22 #include "c_types_map.hpp"
23 #include "mkldnn_debug.h"
24 #include "nstl.hpp"
25 #include "type_helpers.hpp"
26
27 #include "jit_generator.hpp"
28
29 namespace mkldnn {
30 namespace impl {
31 namespace cpu {
32
33 namespace bf16_cvt_utils {
34 struct jit_call_t {
35     void *inp;
36     void *out;
37     void *add;
38     size_t size;
39 };
40 }
41
42 #define GET_OFF(field) offsetof(bf16_cvt_utils::jit_call_t, field)
43
44 struct bf16_emulation_t {
45     using opmask_t = const Xbyak::Opmask;
46     using Zmm_t = const Xbyak::Zmm;
47     using Ymm_t = const Xbyak::Ymm;
48     using reg64_t = const Xbyak::Reg64;
49
50     bf16_emulation_t(jit_generator *host, Zmm_t one, Zmm_t even,
51             Zmm_t selector, reg64_t scratch, Zmm_t tr0, Zmm_t tr1)
52         : one_(one)
53         , even_(even)
54         , selector_(selector)
55         , tr0_(tr0)
56         , tr1_(tr1)
57         , scratch_(scratch)
58         , host_(host) {}
59
60     void r_vdpbf16ps(Zmm_t &acc, Zmm_t wei, Zmm_t inp) {
61         host_->vpsrad(tr0_, wei, 16);
62         host_->vpslld(tr0_, tr0_, 16);
63
64         host_->vpsrad(tr1_, inp, 16);
65         host_->vpslld(tr1_, tr1_, 16);
66
67         host_->vfmadd231ps(acc, tr1_, tr0_);
68
69         host_->vpslld(tr0_, wei, 16);
70         host_->vpslld(tr1_, inp, 16);
71
72         host_->vfmadd231ps(acc, tr1_, tr0_);
73     }
74
75     void r_vcvtneps2bf16(Ymm_t &out, Zmm_t in) {
76         host_->vpsrld(tr0_, in, 16);
77         host_->vpandd(tr0_, tr0_, one_);
78
79         host_->vpaddd(tr0_, even_, tr0_);
80
81         host_->vpaddd(tr0_, in, tr0_);
82         host_->vfixupimmps(tr0_, in, selector_, 0);
83
84         host_->vpsrad(tr0_, tr0_, 16);
85         host_->vpmovdw(out, tr0_);
86     }
87
88     void init_vcvtneps2bf16() {
89         const int selector_int32 =
90             /* qnan input to qnan output (presenrving input bits 0..21) */
91             encode_fixup_selector(fixup_input_code_snan_, fixup_output_code_qnan_input_) |
92             /* snan input to qnan output (presenrving input bits 0..21) */
93             encode_fixup_selector(fixup_input_code_qnan_, fixup_output_code_qnan_input_) |
94             /* neg inf input copied to output */
95             encode_fixup_selector(fixup_input_code_ninf_, fixup_output_code_copy_input_) |
96             /* pos inf input copied to output */
97             encode_fixup_selector(fixup_input_code_pinf_, fixup_output_code_copy_input_);
98
99         host_->xor_(scratch_, scratch_);
100         host_->mov(scratch_.cvt32(), 0x1);
101         host_->vpbroadcastd(one_, scratch_.cvt32());
102
103         host_->xor_(scratch_, scratch_);
104         host_->mov(scratch_.cvt32(), 0x7fff);
105         host_->vpbroadcastd(even_, scratch_.cvt32());
106
107         host_->xor_(scratch_, scratch_);
108         host_->mov(scratch_.cvt32(), selector_int32);
109         host_->vpbroadcastd(selector_, scratch_.cvt32());
110     }
111
112 private:
113     Zmm_t one_;
114     Zmm_t even_;
115     Zmm_t selector_;
116     Zmm_t tr0_;
117     Zmm_t tr1_;
118     reg64_t scratch_;
119     jit_generator *const host_;
120
121     inline int encode_fixup_selector(int input, int output) {
122         return ((output) << (4 * (input)));
123     }
124
125     enum {
126         fixup_input_code_qnan_ = 0,
127         fixup_input_code_snan_ = 1,
128         fixup_input_code_ninf_ = 4,
129         fixup_input_code_pinf_ = 5,
130         fixup_output_code_copy_input_ = 1,
131         fixup_output_code_qnan_input_ = 2,
132     };
133
134 };
135
136 struct jit_avx512_core_cvt_ps_to_bf16_t : public jit_generator {
137     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_cvt_ps_to_bf16)
138
139     jit_avx512_core_cvt_ps_to_bf16_t(void) : simd_w_(16), is_dynamic_size_(true) {
140         is_cpx_ = mayiuse(avx512_core_bf16);
141         bf16_emu_ = new bf16_emulation_t(this, one, even,
142                 selector, scratch, fp32_tmp, fp32_tmp);
143
144         generate();
145         jit_ker = (void (*)(bf16_cvt_utils::jit_call_t *))getCode();
146     }
147
148     jit_avx512_core_cvt_ps_to_bf16_t(size_t size)
149         : size_(size), simd_w_(16), is_dynamic_size_(false) {
150         tail_mask_ = (1 << (size % simd_w_)) - 1;
151
152         is_cpx_ = (mayiuse(avx512_core_bf16)) ? true : false;
153         bf16_emu_ = new bf16_emulation_t(this, one, even,
154                 selector, scratch, fp32_tmp, fp32_tmp);
155
156         generate();
157         jit_ker = (void (*)(bf16_cvt_utils::jit_call_t *))getCode();
158     }
159
160     ~jit_avx512_core_cvt_ps_to_bf16_t() { delete bf16_emu_; }
161
162     void generate() {
163         preamble();
164
165         auto cvt = [&](size_t idx, Xbyak::Opmask ktail_mask) {
166             vmovups(fp32_inp | ktail_mask | T_z,
167                     ptr[reg_inp + sizeof(float) * (idx)]);
168             if (!is_cpx_)
169                 bf16_emu_->r_vcvtneps2bf16(bf16_out, fp32_inp);
170             else
171                 vcvtneps2bf16(bf16_out, fp32_inp);
172             vmovdqu16(
173                     yword[reg_out + sizeof(mkldnn_bfloat16_t) * (idx)] | ktail_mask,
174                     bf16_out);
175         };
176
177         mov(reg_inp, ptr[abi_param1 + GET_OFF(inp)]);
178         mov(reg_out, ptr[abi_param1 + GET_OFF(out)]);
179         if (is_dynamic_size_)
180             mov(reg_size, ptr[abi_param1 + GET_OFF(size)]);
181
182         if (!is_cpx_)
183             bf16_emu_->init_vcvtneps2bf16();
184
185         mov(reg32_tail, 0xffff);
186         kmovw(ktail_mask, reg32_tail);
187
188         if (is_dynamic_size_) { // determine size after JIT is called
189             constexpr int n_unroll = 2; // unroll by powers of 2 from 2^n to 2^0
190             Xbyak::Label l_simd_loop[n_unroll + 2], l_simd_notail;
191             for (int i = n_unroll; i >= 0; i--) {
192                 const int unroll = 1 << i; // 4, 2, 1
193                 L(l_simd_loop[i + 1]); {
194                     cmp(reg_size, simd_w_ * unroll);
195                     jl(l_simd_loop[i], T_NEAR);
196                     for (int j = 0; j < simd_w_ * unroll; j += simd_w_) {
197                         cvt(j, ktail_mask);
198                     }
199                     add(reg_inp, simd_w_ * unroll * sizeof(float));
200                     add(reg_out, simd_w_ * unroll * sizeof(mkldnn_bfloat16_t));
201                     sub(reg_size, simd_w_ * unroll);
202                     jmp(l_simd_loop[i + 1], T_NEAR);
203                 }
204             }
205             L(l_simd_loop[0]);
206             test(reg_size, reg_size);
207             jz(l_simd_notail);
208             // JIT of `tail_mask_ = (1 << (size_ % simd_w_)) - 1;`
209             mov(reg32_mask, 1);
210             mov(reg64_tail, reg_size);
211             shl(reg32_mask, reg8_mask_shift);
212             sub(reg32_mask, 1);
213             kmovd(ktail_mask, reg32_mask);
214             cvt(0, ktail_mask);
215             L(l_simd_notail);
216
217         } else {
218
219             size_t blocked_size = (size_ / simd_w_) * simd_w_;
220             const size_t loop_length = 1024;
221             const size_t number_of_loops = blocked_size / loop_length;
222             const size_t tail_of_loops = blocked_size % loop_length;
223
224             if (number_of_loops > 0) {
225                 Xbyak::Label l_number_of_loops;
226                 mov(reg_size, number_of_loops);
227                 L(l_number_of_loops);
228                 for (size_t i = 0; i < loop_length; i += simd_w_)
229                     cvt(i, ktail_mask);
230                 add(reg_inp, sizeof(float) * loop_length);
231                 add(reg_out, sizeof(mkldnn_bfloat16_t) * loop_length);
232
233                 dec(reg_size);
234                 cmp(reg_size, 0);
235                 jg(l_number_of_loops, T_NEAR);
236             }
237             if (tail_of_loops > 0) {
238                 for (size_t i = 0; i < tail_of_loops; i += simd_w_)
239                     cvt(i, ktail_mask);
240                 add(reg_inp, sizeof(float) * tail_of_loops);
241                 add(reg_out, sizeof(mkldnn_bfloat16_t) * tail_of_loops);
242             }
243             if (tail_mask_ != 0) {
244                 mov(reg32_tail, tail_mask_);
245                 kmovw(ktail_mask, reg32_tail);
246                 cvt(0, ktail_mask);
247             }
248         }
249         postamble();
250     }
251
252     void (*jit_ker)(bf16_cvt_utils::jit_call_t *);
253
254 private:
255     size_t size_;
256     int tail_mask_;
257     int simd_w_;
258
259     bf16_emulation_t *bf16_emu_;
260     bool is_cpx_;
261     bool is_dynamic_size_;
262
263     Xbyak::Opmask ktail_mask = k2;
264     Xbyak::Zmm fp32_inp = Xbyak::Zmm(0);
265     Xbyak::Zmm fp32_tmp = Xbyak::Zmm(1);
266
267     Xbyak::Zmm one = Xbyak::Zmm(2);
268     Xbyak::Zmm even = Xbyak::Zmm(3);
269     Xbyak::Zmm selector = Xbyak::Zmm(4);
270
271     Xbyak::Ymm bf16_out = Xbyak::Ymm(5);
272
273     Xbyak::Reg64 scratch = r15;
274     Xbyak::Reg64 reg_inp = rax;
275     Xbyak::Reg64 reg_out = rbx;
276     Xbyak::Reg64 reg_size = rdx;
277
278     Xbyak::Reg64 reg64_tail = rcx;
279     Xbyak::Reg32 reg32_tail = ecx;
280     Xbyak::Reg8 reg8_mask_shift = cl;
281     Xbyak::Reg32 reg32_mask = r8d;
282
283 };
284
285 struct jit_avx512_core_cvt_bf16_to_ps_t : public jit_generator {
286     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_cvt_bf16_to_ps_t)
287
288     jit_avx512_core_cvt_bf16_to_ps_t(void) : simd_w_(16), is_dynamic_size_(true) {
289         generate();
290         jit_ker = (void (*)(bf16_cvt_utils::jit_call_t *))getCode();
291     }
292
293     jit_avx512_core_cvt_bf16_to_ps_t(size_t size)
294         : size_(size), simd_w_(16), is_dynamic_size_(false) {
295         tail_mask_ = (1 << (size_ % simd_w_)) - 1;
296
297         generate();
298         jit_ker = (void (*)(bf16_cvt_utils::jit_call_t *))getCode();
299     }
300
301     void generate() {
302         preamble();
303
304         mov(reg_inp, ptr[abi_param1 + GET_OFF(inp)]);
305         mov(reg_out, ptr[abi_param1 + GET_OFF(out)]);
306
307         if (is_dynamic_size_) { // determine size after JIT is called
308             mov(reg_size, ptr[abi_param1 + GET_OFF(size)]);
309             constexpr int n_unroll = 2; // unroll by powers of 2 from 2^n to 2^0
310             Xbyak::Label l_simd_loop[n_unroll + 2], l_simd_notail;
311             for (int i = n_unroll; i >= 0; i--) {
312                 const int unroll = 1 << i; // 4, 2, 1
313                 L(l_simd_loop[i + 1]); {
314                     cmp(reg_size, simd_w_ * unroll);
315                     jl(l_simd_loop[i], T_NEAR);
316                     for (int j = 0; j < simd_w_ * unroll; j += simd_w_) {
317                         vpmovzxwd(zmm_cvt,
318                                 ptr[reg_inp + sizeof(mkldnn_bfloat16_t) * j]);
319                         vpslld(zmm_cvt, zmm_cvt, 0x10);
320                         vmovdqu32(zword[reg_out + sizeof(float) * j], zmm_cvt);
321                     }
322                     add(reg_inp, simd_w_ * unroll * sizeof(mkldnn_bfloat16_t));
323                     add(reg_out, simd_w_ * unroll * sizeof(float));
324                     sub(reg_size, simd_w_ * unroll);
325                     jmp(l_simd_loop[i + 1], T_NEAR);
326                 }
327             }
328             L(l_simd_loop[0]);
329             test(reg_size, reg_size);
330             jz(l_simd_notail);
331             // JIT of `tail_mask_ = (1 << (size_ % simd_w_)) - 1;`
332             mov(reg32_mask, 1);
333             mov(reg64_tail, reg_size);
334             shl(reg32_mask, reg8_mask_shift);
335             sub(reg32_mask, 1);
336             kmovd(ktail_mask, reg32_mask);
337             vpmovzxwd(zmm_cvt | ktail_mask | T_z, ptr[reg_inp]);
338             vpslld(zmm_cvt, zmm_cvt, 0x10);
339             vmovdqu32(zword[reg_out] | ktail_mask, zmm_cvt);
340             L(l_simd_notail);
341
342         } else {
343
344             size_t blocked_size = (size_ / simd_w_) * simd_w_;
345             const size_t loop_length = 1024;
346             const size_t number_of_loops = blocked_size / loop_length;
347             const size_t tail_of_loops = blocked_size % loop_length;
348
349             if (number_of_loops > 0) {
350                 Xbyak::Label l_number_of_loops;
351                 mov(reg_size, number_of_loops);
352                 L(l_number_of_loops);
353                 for (size_t i = 0; i < loop_length; i += simd_w_) {
354                     vpmovzxwd(zmm_cvt, ptr[reg_inp + sizeof(mkldnn_bfloat16_t) * i]);
355                     vpslld(zmm_cvt, zmm_cvt, 0x10);
356                     vmovups(zword[reg_out + sizeof(float) * i], zmm_cvt);
357                 }
358                 add(reg_inp, sizeof(mkldnn_bfloat16_t) * loop_length);
359                 add(reg_out, sizeof(float) * loop_length);
360
361                 dec(reg_size);
362                 cmp(reg_size, 0);
363                 jg(l_number_of_loops, T_NEAR);
364             }
365
366             if (tail_of_loops > 0) {
367                 for (size_t i = 0; i < tail_of_loops; i += simd_w_) {
368                     vpmovzxwd(zmm_cvt, ptr[reg_inp + sizeof(mkldnn_bfloat16_t) * i]);
369                     vpslld(zmm_cvt, zmm_cvt, 0x10);
370                     vmovups(zword[reg_out + sizeof(float) * (i)], zmm_cvt);
371                 }
372                 add(reg_inp, sizeof(mkldnn_bfloat16_t) * tail_of_loops);
373                 add(reg_out, sizeof(float) * tail_of_loops);
374             }
375             if (tail_mask_ != 0) {
376                 mov(reg32_mask, tail_mask_);
377                 kmovw(ktail_mask, reg32_mask);
378
379                 vpmovzxwd(zmm_cvt | ktail_mask | T_z, ptr[reg_inp]);
380                 vpslld(zmm_cvt, zmm_cvt, 0x10);
381                 vmovups(zword[reg_out] | ktail_mask, zmm_cvt);
382             }
383         }
384
385         postamble();
386     }
387
388     void (*jit_ker)(bf16_cvt_utils::jit_call_t *);
389
390 private:
391     size_t size_;
392     int tail_mask_;
393     int simd_w_;
394     bool is_dynamic_size_;
395
396     Xbyak::Opmask ktail_mask = k1;
397     Xbyak::Zmm zmm_cvt = Xbyak::Zmm(0);
398
399     Xbyak::Reg64 reg_inp = rax;
400     Xbyak::Reg64 reg_out = rbx;
401     Xbyak::Reg64 reg_size = rdx;
402
403     Xbyak::Reg64 reg64_tail = rcx;
404     Xbyak::Reg32 reg32_tail = ecx;
405     Xbyak::Reg8 reg8_mask_shift = cl;
406     Xbyak::Reg32 reg32_mask = r8d;
407 };
408
409 // performs element-by-element sum of inp and add float arrays and stores
410 // result to bfloat16 out array with downconversion
411 struct jit_avx512_core_add_cvt_ps_to_bf16_t : public jit_generator {
412     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_add_cvt_ps_to_bf16)
413
414     jit_avx512_core_add_cvt_ps_to_bf16_t() : simd_w_(16) {
415         is_cpx_ = mayiuse(avx512_core_bf16);
416         bf16_emu_ = new bf16_emulation_t(this, one, even,
417                 selector, scratch, fp32_tmp, fp32_tmp);
418
419         generate();
420         jit_ker = (void (*)(bf16_cvt_utils::jit_call_t *))getCode();
421     }
422
423     ~jit_avx512_core_add_cvt_ps_to_bf16_t() { delete bf16_emu_; }
424
425     void generate() {
426         preamble();
427
428         auto add_cvt = [&](size_t idx, Xbyak::Opmask ktail_mask) {
429             vmovups(fp32_inp | ktail_mask | T_z, ptr[reg_inp + sizeof(float) * (idx)]);
430             vaddps(fp32_inp | ktail_mask | T_z, fp32_inp, ptr[reg_add + sizeof(float) * (idx)]);
431             if (!is_cpx_)
432                 bf16_emu_->r_vcvtneps2bf16(bf16_out, fp32_inp);
433             else
434                 vcvtneps2bf16(bf16_out, fp32_inp);
435
436             vmovdqu16(
437                     yword[reg_out + sizeof(mkldnn_bfloat16_t) * (idx)] | ktail_mask,
438                     bf16_out);
439         };
440
441         mov(reg_inp, ptr[abi_param1 + GET_OFF(inp)]);
442         mov(reg_add, ptr[abi_param1 + GET_OFF(add)]);
443         mov(reg_out, ptr[abi_param1 + GET_OFF(out)]);
444         mov(reg_size, ptr[abi_param1 + GET_OFF(size)]);
445
446         if (!is_cpx_)
447             bf16_emu_->init_vcvtneps2bf16();
448
449         mov(reg32_tail, 0xffff);
450         kmovw(ktail_mask, reg32_tail);
451
452         constexpr int n_unroll = 2; // unroll by powers of 2 from 2^n to 2^0
453         Xbyak::Label l_simd_loop[n_unroll + 2], l_simd_notail;
454         for (int i = n_unroll; i >= 0; i--) {
455             const int unroll = 1 << i; // 4, 2, 1
456             L(l_simd_loop[i + 1]); {
457                 cmp(reg_size, simd_w_ * unroll);
458                 jl(l_simd_loop[i], T_NEAR);
459                 for (int j = 0; j < simd_w_ * unroll; j += simd_w_) {
460                     add_cvt(j, ktail_mask);
461                 }
462                 add(reg_inp, simd_w_ * unroll * sizeof(float));
463                 add(reg_add, simd_w_ * unroll * sizeof(float));
464                 add(reg_out, simd_w_ * unroll * sizeof(mkldnn_bfloat16_t));
465
466                 sub(reg_size, simd_w_ * unroll);
467                 jmp(l_simd_loop[i + 1], T_NEAR);
468             }
469         }
470         L(l_simd_loop[0]);
471         test(reg_size, reg_size);
472         jz(l_simd_notail);
473         // JIT of `tail_mask_ = (1 << (size_ % simd_w_)) - 1;`
474         mov(reg32_mask, 1);
475         mov(reg64_tail, reg_size);
476         shl(reg32_mask, reg8_mask_shift);
477         sub(reg32_mask, 1);
478         kmovd(ktail_mask, reg32_mask);
479         add_cvt(0, ktail_mask);
480         L(l_simd_notail);
481
482         postamble();
483     }
484
485     void (*jit_ker)(bf16_cvt_utils::jit_call_t *);
486
487 private:
488     int simd_w_;
489
490     bf16_emulation_t *bf16_emu_;
491     bool is_cpx_;
492
493     Xbyak::Opmask ktail_mask = k2;
494     Xbyak::Zmm fp32_inp = Xbyak::Zmm(0);
495     Xbyak::Zmm fp32_tmp = Xbyak::Zmm(1);
496
497     Xbyak::Zmm one = Xbyak::Zmm(2);
498     Xbyak::Zmm even = Xbyak::Zmm(3);
499     Xbyak::Zmm selector = Xbyak::Zmm(4);
500     Xbyak::Reg64 scratch = r15;
501
502     Xbyak::Ymm bf16_out = Xbyak::Ymm(5);
503
504     Xbyak::Reg64 reg_inp = rax;
505     Xbyak::Reg64 reg_out = rbx;
506     Xbyak::Reg64 reg_add = r11;
507     Xbyak::Reg64 reg_size = rdx;
508
509     Xbyak::Reg64 reg64_tail = rcx;
510     Xbyak::Reg32 reg32_tail = ecx;
511     Xbyak::Reg8 reg8_mask_shift = cl;
512     Xbyak::Reg32 reg32_mask = r8d;
513 };
514
515 // implementation of reorder of part of tensor [s][16c] -> [S][16c][2s]
516 // it is required for quick implementation of 1x1 bf16 bwd_w jit kernel
517 // w/o using permw instruction inside
518 // TODO: consider modification/replacement for outer transformation jit kernel
519 struct jit_avx512_core_bf16_reorder_s16c_to_S16c2s_t : public jit_generator {
520     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_bf16_reorder_s16c_to_S16c2s)
521
522     jit_avx512_core_bf16_reorder_s16c_to_S16c2s_t() : simd_w_(16) {
523         generate();
524         jit_ker = (void (*)(bf16_cvt_utils::jit_call_t *))getCode();
525     }
526
527     ~jit_avx512_core_bf16_reorder_s16c_to_S16c2s_t() {}
528
529     void generate() {
530         preamble();
531
532         mov(reg_inp, ptr[abi_param1 + GET_OFF(inp)]);
533         mov(reg_out, ptr[abi_param1 + GET_OFF(out)]);
534         mov(reg_size, ptr[abi_param1 + GET_OFF(size)]);
535
536         auto zmm_reg = [=](int idx) {
537             assert(idx < 31);
538             return Xbyak::Zmm(idx);
539         };
540
541         Xbyak::Label dst_prm_table;
542         mov(reg_prm, dst_prm_table);
543         vmovups(zmm_prm, ptr[reg_prm]);
544
545         constexpr int n_unroll = 2; // unroll by powers of 2 from 2^n to 2^0
546         int sizeofcacheline = 2 * simd_w_ * sizeof(mkldnn_bfloat16_t);
547         Xbyak::Label l_simd_loop[n_unroll + 2], l_simd_notail;
548         for (int i = n_unroll; i >= 0; i--) {
549             const int unroll = 1 << i; // 4, 2, 1
550             L(l_simd_loop[i + 1]); {
551                 cmp(reg_size, 2 * unroll);
552                 jl(l_simd_loop[i], T_NEAR);
553                 for (int j = 0; j < unroll; j++) {
554                      auto zmm_inp = zmm_reg(j);
555                      vmovups(zmm_inp, zword[reg_inp + j * sizeofcacheline]);
556                      vpermw(zmm_inp, zmm_prm, zmm_inp);
557                      vmovups(zword[reg_out + j * sizeofcacheline], zmm_inp);
558                 }
559                 add(reg_inp, unroll * sizeofcacheline);
560                 add(reg_out, unroll * sizeofcacheline);
561
562                 sub(reg_size, 2 * unroll);
563                 jmp(l_simd_loop[i + 1], T_NEAR);
564             }
565         }
566         L(l_simd_loop[0]);
567
568         test(reg_size, reg_size);
569         jz(l_simd_notail);
570
571         mov(reg32_tail, 0x00ff);
572         kmovw(ktail_mask, reg32_tail);
573
574         auto zmm_inp = zmm_reg(0);
575         vpxord(zmm_inp, zmm_inp, zmm_inp);
576         vmovups(zmm_inp | ktail_mask | T_z, ptr[reg_inp]);
577         vpermw(zmm_inp, zmm_prm, zmm_inp);
578         vmovups(zword[reg_out], zmm_inp);
579
580         L(l_simd_notail);
581
582         postamble();
583
584         const uint16_t dst_prm_array[32] =
585             {0,16,  1,17,  2,18,  3,19,  4,20,  5,21,  6,22,  7,23,  8,24,
586             9,25,  10,26,  11,27,  12,28,  13,29,  14,30,  15,31 };
587
588         align(64);
589         L(dst_prm_table);
590         for (size_t i = 0; i < 32; ++i)
591             dw(dst_prm_array[i]);
592     }
593
594     void (*jit_ker)(bf16_cvt_utils::jit_call_t *);
595
596 private:
597     int simd_w_;
598
599     Xbyak::Opmask ktail_mask = k2;
600     Xbyak::Zmm zmm_prm = Xbyak::Zmm(31);
601
602     Xbyak::Reg64 reg_inp = rax;
603     Xbyak::Reg64 reg_out = rbx;
604     Xbyak::Reg64 reg_prm = r11;
605     Xbyak::Reg64 reg_size = rdx;
606
607     Xbyak::Reg32 reg32_tail = abi_not_param1.cvt32();
608 };
609
610 #undef GET_OFF
611 }
612 }
613 }
614
615 #endif