Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_reorder.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18
19 #include "c_types_map.hpp"
20 #include "memory_desc_wrapper.hpp"
21 #include "mkldnn_debug.h"
22 #include "nstl.hpp"
23 #include "type_helpers.hpp"
24
25 #include "cpu_primitive.hpp"
26 #include "cpu_reorder_pd.hpp"
27 #include "jit_uni_reorder.hpp"
28
29 #include "jit_generator.hpp"
30
31 // #define TR_DEBUG
32 #if defined(TR_DEBUG)
33 #define DEBUg(...) do { __VA_ARGS__ } while (0)
34 #else
35 #define DEBUg(...)
36 #endif
37 #define DEBUG(...) DEBUg(__VA_ARGS__)
38
39 #ifdef _WIN32
40 /* seems like s_addr is a reserved macro on Windows */
41 #undef s_addr
42 #endif
43
44 using namespace Xbyak;
45 using namespace mkldnn::impl::types;
46
47 namespace mkldnn {
48 namespace impl {
49 namespace cpu {
50
51 namespace tr {
52
53 /** Minimal reasonable/desirable kernel size.
54  * The constant might be used to determine how a problem should be split
55  * between kernel and threading driver. */
56 const size_t ker_prb_size_min = 64;
57
58 /* kernel */
59 struct jit_uni_reorder_kernel_f32: public kernel_t, public jit_generator {
60     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_reorder_kernel_f32)
61
62     enum {
63         len_unroll_max = 256,
64         ndims_jit_loop_max = 3,
65     };
66
67     struct simple_impl_desc_t {
68         int ndims_full_unroll;
69         int len_last_dim_unroll;
70         int len_unroll;
71     };
72
73     static bool simple_impl_desc_init(const prb_t &prb,
74             simple_impl_desc_t *desc) {
75         const int ndims = prb.ndims;
76
77         int ndims_full_unroll = 0;
78         int len_last_dim_unroll = 1;
79         int len_unroll = 1;
80
81         for (int d = 0; d < ndims; ++d) {
82             auto &node = prb.nodes[d];
83             if (len_unroll * node.n <= len_unroll_max) {
84                 ndims_full_unroll++;
85                 len_unroll *= node.n;
86             } else {
87                 len_last_dim_unroll = len_unroll_max / len_unroll;
88                 while (node.n % len_last_dim_unroll)
89                     --len_last_dim_unroll;
90                 len_unroll *= len_last_dim_unroll;
91                 break;
92             }
93         }
94
95         if (prb.ndims - ndims_full_unroll > ndims_jit_loop_max)
96             return false;
97
98         if (desc) {
99             desc->ndims_full_unroll = ndims_full_unroll;
100             desc->len_last_dim_unroll = len_last_dim_unroll;
101             desc->len_unroll = len_unroll;
102         }
103
104         return true;
105     }
106
107     static bool applicable(const prb_t &p) {
108         using namespace data_type;
109
110         bool ok = true
111             && p.ndims > 0
112             && utils::one_of(p.itype, f32, s32, s8, u8)
113             && utils::one_of(p.otype, f32, s32, s8, u8)
114             && utils::everyone_is(0, p.ioff, p.ooff) /* do we need this? */
115             && utils::one_of(p.beta, 0.f, 1.f) /* anything else? */
116             && simple_impl_desc_init(p, nullptr)
117             && mayiuse(sse42)
118             && IMPLICATION(!utils::everyone_is(f32, p.itype, p.otype),
119                     mayiuse(avx));
120         if (!ok) return false;
121
122         const ptrdiff_t max_stride = (1LL<<31) - 1;
123         for (int d = 0; d < p.ndims; ++d) {
124             const ptrdiff_t cms = max_stride / p.nodes[d].n;
125             bool strides_ok = true
126                 && p.nodes[d].is < cms / (int)data_type_size(p.itype)
127                 && p.nodes[d].os < cms / (int)data_type_size(p.otype);
128             if (!strides_ok) return false;
129         }
130
131         return true;
132     }
133
134     int n(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].n; }
135     int is(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].is; }
136     int os(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].os; }
137     int ss(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].ss; }
138
139     Address i_addr(int i_off)
140     { return ptr[reg_ptr_in + reg_off_in + i_off * itype_sz]; }
141
142     Address o_addr(int o_off)
143     { return ptr[reg_ptr_out + reg_off_out + o_off * otype_sz]; }
144
145     Address s_addr(int s_off)
146     { return ptr[reg_ptr_scale + reg_off_scale + s_off * stype_sz]; }
147
148     void step(int off, int prev_i_off, int prev_o_off, int prev_s_off,
149             int &i_off, int &o_off, int &s_off, int step_size = 1) {
150         i_off = prev_i_off;
151         o_off = prev_o_off;
152         s_off = prev_s_off;
153
154         if (off == 0) return;
155
156         int start_dim = 0, dims_prod = 1;
157         for (; start_dim < prb_.ndims && dims_prod != step_size; ++start_dim)
158             dims_prod *= n(start_dim);
159         assert(start_dim < prb_.ndims);
160         off /= step_size;
161
162         for (int d = start_dim; d < prb_.ndims; ++d) {
163             i_off += is(d);
164             o_off += os(d);
165             s_off += ss(d);
166
167             if (off % n(d)) break;
168
169             i_off += - n(d) * is(d);
170             o_off += - n(d) * os(d);
171             s_off += - n(d) * ss(d);
172             off /= n(d);
173
174             if (off == 0) break; /* FIXME: is it really required? */
175         }
176     }
177
178     void step(int off, int prev_i_off, int prev_o_off, int &i_off, int &o_off,
179             int step_size = 1) {
180         int dummy = 0;
181         step(off, prev_i_off, prev_o_off, dummy, i_off, o_off, dummy,
182                 step_size);
183     }
184
185     void tr8x8_avx2(int i_off, int o_off) {
186         for (int i = 0; i < 8; i++)
187             vmovups(Ymm(i), i_addr(i_off + i * 8));
188
189         for (int i = 0; i < 8 / 2; i++) {
190             vunpcklps(Ymm(8 + i), Ymm(2 * i), Ymm(2 * i + 1));
191             vunpckhps(Ymm(i), Ymm(2 * i), Ymm(2 * i + 1));
192         }
193
194         const unsigned int lfloat = 0x44;
195         const unsigned int ufloat = 0xee;
196         for (int i = 0; i < 8 / 2; i++) {
197             int j = i % 2 == 0 ? 8 + i : i - 1;
198             vshufps(Ymm(8 / 2 + 2 * i), Ymm(j), Ymm(j + 1), lfloat);
199             vshufps(Ymm(8 / 2 + 2 * i + 1), Ymm(j), Ymm(j + 1), ufloat);
200         }
201
202         const unsigned int lquad = 0x20;
203         for (int i = 0; i < 8 / 2; i++)
204             vperm2f128(Ymm(i), Ymm(8 / 2 + i), Ymm(8 + i), lquad);
205
206         const unsigned int uquad = 0x31;
207         for (int i = 8 / 2; i < 8; i++)
208             vperm2f128(Ymm(i), Ymm(i), Ymm(8 / 2 + i), uquad);
209
210         for (int i = 0; i < 8; i++)
211             vmovups(o_addr(o_off + i * 8), Ymm(i));
212     }
213
214     bool process_unroll_tr8x8(int len) {
215         bool can_do = true
216             && mayiuse(avx2)
217             && prb_.ndims >= 2
218             && utils::everyone_is(4, itype_sz, otype_sz)
219             && utils::everyone_is(8, n(0), n(1))
220             && utils::everyone_is(1, os(0), is(1))
221             && utils::everyone_is(8, os(1), is(0))
222             && prb_.scale_type == scale_type_t::NONE
223             && prb_.beta == 0.f;
224         if (!can_do) return false;
225
226         const int step_size = n(0) * n(1);
227         int i_off = 0, o_off = 0;
228         for (int off = 0; off < len; off += step_size) {
229             step(off, i_off, o_off, i_off, o_off, step_size);
230             tr8x8_avx2(i_off, o_off);
231         }
232
233         return true;
234     }
235
236     template <cpu_isa_t isa>
237     bool process_direct_copy(int len) {
238         using namespace data_type;
239
240         using Vmm = typename cpu_isa_traits<isa>::Vmm;
241         const int simd_w = cpu_isa_traits<isa>::vlen / itype_sz;
242
243         bool can_do = true
244             && mayiuse(isa)
245             && utils::everyone_is(1, os(0), is(0))
246             && (false
247                     || prb_.itype == prb_.otype
248                     || (prb_.itype == s32 && prb_.otype == f32)
249                     || (prb_.itype == f32 && prb_.otype == s32)
250                     )
251             && len % simd_w == 0
252             && n(0) % len == 0
253             && prb_.scale_type == scale_type_t::NONE
254             && prb_.beta == 0.f;
255         if (!can_do) return false;
256
257         for (int off = 0; off < len;) {
258             const int unroll = nstl::min(16, (len - off) / simd_w);
259
260             for (int ur = 0; ur < unroll; ++ur)
261                 uni_vmovups(Vmm(ur), i_addr(off + ur * simd_w));
262
263             if (prb_.itype != prb_.otype) {
264                 for (int ur = 0; ur < unroll; ++ur) {
265                     if (prb_.itype == s32 && prb_.otype == f32)
266                         uni_vcvtdq2ps(Vmm(ur), Vmm(ur));
267                     else if (prb_.itype == f32 && prb_.otype == s32)
268                         uni_vcvtps2dq(Vmm(ur), Vmm(ur));
269                     else assert(!"unreachable");
270                 }
271             }
272
273             for (int ur = 0; ur < unroll; ++ur)
274                 uni_vmovups(o_addr(off + ur * simd_w), Vmm(ur));
275
276             off += unroll * simd_w;
277         }
278
279         return true;
280     }
281
282     void process_unroll_generic_step(int reg_unroll, const int *i_off,
283             const int *o_off, const int *s_off) {
284         using namespace data_type;
285
286         auto cvt2ps = [=](const Xmm &dst, const Operand &src, data_type_t idt) {
287             Xmm dst_pure = Xmm(dst.getIdx());
288             switch (idt) {
289             case f32:
290                 if (src.isMEM() || src.getIdx() != dst.getIdx())
291                     vmovups(dst, src);
292                 break;
293             case s32: vcvtdq2ps(dst, src); break;
294             case s8: vpmovsxbd(dst, src); vcvtdq2ps(dst_pure, dst); break;
295             case u8: vpmovzxbd(dst, src); vcvtdq2ps(dst_pure, dst); break;
296             default: assert(!"unreachable");
297             }
298         };
299
300         auto cvt2int = [=](const Xmm &xmm, data_type_t odt, data_type_t idt) {
301             switch (odt) {
302             case s32:
303                 if (idt == f32) vcvtps2dq(xmm, xmm);
304                 else if (idt == s8) vpmovsxbd(xmm, xmm);
305                 else if (idt == u8) vpmovzxbd(xmm, xmm);
306                 break;
307             case s8:
308                 if (idt == f32) vcvtps2dq(xmm, xmm);
309                 if (idt == f32 || idt == s32) {
310                     if (mayiuse(avx512_core)) {
311                         vpmovsdb(xmm, xmm);
312                     } else {
313                         vpackssdw(xmm, xmm, xmm_zero);
314                         vpacksswb(xmm, xmm, xmm_zero);
315                     }
316                 }
317                 if (idt == u8) vpminub(xmm, xmm, xmm_4x127b);
318                 break;
319             case u8:
320                 if (idt == f32) vcvtps2dq(xmm, xmm);
321                 if (idt == f32 || idt == s32) {
322                     if (mayiuse(avx512_core)) {
323                         vpmaxsd(xmm, xmm, xmm_zero);
324                         vpmovusdb(xmm, xmm);
325                     } else {
326                         vpackssdw(xmm, xmm, xmm_zero);
327                         vpackuswb(xmm, xmm, xmm_zero);
328                     }
329                 }
330                 if (idt == s8) vpmaxsb(xmm, xmm, xmm_zero);
331                 break;
332             default: assert(!"unreachable");
333             }
334         };
335
336         auto load = [=](const Xmm &xmm, const Address &addr, int size) {
337             switch (size) {
338             case 16: movups(xmm, addr); break;
339             case 4: movss(xmm, addr); break;
340             case 1: pinsrb(xmm, addr, 0x0); break;
341             default: assert(!"unreachable");
342             }
343         };
344
345         auto store = [=](const Address &addr, const Xmm &xmm, int size) {
346             switch (size) {
347             case 16: movups(addr, xmm); break;
348             case 4: movss(addr, xmm); break;
349             case 1: pextrb(addr, xmm, 0x0); break;
350             default: assert(!"unreachable");
351             }
352         };
353
354         /* check whether loading 4 values at once is possible */
355         bool can_load_xmm = mayiuse(avx) && reg_unroll % 4 == 0;
356         for (int ur = 1; ur < reg_unroll; ++ur)
357             if (i_off[ur] != i_off[ur - 1] + 1)
358                 can_load_xmm = false;
359         const int load_step = can_load_xmm ? 4 : 1;
360
361         /* check whether storing 4 values at once is possible */
362         bool can_store_xmm = reg_unroll % 4 == 0;
363         for (int ur = 1; ur < reg_unroll; ++ur)
364             if (o_off[ur] != o_off[ur - 1] + 1)
365                 can_store_xmm = false;
366         const int ur_step = can_store_xmm ? 4 : 1;
367
368         const bool interim_f32 = false
369             || utils::one_of(f32, prb_.itype, prb_.otype)
370             || prb_.scale_type != scale_type_t::NONE
371             || prb_.beta != 0.f;
372
373         if (!can_load_xmm && can_store_xmm) {
374             assert(ur_step == 4);
375             /* load with stride */
376             for (int ur = 0; ur < reg_unroll; ur += ur_step) {
377                 for (int r = 0; r < ur_step; ++r) {
378                     if (itype_sz == 4)
379                         pinsrd(Xmm(ur), i_addr(i_off[ur + r]), r);
380                     else
381                         pinsrb(Xmm(ur), i_addr(i_off[ur + r]), r);
382                 }
383             }
384         } else {
385             for (int ur = 0; ur < reg_unroll; ur += load_step)
386                 load(Xmm(ur), i_addr(i_off[ur]), load_step * itype_sz);
387         }
388
389         /* xmm[:] <-- (f32)xmm[:] */
390         if (interim_f32) {
391             const int cvt_step = nstl::max(load_step, ur_step);
392             for (int ur = 0; ur < reg_unroll; ur += cvt_step)
393                 cvt2ps(Xmm(ur), Xmm(ur), prb_.itype);
394         }
395
396         if (can_load_xmm && !can_store_xmm) {
397             const bool fast_return = true // transposition on the fly
398                 && prb_.scale_type != scale_type_t::MANY
399                 && prb_.beta == 0.f;
400             if (fast_return) {
401                 for (int ur = 0; ur < reg_unroll; ur += load_step) {
402                     if (prb_.scale_type == scale_type_t::COMMON)
403                         mulps(Xmm(ur), xmm_scale);
404                     if (prb_.otype != f32)
405                         cvt2int(Xmm(ur), prb_.otype,
406                                 interim_f32 ? f32 : prb_.itype);
407                     for (int r = 0; r < load_step; ++r) {
408                         if (otype_sz == 4)
409                             pextrd(o_addr(o_off[ur + r]), Xmm(ur), r);
410                         else
411                             pextrb(o_addr(o_off[ur + r]), Xmm(ur), r);
412                     }
413                 }
414                 return;
415             }
416
417             /* scatter elements of xmm into 4 xmms */
418             if (itype_sz == 4 || interim_f32) {
419                 for (int ur = 0; ur < reg_unroll; ur += load_step)
420                     for (int r = 1; r < load_step; ++r)
421                         vshufps(Xmm(ur + r), Xmm(ur), Xmm(ur), r);
422             } else {
423                 for (int ur = 0; ur < reg_unroll; ur += load_step)
424                     for (int r = 1; r < load_step; ++r)
425                         vpalignr(Xmm(ur + r), Xmm(ur), Xmm(ur), r);
426             }
427         }
428
429         /* scale and beta processing */
430         if (can_store_xmm) {
431             /* xmm <-- scale * xmm[:] */
432             if (prb_.scale_type == scale_type_t::COMMON) {
433                 for (int ur = 0; ur < reg_unroll; ur += ur_step)
434                     mulps(Xmm(ur), xmm_scale);
435             } else if (prb_.scale_type == scale_type_t::MANY) {
436                 enum class scale_load_type_t { bcast, load, gather };
437
438                 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
439                     scale_load_type_t scale_load_type =
440                         scale_load_type_t::bcast; // the best case
441
442                     for (int r = ur + 1; r < ur + ur_step; ++r)
443                         if (s_off[r] != s_off[r - 1] + 0)
444                             scale_load_type = scale_load_type_t::load;
445
446                     if (scale_load_type == scale_load_type_t::bcast) {
447                         movss(xmm_scale, s_addr(s_off[ur]));
448                         shufps(xmm_scale, xmm_scale, 0x0);
449                         mulps(Xmm(ur), xmm_scale);
450                         continue;
451                     }
452
453                     // bcast doesn't work, the next try -- load
454                     for (int r = ur + 1; r < ur + ur_step; ++r)
455                         if (s_off[r] != s_off[r - 1] + 1)
456                             scale_load_type = scale_load_type_t::gather;
457
458                     if (scale_load_type == scale_load_type_t::load) {
459                         movups(xmm_scale, s_addr(s_off[ur]));
460                         mulps(Xmm(ur), xmm_scale);
461                         continue;
462                     }
463
464                     // load doesn't work as well
465                     // so gather the scale factors one by one
466                     for (int r = ur; r < ur + ur_step; ++r)
467                         pinsrd(xmm_scale, s_addr(s_off[r]), r - ur);
468                     mulps(Xmm(ur), xmm_scale);
469                 }
470             }
471
472             /* dst <-- beta * dst + xmm[:] */
473             assert(prb_.beta == 0.f || prb_.beta == 1.f);
474             if (prb_.beta == 1.f) {
475                 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
476                     if (prb_.otype == f32) {
477                         /* non VEX instructions do not support unaligned
478                          * memory for instructions other than movups. */
479                         if (mayiuse(avx)) {
480                             vaddps(Xmm(ur), o_addr(o_off[ur]));
481                         } else {
482                             /* register xmm(1) is unused */
483                             movups(Xmm(1), o_addr(o_off[ur]));
484                             addps(Xmm(ur), Xmm(1));
485                         }
486                     } else {
487                         cvt2ps(Xmm(1), o_addr(o_off[ur]), prb_.otype);
488                         vaddps(Xmm(ur), Xmm(1));
489                     }
490                 }
491             }
492         } else {
493             /* xmm[0] <-- scale * xmm[0] */
494             if (prb_.scale_type == scale_type_t::COMMON) {
495                 for (int ur = 0; ur < reg_unroll; ur += ur_step)
496                     mulss(Xmm(ur), xmm_scale);
497             } else if (prb_.scale_type == scale_type_t::MANY) {
498                 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
499                     mulss(Xmm(ur), s_addr(s_off[ur]));
500                 }
501             }
502
503             /* dst <-- beta * dst + xmm[0] */
504             assert(prb_.beta == 0.f || prb_.beta == 1.f);
505             if (prb_.beta == 1.f) {
506                 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
507                     if (prb_.otype == f32) {
508                         addss(Xmm(ur), o_addr(o_off[ur]));
509                     } else {
510                         if (prb_.otype == s32) {
511                             vmovss(xmm_tmp, o_addr(o_off[ur]));
512                         } else if (utils::one_of(prb_.otype, s8, u8)) {
513                             pinsrb(xmm_tmp, o_addr(o_off[ur]), 0x0);
514                         } else {
515                             assert(!"unsupported o_type");
516                         }
517                         cvt2ps(xmm_tmp, xmm_tmp, prb_.otype);
518                         addps(Xmm(ur), xmm_tmp);
519                     }
520                 }
521             }
522         }
523
524         for (int ur = 0; ur < reg_unroll; ur += ur_step) {
525             if (prb_.otype != f32)
526                 cvt2int(Xmm(ur), prb_.otype, interim_f32 ? f32 : prb_.itype);
527             store(o_addr(o_off[ur]), Xmm(ur), ur_step * otype_sz);
528         }
529     }
530
531     void process_unroll_generic(int len) {
532         const int blk = 8;
533
534         int i_off[2 * blk] = {0};
535         int o_off[2 * blk] = {0};
536         int s_off[2 * blk] = {0};
537
538         int curr = 0; // will switch between 0 and 1
539
540         for (int off = 0; off < len; off += blk) {
541             const int reg_unroll = nstl::min(off + blk, len) - off;
542
543             /* compute offsets */
544             for (int ur = off != 0 ? 0 : 1; ur < reg_unroll; ++ur) {
545                 const int ur_c = curr * blk + ur;
546                 const int ur_p = (ur_c - 1 + 2 * blk) % (2 * blk); // prev ur
547                 step(off + ur,
548                         i_off[ur_p], o_off[ur_p], s_off[ur_p],
549                         i_off[ur_c], o_off[ur_c], s_off[ur_c]);
550             }
551
552             process_unroll_generic_step(reg_unroll, i_off + curr * blk,
553                     o_off + curr * blk, s_off + curr * blk);
554
555             curr = 1 - curr;
556         }
557     }
558
559     void loop_begin(Label &l, Reg64 reg_cnt, int len) {
560         mov(reg_cnt, len);
561         L(l);
562     }
563
564     void loop_end(Label &l, Reg64 reg_cnt, int len,
565             int i_step, int o_step, int s_step) {
566         add(reg_off_in, i_step * itype_sz);
567         add(reg_off_out, o_step * otype_sz);
568         if (prb_.scale_type == scale_type_t::MANY)
569             add(reg_off_scale, s_step * stype_sz);
570         dec(reg_cnt);
571         jnz(l);
572
573         sub(reg_off_in, len * i_step * itype_sz);
574         sub(reg_off_out, len * o_step * otype_sz);
575         if (prb_.scale_type == scale_type_t::MANY)
576             sub(reg_off_scale, len * s_step * stype_sz);
577     }
578
579     bool simple_impl() {
580         simple_impl_desc_t d;
581         if (!simple_impl_desc_init(prb_, &d)) return false;
582
583         const int nfu = d.ndims_full_unroll;
584         const int ldu = d.len_last_dim_unroll;
585         const int n_jit_loops = prb_.ndims - d.ndims_full_unroll;
586         assert(n_jit_loops <= ndims_jit_loop_max);
587
588         xor_(reg_off_in, reg_off_in);
589         xor_(reg_off_out, reg_off_out);
590         if (prb_.scale_type == scale_type_t::MANY)
591             xor_(reg_off_scale, reg_off_scale);
592
593         Label l_loop[3];
594         Reg64 reg_cnt[3] = {r15, r14, r13};
595
596         if (n_jit_loops > 2)
597             loop_begin(l_loop[2], reg_cnt[2], n(nfu + 2));
598
599         if (n_jit_loops > 1)
600             loop_begin(l_loop[1], reg_cnt[1], n(nfu + 1));
601
602         if (n_jit_loops > 0)
603             loop_begin(l_loop[0], reg_cnt[0], n(nfu + 0) / ldu);
604
605         const bool optimized = false
606             || process_direct_copy<avx>(d.len_unroll)
607             || process_direct_copy<sse42>(d.len_unroll)
608             || process_unroll_tr8x8(d.len_unroll);
609         if (!optimized)
610             process_unroll_generic(d.len_unroll);
611
612         if (n_jit_loops > 0)
613             loop_end(l_loop[0], reg_cnt[0],
614                     n(nfu + 0) / ldu, is(nfu + 0) * ldu, os(nfu + 0) * ldu,
615                     ss(nfu + 0) * ldu);
616
617         if (n_jit_loops > 1)
618             loop_end(l_loop[1], reg_cnt[1],
619                     n(nfu + 1), is(nfu + 1), os(nfu + 1), ss(nfu + 1));
620
621         if (n_jit_loops > 2)
622             loop_end(l_loop[2], reg_cnt[2],
623                     n(nfu + 2), is(nfu + 2), os(nfu + 2), ss(nfu + 2));
624
625         return true;
626     }
627
628     void impl() {
629         if (simple_impl()) return;
630         assert(!"no implementation available");
631     }
632
633     jit_uni_reorder_kernel_f32(const desc_t &desc)
634         : kernel_t(desc), jit_generator() {
635         itype_sz = data_type_size(prb_.itype);
636         otype_sz = data_type_size(prb_.otype);
637         stype_sz = sizeof(float);
638
639         preamble();
640 #       define PARAM(x) ptr[abi_param1 + offsetof(call_param_t, x)]
641         if (prb_.scale_type == scale_type_t::COMMON) {
642             auto reg_ptr_scale_tmp = reg_ptr_in;
643             mov(reg_ptr_scale_tmp, PARAM(scale));
644             movups(xmm_scale, ptr[reg_ptr_scale_tmp]);
645         } else if (prb_.scale_type == scale_type_t::MANY) {
646             mov(reg_ptr_scale, PARAM(scale));
647         }
648         mov(reg_ptr_in, PARAM(in));
649         mov(reg_ptr_out, PARAM(out));
650 #       undef PARAM
651
652         if (mayiuse(avx)) {
653             vxorps(xmm_zero, xmm_zero, xmm_zero);
654
655             if (prb_.itype == data_type::u8 && prb_.otype == data_type::s8) {
656                 mov(reg_tmp.cvt32(), 0x7f7f7f7f);
657                 movd(xmm_4x127b, reg_tmp.cvt32());
658             }
659         }
660
661         impl();
662         postamble();
663         ker_ = (void (*)(const call_param_t *))getCode();
664     }
665
666 private:
667     int itype_sz;
668     int otype_sz;
669     int stype_sz;
670
671     Reg64 reg_ptr_in = rsi;
672     Reg64 reg_ptr_out = rdx;
673     Reg64 reg_ptr_scale = abi_not_param1;
674
675     Reg64 reg_off_in = r8;
676     Reg64 reg_off_out = r9;
677     Reg64 reg_off_scale = r10;
678
679     Reg64 reg_tmp = rax;
680
681     Xmm xmm_scale = xmm15;
682     Xmm xmm_zero = xmm14;
683     Xmm xmm_4x127b = xmm13; // TODO: unite with xmm_zero
684     Xmm xmm_tmp = xmm12;
685 };
686
687 status_t kernel_t::desc_init(kernel_t::desc_t &desc, const prb_t &prb,
688         int ndims_ker_max) {
689     desc.prb = prb;
690     desc.prb.ioff = desc.prb.ooff = 0;
691
692     if (ndims_ker_max > prb.ndims)
693         return status::invalid_arguments;
694
695     auto ndims_ker_max_f = [&]() {
696         size_t cur_size = 1;
697         for (int d = 0; d < prb.ndims; cur_size *= prb.nodes[d++].n)
698             if (cur_size >= ker_prb_size_min) return d;
699         return prb.ndims;
700     };
701
702     if (ndims_ker_max <= 0)
703         ndims_ker_max = ndims_ker_max_f();
704
705     /* traverse through kernel implementations */
706     /* TODO: find a better way to do that... */
707     desc.id = 0;
708     for (int ndims_ker = ndims_ker_max; ndims_ker > 0; --ndims_ker) {
709         desc.prb.ndims = ndims_ker;
710         if (jit_uni_reorder_kernel_f32::applicable(desc.prb))
711             return status::success;
712     }
713
714     return status::unimplemented;
715 }
716
717 kernel_t *kernel_t::create(const kernel_t::desc_t &desc) {
718     switch (desc.id) {
719     case 0: return new jit_uni_reorder_kernel_f32(desc);
720     default: assert(!"unknown kernel id"); return nullptr;
721     }
722
723     return nullptr;
724 }
725
726 }
727
728 static void prb_block_for_cache(tr::prb_t &prb) {
729     if (prb.nodes[0].is % 64 == 0 && prb.nodes[0].n > 16) {
730         /** an attempt to use caches more efficient and
731          * address the 4K-aliasing issue */
732         /* TODO: improve the logic around here */
733         int j = 1;
734         for (; j < prb.ndims && prb.nodes[j].is != 1; ++j);
735         if (j == prb.ndims) return;
736
737         /* it makes sense to re-prioritize sequential read over
738          * sequential write if the former would not trash the
739          * cache, i.e. is == 1 and os % 2^smth != 0. Smth is
740          * set to 2 at the moment */
741         const int move_to = prb.nodes[j].os % 4 != 0 ? 0 : 1;
742         if (j == move_to) return;
743
744         if (prb.nodes[j].n > 16 && prb.nodes[j].n % 16 == 0)
745             prb_node_split(prb, j, 16);
746
747         prb_node_move(prb, j, move_to);
748         DEBUG({ printf("cache: "); prb_dump(prb); });
749     }
750 }
751
752 /** finds the maximum number of dimension the kernel should process and
753  * optionally splits one of the dimension to achieve better balance between
754  * parallel driver and the kernel. */
755 static void prb_thread_kernel_balance(tr::prb_t &prb, int &ndims_ker_max) {
756     size_t sz_total = 1;
757     for (int d = 0; d < prb.ndims; ++d)
758         sz_total *= prb.nodes[d].n;
759
760     /* sz_drv_min is the minimal size for the parallel
761      * driver required for good parallelization */
762     const size_t sz_drv_min = nstl::min<size_t>(
763             16 * mkldnn_get_max_threads(),
764             utils::div_up(sz_total, 1024));
765
766     /* kdims -- # of dimensions processed by a kernel
767      * sz_ker_cur -- product of the dimension processed by a kernel
768      * sz_drv_cur -- product of the dimension processed by a driver */
769
770     int kdims = prb.ndims;
771     size_t sz_drv_cur = 1;
772     for (; kdims > 1 && sz_drv_cur < sz_drv_min; --kdims)
773         sz_drv_cur *= prb.nodes[kdims - 1].n;
774
775     size_t sz_ker_cur = 1;
776     for (int d = 0; d < kdims; ++d)
777         sz_ker_cur *= prb.nodes[d].n;
778
779     /* Initially kdims is chosen so that sz_drv_cur >= sz_drv_min.
780      *
781      * It might happen that for chosen kdims the sz_ker_cur is too small
782      * (less than tr::ker_prb_size_min). In that case try to split the
783      * innermost driver dimension into two, to increase sz_ker_cur. */
784     bool want_borrow_ker_from_drv = true
785         && kdims < prb.ndims
786         && sz_ker_cur < tr::ker_prb_size_min
787         && sz_drv_cur > sz_drv_min;
788     if (want_borrow_ker_from_drv) {
789         /* sz_want_borrow is the minimal sz, so that:
790          *  o) sz_ker_cur * sz_want_borrow >= tr::ker_prb_size_min
791          *  o) current innermost driver dimension is divisible by
792          *     sz_want_borrow (so that we can evenly split that
793          *     dimension into two)
794          *
795          *  In the worst case the minimal sz_want_borrow is equal
796          *  to the innermost driver dimension itself. In that case
797          *  we will sacrifice it in favor of kernel (is it fine?). */
798         size_t sz_want_borrow
799             = utils::div_up(tr::ker_prb_size_min, sz_ker_cur);
800         for (; prb.nodes[kdims].n % sz_want_borrow; ++sz_want_borrow);
801         if (sz_want_borrow != prb.nodes[kdims].n)
802             prb_node_split(prb, kdims, sz_want_borrow);
803         kdims += 1;
804     }
805
806     /* On the other hand it might happen that for chosen kdims
807      * the sz_drv_cur is too small (less than sz_drv_min). In that case
808      * try to split the outermost kernel dimension into two, to increase
809      * sz_drv_cur. */
810     bool want_borrow_drv_from_ker = true
811         && sz_ker_cur > tr::ker_prb_size_min
812         && sz_drv_cur < sz_drv_min;
813     if (want_borrow_drv_from_ker) {
814         size_t sz_want_borrow = utils::div_up(sz_drv_min, sz_drv_cur);
815         for (; prb.nodes[kdims - 1].n % sz_want_borrow; ++sz_want_borrow);
816         if (sz_want_borrow != prb.nodes[kdims - 1].n)
817             prb_node_split(prb, kdims - 1,
818                     prb.nodes[kdims - 1].n / sz_want_borrow);
819     }
820
821     ndims_ker_max = kdims;
822
823     if (want_borrow_ker_from_drv || want_borrow_drv_from_ker) {
824         DEBUG({ printf("split: "); prb_dump(prb);
825                 printf("ndims_ker_max = %d\n", ndims_ker_max); });
826     }
827 }
828
829 struct jit_uni_reorder_t : public cpu_primitive_t {
830     struct pd_t : public cpu_reorder_pd_t {
831         pd_t(const cpu_memory_pd_t *input_pd, const cpu_memory_pd_t *output_pd,
832                 const primitive_attr_t *attr)
833             : cpu_reorder_pd_t(input_pd, output_pd, attr) {}
834
835         DECLARE_COMMON_PD_T("jit:uni", jit_uni_reorder_t);
836
837         static status_t create(reorder_pd_t **reorder_pd,
838                 const memory_pd_t *input_pd, const memory_pd_t *output_pd,
839                 const primitive_attr_t *attr) {
840             const memory_desc_t *imd = input_pd->desc();
841             const memory_desc_t *omd = output_pd->desc();
842
843             auto prb = tr::prb_t();
844
845             if (imd->format == mkldnn_OhIw8o4i || imd->format == mkldnn_gOhIw8o4i ||
846                 imd->format == mkldnn_OhIw8o4i_s8s8 || imd->format == mkldnn_gOhIw8o4i_s8s8 ||
847                 omd->format == mkldnn_OhIw8o4i || omd->format == mkldnn_gOhIw8o4i ||
848                 omd->format == mkldnn_OhIw8o4i_s8s8 || omd->format == mkldnn_gOhIw8o4i_s8s8)
849                 return status::unimplemented;
850
851             status_t prb_init_status = prb_init(prb, *imd, *omd, attr);
852             if (prb_init_status != success) return prb_init_status;
853
854             DEBUG({ printf("init : "); prb_dump(prb); });
855             prb_normalize(prb);
856             DEBUG({ printf("norm : "); prb_dump(prb); });
857             prb_simplify(prb);
858             DEBUG({ printf("smpl : "); prb_dump(prb); });
859
860             prb_block_for_cache(prb);
861
862             int ndims_ker_max;
863             prb_thread_kernel_balance(prb, ndims_ker_max);
864
865             tr::kernel_t::desc_t ker_desc;
866             status_t ker_init_status
867                 = tr::kernel_t::desc_init(ker_desc, prb, ndims_ker_max);
868             if (ker_init_status != status::success) return ker_init_status;
869
870             const int ndims_driver = prb.ndims - ker_desc.prb.ndims;
871             if (ndims_driver > jit_uni_reorder_t::ndims_driver_max)
872                 return status::unimplemented;
873
874             DEBUG({ printf("ker  : "); prb_dump(ker_desc.prb); });
875
876             auto _pd = new pd_t((const cpu_memory_pd_t *)input_pd,
877                     (const cpu_memory_pd_t *)output_pd, attr);
878             if (_pd == nullptr) return out_of_memory;
879             if (_pd->init() != success) { delete _pd; return unimplemented; }
880             _pd->prb_ = prb;
881             _pd->ker_desc_ = ker_desc;
882             return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
883         }
884
885         tr::prb_t prb_;
886         tr::kernel_t::desc_t ker_desc_;
887     };
888
889     jit_uni_reorder_t(const pd_t *apd, const input_vector &inputs,
890             const output_vector &outputs)
891         : cpu_primitive_t(apd, inputs, outputs) {
892         kernel_ = tr::kernel_t::create(pd()->ker_desc_);
893         assert(kernel_);
894     }
895     ~jit_uni_reorder_t() { delete kernel_; }
896
897     void omp_driver_0d(int off, const char *in, char *out,
898             const float *scale) const {
899         tr::call_param_t c{in, out, scale};
900         (*kernel_)(&c);
901     }
902
903     void omp_driver_1d(int ithr, int nthr, int off, const char *in, char *out,
904             const float *scale) const {
905         const tr::node_t *ns = pd()->prb_.nodes + off;
906         for_nd(ithr, nthr, (ptrdiff_t)ns[0].n, [&](ptrdiff_t d0) {
907             auto c = tr::call_param_t();
908             c.in = in + d0 * ns[0].is * data_type_size(pd()->prb_.itype);
909             c.out = out + d0 * ns[0].os * data_type_size(pd()->prb_.otype);
910             c.scale = scale + d0 * ns[0].ss;
911             (*kernel_)(&c);
912         });
913     }
914
915     void omp_driver_2d(int ithr, int nthr, int off, const char *in, char *out,
916             const float *scale) const {
917         const tr::node_t *ns = pd()->prb_.nodes + off;
918         for_nd(ithr, nthr, (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n,
919                 [&](ptrdiff_t d1, ptrdiff_t d0) {
920             auto c = tr::call_param_t();
921             c.in = in + (d0 * ns[0].is + d1 * ns[1].is)
922                 * data_type_size(pd()->prb_.itype);
923             c.out = out + (d0 * ns[0].os + d1 * ns[1].os)
924                 * data_type_size(pd()->prb_.otype);
925             c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss;
926             (*kernel_)(&c);
927         });
928     }
929
930     void omp_driver_3d(int ithr, int nthr, int off, const char *in, char *out,
931             const float *scale) const {
932         const tr::node_t *ns = pd()->prb_.nodes + off;
933         for_nd(ithr, nthr, (ptrdiff_t)ns[2].n, (ptrdiff_t)ns[1].n,
934                 (ptrdiff_t)ns[0].n,
935                 [&](ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) {
936             auto c = tr::call_param_t();
937             c.in = in + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is)
938                 * data_type_size(pd()->prb_.itype);
939             c.out = out + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os)
940                 * data_type_size(pd()->prb_.otype);
941             c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss;
942             (*kernel_)(&c);
943         });
944     }
945
946     void omp_driver_4d(int ithr, int nthr, int off, const char *in, char *out,
947             const float *scale) const {
948         const tr::node_t *ns = pd()->prb_.nodes + off;
949         for_nd(ithr, nthr, (ptrdiff_t)ns[3].n, (ptrdiff_t)ns[2].n,
950                 (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n,
951                 [&](ptrdiff_t d3, ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) {
952             auto c = tr::call_param_t();
953             c.in = in + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is
954                     + d3 * ns[3].is) * data_type_size(pd()->prb_.itype);
955             c.out = out + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os
956                     + d3 * ns[3].os) * data_type_size(pd()->prb_.otype);
957             c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss
958                 + d3 * ns[3].ss;
959             (*kernel_)(&c);
960         });
961     }
962
963     void omp_driver(const char *in, char *out, const float *scale) const {
964         in += pd()->prb_.ioff * data_type_size(pd()->prb_.itype);
965         out += pd()->prb_.ooff * data_type_size(pd()->prb_.otype);
966
967         DEBUG({ printf("prb : "); tr::prb_dump(pd()->prb_); });
968         DEBUG({ printf("ker : "); tr::prb_dump(pd()->ker_desc_.prb); });
969
970         int ndims = pd()->prb_.ndims;
971         int ndims_ker = pd()->ker_desc_.prb.ndims;
972         assert(ndims - ndims_ker <= ndims_driver_max);
973
974         if (ndims - ndims_ker == 0) {
975             set_rnd_mode(pd()->attr()->round_mode_);
976             omp_driver_0d(ndims_ker, in, out, scale);
977             restore_rnd_mode();
978         } else {
979             parallel(0, [&](const int ithr, const int nthr) {
980                 set_rnd_mode(pd()->attr()->round_mode_);
981                 switch (ndims - ndims_ker) {
982                 case 1: omp_driver_1d(ithr, nthr, ndims_ker, in, out, scale); break;
983                 case 2: omp_driver_2d(ithr, nthr, ndims_ker, in, out, scale); break;
984                 case 3: omp_driver_3d(ithr, nthr, ndims_ker, in, out, scale); break;
985                 case 4: omp_driver_4d(ithr, nthr, ndims_ker, in, out, scale); break;
986                 default: assert(!"unimplemented");
987                 }
988                 restore_rnd_mode();
989             });
990         }
991     }
992
993     virtual void execute(event_t *e) const {
994         auto in = reinterpret_cast<const char *>(input_memory(0));
995         auto out = reinterpret_cast<char *>(memory());
996
997         omp_driver(in, out, pd()->attr()->output_scales_.scales_);
998
999         e->set_state(event_t::ready);
1000     }
1001
1002     enum { ndims_driver_max = 4 };
1003
1004 private:
1005     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
1006     tr::kernel_t *kernel_;
1007 };
1008
1009 status_t jit_uni_reorder_create(reorder_pd_t **reorder_pd,
1010         const memory_pd_t *input_pd, const memory_pd_t *output_pd,
1011         const primitive_attr_t *attr) {
1012     return jit_uni_reorder_t::pd_t::create(reorder_pd, input_pd, output_pd,
1013             attr);
1014 }
1015
1016 }
1017 }
1018 }