Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / simple_reorder.hpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_SIMPLE_REORDER_HPP
18 #define CPU_SIMPLE_REORDER_HPP
19
20 #include <assert.h>
21
22 #include "c_types_map.hpp"
23 #include "type_helpers.hpp"
24 #include "math_utils.hpp"
25 #include "mkldnn_thread.hpp"
26 #include "utils.hpp"
27
28 #include "format_traits.hpp"
29 #include "cpu_reorder_pd.hpp"
30 #include "cpu_primitive.hpp"
31
32 #include "simple_q10n.hpp"
33 #include "cpu_isa_traits.hpp"
34
35 namespace mkldnn {
36 namespace impl {
37 namespace cpu {
38
39 using namespace mkldnn::impl::status;
40 using namespace mkldnn::impl::memory_format;
41 using namespace mkldnn::impl::data_type;
42
43 using dk = data_kind_t;
44 using bf = block_format_t;
45
46 using namespace mkldnn::impl::utils;
47 using math::saturate;
48
49 template<impl::data_type_t type>
50 using data_t = typename prec_traits<type>::type;
51
52 template<impl::data_type_t type_i, impl::data_type_t type_o>
53 using _qz_a1b0 = qz_a1b0<data_t<type_i>, data_t<type_o>>;
54
55 template<impl::data_type_t type_i, impl::data_type_t type_o>
56 using _qz = qz<data_t<type_i>, data_t<type_o>>;
57
58 namespace fmt_order {
59     const bool keep = true;
60     const bool reverse = false;
61     const bool any = keep;
62 }
63
64 namespace spec {
65 struct direct_copy {};
66 struct direct_copy_except_dim_0 {};
67 struct reference {};
68 }
69
70 #define SIMPLE_REORDER_TEMPL_DECL \
71     impl::data_type_t type_i, impl::memory_format_t fmt_i, \
72     impl::data_type_t type_o, impl::memory_format_t fmt_o, bool order_keep
73 #define SIMPLE_REORDER_TEMPL_CALL \
74     type_i, fmt_i, type_o, fmt_o, order_keep
75
76 #define DECLARE_COMMON_PARAMS() \
77         const memory_desc_wrapper &input_d = pd->input_pd(); \
78         const memory_desc_wrapper &output_d = pd->output_pd(); \
79         const float alpha = pd->alpha(); MAYBE_UNUSED(alpha); \
80         const float beta = pd->beta(); MAYBE_UNUSED(beta); \
81         const round_mode_t rmode = pd->attr()->round_mode_; MAYBE_UNUSED(rmode);
82
83 /* specific reorders: common template */
84 template <SIMPLE_REORDER_TEMPL_DECL, typename spec = void>
85 struct simple_reorder_impl {};
86
87 namespace {
88 bool simple_fmt_check(bool order_keep, impl::memory_format_t fmt_i,
89         impl::memory_format_t fmt_o, const memory_desc_wrapper &input_d,
90         const memory_desc_wrapper &output_d) {
91     return input_d.format() == (order_keep ? fmt_i : fmt_o)
92         && output_d.format() == (order_keep ? fmt_o : fmt_i);
93 }
94 bool simple_attr_check(const primitive_attr_t *attr, bool many_scales_support) {
95     if (many_scales_support)
96         return true;
97     return IMPLICATION(attr, attr->output_scales_.mask_ == 0);
98 }
99 }
100
101 /* specific reorders: implementation */
102 template <SIMPLE_REORDER_TEMPL_DECL>
103 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
104 typename utils::enable_if<fmt_i == any && (false
105     || fmt_o == hwio_s8s8
106     || fmt_o == hwigo_s8s8)>::type>
107 {
108     static bool is_applicable(const memory_desc_wrapper &input_d,
109             const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
110     {
111         const size_t D_mask = utils::array_product(input_d.dims(),
112                                 math::ilog2q(attr->output_scales_.mask_ + 1));
113         const int oc = (input_d.dims()[fmt_o == hwigo_s8s8 + 0]);
114         const int g = (fmt_o == hwigo_s8s8) ? (input_d.dims()[0]) : 1;
115
116         return output_d.format() == fmt_o
117             && (input_d.data_type() == f32 || input_d.data_type() == s8)
118             && output_d.data_type() == s8
119             && (D_mask == 1 || D_mask == (size_t)g * oc);
120     }
121
122     static status_t execute(const cpu_reorder_pd_t *pd,
123         const data_t<type_i> *input, data_t<type_o> *output) {
124         DECLARE_COMMON_PARAMS();
125
126         static constexpr bool w_groups = fmt_o == hwigo_s8s8;
127
128         const auto &dims = input_d.dims();
129         const auto &pdims = output_d.blocking_desc().padding_dims;
130
131         const int G = w_groups ? dims[0] : 1;
132         const int OC = dims[w_groups + 0];
133         const int IC = dims[w_groups + 1];
134         const int H = dims[w_groups + 2];
135         const int W = dims[w_groups + 3];
136
137         const float *scales = pd->attr()->output_scales_.scales_;
138         const size_t D_mask = utils::array_product(input_d.dims(),
139                 math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
140
141         float adj_scale = (mayiuse(avx512_core_vnni)) ? 1.0f : (1.0f / 2.0f);
142
143         size_t offset = G * pdims[w_groups + 0] * pdims[w_groups + 1] * H * W;
144         int32_t *cp = reinterpret_cast<int32_t *>(output + offset);
145
146         parallel_nd(G, OC, [&](int g, int oc) {
147             cp[g * OC + oc] = 0;
148             for (int ic = 0; ic < IC; ic++)
149             for (int h = 0; h < H; h++)
150             for (int w = 0; w < W; w++) {
151                 auto i = input[input_d.blk_off<!w_groups>(g, oc, ic, h, w)];
152                 auto &o = output[output_d.blk_off<!w_groups>(g, oc, ic, h, w)];
153                 const float s = scales[(D_mask == 1) ? 0 : g * OC + oc];
154
155                 o = qz_b0<data_t<type_i>, data_t<type_o>>()(
156                     i, s * adj_scale, rmode);
157                 cp[g * OC + oc] -= (int32_t)o;
158             }
159             cp [g * OC + oc] *= 128;
160         });
161         return success;
162     }
163 };
164
165 template <SIMPLE_REORDER_TEMPL_DECL>
166 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
167     typename utils::enable_if<
168           ((fmt_i == goihw || fmt_i == oihw)
169            && (format_traits<fmt_o>::blk_fmt == bf::_4i16o4i_s8s8
170                || format_traits<fmt_o>::blk_fmt == bf::_2i8o4i_s8s8
171                || format_traits<fmt_o>::blk_fmt == bf::_4o4i_s8s8))
172     >::type>
173 {
174     static bool is_applicable(const memory_desc_wrapper &input_d,
175             const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
176     {
177         const size_t D_mask = utils::array_product(input_d.dims(),
178                                 math::ilog2q(attr->output_scales_.mask_ + 1));
179         const int oc = (input_d.dims()[(fmt_i == goihw) + 0]);
180         const int g = (fmt_i == goihw) ? (input_d.dims()[0]) : 1;
181
182         return input_d.format() == fmt_i
183             && output_d.format() == fmt_o
184             && (input_d.data_type() == f32 || input_d.data_type() == s8)
185             && output_d.data_type() == s8
186             && (D_mask == 1 || D_mask == (size_t)g * oc);
187     }
188
189     static status_t execute(const cpu_reorder_pd_t *pd,
190         const data_t<type_i> *input, data_t<type_o> *output) {
191         DECLARE_COMMON_PARAMS();
192
193         static constexpr bool w_groups = fmt_i == goihw;
194         const int blksize = format_traits<fmt_o>::blk_size;
195         const int sblk = 4;
196
197         const auto &_g_oihw_d = order_keep ? input_d : output_d;
198         const auto &dims = input_d.dims();
199         const auto &pdims = order_keep
200             ? output_d.blocking_desc().padding_dims
201             : input_d.blocking_desc().padding_dims;
202
203         const int G = w_groups ? dims[0] : 1;
204         const int OC = dims[w_groups + 0];
205         const int NB_OC = pdims[w_groups + 0] / blksize;
206         const int IC = dims[w_groups + 1];
207         const int NB_IC = pdims[w_groups + 1] / blksize;
208         const int H = dims[w_groups + 2];
209         const int W = dims[w_groups + 3];
210
211         const float *scales = pd->attr()->output_scales_.scales_;
212         const size_t D_mask = utils::array_product(input_d.dims(),
213                             math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
214
215         float adj_scale = (mayiuse(avx512_core_vnni)) ? 1.f : (1.f / 2.f);
216
217         auto index = [&](const int ic, const int oc) {
218             return ((ic / sblk) * blksize * sblk + sblk * oc + ic % sblk);
219         };
220
221         auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out,
222             int32_t *c, const float *s, const int oc_block, const int ic_block) {
223             for (int ic = 0; ic < ic_block; ++ic) {
224             for (int oc = 0; oc < oc_block; ++oc) {
225                 const auto _g_oihw_off =
226                     oc * _g_oihw_d.blocking_desc().strides[0][w_groups + 0]
227                   + ic * _g_oihw_d.blocking_desc().strides[0][w_groups + 1];
228                 out[index(ic, oc)]
229                     = qz_b0<data_t<type_i>, data_t<type_o>>()(
230                             inp[_g_oihw_off], s[oc] * adj_scale, rmode);
231                 c[oc] -= (128 * (int32_t)(out[index(ic, oc)]));
232             }
233             }
234         };
235
236         constexpr int i_mult = blksize;
237         constexpr int o_mult = 1;
238
239         size_t offset = G * pdims[w_groups+0] * pdims[w_groups+1] * H * W;
240         int32_t *cp = reinterpret_cast<int32_t *>(output + offset);
241         parallel_nd(G * NB_OC * blksize, [&](int i) {
242             cp[i] = 0;
243         });
244
245         parallel_nd(G, NB_OC, [&](int g, int O) {
246             for (int I = 0; I < NB_IC; I++)
247                 for (int h = 0; h < H; h++)
248                 for (int w = 0; w < W; w++) {
249                     auto i = &input[input_d.blk_off<!w_groups>(g,
250                             i_mult * O, i_mult * I, h, w)];
251                     auto o = &output[output_d.blk_off<!w_groups>(
252                             g, o_mult * O, o_mult * I, h, w)];
253                     const int oc_block = nstl::min(blksize, OC - O * blksize);
254                     const int ic_block = nstl::min(blksize, IC - I * blksize);
255
256                     int _offset = (g * NB_OC + O) * blksize;
257                     ker(i, o, (order_keep) ? &cp[_offset] : nullptr,
258                             &scales[(D_mask == 1) ? 0 : _offset],
259                                         oc_block, ic_block);
260                 }
261         });
262         return success;
263     }
264 };
265
266 template <SIMPLE_REORDER_TEMPL_DECL>
267 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
268     typename utils::enable_if<true
269     && (fmt_i == goihw && fmt_o == Goihw16g_s8s8)>::type>
270 {
271     static bool is_applicable(const memory_desc_wrapper &input_d,
272             const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
273         const size_t D_mask = utils::array_product(input_d.dims(),
274                             math::ilog2q(attr->output_scales_.mask_ + 1));
275         const int oc = input_d.dims()[1];
276         const int g = input_d.dims()[0];
277
278         return true
279             && order_keep
280             && input_d.format() == fmt_i
281             && output_d.format() == fmt_o
282             && (input_d.data_type() == f32 || input_d.data_type() == s8)
283             && output_d.data_type() == s8
284             && (D_mask == 1 || D_mask == (size_t)g * oc);
285     }
286
287     static status_t execute(const cpu_reorder_pd_t *pd,
288             const data_t<type_i> *input, data_t<type_o> *output) {
289         DECLARE_COMMON_PARAMS();
290
291         const int blksize = 16;
292
293         const auto &dims = input_d.dims();
294         const auto &pdims = output_d.blocking_desc().padding_dims;
295         const int G = dims[0];
296         const int Gp = pdims[0];
297         const int OC = dims[1];
298         const int IC = dims[2];
299         const int H = dims[3];
300         const int W = dims[4];
301
302         const size_t D_mask = utils::array_product(input_d.dims(),
303                             math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
304         const float *scales = pd->attr()->output_scales_.scales_;
305         float adj_scale = (mayiuse(avx512_core_vnni)) ? 1.f : (1.f / 2.f);
306
307
308         auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out,
309                 int32_t *cp, const float *s, const int g_block) {
310             PRAGMA_OMP_SIMD()
311             for (int g = 0; g < g_block; g++) {
312                 const auto i_off = g * input_d.blocking_desc().strides[0][0];
313                 out[g] = qz_b0<data_t<type_i>, data_t<type_o>>()(
314                         inp[i_off], s[g * OC] * adj_scale, rmode);
315                 cp[g * OC] -= 128 * (int32_t)(out[g]);
316             }
317         };
318
319         size_t cp_offset = output_d.size() - output_d.additional_buffer_size();
320         int32_t *cp = reinterpret_cast<int32_t *>(output + cp_offset);
321         parallel_nd((Gp/blksize) * OC, [&](int ib) {
322             PRAGMA_OMP_SIMD()
323             for (int i = 0; i < blksize; i++)
324                 cp[ib * blksize + i] = 0;
325         });
326
327         parallel_nd(Gp/blksize, OC, [&](int gb, int O) {
328                 for (int I = 0; I < IC; I++) {
329                     for (int h = 0; h < H; h++) {
330                     for (int w = 0; w < W; w++) {
331                         const int g_block = nstl::min(G - gb * blksize, blksize);
332                         const auto inp = &input[input_d.blk_off(gb * blksize, O, I, h, w)];
333                         const auto out = &output[output_d.blk_off(gb, O, I, h, w)];
334                         int offset = gb * blksize + O;
335                         ker(inp, out, &cp[offset],
336                             &scales[(D_mask == 1) ? 0 : offset], g_block);
337                    }
338                    }
339                }
340         });
341         return success;
342     }
343 };
344
345 template <SIMPLE_REORDER_TEMPL_DECL>
346 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
347     typename utils::enable_if<true
348     && format_traits<fmt_i>::blk_fmt == bf::_8i16o2i
349     && format_traits<fmt_o>::blk_fmt == bf::_8o16i2o>::type>
350 {
351     static bool is_applicable(const memory_desc_wrapper &input_d,
352             const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
353     {
354         return simple_fmt_check(order_keep, fmt_i, fmt_o, input_d, output_d)
355             && simple_attr_check(attr, false);
356     }
357
358     static status_t execute(const cpu_reorder_pd_t *pd,
359         const data_t<type_i> *input, data_t<type_o> *output) {
360         DECLARE_COMMON_PARAMS();
361
362         static constexpr bool w_groups
363             = format_traits<fmt_o>::data_kind == dk::gwei;
364         constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
365         constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
366         constexpr int blksize = format_traits<fmt_o>::blk_size;
367
368         const auto &dims = input_d.dims();
369
370         const int G = w_groups ? dims[0] : 1;
371         const int NB_OC = dims[w_groups + 0] / blksize;
372         const int NB_IC = dims[w_groups + 1] / blksize;
373         const int D = is_3d ? dims[w_groups + 2] : 1;
374         const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
375         const int W = dims[w_groups + 3 + is_3d - is_1d];
376
377         auto idx_i = [&](const int oc, const int ic)
378         { return ((ic / 2) * blksize * 2 + 2 * oc + ic % 2); };
379
380         auto idx_o = [&](const int oc, const int ic)
381         { return ((oc / 2) * blksize * 2 + 2 * ic + oc % 2); };
382
383         auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o) -> void {
384             if (alpha == 1.0 && beta == 0.0) {
385                 for (int ic = 0; ic < blksize; ++ic) {
386                     for (int oc = 0; oc < blksize; ++oc) {
387                         o[idx_o(oc, ic)] = _qz_a1b0<type_i, type_o>()(
388                                 i[idx_i(oc, ic)], rmode);
389                     }
390                 }
391             } else {
392                 for (int ic = 0; ic < blksize; ++ic) {
393                     for (int oc = 0; oc < blksize; ++oc) {
394                         o[idx_o(oc, ic)] = _qz<type_i, type_o>()(
395                                 i[idx_i(oc, ic)], o[idx_o(oc, ic)], alpha,
396                                 beta, rmode);
397                     }
398                 }
399             }
400         };
401
402         parallel_nd(G, NB_OC, NB_IC, D, H, W,
403             [&](int g, int o, int i, int d, int h, int w) {
404             auto ptr_i = &input[wei_blk_off_like_gwei3D<fmt_i>(
405                     input_d, g, o, i, d,  h, w)];
406             auto ptr_o = &output[wei_blk_off_like_gwei3D<fmt_o>(
407                     output_d, g, o, i, d, h, w)];
408             ker(ptr_i, ptr_o);
409         });
410
411         return success;
412     }
413 };
414
415 /* reorders with tail support */
416
417 template <SIMPLE_REORDER_TEMPL_DECL>
418 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
419 typename utils::enable_if<fmt_i == nChw8c && fmt_o == nhwc && order_keep>::type>
420 {
421     static bool is_applicable(const memory_desc_wrapper &input_d,
422         const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
423         int smask = attr ? attr->output_scales_.mask_ : 0;
424         return (smask == 0 || smask == 2) && order_keep && input_d._md->format == nChw8c && output_d._md->format == nhwc;
425     }
426
427     static status_t execute(const cpu_reorder_pd_t *pd,
428         const data_t<type_i> *input, data_t<type_o> *output) {
429         DECLARE_COMMON_PARAMS();
430
431         const auto &pdims = input_d.blocking_desc().padding_dims;
432         const auto &dims = input_d.dims();
433         constexpr int blksize = format_traits<fmt_i>::blk_size;
434         const int C = dims[1];
435         const int H = dims[2];
436         const int W = dims[3];
437
438         constexpr int i_c_mult = 1;
439         constexpr int o_c_mult = blksize;
440
441         const float *scales = pd->attr()->output_scales_.scales_;
442         int smask = pd->attr()->output_scales_.mask_;
443
444         auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
445                        const int nb_c, const int c_block) {
446             if (smask == 2) {
447                 for (int w = 0; w < W; ++w) {
448                     const ptrdiff_t flat_off = w * output_d.blocking_desc().strides[0][3];
449                     PRAGMA_OMP_SIMD()
450                     for (int c = 0; c < c_block; ++c) {
451                         const float scale = scales[nb_c * blksize + c];
452
453                         o[flat_off + c] = _qz<type_i, type_o>()(i[w * blksize + c],
454                                                             o[flat_off + c], scale, beta, rmode);
455                     }
456                 }
457             } else {
458                 for (int w = 0; w < W; ++w) {
459                     const ptrdiff_t flat_off = w * output_d.blocking_desc().strides[0][3];
460                     PRAGMA_OMP_SIMD()
461                     for (int c = 0; c < c_block; ++c) {
462                         o[flat_off + c] = _qz_a1b0<type_i, type_o>()(i[w * blksize + c], rmode);
463                     }
464                 }
465             }
466         };
467
468         parallel_nd(dims[0], pdims[1] / blksize, H,
469             [&](int n, int nb_c, int h) {
470                     auto i = &input[input_d.blk_off(n, i_c_mult * nb_c, h)];
471                     auto o = &output[output_d.blk_off(n, o_c_mult * nb_c, h)];
472                     const int c_block = nstl::min(blksize, C - nb_c * blksize);
473                     ker(i, o, nb_c, c_block);
474         });
475
476         return success;
477     }
478 };
479
480 template <SIMPLE_REORDER_TEMPL_DECL>
481 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
482 typename utils::enable_if<fmt_i == nhwc && fmt_o == nChw8c>::type>
483 {
484     static bool is_applicable(const memory_desc_wrapper &input_d,
485         const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
486         int smask = attr ? attr->output_scales_.mask_ : 0;
487         return (smask == 2) && order_keep && input_d._md->format == nhwc && output_d._md->format == nChw8c;
488     }
489
490     static status_t execute(const cpu_reorder_pd_t *pd,
491         const data_t<type_i> *input, data_t<type_o> *output) {
492         DECLARE_COMMON_PARAMS();
493
494         const auto &pdims = output_d.blocking_desc().padding_dims;
495         const auto &dims = input_d.dims();
496         constexpr int blksize = format_traits<fmt_o>::blk_size;
497         const int C = dims[1];
498         const int H = dims[2];
499         const int W = dims[3];
500
501         constexpr int i_c_mult = blksize;
502         constexpr int o_c_mult = 1;
503
504         const float *scales = pd->attr()->output_scales_.scales_;
505         int smask = pd->attr()->output_scales_.mask_;
506
507         auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
508                        const int nb_c, const int c_block) {
509             if (smask == 2) {
510                 for (int w = 0; w < W; ++w) {
511                     const ptrdiff_t flat_off = w * input_d.blocking_desc().strides[0][3];
512                     PRAGMA_OMP_SIMD()
513                     for (int c = 0; c < c_block; ++c) {
514                         const float scale = scales[nb_c * blksize + c];
515
516                         o[w * blksize + c] = _qz<type_i, type_o>()(i[flat_off + c],
517                                                                    o[w * blksize + c], scale, beta, rmode);
518                     }
519                 }
520             } else {
521                 for (int w = 0; w < W; ++w) {
522                     const ptrdiff_t flat_off = w * input_d.blocking_desc().strides[0][3];
523                     PRAGMA_OMP_SIMD()
524                     for (int c = 0; c < c_block; ++c) {
525                         o[w * blksize + c] = _qz_a1b0<type_i, type_o>()(i[flat_off + c], rmode);
526                     }
527                 }
528             }
529         };
530
531         parallel_nd(dims[0], pdims[1] / blksize, H,
532             [&](int n, int nb_c, int h) {
533                     auto i = &input[input_d.blk_off(n, i_c_mult * nb_c, h)];
534                     auto o = &output[output_d.blk_off(n, o_c_mult * nb_c, h)];
535                     const int c_block = nstl::min(blksize, C - nb_c * blksize);
536                     ker(i, o, nb_c, c_block);
537         });
538
539         return success;
540     }
541 };
542
543 template <SIMPLE_REORDER_TEMPL_DECL>
544 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
545 typename utils::enable_if<fmt_i == nhwc && fmt_o == nhwc && type_o != mkldnn_bin>::type>
546 {
547     static bool is_applicable(const memory_desc_wrapper &input_d,
548         const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
549         int smask = attr ? attr->output_scales_.mask_ : 0;
550         return (smask == 2) && order_keep && input_d._md->format == nhwc && output_d._md->format == nhwc;
551     }
552
553     static status_t execute(const cpu_reorder_pd_t *pd,
554         const data_t<type_i> *input, data_t<type_o> *output) {
555         DECLARE_COMMON_PARAMS();
556
557         const auto &dims = input_d.dims();
558         const int C = dims[1];
559         const int H = dims[2];
560         const int W = dims[3];
561
562         const float *scales = pd->attr()->output_scales_.scales_;
563
564         auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o) {
565                 for (int c = 0; c < C; ++c) {
566                     const float scale = scales[c];
567
568                     o[c] = _qz<type_i, type_o>()(i[c], o[c], scale, beta, rmode);
569                 }
570         };
571
572         parallel_nd(dims[0], H, W,
573             [&](int n, int h, int w) {
574                 auto i = &input[input_d.blk_off(n, 0, h, w)];
575                 auto o = &output[output_d.blk_off(n, 0, h, w)];
576                 ker(i, o);
577         });
578
579         return success;
580     }
581 };
582
583 template <SIMPLE_REORDER_TEMPL_DECL>
584 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
585 typename utils::enable_if<fmt_i == nchw && fmt_o == nhwc && type_i != mkldnn_bin && type_o != mkldnn_bin>::type>
586 {
587     static bool is_applicable(const memory_desc_wrapper &input_d,
588         const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
589         int smask = attr ? attr->output_scales_.mask_ : 0;
590         return (smask == 0 || smask == 2) && order_keep && input_d._md->format == nchw && output_d._md->format == nhwc;
591     }
592
593     static status_t execute(const cpu_reorder_pd_t *pd,
594         const data_t<type_i> *input, data_t<type_o> *output) {
595         DECLARE_COMMON_PARAMS();
596
597         const auto &dims = input_d.dims();
598         const int C = dims[1];
599         const int H = dims[2];
600         const int W = dims[3];
601
602         int smask = pd->attr()->output_scales_.mask_;
603         const float *scales = pd->attr()->output_scales_.scales_;
604
605         auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o) {
606             if (smask == 2) {
607                 for (int c = 0; c < C; ++c) {
608                     const float scale = scales[c];
609
610                     const ptrdiff_t flat_off = c * input_d.blocking_desc().strides[0][1];
611
612                     o[c] = _qz<type_i, type_o>()(i[flat_off], o[c], scale, beta, rmode);
613                 }
614             } else {
615                 for (int c = 0; c < C; ++c) {
616                     const ptrdiff_t flat_off = c * input_d.blocking_desc().strides[0][1];
617
618                     o[c] = _qz_a1b0<type_i, type_o>()(i[flat_off], rmode);
619                 }
620             }
621         };
622
623         parallel_nd(dims[0], H, W,
624             [&](int n, int h, int w) {
625                 auto i = &input[input_d.blk_off(n, 0, h, w)];
626                 auto o = &output[output_d.blk_off(n, 0, h, w)];
627                 ker(i, o);
628         });
629
630         return success;
631     }
632 };
633
634 template <SIMPLE_REORDER_TEMPL_DECL>
635 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
636 typename utils::enable_if<(fmt_i == nchw || fmt_i == nhwc) && fmt_o == nhwc && (type_i == mkldnn_bin || type_o == mkldnn_bin)>::type>
637 {
638     static bool is_applicable(const memory_desc_wrapper &input_d,
639         const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
640         int smask = attr ? attr->output_scales_.mask_ : 0;
641         return smask == 0 && order_keep && (input_d._md->format == nchw || input_d._md->format == nhwc) && output_d._md->format == nhwc;
642     }
643
644     static status_t execute(const cpu_reorder_pd_t *pd,
645         const data_t<type_i> *input, data_t<type_o> *output) {
646         DECLARE_COMMON_PARAMS();
647
648         const auto &dims = input_d.dims();
649         const int C = dims[1];
650         const int H = dims[2];
651         const int W = dims[3];
652
653         int nbits = 8;
654         const int CB = div_up(C, nbits);
655
656         auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o) {
657             for (int cb = 0; cb < CB; ++cb) {
658                 uint8_t bin_val = 0x00;
659                 for (int c = cb * nbits, shift = 0; c < std::min(C, (cb + 1) * nbits); c++, shift++) {
660                     const ptrdiff_t flat_off = c * input_d.blocking_desc().strides[0][1];
661
662                     auto bit = uint8_t((i[flat_off] > 0) ? 0x01 : 0x00);
663                     bin_val |= (bit << shift);
664                 }
665
666                 o[cb] = bin_val;
667             }
668         };
669
670         parallel_nd(dims[0], H, W,
671             [&](int n, int h, int w) {
672                 auto iidx = input_d.blk_off(n, 0, h, w);
673                 auto oidx = output_d.blk_off(n, 0, h, w);
674
675                 auto i = &input[iidx];
676                 auto o = &output[oidx / nbits];
677                 ker(i, o);
678         });
679
680         return success;
681     }
682 };
683
684 template <SIMPLE_REORDER_TEMPL_DECL>
685 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
686 typename utils::enable_if<fmt_i == nhwc && fmt_o == nchw>::type>
687 {
688     static bool is_applicable(const memory_desc_wrapper &input_d,
689         const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
690         int smask = attr ? attr->output_scales_.mask_ : 0;
691         return (smask == 0 || smask == 2) && order_keep && input_d._md->format == nhwc && output_d._md->format == nchw;
692     }
693
694     static status_t execute(const cpu_reorder_pd_t *pd,
695         const data_t<type_i> *input, data_t<type_o> *output) {
696         DECLARE_COMMON_PARAMS();
697
698         const auto &dims = input_d.dims();
699         const int C = dims[1];
700         const int H = dims[2];
701         const int W = dims[3];
702
703         int smask = pd->attr()->output_scales_.mask_;
704         const float *scales = pd->attr()->output_scales_.scales_;
705
706         auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o) {
707             if (smask == 2) {
708                 for (int c = 0; c < C; ++c) {
709                     const float scale = scales[c];
710
711                     const ptrdiff_t flat_off = c * output_d.blocking_desc().strides[0][1];
712
713                     o[flat_off] = _qz<type_i, type_o>()(i[c], o[flat_off], scale, beta, rmode);
714                 }
715             } else {
716                 for (int c = 0; c < C; ++c) {
717                     const ptrdiff_t flat_off = c * output_d.blocking_desc().strides[0][1];
718
719                     o[flat_off] = _qz_a1b0<type_i, type_o>()(i[c], rmode);
720                 }
721             }
722         };
723
724         parallel_nd(dims[0], H, W,
725             [&](int n, int h, int w) {
726                 auto i = &input[input_d.blk_off(n, 0, h, w)];
727                 auto o = &output[output_d.blk_off(n, 0, h, w)];
728                 ker(i, o);
729         });
730
731         return success;
732     }
733 };
734
735 template <SIMPLE_REORDER_TEMPL_DECL>
736 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
737 typename utils::enable_if<format_traits<fmt_i>::blk_fmt == bf::_8c
738     && format_traits<fmt_o>::blk_fmt == bf::_16c>::type>
739 {
740     static bool is_applicable(const memory_desc_wrapper &input_d,
741             const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
742     {
743         return simple_fmt_check(order_keep, fmt_i, fmt_o, input_d, output_d)
744             && simple_attr_check(attr, false);
745     }
746
747     static status_t execute(const cpu_reorder_pd_t *pd,
748         const data_t<type_i> *input, data_t<type_o> *output) {
749         DECLARE_COMMON_PARAMS();
750
751         constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
752         constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
753         constexpr int blksize_16 = format_traits<fmt_o>::blk_size;
754         constexpr int blksize_8 = format_traits<fmt_i>::blk_size;
755         constexpr int ic_mult = order_keep ? 2 : 1;
756         constexpr int oc_mult = order_keep ? 1 : 2;
757
758         const auto &nchw8c_d = order_keep ? input_d : output_d;
759         const auto &dims = input_d.dims();
760         const auto &pdims = order_keep ? output_d.blocking_desc().padding_dims
761                                        : input_d.blocking_desc().padding_dims;
762         const auto stride_8c = nchw8c_d.blocking_desc().strides[0];
763
764         const int C = dims[1];
765         const int D = is_3d ? dims[2] : 1;
766         const int H = is_1d ? 1 : dims[2 + is_3d];
767         const int W = dims[3 + is_3d - is_1d];
768
769         auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
770             const int block_16) {
771             const int nb = (block_16 - 1) / blksize_8 + 1;
772             if (alpha == 1.0 && beta == 0.0) {
773                 for (int b = 0; b < nb; ++b) {
774                     const ptrdiff_t i_off = order_keep ? b * stride_8c[1]
775                                                        : b * blksize_8;
776                     const ptrdiff_t o_off = order_keep ? b * blksize_8
777                                                        : b * stride_8c[1];
778                     const int block_8 = nstl::min(blksize_8,
779                                                   block_16 - b * blksize_8);
780                     for (int c = 0; c < block_8; ++c) {
781                         o[o_off + c] = _qz_a1b0<type_i, type_o>()(
782                                 i[i_off + c], rmode);
783                     }
784                 }
785             } else {
786                 for (int b = 0; b < nb; ++b) {
787                     const ptrdiff_t i_off = order_keep ? b * stride_8c[1]
788                                                        : b * blksize_8;
789                     const ptrdiff_t o_off = order_keep ? b * blksize_8
790                                                        : b * stride_8c[1];
791                     const int block_8 = nstl::min(blksize_8,
792                                                   block_16 - b * blksize_8);
793                     for (int c = 0; c < block_8; ++c) {
794                         o[o_off + c] = _qz<type_i, type_o>()(i[i_off + c],
795                                 o[o_off + c], alpha, beta, rmode);
796                     }
797                 }
798             }
799         };
800
801 #       define data_blk_off(md, n, c, d, h, w) \
802         ( is_1d ? (md).blk_off(n, c, w) \
803           : is_3d ? (md).blk_off(n, c, d, h, w) : (md).blk_off(n, c, h, w))
804
805         parallel_nd(dims[0], pdims[1] / blksize_16, D, H, W,
806             [&](int n, int nb_c, int d, int h, int w) {
807             auto i = &input[data_blk_off(input_d, n, ic_mult * nb_c, d, h, w)];
808             auto o = &output[data_blk_off(output_d, n, oc_mult * nb_c, d, h, w)];
809             const int block_16 = nstl::min(blksize_16, C - nb_c * blksize_16);
810             ker(i, o, block_16);
811         });
812
813 #       undef data_blk_off
814
815         return success;
816     }
817 };
818
819 #define PLAIN_TO_BLOCKED_IS_APPLICABLE() \
820     static bool is_applicable(const memory_desc_wrapper &input_d, \
821         const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { \
822         return simple_attr_check(attr, false) && (order_keep \
823                 ? output_d.format() == fmt_o && input_d.is_plain() \
824                 : input_d.format() == fmt_o && output_d.is_plain()); \
825     }
826
827 template <SIMPLE_REORDER_TEMPL_DECL>
828 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
829 typename utils::enable_if<fmt_i == any && (false
830     || format_traits<fmt_o>::blk_fmt == bf::_4c
831     || format_traits<fmt_o>::blk_fmt == bf::_8c
832     || format_traits<fmt_o>::blk_fmt == bf::_16c)>::type>
833 {
834     PLAIN_TO_BLOCKED_IS_APPLICABLE();
835
836     static status_t execute(const cpu_reorder_pd_t *pd,
837         const data_t<type_i> *input, data_t<type_o> *output) {
838         DECLARE_COMMON_PARAMS();
839
840         constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
841         constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
842         constexpr int blksize = format_traits<fmt_o>::blk_size;
843
844         const auto &flat_d = order_keep ? input_d : output_d;
845         const auto &dims = input_d.dims();
846         const auto &pdims = order_keep
847             ? output_d.blocking_desc().padding_dims
848             : input_d.blocking_desc().padding_dims;
849
850         const int C = dims[1];
851         const int D = is_3d ? dims[2] : 1;
852         const int H = is_1d ? 1 : dims[2 + is_3d];
853         const int W = dims[3 + is_3d - is_1d];
854
855         auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
856             const int c_block) {
857             if (alpha == 1.0 && beta == 0.0) {
858                 for (int w = 0; w < W; ++w)
859                 for (int c = 0; c < c_block; ++c) {
860                     const ptrdiff_t flat_off = 0
861                         + c * flat_d.blocking_desc().strides[0][1]
862                         + w * flat_d.blocking_desc().strides[0][3 + is_3d
863                             - is_1d];
864                     if (order_keep) {
865                         o[w * blksize + c] = _qz_a1b0<type_i, type_o>()(
866                                 i[flat_off], rmode);
867                     } else {
868                         o[flat_off] = _qz_a1b0<type_i, type_o>()(
869                                 i[w * blksize + c], rmode);
870                     }
871                 }
872             } else {
873                 for (int w = 0; w < W; ++w)
874                 for (int c = 0; c < c_block; ++c) {
875                     const ptrdiff_t flat_off = 0
876                         + c * flat_d.blocking_desc().strides[0][1]
877                         + w * flat_d.blocking_desc().strides[0][3 + is_3d
878                             - is_1d];
879                     if (order_keep) {
880                         o[w * blksize + c] = _qz<type_i, type_o>()(i[flat_off],
881                                 o[w * blksize + c], alpha, beta, rmode);
882                     } else {
883                         o[flat_off] = _qz<type_i, type_o>()(i[w * blksize + c],
884                                 o[flat_off], alpha, beta, rmode);
885                     }
886                 }
887             }
888         };
889
890         constexpr int i_c_mult = order_keep ? blksize : 1;
891         constexpr int o_c_mult = order_keep ? 1 : blksize;
892
893 #       define data_blk_off(md, n, c, d, h) \
894         ( is_1d ? (md).blk_off(n, c) \
895           : is_3d ? (md).blk_off(n, c, d, h) : (md).blk_off(n, c, h))
896
897         parallel_nd(dims[0], pdims[1] / blksize, D, H,
898             [&](int n, int nb_c, int d, int h) {
899             auto i = &input[data_blk_off(input_d, n, i_c_mult * nb_c, d, h)];
900             auto o = &output[data_blk_off(output_d, n, o_c_mult * nb_c, d, h)];
901             const int c_block = nstl::min(blksize, C - nb_c * blksize);
902             ker(i, o, c_block);
903         });
904
905 #       undef data_blk_off
906
907         return success;
908     }
909 };
910
911 template <SIMPLE_REORDER_TEMPL_DECL>
912 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
913     typename utils::enable_if<
914           (fmt_i == goihw && fmt_o == gOhIw8o4i_s8s8)
915        || (fmt_i == oihw && fmt_o == OhIw8o4i_s8s8)
916     >::type>
917 {
918     static bool is_applicable(const memory_desc_wrapper &input_d,
919             const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
920     {
921         const size_t D_mask = utils::array_product(input_d.dims(),
922                                 math::ilog2q(attr->output_scales_.mask_ + 1));
923         const int oc = (input_d.dims()[(fmt_i == goihw) + 0]);
924         const int g = (fmt_i == goihw) ? (input_d.dims()[0]) : 1;
925
926         return input_d.format() == fmt_i
927             && output_d.format() == fmt_o
928             && (input_d.data_type() == f32 || input_d.data_type() == s8)
929             && output_d.data_type() == s8
930             && (D_mask == 1 || D_mask == (size_t)g * oc);
931     }
932
933     static status_t execute(const cpu_reorder_pd_t *pd,
934         const data_t<type_i> *input, data_t<type_o> *output) {
935         DECLARE_COMMON_PARAMS();
936
937         static constexpr bool w_groups
938             = format_traits<fmt_o>::data_kind == dk::gwei;
939         constexpr int blksize_o = 8;
940         constexpr int blksize_i = 4;
941
942         const auto &flat_d = order_keep ? input_d : output_d;
943         const auto &dims = input_d.dims();
944         const auto &pdims = order_keep
945             ? output_d.blocking_desc().padding_dims
946             : input_d.blocking_desc().padding_dims;
947
948         const int G = w_groups ? dims[0] : 1;
949         const int OC = dims[w_groups + 0];
950         const int NB_OC = pdims[w_groups + 0] / blksize_o;
951         const int IC = dims[w_groups + 1];
952         const int NB_IC = pdims[w_groups + 1] / blksize_i;
953         const int H = dims[w_groups + 2];
954         const int W = dims[w_groups + 3];
955
956         const float *scales = pd->attr()->output_scales_.scales_;
957         const size_t D_mask = utils::array_product(input_d.dims(),
958                                                    math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
959
960         float adj_scale = (mayiuse(avx512_core_vnni)) ? 1.0 : (1.0 / 2.0);
961
962         auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out,
963             int32_t *c, const float *s, const int oc_block, const int ic_block) {
964 #            define blk_off OI_blk_off<format_traits<fmt_o>::blk_fmt>
965
966             for (int ic = 0; ic < ic_block; ++ic) {
967                 for (int oc = 0; oc < oc_block; ++oc) {
968                     const auto _g_oihw_off = oc * flat_d.blocking_desc().strides[0][w_groups + 0] +
969                                              ic * flat_d.blocking_desc().strides[0][w_groups + 1];
970
971                     if (order_keep) {
972                         out[blk_off(oc, ic)] = qz_b0<data_t<type_i>, data_t<type_o>>()(inp[_g_oihw_off], s[oc] * adj_scale, rmode);
973                         c[oc] -= (128 * (int32_t)(out[blk_off(oc, ic)]));
974                     } else {
975                         out[_g_oihw_off] = qz_b0<data_t<type_i>, data_t<type_o>>()(inp[blk_off(oc, ic)], s[oc] * adj_scale, rmode);
976                         c[oc] -= (128 * (int32_t)(out[_g_oihw_off]));
977                     }
978                 }
979             }
980
981 #           undef blk_off
982         };
983
984         constexpr int i_mult_o = blksize_o;
985         constexpr int i_mult_i = blksize_i;
986
987         size_t offset = G * pdims[w_groups+0] * pdims[w_groups+1] * H * W;
988         int32_t *cp = reinterpret_cast<int32_t *>(output + offset);
989         parallel_nd(G * NB_OC * blksize_o, [&](int i) {
990             cp[i] = 0;
991         });
992
993         parallel_nd(G, NB_OC, [&](int g, int O) {
994             for (int I = 0; I < NB_IC; I++) {
995                 for (int h = 0; h < H; h++) {
996                     for (int w = 0; w < W; w++) {
997                         auto i = &input[input_d.blk_off<!w_groups>(g, i_mult_o * O, i_mult_i * I, h, w)];
998                         auto o = &output[output_d.blk_off<!w_groups>(g, O, I, h, w)];
999                         const int oc_block = nstl::min(blksize_o, OC - O * blksize_o);
1000                         const int ic_block = nstl::min(blksize_i, IC - I * blksize_i);
1001
1002                         int _offset = (g * NB_OC + O) * blksize_o;
1003                         ker(i, o, (order_keep) ? &cp[_offset] : nullptr, &scales[(D_mask == 1) ? 0 : _offset], oc_block,
1004                             ic_block);
1005                     }
1006                 }
1007             }
1008         });
1009
1010         return success;
1011     }
1012 };
1013
1014 template <SIMPLE_REORDER_TEMPL_DECL>
1015 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1016 typename utils::enable_if<fmt_i == any && (fmt_o == OhIw8o4i || fmt_o == gOhIw8o4i)>::type>
1017 {
1018     PLAIN_TO_BLOCKED_IS_APPLICABLE();
1019
1020     static status_t execute(const cpu_reorder_pd_t *pd,
1021         const data_t<type_i> *input, data_t<type_o> *output) {
1022         DECLARE_COMMON_PARAMS();
1023
1024         static constexpr bool w_groups
1025             = format_traits<fmt_o>::data_kind == dk::gwei;
1026         constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
1027         constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
1028         constexpr int blksize_o = 8;//format_traits<fmt_o>::blk_size;
1029         constexpr int blksize_i = 4;
1030
1031         const auto &flat_d = order_keep ? input_d : output_d;
1032         const auto &dims = input_d.dims();
1033         const auto &pdims = order_keep
1034             ? output_d.blocking_desc().padding_dims
1035             : input_d.blocking_desc().padding_dims;
1036
1037         const int G = w_groups ? dims[0] : 1;
1038         const int OC = dims[w_groups + 0];
1039         const int NB_OC = pdims[w_groups + 0] / blksize_o;
1040         const int IC = dims[w_groups + 1];
1041         const int NB_IC = pdims[w_groups + 1] / blksize_i;
1042         const int D = is_3d ? dims[w_groups + 2] : 1;
1043         const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
1044         const int W = dims[w_groups + 3 + is_3d - is_1d];
1045
1046         auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
1047             const int oc_block, const int ic_block) {
1048 #           define blk_off OI_blk_off<format_traits<fmt_o>::blk_fmt>
1049
1050             if (alpha == 1.0 && beta == 0.0) {
1051                 for (int oc = 0; oc < oc_block; ++oc)
1052                 for (int ic = 0; ic < ic_block; ++ic) {
1053                     const ptrdiff_t flat_off = 0
1054                         + oc * flat_d.blocking_desc().strides[0][w_groups + 0]
1055                         + ic * flat_d.blocking_desc().strides[0][w_groups + 1];
1056                     if (order_keep) {
1057                         o[blk_off(oc, ic)] = _qz_a1b0<type_i, type_o>()(
1058                                 i[flat_off], rmode);
1059                     } else {
1060                         o[flat_off] = _qz_a1b0<type_i, type_o>()(
1061                                 i[blk_off(oc, ic)], rmode);
1062                     }
1063                 }
1064             } else {
1065                 for (int oc = 0; oc < oc_block; ++oc)
1066                 for (int ic = 0; ic < ic_block; ++ic) {
1067                     const ptrdiff_t flat_off = 0
1068                         + oc * flat_d.blocking_desc().strides[0][w_groups + 0]
1069                         + ic * flat_d.blocking_desc().strides[0][w_groups + 1];
1070                     if (order_keep) {
1071                         o[blk_off(oc, ic)] = _qz<type_i, type_o>()(i[flat_off],
1072                                 o[blk_off(oc, ic)], alpha, beta, rmode);
1073                     } else {
1074                         o[flat_off] = _qz<type_i, type_o>()(i[blk_off(oc, ic)],
1075                                 o[flat_off], alpha, beta, rmode);
1076                     }
1077                 }
1078             }
1079
1080 #           undef blk_off
1081         };
1082
1083
1084         constexpr int i_mult_o = blksize_o;
1085         constexpr int i_mult_i = blksize_i;
1086
1087         parallel_nd(G, NB_OC, NB_IC, D, H, W,
1088             [&](int g, int nb_oc, int nb_ic, int d, int h, int w) {
1089             int i_off = wei_blk_off_like_gwei3D<fmt_o>(input_d,
1090                                                        g, i_mult_o * nb_oc, i_mult_i * nb_ic, d, h, w);
1091             int o_off = wei_blk_off_like_gwei3D<fmt_o>(output_d,
1092                                                        g, nb_oc, nb_ic, d, h, w);
1093             auto i = &input[i_off];
1094             auto o = &output[o_off];
1095             const int oc_block = nstl::min(blksize_o, OC - nb_oc * blksize_o);
1096             const int ic_block = nstl::min(blksize_i, IC - nb_ic * blksize_i);
1097             ker(i, o, oc_block, ic_block);
1098         });
1099
1100         return success;
1101     }
1102 };
1103
1104 template <SIMPLE_REORDER_TEMPL_DECL>
1105 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1106 typename utils::enable_if<fmt_i == any && (fmt_o == OhIw8o32i || fmt_o == OhIw16o32i) && type_i == mkldnn_bin && type_o == mkldnn_bin>::type>
1107 {
1108     PLAIN_TO_BLOCKED_IS_APPLICABLE();
1109
1110     static status_t execute(const cpu_reorder_pd_t *pd,
1111         const data_t<type_i> *input, data_t<type_o> *output) {
1112         DECLARE_COMMON_PARAMS();
1113
1114         static constexpr bool w_groups
1115             = format_traits<fmt_o>::data_kind == dk::gwei;
1116         constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
1117         constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
1118         constexpr int blksize_o = fmt_o == OhIw8o32i ? 8 : 16;
1119         constexpr int blksize_i = 32;
1120
1121         const auto &dims = input_d.dims();
1122         const auto &pdims = order_keep
1123             ? output_d.blocking_desc().padding_dims
1124             : input_d.blocking_desc().padding_dims;
1125
1126         const int G = w_groups ? dims[0] : 1;
1127         const int OC = dims[w_groups + 0];
1128         const int NB_OC = pdims[w_groups + 0] / blksize_o;
1129         const int IC = dims[w_groups + 1];
1130         const int NB_IC = pdims[w_groups + 1] / blksize_i;
1131         const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
1132         const int W = dims[w_groups + 3 + is_3d - is_1d];
1133
1134         constexpr int i_mult_o = blksize_o;
1135         constexpr int i_mult_i = blksize_i;
1136         constexpr int nbits = 8;
1137
1138         auto extract_bit = [](uint8_t val, uint8_t bit) -> uint8_t {
1139             return (uint8_t) ((val >> bit) & 0x0001);
1140         };
1141
1142         parallel_nd(G, NB_OC, NB_IC, H, W,
1143             [&](int g, int nb_oc, int nb_ic, int h, int w) {
1144                 const int oc_block = nstl::min(blksize_o, OC - nb_oc * blksize_o);
1145                 const int ic_block = nstl::min(blksize_i, IC - nb_ic * blksize_i);
1146
1147                 for (int oc = 0; oc < oc_block; ++oc) {
1148                     for (int icb = 0; icb < div_up(ic_block, nbits); ++icb) {
1149
1150                         uint8_t bin_val = 0x00;
1151                         for (int ic = icb*nbits, shift = 0; ic < std::min(IC, (icb + 1)*nbits); ic++, shift++) {
1152                             size_t iidx = (i_mult_o * nb_oc + oc) * input_d.blocking_desc().strides[0][0] +
1153                                           (i_mult_i * nb_ic + ic) *input_d.blocking_desc().strides[0][1] +
1154                                                                 h * input_d.blocking_desc().strides[0][2] +
1155                                                                 w;
1156
1157                             uint8_t bit = extract_bit(input[iidx / nbits], (uint8_t)(iidx % nbits));
1158                             bin_val |= (bit << shift);
1159                         }
1160
1161                         size_t oidx = wei_blk_off_like_gwei3D<fmt_o>(output_d, g, nb_oc, nb_ic, 0, h, w) + oc * blksize_i + icb * blksize_o;
1162                         output[oidx / nbits] = bin_val;
1163
1164                     }
1165                 }
1166             });
1167
1168         return success;
1169     }
1170 };
1171
1172 template <SIMPLE_REORDER_TEMPL_DECL>
1173 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1174 typename utils::enable_if<fmt_i == any
1175 && block_format_traits<format_traits<fmt_o>::blk_fmt>::blk_ndims == 2
1176 && fmt_o != OhIw8o4i && fmt_o != gOhIw8o4i && fmt_o != OhIw8o32i && fmt_o != OhIw16o32i>::type>
1177 {
1178     PLAIN_TO_BLOCKED_IS_APPLICABLE();
1179
1180     static status_t execute(const cpu_reorder_pd_t *pd,
1181         const data_t<type_i> *input, data_t<type_o> *output) {
1182         DECLARE_COMMON_PARAMS();
1183
1184         static constexpr bool w_groups
1185             = format_traits<fmt_o>::data_kind == dk::gwei;
1186         constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
1187         constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
1188         constexpr int blksize = format_traits<fmt_o>::blk_size;
1189
1190         const auto &flat_d = order_keep ? input_d : output_d;
1191         const auto &dims = input_d.dims();
1192         const auto &pdims = order_keep
1193             ? output_d.blocking_desc().padding_dims
1194             : input_d.blocking_desc().padding_dims;
1195
1196         const int G = w_groups ? dims[0] : 1;
1197         const int OC = dims[w_groups + 0];
1198         const int NB_OC = pdims[w_groups + 0] / blksize;
1199         const int IC = dims[w_groups + 1];
1200         const int NB_IC = pdims[w_groups + 1] / blksize;
1201         const int D = is_3d ? dims[w_groups + 2] : 1;
1202         const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
1203         const int W = dims[w_groups + 3 + is_3d - is_1d];
1204
1205         auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
1206             const int oc_block, const int ic_block) {
1207 #           define blk_off OI_blk_off<format_traits<fmt_o>::blk_fmt>
1208
1209             if (alpha == 1.0 && beta == 0.0) {
1210                 for (int oc = 0; oc < oc_block; ++oc)
1211                 for (int ic = 0; ic < ic_block; ++ic) {
1212                     const ptrdiff_t flat_off = 0
1213                         + oc * flat_d.blocking_desc().strides[0][w_groups + 0]
1214                         + ic * flat_d.blocking_desc().strides[0][w_groups + 1];
1215                     if (order_keep) {
1216                         o[blk_off(oc, ic)] = _qz_a1b0<type_i, type_o>()(
1217                                 i[flat_off], rmode);
1218                     } else {
1219                         o[flat_off] = _qz_a1b0<type_i, type_o>()(
1220                                 i[blk_off(oc, ic)], rmode);
1221                     }
1222                 }
1223             } else {
1224                 for (int oc = 0; oc < oc_block; ++oc)
1225                 for (int ic = 0; ic < ic_block; ++ic) {
1226                     const ptrdiff_t flat_off = 0
1227                         + oc * flat_d.blocking_desc().strides[0][w_groups + 0]
1228                         + ic * flat_d.blocking_desc().strides[0][w_groups + 1];
1229                     if (order_keep) {
1230                         o[blk_off(oc, ic)] = _qz<type_i, type_o>()(i[flat_off],
1231                                 o[blk_off(oc, ic)], alpha, beta, rmode);
1232                     } else {
1233                         o[flat_off] = _qz<type_i, type_o>()(i[blk_off(oc, ic)],
1234                                 o[flat_off], alpha, beta, rmode);
1235                     }
1236                 }
1237             }
1238
1239 #           undef blk_off
1240         };
1241
1242
1243         constexpr int i_mult = order_keep ? blksize : 1;
1244         constexpr int o_mult = order_keep ? 1 : blksize;
1245
1246         parallel_nd(G, NB_OC, NB_IC, D, H, W,
1247             [&](int g, int nb_oc, int nb_ic, int d, int h, int w) {
1248             auto i = &input[wei_blk_off_like_gwei3D<fmt_o>(input_d,
1249                     g, i_mult * nb_oc, i_mult * nb_ic, d, h, w)];
1250             auto o = &output[wei_blk_off_like_gwei3D<fmt_o>(output_d,
1251                     g, o_mult * nb_oc, o_mult * nb_ic, d, h, w)];
1252             const int oc_block = nstl::min(blksize, OC - nb_oc * blksize);
1253             const int ic_block = nstl::min(blksize, IC - nb_ic * blksize);
1254             ker(i, o, oc_block, ic_block);
1255         });
1256
1257         return success;
1258     }
1259 };
1260
1261 template <SIMPLE_REORDER_TEMPL_DECL>
1262 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1263 typename utils::enable_if<fmt_i == any && (false
1264     || format_traits<fmt_o>::blk_fmt == bf::_4o
1265     || format_traits<fmt_o>::blk_fmt == bf::_8o
1266     || format_traits<fmt_o>::blk_fmt == bf::_16o)>::type>
1267 {
1268     PLAIN_TO_BLOCKED_IS_APPLICABLE();
1269
1270     static status_t execute(const cpu_reorder_pd_t *pd,
1271         const data_t<type_i> *input, data_t<type_o> *output) {
1272         DECLARE_COMMON_PARAMS();
1273
1274         static constexpr bool w_groups
1275             = format_traits<fmt_o>::data_kind == dk::gwei;
1276         constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
1277         constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
1278         constexpr int blksize = format_traits<fmt_o>::blk_size;
1279
1280         const auto &flat_d = order_keep ? input_d : output_d;
1281         const auto &dims = input_d.dims();
1282         const auto &pdims = order_keep
1283             ? output_d.blocking_desc().padding_dims
1284             : input_d.blocking_desc().padding_dims;
1285
1286         const int G = w_groups ? dims[0] : 1;
1287         const int OC = dims[w_groups + 0];
1288         const int IC = dims[w_groups + 1];
1289         const int D = is_3d ? dims[w_groups + 2] : 1;
1290         const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
1291         const int W = dims[w_groups + 3 + is_3d - is_1d];
1292
1293         constexpr int i_mult = order_keep ? blksize : 1;
1294         constexpr int o_mult = order_keep ? 1 : blksize;
1295         const auto strd_oc = flat_d.blocking_desc().strides[0][w_groups];
1296
1297         parallel_nd(G, pdims[w_groups + 0] / blksize, IC, D, H, W,
1298             [&](int g, int nb_oc, int ic, int d, int h, int w) {
1299             auto i = &input[wei_blk_off_like_gwei3D<fmt_o>(input_d,
1300                     g, i_mult * nb_oc, ic, d, h, w)];
1301             auto o = &output[wei_blk_off_like_gwei3D<fmt_o>(output_d,
1302                     g, o_mult * nb_oc, ic, d, h, w)];
1303             const int oc_block = nstl::min(blksize, OC - nb_oc * blksize);
1304
1305             if (alpha == 1.0 && beta == 0.0) {
1306                 for (int oc = 0; oc < oc_block; ++oc) {
1307                     const auto off = oc * strd_oc;
1308                     if (order_keep) {
1309                         o[oc] = _qz_a1b0<type_i, type_o>()(i[off], rmode);
1310                     } else {
1311                         o[off] = _qz_a1b0<type_i, type_o>()(i[oc], rmode);
1312                     }
1313                 }
1314             } else {
1315                 for (int oc = 0; oc < oc_block; ++oc) {
1316                     const auto off = oc * strd_oc;
1317                     if (order_keep) {
1318                         o[oc] = _qz<type_i, type_o>()(i[off], o[oc], alpha,
1319                                 beta, rmode);
1320                     } else {
1321                         o[off] = _qz<type_i, type_o>()(i[oc], o[off], alpha,
1322                                 beta, rmode);
1323                     }
1324                 }
1325             }
1326         });
1327
1328         return success;
1329     }
1330 };
1331
1332 /* generic and direct-copy reorders */
1333
1334 template <SIMPLE_REORDER_TEMPL_DECL>
1335 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1336     typename utils::enable_if<
1337         fmt_i == any && fmt_o == any && order_keep == fmt_order::any,
1338     spec::direct_copy>::type>
1339 {
1340     static bool is_applicable(const memory_desc_wrapper &input_d,
1341             const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
1342         /* FIXME: is the formula correct? */
1343         return input_d.similar_to(output_d, true, false, 0)
1344             && input_d.is_dense() && output_d.is_dense()
1345             && simple_attr_check(attr, false);
1346     }
1347
1348     static status_t execute(const cpu_reorder_pd_t *pd,
1349         const data_t<type_i> *input, data_t<type_o> *output) {
1350         DECLARE_COMMON_PARAMS();
1351
1352         assert(input_d.is_dense());
1353
1354         input += input_d.blk_off(0);
1355         output += output_d.blk_off(0);
1356
1357         const size_t nelems = input_d.nelems();
1358
1359         constexpr int block_size = 16;
1360         const auto num_blocks = nelems / block_size;
1361         const auto rem_elems = nelems % block_size;
1362
1363         parallel(0, [&](const int ithr, const int nthr) {
1364             size_t start{0}, end{0};
1365             balance211(num_blocks, nthr, ithr, start, end);
1366             start = start * block_size;
1367             end = end * block_size;
1368
1369             if (alpha == 1.0 && beta == 0.0) {
1370                 PRAGMA_OMP_SIMD()
1371                 for (size_t e = start; e < end; ++e) {
1372                     output[e] = qz_a1b0<data_t<type_i>, data_t<type_o>>()
1373                                 (input[e], rmode);
1374                 }
1375             } else if (alpha == 1.0) {
1376                 PRAGMA_OMP_SIMD()
1377                 for (size_t e = start; e < end; ++e) {
1378                     output[e] = qz_a1<data_t<type_i>, data_t<type_o>>()
1379                                 (input[e], output[e], beta, rmode);
1380                 }
1381             } else if (beta == 0.0) {
1382                 PRAGMA_OMP_SIMD()
1383                 for (size_t e = start; e < end; ++e) {
1384                     output[e] = qz_b0<data_t<type_i>, data_t<type_o>>()
1385                                 (input[e], alpha, rmode);
1386                 }
1387             } else {
1388                 PRAGMA_OMP_SIMD()
1389                 for (size_t e = start; e < end; ++e) {
1390                     output[e] = qz<data_t<type_i>, data_t<type_o>>()
1391                                 (input[e], output[e], alpha, beta, rmode);
1392                 }
1393             }
1394
1395             if (rem_elems != 0 && ithr == nthr - 1){
1396                 if (alpha == 1.0 && beta == 0.0) {
1397                     PRAGMA_OMP_SIMD()
1398                     for (size_t e = nelems - rem_elems; e < nelems; ++e) {
1399                         output[e] = qz_a1b0<data_t<type_i>,
1400                             data_t<type_o>>()(input[e], rmode);
1401                     }
1402                 } else if (alpha == 1.0) {
1403                     PRAGMA_OMP_SIMD()
1404                     for (size_t e = nelems - rem_elems; e < nelems; ++e) {
1405                         output[e] = qz_a1<data_t<type_i>,
1406                             data_t<type_o>>()(input[e], output[e], beta, rmode);
1407                     }
1408                 } else if (beta == 0.0) {
1409                     PRAGMA_OMP_SIMD()
1410                     for (size_t e = nelems - rem_elems; e < nelems; ++e) {
1411                         output[e] = qz_b0<data_t<type_i>,
1412                             data_t<type_o>>()(input[e], alpha, rmode);
1413                     }
1414                 } else {
1415                     PRAGMA_OMP_SIMD()
1416                     for (size_t e = nelems - rem_elems; e < nelems; ++e) {
1417                         output[e] = qz<data_t<type_i>, data_t<type_o>>()
1418                                     (input[e], output[e], alpha, beta, rmode);
1419                    }
1420                }
1421             }
1422         });
1423         return success;
1424     }
1425 };
1426
1427 template <SIMPLE_REORDER_TEMPL_DECL>
1428 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1429     typename utils::enable_if<
1430         fmt_i == any && fmt_o == any && order_keep == fmt_order::any,
1431     spec::direct_copy_except_dim_0>::type>
1432 {
1433     static bool is_applicable(const memory_desc_wrapper &input_d,
1434             const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
1435         auto is_dense_no_0 = [](const memory_desc_wrapper &data_d) {
1436             return nelems_no_dim_0(data_d) == _size_no_dim_0(data_d);
1437         };
1438         /* FIXME: is the formula correct? */
1439         return input_d.similar_to(output_d, true, false, 1)
1440             && is_dense_no_0(input_d) && is_dense_no_0(output_d)
1441             && simple_attr_check(attr, false);
1442     }
1443
1444     static status_t execute(const cpu_reorder_pd_t *pd,
1445         const data_t<type_i> *input, data_t<type_o> *output) {
1446         DECLARE_COMMON_PARAMS();
1447
1448         input += input_d.blk_off(0);
1449         output += output_d.blk_off(0);
1450
1451         const int N = input_d.dims()[0];
1452         const size_t is = input_d.blocking_desc().strides[0][0];
1453         const size_t os = output_d.blocking_desc().strides[0][0];
1454         const size_t nelems_no_d0 = nelems_no_dim_0(input_d);
1455         const size_t work_amount = N * nelems_no_d0;
1456
1457         if (alpha == 1.0 && beta == 0.0) {
1458             parallel(0, [&](const int ithr, const int nthr) {
1459                 size_t n{0}, dim1_s{0};
1460                 size_t start{0}, end{0};
1461                 balance211(work_amount, nthr, ithr, start, end);
1462                 nd_iterator_init(start, n, N, dim1_s, nelems_no_d0);
1463                 while(start < end) {
1464                     size_t work_rem = end - start;
1465                     size_t dim1_e = dim1_s + work_rem > nelems_no_d0
1466                         ? nelems_no_d0 : dim1_s + work_rem;
1467                     PRAGMA_OMP_SIMD()
1468                     for (size_t e = dim1_s; e < dim1_e; ++e) {
1469                         output[os * n + e] = _qz_a1b0<type_i, type_o>()(
1470                                 input[is * n + e], rmode);
1471                     }
1472                     nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0);
1473                 }
1474             });
1475         } else {
1476             parallel(0, [&](const int ithr, const int nthr) {
1477                 size_t n{0}, dim1_s{0};
1478                 size_t start{0}, end{0};
1479                 balance211(work_amount, nthr, ithr, start, end);
1480                 nd_iterator_init(start, n, N, dim1_s, nelems_no_d0);
1481                 while(start < end) {
1482                     size_t work_rem = end - start;
1483                     size_t dim1_e =
1484                         dim1_s + work_rem > nelems_no_d0 ? nelems_no_d0
1485                         : dim1_s + work_rem;
1486                     PRAGMA_OMP_SIMD()
1487                     for (size_t e = dim1_s; e < dim1_e; ++e){
1488                         output[os * n + e] = _qz<type_i, type_o>()(
1489                                 input[is * n + e], output[os * n + e], alpha,
1490                                 beta, rmode);
1491                     }
1492                     nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0);
1493                 }
1494             });
1495         }
1496
1497         return success;
1498     }
1499
1500 private:
1501     static size_t nelems_no_dim_0(const memory_desc_wrapper &data_d) {
1502         const int ndims = data_d.ndims();
1503         if (ndims <= 1) return 1;
1504         return utils::array_product(data_d.dims() + 1, data_d.ndims() - 1);
1505     }
1506
1507     static size_t _size_no_dim_0(const memory_desc_wrapper &data_d) {
1508         size_t max_size = 0;
1509         auto &blk = data_d.blocking_desc();
1510         for (int d = 1; d < data_d.ndims(); ++d) {
1511             auto block = blk.block_dims[d];
1512             max_size = nstl::max(max_size,
1513                     size_t(size_t(blk.padding_dims[d] / block)
1514                         * blk.strides[0][d]));
1515             if (block > 1)
1516                 max_size = nstl::max(max_size,
1517                         size_t(block * blk.strides[1][d]));
1518         }
1519         return max_size;
1520     }
1521 };
1522
1523 template <SIMPLE_REORDER_TEMPL_DECL>
1524 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1525     typename utils::enable_if<
1526         fmt_i == any && fmt_o == any && order_keep == fmt_order::any,
1527     spec::reference>::type>
1528 {
1529     static bool is_applicable(const memory_desc_wrapper &input_d,
1530             const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
1531         /* supported smask: 0x0...011..10...0,
1532          * i.e. 1 should be contiguous */
1533         int smask = attr ? attr->output_scales_.mask_ : 0;
1534         for (; smask > 0 && !(smask & 0x1); smask >>= 1);
1535         for (; smask > 0 && smask & 0x1; smask >>= 1);
1536         return true
1537             && input_d.is_blocking_desc()
1538             && output_d.is_blocking_desc()
1539             && !output_d.is_additional_buffer()
1540             && !input_d.is_additional_buffer()
1541             && smask == 0;
1542     }
1543
1544     static status_t execute(const cpu_reorder_pd_t *pd,
1545         const data_t<type_i> *input, data_t<type_o> *output) {
1546         DECLARE_COMMON_PARAMS();
1547
1548         const size_t nelems = input_d.nelems();
1549
1550         int ndims_start = 0, ndims_mask = 0;
1551         int smask = pd->attr()->output_scales_.mask_;
1552         for (; smask > 0 && !(smask & 0x1); smask >>= 1) ++ndims_start;
1553         for (; smask > 0 && smask & 0x1; smask >>= 1) ++ndims_mask;
1554         assert(smask == 0);
1555
1556         const ptrdiff_t D_start
1557             = utils::array_product(input_d.dims(), ndims_start);
1558         const ptrdiff_t D_mask
1559             = utils::array_product(input_d.dims() + ndims_start, ndims_mask);
1560         const ptrdiff_t D_rest = nelems / D_start / D_mask;
1561
1562         const float *scales = pd->attr()->output_scales_.scales_;
1563
1564         parallel_nd(D_start, D_mask, D_rest,
1565             [&](ptrdiff_t ds, ptrdiff_t dm, ptrdiff_t dr) {
1566             const float scale = scales[dm];
1567
1568             const size_t e = (ds * D_mask + dm) * D_rest + dr;
1569             const auto &i = input[input_d.off_l(e)];
1570             auto &o = output[output_d.off_l(e)];
1571
1572             o = _qz<type_i, type_o>()(i, o, scale, beta, rmode);
1573         });
1574
1575         return success;
1576     }
1577 };
1578
1579
1580 /* high level class declaration */
1581
1582 template <SIMPLE_REORDER_TEMPL_DECL, typename spec = void>
1583 struct simple_reorder_t: public cpu_primitive_t {
1584     struct pd_t: public cpu_reorder_pd_t {
1585         pd_t(const cpu_memory_pd_t *input_pd, const cpu_memory_pd_t *output_pd,
1586                 const primitive_attr_t *attr)
1587             : cpu_reorder_pd_t(input_pd, output_pd, attr) {}
1588
1589         DECLARE_COMMON_PD_T("simple:any", simple_reorder_t);
1590
1591         static status_t create(reorder_pd_t **reorder_pd,
1592                 const memory_pd_t *input_pd, const memory_pd_t *output_pd,
1593                 const primitive_attr_t *attr) {
1594             assert(input_pd->engine()->kind() == engine_kind::cpu);
1595             assert(output_pd->engine()->kind() == engine_kind::cpu);
1596             bool args_ok = true
1597                 && input_pd->desc()->data_type == type_i
1598                 && output_pd->desc()->data_type == type_o
1599                 && simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, spec>::
1600                 is_applicable(input_pd->desc(), output_pd->desc(), attr);
1601             if (!args_ok)
1602                 return invalid_arguments;
1603
1604             auto _pd = new pd_t((const cpu_memory_pd_t *)input_pd,
1605                     (const cpu_memory_pd_t *)output_pd, attr);
1606             if (_pd == nullptr) return out_of_memory;
1607             if (_pd->init() != success) { delete _pd; return unimplemented; }
1608             return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
1609         }
1610     };
1611
1612     simple_reorder_t(const pd_t *apd, const input_vector &inputs,
1613             const output_vector &outputs)
1614         : cpu_primitive_t(apd, inputs, outputs) {}
1615
1616     virtual void execute(event_t *e) const {
1617         auto input = reinterpret_cast<const data_t<type_i> *>(
1618                 this->input_memory(0));
1619         auto output = reinterpret_cast<data_t<type_o> *>(this->memory());
1620         simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, spec>::execute(
1621                 pd(), input, output);
1622         e->set_state(event_t::ready);
1623     }
1624
1625 private:
1626     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
1627 };
1628
1629 #undef SIMPLE_REORDER_TEMPL_DECL
1630 #undef SIMPLE_REORDER_TEMPL_CALL
1631
1632 }
1633 }
1634 }
1635
1636 #endif
1637
1638 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s