updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_reorder.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18
19 #include "c_types_map.hpp"
20 #include "memory_desc_wrapper.hpp"
21 #include "mkldnn_debug.h"
22 #include "nstl.hpp"
23 #include "type_helpers.hpp"
24
25 #include "cpu_primitive.hpp"
26 #include "cpu_reorder_pd.hpp"
27 #include "jit_uni_reorder.hpp"
28
29 #include "jit_avx512_core_bf16cvt.hpp"
30 #include "jit_generator.hpp"
31
32 // #define TR_DEBUG
33 #if defined(TR_DEBUG)
34 #define DEBUg(...) do { __VA_ARGS__ } while (0)
35 #else
36 #define DEBUg(...)
37 #endif
38 #define DEBUG(...) DEBUg(__VA_ARGS__)
39
40 #ifdef _WIN32
41 /* seems like s_addr is a reserved macro on Windows */
42 #undef s_addr
43 #endif
44
45 using namespace Xbyak;
46 using namespace mkldnn::impl::types;
47
48 namespace mkldnn {
49 namespace impl {
50 namespace cpu {
51
52 namespace tr {
53
54 /** Minimal reasonable/desirable kernel size.
55  * The constant might be used to determine how a problem should be split
56  * between kernel and threading driver. */
57 const size_t ker_prb_size_min = 64;
58
59 /* kernel */
60 struct jit_uni_reorder_kernel_f32: public kernel_t, public jit_generator {
61     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_reorder_kernel_f32)
62
63     enum {
64         len_unroll_max = 256,
65         ndims_jit_loop_max = 3,
66     };
67
68     struct simple_impl_desc_t {
69         int ndims_full_unroll;
70         int len_last_dim_unroll;
71         int len_unroll;
72     };
73
74     static bool simple_impl_desc_init(const prb_t &prb,
75             simple_impl_desc_t *desc) {
76         const int ndims = prb.ndims;
77
78         int ndims_full_unroll = 0;
79         int len_last_dim_unroll = 1;
80         int len_unroll = 1;
81
82         for (int d = 0; d < ndims; ++d) {
83             auto &node = prb.nodes[d];
84             if (len_unroll * node.n <= len_unroll_max) {
85                 ndims_full_unroll++;
86                 len_unroll *= node.n;
87             } else {
88                 len_last_dim_unroll = len_unroll_max / len_unroll;
89                 while (node.n % len_last_dim_unroll)
90                     --len_last_dim_unroll;
91                 len_unroll *= len_last_dim_unroll;
92                 break;
93             }
94         }
95
96         if (prb.ndims - ndims_full_unroll > ndims_jit_loop_max)
97             return false;
98
99         if (desc) {
100             desc->ndims_full_unroll = ndims_full_unroll;
101             desc->len_last_dim_unroll = len_last_dim_unroll;
102             desc->len_unroll = len_unroll;
103         }
104
105         return true;
106     }
107
108     static bool applicable(const prb_t &p) {
109         using namespace data_type;
110
111         bool ok = true
112             && p.ndims > 0
113             && utils::one_of(p.itype, f32, bf16, s32, s8, u8)
114             && utils::one_of(p.otype, f32, bf16, s32, s8, u8)
115             && IMPLICATION(p.itype == bf16, utils::one_of(p.otype, f32, bf16))
116             && IMPLICATION(p.otype == bf16, utils::one_of(p.itype, f32, bf16))
117             && utils::everyone_is(0, p.ioff, p.ooff) /* do we need this? */
118             && utils::one_of(p.beta, 0.f, 1.f) /* anything else? */
119             && simple_impl_desc_init(p, nullptr)
120             && mayiuse(sse42)
121             && IMPLICATION(!utils::everyone_is(f32, p.itype, p.otype),
122                     mayiuse(avx))
123             && IMPLICATION((p.itype == bf16 || p.otype == bf16),
124                     mayiuse(avx512_core));
125         if (!ok) return false;
126
127         const ptrdiff_t max_stride = (1LL<<31) - 1;
128         for (int d = 0; d < p.ndims; ++d) {
129             const ptrdiff_t cms = max_stride / p.nodes[d].n;
130             bool strides_ok = true
131                 && p.nodes[d].is < cms / (int)data_type_size(p.itype)
132                 && p.nodes[d].os < cms / (int)data_type_size(p.otype);
133             if (!strides_ok) return false;
134         }
135
136         return true;
137     }
138
139     int n(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].n; }
140     int is(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].is; }
141     int os(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].os; }
142     int ss(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].ss; }
143
144     Address i_addr(int i_off)
145     { return ptr[reg_ptr_in + reg_off_in + i_off * itype_sz]; }
146
147     Address o_addr(int o_off)
148     { return ptr[reg_ptr_out + reg_off_out + o_off * otype_sz]; }
149
150     Address s_addr(int s_off)
151     { return ptr[reg_ptr_scale + reg_off_scale + s_off * stype_sz]; }
152
153     void step(int off, int prev_i_off, int prev_o_off, int prev_s_off,
154             int &i_off, int &o_off, int &s_off, int step_size = 1) {
155         i_off = prev_i_off;
156         o_off = prev_o_off;
157         s_off = prev_s_off;
158
159         if (off == 0) return;
160
161         int start_dim = 0, dims_prod = 1;
162         for (; start_dim < prb_.ndims && dims_prod != step_size; ++start_dim)
163             dims_prod *= n(start_dim);
164         assert(start_dim < prb_.ndims);
165         off /= step_size;
166
167         for (int d = start_dim; d < prb_.ndims; ++d) {
168             i_off += is(d);
169             o_off += os(d);
170             s_off += ss(d);
171
172             if (off % n(d)) break;
173
174             i_off += - n(d) * is(d);
175             o_off += - n(d) * os(d);
176             s_off += - n(d) * ss(d);
177             off /= n(d);
178
179             if (off == 0) break; /* FIXME: is it really required? */
180         }
181     }
182
183     void step(int off, int prev_i_off, int prev_o_off, int &i_off, int &o_off,
184             int step_size = 1) {
185         int dummy = 0;
186         step(off, prev_i_off, prev_o_off, dummy, i_off, o_off, dummy,
187                 step_size);
188     }
189
190     void tr8x8_avx2(int i_off, int o_off) {
191         for (int i = 0; i < 8; i++) {
192             using namespace data_type;
193
194             if (prb_.itype == s32 && prb_.otype == f32)
195                 vcvtdq2ps(Ymm(i), i_addr(i_off + i * 8));
196             else if (prb_.itype == f32 && prb_.otype == s32)
197                 vcvtps2dq(Ymm(i), i_addr(i_off + i * 8));
198             else
199                 vmovups(Ymm(i), i_addr(i_off + i * 8));
200         }
201
202         for (int i = 0; i < 8 / 2; i++) {
203             vunpcklps(Ymm(8 + i), Ymm(2 * i), Ymm(2 * i + 1));
204             vunpckhps(Ymm(i), Ymm(2 * i), Ymm(2 * i + 1));
205         }
206
207         const unsigned int lfloat = 0x44;
208         const unsigned int ufloat = 0xee;
209         for (int i = 0; i < 8 / 2; i++) {
210             int j = i % 2 == 0 ? 8 + i : i - 1;
211             vshufps(Ymm(8 / 2 + 2 * i), Ymm(j), Ymm(j + 1), lfloat);
212             vshufps(Ymm(8 / 2 + 2 * i + 1), Ymm(j), Ymm(j + 1), ufloat);
213         }
214
215         const unsigned int lquad = 0x20;
216         for (int i = 0; i < 8 / 2; i++)
217             vperm2f128(Ymm(i), Ymm(8 / 2 + i), Ymm(8 + i), lquad);
218
219         const unsigned int uquad = 0x31;
220         for (int i = 8 / 2; i < 8; i++)
221             vperm2f128(Ymm(i), Ymm(i), Ymm(8 / 2 + i), uquad);
222
223         for (int i = 0; i < 8; i++)
224             vmovups(o_addr(o_off + i * 8), Ymm(i));
225     }
226
227     bool process_unroll_tr8x8(int len) {
228         bool can_do = true
229             && mayiuse(avx2)
230             && prb_.ndims >= 2
231             && utils::everyone_is(4, itype_sz, otype_sz)
232             && utils::everyone_is(8, n(0), n(1))
233             && utils::everyone_is(1, os(0), is(1))
234             && utils::everyone_is(8, os(1), is(0))
235             && prb_.scale_type == scale_type_t::NONE
236             && prb_.beta == 0.f;
237         if (!can_do) return false;
238
239         const int step_size = n(0) * n(1);
240         int i_off = 0, o_off = 0;
241         for (int off = 0; off < len; off += step_size) {
242             step(off, i_off, o_off, i_off, o_off, step_size);
243             tr8x8_avx2(i_off, o_off);
244         }
245
246         return true;
247     }
248
249     template <cpu_isa_t isa>
250     bool process_direct_copy(int len) {
251         using namespace data_type;
252
253         using Vmm = typename cpu_isa_traits<isa>::Vmm;
254         const int simd_w = cpu_isa_traits<isa>::vlen / itype_sz;
255
256         bool can_do = true
257             && mayiuse(isa)
258             && utils::everyone_is(1, os(0), is(0))
259             && (false
260                     || prb_.itype == prb_.otype
261                     || (prb_.itype == s32 && prb_.otype == f32)
262                     || (prb_.itype == f32 && prb_.otype == s32)
263                     )
264             && len % simd_w == 0
265             && n(0) % len == 0
266             && prb_.scale_type == scale_type_t::NONE
267             && prb_.beta == 0.f;
268         if (!can_do) return false;
269
270         for (int off = 0; off < len;) {
271             const int unroll = nstl::min(16, (len - off) / simd_w);
272
273             for (int ur = 0; ur < unroll; ++ur)
274                 uni_vmovups(Vmm(ur), i_addr(off + ur * simd_w));
275
276             if (prb_.itype != prb_.otype) {
277                 for (int ur = 0; ur < unroll; ++ur) {
278                     if (prb_.itype == s32 && prb_.otype == f32)
279                         uni_vcvtdq2ps(Vmm(ur), Vmm(ur));
280                     else if (prb_.itype == f32 && prb_.otype == s32)
281                         uni_vcvtps2dq(Vmm(ur), Vmm(ur));
282                     else assert(!"unreachable");
283                 }
284             }
285
286             for (int ur = 0; ur < unroll; ++ur)
287                 uni_vmovups(o_addr(off + ur * simd_w), Vmm(ur));
288
289             off += unroll * simd_w;
290         }
291
292         return true;
293     }
294
295     void process_unroll_generic_step(int reg_unroll, const int *i_off,
296             const int *o_off, const int *s_off) {
297         using namespace data_type;
298
299         auto cvt2ps = [=](const Xmm &dst, const Operand &src, data_type_t idt) {
300             Xmm dst_pure = Xmm(dst.getIdx());
301             switch (idt) {
302             case f32:
303                 if (src.isMEM() || src.getIdx() != dst.getIdx())
304                     vmovups(dst, src);
305                 break;
306             case bf16: vpmovzxwd(dst, src); vpslld(dst, dst, 0x10); break;
307             case s32: vcvtdq2ps(dst, src); break;
308             case s8: vpmovsxbd(dst, src); vcvtdq2ps(dst_pure, dst); break;
309             case u8: vpmovzxbd(dst, src); vcvtdq2ps(dst_pure, dst); break;
310             default: assert(!"unreachable");
311             }
312         };
313
314         auto cvt2odt = [=](const Xmm &xmm, data_type_t odt, data_type_t idt) {
315             switch (odt) {
316             case bf16:
317                 if (idt == f32) {
318                     if (is_cpx_) {
319                         vcvtneps2bf16(xmm, xmm);
320                     } else {
321                         bf16_emu_->r_vcvtneps2bf16(
322                                 Ymm(xmm.getIdx()), Zmm(xmm.getIdx()));
323                     }
324                 }
325                 break;
326             case s32:
327                 if (idt == f32) vcvtps2dq(xmm, xmm);
328                 else if (idt == s8) vpmovsxbd(xmm, xmm);
329                 else if (idt == u8) vpmovzxbd(xmm, xmm);
330                 break;
331             case s8:
332                 if (idt == f32) vcvtps2dq(xmm, xmm);
333                 if (idt == f32 || idt == s32) {
334                     if (mayiuse(avx512_core)) {
335                         vpmovsdb(xmm, xmm);
336                     } else {
337                         vpackssdw(xmm, xmm, xmm_zero);
338                         vpacksswb(xmm, xmm, xmm_zero);
339                     }
340                 }
341                 if (idt == u8) vpminub(xmm, xmm, xmm_4x127b);
342                 break;
343             case u8:
344                 if (idt == f32) vcvtps2dq(xmm, xmm);
345                 if (idt == f32 || idt == s32) {
346                     if (mayiuse(avx512_core)) {
347                         vpmaxsd(xmm, xmm, xmm_zero);
348                         vpmovusdb(xmm, xmm);
349                     } else {
350                         vpackssdw(xmm, xmm, xmm_zero);
351                         vpackuswb(xmm, xmm, xmm_zero);
352                     }
353                 }
354                 if (idt == s8) vpmaxsb(xmm, xmm, xmm_zero);
355                 break;
356             default: assert(!"unreachable");
357             }
358         };
359
360         auto load = [=](const Xmm &xmm, const Address &addr, int size) {
361             switch (size) {
362             case 16: movups(xmm, addr); break;
363             case 8: movsd(xmm, addr); break;
364             case 4: movss(xmm, addr); break;
365             case 2: pinsrw(xmm, addr, 0x0); break;
366             case 1: pinsrb(xmm, addr, 0x0); break;
367             default: assert(!"unreachable");
368             }
369         };
370
371         auto store = [=](const Address &addr, const Xmm &xmm, int size) {
372             switch (size) {
373             case 16: movups(addr, xmm); break;
374             case 8: movsd(addr, xmm); break;
375             case 4: movss(addr, xmm); break;
376             case 2: pextrw(addr, xmm, 0x0); break;
377             case 1: pextrb(addr, xmm, 0x0); break;
378             default: assert(!"unreachable");
379             }
380         };
381
382         /* check whether loading 4 values at once is possible */
383         bool can_load_xmm = mayiuse(avx) && reg_unroll % 4 == 0;
384         for (int ur = 1; ur < reg_unroll; ++ur)
385             if (i_off[ur] != i_off[ur - 1] + 1)
386                 can_load_xmm = false;
387         const int load_step = can_load_xmm ? 4 : 1;
388
389         /* check whether storing 4 values at once is possible */
390         bool can_store_xmm = reg_unroll % 4 == 0;
391         for (int ur = 1; ur < reg_unroll; ++ur)
392             if (o_off[ur] != o_off[ur - 1] + 1)
393                 can_store_xmm = false;
394         const int ur_step = can_store_xmm ? 4 : 1;
395
396         const bool interim_f32 = false
397             || utils::one_of(f32, prb_.itype, prb_.otype)
398             || prb_.scale_type != scale_type_t::NONE
399             || prb_.beta != 0.f;
400
401         if (!can_load_xmm && can_store_xmm) {
402             assert(ur_step == 4);
403             /* load with stride */
404             for (int ur = 0; ur < reg_unroll; ur += ur_step) {
405                 for (int r = 0; r < ur_step; ++r) {
406                     if (itype_sz == 4)
407                         pinsrd(Xmm(ur), i_addr(i_off[ur + r]), r);
408                     else if (itype_sz == 2)
409                         pinsrw(Xmm(ur), i_addr(i_off[ur + r]), r);
410                     else
411                         pinsrb(Xmm(ur), i_addr(i_off[ur + r]), r);
412                 }
413             }
414         } else {
415             for (int ur = 0; ur < reg_unroll; ur += load_step)
416                 load(Xmm(ur), i_addr(i_off[ur]), load_step * itype_sz);
417         }
418
419         /* xmm[:] <-- (f32)xmm[:] */
420         if (interim_f32) {
421             const int cvt_step = nstl::max(load_step, ur_step);
422             for (int ur = 0; ur < reg_unroll; ur += cvt_step)
423                 cvt2ps(Xmm(ur), Xmm(ur), prb_.itype);
424         }
425
426         if (can_load_xmm && !can_store_xmm) {
427             const bool fast_return = true // transposition on the fly
428                 && prb_.scale_type != scale_type_t::MANY
429                 && prb_.beta == 0.f;
430             if (fast_return) {
431                 for (int ur = 0; ur < reg_unroll; ur += load_step) {
432                     if (prb_.scale_type == scale_type_t::COMMON)
433                         mulps(Xmm(ur), xmm_scale);
434                     if (prb_.otype != f32)
435                         cvt2odt(Xmm(ur), prb_.otype,
436                                 interim_f32 ? f32 : prb_.itype);
437                     for (int r = 0; r < load_step; ++r) {
438                         if (otype_sz == 4)
439                             pextrd(o_addr(o_off[ur + r]), Xmm(ur), r);
440                         else if (otype_sz == 2)
441                             pextrw(o_addr(o_off[ur + r]), Xmm(ur), r);
442                         else
443                             pextrb(o_addr(o_off[ur + r]), Xmm(ur), r);
444                     }
445                 }
446                 return;
447             }
448
449             /* scatter elements of xmm into 4 xmms */
450             if (itype_sz == 4 || interim_f32) {
451                 for (int ur = 0; ur < reg_unroll; ur += load_step)
452                     for (int r = 1; r < load_step; ++r)
453                         vshufps(Xmm(ur + r), Xmm(ur), Xmm(ur), r);
454             } else if (itype_sz == 2) {
455                 for (int ur = 0; ur < reg_unroll; ur += load_step)
456                     for (int r = 1; r < load_step; ++r)
457                         vpalignr(Xmm(ur + r), Xmm(ur), Xmm(ur), 2 * r);
458             } else {
459                 for (int ur = 0; ur < reg_unroll; ur += load_step)
460                     for (int r = 1; r < load_step; ++r)
461                         vpalignr(Xmm(ur + r), Xmm(ur), Xmm(ur), r);
462             }
463         }
464
465         /* scale and beta processing */
466         if (can_store_xmm) {
467             /* xmm <-- scale * xmm[:] */
468             if (prb_.scale_type == scale_type_t::COMMON) {
469                 for (int ur = 0; ur < reg_unroll; ur += ur_step)
470                     mulps(Xmm(ur), xmm_scale);
471             } else if (prb_.scale_type == scale_type_t::MANY) {
472                 enum class scale_load_type_t { bcast, load, gather };
473
474                 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
475                     scale_load_type_t scale_load_type =
476                         scale_load_type_t::bcast; // the best case
477
478                     for (int r = ur + 1; r < ur + ur_step; ++r)
479                         if (s_off[r] != s_off[r - 1] + 0)
480                             scale_load_type = scale_load_type_t::load;
481
482                     if (scale_load_type == scale_load_type_t::bcast) {
483                         movss(xmm_scale, s_addr(s_off[ur]));
484                         shufps(xmm_scale, xmm_scale, 0x0);
485                         mulps(Xmm(ur), xmm_scale);
486                         continue;
487                     }
488
489                     // bcast doesn't work, the next try -- load
490                     for (int r = ur + 1; r < ur + ur_step; ++r)
491                         if (s_off[r] != s_off[r - 1] + 1)
492                             scale_load_type = scale_load_type_t::gather;
493
494                     if (scale_load_type == scale_load_type_t::load) {
495                         movups(xmm_scale, s_addr(s_off[ur]));
496                         mulps(Xmm(ur), xmm_scale);
497                         continue;
498                     }
499
500                     // load doesn't work as well
501                     // so gather the scale factors one by one
502                     for (int r = ur; r < ur + ur_step; ++r)
503                         pinsrd(xmm_scale, s_addr(s_off[r]), r - ur);
504                     mulps(Xmm(ur), xmm_scale);
505                 }
506             }
507
508             /* dst <-- beta * dst + xmm[:] */
509             assert(prb_.beta == 0.f || prb_.beta == 1.f);
510             if (prb_.beta == 1.f) {
511                 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
512                     if (prb_.otype == f32) {
513                         /* non VEX instructions do not support unaligned
514                          * memory for instructions other than movups. */
515                         if (mayiuse(avx)) {
516                             vaddps(Xmm(ur), o_addr(o_off[ur]));
517                         } else {
518                             /* register xmm(1) is unused */
519                             movups(Xmm(1), o_addr(o_off[ur]));
520                             addps(Xmm(ur), Xmm(1));
521                         }
522                     } else {
523                         cvt2ps(Xmm(1), o_addr(o_off[ur]), prb_.otype);
524                         vaddps(Xmm(ur), Xmm(1));
525                     }
526                 }
527             }
528         } else {
529             /* xmm[0] <-- scale * xmm[0] */
530             if (prb_.scale_type == scale_type_t::COMMON) {
531                 for (int ur = 0; ur < reg_unroll; ur += ur_step)
532                     mulss(Xmm(ur), xmm_scale);
533             } else if (prb_.scale_type == scale_type_t::MANY) {
534                 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
535                     mulss(Xmm(ur), s_addr(s_off[ur]));
536                 }
537             }
538
539             /* dst <-- beta * dst + xmm[0] */
540             assert(prb_.beta == 0.f || prb_.beta == 1.f);
541             if (prb_.beta == 1.f) {
542                 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
543                     if (prb_.otype == f32) {
544                         addss(Xmm(ur), o_addr(o_off[ur]));
545                     } else {
546                         if (prb_.otype == s32) {
547                             vmovss(xmm_tmp, o_addr(o_off[ur]));
548                         } else if (utils::one_of(prb_.otype, s8, u8)) {
549                             pinsrb(xmm_tmp, o_addr(o_off[ur]), 0x0);
550                         } else {
551                             assert(!"unsupported o_type");
552                         }
553                         cvt2ps(xmm_tmp, xmm_tmp, prb_.otype);
554                         addps(Xmm(ur), xmm_tmp);
555                     }
556                 }
557             }
558         }
559
560         for (int ur = 0; ur < reg_unroll; ur += ur_step) {
561             if (prb_.otype != f32)
562                 cvt2odt(Xmm(ur), prb_.otype, interim_f32 ? f32 : prb_.itype);
563             store(o_addr(o_off[ur]), Xmm(ur), ur_step * otype_sz);
564         }
565     }
566
567     void process_unroll_generic(int len) {
568         const int blk = 8;
569
570         int i_off[2 * blk] = {0};
571         int o_off[2 * blk] = {0};
572         int s_off[2 * blk] = {0};
573
574         int curr = 0; // will switch between 0 and 1
575
576         for (int off = 0; off < len; off += blk) {
577             const int reg_unroll = nstl::min(off + blk, len) - off;
578
579             /* compute offsets */
580             for (int ur = off != 0 ? 0 : 1; ur < reg_unroll; ++ur) {
581                 const int ur_c = curr * blk + ur;
582                 const int ur_p = (ur_c - 1 + 2 * blk) % (2 * blk); // prev ur
583                 step(off + ur,
584                         i_off[ur_p], o_off[ur_p], s_off[ur_p],
585                         i_off[ur_c], o_off[ur_c], s_off[ur_c]);
586             }
587
588             process_unroll_generic_step(reg_unroll, i_off + curr * blk,
589                     o_off + curr * blk, s_off + curr * blk);
590
591             curr = 1 - curr;
592         }
593     }
594
595     void loop_begin(Label &l, Reg64 reg_cnt, int len) {
596         mov(reg_cnt, len);
597         L(l);
598     }
599
600     void loop_end(Label &l, Reg64 reg_cnt, int len,
601             int i_step, int o_step, int s_step) {
602         add(reg_off_in, i_step * itype_sz);
603         add(reg_off_out, o_step * otype_sz);
604         if (prb_.scale_type == scale_type_t::MANY)
605             add(reg_off_scale, s_step * stype_sz);
606         dec(reg_cnt);
607         jnz(l);
608
609         sub(reg_off_in, len * i_step * itype_sz);
610         sub(reg_off_out, len * o_step * otype_sz);
611         if (prb_.scale_type == scale_type_t::MANY)
612             sub(reg_off_scale, len * s_step * stype_sz);
613     }
614
615     bool simple_impl() {
616         simple_impl_desc_t d;
617         if (!simple_impl_desc_init(prb_, &d)) return false;
618
619         const int nfu = d.ndims_full_unroll;
620         const int ldu = d.len_last_dim_unroll;
621         const int n_jit_loops = prb_.ndims - d.ndims_full_unroll;
622         assert(n_jit_loops <= ndims_jit_loop_max);
623
624         xor_(reg_off_in, reg_off_in);
625         xor_(reg_off_out, reg_off_out);
626         if (prb_.scale_type == scale_type_t::MANY)
627             xor_(reg_off_scale, reg_off_scale);
628
629         Label l_loop[3];
630         Reg64 reg_cnt[3] = {r15, r14, r13};
631
632         if (n_jit_loops > 2)
633             loop_begin(l_loop[2], reg_cnt[2], n(nfu + 2));
634
635         if (n_jit_loops > 1)
636             loop_begin(l_loop[1], reg_cnt[1], n(nfu + 1));
637
638         if (n_jit_loops > 0)
639             loop_begin(l_loop[0], reg_cnt[0], n(nfu + 0) / ldu);
640
641         const bool optimized = false
642             || process_direct_copy<avx>(d.len_unroll)
643             || process_direct_copy<sse42>(d.len_unroll)
644             || process_unroll_tr8x8(d.len_unroll);
645         if (!optimized)
646             process_unroll_generic(d.len_unroll);
647
648         if (n_jit_loops > 0)
649             loop_end(l_loop[0], reg_cnt[0],
650                     n(nfu + 0) / ldu, is(nfu + 0) * ldu, os(nfu + 0) * ldu,
651                     ss(nfu + 0) * ldu);
652
653         if (n_jit_loops > 1)
654             loop_end(l_loop[1], reg_cnt[1],
655                     n(nfu + 1), is(nfu + 1), os(nfu + 1), ss(nfu + 1));
656
657         if (n_jit_loops > 2)
658             loop_end(l_loop[2], reg_cnt[2],
659                     n(nfu + 2), is(nfu + 2), os(nfu + 2), ss(nfu + 2));
660
661         return true;
662     }
663
664     void impl() {
665         if (simple_impl()) return;
666         assert(!"no implementation available");
667     }
668
669     jit_uni_reorder_kernel_f32(const desc_t &desc)
670         : kernel_t(desc), jit_generator(), bf16_emu_(nullptr) {
671         itype_sz = data_type_size(prb_.itype);
672         otype_sz = data_type_size(prb_.otype);
673         stype_sz = sizeof(float);
674         is_cpx_ = mayiuse(avx512_core_bf16);
675         if (prb_.otype == data_type::bf16 && !is_cpx_) {
676             bf16_emu_ = new bf16_emulation_t(this, vcvt_bf16_one, vcvt_bf16_eve,
677                     vcvt_bf16_sel, reg_bf16_scratch, vcvt_bf16_tmp, vcvt_bf16_tmp);
678             bf16_emu_->init_vcvtneps2bf16();
679         }
680
681         preamble();
682 #       define PARAM(x) ptr[abi_param1 + offsetof(call_param_t, x)]
683         if (prb_.scale_type == scale_type_t::COMMON) {
684             auto reg_ptr_scale_tmp = reg_ptr_in;
685             mov(reg_ptr_scale_tmp, PARAM(scale));
686             movups(xmm_scale, ptr[reg_ptr_scale_tmp]);
687         } else if (prb_.scale_type == scale_type_t::MANY) {
688             mov(reg_ptr_scale, PARAM(scale));
689         }
690         mov(reg_ptr_in, PARAM(in));
691         mov(reg_ptr_out, PARAM(out));
692 #       undef PARAM
693
694         if (mayiuse(avx)) {
695             vxorps(xmm_zero, xmm_zero, xmm_zero);
696
697             if (prb_.itype == data_type::u8 && prb_.otype == data_type::s8) {
698                 mov(reg_tmp.cvt32(), 0x7f7f7f7f);
699                 movd(xmm_4x127b, reg_tmp.cvt32());
700             }
701         }
702
703         impl();
704         postamble();
705         ker_ = (void (*)(const call_param_t *))getCode();
706     }
707     ~jit_uni_reorder_kernel_f32() { delete bf16_emu_; }
708
709 private:
710     int itype_sz;
711     int otype_sz;
712     int stype_sz;
713
714     Reg64 reg_ptr_in = rsi;
715     Reg64 reg_ptr_out = rdx;
716     Reg64 reg_ptr_scale = abi_not_param1;
717
718     Reg64 reg_off_in = r8;
719     Reg64 reg_off_out = r9;
720     Reg64 reg_off_scale = r10;
721
722     Reg64 reg_tmp = rax;
723
724     Xmm xmm_scale = xmm15;
725     Xmm xmm_zero = xmm14;
726     Xmm xmm_4x127b = xmm13; // TODO: unite with xmm_zero
727     Xmm xmm_tmp = xmm12;
728
729     /* bf16 support */
730     bool is_cpx_;
731     bf16_emulation_t *bf16_emu_;
732     Reg64 reg_bf16_scratch = reg_tmp;
733     Zmm vcvt_bf16_one = Zmm(16);
734     Zmm vcvt_bf16_eve = Zmm(17);
735     Zmm vcvt_bf16_sel = Zmm(18);
736     Zmm vcvt_bf16_tmp = Zmm(19);
737 };
738
739 status_t kernel_t::desc_init(kernel_t::desc_t &desc, const prb_t &prb,
740         int ndims_ker_max) {
741     desc.prb = prb;
742     desc.prb.ioff = desc.prb.ooff = 0;
743
744     if (ndims_ker_max > prb.ndims)
745         return status::invalid_arguments;
746
747     auto ndims_ker_max_f = [&]() {
748         size_t cur_size = 1;
749         for (int d = 0; d < prb.ndims; cur_size *= prb.nodes[d++].n)
750             if (cur_size >= ker_prb_size_min) return d;
751         return prb.ndims;
752     };
753
754     if (ndims_ker_max <= 0)
755         ndims_ker_max = ndims_ker_max_f();
756
757     /* traverse through kernel implementations */
758     /* TODO: find a better way to do that... */
759     desc.id = 0;
760     for (int ndims_ker = ndims_ker_max; ndims_ker > 0; --ndims_ker) {
761         desc.prb.ndims = ndims_ker;
762         if (jit_uni_reorder_kernel_f32::applicable(desc.prb))
763             return status::success;
764     }
765
766     return status::unimplemented;
767 }
768
769 kernel_t *kernel_t::create(const kernel_t::desc_t &desc) {
770     switch (desc.id) {
771     case 0: return new jit_uni_reorder_kernel_f32(desc);
772     default: assert(!"unknown kernel id"); return nullptr;
773     }
774
775     return nullptr;
776 }
777
778 }
779
780 static void prb_block_for_cache(tr::prb_t &prb) {
781     if (prb.nodes[0].is % 64 == 0 && prb.nodes[0].n > 16) {
782         /** an attempt to use caches more efficient and
783          * address the 4K-aliasing issue */
784         /* TODO: improve the logic around here */
785         int j = 1;
786         for (; j < prb.ndims && prb.nodes[j].is != 1; ++j);
787         if (j == prb.ndims) return;
788
789         /* it makes sense to re-prioritize sequential read over
790          * sequential write if the former would not trash the
791          * cache, i.e. is == 1 and os % 2^smth != 0. Smth is
792          * set to 2 at the moment */
793         const int move_to = prb.nodes[j].os % 4 != 0 ? 0 : 1;
794         if (j == move_to) return;
795
796         if (prb.nodes[j].n > 16 && prb.nodes[j].n % 16 == 0)
797             prb_node_split(prb, j, 16);
798
799         prb_node_move(prb, j, move_to);
800         DEBUG({ printf("cache: "); prb_dump(prb); });
801     }
802 }
803
804 /** finds the maximum number of dimension the kernel should process and
805  * optionally splits one of the dimension to achieve better balance between
806  * parallel driver and the kernel. */
807 static void prb_thread_kernel_balance(tr::prb_t &prb, int &ndims_ker_max) {
808     size_t sz_total = 1;
809     for (int d = 0; d < prb.ndims; ++d)
810         sz_total *= prb.nodes[d].n;
811
812     /* sz_drv_min is the minimal size for the parallel
813      * driver required for good parallelization */
814     const size_t sz_drv_min = nstl::min<size_t>(
815             16 * mkldnn_get_max_threads(),
816             utils::div_up(sz_total, 1024));
817
818     /* kdims -- # of dimensions processed by a kernel
819      * sz_ker_cur -- product of the dimension processed by a kernel
820      * sz_drv_cur -- product of the dimension processed by a driver */
821
822     int kdims = prb.ndims;
823     size_t sz_drv_cur = 1;
824     for (; kdims > 1 && sz_drv_cur < sz_drv_min; --kdims)
825         sz_drv_cur *= prb.nodes[kdims - 1].n;
826
827     size_t sz_ker_cur = 1;
828     for (int d = 0; d < kdims; ++d)
829         sz_ker_cur *= prb.nodes[d].n;
830
831     /* Initially kdims is chosen so that sz_drv_cur >= sz_drv_min.
832      *
833      * It might happen that for chosen kdims the sz_ker_cur is too small
834      * (less than tr::ker_prb_size_min). In that case try to split the
835      * innermost driver dimension into two, to increase sz_ker_cur. */
836     bool want_borrow_ker_from_drv = true
837         && kdims < prb.ndims
838         && sz_ker_cur < tr::ker_prb_size_min
839         && sz_drv_cur > sz_drv_min;
840     if (want_borrow_ker_from_drv) {
841         /* sz_want_borrow is the minimal sz, so that:
842          *  o) sz_ker_cur * sz_want_borrow >= tr::ker_prb_size_min
843          *  o) current innermost driver dimension is divisible by
844          *     sz_want_borrow (so that we can evenly split that
845          *     dimension into two)
846          *
847          *  In the worst case the minimal sz_want_borrow is equal
848          *  to the innermost driver dimension itself. In that case
849          *  we will sacrifice it in favor of kernel (is it fine?). */
850         size_t sz_want_borrow
851             = utils::div_up(tr::ker_prb_size_min, sz_ker_cur);
852         for (; prb.nodes[kdims].n % sz_want_borrow; ++sz_want_borrow);
853         if (sz_want_borrow != prb.nodes[kdims].n)
854             prb_node_split(prb, kdims, sz_want_borrow);
855         kdims += 1;
856     }
857
858     /* On the other hand it might happen that for chosen kdims
859      * the sz_drv_cur is too small (less than sz_drv_min). In that case
860      * try to split the outermost kernel dimension into two, to increase
861      * sz_drv_cur. */
862     bool want_borrow_drv_from_ker = true
863         && sz_ker_cur > tr::ker_prb_size_min
864         && sz_drv_cur < sz_drv_min;
865     if (want_borrow_drv_from_ker) {
866         size_t sz_want_borrow = utils::div_up(sz_drv_min, sz_drv_cur);
867         for (; prb.nodes[kdims - 1].n % sz_want_borrow; ++sz_want_borrow);
868         if (sz_want_borrow != prb.nodes[kdims - 1].n)
869             prb_node_split(prb, kdims - 1,
870                     prb.nodes[kdims - 1].n / sz_want_borrow);
871     }
872
873     ndims_ker_max = kdims;
874
875     if (want_borrow_ker_from_drv || want_borrow_drv_from_ker) {
876         DEBUG({ printf("split: "); prb_dump(prb);
877                 printf("ndims_ker_max = %d\n", ndims_ker_max); });
878     }
879 }
880
881 struct jit_uni_reorder_t : public cpu_primitive_t {
882     struct pd_t : public cpu_reorder_pd_t {
883         pd_t(const cpu_memory_pd_t *input_pd, const cpu_memory_pd_t *output_pd,
884                 const primitive_attr_t *attr)
885             : cpu_reorder_pd_t(input_pd, output_pd, attr) {}
886
887         DECLARE_COMMON_PD_T("jit:uni", jit_uni_reorder_t);
888
889         static status_t create(reorder_pd_t **reorder_pd,
890                 const memory_pd_t *input_pd, const memory_pd_t *output_pd,
891                 const primitive_attr_t *attr) {
892             const memory_desc_t *imd = input_pd->desc();
893             const memory_desc_t *omd = output_pd->desc();
894
895             auto prb = tr::prb_t();
896
897             if (imd->format == mkldnn_OhIw8o4i || imd->format == mkldnn_gOhIw8o4i ||
898                 imd->format == mkldnn_OhIw8o4i_s8s8 || imd->format == mkldnn_gOhIw8o4i_s8s8 ||
899                 omd->format == mkldnn_OhIw8o4i || omd->format == mkldnn_gOhIw8o4i ||
900                 omd->format == mkldnn_OhIw8o4i_s8s8 || omd->format == mkldnn_gOhIw8o4i_s8s8 ||
901                 omd->format == mkldnn_OdhIw8o4i_s8s8 || omd->format == mkldnn_gOdhIw8o4i_s8s8 ||
902                 omd->format == mkldnn_dhwio_s8s8 || omd->format == mkldnn_dhwigo_s8s8)
903                 return status::unimplemented;
904
905             status_t prb_init_status = prb_init(prb, *imd, *omd, attr);
906             if (prb_init_status != success) return prb_init_status;
907
908             DEBUG({ printf("init : "); prb_dump(prb); });
909             prb_normalize(prb);
910             DEBUG({ printf("norm : "); prb_dump(prb); });
911             prb_simplify(prb);
912             DEBUG({ printf("smpl : "); prb_dump(prb); });
913
914             prb_block_for_cache(prb);
915
916             int ndims_ker_max;
917             prb_thread_kernel_balance(prb, ndims_ker_max);
918
919             tr::kernel_t::desc_t ker_desc;
920             status_t ker_init_status
921                 = tr::kernel_t::desc_init(ker_desc, prb, ndims_ker_max);
922             if (ker_init_status != status::success) return ker_init_status;
923
924             const int ndims_driver = prb.ndims - ker_desc.prb.ndims;
925             if (ndims_driver > jit_uni_reorder_t::ndims_driver_max)
926                 return status::unimplemented;
927
928             DEBUG({ printf("ker  : "); prb_dump(ker_desc.prb); });
929
930             auto _pd = new pd_t((const cpu_memory_pd_t *)input_pd,
931                     (const cpu_memory_pd_t *)output_pd, attr);
932             if (_pd == nullptr) return out_of_memory;
933             if (_pd->init() != success) { delete _pd; return unimplemented; }
934             _pd->prb_ = prb;
935             _pd->ker_desc_ = ker_desc;
936             return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
937         }
938
939         tr::prb_t prb_;
940         tr::kernel_t::desc_t ker_desc_;
941     };
942
943     jit_uni_reorder_t(const pd_t *apd, const input_vector &inputs,
944             const output_vector &outputs)
945         : cpu_primitive_t(apd, inputs, outputs) {
946         kernel_ = tr::kernel_t::create(pd()->ker_desc_);
947         assert(kernel_);
948     }
949     ~jit_uni_reorder_t() { delete kernel_; }
950
951     void omp_driver_0d(int off, const char *in, char *out,
952             const float *scale) const {
953         tr::call_param_t c{in, out, scale};
954         (*kernel_)(&c);
955     }
956
957     void omp_driver_1d(int ithr, int nthr, int off, const char *in, char *out,
958             const float *scale) const {
959         const tr::node_t *ns = pd()->prb_.nodes + off;
960         for_nd(ithr, nthr, (ptrdiff_t)ns[0].n, [&](ptrdiff_t d0) {
961             auto c = tr::call_param_t();
962             c.in = in + d0 * ns[0].is * data_type_size(pd()->prb_.itype);
963             c.out = out + d0 * ns[0].os * data_type_size(pd()->prb_.otype);
964             c.scale = scale + d0 * ns[0].ss;
965             (*kernel_)(&c);
966         });
967     }
968
969     void omp_driver_2d(int ithr, int nthr, int off, const char *in, char *out,
970             const float *scale) const {
971         const tr::node_t *ns = pd()->prb_.nodes + off;
972         for_nd(ithr, nthr, (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n,
973                 [&](ptrdiff_t d1, ptrdiff_t d0) {
974             auto c = tr::call_param_t();
975             c.in = in + (d0 * ns[0].is + d1 * ns[1].is)
976                 * data_type_size(pd()->prb_.itype);
977             c.out = out + (d0 * ns[0].os + d1 * ns[1].os)
978                 * data_type_size(pd()->prb_.otype);
979             c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss;
980             (*kernel_)(&c);
981         });
982     }
983
984     void omp_driver_3d(int ithr, int nthr, int off, const char *in, char *out,
985             const float *scale) const {
986         const tr::node_t *ns = pd()->prb_.nodes + off;
987         for_nd(ithr, nthr, (ptrdiff_t)ns[2].n, (ptrdiff_t)ns[1].n,
988                 (ptrdiff_t)ns[0].n,
989                 [&](ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) {
990             auto c = tr::call_param_t();
991             c.in = in + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is)
992                 * data_type_size(pd()->prb_.itype);
993             c.out = out + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os)
994                 * data_type_size(pd()->prb_.otype);
995             c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss;
996             (*kernel_)(&c);
997         });
998     }
999
1000     void omp_driver_4d(int ithr, int nthr, int off, const char *in, char *out,
1001             const float *scale) const {
1002         const tr::node_t *ns = pd()->prb_.nodes + off;
1003         for_nd(ithr, nthr, (ptrdiff_t)ns[3].n, (ptrdiff_t)ns[2].n,
1004                 (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n,
1005                 [&](ptrdiff_t d3, ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) {
1006             auto c = tr::call_param_t();
1007             c.in = in + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is
1008                     + d3 * ns[3].is) * data_type_size(pd()->prb_.itype);
1009             c.out = out + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os
1010                     + d3 * ns[3].os) * data_type_size(pd()->prb_.otype);
1011             c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss
1012                 + d3 * ns[3].ss;
1013             (*kernel_)(&c);
1014         });
1015     }
1016
1017     void omp_driver(const char *in, char *out, const float *scale) const {
1018         in += pd()->prb_.ioff * data_type_size(pd()->prb_.itype);
1019         out += pd()->prb_.ooff * data_type_size(pd()->prb_.otype);
1020
1021         DEBUG({ printf("prb : "); tr::prb_dump(pd()->prb_); });
1022         DEBUG({ printf("ker : "); tr::prb_dump(pd()->ker_desc_.prb); });
1023
1024         int ndims = pd()->prb_.ndims;
1025         int ndims_ker = pd()->ker_desc_.prb.ndims;
1026         assert(ndims - ndims_ker <= ndims_driver_max);
1027
1028         if (ndims - ndims_ker == 0) {
1029             set_rnd_mode(pd()->attr()->round_mode_);
1030             omp_driver_0d(ndims_ker, in, out, scale);
1031             restore_rnd_mode();
1032         } else {
1033             size_t work_amount = 0;
1034             const tr::node_t *ns = pd()->prb_.nodes + ndims_ker;
1035             switch (ndims - ndims_ker) {
1036                 case 1: work_amount = (size_t)ns[0].n; break;
1037                 case 2: work_amount = (size_t)ns[1].n * (size_t)ns[0].n; break;
1038                 case 3: work_amount = (size_t)ns[2].n * (size_t)ns[1].n * (size_t)ns[0].n; break;
1039                 case 4: work_amount = (size_t)ns[3].n * (size_t)ns[2].n * (size_t)ns[1].n * (size_t)ns[0].n; break;
1040                 default: assert(!"unimplemented");
1041             }
1042
1043             parallel(0, work_amount, [&](const int ithr, const int nthr) {
1044                 set_rnd_mode(pd()->attr()->round_mode_);
1045                 switch (ndims - ndims_ker) {
1046                 case 1: omp_driver_1d(ithr, nthr, ndims_ker, in, out, scale); break;
1047                 case 2: omp_driver_2d(ithr, nthr, ndims_ker, in, out, scale); break;
1048                 case 3: omp_driver_3d(ithr, nthr, ndims_ker, in, out, scale); break;
1049                 case 4: omp_driver_4d(ithr, nthr, ndims_ker, in, out, scale); break;
1050                 default: assert(!"unimplemented");
1051                 }
1052                 restore_rnd_mode();
1053             });
1054         }
1055     }
1056
1057     virtual void execute(event_t *e) const {
1058         auto in = reinterpret_cast<const char *>(input_memory(0));
1059         auto out = reinterpret_cast<char *>(memory());
1060
1061         omp_driver(in, out, pd()->attr()->output_scales_.scales_);
1062
1063         e->set_state(event_t::ready);
1064     }
1065
1066     enum { ndims_driver_max = 4 };
1067
1068 private:
1069     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
1070     tr::kernel_t *kernel_;
1071 };
1072
1073 status_t jit_uni_reorder_create(reorder_pd_t **reorder_pd,
1074         const memory_pd_t *input_pd, const memory_pd_t *output_pd,
1075         const primitive_attr_t *attr) {
1076     return jit_uni_reorder_t::pd_t::create(reorder_pd, input_pd, output_pd,
1077             attr);
1078 }
1079
1080 }
1081 }
1082 }