updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / simple_reorder.hpp
1 /*******************************************************************************
2 * Copyright 2016-2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_SIMPLE_REORDER_HPP
18 #define CPU_SIMPLE_REORDER_HPP
19
20 #include <assert.h>
21
22 #include "c_types_map.hpp"
23 #include "type_helpers.hpp"
24 #include "math_utils.hpp"
25 #include "mkldnn_thread.hpp"
26 #include "utils.hpp"
27
28 #include "format_traits.hpp"
29 #include "cpu_reorder_pd.hpp"
30 #include "cpu_primitive.hpp"
31
32 #include "simple_q10n.hpp"
33 #include "cpu_isa_traits.hpp"
34
35 #include "bfloat16_utils.hpp"
36
37 namespace mkldnn {
38 namespace impl {
39 namespace cpu {
40
41 using namespace mkldnn::impl::status;
42 using namespace mkldnn::impl::memory_format;
43 using namespace mkldnn::impl::data_type;
44
45 using dk = data_kind_t;
46 using bf = block_format_t;
47
48 using namespace mkldnn::impl::utils;
49 using math::saturate;
50
51 template<impl::data_type_t type>
52 using data_t = typename prec_traits<type>::type;
53
54 template<impl::data_type_t type_i, impl::data_type_t type_o>
55 using _qz_a1b0 = qz_a1b0<data_t<type_i>, data_t<type_o>>;
56
57 template<impl::data_type_t type_i, impl::data_type_t type_o>
58 using _qz = qz<data_t<type_i>, data_t<type_o>>;
59
60 namespace fmt_order {
61     const bool keep = true;
62     const bool reverse = false;
63     const bool any = keep;
64 }
65
66 namespace spec {
67 struct direct_copy {};
68 struct direct_copy_except_dim_0 {};
69 struct reference {};
70 }
71
72 #define SIMPLE_REORDER_TEMPL_DECL \
73     impl::data_type_t type_i, impl::memory_format_t fmt_i, \
74     impl::data_type_t type_o, impl::memory_format_t fmt_o, bool order_keep
75 #define SIMPLE_REORDER_TEMPL_CALL \
76     type_i, fmt_i, type_o, fmt_o, order_keep
77
78 #define DECLARE_COMMON_PARAMS() \
79         const memory_desc_wrapper &input_d = pd->input_pd(); \
80         const memory_desc_wrapper &output_d = pd->output_pd(); \
81         const float alpha = pd->alpha(); MAYBE_UNUSED(alpha); \
82         const float beta = pd->beta(); MAYBE_UNUSED(beta); \
83         const round_mode_t rmode = pd->attr()->round_mode_; MAYBE_UNUSED(rmode);
84
85 #define GET_SCRATCHPAD_SIZE_ZERO() \
86     static size_t get_scratchpad_size(const memory_desc_wrapper &input_d, \
87             const memory_desc_wrapper &output_d) { \
88         return 0; \
89     }
90
91 /* specific reorders: common template */
92 template <SIMPLE_REORDER_TEMPL_DECL, typename spec = void>
93 struct simple_reorder_impl {};
94
95 namespace {
96 bool simple_fmt_check(bool order_keep, impl::memory_format_t fmt_i,
97         impl::memory_format_t fmt_o, const memory_desc_wrapper &input_d,
98         const memory_desc_wrapper &output_d) {
99     return input_d.format() == (order_keep ? fmt_i : fmt_o)
100         && output_d.format() == (order_keep ? fmt_o : fmt_i);
101 }
102 bool simple_attr_check(const primitive_attr_t *attr, bool many_scales_support) {
103     if (many_scales_support)
104         return true;
105     return IMPLICATION(attr, attr->output_scales_.mask_ == 0);
106 }
107 }
108
109 /* specific reorders: implementation */
110 template <SIMPLE_REORDER_TEMPL_DECL>
111 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
112 typename utils::enable_if<fmt_i == any && (false
113     || fmt_o == hwio_s8s8 || fmt_o == dhwio_s8s8
114     || fmt_o == hwigo_s8s8 || fmt_o == dhwigo_s8s8)>::type>
115 {
116     static bool is_applicable(const memory_desc_wrapper &input_d,
117             const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
118     {
119         const size_t D_mask = utils::array_product(input_d.dims(),
120                                 math::ilog2q(attr->output_scales_.mask_ + 1));
121         const int oc = (input_d.dims()[fmt_o == hwigo_s8s8 || fmt_o == dhwigo_s8s8 + 0]);
122         const int g = (fmt_o == hwigo_s8s8 || fmt_o == dhwigo_s8s8) ? (input_d.dims()[0]) : 1;
123
124         return output_d.format() == fmt_o
125             && (input_d.data_type() == f32 || input_d.data_type() == s8)
126             && output_d.data_type() == s8
127             && (D_mask == 1 || D_mask == (size_t)g * oc);
128     }
129
130     GET_SCRATCHPAD_SIZE_ZERO();
131
132     static status_t execute(const cpu_reorder_pd_t *pd,
133         const data_t<type_i> *input, data_t<type_o> *output,
134         const memory_tracking::grantor_t &scratchpad) {
135         DECLARE_COMMON_PARAMS();
136
137         static constexpr bool w_groups = fmt_o == hwigo_s8s8 || fmt_o == dhwigo_s8s8;
138         int is_3d = format_traits<fmt_o>::ndims_sp == 3;
139
140         const auto &dims = input_d.dims();
141         const auto &pdims = output_d.blocking_desc().padding_dims;
142
143         const int G = w_groups ? dims[0] : 1;
144         const int OC = dims[w_groups + 0];
145         const int IC = dims[w_groups + 1];
146         const int D = is_3d ? dims[w_groups + 2] : 1;
147         const int H = dims[w_groups + 2 + is_3d];
148         const int W = dims[w_groups + 3 + is_3d];
149
150         const float *scales = pd->attr()->output_scales_.scales_;
151         const size_t D_mask = utils::array_product(input_d.dims(),
152                 math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
153
154         float adj_scale = (mayiuse(avx512_core_vnni)) ? 1.0f : (1.0f / 2.0f);
155
156         size_t offset = G * pdims[w_groups + 0] * pdims[w_groups + 1] * D * H * W;
157         int32_t *cp = reinterpret_cast<int32_t *>(output + offset);
158
159         parallel_nd(G, OC, [&](int g, int oc) {
160             cp[g * OC + oc] = 0;
161             for (int ic = 0; ic < IC; ic++)
162             for (int d = 0; d < D; d++)
163             for (int h = 0; h < H; h++)
164             for (int w = 0; w < W; w++) {
165                 auto i = is_3d ? input[input_d.blk_off<!w_groups>(g, oc, ic, d, h, w)]
166                                : input[input_d.blk_off<!w_groups>(g, oc, ic, h, w)];
167                 auto &o = is_3d ? output[output_d.blk_off<!w_groups>(g, oc, ic, d, h, w)]
168                                 : output[output_d.blk_off<!w_groups>(g, oc, ic, h, w)];
169                 const float s = scales[(D_mask == 1) ? 0 : g * OC + oc];
170
171                 o = qz_b0<data_t<type_i>, data_t<type_o>>()(
172                     i, s * adj_scale, rmode);
173                 cp[g * OC + oc] -= (int32_t)o;
174             }
175             cp [g * OC + oc] *= 128;
176         });
177         return success;
178     }
179 };
180
181 template <SIMPLE_REORDER_TEMPL_DECL>
182 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
183         typename utils::enable_if<(
184                 utils::one_of(fmt_i, goihw, oihw, goiw, oiw, hwio, hwigo)
185                 && (format_traits<fmt_o>::blk_fmt == bf::_4i16o4i_s8s8
186                            || format_traits<fmt_o>::blk_fmt == bf::_2i8o4i_s8s8
187                            || format_traits<fmt_o>::blk_fmt
188                                    == bf::_4o4i_s8s8))>::type> {
189     static bool is_applicable(const memory_desc_wrapper &input_d,
190             const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
191     {
192         const size_t D_mask = utils::array_product(input_d.dims(),
193                                 math::ilog2q(attr->output_scales_.mask_ + 1));
194         static constexpr bool w_groups
195                 = format_traits<fmt_i>::data_kind == dk::gwei;
196         const int oc = input_d.dims()[w_groups + 0];
197         const int g = w_groups ? input_d.dims()[0] : 1;
198
199         return input_d.format() == fmt_i
200             && output_d.format() == fmt_o
201             && utils::one_of(input_d.data_type(), f32, s8)
202             && output_d.data_type() == s8
203             && (D_mask == 1 || D_mask == (size_t)g * oc);
204     }
205
206     GET_SCRATCHPAD_SIZE_ZERO();
207
208     static status_t execute(const cpu_reorder_pd_t *pd,
209         const data_t<type_i> *input, data_t<type_o> *output,
210         const memory_tracking::grantor_t &scratchpad) {
211         DECLARE_COMMON_PARAMS();
212
213         constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
214         static constexpr bool w_groups
215                 = format_traits<fmt_o>::data_kind == dk::gwei;
216         const int blksize = format_traits<fmt_o>::blk_size;
217         const int sblk = 4;
218
219         const auto &plain_d = order_keep ? input_d : output_d;
220         const auto &dims = input_d.dims();
221         const auto &pdims = order_keep
222             ? output_d.blocking_desc().padding_dims
223             : input_d.blocking_desc().padding_dims;
224
225         const int G = w_groups ? dims[0] : 1;
226         const int OC = dims[w_groups + 0];
227         const int NB_OC = pdims[w_groups + 0] / blksize;
228         const int IC = dims[w_groups + 1];
229         const int NB_IC = pdims[w_groups + 1] / blksize;
230         const int H = is_1d ? 1 : dims[w_groups + 2];
231         const int W = dims[w_groups + 3 - is_1d];
232
233         const float *scales = pd->attr()->output_scales_.scales_;
234         const size_t D_mask = utils::array_product(input_d.dims(),
235                             math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
236
237         float adj_scale = (mayiuse(avx512_core_vnni)) ? 1.f : (1.f / 2.f);
238
239         auto index = [&](const int ic, const int oc) {
240             return ((ic / sblk) * blksize * sblk + sblk * oc + ic % sblk);
241         };
242
243         auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out,
244             int32_t *c, const float *s, const int oc_block, const int ic_block) {
245             for (int ic = 0; ic < ic_block; ++ic) {
246             for (int oc = 0; oc < oc_block; ++oc) {
247                 const auto plain_off =
248                     oc * plain_d.blocking_desc().strides[0][w_groups + 0]
249                   + ic * plain_d.blocking_desc().strides[0][w_groups + 1];
250                 out[index(ic, oc)]
251                     = qz_b0<data_t<type_i>, data_t<type_o>>()(
252                             inp[plain_off], s[oc] * adj_scale, rmode);
253                 c[oc] -= (128 * (int32_t)(out[index(ic, oc)]));
254             }
255             }
256         };
257
258         constexpr int i_mult = blksize;
259         constexpr int o_mult = 1;
260
261         size_t offset = G * pdims[w_groups+0] * pdims[w_groups+1] * H * W;
262         int32_t *cp = reinterpret_cast<int32_t *>(output + offset);
263         parallel_nd(G * NB_OC * blksize, [&](int i) {
264             cp[i] = 0;
265         });
266
267         parallel_nd(G, NB_OC, [&](int g, int O) {
268             for (int I = 0; I < NB_IC; I++)
269                 for (int h = 0; h < H; h++)
270                 for (int w = 0; w < W; w++) {
271                     auto i = &input[wei_blk_off_like_gwei3D<fmt_i>(
272                             input_d, g, i_mult * O, i_mult * I, 0, h, w)];
273                     auto o = &output[wei_blk_off_like_gwei3D<fmt_o>(
274                             output_d, g, o_mult * O, o_mult * I, 0, h, w)];
275                     const int oc_block = nstl::min(blksize, OC - O * blksize);
276                     const int ic_block = nstl::min(blksize, IC - I * blksize);
277
278                     int _offset = (g * NB_OC + O) * blksize;
279                     ker(i, o, (order_keep) ? &cp[_offset] : nullptr,
280                             &scales[(D_mask == 1) ? 0 : _offset],
281                                         oc_block, ic_block);
282                 }
283         });
284         return success;
285     }
286 };
287
288 template <SIMPLE_REORDER_TEMPL_DECL>
289 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
290     typename utils::enable_if<(
291         (fmt_i == goihw || fmt_i == oihw) &&
292         (format_traits<fmt_o>::blk_fmt == bf::_16i16o
293          || format_traits<fmt_o>::blk_fmt == bf::_8i16o2i
294          || format_traits<fmt_o>::blk_fmt == bf::_8o16i2o) &&
295         type_i == data_type::f32 && type_o == data_type::bf16)>::type>
296 {
297     static bool is_applicable(const memory_desc_wrapper &input_d,
298             const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
299     {
300         return order_keep
301             && input_d.format() == fmt_i && output_d.format() == fmt_o
302             && input_d.data_type() == f32 && output_d.data_type() == bf16;
303     }
304
305     static size_t get_scratchpad_size(const memory_desc_wrapper &input_d,
306             const memory_desc_wrapper &output_d) {
307         const int blksize = 16;
308         return sizeof(float) * blksize * blksize * mkldnn_get_max_threads();
309     }
310
311     static status_t execute(const cpu_reorder_pd_t *pd,
312         const data_t<type_i> *input, data_t<type_o> *output,
313         const memory_tracking::grantor_t &scratchpad) {
314         DECLARE_COMMON_PARAMS();
315
316         static constexpr bool w_groups = fmt_i == goihw;
317         const int blksize = 16;
318         const int sblk = 2;
319
320         const auto &_g_oihw_d = input_d;
321         const auto &dims = input_d.dims();
322         const auto &pdims = output_d.blocking_desc().padding_dims;
323
324         const int G = w_groups ? dims[0] : 1;
325         const int OC = dims[w_groups + 0];
326         const int NB_OC = pdims[w_groups + 0] / blksize;
327         const int IC = dims[w_groups + 1];
328         const int NB_IC = pdims[w_groups + 1] / blksize;
329         const int H = dims[w_groups + 2];
330         const int W = dims[w_groups + 3];
331
332         const size_t wsp_size = blksize * blksize;
333         float *wspace = scratchpad.template get<float>(
334                 memory_tracking::names::key_reorder_space);
335
336         auto index = [&](const int ic, const int oc) {
337             if (format_traits<fmt_o>::blk_fmt == bf::_16i16o)
338                 return (ic * blksize + oc);
339             else if (format_traits<fmt_o>::blk_fmt == bf::_8i16o2i)
340                 return ((ic / sblk) * blksize * sblk + sblk * oc + ic % sblk);
341             else if (format_traits<fmt_o>::blk_fmt == bf::_8o16i2o)
342                 return ((oc / sblk) * blksize * sblk + sblk * ic + oc % sblk);
343             else
344                 assert(!"Invalid weight format");
345                 return 0;
346         };
347
348         auto ker = [&](const data_t<type_i> *inp, data_t<type_i> *out,
349             const int curr_oc_block, const int oc_block,
350             const int curr_ic_block, const int ic_block) {
351             int ic = 0;
352             for (ic = 0; ic < curr_ic_block; ++ic) {
353                 int oc = 0;
354                 for (oc = 0; oc < curr_oc_block; ++oc) {
355                     const auto _g_oihw_off =
356                         oc * _g_oihw_d.blocking_desc().strides[0][w_groups + 0]
357                       + ic * _g_oihw_d.blocking_desc().strides[0][w_groups + 1];
358                     out[index(ic, oc)] = inp[_g_oihw_off];
359                 }
360                 for (/* continue */; oc < oc_block; ++oc) {
361                     out[index(ic, oc)] = (data_t<type_i>)0;
362                 }
363             }
364             for (/* continue */; ic < ic_block; ++ic) {
365                 for (int oc = 0; oc < oc_block; ++oc) {
366                     out[index(ic, oc)] = (data_t<type_i>)0;
367                 }
368             }
369         };
370
371         constexpr int i_mult = blksize;
372         constexpr int o_mult = 1;
373
374         parallel_nd(G, NB_OC, NB_IC, H, W, [&](int g, int O, int I, int h, int w) {
375             int ithr = mkldnn_get_thread_num();
376             float *_wspace = wspace + wsp_size * ithr;
377             auto i = &input[input_d.blk_off<!w_groups>(g,
378                     i_mult * O, i_mult * I, h, w)];
379             auto o = &output[output_d.blk_off<!w_groups>(
380                     g, o_mult * O, o_mult * I, h, w)];
381             const int oc_block = nstl::min(blksize, OC - O * blksize);
382             const int ic_block = nstl::min(blksize, IC - I * blksize);
383             ker(i, _wspace, oc_block, blksize, ic_block, blksize);
384             bf16_cvt_utils::cvt_float_to_bfloat16(o, _wspace, wsp_size);
385         });
386
387         return success;
388     }
389
390 };
391
392 template <SIMPLE_REORDER_TEMPL_DECL>
393 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
394     typename utils::enable_if<format_traits<fmt_i>::blk_fmt == bf::_16i16o &&
395            (fmt_o == goihw || fmt_o == oihw) &&
396            type_i == data_type::bf16 && type_o == data_type::f32>::type>
397 {
398     static bool is_applicable(const memory_desc_wrapper &input_d,
399             const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
400     {
401         return order_keep
402             && input_d.format() == fmt_i && output_d.format() == fmt_o
403             && input_d.data_type() == bf16 && output_d.data_type() == f32;
404     }
405
406     GET_SCRATCHPAD_SIZE_ZERO();
407
408     static status_t execute(const cpu_reorder_pd_t *pd,
409         const data_t<type_i> *input, data_t<type_o> *output,
410         const memory_tracking::grantor_t &scratchpad) {
411         DECLARE_COMMON_PARAMS();
412
413         static constexpr bool w_groups = fmt_o == goihw;
414         const int blksize = 16;
415
416         const auto &_g_oihw_d = output_d;
417         const auto &dims = input_d.dims();
418         const auto &pdims = input_d.blocking_desc().padding_dims;
419
420         const int G = w_groups ? dims[0] : 1;
421         const int OC = dims[w_groups + 0];
422         const int NB_OC = pdims[w_groups + 0] / blksize;
423         const int IC = dims[w_groups + 1];
424         const int NB_IC = pdims[w_groups + 1] / blksize;
425         const int H = dims[w_groups + 2];
426         const int W = dims[w_groups + 3];
427
428         auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
429             int curr_oc_block, int curr_ic_block) {
430             for (int ic = 0; ic < curr_ic_block; ++ic) {
431                 for (int oc = 0; oc < curr_oc_block; ++oc) {
432                     const auto _g_oihw_off =
433                         oc * _g_oihw_d.blocking_desc().strides[0][w_groups + 0]
434                       + ic * _g_oihw_d.blocking_desc().strides[0][w_groups + 1];
435                     bf16_cvt_utils::cvt_bfloat16_to_float(
436                             &o[_g_oihw_off], &i[ic * blksize + oc]);
437                 }
438             }
439         };
440
441         constexpr int i_mult = 1;
442         constexpr int o_mult = blksize;
443
444         parallel_nd(G, NB_OC, NB_IC, H, W, [&](int g, int O, int I, int h, int w) {
445             auto i = &input[input_d.blk_off<!w_groups>(
446                     g, i_mult * O, i_mult * I, h, w)];
447             auto o = &output[output_d.blk_off<!w_groups>(
448                     g, o_mult * O, o_mult * I, h, w)];
449             const int oc_block = nstl::min(blksize, OC - O * blksize);
450             const int ic_block = nstl::min(blksize, IC - I * blksize);
451             ker(i, o, oc_block, ic_block);
452         });
453
454         return success;
455     }
456 };
457
458 template <SIMPLE_REORDER_TEMPL_DECL>
459 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
460 typename utils::enable_if<
461           (fmt_i == nchw && fmt_o == nChw16c) &&
462            type_i == data_type::f32 && type_o == data_type::bf16>::type>
463 {
464     static bool is_applicable(const memory_desc_wrapper &input_d,
465         const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
466         return input_d.format() == fmt_i && output_d.format() == fmt_o
467             && input_d.data_type() == f32 && output_d.data_type() == bf16;
468     }
469
470     static size_t get_scratchpad_size(const memory_desc_wrapper &input_d,
471             const memory_desc_wrapper &output_d) {
472         const size_t blksize = 16;
473         const size_t W = input_d.dims()[3];
474         return sizeof(float) * blksize * W * mkldnn_get_max_threads();
475     }
476
477     static status_t execute(const cpu_reorder_pd_t *pd,
478         const data_t<type_i> *input, data_t<type_o> *output,
479         const memory_tracking::grantor_t &scratchpad) {
480         DECLARE_COMMON_PARAMS();
481
482         constexpr int blksize = 16;
483
484         const auto &flat_d = input_d;
485         const auto &dims = input_d.dims();
486         const auto &pdims = output_d.blocking_desc().padding_dims;
487
488         const int C = dims[1];
489         const int H = dims[2];
490         const int W = dims[3];
491
492         const int wsp_size = W * blksize;
493         float *wspace = scratchpad.template get<float>(
494                 memory_tracking::names::key_reorder_space);
495
496         auto ker = [&](const data_t<type_i> *i, data_t<type_i> *o,
497             const int curr_c_block, const int c_block) {
498             for (int w = 0; w < W; ++w) {
499                 int c = 0;
500                 for (c = 0; c < curr_c_block; ++c) {
501                     const ptrdiff_t flat_off = 0
502                         + c * flat_d.blocking_desc().strides[0][1]
503                         + w * flat_d.blocking_desc().strides[0][3];
504                     o[w * blksize + c] = i[flat_off];
505                 }
506                 for (/* continue */; c < c_block; ++c) {
507                     o[w * blksize + c] = (data_t<type_i>)0;
508                 }
509             }
510         };
511
512         constexpr int i_c_mult = blksize;
513         constexpr int o_c_mult = 1;
514
515         parallel_nd(dims[0], pdims[1] / blksize, H, [&](int n, int nb_c, int h) {
516             int ithr = mkldnn_get_thread_num();
517             float *_wspace = wspace + wsp_size * ithr;
518             auto i = &input[input_d.blk_off(n, i_c_mult * nb_c, h)];
519             auto o = &output[output_d.blk_off(n, o_c_mult * nb_c, h)];
520             const int c_block = nstl::min(blksize, C - nb_c * blksize);
521             ker(i, _wspace, c_block, blksize);
522             bf16_cvt_utils::cvt_float_to_bfloat16(o, _wspace, wsp_size);
523         });
524
525         return success;
526     }
527
528 };
529
530 template <SIMPLE_REORDER_TEMPL_DECL>
531 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
532 typename utils::enable_if<
533           (fmt_i == nChw16c && fmt_o == nchw) &&
534           type_i == data_type::bf16 && type_o == data_type::f32>::type>
535 {
536     static bool is_applicable(const memory_desc_wrapper &input_d,
537         const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
538         return input_d.format() == fmt_i && output_d.format() == fmt_o
539             && input_d.data_type() == bf16 && output_d.data_type() == f32;
540     }
541
542     GET_SCRATCHPAD_SIZE_ZERO();
543
544     static status_t execute(const cpu_reorder_pd_t *pd,
545         const data_t<type_i> *input, data_t<type_o> *output,
546         const memory_tracking::grantor_t &scratchpad) {
547         DECLARE_COMMON_PARAMS();
548
549         constexpr int blksize = 16;
550         const auto &flat_d = output_d;
551         const auto &dims = input_d.dims();
552         const auto &pdims = input_d.blocking_desc().padding_dims;
553
554         const int C = dims[1];
555         const int H = dims[2];
556         const int W = dims[3];
557
558         auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
559             const int c_block) {
560             for (int w = 0; w < W; ++w)
561             for (int c = 0; c < c_block; ++c) {
562                 const ptrdiff_t flat_off = 0
563                     + c * flat_d.blocking_desc().strides[0][1]
564                     + w * flat_d.blocking_desc().strides[0][3];
565
566                 bf16_cvt_utils::cvt_bfloat16_to_float(
567                         &o[flat_off], &i[w * blksize + c]);
568             }
569         };
570
571         constexpr int i_c_mult = 1;
572         constexpr int o_c_mult = blksize;
573
574         parallel_nd(dims[0], pdims[1] / blksize, H, [&](int n, int nb_c, int h) {
575             auto i = &input[input_d.blk_off(n, i_c_mult * nb_c, h)];
576             auto o = &output[output_d.blk_off(n, o_c_mult * nb_c, h)];
577             const int c_block = nstl::min(blksize, C - nb_c * blksize);
578             ker(i, o, c_block);
579         });
580
581         return success;
582     }
583 };
584
585
586 template <SIMPLE_REORDER_TEMPL_DECL>
587 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
588     typename utils::enable_if<true
589         && utils::one_of(fmt_i, goiw, goihw, hwigo)
590         && format_traits<fmt_o>::blk_fmt == bf::_16g_s8s8>::type> {
591
592     static bool is_applicable(const memory_desc_wrapper &input_d,
593             const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
594         const size_t D_mask = utils::array_product(input_d.dims(),
595                             math::ilog2q(attr->output_scales_.mask_ + 1));
596         const int oc = input_d.dims()[1];
597         const int g = input_d.dims()[0];
598
599         return true
600             && order_keep
601             && input_d.format() == fmt_i
602             && output_d.format() == fmt_o
603             && utils::one_of(input_d.data_type(), f32, s8)
604             && output_d.data_type() == s8
605             && (D_mask == 1 || D_mask == (size_t)g * oc);
606     }
607
608     GET_SCRATCHPAD_SIZE_ZERO();
609
610     static status_t execute(const cpu_reorder_pd_t *pd,
611             const data_t<type_i> *input, data_t<type_o> *output,
612             const memory_tracking::grantor_t &scratchpad) {
613         DECLARE_COMMON_PARAMS();
614
615         constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
616         const int blksize = format_traits<fmt_o>::blk_size;
617
618         const auto &dims = input_d.dims();
619         const auto &pdims = output_d.blocking_desc().padding_dims;
620         const int G = dims[0];
621         const int Gp = pdims[0];
622         const int OC = dims[1];
623         const int IC = dims[2];
624         const int H = is_1d ? 1 : dims[3];
625         const int W = dims[4 - is_1d];
626
627         const size_t D_mask = utils::array_product(input_d.dims(),
628                             math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
629         const float *scales = pd->attr()->output_scales_.scales_;
630         float adj_scale = (mayiuse(avx512_core_vnni)) ? 1.f : (1.f / 2.f);
631
632
633         auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out,
634                 int32_t *cp, const float *s, const int g_block) {
635             PRAGMA_OMP_SIMD()
636             for (int g = 0; g < g_block; g++) {
637                 const auto i_off = g * input_d.blocking_desc().strides[0][0];
638                 out[g] = qz_b0<data_t<type_i>, data_t<type_o>>()(
639                         inp[i_off], s[g * OC] * adj_scale, rmode);
640                 cp[g * OC] -= 128 * (int32_t)(out[g]);
641             }
642         };
643
644         size_t cp_offset = output_d.size() - output_d.additional_buffer_size();
645         int32_t *cp = reinterpret_cast<int32_t *>(output + cp_offset);
646         parallel_nd((Gp/blksize) * OC, [&](int ib) {
647             PRAGMA_OMP_SIMD()
648             for (int i = 0; i < blksize; i++)
649                 cp[ib * blksize + i] = 0;
650         });
651
652         parallel_nd(Gp/blksize, OC, [&](int gb, int O) {
653                 for (int I = 0; I < IC; I++) {
654                     for (int h = 0; h < H; h++) {
655                     for (int w = 0; w < W; w++) {
656                         const int g_block = nstl::min(G - gb * blksize, blksize);
657                         const auto inp = &input[wei_blk_off_like_gwei3D<fmt_i>(
658                                 input_d, gb * blksize, O, I, 0, h, w)];
659                         const auto out = &output[wei_blk_off_like_gwei3D<fmt_o>(
660                                 output_d, gb, O, I, 0, h, w)];
661                         int offset = gb * blksize + O;
662                         ker(inp, out, &cp[offset],
663                             &scales[(D_mask == 1) ? 0 : offset], g_block);
664                    }
665                    }
666                }
667         });
668         return success;
669     }
670 };
671
672 template <SIMPLE_REORDER_TEMPL_DECL>
673 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
674     typename utils::enable_if<true
675     && format_traits<fmt_i>::blk_fmt == bf::_8i16o2i
676     && format_traits<fmt_o>::blk_fmt == bf::_8o16i2o>::type>
677 {
678     static bool is_applicable(const memory_desc_wrapper &input_d,
679             const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
680     {
681         return simple_fmt_check(order_keep, fmt_i, fmt_o, input_d, output_d)
682             && simple_attr_check(attr, false);
683     }
684
685     GET_SCRATCHPAD_SIZE_ZERO();
686
687     static status_t execute(const cpu_reorder_pd_t *pd,
688         const data_t<type_i> *input, data_t<type_o> *output,
689         const memory_tracking::grantor_t &scratchpad) {
690         DECLARE_COMMON_PARAMS();
691
692         static constexpr bool w_groups
693             = format_traits<fmt_o>::data_kind == dk::gwei;
694         constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
695         constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
696         constexpr int blksize = format_traits<fmt_o>::blk_size;
697
698         const auto &dims = input_d.dims();
699
700         const int G = w_groups ? dims[0] : 1;
701         const int NB_OC = dims[w_groups + 0] / blksize;
702         const int NB_IC = dims[w_groups + 1] / blksize;
703         const int D = is_3d ? dims[w_groups + 2] : 1;
704         const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
705         const int W = dims[w_groups + 3 + is_3d - is_1d];
706
707         auto idx_i = [&](const int oc, const int ic)
708         { return ((ic / 2) * blksize * 2 + 2 * oc + ic % 2); };
709
710         auto idx_o = [&](const int oc, const int ic)
711         { return ((oc / 2) * blksize * 2 + 2 * ic + oc % 2); };
712
713         auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o) -> void {
714             if (alpha == 1.0 && beta == 0.0) {
715                 for (int ic = 0; ic < blksize; ++ic) {
716                     for (int oc = 0; oc < blksize; ++oc) {
717                         o[idx_o(oc, ic)] = _qz_a1b0<type_i, type_o>()(
718                                 i[idx_i(oc, ic)], rmode);
719                     }
720                 }
721             } else {
722                 for (int ic = 0; ic < blksize; ++ic) {
723                     for (int oc = 0; oc < blksize; ++oc) {
724                         o[idx_o(oc, ic)] = _qz<type_i, type_o>()(
725                                 i[idx_i(oc, ic)], o[idx_o(oc, ic)], alpha,
726                                 beta, rmode);
727                     }
728                 }
729             }
730         };
731
732         parallel_nd(G, NB_OC, NB_IC, D, H, W,
733             [&](int g, int o, int i, int d, int h, int w) {
734             auto ptr_i = &input[wei_blk_off_like_gwei3D<fmt_i>(
735                     input_d, g, o, i, d,  h, w)];
736             auto ptr_o = &output[wei_blk_off_like_gwei3D<fmt_o>(
737                     output_d, g, o, i, d, h, w)];
738             ker(ptr_i, ptr_o);
739         });
740
741         return success;
742     }
743 };
744
745 /* reorders with tail support */
746
747 template <SIMPLE_REORDER_TEMPL_DECL>
748 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
749 typename utils::enable_if<fmt_i == nChw8c && fmt_o == nhwc && order_keep>::type>
750 {
751     static bool is_applicable(const memory_desc_wrapper &input_d,
752         const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
753         int smask = attr ? attr->output_scales_.mask_ : 0;
754         return (smask == 0 || smask == 2) && order_keep && input_d._md->format == nChw8c && output_d._md->format == nhwc;
755     }
756
757     GET_SCRATCHPAD_SIZE_ZERO();
758
759     static status_t execute(const cpu_reorder_pd_t *pd,
760         const data_t<type_i> *input, data_t<type_o> *output,
761         const memory_tracking::grantor_t &scratchpad) {
762         DECLARE_COMMON_PARAMS();
763
764         const auto &pdims = input_d.blocking_desc().padding_dims;
765         const auto &dims = input_d.dims();
766         constexpr int blksize = format_traits<fmt_i>::blk_size;
767         const int C = dims[1];
768         const int H = dims[2];
769         const int W = dims[3];
770
771         constexpr int i_c_mult = 1;
772         constexpr int o_c_mult = blksize;
773
774         const float *scales = pd->attr()->output_scales_.scales_;
775         int smask = pd->attr()->output_scales_.mask_;
776
777         auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
778                        const int nb_c, const int c_block) {
779             if (smask == 2) {
780                 for (int w = 0; w < W; ++w) {
781                     const ptrdiff_t flat_off = w * output_d.blocking_desc().strides[0][3];
782                     PRAGMA_OMP_SIMD()
783                     for (int c = 0; c < c_block; ++c) {
784                         const float scale = scales[nb_c * blksize + c];
785
786                         o[flat_off + c] = _qz<type_i, type_o>()(i[w * blksize + c],
787                                                             o[flat_off + c], scale, beta, rmode);
788                     }
789                 }
790             } else {
791                 for (int w = 0; w < W; ++w) {
792                     const ptrdiff_t flat_off = w * output_d.blocking_desc().strides[0][3];
793                     PRAGMA_OMP_SIMD()
794                     for (int c = 0; c < c_block; ++c) {
795                         o[flat_off + c] = _qz_a1b0<type_i, type_o>()(i[w * blksize + c], rmode);
796                     }
797                 }
798             }
799         };
800
801         parallel_nd(dims[0], pdims[1] / blksize, H,
802             [&](int n, int nb_c, int h) {
803                     auto i = &input[input_d.blk_off(n, i_c_mult * nb_c, h)];
804                     auto o = &output[output_d.blk_off(n, o_c_mult * nb_c, h)];
805                     const int c_block = nstl::min(blksize, C - nb_c * blksize);
806                     ker(i, o, nb_c, c_block);
807         });
808
809         return success;
810     }
811 };
812
813 template <SIMPLE_REORDER_TEMPL_DECL>
814 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
815 typename utils::enable_if<fmt_i == nhwc && fmt_o == nChw8c>::type>
816 {
817     static bool is_applicable(const memory_desc_wrapper &input_d,
818         const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
819         int smask = attr ? attr->output_scales_.mask_ : 0;
820         return (smask == 2) && order_keep && input_d._md->format == nhwc && output_d._md->format == nChw8c;
821     }
822
823     GET_SCRATCHPAD_SIZE_ZERO();
824
825     static status_t execute(const cpu_reorder_pd_t *pd,
826         const data_t<type_i> *input, data_t<type_o> *output,
827         const memory_tracking::grantor_t &scratchpad) {
828         DECLARE_COMMON_PARAMS();
829
830         const auto &pdims = output_d.blocking_desc().padding_dims;
831         const auto &dims = input_d.dims();
832         constexpr int blksize = format_traits<fmt_o>::blk_size;
833         const int C = dims[1];
834         const int H = dims[2];
835         const int W = dims[3];
836
837         constexpr int i_c_mult = blksize;
838         constexpr int o_c_mult = 1;
839
840         const float *scales = pd->attr()->output_scales_.scales_;
841         int smask = pd->attr()->output_scales_.mask_;
842
843         auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
844                        const int nb_c, const int c_block) {
845             if (smask == 2) {
846                 for (int w = 0; w < W; ++w) {
847                     const ptrdiff_t flat_off = w * input_d.blocking_desc().strides[0][3];
848                     PRAGMA_OMP_SIMD()
849                     for (int c = 0; c < c_block; ++c) {
850                         const float scale = scales[nb_c * blksize + c];
851
852                         o[w * blksize + c] = _qz<type_i, type_o>()(i[flat_off + c],
853                                                                    o[w * blksize + c], scale, beta, rmode);
854                     }
855                 }
856             } else {
857                 for (int w = 0; w < W; ++w) {
858                     const ptrdiff_t flat_off = w * input_d.blocking_desc().strides[0][3];
859                     PRAGMA_OMP_SIMD()
860                     for (int c = 0; c < c_block; ++c) {
861                         o[w * blksize + c] = _qz_a1b0<type_i, type_o>()(i[flat_off + c], rmode);
862                     }
863                 }
864             }
865         };
866
867         parallel_nd(dims[0], pdims[1] / blksize, H,
868             [&](int n, int nb_c, int h) {
869                     auto i = &input[input_d.blk_off(n, i_c_mult * nb_c, h)];
870                     auto o = &output[output_d.blk_off(n, o_c_mult * nb_c, h)];
871                     const int c_block = nstl::min(blksize, C - nb_c * blksize);
872                     ker(i, o, nb_c, c_block);
873         });
874
875         return success;
876     }
877 };
878
879 template <SIMPLE_REORDER_TEMPL_DECL>
880 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
881 typename utils::enable_if<fmt_i == nhwc && fmt_o == nhwc && type_o != mkldnn_bin>::type>
882 {
883     static bool is_applicable(const memory_desc_wrapper &input_d,
884         const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
885         int smask = attr ? attr->output_scales_.mask_ : 0;
886         return (smask == 2) && order_keep && input_d._md->format == nhwc && output_d._md->format == nhwc;
887     }
888
889     GET_SCRATCHPAD_SIZE_ZERO();
890
891     static status_t execute(const cpu_reorder_pd_t *pd,
892         const data_t<type_i> *input, data_t<type_o> *output,
893         const memory_tracking::grantor_t &scratchpad) {
894         DECLARE_COMMON_PARAMS();
895
896         const auto &dims = input_d.dims();
897         const int C = dims[1];
898         const int H = dims[2];
899         const int W = dims[3];
900
901         const float *scales = pd->attr()->output_scales_.scales_;
902
903         auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o) {
904                 for (int c = 0; c < C; ++c) {
905                     const float scale = scales[c];
906
907                     o[c] = _qz<type_i, type_o>()(i[c], o[c], scale, beta, rmode);
908                 }
909         };
910
911         parallel_nd(dims[0], H, W,
912             [&](int n, int h, int w) {
913                 auto i = &input[input_d.blk_off(n, 0, h, w)];
914                 auto o = &output[output_d.blk_off(n, 0, h, w)];
915                 ker(i, o);
916         });
917
918         return success;
919     }
920 };
921
922 template <SIMPLE_REORDER_TEMPL_DECL>
923 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
924 typename utils::enable_if<fmt_i == nchw && fmt_o == nhwc && type_i != mkldnn_bin && type_o != mkldnn_bin>::type>
925 {
926     static bool is_applicable(const memory_desc_wrapper &input_d,
927         const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
928         int smask = attr ? attr->output_scales_.mask_ : 0;
929         return (smask == 0 || smask == 2) && order_keep && input_d._md->format == nchw && output_d._md->format == nhwc;
930     }
931
932     GET_SCRATCHPAD_SIZE_ZERO();
933
934     static status_t execute(const cpu_reorder_pd_t *pd,
935         const data_t<type_i> *input, data_t<type_o> *output,
936         const memory_tracking::grantor_t &scratchpad) {
937         DECLARE_COMMON_PARAMS();
938
939         const auto &dims = input_d.dims();
940         const int C = dims[1];
941         const int H = dims[2];
942         const int W = dims[3];
943
944         int smask = pd->attr()->output_scales_.mask_;
945         const float *scales = pd->attr()->output_scales_.scales_;
946
947         auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o) {
948             if (smask == 2) {
949                 for (int c = 0; c < C; ++c) {
950                     const float scale = scales[c];
951
952                     const ptrdiff_t flat_off = c * input_d.blocking_desc().strides[0][1];
953
954                     o[c] = _qz<type_i, type_o>()(i[flat_off], o[c], scale, beta, rmode);
955                 }
956             } else {
957                 for (int c = 0; c < C; ++c) {
958                     const ptrdiff_t flat_off = c * input_d.blocking_desc().strides[0][1];
959
960                     o[c] = _qz_a1b0<type_i, type_o>()(i[flat_off], rmode);
961                 }
962             }
963         };
964
965         parallel_nd(dims[0], H, W,
966             [&](int n, int h, int w) {
967                 auto i = &input[input_d.blk_off(n, 0, h, w)];
968                 auto o = &output[output_d.blk_off(n, 0, h, w)];
969                 ker(i, o);
970         });
971
972         return success;
973     }
974 };
975
976 template <SIMPLE_REORDER_TEMPL_DECL>
977 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
978 typename utils::enable_if<(fmt_i == nchw || fmt_i == nhwc) && fmt_o == nhwc && (type_i == mkldnn_bin || type_o == mkldnn_bin)>::type>
979 {
980     static bool is_applicable(const memory_desc_wrapper &input_d,
981         const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
982         int smask = attr ? attr->output_scales_.mask_ : 0;
983         return smask == 0 && order_keep && (input_d._md->format == nchw || input_d._md->format == nhwc) && output_d._md->format == nhwc;
984     }
985
986     GET_SCRATCHPAD_SIZE_ZERO();
987
988     static status_t execute(const cpu_reorder_pd_t *pd,
989         const data_t<type_i> *input, data_t<type_o> *output,
990         const memory_tracking::grantor_t &scratchpad) {
991         DECLARE_COMMON_PARAMS();
992
993         const auto &dims = input_d.dims();
994         const int C = dims[1];
995         const int H = dims[2];
996         const int W = dims[3];
997
998         int nbits = 8;
999         const int CB = div_up(C, nbits);
1000
1001         auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o) {
1002             for (int cb = 0; cb < CB; ++cb) {
1003                 uint8_t bin_val = 0x00;
1004                 for (int c = cb * nbits, shift = 0; c < std::min(C, (cb + 1) * nbits); c++, shift++) {
1005                     const ptrdiff_t flat_off = c * input_d.blocking_desc().strides[0][1];
1006
1007                     auto bit = uint8_t((i[flat_off] > 0) ? 0x01 : 0x00);
1008                     bin_val |= (bit << shift);
1009                 }
1010
1011                 o[cb] = bin_val;
1012             }
1013         };
1014
1015         parallel_nd(dims[0], H, W,
1016             [&](int n, int h, int w) {
1017                 auto iidx = input_d.blk_off(n, 0, h, w);
1018                 auto oidx = output_d.blk_off(n, 0, h, w);
1019
1020                 auto i = &input[iidx];
1021                 auto o = &output[oidx / nbits];
1022                 ker(i, o);
1023         });
1024
1025         return success;
1026     }
1027 };
1028
1029 template <SIMPLE_REORDER_TEMPL_DECL>
1030 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1031 typename utils::enable_if<fmt_i == nhwc && fmt_o == nchw>::type>
1032 {
1033     static bool is_applicable(const memory_desc_wrapper &input_d,
1034         const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
1035         int smask = attr ? attr->output_scales_.mask_ : 0;
1036         return (smask == 0 || smask == 2) && order_keep && input_d._md->format == nhwc && output_d._md->format == nchw;
1037     }
1038
1039     GET_SCRATCHPAD_SIZE_ZERO();
1040
1041     static status_t execute(const cpu_reorder_pd_t *pd,
1042         const data_t<type_i> *input, data_t<type_o> *output,
1043         const memory_tracking::grantor_t &scratchpad) {
1044         DECLARE_COMMON_PARAMS();
1045
1046         const auto &dims = input_d.dims();
1047         const int C = dims[1];
1048         const int H = dims[2];
1049         const int W = dims[3];
1050
1051         int smask = pd->attr()->output_scales_.mask_;
1052         const float *scales = pd->attr()->output_scales_.scales_;
1053
1054         auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o) {
1055             if (smask == 2) {
1056                 for (int c = 0; c < C; ++c) {
1057                     const float scale = scales[c];
1058
1059                     const ptrdiff_t flat_off = c * output_d.blocking_desc().strides[0][1];
1060
1061                     o[flat_off] = _qz<type_i, type_o>()(i[c], o[flat_off], scale, beta, rmode);
1062                 }
1063             } else {
1064                 for (int c = 0; c < C; ++c) {
1065                     const ptrdiff_t flat_off = c * output_d.blocking_desc().strides[0][1];
1066
1067                     o[flat_off] = _qz_a1b0<type_i, type_o>()(i[c], rmode);
1068                 }
1069             }
1070         };
1071
1072         parallel_nd(dims[0], H, W,
1073             [&](int n, int h, int w) {
1074                 auto i = &input[input_d.blk_off(n, 0, h, w)];
1075                 auto o = &output[output_d.blk_off(n, 0, h, w)];
1076                 ker(i, o);
1077         });
1078
1079         return success;
1080     }
1081 };
1082
1083 template <SIMPLE_REORDER_TEMPL_DECL>
1084 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1085 typename utils::enable_if<true
1086         && (format_traits<fmt_i>::blk_fmt == bf::_4c
1087                 || format_traits<fmt_i>::blk_fmt == bf::_8c)
1088         && format_traits<fmt_o>::blk_fmt == bf::_16c>::type>
1089 {
1090     static bool is_applicable(const memory_desc_wrapper &input_d,
1091             const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
1092     {
1093         return simple_fmt_check(order_keep, fmt_i, fmt_o, input_d, output_d)
1094             && simple_attr_check(attr, false);
1095     }
1096
1097     GET_SCRATCHPAD_SIZE_ZERO();
1098
1099     static status_t execute(const cpu_reorder_pd_t *pd,
1100         const data_t<type_i> *input, data_t<type_o> *output,
1101         const memory_tracking::grantor_t &scratchpad) {
1102         DECLARE_COMMON_PARAMS();
1103
1104         constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
1105         constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
1106         constexpr int blksize_fmt_o = format_traits<fmt_o>::blk_size;
1107         constexpr int blksize_fmt_i = format_traits<fmt_i>::blk_size;
1108         constexpr int ic_mult = order_keep ? 2 : 1;
1109         constexpr int oc_mult = order_keep ? 1 : 2;
1110
1111         const auto &fmt_i_d = order_keep ? input_d : output_d;
1112         const auto &dims = input_d.dims();
1113         const auto &pdims = order_keep ? output_d.blocking_desc().padding_dims
1114                                        : input_d.blocking_desc().padding_dims;
1115         const auto stride_fmt_i = fmt_i_d.blocking_desc().strides[0];
1116
1117         const int C = dims[1];
1118         const int D = is_3d ? dims[2] : 1;
1119         const int H = is_1d ? 1 : dims[2 + is_3d];
1120         const int W = dims[3 + is_3d - is_1d];
1121
1122         auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
1123             const int block_fmt_o) {
1124             const int nb = (block_fmt_o - 1) / blksize_fmt_i + 1;
1125             if (alpha == 1.0 && beta == 0.0) {
1126                 for (int b = 0; b < nb; ++b) {
1127                     const ptrdiff_t i_off = order_keep ? b * stride_fmt_i[1]
1128                                                        : b * blksize_fmt_i;
1129                     const ptrdiff_t o_off = order_keep ? b * blksize_fmt_i
1130                                                        : b * stride_fmt_i[1];
1131                     const int block_fmt_i = nstl::min(blksize_fmt_i,
1132                                                   block_fmt_o - b * blksize_fmt_i);
1133                     for (int c = 0; c < block_fmt_i; ++c) {
1134                         o[o_off + c] = _qz_a1b0<type_i, type_o>()(
1135                                 i[i_off + c], rmode);
1136                     }
1137                 }
1138             } else {
1139                 for (int b = 0; b < nb; ++b) {
1140                     const ptrdiff_t i_off = order_keep ? b * stride_fmt_i[1]
1141                                                        : b * blksize_fmt_i;
1142                     const ptrdiff_t o_off = order_keep ? b * blksize_fmt_i
1143                                                        : b * stride_fmt_i[1];
1144                     const int block_fmt_i = nstl::min(blksize_fmt_i,
1145                                                   block_fmt_o - b * blksize_fmt_i);
1146                     for (int c = 0; c < block_fmt_i; ++c) {
1147                         o[o_off + c] = _qz<type_i, type_o>()(i[i_off + c],
1148                                 o[o_off + c], alpha, beta, rmode);
1149                     }
1150                 }
1151             }
1152         };
1153
1154 #       define data_blk_off(md, n, c, d, h, w) \
1155         ( is_1d ? (md).blk_off(n, c, w) \
1156           : is_3d ? (md).blk_off(n, c, d, h, w) : (md).blk_off(n, c, h, w))
1157
1158         parallel_nd(dims[0], pdims[1] / blksize_fmt_o, D, H, W,
1159             [&](int n, int nb_c, int d, int h, int w) {
1160             auto i = &input[data_blk_off(input_d, n, ic_mult * nb_c, d, h, w)];
1161             auto o = &output[data_blk_off(output_d, n, oc_mult * nb_c, d, h, w)];
1162             const int block_fmt_o = nstl::min(blksize_fmt_o, C - nb_c * blksize_fmt_o);
1163             ker(i, o, block_fmt_o);
1164         });
1165
1166 #       undef data_blk_off
1167
1168         return success;
1169     }
1170 };
1171
1172 #define PLAIN_TO_BLOCKED_IS_APPLICABLE() \
1173     static bool is_applicable(const memory_desc_wrapper &input_d, \
1174         const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { \
1175         return simple_attr_check(attr, false) && (order_keep \
1176                 ? output_d.format() == fmt_o && input_d.is_plain() \
1177                 : input_d.format() == fmt_o && output_d.is_plain()); \
1178     }
1179
1180 template <SIMPLE_REORDER_TEMPL_DECL>
1181 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1182 typename utils::enable_if<fmt_i == any && (false
1183     || format_traits<fmt_o>::blk_fmt == bf::_4c
1184     || format_traits<fmt_o>::blk_fmt == bf::_8c
1185     || format_traits<fmt_o>::blk_fmt == bf::_16c)>::type>
1186 {
1187     PLAIN_TO_BLOCKED_IS_APPLICABLE();
1188
1189     GET_SCRATCHPAD_SIZE_ZERO();
1190
1191     static status_t execute(const cpu_reorder_pd_t *pd,
1192         const data_t<type_i> *input, data_t<type_o> *output,
1193         const memory_tracking::grantor_t &scratchpad) {
1194         DECLARE_COMMON_PARAMS();
1195
1196         constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
1197         constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
1198         constexpr int blksize = format_traits<fmt_o>::blk_size;
1199
1200         const auto &flat_d = order_keep ? input_d : output_d;
1201         const auto &dims = input_d.dims();
1202         const auto &pdims = order_keep
1203             ? output_d.blocking_desc().padding_dims
1204             : input_d.blocking_desc().padding_dims;
1205
1206         const int C = dims[1];
1207         const int D = is_3d ? dims[2] : 1;
1208         const int H = is_1d ? 1 : dims[2 + is_3d];
1209         const int W = dims[3 + is_3d - is_1d];
1210
1211         auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
1212             const int c_block) {
1213             if (alpha == 1.0 && beta == 0.0) {
1214                 for (int w = 0; w < W; ++w)
1215                 for (int c = 0; c < c_block; ++c) {
1216                     const ptrdiff_t flat_off = 0
1217                         + c * flat_d.blocking_desc().strides[0][1]
1218                         + w * flat_d.blocking_desc().strides[0][3 + is_3d
1219                             - is_1d];
1220                     if (order_keep) {
1221                         o[w * blksize + c] = _qz_a1b0<type_i, type_o>()(
1222                                 i[flat_off], rmode);
1223                     } else {
1224                         o[flat_off] = _qz_a1b0<type_i, type_o>()(
1225                                 i[w * blksize + c], rmode);
1226                     }
1227                 }
1228             } else {
1229                 for (int w = 0; w < W; ++w)
1230                 for (int c = 0; c < c_block; ++c) {
1231                     const ptrdiff_t flat_off = 0
1232                         + c * flat_d.blocking_desc().strides[0][1]
1233                         + w * flat_d.blocking_desc().strides[0][3 + is_3d
1234                             - is_1d];
1235                     if (order_keep) {
1236                         o[w * blksize + c] = _qz<type_i, type_o>()(i[flat_off],
1237                                 o[w * blksize + c], alpha, beta, rmode);
1238                     } else {
1239                         o[flat_off] = _qz<type_i, type_o>()(i[w * blksize + c],
1240                                 o[flat_off], alpha, beta, rmode);
1241                     }
1242                 }
1243             }
1244         };
1245
1246         constexpr int i_c_mult = order_keep ? blksize : 1;
1247         constexpr int o_c_mult = order_keep ? 1 : blksize;
1248
1249 #       define data_blk_off(md, n, c, d, h) \
1250         ( is_1d ? (md).blk_off(n, c) \
1251           : is_3d ? (md).blk_off(n, c, d, h) : (md).blk_off(n, c, h))
1252
1253         parallel_nd(dims[0], pdims[1] / blksize, D, H,
1254             [&](int n, int nb_c, int d, int h) {
1255             auto i = &input[data_blk_off(input_d, n, i_c_mult * nb_c, d, h)];
1256             auto o = &output[data_blk_off(output_d, n, o_c_mult * nb_c, d, h)];
1257             const int c_block = nstl::min(blksize, C - nb_c * blksize);
1258             ker(i, o, c_block);
1259         });
1260
1261 #       undef data_blk_off
1262
1263         return success;
1264     }
1265 };
1266
1267 template <SIMPLE_REORDER_TEMPL_DECL>
1268 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1269     typename utils::enable_if<
1270           (fmt_i == goihw && fmt_o == gOhIw8o4i_s8s8)
1271        || (fmt_i == oihw && fmt_o == OhIw8o4i_s8s8)
1272        || (fmt_i == goidhw && fmt_o == gOdhIw8o4i_s8s8)
1273        || (fmt_i == oidhw && fmt_o == OdhIw8o4i_s8s8)
1274     >::type>
1275 {
1276     static bool is_applicable(const memory_desc_wrapper &input_d,
1277             const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
1278     {
1279         const size_t D_mask = utils::array_product(input_d.dims(),
1280                                 math::ilog2q(attr->output_scales_.mask_ + 1));
1281         const int oc = (input_d.dims()[(fmt_i == goihw || fmt_i == goidhw) + 0]);
1282         const int g = (fmt_i == goihw || fmt_i == goidhw) ? (input_d.dims()[0]) : 1;
1283
1284         return input_d.format() == fmt_i
1285             && output_d.format() == fmt_o
1286             && (input_d.data_type() == f32 || input_d.data_type() == s8)
1287             && output_d.data_type() == s8
1288             && (D_mask == 1 || D_mask == (size_t)g * oc);
1289     }
1290
1291     GET_SCRATCHPAD_SIZE_ZERO();
1292
1293     static status_t execute(const cpu_reorder_pd_t *pd,
1294         const data_t<type_i> *input, data_t<type_o> *output,
1295         const memory_tracking::grantor_t &scratchpad) {
1296         DECLARE_COMMON_PARAMS();
1297
1298         static constexpr bool w_groups
1299             = format_traits<fmt_o>::data_kind == dk::gwei;
1300         int is_3d = format_traits<fmt_o>::ndims_sp == 3;
1301         constexpr int blksize_o = 8;
1302         constexpr int blksize_i = 4;
1303
1304         const auto &flat_d = order_keep ? input_d : output_d;
1305         const auto &dims = input_d.dims();
1306         const auto &pdims = order_keep
1307             ? output_d.blocking_desc().padding_dims
1308             : input_d.blocking_desc().padding_dims;
1309
1310         const int G = w_groups ? dims[0] : 1;
1311         const int OC = dims[w_groups + 0];
1312         const int NB_OC = pdims[w_groups + 0] / blksize_o;
1313         const int IC = dims[w_groups + 1];
1314         const int NB_IC = pdims[w_groups + 1] / blksize_i;
1315         const int D = is_3d ? dims[w_groups + 2] : 1;
1316         const int H = dims[w_groups + 2 + is_3d];
1317         const int W = dims[w_groups + 3 + is_3d];
1318
1319         const float *scales = pd->attr()->output_scales_.scales_;
1320         const size_t D_mask = utils::array_product(input_d.dims(),
1321                                                    math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
1322
1323         float adj_scale = (mayiuse(avx512_core_vnni)) ? 1.0 : (1.0 / 2.0);
1324
1325         auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out,
1326             int32_t *c, const float *s, const int oc_block, const int ic_block) {
1327 #            define blk_off OI_blk_off<format_traits<fmt_o>::blk_fmt>
1328
1329             for (int ic = 0; ic < ic_block; ++ic) {
1330                 for (int oc = 0; oc < oc_block; ++oc) {
1331                     const auto _g_oihw_off = oc * flat_d.blocking_desc().strides[0][w_groups + 0] +
1332                                              ic * flat_d.blocking_desc().strides[0][w_groups + 1];
1333
1334                     if (order_keep) {
1335                         out[blk_off(oc, ic)] = qz_b0<data_t<type_i>, data_t<type_o>>()(inp[_g_oihw_off], s[oc] * adj_scale, rmode);
1336                         c[oc] -= (128 * (int32_t)(out[blk_off(oc, ic)]));
1337                     } else {
1338                         out[_g_oihw_off] = qz_b0<data_t<type_i>, data_t<type_o>>()(inp[blk_off(oc, ic)], s[oc] * adj_scale, rmode);
1339                         c[oc] -= (128 * (int32_t)(out[_g_oihw_off]));
1340                     }
1341                 }
1342             }
1343
1344 #           undef blk_off
1345         };
1346
1347         constexpr int i_mult_o = blksize_o;
1348         constexpr int i_mult_i = blksize_i;
1349
1350         size_t offset = G * pdims[w_groups+0] * pdims[w_groups+1] * D * H * W;
1351         int32_t *cp = reinterpret_cast<int32_t *>(output + offset);
1352         parallel_nd(G * NB_OC * blksize_o, [&](int i) {
1353             cp[i] = 0;
1354         });
1355
1356         parallel_nd(G, NB_OC, [&](int g, int O) {
1357             for (int I = 0; I < NB_IC; I++) {
1358                 for (int d = 0; d < D; d++) {
1359                     for (int h = 0; h < H; h++) {
1360                         for (int w = 0; w < W; w++) {
1361                             auto i = is_3d ? &input[input_d.blk_off<!w_groups>(g, i_mult_o * O, i_mult_i * I, d, h, w)]
1362                                            : &input[input_d.blk_off<!w_groups>(g, i_mult_o * O, i_mult_i * I, h, w)];
1363                             auto o = is_3d ? &output[output_d.blk_off<!w_groups>(g, O, I, d, h, w)]
1364                                            : &output[output_d.blk_off<!w_groups>(g, O, I, h, w)];
1365                             const int oc_block = nstl::min(blksize_o, OC - O * blksize_o);
1366                             const int ic_block = nstl::min(blksize_i, IC - I * blksize_i);
1367
1368                             int _offset = (g * NB_OC + O) * blksize_o;
1369                             ker(i, o, (order_keep) ? &cp[_offset] : nullptr, &scales[(D_mask == 1) ? 0 : _offset],
1370                                 oc_block,
1371                                 ic_block);
1372                         }
1373                     }
1374                 }
1375             }
1376         });
1377
1378         return success;
1379     }
1380 };
1381
1382 template <SIMPLE_REORDER_TEMPL_DECL>
1383 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1384 typename utils::enable_if<fmt_i == any && (fmt_o == OhIw8o4i || fmt_o == gOhIw8o4i)>::type>
1385 {
1386     PLAIN_TO_BLOCKED_IS_APPLICABLE();
1387
1388     GET_SCRATCHPAD_SIZE_ZERO();
1389
1390     static status_t execute(const cpu_reorder_pd_t *pd,
1391         const data_t<type_i> *input, data_t<type_o> *output,
1392         const memory_tracking::grantor_t &scratchpad) {
1393         DECLARE_COMMON_PARAMS();
1394
1395         static constexpr bool w_groups
1396             = format_traits<fmt_o>::data_kind == dk::gwei;
1397         constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
1398         constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
1399         constexpr int blksize_o = 8;//format_traits<fmt_o>::blk_size;
1400         constexpr int blksize_i = 4;
1401
1402         const auto &flat_d = order_keep ? input_d : output_d;
1403         const auto &dims = input_d.dims();
1404         const auto &pdims = order_keep
1405             ? output_d.blocking_desc().padding_dims
1406             : input_d.blocking_desc().padding_dims;
1407
1408         const int G = w_groups ? dims[0] : 1;
1409         const int OC = dims[w_groups + 0];
1410         const int NB_OC = pdims[w_groups + 0] / blksize_o;
1411         const int IC = dims[w_groups + 1];
1412         const int NB_IC = pdims[w_groups + 1] / blksize_i;
1413         const int D = is_3d ? dims[w_groups + 2] : 1;
1414         const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
1415         const int W = dims[w_groups + 3 + is_3d - is_1d];
1416
1417         auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
1418             const int oc_block, const int ic_block) {
1419 #           define blk_off OI_blk_off<format_traits<fmt_o>::blk_fmt>
1420
1421             if (alpha == 1.0 && beta == 0.0) {
1422                 for (int oc = 0; oc < oc_block; ++oc)
1423                 for (int ic = 0; ic < ic_block; ++ic) {
1424                     const ptrdiff_t flat_off = 0
1425                         + oc * flat_d.blocking_desc().strides[0][w_groups + 0]
1426                         + ic * flat_d.blocking_desc().strides[0][w_groups + 1];
1427                     if (order_keep) {
1428                         o[blk_off(oc, ic)] = _qz_a1b0<type_i, type_o>()(
1429                                 i[flat_off], rmode);
1430                     } else {
1431                         o[flat_off] = _qz_a1b0<type_i, type_o>()(
1432                                 i[blk_off(oc, ic)], rmode);
1433                     }
1434                 }
1435             } else {
1436                 for (int oc = 0; oc < oc_block; ++oc)
1437                 for (int ic = 0; ic < ic_block; ++ic) {
1438                     const ptrdiff_t flat_off = 0
1439                         + oc * flat_d.blocking_desc().strides[0][w_groups + 0]
1440                         + ic * flat_d.blocking_desc().strides[0][w_groups + 1];
1441                     if (order_keep) {
1442                         o[blk_off(oc, ic)] = _qz<type_i, type_o>()(i[flat_off],
1443                                 o[blk_off(oc, ic)], alpha, beta, rmode);
1444                     } else {
1445                         o[flat_off] = _qz<type_i, type_o>()(i[blk_off(oc, ic)],
1446                                 o[flat_off], alpha, beta, rmode);
1447                     }
1448                 }
1449             }
1450
1451 #           undef blk_off
1452         };
1453
1454
1455         constexpr int i_mult_o = blksize_o;
1456         constexpr int i_mult_i = blksize_i;
1457
1458         parallel_nd(G, NB_OC, NB_IC, D, H, W,
1459             [&](int g, int nb_oc, int nb_ic, int d, int h, int w) {
1460             int i_off = wei_blk_off_like_gwei3D<fmt_o>(input_d,
1461                                                        g, i_mult_o * nb_oc, i_mult_i * nb_ic, d, h, w);
1462             int o_off = wei_blk_off_like_gwei3D<fmt_o>(output_d,
1463                                                        g, nb_oc, nb_ic, d, h, w);
1464             auto i = &input[i_off];
1465             auto o = &output[o_off];
1466             const int oc_block = nstl::min(blksize_o, OC - nb_oc * blksize_o);
1467             const int ic_block = nstl::min(blksize_i, IC - nb_ic * blksize_i);
1468             ker(i, o, oc_block, ic_block);
1469         });
1470
1471         return success;
1472     }
1473 };
1474
1475 template <SIMPLE_REORDER_TEMPL_DECL>
1476 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1477 typename utils::enable_if<fmt_i == any && (fmt_o == OhIw8o32i || fmt_o == OhIw16o32i) && type_i == mkldnn_bin && type_o == mkldnn_bin>::type>
1478 {
1479     PLAIN_TO_BLOCKED_IS_APPLICABLE();
1480
1481     GET_SCRATCHPAD_SIZE_ZERO();
1482
1483     static status_t execute(const cpu_reorder_pd_t *pd,
1484         const data_t<type_i> *input, data_t<type_o> *output,
1485         const memory_tracking::grantor_t &scratchpad) {
1486         DECLARE_COMMON_PARAMS();
1487
1488         static constexpr bool w_groups
1489             = format_traits<fmt_o>::data_kind == dk::gwei;
1490         constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
1491         constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
1492         constexpr int blksize_o = fmt_o == OhIw8o32i ? 8 : 16;
1493         constexpr int blksize_i = 32;
1494
1495         const auto &dims = input_d.dims();
1496         const auto &pdims = order_keep
1497             ? output_d.blocking_desc().padding_dims
1498             : input_d.blocking_desc().padding_dims;
1499
1500         const int G = w_groups ? dims[0] : 1;
1501         const int OC = dims[w_groups + 0];
1502         const int NB_OC = pdims[w_groups + 0] / blksize_o;
1503         const int IC = dims[w_groups + 1];
1504         const int NB_IC = pdims[w_groups + 1] / blksize_i;
1505         const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
1506         const int W = dims[w_groups + 3 + is_3d - is_1d];
1507
1508         constexpr int i_mult_o = blksize_o;
1509         constexpr int i_mult_i = blksize_i;
1510         constexpr int nbits = 8;
1511
1512         auto extract_bit = [](uint8_t val, uint8_t bit) -> uint8_t {
1513             return (uint8_t) ((val >> bit) & 0x0001);
1514         };
1515
1516         parallel_nd(G, NB_OC, NB_IC, H, W,
1517             [&](int g, int nb_oc, int nb_ic, int h, int w) {
1518                 const int oc_block = nstl::min(blksize_o, OC - nb_oc * blksize_o);
1519                 const int ic_block = nstl::min(blksize_i, IC - nb_ic * blksize_i);
1520
1521                 for (int oc = 0; oc < oc_block; ++oc) {
1522                     for (int icb = 0; icb < div_up(ic_block, nbits); ++icb) {
1523
1524                         uint8_t bin_val = 0x00;
1525                         for (int ic = icb*nbits, shift = 0; ic < std::min(IC, (icb + 1)*nbits); ic++, shift++) {
1526                             size_t iidx = (i_mult_o * nb_oc + oc) * input_d.blocking_desc().strides[0][0] +
1527                                           (i_mult_i * nb_ic + ic) * input_d.blocking_desc().strides[0][1] +
1528                                                                 h * input_d.blocking_desc().strides[0][2] +
1529                                                                 w;
1530
1531                             uint8_t bit = extract_bit(input[iidx / nbits], (uint8_t)(iidx % nbits));
1532                             bin_val |= (bit << shift);
1533                         }
1534
1535                         size_t oidx = wei_blk_off_like_gwei3D<fmt_o>(output_d, g, nb_oc, nb_ic, 0, h, w) + oc * blksize_i + icb * nbits;
1536                         output[oidx / nbits] = bin_val;
1537
1538                     }
1539                 }
1540             });
1541
1542         return success;
1543     }
1544 };
1545
1546 template <SIMPLE_REORDER_TEMPL_DECL>
1547 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1548 typename utils::enable_if<fmt_i == any
1549 && block_format_traits<format_traits<fmt_o>::blk_fmt>::blk_ndims == 2
1550 && fmt_o != OhIw8o4i && fmt_o != gOhIw8o4i && fmt_o != OhIw8o32i && fmt_o != OhIw16o32i>::type>
1551 {
1552     PLAIN_TO_BLOCKED_IS_APPLICABLE();
1553
1554     GET_SCRATCHPAD_SIZE_ZERO();
1555
1556     static status_t execute(const cpu_reorder_pd_t *pd,
1557         const data_t<type_i> *input, data_t<type_o> *output,
1558         const memory_tracking::grantor_t &scratchpad) {
1559         DECLARE_COMMON_PARAMS();
1560
1561         static constexpr bool w_groups
1562             = format_traits<fmt_o>::data_kind == dk::gwei;
1563         constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
1564         constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
1565         constexpr int blksize = format_traits<fmt_o>::blk_size;
1566
1567         const auto &flat_d = order_keep ? input_d : output_d;
1568         const auto &dims = input_d.dims();
1569         const auto &pdims = order_keep
1570             ? output_d.blocking_desc().padding_dims
1571             : input_d.blocking_desc().padding_dims;
1572
1573         const int G = w_groups ? dims[0] : 1;
1574         const int OC = dims[w_groups + 0];
1575         const int NB_OC = pdims[w_groups + 0] / blksize;
1576         const int IC = dims[w_groups + 1];
1577         const int NB_IC = pdims[w_groups + 1] / blksize;
1578         const int D = is_3d ? dims[w_groups + 2] : 1;
1579         const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
1580         const int W = dims[w_groups + 3 + is_3d - is_1d];
1581
1582         auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
1583             const int oc_block, const int ic_block) {
1584 #           define blk_off OI_blk_off<format_traits<fmt_o>::blk_fmt>
1585
1586             if (alpha == 1.0 && beta == 0.0) {
1587                 for (int oc = 0; oc < oc_block; ++oc)
1588                 for (int ic = 0; ic < ic_block; ++ic) {
1589                     const ptrdiff_t flat_off = 0
1590                         + oc * flat_d.blocking_desc().strides[0][w_groups + 0]
1591                         + ic * flat_d.blocking_desc().strides[0][w_groups + 1];
1592                     if (order_keep) {
1593                         o[blk_off(oc, ic)] = _qz_a1b0<type_i, type_o>()(
1594                                 i[flat_off], rmode);
1595                     } else {
1596                         o[flat_off] = _qz_a1b0<type_i, type_o>()(
1597                                 i[blk_off(oc, ic)], rmode);
1598                     }
1599                 }
1600             } else {
1601                 for (int oc = 0; oc < oc_block; ++oc)
1602                 for (int ic = 0; ic < ic_block; ++ic) {
1603                     const ptrdiff_t flat_off = 0
1604                         + oc * flat_d.blocking_desc().strides[0][w_groups + 0]
1605                         + ic * flat_d.blocking_desc().strides[0][w_groups + 1];
1606                     if (order_keep) {
1607                         o[blk_off(oc, ic)] = _qz<type_i, type_o>()(i[flat_off],
1608                                 o[blk_off(oc, ic)], alpha, beta, rmode);
1609                     } else {
1610                         o[flat_off] = _qz<type_i, type_o>()(i[blk_off(oc, ic)],
1611                                 o[flat_off], alpha, beta, rmode);
1612                     }
1613                 }
1614             }
1615
1616 #           undef blk_off
1617         };
1618
1619
1620         constexpr int i_mult = order_keep ? blksize : 1;
1621         constexpr int o_mult = order_keep ? 1 : blksize;
1622
1623         parallel_nd(G, NB_OC, NB_IC, D, H, W,
1624             [&](int g, int nb_oc, int nb_ic, int d, int h, int w) {
1625             auto i = &input[wei_blk_off_like_gwei3D<fmt_o>(input_d,
1626                     g, i_mult * nb_oc, i_mult * nb_ic, d, h, w)];
1627             auto o = &output[wei_blk_off_like_gwei3D<fmt_o>(output_d,
1628                     g, o_mult * nb_oc, o_mult * nb_ic, d, h, w)];
1629             const int oc_block = nstl::min(blksize, OC - nb_oc * blksize);
1630             const int ic_block = nstl::min(blksize, IC - nb_ic * blksize);
1631             ker(i, o, oc_block, ic_block);
1632         });
1633
1634         return success;
1635     }
1636 };
1637
1638 template <SIMPLE_REORDER_TEMPL_DECL>
1639 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1640 typename utils::enable_if<fmt_i == any && (false
1641     || format_traits<fmt_o>::blk_fmt == bf::_4o
1642     || format_traits<fmt_o>::blk_fmt == bf::_8o
1643     || format_traits<fmt_o>::blk_fmt == bf::_16o)>::type>
1644 {
1645     PLAIN_TO_BLOCKED_IS_APPLICABLE();
1646
1647     GET_SCRATCHPAD_SIZE_ZERO();
1648
1649     static status_t execute(const cpu_reorder_pd_t *pd,
1650         const data_t<type_i> *input, data_t<type_o> *output,
1651         const memory_tracking::grantor_t &scratchpad) {
1652         DECLARE_COMMON_PARAMS();
1653
1654         static constexpr bool w_groups
1655             = format_traits<fmt_o>::data_kind == dk::gwei;
1656         constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
1657         constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
1658         constexpr int blksize = format_traits<fmt_o>::blk_size;
1659
1660         const auto &flat_d = order_keep ? input_d : output_d;
1661         const auto &dims = input_d.dims();
1662         const auto &pdims = order_keep
1663             ? output_d.blocking_desc().padding_dims
1664             : input_d.blocking_desc().padding_dims;
1665
1666         const int G = w_groups ? dims[0] : 1;
1667         const int OC = dims[w_groups + 0];
1668         const int IC = dims[w_groups + 1];
1669         const int D = is_3d ? dims[w_groups + 2] : 1;
1670         const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
1671         const int W = dims[w_groups + 3 + is_3d - is_1d];
1672
1673         constexpr int i_mult = order_keep ? blksize : 1;
1674         constexpr int o_mult = order_keep ? 1 : blksize;
1675         const auto strd_oc = flat_d.blocking_desc().strides[0][w_groups];
1676
1677         parallel_nd(G, pdims[w_groups + 0] / blksize, IC, D, H, W,
1678             [&](int g, int nb_oc, int ic, int d, int h, int w) {
1679             auto i = &input[wei_blk_off_like_gwei3D<fmt_o>(input_d,
1680                     g, i_mult * nb_oc, ic, d, h, w)];
1681             auto o = &output[wei_blk_off_like_gwei3D<fmt_o>(output_d,
1682                     g, o_mult * nb_oc, ic, d, h, w)];
1683             const int oc_block = nstl::min(blksize, OC - nb_oc * blksize);
1684
1685             if (alpha == 1.0 && beta == 0.0) {
1686                 for (int oc = 0; oc < oc_block; ++oc) {
1687                     const auto off = oc * strd_oc;
1688                     if (order_keep) {
1689                         o[oc] = _qz_a1b0<type_i, type_o>()(i[off], rmode);
1690                     } else {
1691                         o[off] = _qz_a1b0<type_i, type_o>()(i[oc], rmode);
1692                     }
1693                 }
1694             } else {
1695                 for (int oc = 0; oc < oc_block; ++oc) {
1696                     const auto off = oc * strd_oc;
1697                     if (order_keep) {
1698                         o[oc] = _qz<type_i, type_o>()(i[off], o[oc], alpha,
1699                                 beta, rmode);
1700                     } else {
1701                         o[off] = _qz<type_i, type_o>()(i[oc], o[off], alpha,
1702                                 beta, rmode);
1703                     }
1704                 }
1705             }
1706         });
1707
1708         return success;
1709     }
1710 };
1711
1712 /* generic and direct-copy reorders */
1713
1714 template <SIMPLE_REORDER_TEMPL_DECL>
1715 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1716     typename utils::enable_if<
1717         fmt_i == any && fmt_o == any && order_keep == fmt_order::any,
1718     spec::direct_copy>::type>
1719 {
1720     static bool is_applicable(const memory_desc_wrapper &input_d,
1721             const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
1722         /* FIXME: is the formula correct? */
1723         return input_d.similar_to(output_d, true, false, 0)
1724             && input_d.is_dense() && output_d.is_dense()
1725             && simple_attr_check(attr, false);
1726     }
1727
1728     GET_SCRATCHPAD_SIZE_ZERO();
1729
1730     static status_t execute(const cpu_reorder_pd_t *pd,
1731         const data_t<type_i> *input, data_t<type_o> *output,
1732         const memory_tracking::grantor_t &scratchpad) {
1733         DECLARE_COMMON_PARAMS();
1734
1735         assert(input_d.is_dense());
1736
1737         input += input_d.blk_off(0);
1738         output += output_d.blk_off(0);
1739
1740         const size_t nelems = input_d.nelems();
1741
1742         constexpr int block_size = 16;
1743         const auto num_blocks = nelems / block_size;
1744         const auto rem_elems = nelems % block_size;
1745
1746         parallel(0, num_blocks, [&](const int ithr, const int nthr) {
1747             size_t start{0}, end{0};
1748             balance211(num_blocks, nthr, ithr, start, end);
1749             start = start * block_size;
1750             end = end * block_size;
1751
1752             if (alpha == 1.0 && beta == 0.0) {
1753                 PRAGMA_OMP_SIMD()
1754                 for (size_t e = start; e < end; ++e) {
1755                     output[e] = qz_a1b0<data_t<type_i>, data_t<type_o>>()
1756                                 (input[e], rmode);
1757                 }
1758             } else if (alpha == 1.0) {
1759                 PRAGMA_OMP_SIMD()
1760                 for (size_t e = start; e < end; ++e) {
1761                     output[e] = qz_a1<data_t<type_i>, data_t<type_o>>()
1762                                 (input[e], output[e], beta, rmode);
1763                 }
1764             } else if (beta == 0.0) {
1765                 PRAGMA_OMP_SIMD()
1766                 for (size_t e = start; e < end; ++e) {
1767                     output[e] = qz_b0<data_t<type_i>, data_t<type_o>>()
1768                                 (input[e], alpha, rmode);
1769                 }
1770             } else {
1771                 PRAGMA_OMP_SIMD()
1772                 for (size_t e = start; e < end; ++e) {
1773                     output[e] = qz<data_t<type_i>, data_t<type_o>>()
1774                                 (input[e], output[e], alpha, beta, rmode);
1775                 }
1776             }
1777
1778             if (rem_elems != 0 && ithr == nthr - 1){
1779                 if (alpha == 1.0 && beta == 0.0) {
1780                     PRAGMA_OMP_SIMD()
1781                     for (size_t e = nelems - rem_elems; e < nelems; ++e) {
1782                         output[e] = qz_a1b0<data_t<type_i>,
1783                             data_t<type_o>>()(input[e], rmode);
1784                     }
1785                 } else if (alpha == 1.0) {
1786                     PRAGMA_OMP_SIMD()
1787                     for (size_t e = nelems - rem_elems; e < nelems; ++e) {
1788                         output[e] = qz_a1<data_t<type_i>,
1789                             data_t<type_o>>()(input[e], output[e], beta, rmode);
1790                     }
1791                 } else if (beta == 0.0) {
1792                     PRAGMA_OMP_SIMD()
1793                     for (size_t e = nelems - rem_elems; e < nelems; ++e) {
1794                         output[e] = qz_b0<data_t<type_i>,
1795                             data_t<type_o>>()(input[e], alpha, rmode);
1796                     }
1797                 } else {
1798                     PRAGMA_OMP_SIMD()
1799                     for (size_t e = nelems - rem_elems; e < nelems; ++e) {
1800                         output[e] = qz<data_t<type_i>, data_t<type_o>>()
1801                                     (input[e], output[e], alpha, beta, rmode);
1802                    }
1803                }
1804             }
1805         });
1806         return success;
1807     }
1808 };
1809
1810 template <SIMPLE_REORDER_TEMPL_DECL>
1811 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1812     typename utils::enable_if<
1813         fmt_i == any && fmt_o == any && order_keep == fmt_order::any,
1814     spec::direct_copy_except_dim_0>::type>
1815 {
1816     static bool is_applicable(const memory_desc_wrapper &input_d,
1817             const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
1818         auto is_dense_no_0 = [](const memory_desc_wrapper &data_d) {
1819             return nelems_no_dim_0(data_d) == _size_no_dim_0(data_d);
1820         };
1821         /* FIXME: is the formula correct? */
1822         return input_d.similar_to(output_d, true, false, 1)
1823             && is_dense_no_0(input_d) && is_dense_no_0(output_d)
1824             && simple_attr_check(attr, false);
1825     }
1826
1827     GET_SCRATCHPAD_SIZE_ZERO();
1828
1829     static status_t execute(const cpu_reorder_pd_t *pd,
1830         const data_t<type_i> *input, data_t<type_o> *output,
1831         const memory_tracking::grantor_t &scratchpad) {
1832         DECLARE_COMMON_PARAMS();
1833
1834         input += input_d.blk_off(0);
1835         output += output_d.blk_off(0);
1836
1837         const int N = input_d.dims()[0];
1838         const size_t is = input_d.blocking_desc().strides[0][0];
1839         const size_t os = output_d.blocking_desc().strides[0][0];
1840         const size_t nelems_no_d0 = nelems_no_dim_0(input_d);
1841         const size_t work_amount = N * nelems_no_d0;
1842
1843         if (alpha == 1.0 && beta == 0.0) {
1844             parallel(0, work_amount, [&](const int ithr, const int nthr) {
1845                 size_t n{0}, dim1_s{0};
1846                 size_t start{0}, end{0};
1847                 balance211(work_amount, nthr, ithr, start, end);
1848                 nd_iterator_init(start, n, N, dim1_s, nelems_no_d0);
1849                 while(start < end) {
1850                     size_t work_rem = end - start;
1851                     size_t dim1_e = dim1_s + work_rem > nelems_no_d0
1852                         ? nelems_no_d0 : dim1_s + work_rem;
1853                     PRAGMA_OMP_SIMD()
1854                     for (size_t e = dim1_s; e < dim1_e; ++e) {
1855                         output[os * n + e] = _qz_a1b0<type_i, type_o>()(
1856                                 input[is * n + e], rmode);
1857                     }
1858                     nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0);
1859                 }
1860             });
1861         } else {
1862             parallel(0, work_amount, [&](const int ithr, const int nthr) {
1863                 size_t n{0}, dim1_s{0};
1864                 size_t start{0}, end{0};
1865                 balance211(work_amount, nthr, ithr, start, end);
1866                 nd_iterator_init(start, n, N, dim1_s, nelems_no_d0);
1867                 while(start < end) {
1868                     size_t work_rem = end - start;
1869                     size_t dim1_e =
1870                         dim1_s + work_rem > nelems_no_d0 ? nelems_no_d0
1871                         : dim1_s + work_rem;
1872                     PRAGMA_OMP_SIMD()
1873                     for (size_t e = dim1_s; e < dim1_e; ++e){
1874                         output[os * n + e] = _qz<type_i, type_o>()(
1875                                 input[is * n + e], output[os * n + e], alpha,
1876                                 beta, rmode);
1877                     }
1878                     nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0);
1879                 }
1880             });
1881         }
1882
1883         return success;
1884     }
1885
1886 private:
1887     static size_t nelems_no_dim_0(const memory_desc_wrapper &data_d) {
1888         const int ndims = data_d.ndims();
1889         if (ndims <= 1) return 1;
1890         return utils::array_product(data_d.dims() + 1, data_d.ndims() - 1);
1891     }
1892
1893     static size_t _size_no_dim_0(const memory_desc_wrapper &data_d) {
1894         size_t max_size = 0;
1895         auto &blk = data_d.blocking_desc();
1896         for (int d = 1; d < data_d.ndims(); ++d) {
1897             auto block = blk.block_dims[d];
1898             max_size = nstl::max(max_size,
1899                     size_t(size_t(blk.padding_dims[d] / block)
1900                         * blk.strides[0][d]));
1901             if (block > 1)
1902                 max_size = nstl::max(max_size,
1903                         size_t(block * blk.strides[1][d]));
1904         }
1905         return max_size;
1906     }
1907 };
1908
1909 template <SIMPLE_REORDER_TEMPL_DECL>
1910 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1911     typename utils::enable_if<
1912         fmt_i == any && fmt_o == any && order_keep == fmt_order::any,
1913     spec::reference>::type>
1914 {
1915     static bool is_applicable(const memory_desc_wrapper &input_d,
1916             const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
1917         /* supported smask: 0x0...011..10...0,
1918          * i.e. 1 should be contiguous */
1919         int smask = attr ? attr->output_scales_.mask_ : 0;
1920         for (; smask > 0 && !(smask & 0x1); smask >>= 1);
1921         for (; smask > 0 && smask & 0x1; smask >>= 1);
1922         return true
1923             && input_d.is_blocking_desc()
1924             && output_d.is_blocking_desc()
1925             && !output_d.is_additional_buffer()
1926             && !input_d.is_additional_buffer()
1927             && smask == 0;
1928     }
1929
1930     GET_SCRATCHPAD_SIZE_ZERO();
1931
1932     static status_t execute(const cpu_reorder_pd_t *pd,
1933         const data_t<type_i> *input, data_t<type_o> *output,
1934         const memory_tracking::grantor_t &scratchpad) {
1935         DECLARE_COMMON_PARAMS();
1936
1937         const size_t nelems = input_d.nelems();
1938
1939         int ndims_start = 0, ndims_mask = 0;
1940         int smask = pd->attr()->output_scales_.mask_;
1941         for (; smask > 0 && !(smask & 0x1); smask >>= 1) ++ndims_start;
1942         for (; smask > 0 && smask & 0x1; smask >>= 1) ++ndims_mask;
1943         assert(smask == 0);
1944
1945         const ptrdiff_t D_start
1946             = utils::array_product(input_d.dims(), ndims_start);
1947         const ptrdiff_t D_mask
1948             = utils::array_product(input_d.dims() + ndims_start, ndims_mask);
1949         const ptrdiff_t D_rest = nelems / D_start / D_mask;
1950
1951         const float *scales = pd->attr()->output_scales_.scales_;
1952
1953         parallel_nd(D_start, D_mask, D_rest,
1954             [&](ptrdiff_t ds, ptrdiff_t dm, ptrdiff_t dr) {
1955             const float scale = scales[dm];
1956
1957             const size_t e = (ds * D_mask + dm) * D_rest + dr;
1958             const auto &i = input[input_d.off_l(e)];
1959             auto &o = output[output_d.off_l(e)];
1960
1961             o = _qz<type_i, type_o>()(i, o, scale, beta, rmode);
1962         });
1963
1964         return success;
1965     }
1966 };
1967
1968
1969 /* high level class declaration */
1970
1971 template <SIMPLE_REORDER_TEMPL_DECL, typename spec = void>
1972 struct simple_reorder_t: public cpu_primitive_t {
1973     struct pd_t: public cpu_reorder_pd_t {
1974         pd_t(const cpu_memory_pd_t *input_pd, const cpu_memory_pd_t *output_pd,
1975                 const primitive_attr_t *attr)
1976             : cpu_reorder_pd_t(input_pd, output_pd, attr) {}
1977
1978         DECLARE_COMMON_PD_T("simple:any", simple_reorder_t);
1979
1980         static status_t create(reorder_pd_t **reorder_pd,
1981                 const memory_pd_t *input_pd, const memory_pd_t *output_pd,
1982                 const primitive_attr_t *attr) {
1983             assert(input_pd->engine()->kind() == engine_kind::cpu);
1984             assert(output_pd->engine()->kind() == engine_kind::cpu);
1985             bool args_ok = true
1986                 && input_pd->desc()->data_type == type_i
1987                 && output_pd->desc()->data_type == type_o
1988                 && IMPLICATION(utils::one_of(data_type::bf16, type_i, type_o),
1989                         mayiuse(avx512_core))
1990                 && simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, spec>::
1991                 is_applicable(input_pd->desc(), output_pd->desc(), attr);
1992             if (!args_ok)
1993                 return invalid_arguments;
1994
1995             auto _pd = new pd_t((const cpu_memory_pd_t *)input_pd,
1996                     (const cpu_memory_pd_t *)output_pd, attr);
1997             if (_pd == nullptr) return out_of_memory;
1998             if (_pd->init() != success) { delete _pd; return unimplemented; }
1999
2000             const size_t scratchpad_sz_ =
2001                 simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, spec>::
2002                     get_scratchpad_size(input_pd->desc(), output_pd->desc());
2003             auto scratchpad = _pd->scratchpad_registry().registrar();
2004             scratchpad.book(memory_tracking::names::key_reorder_space,
2005                     scratchpad_sz_);
2006             return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
2007         }
2008     };
2009
2010     simple_reorder_t(const pd_t *apd, const input_vector &inputs,
2011             const output_vector &outputs)
2012         : cpu_primitive_t(apd, inputs, outputs) {}
2013
2014     virtual void execute(event_t *e) const {
2015         auto input = reinterpret_cast<const data_t<type_i> *>(
2016                 this->input_memory(0));
2017         auto output = reinterpret_cast<data_t<type_o> *>(this->memory());
2018         simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, spec>::execute(
2019                 pd(), input, output, this->scratchpad());
2020         e->set_state(event_t::ready);
2021     }
2022
2023 private:
2024     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
2025 };
2026
2027 #undef SIMPLE_REORDER_TEMPL_DECL
2028 #undef SIMPLE_REORDER_TEMPL_CALL
2029
2030 }
2031 }
2032 }
2033
2034 #endif
2035
2036 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s