1 /*******************************************************************************
2 * Copyright 2016-2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef CPU_SIMPLE_REORDER_HPP
18 #define CPU_SIMPLE_REORDER_HPP
22 #include "c_types_map.hpp"
23 #include "type_helpers.hpp"
24 #include "math_utils.hpp"
25 #include "mkldnn_thread.hpp"
28 #include "format_traits.hpp"
29 #include "cpu_reorder_pd.hpp"
30 #include "cpu_primitive.hpp"
32 #include "simple_q10n.hpp"
33 #include "cpu_isa_traits.hpp"
35 #include "bfloat16_utils.hpp"
41 using namespace mkldnn::impl::status;
42 using namespace mkldnn::impl::memory_format;
43 using namespace mkldnn::impl::data_type;
45 using dk = data_kind_t;
46 using bf = block_format_t;
48 using namespace mkldnn::impl::utils;
51 template<impl::data_type_t type>
52 using data_t = typename prec_traits<type>::type;
54 template<impl::data_type_t type_i, impl::data_type_t type_o>
55 using _qz_a1b0 = qz_a1b0<data_t<type_i>, data_t<type_o>>;
57 template<impl::data_type_t type_i, impl::data_type_t type_o>
58 using _qz = qz<data_t<type_i>, data_t<type_o>>;
61 const bool keep = true;
62 const bool reverse = false;
63 const bool any = keep;
67 struct direct_copy {};
68 struct direct_copy_except_dim_0 {};
72 #define SIMPLE_REORDER_TEMPL_DECL \
73 impl::data_type_t type_i, impl::memory_format_t fmt_i, \
74 impl::data_type_t type_o, impl::memory_format_t fmt_o, bool order_keep
75 #define SIMPLE_REORDER_TEMPL_CALL \
76 type_i, fmt_i, type_o, fmt_o, order_keep
78 #define DECLARE_COMMON_PARAMS() \
79 const memory_desc_wrapper &input_d = pd->input_pd(); \
80 const memory_desc_wrapper &output_d = pd->output_pd(); \
81 const float alpha = pd->alpha(); MAYBE_UNUSED(alpha); \
82 const float beta = pd->beta(); MAYBE_UNUSED(beta); \
83 const round_mode_t rmode = pd->attr()->round_mode_; MAYBE_UNUSED(rmode);
85 #define GET_SCRATCHPAD_SIZE_ZERO() \
86 static size_t get_scratchpad_size(const memory_desc_wrapper &input_d, \
87 const memory_desc_wrapper &output_d) { \
91 /* specific reorders: common template */
92 template <SIMPLE_REORDER_TEMPL_DECL, typename spec = void>
93 struct simple_reorder_impl {};
96 bool simple_fmt_check(bool order_keep, impl::memory_format_t fmt_i,
97 impl::memory_format_t fmt_o, const memory_desc_wrapper &input_d,
98 const memory_desc_wrapper &output_d) {
99 return input_d.format() == (order_keep ? fmt_i : fmt_o)
100 && output_d.format() == (order_keep ? fmt_o : fmt_i);
102 bool simple_attr_check(const primitive_attr_t *attr, bool many_scales_support) {
103 if (many_scales_support)
105 return IMPLICATION(attr, attr->output_scales_.mask_ == 0);
109 /* specific reorders: implementation */
110 template <SIMPLE_REORDER_TEMPL_DECL>
111 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
112 typename utils::enable_if<fmt_i == any && (false
113 || fmt_o == hwio_s8s8 || fmt_o == dhwio_s8s8
114 || fmt_o == hwigo_s8s8 || fmt_o == dhwigo_s8s8)>::type>
116 static bool is_applicable(const memory_desc_wrapper &input_d,
117 const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
119 const size_t D_mask = utils::array_product(input_d.dims(),
120 math::ilog2q(attr->output_scales_.mask_ + 1));
121 const int oc = (input_d.dims()[fmt_o == hwigo_s8s8 || fmt_o == dhwigo_s8s8 + 0]);
122 const int g = (fmt_o == hwigo_s8s8 || fmt_o == dhwigo_s8s8) ? (input_d.dims()[0]) : 1;
124 return output_d.format() == fmt_o
125 && (input_d.data_type() == f32 || input_d.data_type() == s8)
126 && output_d.data_type() == s8
127 && (D_mask == 1 || D_mask == (size_t)g * oc);
130 GET_SCRATCHPAD_SIZE_ZERO();
132 static status_t execute(const cpu_reorder_pd_t *pd,
133 const data_t<type_i> *input, data_t<type_o> *output,
134 const memory_tracking::grantor_t &scratchpad) {
135 DECLARE_COMMON_PARAMS();
137 static constexpr bool w_groups = fmt_o == hwigo_s8s8 || fmt_o == dhwigo_s8s8;
138 int is_3d = format_traits<fmt_o>::ndims_sp == 3;
140 const auto &dims = input_d.dims();
141 const auto &pdims = output_d.blocking_desc().padding_dims;
143 const int G = w_groups ? dims[0] : 1;
144 const int OC = dims[w_groups + 0];
145 const int IC = dims[w_groups + 1];
146 const int D = is_3d ? dims[w_groups + 2] : 1;
147 const int H = dims[w_groups + 2 + is_3d];
148 const int W = dims[w_groups + 3 + is_3d];
150 const float *scales = pd->attr()->output_scales_.scales_;
151 const size_t D_mask = utils::array_product(input_d.dims(),
152 math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
154 float adj_scale = (mayiuse(avx512_core_vnni)) ? 1.0f : (1.0f / 2.0f);
156 size_t offset = G * pdims[w_groups + 0] * pdims[w_groups + 1] * D * H * W;
157 int32_t *cp = reinterpret_cast<int32_t *>(output + offset);
159 parallel_nd(G, OC, [&](int g, int oc) {
161 for (int ic = 0; ic < IC; ic++)
162 for (int d = 0; d < D; d++)
163 for (int h = 0; h < H; h++)
164 for (int w = 0; w < W; w++) {
165 auto i = is_3d ? input[input_d.blk_off<!w_groups>(g, oc, ic, d, h, w)]
166 : input[input_d.blk_off<!w_groups>(g, oc, ic, h, w)];
167 auto &o = is_3d ? output[output_d.blk_off<!w_groups>(g, oc, ic, d, h, w)]
168 : output[output_d.blk_off<!w_groups>(g, oc, ic, h, w)];
169 const float s = scales[(D_mask == 1) ? 0 : g * OC + oc];
171 o = qz_b0<data_t<type_i>, data_t<type_o>>()(
172 i, s * adj_scale, rmode);
173 cp[g * OC + oc] -= (int32_t)o;
175 cp [g * OC + oc] *= 128;
181 template <SIMPLE_REORDER_TEMPL_DECL>
182 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
183 typename utils::enable_if<(
184 utils::one_of(fmt_i, goihw, oihw, goiw, oiw, hwio, hwigo)
185 && (format_traits<fmt_o>::blk_fmt == bf::_4i16o4i_s8s8
186 || format_traits<fmt_o>::blk_fmt == bf::_2i8o4i_s8s8
187 || format_traits<fmt_o>::blk_fmt
188 == bf::_4o4i_s8s8))>::type> {
189 static bool is_applicable(const memory_desc_wrapper &input_d,
190 const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
192 const size_t D_mask = utils::array_product(input_d.dims(),
193 math::ilog2q(attr->output_scales_.mask_ + 1));
194 static constexpr bool w_groups
195 = format_traits<fmt_i>::data_kind == dk::gwei;
196 const int oc = input_d.dims()[w_groups + 0];
197 const int g = w_groups ? input_d.dims()[0] : 1;
199 return input_d.format() == fmt_i
200 && output_d.format() == fmt_o
201 && utils::one_of(input_d.data_type(), f32, s8)
202 && output_d.data_type() == s8
203 && (D_mask == 1 || D_mask == (size_t)g * oc);
206 GET_SCRATCHPAD_SIZE_ZERO();
208 static status_t execute(const cpu_reorder_pd_t *pd,
209 const data_t<type_i> *input, data_t<type_o> *output,
210 const memory_tracking::grantor_t &scratchpad) {
211 DECLARE_COMMON_PARAMS();
213 constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
214 static constexpr bool w_groups
215 = format_traits<fmt_o>::data_kind == dk::gwei;
216 const int blksize = format_traits<fmt_o>::blk_size;
219 const auto &plain_d = order_keep ? input_d : output_d;
220 const auto &dims = input_d.dims();
221 const auto &pdims = order_keep
222 ? output_d.blocking_desc().padding_dims
223 : input_d.blocking_desc().padding_dims;
225 const int G = w_groups ? dims[0] : 1;
226 const int OC = dims[w_groups + 0];
227 const int NB_OC = pdims[w_groups + 0] / blksize;
228 const int IC = dims[w_groups + 1];
229 const int NB_IC = pdims[w_groups + 1] / blksize;
230 const int H = is_1d ? 1 : dims[w_groups + 2];
231 const int W = dims[w_groups + 3 - is_1d];
233 const float *scales = pd->attr()->output_scales_.scales_;
234 const size_t D_mask = utils::array_product(input_d.dims(),
235 math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
237 float adj_scale = (mayiuse(avx512_core_vnni)) ? 1.f : (1.f / 2.f);
239 auto index = [&](const int ic, const int oc) {
240 return ((ic / sblk) * blksize * sblk + sblk * oc + ic % sblk);
243 auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out,
244 int32_t *c, const float *s, const int oc_block, const int ic_block) {
245 for (int ic = 0; ic < ic_block; ++ic) {
246 for (int oc = 0; oc < oc_block; ++oc) {
247 const auto plain_off =
248 oc * plain_d.blocking_desc().strides[0][w_groups + 0]
249 + ic * plain_d.blocking_desc().strides[0][w_groups + 1];
251 = qz_b0<data_t<type_i>, data_t<type_o>>()(
252 inp[plain_off], s[oc] * adj_scale, rmode);
253 c[oc] -= (128 * (int32_t)(out[index(ic, oc)]));
258 constexpr int i_mult = blksize;
259 constexpr int o_mult = 1;
261 size_t offset = G * pdims[w_groups+0] * pdims[w_groups+1] * H * W;
262 int32_t *cp = reinterpret_cast<int32_t *>(output + offset);
263 parallel_nd(G * NB_OC * blksize, [&](int i) {
267 parallel_nd(G, NB_OC, [&](int g, int O) {
268 for (int I = 0; I < NB_IC; I++)
269 for (int h = 0; h < H; h++)
270 for (int w = 0; w < W; w++) {
271 auto i = &input[wei_blk_off_like_gwei3D<fmt_i>(
272 input_d, g, i_mult * O, i_mult * I, 0, h, w)];
273 auto o = &output[wei_blk_off_like_gwei3D<fmt_o>(
274 output_d, g, o_mult * O, o_mult * I, 0, h, w)];
275 const int oc_block = nstl::min(blksize, OC - O * blksize);
276 const int ic_block = nstl::min(blksize, IC - I * blksize);
278 int _offset = (g * NB_OC + O) * blksize;
279 ker(i, o, (order_keep) ? &cp[_offset] : nullptr,
280 &scales[(D_mask == 1) ? 0 : _offset],
288 template <SIMPLE_REORDER_TEMPL_DECL>
289 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
290 typename utils::enable_if<(
291 (fmt_i == goihw || fmt_i == oihw) &&
292 (format_traits<fmt_o>::blk_fmt == bf::_16i16o
293 || format_traits<fmt_o>::blk_fmt == bf::_8i16o2i
294 || format_traits<fmt_o>::blk_fmt == bf::_8o16i2o) &&
295 type_i == data_type::f32 && type_o == data_type::bf16)>::type>
297 static bool is_applicable(const memory_desc_wrapper &input_d,
298 const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
301 && input_d.format() == fmt_i && output_d.format() == fmt_o
302 && input_d.data_type() == f32 && output_d.data_type() == bf16;
305 static size_t get_scratchpad_size(const memory_desc_wrapper &input_d,
306 const memory_desc_wrapper &output_d) {
307 const int blksize = 16;
308 return sizeof(float) * blksize * blksize * mkldnn_get_max_threads();
311 static status_t execute(const cpu_reorder_pd_t *pd,
312 const data_t<type_i> *input, data_t<type_o> *output,
313 const memory_tracking::grantor_t &scratchpad) {
314 DECLARE_COMMON_PARAMS();
316 static constexpr bool w_groups = fmt_i == goihw;
317 const int blksize = 16;
320 const auto &_g_oihw_d = input_d;
321 const auto &dims = input_d.dims();
322 const auto &pdims = output_d.blocking_desc().padding_dims;
324 const int G = w_groups ? dims[0] : 1;
325 const int OC = dims[w_groups + 0];
326 const int NB_OC = pdims[w_groups + 0] / blksize;
327 const int IC = dims[w_groups + 1];
328 const int NB_IC = pdims[w_groups + 1] / blksize;
329 const int H = dims[w_groups + 2];
330 const int W = dims[w_groups + 3];
332 const size_t wsp_size = blksize * blksize;
333 float *wspace = scratchpad.template get<float>(
334 memory_tracking::names::key_reorder_space);
336 auto index = [&](const int ic, const int oc) {
337 if (format_traits<fmt_o>::blk_fmt == bf::_16i16o)
338 return (ic * blksize + oc);
339 else if (format_traits<fmt_o>::blk_fmt == bf::_8i16o2i)
340 return ((ic / sblk) * blksize * sblk + sblk * oc + ic % sblk);
341 else if (format_traits<fmt_o>::blk_fmt == bf::_8o16i2o)
342 return ((oc / sblk) * blksize * sblk + sblk * ic + oc % sblk);
344 assert(!"Invalid weight format");
348 auto ker = [&](const data_t<type_i> *inp, data_t<type_i> *out,
349 const int curr_oc_block, const int oc_block,
350 const int curr_ic_block, const int ic_block) {
352 for (ic = 0; ic < curr_ic_block; ++ic) {
354 for (oc = 0; oc < curr_oc_block; ++oc) {
355 const auto _g_oihw_off =
356 oc * _g_oihw_d.blocking_desc().strides[0][w_groups + 0]
357 + ic * _g_oihw_d.blocking_desc().strides[0][w_groups + 1];
358 out[index(ic, oc)] = inp[_g_oihw_off];
360 for (/* continue */; oc < oc_block; ++oc) {
361 out[index(ic, oc)] = (data_t<type_i>)0;
364 for (/* continue */; ic < ic_block; ++ic) {
365 for (int oc = 0; oc < oc_block; ++oc) {
366 out[index(ic, oc)] = (data_t<type_i>)0;
371 constexpr int i_mult = blksize;
372 constexpr int o_mult = 1;
374 parallel_nd(G, NB_OC, NB_IC, H, W, [&](int g, int O, int I, int h, int w) {
375 int ithr = mkldnn_get_thread_num();
376 float *_wspace = wspace + wsp_size * ithr;
377 auto i = &input[input_d.blk_off<!w_groups>(g,
378 i_mult * O, i_mult * I, h, w)];
379 auto o = &output[output_d.blk_off<!w_groups>(
380 g, o_mult * O, o_mult * I, h, w)];
381 const int oc_block = nstl::min(blksize, OC - O * blksize);
382 const int ic_block = nstl::min(blksize, IC - I * blksize);
383 ker(i, _wspace, oc_block, blksize, ic_block, blksize);
384 bf16_cvt_utils::cvt_float_to_bfloat16(o, _wspace, wsp_size);
392 template <SIMPLE_REORDER_TEMPL_DECL>
393 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
394 typename utils::enable_if<format_traits<fmt_i>::blk_fmt == bf::_16i16o &&
395 (fmt_o == goihw || fmt_o == oihw) &&
396 type_i == data_type::bf16 && type_o == data_type::f32>::type>
398 static bool is_applicable(const memory_desc_wrapper &input_d,
399 const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
402 && input_d.format() == fmt_i && output_d.format() == fmt_o
403 && input_d.data_type() == bf16 && output_d.data_type() == f32;
406 GET_SCRATCHPAD_SIZE_ZERO();
408 static status_t execute(const cpu_reorder_pd_t *pd,
409 const data_t<type_i> *input, data_t<type_o> *output,
410 const memory_tracking::grantor_t &scratchpad) {
411 DECLARE_COMMON_PARAMS();
413 static constexpr bool w_groups = fmt_o == goihw;
414 const int blksize = 16;
416 const auto &_g_oihw_d = output_d;
417 const auto &dims = input_d.dims();
418 const auto &pdims = input_d.blocking_desc().padding_dims;
420 const int G = w_groups ? dims[0] : 1;
421 const int OC = dims[w_groups + 0];
422 const int NB_OC = pdims[w_groups + 0] / blksize;
423 const int IC = dims[w_groups + 1];
424 const int NB_IC = pdims[w_groups + 1] / blksize;
425 const int H = dims[w_groups + 2];
426 const int W = dims[w_groups + 3];
428 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
429 int curr_oc_block, int curr_ic_block) {
430 for (int ic = 0; ic < curr_ic_block; ++ic) {
431 for (int oc = 0; oc < curr_oc_block; ++oc) {
432 const auto _g_oihw_off =
433 oc * _g_oihw_d.blocking_desc().strides[0][w_groups + 0]
434 + ic * _g_oihw_d.blocking_desc().strides[0][w_groups + 1];
435 bf16_cvt_utils::cvt_bfloat16_to_float(
436 &o[_g_oihw_off], &i[ic * blksize + oc]);
441 constexpr int i_mult = 1;
442 constexpr int o_mult = blksize;
444 parallel_nd(G, NB_OC, NB_IC, H, W, [&](int g, int O, int I, int h, int w) {
445 auto i = &input[input_d.blk_off<!w_groups>(
446 g, i_mult * O, i_mult * I, h, w)];
447 auto o = &output[output_d.blk_off<!w_groups>(
448 g, o_mult * O, o_mult * I, h, w)];
449 const int oc_block = nstl::min(blksize, OC - O * blksize);
450 const int ic_block = nstl::min(blksize, IC - I * blksize);
451 ker(i, o, oc_block, ic_block);
458 template <SIMPLE_REORDER_TEMPL_DECL>
459 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
460 typename utils::enable_if<
461 (fmt_i == nchw && fmt_o == nChw16c) &&
462 type_i == data_type::f32 && type_o == data_type::bf16>::type>
464 static bool is_applicable(const memory_desc_wrapper &input_d,
465 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
466 return input_d.format() == fmt_i && output_d.format() == fmt_o
467 && input_d.data_type() == f32 && output_d.data_type() == bf16;
470 static size_t get_scratchpad_size(const memory_desc_wrapper &input_d,
471 const memory_desc_wrapper &output_d) {
472 const size_t blksize = 16;
473 const size_t W = input_d.dims()[3];
474 return sizeof(float) * blksize * W * mkldnn_get_max_threads();
477 static status_t execute(const cpu_reorder_pd_t *pd,
478 const data_t<type_i> *input, data_t<type_o> *output,
479 const memory_tracking::grantor_t &scratchpad) {
480 DECLARE_COMMON_PARAMS();
482 constexpr int blksize = 16;
484 const auto &flat_d = input_d;
485 const auto &dims = input_d.dims();
486 const auto &pdims = output_d.blocking_desc().padding_dims;
488 const int C = dims[1];
489 const int H = dims[2];
490 const int W = dims[3];
492 const int wsp_size = W * blksize;
493 float *wspace = scratchpad.template get<float>(
494 memory_tracking::names::key_reorder_space);
496 auto ker = [&](const data_t<type_i> *i, data_t<type_i> *o,
497 const int curr_c_block, const int c_block) {
498 for (int w = 0; w < W; ++w) {
500 for (c = 0; c < curr_c_block; ++c) {
501 const ptrdiff_t flat_off = 0
502 + c * flat_d.blocking_desc().strides[0][1]
503 + w * flat_d.blocking_desc().strides[0][3];
504 o[w * blksize + c] = i[flat_off];
506 for (/* continue */; c < c_block; ++c) {
507 o[w * blksize + c] = (data_t<type_i>)0;
512 constexpr int i_c_mult = blksize;
513 constexpr int o_c_mult = 1;
515 parallel_nd(dims[0], pdims[1] / blksize, H, [&](int n, int nb_c, int h) {
516 int ithr = mkldnn_get_thread_num();
517 float *_wspace = wspace + wsp_size * ithr;
518 auto i = &input[input_d.blk_off(n, i_c_mult * nb_c, h)];
519 auto o = &output[output_d.blk_off(n, o_c_mult * nb_c, h)];
520 const int c_block = nstl::min(blksize, C - nb_c * blksize);
521 ker(i, _wspace, c_block, blksize);
522 bf16_cvt_utils::cvt_float_to_bfloat16(o, _wspace, wsp_size);
530 template <SIMPLE_REORDER_TEMPL_DECL>
531 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
532 typename utils::enable_if<
533 (fmt_i == nChw16c && fmt_o == nchw) &&
534 type_i == data_type::bf16 && type_o == data_type::f32>::type>
536 static bool is_applicable(const memory_desc_wrapper &input_d,
537 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
538 return input_d.format() == fmt_i && output_d.format() == fmt_o
539 && input_d.data_type() == bf16 && output_d.data_type() == f32;
542 GET_SCRATCHPAD_SIZE_ZERO();
544 static status_t execute(const cpu_reorder_pd_t *pd,
545 const data_t<type_i> *input, data_t<type_o> *output,
546 const memory_tracking::grantor_t &scratchpad) {
547 DECLARE_COMMON_PARAMS();
549 constexpr int blksize = 16;
550 const auto &flat_d = output_d;
551 const auto &dims = input_d.dims();
552 const auto &pdims = input_d.blocking_desc().padding_dims;
554 const int C = dims[1];
555 const int H = dims[2];
556 const int W = dims[3];
558 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
560 for (int w = 0; w < W; ++w)
561 for (int c = 0; c < c_block; ++c) {
562 const ptrdiff_t flat_off = 0
563 + c * flat_d.blocking_desc().strides[0][1]
564 + w * flat_d.blocking_desc().strides[0][3];
566 bf16_cvt_utils::cvt_bfloat16_to_float(
567 &o[flat_off], &i[w * blksize + c]);
571 constexpr int i_c_mult = 1;
572 constexpr int o_c_mult = blksize;
574 parallel_nd(dims[0], pdims[1] / blksize, H, [&](int n, int nb_c, int h) {
575 auto i = &input[input_d.blk_off(n, i_c_mult * nb_c, h)];
576 auto o = &output[output_d.blk_off(n, o_c_mult * nb_c, h)];
577 const int c_block = nstl::min(blksize, C - nb_c * blksize);
586 template <SIMPLE_REORDER_TEMPL_DECL>
587 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
588 typename utils::enable_if<true
589 && utils::one_of(fmt_i, goiw, goihw, hwigo)
590 && format_traits<fmt_o>::blk_fmt == bf::_16g_s8s8>::type> {
592 static bool is_applicable(const memory_desc_wrapper &input_d,
593 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
594 const size_t D_mask = utils::array_product(input_d.dims(),
595 math::ilog2q(attr->output_scales_.mask_ + 1));
596 const int oc = input_d.dims()[1];
597 const int g = input_d.dims()[0];
601 && input_d.format() == fmt_i
602 && output_d.format() == fmt_o
603 && utils::one_of(input_d.data_type(), f32, s8)
604 && output_d.data_type() == s8
605 && (D_mask == 1 || D_mask == (size_t)g * oc);
608 GET_SCRATCHPAD_SIZE_ZERO();
610 static status_t execute(const cpu_reorder_pd_t *pd,
611 const data_t<type_i> *input, data_t<type_o> *output,
612 const memory_tracking::grantor_t &scratchpad) {
613 DECLARE_COMMON_PARAMS();
615 constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
616 const int blksize = format_traits<fmt_o>::blk_size;
618 const auto &dims = input_d.dims();
619 const auto &pdims = output_d.blocking_desc().padding_dims;
620 const int G = dims[0];
621 const int Gp = pdims[0];
622 const int OC = dims[1];
623 const int IC = dims[2];
624 const int H = is_1d ? 1 : dims[3];
625 const int W = dims[4 - is_1d];
627 const size_t D_mask = utils::array_product(input_d.dims(),
628 math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
629 const float *scales = pd->attr()->output_scales_.scales_;
630 float adj_scale = (mayiuse(avx512_core_vnni)) ? 1.f : (1.f / 2.f);
633 auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out,
634 int32_t *cp, const float *s, const int g_block) {
636 for (int g = 0; g < g_block; g++) {
637 const auto i_off = g * input_d.blocking_desc().strides[0][0];
638 out[g] = qz_b0<data_t<type_i>, data_t<type_o>>()(
639 inp[i_off], s[g * OC] * adj_scale, rmode);
640 cp[g * OC] -= 128 * (int32_t)(out[g]);
644 size_t cp_offset = output_d.size() - output_d.additional_buffer_size();
645 int32_t *cp = reinterpret_cast<int32_t *>(output + cp_offset);
646 parallel_nd((Gp/blksize) * OC, [&](int ib) {
648 for (int i = 0; i < blksize; i++)
649 cp[ib * blksize + i] = 0;
652 parallel_nd(Gp/blksize, OC, [&](int gb, int O) {
653 for (int I = 0; I < IC; I++) {
654 for (int h = 0; h < H; h++) {
655 for (int w = 0; w < W; w++) {
656 const int g_block = nstl::min(G - gb * blksize, blksize);
657 const auto inp = &input[wei_blk_off_like_gwei3D<fmt_i>(
658 input_d, gb * blksize, O, I, 0, h, w)];
659 const auto out = &output[wei_blk_off_like_gwei3D<fmt_o>(
660 output_d, gb, O, I, 0, h, w)];
661 int offset = gb * blksize + O;
662 ker(inp, out, &cp[offset],
663 &scales[(D_mask == 1) ? 0 : offset], g_block);
672 template <SIMPLE_REORDER_TEMPL_DECL>
673 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
674 typename utils::enable_if<true
675 && format_traits<fmt_i>::blk_fmt == bf::_8i16o2i
676 && format_traits<fmt_o>::blk_fmt == bf::_8o16i2o>::type>
678 static bool is_applicable(const memory_desc_wrapper &input_d,
679 const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
681 return simple_fmt_check(order_keep, fmt_i, fmt_o, input_d, output_d)
682 && simple_attr_check(attr, false);
685 GET_SCRATCHPAD_SIZE_ZERO();
687 static status_t execute(const cpu_reorder_pd_t *pd,
688 const data_t<type_i> *input, data_t<type_o> *output,
689 const memory_tracking::grantor_t &scratchpad) {
690 DECLARE_COMMON_PARAMS();
692 static constexpr bool w_groups
693 = format_traits<fmt_o>::data_kind == dk::gwei;
694 constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
695 constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
696 constexpr int blksize = format_traits<fmt_o>::blk_size;
698 const auto &dims = input_d.dims();
700 const int G = w_groups ? dims[0] : 1;
701 const int NB_OC = dims[w_groups + 0] / blksize;
702 const int NB_IC = dims[w_groups + 1] / blksize;
703 const int D = is_3d ? dims[w_groups + 2] : 1;
704 const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
705 const int W = dims[w_groups + 3 + is_3d - is_1d];
707 auto idx_i = [&](const int oc, const int ic)
708 { return ((ic / 2) * blksize * 2 + 2 * oc + ic % 2); };
710 auto idx_o = [&](const int oc, const int ic)
711 { return ((oc / 2) * blksize * 2 + 2 * ic + oc % 2); };
713 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o) -> void {
714 if (alpha == 1.0 && beta == 0.0) {
715 for (int ic = 0; ic < blksize; ++ic) {
716 for (int oc = 0; oc < blksize; ++oc) {
717 o[idx_o(oc, ic)] = _qz_a1b0<type_i, type_o>()(
718 i[idx_i(oc, ic)], rmode);
722 for (int ic = 0; ic < blksize; ++ic) {
723 for (int oc = 0; oc < blksize; ++oc) {
724 o[idx_o(oc, ic)] = _qz<type_i, type_o>()(
725 i[idx_i(oc, ic)], o[idx_o(oc, ic)], alpha,
732 parallel_nd(G, NB_OC, NB_IC, D, H, W,
733 [&](int g, int o, int i, int d, int h, int w) {
734 auto ptr_i = &input[wei_blk_off_like_gwei3D<fmt_i>(
735 input_d, g, o, i, d, h, w)];
736 auto ptr_o = &output[wei_blk_off_like_gwei3D<fmt_o>(
737 output_d, g, o, i, d, h, w)];
745 /* reorders with tail support */
747 template <SIMPLE_REORDER_TEMPL_DECL>
748 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
749 typename utils::enable_if<fmt_i == nChw8c && fmt_o == nhwc && order_keep>::type>
751 static bool is_applicable(const memory_desc_wrapper &input_d,
752 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
753 int smask = attr ? attr->output_scales_.mask_ : 0;
754 return (smask == 0 || smask == 2) && order_keep && input_d._md->format == nChw8c && output_d._md->format == nhwc;
757 GET_SCRATCHPAD_SIZE_ZERO();
759 static status_t execute(const cpu_reorder_pd_t *pd,
760 const data_t<type_i> *input, data_t<type_o> *output,
761 const memory_tracking::grantor_t &scratchpad) {
762 DECLARE_COMMON_PARAMS();
764 const auto &pdims = input_d.blocking_desc().padding_dims;
765 const auto &dims = input_d.dims();
766 constexpr int blksize = format_traits<fmt_i>::blk_size;
767 const int C = dims[1];
768 const int H = dims[2];
769 const int W = dims[3];
771 constexpr int i_c_mult = 1;
772 constexpr int o_c_mult = blksize;
774 const float *scales = pd->attr()->output_scales_.scales_;
775 int smask = pd->attr()->output_scales_.mask_;
777 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
778 const int nb_c, const int c_block) {
780 for (int w = 0; w < W; ++w) {
781 const ptrdiff_t flat_off = w * output_d.blocking_desc().strides[0][3];
783 for (int c = 0; c < c_block; ++c) {
784 const float scale = scales[nb_c * blksize + c];
786 o[flat_off + c] = _qz<type_i, type_o>()(i[w * blksize + c],
787 o[flat_off + c], scale, beta, rmode);
791 for (int w = 0; w < W; ++w) {
792 const ptrdiff_t flat_off = w * output_d.blocking_desc().strides[0][3];
794 for (int c = 0; c < c_block; ++c) {
795 o[flat_off + c] = _qz_a1b0<type_i, type_o>()(i[w * blksize + c], rmode);
801 parallel_nd(dims[0], pdims[1] / blksize, H,
802 [&](int n, int nb_c, int h) {
803 auto i = &input[input_d.blk_off(n, i_c_mult * nb_c, h)];
804 auto o = &output[output_d.blk_off(n, o_c_mult * nb_c, h)];
805 const int c_block = nstl::min(blksize, C - nb_c * blksize);
806 ker(i, o, nb_c, c_block);
813 template <SIMPLE_REORDER_TEMPL_DECL>
814 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
815 typename utils::enable_if<fmt_i == nhwc && fmt_o == nChw8c>::type>
817 static bool is_applicable(const memory_desc_wrapper &input_d,
818 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
819 int smask = attr ? attr->output_scales_.mask_ : 0;
820 return (smask == 2) && order_keep && input_d._md->format == nhwc && output_d._md->format == nChw8c;
823 GET_SCRATCHPAD_SIZE_ZERO();
825 static status_t execute(const cpu_reorder_pd_t *pd,
826 const data_t<type_i> *input, data_t<type_o> *output,
827 const memory_tracking::grantor_t &scratchpad) {
828 DECLARE_COMMON_PARAMS();
830 const auto &pdims = output_d.blocking_desc().padding_dims;
831 const auto &dims = input_d.dims();
832 constexpr int blksize = format_traits<fmt_o>::blk_size;
833 const int C = dims[1];
834 const int H = dims[2];
835 const int W = dims[3];
837 constexpr int i_c_mult = blksize;
838 constexpr int o_c_mult = 1;
840 const float *scales = pd->attr()->output_scales_.scales_;
841 int smask = pd->attr()->output_scales_.mask_;
843 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
844 const int nb_c, const int c_block) {
846 for (int w = 0; w < W; ++w) {
847 const ptrdiff_t flat_off = w * input_d.blocking_desc().strides[0][3];
849 for (int c = 0; c < c_block; ++c) {
850 const float scale = scales[nb_c * blksize + c];
852 o[w * blksize + c] = _qz<type_i, type_o>()(i[flat_off + c],
853 o[w * blksize + c], scale, beta, rmode);
857 for (int w = 0; w < W; ++w) {
858 const ptrdiff_t flat_off = w * input_d.blocking_desc().strides[0][3];
860 for (int c = 0; c < c_block; ++c) {
861 o[w * blksize + c] = _qz_a1b0<type_i, type_o>()(i[flat_off + c], rmode);
867 parallel_nd(dims[0], pdims[1] / blksize, H,
868 [&](int n, int nb_c, int h) {
869 auto i = &input[input_d.blk_off(n, i_c_mult * nb_c, h)];
870 auto o = &output[output_d.blk_off(n, o_c_mult * nb_c, h)];
871 const int c_block = nstl::min(blksize, C - nb_c * blksize);
872 ker(i, o, nb_c, c_block);
879 template <SIMPLE_REORDER_TEMPL_DECL>
880 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
881 typename utils::enable_if<fmt_i == nhwc && fmt_o == nhwc && type_o != mkldnn_bin>::type>
883 static bool is_applicable(const memory_desc_wrapper &input_d,
884 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
885 int smask = attr ? attr->output_scales_.mask_ : 0;
886 return (smask == 2) && order_keep && input_d._md->format == nhwc && output_d._md->format == nhwc;
889 GET_SCRATCHPAD_SIZE_ZERO();
891 static status_t execute(const cpu_reorder_pd_t *pd,
892 const data_t<type_i> *input, data_t<type_o> *output,
893 const memory_tracking::grantor_t &scratchpad) {
894 DECLARE_COMMON_PARAMS();
896 const auto &dims = input_d.dims();
897 const int C = dims[1];
898 const int H = dims[2];
899 const int W = dims[3];
901 const float *scales = pd->attr()->output_scales_.scales_;
903 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o) {
904 for (int c = 0; c < C; ++c) {
905 const float scale = scales[c];
907 o[c] = _qz<type_i, type_o>()(i[c], o[c], scale, beta, rmode);
911 parallel_nd(dims[0], H, W,
912 [&](int n, int h, int w) {
913 auto i = &input[input_d.blk_off(n, 0, h, w)];
914 auto o = &output[output_d.blk_off(n, 0, h, w)];
922 template <SIMPLE_REORDER_TEMPL_DECL>
923 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
924 typename utils::enable_if<fmt_i == nchw && fmt_o == nhwc && type_i != mkldnn_bin && type_o != mkldnn_bin>::type>
926 static bool is_applicable(const memory_desc_wrapper &input_d,
927 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
928 int smask = attr ? attr->output_scales_.mask_ : 0;
929 return (smask == 0 || smask == 2) && order_keep && input_d._md->format == nchw && output_d._md->format == nhwc;
932 GET_SCRATCHPAD_SIZE_ZERO();
934 static status_t execute(const cpu_reorder_pd_t *pd,
935 const data_t<type_i> *input, data_t<type_o> *output,
936 const memory_tracking::grantor_t &scratchpad) {
937 DECLARE_COMMON_PARAMS();
939 const auto &dims = input_d.dims();
940 const int C = dims[1];
941 const int H = dims[2];
942 const int W = dims[3];
944 int smask = pd->attr()->output_scales_.mask_;
945 const float *scales = pd->attr()->output_scales_.scales_;
947 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o) {
949 for (int c = 0; c < C; ++c) {
950 const float scale = scales[c];
952 const ptrdiff_t flat_off = c * input_d.blocking_desc().strides[0][1];
954 o[c] = _qz<type_i, type_o>()(i[flat_off], o[c], scale, beta, rmode);
957 for (int c = 0; c < C; ++c) {
958 const ptrdiff_t flat_off = c * input_d.blocking_desc().strides[0][1];
960 o[c] = _qz_a1b0<type_i, type_o>()(i[flat_off], rmode);
965 parallel_nd(dims[0], H, W,
966 [&](int n, int h, int w) {
967 auto i = &input[input_d.blk_off(n, 0, h, w)];
968 auto o = &output[output_d.blk_off(n, 0, h, w)];
976 template <SIMPLE_REORDER_TEMPL_DECL>
977 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
978 typename utils::enable_if<(fmt_i == nchw || fmt_i == nhwc) && fmt_o == nhwc && (type_i == mkldnn_bin || type_o == mkldnn_bin)>::type>
980 static bool is_applicable(const memory_desc_wrapper &input_d,
981 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
982 int smask = attr ? attr->output_scales_.mask_ : 0;
983 return smask == 0 && order_keep && (input_d._md->format == nchw || input_d._md->format == nhwc) && output_d._md->format == nhwc;
986 GET_SCRATCHPAD_SIZE_ZERO();
988 static status_t execute(const cpu_reorder_pd_t *pd,
989 const data_t<type_i> *input, data_t<type_o> *output,
990 const memory_tracking::grantor_t &scratchpad) {
991 DECLARE_COMMON_PARAMS();
993 const auto &dims = input_d.dims();
994 const int C = dims[1];
995 const int H = dims[2];
996 const int W = dims[3];
999 const int CB = div_up(C, nbits);
1001 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o) {
1002 for (int cb = 0; cb < CB; ++cb) {
1003 uint8_t bin_val = 0x00;
1004 for (int c = cb * nbits, shift = 0; c < std::min(C, (cb + 1) * nbits); c++, shift++) {
1005 const ptrdiff_t flat_off = c * input_d.blocking_desc().strides[0][1];
1007 auto bit = uint8_t((i[flat_off] > 0) ? 0x01 : 0x00);
1008 bin_val |= (bit << shift);
1015 parallel_nd(dims[0], H, W,
1016 [&](int n, int h, int w) {
1017 auto iidx = input_d.blk_off(n, 0, h, w);
1018 auto oidx = output_d.blk_off(n, 0, h, w);
1020 auto i = &input[iidx];
1021 auto o = &output[oidx / nbits];
1029 template <SIMPLE_REORDER_TEMPL_DECL>
1030 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1031 typename utils::enable_if<fmt_i == nhwc && fmt_o == nchw>::type>
1033 static bool is_applicable(const memory_desc_wrapper &input_d,
1034 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
1035 int smask = attr ? attr->output_scales_.mask_ : 0;
1036 return (smask == 0 || smask == 2) && order_keep && input_d._md->format == nhwc && output_d._md->format == nchw;
1039 GET_SCRATCHPAD_SIZE_ZERO();
1041 static status_t execute(const cpu_reorder_pd_t *pd,
1042 const data_t<type_i> *input, data_t<type_o> *output,
1043 const memory_tracking::grantor_t &scratchpad) {
1044 DECLARE_COMMON_PARAMS();
1046 const auto &dims = input_d.dims();
1047 const int C = dims[1];
1048 const int H = dims[2];
1049 const int W = dims[3];
1051 int smask = pd->attr()->output_scales_.mask_;
1052 const float *scales = pd->attr()->output_scales_.scales_;
1054 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o) {
1056 for (int c = 0; c < C; ++c) {
1057 const float scale = scales[c];
1059 const ptrdiff_t flat_off = c * output_d.blocking_desc().strides[0][1];
1061 o[flat_off] = _qz<type_i, type_o>()(i[c], o[flat_off], scale, beta, rmode);
1064 for (int c = 0; c < C; ++c) {
1065 const ptrdiff_t flat_off = c * output_d.blocking_desc().strides[0][1];
1067 o[flat_off] = _qz_a1b0<type_i, type_o>()(i[c], rmode);
1072 parallel_nd(dims[0], H, W,
1073 [&](int n, int h, int w) {
1074 auto i = &input[input_d.blk_off(n, 0, h, w)];
1075 auto o = &output[output_d.blk_off(n, 0, h, w)];
1083 template <SIMPLE_REORDER_TEMPL_DECL>
1084 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1085 typename utils::enable_if<true
1086 && (format_traits<fmt_i>::blk_fmt == bf::_4c
1087 || format_traits<fmt_i>::blk_fmt == bf::_8c)
1088 && format_traits<fmt_o>::blk_fmt == bf::_16c>::type>
1090 static bool is_applicable(const memory_desc_wrapper &input_d,
1091 const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
1093 return simple_fmt_check(order_keep, fmt_i, fmt_o, input_d, output_d)
1094 && simple_attr_check(attr, false);
1097 GET_SCRATCHPAD_SIZE_ZERO();
1099 static status_t execute(const cpu_reorder_pd_t *pd,
1100 const data_t<type_i> *input, data_t<type_o> *output,
1101 const memory_tracking::grantor_t &scratchpad) {
1102 DECLARE_COMMON_PARAMS();
1104 constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
1105 constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
1106 constexpr int blksize_fmt_o = format_traits<fmt_o>::blk_size;
1107 constexpr int blksize_fmt_i = format_traits<fmt_i>::blk_size;
1108 constexpr int ic_mult = order_keep ? 2 : 1;
1109 constexpr int oc_mult = order_keep ? 1 : 2;
1111 const auto &fmt_i_d = order_keep ? input_d : output_d;
1112 const auto &dims = input_d.dims();
1113 const auto &pdims = order_keep ? output_d.blocking_desc().padding_dims
1114 : input_d.blocking_desc().padding_dims;
1115 const auto stride_fmt_i = fmt_i_d.blocking_desc().strides[0];
1117 const int C = dims[1];
1118 const int D = is_3d ? dims[2] : 1;
1119 const int H = is_1d ? 1 : dims[2 + is_3d];
1120 const int W = dims[3 + is_3d - is_1d];
1122 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
1123 const int block_fmt_o) {
1124 const int nb = (block_fmt_o - 1) / blksize_fmt_i + 1;
1125 if (alpha == 1.0 && beta == 0.0) {
1126 for (int b = 0; b < nb; ++b) {
1127 const ptrdiff_t i_off = order_keep ? b * stride_fmt_i[1]
1128 : b * blksize_fmt_i;
1129 const ptrdiff_t o_off = order_keep ? b * blksize_fmt_i
1130 : b * stride_fmt_i[1];
1131 const int block_fmt_i = nstl::min(blksize_fmt_i,
1132 block_fmt_o - b * blksize_fmt_i);
1133 for (int c = 0; c < block_fmt_i; ++c) {
1134 o[o_off + c] = _qz_a1b0<type_i, type_o>()(
1135 i[i_off + c], rmode);
1139 for (int b = 0; b < nb; ++b) {
1140 const ptrdiff_t i_off = order_keep ? b * stride_fmt_i[1]
1141 : b * blksize_fmt_i;
1142 const ptrdiff_t o_off = order_keep ? b * blksize_fmt_i
1143 : b * stride_fmt_i[1];
1144 const int block_fmt_i = nstl::min(blksize_fmt_i,
1145 block_fmt_o - b * blksize_fmt_i);
1146 for (int c = 0; c < block_fmt_i; ++c) {
1147 o[o_off + c] = _qz<type_i, type_o>()(i[i_off + c],
1148 o[o_off + c], alpha, beta, rmode);
1154 # define data_blk_off(md, n, c, d, h, w) \
1155 ( is_1d ? (md).blk_off(n, c, w) \
1156 : is_3d ? (md).blk_off(n, c, d, h, w) : (md).blk_off(n, c, h, w))
1158 parallel_nd(dims[0], pdims[1] / blksize_fmt_o, D, H, W,
1159 [&](int n, int nb_c, int d, int h, int w) {
1160 auto i = &input[data_blk_off(input_d, n, ic_mult * nb_c, d, h, w)];
1161 auto o = &output[data_blk_off(output_d, n, oc_mult * nb_c, d, h, w)];
1162 const int block_fmt_o = nstl::min(blksize_fmt_o, C - nb_c * blksize_fmt_o);
1163 ker(i, o, block_fmt_o);
1166 # undef data_blk_off
1172 #define PLAIN_TO_BLOCKED_IS_APPLICABLE() \
1173 static bool is_applicable(const memory_desc_wrapper &input_d, \
1174 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { \
1175 return simple_attr_check(attr, false) && (order_keep \
1176 ? output_d.format() == fmt_o && input_d.is_plain() \
1177 : input_d.format() == fmt_o && output_d.is_plain()); \
1180 template <SIMPLE_REORDER_TEMPL_DECL>
1181 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1182 typename utils::enable_if<fmt_i == any && (false
1183 || format_traits<fmt_o>::blk_fmt == bf::_4c
1184 || format_traits<fmt_o>::blk_fmt == bf::_8c
1185 || format_traits<fmt_o>::blk_fmt == bf::_16c)>::type>
1187 PLAIN_TO_BLOCKED_IS_APPLICABLE();
1189 GET_SCRATCHPAD_SIZE_ZERO();
1191 static status_t execute(const cpu_reorder_pd_t *pd,
1192 const data_t<type_i> *input, data_t<type_o> *output,
1193 const memory_tracking::grantor_t &scratchpad) {
1194 DECLARE_COMMON_PARAMS();
1196 constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
1197 constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
1198 constexpr int blksize = format_traits<fmt_o>::blk_size;
1200 const auto &flat_d = order_keep ? input_d : output_d;
1201 const auto &dims = input_d.dims();
1202 const auto &pdims = order_keep
1203 ? output_d.blocking_desc().padding_dims
1204 : input_d.blocking_desc().padding_dims;
1206 const int C = dims[1];
1207 const int D = is_3d ? dims[2] : 1;
1208 const int H = is_1d ? 1 : dims[2 + is_3d];
1209 const int W = dims[3 + is_3d - is_1d];
1211 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
1212 const int c_block) {
1213 if (alpha == 1.0 && beta == 0.0) {
1214 for (int w = 0; w < W; ++w)
1215 for (int c = 0; c < c_block; ++c) {
1216 const ptrdiff_t flat_off = 0
1217 + c * flat_d.blocking_desc().strides[0][1]
1218 + w * flat_d.blocking_desc().strides[0][3 + is_3d
1221 o[w * blksize + c] = _qz_a1b0<type_i, type_o>()(
1222 i[flat_off], rmode);
1224 o[flat_off] = _qz_a1b0<type_i, type_o>()(
1225 i[w * blksize + c], rmode);
1229 for (int w = 0; w < W; ++w)
1230 for (int c = 0; c < c_block; ++c) {
1231 const ptrdiff_t flat_off = 0
1232 + c * flat_d.blocking_desc().strides[0][1]
1233 + w * flat_d.blocking_desc().strides[0][3 + is_3d
1236 o[w * blksize + c] = _qz<type_i, type_o>()(i[flat_off],
1237 o[w * blksize + c], alpha, beta, rmode);
1239 o[flat_off] = _qz<type_i, type_o>()(i[w * blksize + c],
1240 o[flat_off], alpha, beta, rmode);
1246 constexpr int i_c_mult = order_keep ? blksize : 1;
1247 constexpr int o_c_mult = order_keep ? 1 : blksize;
1249 # define data_blk_off(md, n, c, d, h) \
1250 ( is_1d ? (md).blk_off(n, c) \
1251 : is_3d ? (md).blk_off(n, c, d, h) : (md).blk_off(n, c, h))
1253 parallel_nd(dims[0], pdims[1] / blksize, D, H,
1254 [&](int n, int nb_c, int d, int h) {
1255 auto i = &input[data_blk_off(input_d, n, i_c_mult * nb_c, d, h)];
1256 auto o = &output[data_blk_off(output_d, n, o_c_mult * nb_c, d, h)];
1257 const int c_block = nstl::min(blksize, C - nb_c * blksize);
1261 # undef data_blk_off
1267 template <SIMPLE_REORDER_TEMPL_DECL>
1268 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1269 typename utils::enable_if<
1270 (fmt_i == goihw && fmt_o == gOhIw8o4i_s8s8)
1271 || (fmt_i == oihw && fmt_o == OhIw8o4i_s8s8)
1272 || (fmt_i == goidhw && fmt_o == gOdhIw8o4i_s8s8)
1273 || (fmt_i == oidhw && fmt_o == OdhIw8o4i_s8s8)
1276 static bool is_applicable(const memory_desc_wrapper &input_d,
1277 const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
1279 const size_t D_mask = utils::array_product(input_d.dims(),
1280 math::ilog2q(attr->output_scales_.mask_ + 1));
1281 const int oc = (input_d.dims()[(fmt_i == goihw || fmt_i == goidhw) + 0]);
1282 const int g = (fmt_i == goihw || fmt_i == goidhw) ? (input_d.dims()[0]) : 1;
1284 return input_d.format() == fmt_i
1285 && output_d.format() == fmt_o
1286 && (input_d.data_type() == f32 || input_d.data_type() == s8)
1287 && output_d.data_type() == s8
1288 && (D_mask == 1 || D_mask == (size_t)g * oc);
1291 GET_SCRATCHPAD_SIZE_ZERO();
1293 static status_t execute(const cpu_reorder_pd_t *pd,
1294 const data_t<type_i> *input, data_t<type_o> *output,
1295 const memory_tracking::grantor_t &scratchpad) {
1296 DECLARE_COMMON_PARAMS();
1298 static constexpr bool w_groups
1299 = format_traits<fmt_o>::data_kind == dk::gwei;
1300 int is_3d = format_traits<fmt_o>::ndims_sp == 3;
1301 constexpr int blksize_o = 8;
1302 constexpr int blksize_i = 4;
1304 const auto &flat_d = order_keep ? input_d : output_d;
1305 const auto &dims = input_d.dims();
1306 const auto &pdims = order_keep
1307 ? output_d.blocking_desc().padding_dims
1308 : input_d.blocking_desc().padding_dims;
1310 const int G = w_groups ? dims[0] : 1;
1311 const int OC = dims[w_groups + 0];
1312 const int NB_OC = pdims[w_groups + 0] / blksize_o;
1313 const int IC = dims[w_groups + 1];
1314 const int NB_IC = pdims[w_groups + 1] / blksize_i;
1315 const int D = is_3d ? dims[w_groups + 2] : 1;
1316 const int H = dims[w_groups + 2 + is_3d];
1317 const int W = dims[w_groups + 3 + is_3d];
1319 const float *scales = pd->attr()->output_scales_.scales_;
1320 const size_t D_mask = utils::array_product(input_d.dims(),
1321 math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
1323 float adj_scale = (mayiuse(avx512_core_vnni)) ? 1.0 : (1.0 / 2.0);
1325 auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out,
1326 int32_t *c, const float *s, const int oc_block, const int ic_block) {
1327 # define blk_off OI_blk_off<format_traits<fmt_o>::blk_fmt>
1329 for (int ic = 0; ic < ic_block; ++ic) {
1330 for (int oc = 0; oc < oc_block; ++oc) {
1331 const auto _g_oihw_off = oc * flat_d.blocking_desc().strides[0][w_groups + 0] +
1332 ic * flat_d.blocking_desc().strides[0][w_groups + 1];
1335 out[blk_off(oc, ic)] = qz_b0<data_t<type_i>, data_t<type_o>>()(inp[_g_oihw_off], s[oc] * adj_scale, rmode);
1336 c[oc] -= (128 * (int32_t)(out[blk_off(oc, ic)]));
1338 out[_g_oihw_off] = qz_b0<data_t<type_i>, data_t<type_o>>()(inp[blk_off(oc, ic)], s[oc] * adj_scale, rmode);
1339 c[oc] -= (128 * (int32_t)(out[_g_oihw_off]));
1347 constexpr int i_mult_o = blksize_o;
1348 constexpr int i_mult_i = blksize_i;
1350 size_t offset = G * pdims[w_groups+0] * pdims[w_groups+1] * D * H * W;
1351 int32_t *cp = reinterpret_cast<int32_t *>(output + offset);
1352 parallel_nd(G * NB_OC * blksize_o, [&](int i) {
1356 parallel_nd(G, NB_OC, [&](int g, int O) {
1357 for (int I = 0; I < NB_IC; I++) {
1358 for (int d = 0; d < D; d++) {
1359 for (int h = 0; h < H; h++) {
1360 for (int w = 0; w < W; w++) {
1361 auto i = is_3d ? &input[input_d.blk_off<!w_groups>(g, i_mult_o * O, i_mult_i * I, d, h, w)]
1362 : &input[input_d.blk_off<!w_groups>(g, i_mult_o * O, i_mult_i * I, h, w)];
1363 auto o = is_3d ? &output[output_d.blk_off<!w_groups>(g, O, I, d, h, w)]
1364 : &output[output_d.blk_off<!w_groups>(g, O, I, h, w)];
1365 const int oc_block = nstl::min(blksize_o, OC - O * blksize_o);
1366 const int ic_block = nstl::min(blksize_i, IC - I * blksize_i);
1368 int _offset = (g * NB_OC + O) * blksize_o;
1369 ker(i, o, (order_keep) ? &cp[_offset] : nullptr, &scales[(D_mask == 1) ? 0 : _offset],
1382 template <SIMPLE_REORDER_TEMPL_DECL>
1383 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1384 typename utils::enable_if<fmt_i == any && (fmt_o == OhIw8o4i || fmt_o == gOhIw8o4i)>::type>
1386 PLAIN_TO_BLOCKED_IS_APPLICABLE();
1388 GET_SCRATCHPAD_SIZE_ZERO();
1390 static status_t execute(const cpu_reorder_pd_t *pd,
1391 const data_t<type_i> *input, data_t<type_o> *output,
1392 const memory_tracking::grantor_t &scratchpad) {
1393 DECLARE_COMMON_PARAMS();
1395 static constexpr bool w_groups
1396 = format_traits<fmt_o>::data_kind == dk::gwei;
1397 constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
1398 constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
1399 constexpr int blksize_o = 8;//format_traits<fmt_o>::blk_size;
1400 constexpr int blksize_i = 4;
1402 const auto &flat_d = order_keep ? input_d : output_d;
1403 const auto &dims = input_d.dims();
1404 const auto &pdims = order_keep
1405 ? output_d.blocking_desc().padding_dims
1406 : input_d.blocking_desc().padding_dims;
1408 const int G = w_groups ? dims[0] : 1;
1409 const int OC = dims[w_groups + 0];
1410 const int NB_OC = pdims[w_groups + 0] / blksize_o;
1411 const int IC = dims[w_groups + 1];
1412 const int NB_IC = pdims[w_groups + 1] / blksize_i;
1413 const int D = is_3d ? dims[w_groups + 2] : 1;
1414 const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
1415 const int W = dims[w_groups + 3 + is_3d - is_1d];
1417 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
1418 const int oc_block, const int ic_block) {
1419 # define blk_off OI_blk_off<format_traits<fmt_o>::blk_fmt>
1421 if (alpha == 1.0 && beta == 0.0) {
1422 for (int oc = 0; oc < oc_block; ++oc)
1423 for (int ic = 0; ic < ic_block; ++ic) {
1424 const ptrdiff_t flat_off = 0
1425 + oc * flat_d.blocking_desc().strides[0][w_groups + 0]
1426 + ic * flat_d.blocking_desc().strides[0][w_groups + 1];
1428 o[blk_off(oc, ic)] = _qz_a1b0<type_i, type_o>()(
1429 i[flat_off], rmode);
1431 o[flat_off] = _qz_a1b0<type_i, type_o>()(
1432 i[blk_off(oc, ic)], rmode);
1436 for (int oc = 0; oc < oc_block; ++oc)
1437 for (int ic = 0; ic < ic_block; ++ic) {
1438 const ptrdiff_t flat_off = 0
1439 + oc * flat_d.blocking_desc().strides[0][w_groups + 0]
1440 + ic * flat_d.blocking_desc().strides[0][w_groups + 1];
1442 o[blk_off(oc, ic)] = _qz<type_i, type_o>()(i[flat_off],
1443 o[blk_off(oc, ic)], alpha, beta, rmode);
1445 o[flat_off] = _qz<type_i, type_o>()(i[blk_off(oc, ic)],
1446 o[flat_off], alpha, beta, rmode);
1455 constexpr int i_mult_o = blksize_o;
1456 constexpr int i_mult_i = blksize_i;
1458 parallel_nd(G, NB_OC, NB_IC, D, H, W,
1459 [&](int g, int nb_oc, int nb_ic, int d, int h, int w) {
1460 int i_off = wei_blk_off_like_gwei3D<fmt_o>(input_d,
1461 g, i_mult_o * nb_oc, i_mult_i * nb_ic, d, h, w);
1462 int o_off = wei_blk_off_like_gwei3D<fmt_o>(output_d,
1463 g, nb_oc, nb_ic, d, h, w);
1464 auto i = &input[i_off];
1465 auto o = &output[o_off];
1466 const int oc_block = nstl::min(blksize_o, OC - nb_oc * blksize_o);
1467 const int ic_block = nstl::min(blksize_i, IC - nb_ic * blksize_i);
1468 ker(i, o, oc_block, ic_block);
1475 template <SIMPLE_REORDER_TEMPL_DECL>
1476 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1477 typename utils::enable_if<fmt_i == any && (fmt_o == OhIw8o32i || fmt_o == OhIw16o32i) && type_i == mkldnn_bin && type_o == mkldnn_bin>::type>
1479 PLAIN_TO_BLOCKED_IS_APPLICABLE();
1481 GET_SCRATCHPAD_SIZE_ZERO();
1483 static status_t execute(const cpu_reorder_pd_t *pd,
1484 const data_t<type_i> *input, data_t<type_o> *output,
1485 const memory_tracking::grantor_t &scratchpad) {
1486 DECLARE_COMMON_PARAMS();
1488 static constexpr bool w_groups
1489 = format_traits<fmt_o>::data_kind == dk::gwei;
1490 constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
1491 constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
1492 constexpr int blksize_o = fmt_o == OhIw8o32i ? 8 : 16;
1493 constexpr int blksize_i = 32;
1495 const auto &dims = input_d.dims();
1496 const auto &pdims = order_keep
1497 ? output_d.blocking_desc().padding_dims
1498 : input_d.blocking_desc().padding_dims;
1500 const int G = w_groups ? dims[0] : 1;
1501 const int OC = dims[w_groups + 0];
1502 const int NB_OC = pdims[w_groups + 0] / blksize_o;
1503 const int IC = dims[w_groups + 1];
1504 const int NB_IC = pdims[w_groups + 1] / blksize_i;
1505 const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
1506 const int W = dims[w_groups + 3 + is_3d - is_1d];
1508 constexpr int i_mult_o = blksize_o;
1509 constexpr int i_mult_i = blksize_i;
1510 constexpr int nbits = 8;
1512 auto extract_bit = [](uint8_t val, uint8_t bit) -> uint8_t {
1513 return (uint8_t) ((val >> bit) & 0x0001);
1516 parallel_nd(G, NB_OC, NB_IC, H, W,
1517 [&](int g, int nb_oc, int nb_ic, int h, int w) {
1518 const int oc_block = nstl::min(blksize_o, OC - nb_oc * blksize_o);
1519 const int ic_block = nstl::min(blksize_i, IC - nb_ic * blksize_i);
1521 for (int oc = 0; oc < oc_block; ++oc) {
1522 for (int icb = 0; icb < div_up(ic_block, nbits); ++icb) {
1524 uint8_t bin_val = 0x00;
1525 for (int ic = icb*nbits, shift = 0; ic < std::min(IC, (icb + 1)*nbits); ic++, shift++) {
1526 size_t iidx = (i_mult_o * nb_oc + oc) * input_d.blocking_desc().strides[0][0] +
1527 (i_mult_i * nb_ic + ic) * input_d.blocking_desc().strides[0][1] +
1528 h * input_d.blocking_desc().strides[0][2] +
1531 uint8_t bit = extract_bit(input[iidx / nbits], (uint8_t)(iidx % nbits));
1532 bin_val |= (bit << shift);
1535 size_t oidx = wei_blk_off_like_gwei3D<fmt_o>(output_d, g, nb_oc, nb_ic, 0, h, w) + oc * blksize_i + icb * nbits;
1536 output[oidx / nbits] = bin_val;
1546 template <SIMPLE_REORDER_TEMPL_DECL>
1547 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1548 typename utils::enable_if<fmt_i == any
1549 && block_format_traits<format_traits<fmt_o>::blk_fmt>::blk_ndims == 2
1550 && fmt_o != OhIw8o4i && fmt_o != gOhIw8o4i && fmt_o != OhIw8o32i && fmt_o != OhIw16o32i>::type>
1552 PLAIN_TO_BLOCKED_IS_APPLICABLE();
1554 GET_SCRATCHPAD_SIZE_ZERO();
1556 static status_t execute(const cpu_reorder_pd_t *pd,
1557 const data_t<type_i> *input, data_t<type_o> *output,
1558 const memory_tracking::grantor_t &scratchpad) {
1559 DECLARE_COMMON_PARAMS();
1561 static constexpr bool w_groups
1562 = format_traits<fmt_o>::data_kind == dk::gwei;
1563 constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
1564 constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
1565 constexpr int blksize = format_traits<fmt_o>::blk_size;
1567 const auto &flat_d = order_keep ? input_d : output_d;
1568 const auto &dims = input_d.dims();
1569 const auto &pdims = order_keep
1570 ? output_d.blocking_desc().padding_dims
1571 : input_d.blocking_desc().padding_dims;
1573 const int G = w_groups ? dims[0] : 1;
1574 const int OC = dims[w_groups + 0];
1575 const int NB_OC = pdims[w_groups + 0] / blksize;
1576 const int IC = dims[w_groups + 1];
1577 const int NB_IC = pdims[w_groups + 1] / blksize;
1578 const int D = is_3d ? dims[w_groups + 2] : 1;
1579 const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
1580 const int W = dims[w_groups + 3 + is_3d - is_1d];
1582 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
1583 const int oc_block, const int ic_block) {
1584 # define blk_off OI_blk_off<format_traits<fmt_o>::blk_fmt>
1586 if (alpha == 1.0 && beta == 0.0) {
1587 for (int oc = 0; oc < oc_block; ++oc)
1588 for (int ic = 0; ic < ic_block; ++ic) {
1589 const ptrdiff_t flat_off = 0
1590 + oc * flat_d.blocking_desc().strides[0][w_groups + 0]
1591 + ic * flat_d.blocking_desc().strides[0][w_groups + 1];
1593 o[blk_off(oc, ic)] = _qz_a1b0<type_i, type_o>()(
1594 i[flat_off], rmode);
1596 o[flat_off] = _qz_a1b0<type_i, type_o>()(
1597 i[blk_off(oc, ic)], rmode);
1601 for (int oc = 0; oc < oc_block; ++oc)
1602 for (int ic = 0; ic < ic_block; ++ic) {
1603 const ptrdiff_t flat_off = 0
1604 + oc * flat_d.blocking_desc().strides[0][w_groups + 0]
1605 + ic * flat_d.blocking_desc().strides[0][w_groups + 1];
1607 o[blk_off(oc, ic)] = _qz<type_i, type_o>()(i[flat_off],
1608 o[blk_off(oc, ic)], alpha, beta, rmode);
1610 o[flat_off] = _qz<type_i, type_o>()(i[blk_off(oc, ic)],
1611 o[flat_off], alpha, beta, rmode);
1620 constexpr int i_mult = order_keep ? blksize : 1;
1621 constexpr int o_mult = order_keep ? 1 : blksize;
1623 parallel_nd(G, NB_OC, NB_IC, D, H, W,
1624 [&](int g, int nb_oc, int nb_ic, int d, int h, int w) {
1625 auto i = &input[wei_blk_off_like_gwei3D<fmt_o>(input_d,
1626 g, i_mult * nb_oc, i_mult * nb_ic, d, h, w)];
1627 auto o = &output[wei_blk_off_like_gwei3D<fmt_o>(output_d,
1628 g, o_mult * nb_oc, o_mult * nb_ic, d, h, w)];
1629 const int oc_block = nstl::min(blksize, OC - nb_oc * blksize);
1630 const int ic_block = nstl::min(blksize, IC - nb_ic * blksize);
1631 ker(i, o, oc_block, ic_block);
1638 template <SIMPLE_REORDER_TEMPL_DECL>
1639 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1640 typename utils::enable_if<fmt_i == any && (false
1641 || format_traits<fmt_o>::blk_fmt == bf::_4o
1642 || format_traits<fmt_o>::blk_fmt == bf::_8o
1643 || format_traits<fmt_o>::blk_fmt == bf::_16o)>::type>
1645 PLAIN_TO_BLOCKED_IS_APPLICABLE();
1647 GET_SCRATCHPAD_SIZE_ZERO();
1649 static status_t execute(const cpu_reorder_pd_t *pd,
1650 const data_t<type_i> *input, data_t<type_o> *output,
1651 const memory_tracking::grantor_t &scratchpad) {
1652 DECLARE_COMMON_PARAMS();
1654 static constexpr bool w_groups
1655 = format_traits<fmt_o>::data_kind == dk::gwei;
1656 constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
1657 constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
1658 constexpr int blksize = format_traits<fmt_o>::blk_size;
1660 const auto &flat_d = order_keep ? input_d : output_d;
1661 const auto &dims = input_d.dims();
1662 const auto &pdims = order_keep
1663 ? output_d.blocking_desc().padding_dims
1664 : input_d.blocking_desc().padding_dims;
1666 const int G = w_groups ? dims[0] : 1;
1667 const int OC = dims[w_groups + 0];
1668 const int IC = dims[w_groups + 1];
1669 const int D = is_3d ? dims[w_groups + 2] : 1;
1670 const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
1671 const int W = dims[w_groups + 3 + is_3d - is_1d];
1673 constexpr int i_mult = order_keep ? blksize : 1;
1674 constexpr int o_mult = order_keep ? 1 : blksize;
1675 const auto strd_oc = flat_d.blocking_desc().strides[0][w_groups];
1677 parallel_nd(G, pdims[w_groups + 0] / blksize, IC, D, H, W,
1678 [&](int g, int nb_oc, int ic, int d, int h, int w) {
1679 auto i = &input[wei_blk_off_like_gwei3D<fmt_o>(input_d,
1680 g, i_mult * nb_oc, ic, d, h, w)];
1681 auto o = &output[wei_blk_off_like_gwei3D<fmt_o>(output_d,
1682 g, o_mult * nb_oc, ic, d, h, w)];
1683 const int oc_block = nstl::min(blksize, OC - nb_oc * blksize);
1685 if (alpha == 1.0 && beta == 0.0) {
1686 for (int oc = 0; oc < oc_block; ++oc) {
1687 const auto off = oc * strd_oc;
1689 o[oc] = _qz_a1b0<type_i, type_o>()(i[off], rmode);
1691 o[off] = _qz_a1b0<type_i, type_o>()(i[oc], rmode);
1695 for (int oc = 0; oc < oc_block; ++oc) {
1696 const auto off = oc * strd_oc;
1698 o[oc] = _qz<type_i, type_o>()(i[off], o[oc], alpha,
1701 o[off] = _qz<type_i, type_o>()(i[oc], o[off], alpha,
1712 /* generic and direct-copy reorders */
1714 template <SIMPLE_REORDER_TEMPL_DECL>
1715 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1716 typename utils::enable_if<
1717 fmt_i == any && fmt_o == any && order_keep == fmt_order::any,
1718 spec::direct_copy>::type>
1720 static bool is_applicable(const memory_desc_wrapper &input_d,
1721 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
1722 /* FIXME: is the formula correct? */
1723 return input_d.similar_to(output_d, true, false, 0)
1724 && input_d.is_dense() && output_d.is_dense()
1725 && simple_attr_check(attr, false);
1728 GET_SCRATCHPAD_SIZE_ZERO();
1730 static status_t execute(const cpu_reorder_pd_t *pd,
1731 const data_t<type_i> *input, data_t<type_o> *output,
1732 const memory_tracking::grantor_t &scratchpad) {
1733 DECLARE_COMMON_PARAMS();
1735 assert(input_d.is_dense());
1737 input += input_d.blk_off(0);
1738 output += output_d.blk_off(0);
1740 const size_t nelems = input_d.nelems();
1742 constexpr int block_size = 16;
1743 const auto num_blocks = nelems / block_size;
1744 const auto rem_elems = nelems % block_size;
1746 parallel(0, num_blocks, [&](const int ithr, const int nthr) {
1747 size_t start{0}, end{0};
1748 balance211(num_blocks, nthr, ithr, start, end);
1749 start = start * block_size;
1750 end = end * block_size;
1752 if (alpha == 1.0 && beta == 0.0) {
1754 for (size_t e = start; e < end; ++e) {
1755 output[e] = qz_a1b0<data_t<type_i>, data_t<type_o>>()
1758 } else if (alpha == 1.0) {
1760 for (size_t e = start; e < end; ++e) {
1761 output[e] = qz_a1<data_t<type_i>, data_t<type_o>>()
1762 (input[e], output[e], beta, rmode);
1764 } else if (beta == 0.0) {
1766 for (size_t e = start; e < end; ++e) {
1767 output[e] = qz_b0<data_t<type_i>, data_t<type_o>>()
1768 (input[e], alpha, rmode);
1772 for (size_t e = start; e < end; ++e) {
1773 output[e] = qz<data_t<type_i>, data_t<type_o>>()
1774 (input[e], output[e], alpha, beta, rmode);
1778 if (rem_elems != 0 && ithr == nthr - 1){
1779 if (alpha == 1.0 && beta == 0.0) {
1781 for (size_t e = nelems - rem_elems; e < nelems; ++e) {
1782 output[e] = qz_a1b0<data_t<type_i>,
1783 data_t<type_o>>()(input[e], rmode);
1785 } else if (alpha == 1.0) {
1787 for (size_t e = nelems - rem_elems; e < nelems; ++e) {
1788 output[e] = qz_a1<data_t<type_i>,
1789 data_t<type_o>>()(input[e], output[e], beta, rmode);
1791 } else if (beta == 0.0) {
1793 for (size_t e = nelems - rem_elems; e < nelems; ++e) {
1794 output[e] = qz_b0<data_t<type_i>,
1795 data_t<type_o>>()(input[e], alpha, rmode);
1799 for (size_t e = nelems - rem_elems; e < nelems; ++e) {
1800 output[e] = qz<data_t<type_i>, data_t<type_o>>()
1801 (input[e], output[e], alpha, beta, rmode);
1810 template <SIMPLE_REORDER_TEMPL_DECL>
1811 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1812 typename utils::enable_if<
1813 fmt_i == any && fmt_o == any && order_keep == fmt_order::any,
1814 spec::direct_copy_except_dim_0>::type>
1816 static bool is_applicable(const memory_desc_wrapper &input_d,
1817 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
1818 auto is_dense_no_0 = [](const memory_desc_wrapper &data_d) {
1819 return nelems_no_dim_0(data_d) == _size_no_dim_0(data_d);
1821 /* FIXME: is the formula correct? */
1822 return input_d.similar_to(output_d, true, false, 1)
1823 && is_dense_no_0(input_d) && is_dense_no_0(output_d)
1824 && simple_attr_check(attr, false);
1827 GET_SCRATCHPAD_SIZE_ZERO();
1829 static status_t execute(const cpu_reorder_pd_t *pd,
1830 const data_t<type_i> *input, data_t<type_o> *output,
1831 const memory_tracking::grantor_t &scratchpad) {
1832 DECLARE_COMMON_PARAMS();
1834 input += input_d.blk_off(0);
1835 output += output_d.blk_off(0);
1837 const int N = input_d.dims()[0];
1838 const size_t is = input_d.blocking_desc().strides[0][0];
1839 const size_t os = output_d.blocking_desc().strides[0][0];
1840 const size_t nelems_no_d0 = nelems_no_dim_0(input_d);
1841 const size_t work_amount = N * nelems_no_d0;
1843 if (alpha == 1.0 && beta == 0.0) {
1844 parallel(0, work_amount, [&](const int ithr, const int nthr) {
1845 size_t n{0}, dim1_s{0};
1846 size_t start{0}, end{0};
1847 balance211(work_amount, nthr, ithr, start, end);
1848 nd_iterator_init(start, n, N, dim1_s, nelems_no_d0);
1849 while(start < end) {
1850 size_t work_rem = end - start;
1851 size_t dim1_e = dim1_s + work_rem > nelems_no_d0
1852 ? nelems_no_d0 : dim1_s + work_rem;
1854 for (size_t e = dim1_s; e < dim1_e; ++e) {
1855 output[os * n + e] = _qz_a1b0<type_i, type_o>()(
1856 input[is * n + e], rmode);
1858 nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0);
1862 parallel(0, work_amount, [&](const int ithr, const int nthr) {
1863 size_t n{0}, dim1_s{0};
1864 size_t start{0}, end{0};
1865 balance211(work_amount, nthr, ithr, start, end);
1866 nd_iterator_init(start, n, N, dim1_s, nelems_no_d0);
1867 while(start < end) {
1868 size_t work_rem = end - start;
1870 dim1_s + work_rem > nelems_no_d0 ? nelems_no_d0
1871 : dim1_s + work_rem;
1873 for (size_t e = dim1_s; e < dim1_e; ++e){
1874 output[os * n + e] = _qz<type_i, type_o>()(
1875 input[is * n + e], output[os * n + e], alpha,
1878 nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0);
1887 static size_t nelems_no_dim_0(const memory_desc_wrapper &data_d) {
1888 const int ndims = data_d.ndims();
1889 if (ndims <= 1) return 1;
1890 return utils::array_product(data_d.dims() + 1, data_d.ndims() - 1);
1893 static size_t _size_no_dim_0(const memory_desc_wrapper &data_d) {
1894 size_t max_size = 0;
1895 auto &blk = data_d.blocking_desc();
1896 for (int d = 1; d < data_d.ndims(); ++d) {
1897 auto block = blk.block_dims[d];
1898 max_size = nstl::max(max_size,
1899 size_t(size_t(blk.padding_dims[d] / block)
1900 * blk.strides[0][d]));
1902 max_size = nstl::max(max_size,
1903 size_t(block * blk.strides[1][d]));
1909 template <SIMPLE_REORDER_TEMPL_DECL>
1910 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1911 typename utils::enable_if<
1912 fmt_i == any && fmt_o == any && order_keep == fmt_order::any,
1913 spec::reference>::type>
1915 static bool is_applicable(const memory_desc_wrapper &input_d,
1916 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
1917 /* supported smask: 0x0...011..10...0,
1918 * i.e. 1 should be contiguous */
1919 int smask = attr ? attr->output_scales_.mask_ : 0;
1920 for (; smask > 0 && !(smask & 0x1); smask >>= 1);
1921 for (; smask > 0 && smask & 0x1; smask >>= 1);
1923 && input_d.is_blocking_desc()
1924 && output_d.is_blocking_desc()
1925 && !output_d.is_additional_buffer()
1926 && !input_d.is_additional_buffer()
1930 GET_SCRATCHPAD_SIZE_ZERO();
1932 static status_t execute(const cpu_reorder_pd_t *pd,
1933 const data_t<type_i> *input, data_t<type_o> *output,
1934 const memory_tracking::grantor_t &scratchpad) {
1935 DECLARE_COMMON_PARAMS();
1937 const size_t nelems = input_d.nelems();
1939 int ndims_start = 0, ndims_mask = 0;
1940 int smask = pd->attr()->output_scales_.mask_;
1941 for (; smask > 0 && !(smask & 0x1); smask >>= 1) ++ndims_start;
1942 for (; smask > 0 && smask & 0x1; smask >>= 1) ++ndims_mask;
1945 const ptrdiff_t D_start
1946 = utils::array_product(input_d.dims(), ndims_start);
1947 const ptrdiff_t D_mask
1948 = utils::array_product(input_d.dims() + ndims_start, ndims_mask);
1949 const ptrdiff_t D_rest = nelems / D_start / D_mask;
1951 const float *scales = pd->attr()->output_scales_.scales_;
1953 parallel_nd(D_start, D_mask, D_rest,
1954 [&](ptrdiff_t ds, ptrdiff_t dm, ptrdiff_t dr) {
1955 const float scale = scales[dm];
1957 const size_t e = (ds * D_mask + dm) * D_rest + dr;
1958 const auto &i = input[input_d.off_l(e)];
1959 auto &o = output[output_d.off_l(e)];
1961 o = _qz<type_i, type_o>()(i, o, scale, beta, rmode);
1969 /* high level class declaration */
1971 template <SIMPLE_REORDER_TEMPL_DECL, typename spec = void>
1972 struct simple_reorder_t: public cpu_primitive_t {
1973 struct pd_t: public cpu_reorder_pd_t {
1974 pd_t(const cpu_memory_pd_t *input_pd, const cpu_memory_pd_t *output_pd,
1975 const primitive_attr_t *attr)
1976 : cpu_reorder_pd_t(input_pd, output_pd, attr) {}
1978 DECLARE_COMMON_PD_T("simple:any", simple_reorder_t);
1980 static status_t create(reorder_pd_t **reorder_pd,
1981 const memory_pd_t *input_pd, const memory_pd_t *output_pd,
1982 const primitive_attr_t *attr) {
1983 assert(input_pd->engine()->kind() == engine_kind::cpu);
1984 assert(output_pd->engine()->kind() == engine_kind::cpu);
1986 && input_pd->desc()->data_type == type_i
1987 && output_pd->desc()->data_type == type_o
1988 && IMPLICATION(utils::one_of(data_type::bf16, type_i, type_o),
1989 mayiuse(avx512_core))
1990 && simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, spec>::
1991 is_applicable(input_pd->desc(), output_pd->desc(), attr);
1993 return invalid_arguments;
1995 auto _pd = new pd_t((const cpu_memory_pd_t *)input_pd,
1996 (const cpu_memory_pd_t *)output_pd, attr);
1997 if (_pd == nullptr) return out_of_memory;
1998 if (_pd->init() != success) { delete _pd; return unimplemented; }
2000 const size_t scratchpad_sz_ =
2001 simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, spec>::
2002 get_scratchpad_size(input_pd->desc(), output_pd->desc());
2003 auto scratchpad = _pd->scratchpad_registry().registrar();
2004 scratchpad.book(memory_tracking::names::key_reorder_space,
2006 return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
2010 simple_reorder_t(const pd_t *apd, const input_vector &inputs,
2011 const output_vector &outputs)
2012 : cpu_primitive_t(apd, inputs, outputs) {}
2014 virtual void execute(event_t *e) const {
2015 auto input = reinterpret_cast<const data_t<type_i> *>(
2016 this->input_memory(0));
2017 auto output = reinterpret_cast<data_t<type_o> *>(this->memory());
2018 simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, spec>::execute(
2019 pd(), input, output, this->scratchpad());
2020 e->set_state(event_t::ready);
2024 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
2027 #undef SIMPLE_REORDER_TEMPL_DECL
2028 #undef SIMPLE_REORDER_TEMPL_CALL
2036 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s