1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef CPU_SIMPLE_REORDER_HPP
18 #define CPU_SIMPLE_REORDER_HPP
22 #include "c_types_map.hpp"
23 #include "type_helpers.hpp"
24 #include "math_utils.hpp"
25 #include "mkldnn_thread.hpp"
28 #include "format_traits.hpp"
29 #include "cpu_reorder_pd.hpp"
30 #include "cpu_primitive.hpp"
32 #include "simple_q10n.hpp"
33 #include "cpu_isa_traits.hpp"
39 using namespace mkldnn::impl::status;
40 using namespace mkldnn::impl::memory_format;
41 using namespace mkldnn::impl::data_type;
43 using dk = data_kind_t;
44 using bf = block_format_t;
46 using namespace mkldnn::impl::utils;
49 template<impl::data_type_t type>
50 using data_t = typename prec_traits<type>::type;
52 template<impl::data_type_t type_i, impl::data_type_t type_o>
53 using _qz_a1b0 = qz_a1b0<data_t<type_i>, data_t<type_o>>;
55 template<impl::data_type_t type_i, impl::data_type_t type_o>
56 using _qz = qz<data_t<type_i>, data_t<type_o>>;
59 const bool keep = true;
60 const bool reverse = false;
61 const bool any = keep;
65 struct direct_copy {};
66 struct direct_copy_except_dim_0 {};
70 #define SIMPLE_REORDER_TEMPL_DECL \
71 impl::data_type_t type_i, impl::memory_format_t fmt_i, \
72 impl::data_type_t type_o, impl::memory_format_t fmt_o, bool order_keep
73 #define SIMPLE_REORDER_TEMPL_CALL \
74 type_i, fmt_i, type_o, fmt_o, order_keep
76 #define DECLARE_COMMON_PARAMS() \
77 const memory_desc_wrapper &input_d = pd->input_pd(); \
78 const memory_desc_wrapper &output_d = pd->output_pd(); \
79 const float alpha = pd->alpha(); MAYBE_UNUSED(alpha); \
80 const float beta = pd->beta(); MAYBE_UNUSED(beta); \
81 const round_mode_t rmode = pd->attr()->round_mode_; MAYBE_UNUSED(rmode);
83 /* specific reorders: common template */
84 template <SIMPLE_REORDER_TEMPL_DECL, typename spec = void>
85 struct simple_reorder_impl {};
88 bool simple_fmt_check(bool order_keep, impl::memory_format_t fmt_i,
89 impl::memory_format_t fmt_o, const memory_desc_wrapper &input_d,
90 const memory_desc_wrapper &output_d) {
91 return input_d.format() == (order_keep ? fmt_i : fmt_o)
92 && output_d.format() == (order_keep ? fmt_o : fmt_i);
94 bool simple_attr_check(const primitive_attr_t *attr, bool many_scales_support) {
95 if (many_scales_support)
97 return IMPLICATION(attr, attr->output_scales_.mask_ == 0);
101 /* specific reorders: implementation */
102 template <SIMPLE_REORDER_TEMPL_DECL>
103 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
104 typename utils::enable_if<fmt_i == any && (false
105 || fmt_o == hwio_s8s8
106 || fmt_o == hwigo_s8s8)>::type>
108 static bool is_applicable(const memory_desc_wrapper &input_d,
109 const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
111 const size_t D_mask = utils::array_product(input_d.dims(),
112 math::ilog2q(attr->output_scales_.mask_ + 1));
113 const int oc = (input_d.dims()[fmt_o == hwigo_s8s8 + 0]);
114 const int g = (fmt_o == hwigo_s8s8) ? (input_d.dims()[0]) : 1;
116 return output_d.format() == fmt_o
117 && (input_d.data_type() == f32 || input_d.data_type() == s8)
118 && output_d.data_type() == s8
119 && (D_mask == 1 || D_mask == (size_t)g * oc);
122 static status_t execute(const cpu_reorder_pd_t *pd,
123 const data_t<type_i> *input, data_t<type_o> *output) {
124 DECLARE_COMMON_PARAMS();
126 static constexpr bool w_groups = fmt_o == hwigo_s8s8;
128 const auto &dims = input_d.dims();
129 const auto &pdims = output_d.blocking_desc().padding_dims;
131 const int G = w_groups ? dims[0] : 1;
132 const int OC = dims[w_groups + 0];
133 const int IC = dims[w_groups + 1];
134 const int H = dims[w_groups + 2];
135 const int W = dims[w_groups + 3];
137 const float *scales = pd->attr()->output_scales_.scales_;
138 const size_t D_mask = utils::array_product(input_d.dims(),
139 math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
141 float adj_scale = (mayiuse(avx512_core_vnni)) ? 1.0f : (1.0f / 2.0f);
143 size_t offset = G * pdims[w_groups + 0] * pdims[w_groups + 1] * H * W;
144 int32_t *cp = reinterpret_cast<int32_t *>(output + offset);
146 parallel_nd(G, OC, [&](int g, int oc) {
148 for (int ic = 0; ic < IC; ic++)
149 for (int h = 0; h < H; h++)
150 for (int w = 0; w < W; w++) {
151 auto i = input[input_d.blk_off<!w_groups>(g, oc, ic, h, w)];
152 auto &o = output[output_d.blk_off<!w_groups>(g, oc, ic, h, w)];
153 const float s = scales[(D_mask == 1) ? 0 : g * OC + oc];
155 o = qz_b0<data_t<type_i>, data_t<type_o>>()(
156 i, s * adj_scale, rmode);
157 cp[g * OC + oc] -= (int32_t)o;
159 cp [g * OC + oc] *= 128;
165 template <SIMPLE_REORDER_TEMPL_DECL>
166 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
167 typename utils::enable_if<
168 ((fmt_i == goihw || fmt_i == oihw)
169 && (format_traits<fmt_o>::blk_fmt == bf::_4i16o4i_s8s8
170 || format_traits<fmt_o>::blk_fmt == bf::_2i8o4i_s8s8
171 || format_traits<fmt_o>::blk_fmt == bf::_4o4i_s8s8))
174 static bool is_applicable(const memory_desc_wrapper &input_d,
175 const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
177 const size_t D_mask = utils::array_product(input_d.dims(),
178 math::ilog2q(attr->output_scales_.mask_ + 1));
179 const int oc = (input_d.dims()[(fmt_i == goihw) + 0]);
180 const int g = (fmt_i == goihw) ? (input_d.dims()[0]) : 1;
182 return input_d.format() == fmt_i
183 && output_d.format() == fmt_o
184 && (input_d.data_type() == f32 || input_d.data_type() == s8)
185 && output_d.data_type() == s8
186 && (D_mask == 1 || D_mask == (size_t)g * oc);
189 static status_t execute(const cpu_reorder_pd_t *pd,
190 const data_t<type_i> *input, data_t<type_o> *output) {
191 DECLARE_COMMON_PARAMS();
193 static constexpr bool w_groups = fmt_i == goihw;
194 const int blksize = format_traits<fmt_o>::blk_size;
197 const auto &_g_oihw_d = order_keep ? input_d : output_d;
198 const auto &dims = input_d.dims();
199 const auto &pdims = order_keep
200 ? output_d.blocking_desc().padding_dims
201 : input_d.blocking_desc().padding_dims;
203 const int G = w_groups ? dims[0] : 1;
204 const int OC = dims[w_groups + 0];
205 const int NB_OC = pdims[w_groups + 0] / blksize;
206 const int IC = dims[w_groups + 1];
207 const int NB_IC = pdims[w_groups + 1] / blksize;
208 const int H = dims[w_groups + 2];
209 const int W = dims[w_groups + 3];
211 const float *scales = pd->attr()->output_scales_.scales_;
212 const size_t D_mask = utils::array_product(input_d.dims(),
213 math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
215 float adj_scale = (mayiuse(avx512_core_vnni)) ? 1.f : (1.f / 2.f);
217 auto index = [&](const int ic, const int oc) {
218 return ((ic / sblk) * blksize * sblk + sblk * oc + ic % sblk);
221 auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out,
222 int32_t *c, const float *s, const int oc_block, const int ic_block) {
223 for (int ic = 0; ic < ic_block; ++ic) {
224 for (int oc = 0; oc < oc_block; ++oc) {
225 const auto _g_oihw_off =
226 oc * _g_oihw_d.blocking_desc().strides[0][w_groups + 0]
227 + ic * _g_oihw_d.blocking_desc().strides[0][w_groups + 1];
229 = qz_b0<data_t<type_i>, data_t<type_o>>()(
230 inp[_g_oihw_off], s[oc] * adj_scale, rmode);
231 c[oc] -= (128 * (int32_t)(out[index(ic, oc)]));
236 constexpr int i_mult = blksize;
237 constexpr int o_mult = 1;
239 size_t offset = G * pdims[w_groups+0] * pdims[w_groups+1] * H * W;
240 int32_t *cp = reinterpret_cast<int32_t *>(output + offset);
241 parallel_nd(G * NB_OC * blksize, [&](int i) {
245 parallel_nd(G, NB_OC, [&](int g, int O) {
246 for (int I = 0; I < NB_IC; I++)
247 for (int h = 0; h < H; h++)
248 for (int w = 0; w < W; w++) {
249 auto i = &input[input_d.blk_off<!w_groups>(g,
250 i_mult * O, i_mult * I, h, w)];
251 auto o = &output[output_d.blk_off<!w_groups>(
252 g, o_mult * O, o_mult * I, h, w)];
253 const int oc_block = nstl::min(blksize, OC - O * blksize);
254 const int ic_block = nstl::min(blksize, IC - I * blksize);
256 int _offset = (g * NB_OC + O) * blksize;
257 ker(i, o, (order_keep) ? &cp[_offset] : nullptr,
258 &scales[(D_mask == 1) ? 0 : _offset],
266 template <SIMPLE_REORDER_TEMPL_DECL>
267 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
268 typename utils::enable_if<true
269 && (fmt_i == goihw && fmt_o == Goihw16g_s8s8)>::type>
271 static bool is_applicable(const memory_desc_wrapper &input_d,
272 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
273 const size_t D_mask = utils::array_product(input_d.dims(),
274 math::ilog2q(attr->output_scales_.mask_ + 1));
275 const int oc = input_d.dims()[1];
276 const int g = input_d.dims()[0];
280 && input_d.format() == fmt_i
281 && output_d.format() == fmt_o
282 && (input_d.data_type() == f32 || input_d.data_type() == s8)
283 && output_d.data_type() == s8
284 && (D_mask == 1 || D_mask == (size_t)g * oc);
287 static status_t execute(const cpu_reorder_pd_t *pd,
288 const data_t<type_i> *input, data_t<type_o> *output) {
289 DECLARE_COMMON_PARAMS();
291 const int blksize = 16;
293 const auto &dims = input_d.dims();
294 const auto &pdims = output_d.blocking_desc().padding_dims;
295 const int G = dims[0];
296 const int Gp = pdims[0];
297 const int OC = dims[1];
298 const int IC = dims[2];
299 const int H = dims[3];
300 const int W = dims[4];
302 const size_t D_mask = utils::array_product(input_d.dims(),
303 math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
304 const float *scales = pd->attr()->output_scales_.scales_;
305 float adj_scale = (mayiuse(avx512_core_vnni)) ? 1.f : (1.f / 2.f);
308 auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out,
309 int32_t *cp, const float *s, const int g_block) {
311 for (int g = 0; g < g_block; g++) {
312 const auto i_off = g * input_d.blocking_desc().strides[0][0];
313 out[g] = qz_b0<data_t<type_i>, data_t<type_o>>()(
314 inp[i_off], s[g * OC] * adj_scale, rmode);
315 cp[g * OC] -= 128 * (int32_t)(out[g]);
319 size_t cp_offset = output_d.size() - output_d.additional_buffer_size();
320 int32_t *cp = reinterpret_cast<int32_t *>(output + cp_offset);
321 parallel_nd((Gp/blksize) * OC, [&](int ib) {
323 for (int i = 0; i < blksize; i++)
324 cp[ib * blksize + i] = 0;
327 parallel_nd(Gp/blksize, OC, [&](int gb, int O) {
328 for (int I = 0; I < IC; I++) {
329 for (int h = 0; h < H; h++) {
330 for (int w = 0; w < W; w++) {
331 const int g_block = nstl::min(G - gb * blksize, blksize);
332 const auto inp = &input[input_d.blk_off(gb * blksize, O, I, h, w)];
333 const auto out = &output[output_d.blk_off(gb, O, I, h, w)];
334 int offset = gb * blksize + O;
335 ker(inp, out, &cp[offset],
336 &scales[(D_mask == 1) ? 0 : offset], g_block);
345 template <SIMPLE_REORDER_TEMPL_DECL>
346 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
347 typename utils::enable_if<true
348 && format_traits<fmt_i>::blk_fmt == bf::_8i16o2i
349 && format_traits<fmt_o>::blk_fmt == bf::_8o16i2o>::type>
351 static bool is_applicable(const memory_desc_wrapper &input_d,
352 const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
354 return simple_fmt_check(order_keep, fmt_i, fmt_o, input_d, output_d)
355 && simple_attr_check(attr, false);
358 static status_t execute(const cpu_reorder_pd_t *pd,
359 const data_t<type_i> *input, data_t<type_o> *output) {
360 DECLARE_COMMON_PARAMS();
362 static constexpr bool w_groups
363 = format_traits<fmt_o>::data_kind == dk::gwei;
364 constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
365 constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
366 constexpr int blksize = format_traits<fmt_o>::blk_size;
368 const auto &dims = input_d.dims();
370 const int G = w_groups ? dims[0] : 1;
371 const int NB_OC = dims[w_groups + 0] / blksize;
372 const int NB_IC = dims[w_groups + 1] / blksize;
373 const int D = is_3d ? dims[w_groups + 2] : 1;
374 const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
375 const int W = dims[w_groups + 3 + is_3d - is_1d];
377 auto idx_i = [&](const int oc, const int ic)
378 { return ((ic / 2) * blksize * 2 + 2 * oc + ic % 2); };
380 auto idx_o = [&](const int oc, const int ic)
381 { return ((oc / 2) * blksize * 2 + 2 * ic + oc % 2); };
383 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o) -> void {
384 if (alpha == 1.0 && beta == 0.0) {
385 for (int ic = 0; ic < blksize; ++ic) {
386 for (int oc = 0; oc < blksize; ++oc) {
387 o[idx_o(oc, ic)] = _qz_a1b0<type_i, type_o>()(
388 i[idx_i(oc, ic)], rmode);
392 for (int ic = 0; ic < blksize; ++ic) {
393 for (int oc = 0; oc < blksize; ++oc) {
394 o[idx_o(oc, ic)] = _qz<type_i, type_o>()(
395 i[idx_i(oc, ic)], o[idx_o(oc, ic)], alpha,
402 parallel_nd(G, NB_OC, NB_IC, D, H, W,
403 [&](int g, int o, int i, int d, int h, int w) {
404 auto ptr_i = &input[wei_blk_off_like_gwei3D<fmt_i>(
405 input_d, g, o, i, d, h, w)];
406 auto ptr_o = &output[wei_blk_off_like_gwei3D<fmt_o>(
407 output_d, g, o, i, d, h, w)];
415 /* reorders with tail support */
417 template <SIMPLE_REORDER_TEMPL_DECL>
418 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
419 typename utils::enable_if<fmt_i == nChw8c && fmt_o == nhwc && order_keep>::type>
421 static bool is_applicable(const memory_desc_wrapper &input_d,
422 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
423 int smask = attr ? attr->output_scales_.mask_ : 0;
424 return (smask == 0 || smask == 2) && order_keep && input_d._md->format == nChw8c && output_d._md->format == nhwc;
427 static status_t execute(const cpu_reorder_pd_t *pd,
428 const data_t<type_i> *input, data_t<type_o> *output) {
429 DECLARE_COMMON_PARAMS();
431 const auto &pdims = input_d.blocking_desc().padding_dims;
432 const auto &dims = input_d.dims();
433 constexpr int blksize = format_traits<fmt_i>::blk_size;
434 const int C = dims[1];
435 const int H = dims[2];
436 const int W = dims[3];
438 constexpr int i_c_mult = 1;
439 constexpr int o_c_mult = blksize;
441 const float *scales = pd->attr()->output_scales_.scales_;
442 int smask = pd->attr()->output_scales_.mask_;
444 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
445 const int nb_c, const int c_block) {
447 for (int w = 0; w < W; ++w) {
448 const ptrdiff_t flat_off = w * output_d.blocking_desc().strides[0][3];
450 for (int c = 0; c < c_block; ++c) {
451 const float scale = scales[nb_c * blksize + c];
453 o[flat_off + c] = _qz<type_i, type_o>()(i[w * blksize + c],
454 o[flat_off + c], scale, beta, rmode);
458 for (int w = 0; w < W; ++w) {
459 const ptrdiff_t flat_off = w * output_d.blocking_desc().strides[0][3];
461 for (int c = 0; c < c_block; ++c) {
462 o[flat_off + c] = _qz_a1b0<type_i, type_o>()(i[w * blksize + c], rmode);
468 parallel_nd(dims[0], pdims[1] / blksize, H,
469 [&](int n, int nb_c, int h) {
470 auto i = &input[input_d.blk_off(n, i_c_mult * nb_c, h)];
471 auto o = &output[output_d.blk_off(n, o_c_mult * nb_c, h)];
472 const int c_block = nstl::min(blksize, C - nb_c * blksize);
473 ker(i, o, nb_c, c_block);
480 template <SIMPLE_REORDER_TEMPL_DECL>
481 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
482 typename utils::enable_if<fmt_i == nhwc && fmt_o == nChw8c>::type>
484 static bool is_applicable(const memory_desc_wrapper &input_d,
485 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
486 int smask = attr ? attr->output_scales_.mask_ : 0;
487 return (smask == 2) && order_keep && input_d._md->format == nhwc && output_d._md->format == nChw8c;
490 static status_t execute(const cpu_reorder_pd_t *pd,
491 const data_t<type_i> *input, data_t<type_o> *output) {
492 DECLARE_COMMON_PARAMS();
494 const auto &pdims = output_d.blocking_desc().padding_dims;
495 const auto &dims = input_d.dims();
496 constexpr int blksize = format_traits<fmt_o>::blk_size;
497 const int C = dims[1];
498 const int H = dims[2];
499 const int W = dims[3];
501 constexpr int i_c_mult = blksize;
502 constexpr int o_c_mult = 1;
504 const float *scales = pd->attr()->output_scales_.scales_;
505 int smask = pd->attr()->output_scales_.mask_;
507 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
508 const int nb_c, const int c_block) {
510 for (int w = 0; w < W; ++w) {
511 const ptrdiff_t flat_off = w * input_d.blocking_desc().strides[0][3];
513 for (int c = 0; c < c_block; ++c) {
514 const float scale = scales[nb_c * blksize + c];
516 o[w * blksize + c] = _qz<type_i, type_o>()(i[flat_off + c],
517 o[w * blksize + c], scale, beta, rmode);
521 for (int w = 0; w < W; ++w) {
522 const ptrdiff_t flat_off = w * input_d.blocking_desc().strides[0][3];
524 for (int c = 0; c < c_block; ++c) {
525 o[w * blksize + c] = _qz_a1b0<type_i, type_o>()(i[flat_off + c], rmode);
531 parallel_nd(dims[0], pdims[1] / blksize, H,
532 [&](int n, int nb_c, int h) {
533 auto i = &input[input_d.blk_off(n, i_c_mult * nb_c, h)];
534 auto o = &output[output_d.blk_off(n, o_c_mult * nb_c, h)];
535 const int c_block = nstl::min(blksize, C - nb_c * blksize);
536 ker(i, o, nb_c, c_block);
543 template <SIMPLE_REORDER_TEMPL_DECL>
544 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
545 typename utils::enable_if<fmt_i == nhwc && fmt_o == nhwc && type_o != mkldnn_bin>::type>
547 static bool is_applicable(const memory_desc_wrapper &input_d,
548 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
549 int smask = attr ? attr->output_scales_.mask_ : 0;
550 return (smask == 2) && order_keep && input_d._md->format == nhwc && output_d._md->format == nhwc;
553 static status_t execute(const cpu_reorder_pd_t *pd,
554 const data_t<type_i> *input, data_t<type_o> *output) {
555 DECLARE_COMMON_PARAMS();
557 const auto &dims = input_d.dims();
558 const int C = dims[1];
559 const int H = dims[2];
560 const int W = dims[3];
562 const float *scales = pd->attr()->output_scales_.scales_;
564 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o) {
565 for (int c = 0; c < C; ++c) {
566 const float scale = scales[c];
568 o[c] = _qz<type_i, type_o>()(i[c], o[c], scale, beta, rmode);
572 parallel_nd(dims[0], H, W,
573 [&](int n, int h, int w) {
574 auto i = &input[input_d.blk_off(n, 0, h, w)];
575 auto o = &output[output_d.blk_off(n, 0, h, w)];
583 template <SIMPLE_REORDER_TEMPL_DECL>
584 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
585 typename utils::enable_if<fmt_i == nchw && fmt_o == nhwc && type_i != mkldnn_bin && type_o != mkldnn_bin>::type>
587 static bool is_applicable(const memory_desc_wrapper &input_d,
588 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
589 int smask = attr ? attr->output_scales_.mask_ : 0;
590 return (smask == 0 || smask == 2) && order_keep && input_d._md->format == nchw && output_d._md->format == nhwc;
593 static status_t execute(const cpu_reorder_pd_t *pd,
594 const data_t<type_i> *input, data_t<type_o> *output) {
595 DECLARE_COMMON_PARAMS();
597 const auto &dims = input_d.dims();
598 const int C = dims[1];
599 const int H = dims[2];
600 const int W = dims[3];
602 int smask = pd->attr()->output_scales_.mask_;
603 const float *scales = pd->attr()->output_scales_.scales_;
605 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o) {
607 for (int c = 0; c < C; ++c) {
608 const float scale = scales[c];
610 const ptrdiff_t flat_off = c * input_d.blocking_desc().strides[0][1];
612 o[c] = _qz<type_i, type_o>()(i[flat_off], o[c], scale, beta, rmode);
615 for (int c = 0; c < C; ++c) {
616 const ptrdiff_t flat_off = c * input_d.blocking_desc().strides[0][1];
618 o[c] = _qz_a1b0<type_i, type_o>()(i[flat_off], rmode);
623 parallel_nd(dims[0], H, W,
624 [&](int n, int h, int w) {
625 auto i = &input[input_d.blk_off(n, 0, h, w)];
626 auto o = &output[output_d.blk_off(n, 0, h, w)];
634 template <SIMPLE_REORDER_TEMPL_DECL>
635 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
636 typename utils::enable_if<(fmt_i == nchw || fmt_i == nhwc) && fmt_o == nhwc && (type_i == mkldnn_bin || type_o == mkldnn_bin)>::type>
638 static bool is_applicable(const memory_desc_wrapper &input_d,
639 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
640 int smask = attr ? attr->output_scales_.mask_ : 0;
641 return smask == 0 && order_keep && (input_d._md->format == nchw || input_d._md->format == nhwc) && output_d._md->format == nhwc;
644 static status_t execute(const cpu_reorder_pd_t *pd,
645 const data_t<type_i> *input, data_t<type_o> *output) {
646 DECLARE_COMMON_PARAMS();
648 const auto &dims = input_d.dims();
649 const int C = dims[1];
650 const int H = dims[2];
651 const int W = dims[3];
654 const int CB = div_up(C, nbits);
656 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o) {
657 for (int cb = 0; cb < CB; ++cb) {
658 uint8_t bin_val = 0x00;
659 for (int c = cb * nbits, shift = 0; c < std::min(C, (cb + 1) * nbits); c++, shift++) {
660 const ptrdiff_t flat_off = c * input_d.blocking_desc().strides[0][1];
662 auto bit = uint8_t((i[flat_off] > 0) ? 0x01 : 0x00);
663 bin_val |= (bit << shift);
670 parallel_nd(dims[0], H, W,
671 [&](int n, int h, int w) {
672 auto iidx = input_d.blk_off(n, 0, h, w);
673 auto oidx = output_d.blk_off(n, 0, h, w);
675 auto i = &input[iidx];
676 auto o = &output[oidx / nbits];
684 template <SIMPLE_REORDER_TEMPL_DECL>
685 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
686 typename utils::enable_if<fmt_i == nhwc && fmt_o == nchw>::type>
688 static bool is_applicable(const memory_desc_wrapper &input_d,
689 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
690 int smask = attr ? attr->output_scales_.mask_ : 0;
691 return (smask == 0 || smask == 2) && order_keep && input_d._md->format == nhwc && output_d._md->format == nchw;
694 static status_t execute(const cpu_reorder_pd_t *pd,
695 const data_t<type_i> *input, data_t<type_o> *output) {
696 DECLARE_COMMON_PARAMS();
698 const auto &dims = input_d.dims();
699 const int C = dims[1];
700 const int H = dims[2];
701 const int W = dims[3];
703 int smask = pd->attr()->output_scales_.mask_;
704 const float *scales = pd->attr()->output_scales_.scales_;
706 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o) {
708 for (int c = 0; c < C; ++c) {
709 const float scale = scales[c];
711 const ptrdiff_t flat_off = c * output_d.blocking_desc().strides[0][1];
713 o[flat_off] = _qz<type_i, type_o>()(i[c], o[flat_off], scale, beta, rmode);
716 for (int c = 0; c < C; ++c) {
717 const ptrdiff_t flat_off = c * output_d.blocking_desc().strides[0][1];
719 o[flat_off] = _qz_a1b0<type_i, type_o>()(i[c], rmode);
724 parallel_nd(dims[0], H, W,
725 [&](int n, int h, int w) {
726 auto i = &input[input_d.blk_off(n, 0, h, w)];
727 auto o = &output[output_d.blk_off(n, 0, h, w)];
735 template <SIMPLE_REORDER_TEMPL_DECL>
736 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
737 typename utils::enable_if<format_traits<fmt_i>::blk_fmt == bf::_8c
738 && format_traits<fmt_o>::blk_fmt == bf::_16c>::type>
740 static bool is_applicable(const memory_desc_wrapper &input_d,
741 const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
743 return simple_fmt_check(order_keep, fmt_i, fmt_o, input_d, output_d)
744 && simple_attr_check(attr, false);
747 static status_t execute(const cpu_reorder_pd_t *pd,
748 const data_t<type_i> *input, data_t<type_o> *output) {
749 DECLARE_COMMON_PARAMS();
751 constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
752 constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
753 constexpr int blksize_16 = format_traits<fmt_o>::blk_size;
754 constexpr int blksize_8 = format_traits<fmt_i>::blk_size;
755 constexpr int ic_mult = order_keep ? 2 : 1;
756 constexpr int oc_mult = order_keep ? 1 : 2;
758 const auto &nchw8c_d = order_keep ? input_d : output_d;
759 const auto &dims = input_d.dims();
760 const auto &pdims = order_keep ? output_d.blocking_desc().padding_dims
761 : input_d.blocking_desc().padding_dims;
762 const auto stride_8c = nchw8c_d.blocking_desc().strides[0];
764 const int C = dims[1];
765 const int D = is_3d ? dims[2] : 1;
766 const int H = is_1d ? 1 : dims[2 + is_3d];
767 const int W = dims[3 + is_3d - is_1d];
769 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
770 const int block_16) {
771 const int nb = (block_16 - 1) / blksize_8 + 1;
772 if (alpha == 1.0 && beta == 0.0) {
773 for (int b = 0; b < nb; ++b) {
774 const ptrdiff_t i_off = order_keep ? b * stride_8c[1]
776 const ptrdiff_t o_off = order_keep ? b * blksize_8
778 const int block_8 = nstl::min(blksize_8,
779 block_16 - b * blksize_8);
780 for (int c = 0; c < block_8; ++c) {
781 o[o_off + c] = _qz_a1b0<type_i, type_o>()(
782 i[i_off + c], rmode);
786 for (int b = 0; b < nb; ++b) {
787 const ptrdiff_t i_off = order_keep ? b * stride_8c[1]
789 const ptrdiff_t o_off = order_keep ? b * blksize_8
791 const int block_8 = nstl::min(blksize_8,
792 block_16 - b * blksize_8);
793 for (int c = 0; c < block_8; ++c) {
794 o[o_off + c] = _qz<type_i, type_o>()(i[i_off + c],
795 o[o_off + c], alpha, beta, rmode);
801 # define data_blk_off(md, n, c, d, h, w) \
802 ( is_1d ? (md).blk_off(n, c, w) \
803 : is_3d ? (md).blk_off(n, c, d, h, w) : (md).blk_off(n, c, h, w))
805 parallel_nd(dims[0], pdims[1] / blksize_16, D, H, W,
806 [&](int n, int nb_c, int d, int h, int w) {
807 auto i = &input[data_blk_off(input_d, n, ic_mult * nb_c, d, h, w)];
808 auto o = &output[data_blk_off(output_d, n, oc_mult * nb_c, d, h, w)];
809 const int block_16 = nstl::min(blksize_16, C - nb_c * blksize_16);
819 #define PLAIN_TO_BLOCKED_IS_APPLICABLE() \
820 static bool is_applicable(const memory_desc_wrapper &input_d, \
821 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { \
822 return simple_attr_check(attr, false) && (order_keep \
823 ? output_d.format() == fmt_o && input_d.is_plain() \
824 : input_d.format() == fmt_o && output_d.is_plain()); \
827 template <SIMPLE_REORDER_TEMPL_DECL>
828 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
829 typename utils::enable_if<fmt_i == any && (false
830 || format_traits<fmt_o>::blk_fmt == bf::_4c
831 || format_traits<fmt_o>::blk_fmt == bf::_8c
832 || format_traits<fmt_o>::blk_fmt == bf::_16c)>::type>
834 PLAIN_TO_BLOCKED_IS_APPLICABLE();
836 static status_t execute(const cpu_reorder_pd_t *pd,
837 const data_t<type_i> *input, data_t<type_o> *output) {
838 DECLARE_COMMON_PARAMS();
840 constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
841 constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
842 constexpr int blksize = format_traits<fmt_o>::blk_size;
844 const auto &flat_d = order_keep ? input_d : output_d;
845 const auto &dims = input_d.dims();
846 const auto &pdims = order_keep
847 ? output_d.blocking_desc().padding_dims
848 : input_d.blocking_desc().padding_dims;
850 const int C = dims[1];
851 const int D = is_3d ? dims[2] : 1;
852 const int H = is_1d ? 1 : dims[2 + is_3d];
853 const int W = dims[3 + is_3d - is_1d];
855 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
857 if (alpha == 1.0 && beta == 0.0) {
858 for (int w = 0; w < W; ++w)
859 for (int c = 0; c < c_block; ++c) {
860 const ptrdiff_t flat_off = 0
861 + c * flat_d.blocking_desc().strides[0][1]
862 + w * flat_d.blocking_desc().strides[0][3 + is_3d
865 o[w * blksize + c] = _qz_a1b0<type_i, type_o>()(
868 o[flat_off] = _qz_a1b0<type_i, type_o>()(
869 i[w * blksize + c], rmode);
873 for (int w = 0; w < W; ++w)
874 for (int c = 0; c < c_block; ++c) {
875 const ptrdiff_t flat_off = 0
876 + c * flat_d.blocking_desc().strides[0][1]
877 + w * flat_d.blocking_desc().strides[0][3 + is_3d
880 o[w * blksize + c] = _qz<type_i, type_o>()(i[flat_off],
881 o[w * blksize + c], alpha, beta, rmode);
883 o[flat_off] = _qz<type_i, type_o>()(i[w * blksize + c],
884 o[flat_off], alpha, beta, rmode);
890 constexpr int i_c_mult = order_keep ? blksize : 1;
891 constexpr int o_c_mult = order_keep ? 1 : blksize;
893 # define data_blk_off(md, n, c, d, h) \
894 ( is_1d ? (md).blk_off(n, c) \
895 : is_3d ? (md).blk_off(n, c, d, h) : (md).blk_off(n, c, h))
897 parallel_nd(dims[0], pdims[1] / blksize, D, H,
898 [&](int n, int nb_c, int d, int h) {
899 auto i = &input[data_blk_off(input_d, n, i_c_mult * nb_c, d, h)];
900 auto o = &output[data_blk_off(output_d, n, o_c_mult * nb_c, d, h)];
901 const int c_block = nstl::min(blksize, C - nb_c * blksize);
911 template <SIMPLE_REORDER_TEMPL_DECL>
912 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
913 typename utils::enable_if<
914 (fmt_i == goihw && fmt_o == gOhIw8o4i_s8s8)
915 || (fmt_i == oihw && fmt_o == OhIw8o4i_s8s8)
918 static bool is_applicable(const memory_desc_wrapper &input_d,
919 const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
921 const size_t D_mask = utils::array_product(input_d.dims(),
922 math::ilog2q(attr->output_scales_.mask_ + 1));
923 const int oc = (input_d.dims()[(fmt_i == goihw) + 0]);
924 const int g = (fmt_i == goihw) ? (input_d.dims()[0]) : 1;
926 return input_d.format() == fmt_i
927 && output_d.format() == fmt_o
928 && (input_d.data_type() == f32 || input_d.data_type() == s8)
929 && output_d.data_type() == s8
930 && (D_mask == 1 || D_mask == (size_t)g * oc);
933 static status_t execute(const cpu_reorder_pd_t *pd,
934 const data_t<type_i> *input, data_t<type_o> *output) {
935 DECLARE_COMMON_PARAMS();
937 static constexpr bool w_groups
938 = format_traits<fmt_o>::data_kind == dk::gwei;
939 constexpr int blksize_o = 8;
940 constexpr int blksize_i = 4;
942 const auto &flat_d = order_keep ? input_d : output_d;
943 const auto &dims = input_d.dims();
944 const auto &pdims = order_keep
945 ? output_d.blocking_desc().padding_dims
946 : input_d.blocking_desc().padding_dims;
948 const int G = w_groups ? dims[0] : 1;
949 const int OC = dims[w_groups + 0];
950 const int NB_OC = pdims[w_groups + 0] / blksize_o;
951 const int IC = dims[w_groups + 1];
952 const int NB_IC = pdims[w_groups + 1] / blksize_i;
953 const int H = dims[w_groups + 2];
954 const int W = dims[w_groups + 3];
956 const float *scales = pd->attr()->output_scales_.scales_;
957 const size_t D_mask = utils::array_product(input_d.dims(),
958 math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
960 float adj_scale = (mayiuse(avx512_core_vnni)) ? 1.0 : (1.0 / 2.0);
962 auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out,
963 int32_t *c, const float *s, const int oc_block, const int ic_block) {
964 # define blk_off OI_blk_off<format_traits<fmt_o>::blk_fmt>
966 for (int ic = 0; ic < ic_block; ++ic) {
967 for (int oc = 0; oc < oc_block; ++oc) {
968 const auto _g_oihw_off = oc * flat_d.blocking_desc().strides[0][w_groups + 0] +
969 ic * flat_d.blocking_desc().strides[0][w_groups + 1];
972 out[blk_off(oc, ic)] = qz_b0<data_t<type_i>, data_t<type_o>>()(inp[_g_oihw_off], s[oc] * adj_scale, rmode);
973 c[oc] -= (128 * (int32_t)(out[blk_off(oc, ic)]));
975 out[_g_oihw_off] = qz_b0<data_t<type_i>, data_t<type_o>>()(inp[blk_off(oc, ic)], s[oc] * adj_scale, rmode);
976 c[oc] -= (128 * (int32_t)(out[_g_oihw_off]));
984 constexpr int i_mult_o = blksize_o;
985 constexpr int i_mult_i = blksize_i;
987 size_t offset = G * pdims[w_groups+0] * pdims[w_groups+1] * H * W;
988 int32_t *cp = reinterpret_cast<int32_t *>(output + offset);
989 parallel_nd(G * NB_OC * blksize_o, [&](int i) {
993 parallel_nd(G, NB_OC, [&](int g, int O) {
994 for (int I = 0; I < NB_IC; I++) {
995 for (int h = 0; h < H; h++) {
996 for (int w = 0; w < W; w++) {
997 auto i = &input[input_d.blk_off<!w_groups>(g, i_mult_o * O, i_mult_i * I, h, w)];
998 auto o = &output[output_d.blk_off<!w_groups>(g, O, I, h, w)];
999 const int oc_block = nstl::min(blksize_o, OC - O * blksize_o);
1000 const int ic_block = nstl::min(blksize_i, IC - I * blksize_i);
1002 int _offset = (g * NB_OC + O) * blksize_o;
1003 ker(i, o, (order_keep) ? &cp[_offset] : nullptr, &scales[(D_mask == 1) ? 0 : _offset], oc_block,
1014 template <SIMPLE_REORDER_TEMPL_DECL>
1015 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1016 typename utils::enable_if<fmt_i == any && (fmt_o == OhIw8o4i || fmt_o == gOhIw8o4i)>::type>
1018 PLAIN_TO_BLOCKED_IS_APPLICABLE();
1020 static status_t execute(const cpu_reorder_pd_t *pd,
1021 const data_t<type_i> *input, data_t<type_o> *output) {
1022 DECLARE_COMMON_PARAMS();
1024 static constexpr bool w_groups
1025 = format_traits<fmt_o>::data_kind == dk::gwei;
1026 constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
1027 constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
1028 constexpr int blksize_o = 8;//format_traits<fmt_o>::blk_size;
1029 constexpr int blksize_i = 4;
1031 const auto &flat_d = order_keep ? input_d : output_d;
1032 const auto &dims = input_d.dims();
1033 const auto &pdims = order_keep
1034 ? output_d.blocking_desc().padding_dims
1035 : input_d.blocking_desc().padding_dims;
1037 const int G = w_groups ? dims[0] : 1;
1038 const int OC = dims[w_groups + 0];
1039 const int NB_OC = pdims[w_groups + 0] / blksize_o;
1040 const int IC = dims[w_groups + 1];
1041 const int NB_IC = pdims[w_groups + 1] / blksize_i;
1042 const int D = is_3d ? dims[w_groups + 2] : 1;
1043 const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
1044 const int W = dims[w_groups + 3 + is_3d - is_1d];
1046 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
1047 const int oc_block, const int ic_block) {
1048 # define blk_off OI_blk_off<format_traits<fmt_o>::blk_fmt>
1050 if (alpha == 1.0 && beta == 0.0) {
1051 for (int oc = 0; oc < oc_block; ++oc)
1052 for (int ic = 0; ic < ic_block; ++ic) {
1053 const ptrdiff_t flat_off = 0
1054 + oc * flat_d.blocking_desc().strides[0][w_groups + 0]
1055 + ic * flat_d.blocking_desc().strides[0][w_groups + 1];
1057 o[blk_off(oc, ic)] = _qz_a1b0<type_i, type_o>()(
1058 i[flat_off], rmode);
1060 o[flat_off] = _qz_a1b0<type_i, type_o>()(
1061 i[blk_off(oc, ic)], rmode);
1065 for (int oc = 0; oc < oc_block; ++oc)
1066 for (int ic = 0; ic < ic_block; ++ic) {
1067 const ptrdiff_t flat_off = 0
1068 + oc * flat_d.blocking_desc().strides[0][w_groups + 0]
1069 + ic * flat_d.blocking_desc().strides[0][w_groups + 1];
1071 o[blk_off(oc, ic)] = _qz<type_i, type_o>()(i[flat_off],
1072 o[blk_off(oc, ic)], alpha, beta, rmode);
1074 o[flat_off] = _qz<type_i, type_o>()(i[blk_off(oc, ic)],
1075 o[flat_off], alpha, beta, rmode);
1084 constexpr int i_mult_o = blksize_o;
1085 constexpr int i_mult_i = blksize_i;
1087 parallel_nd(G, NB_OC, NB_IC, D, H, W,
1088 [&](int g, int nb_oc, int nb_ic, int d, int h, int w) {
1089 int i_off = wei_blk_off_like_gwei3D<fmt_o>(input_d,
1090 g, i_mult_o * nb_oc, i_mult_i * nb_ic, d, h, w);
1091 int o_off = wei_blk_off_like_gwei3D<fmt_o>(output_d,
1092 g, nb_oc, nb_ic, d, h, w);
1093 auto i = &input[i_off];
1094 auto o = &output[o_off];
1095 const int oc_block = nstl::min(blksize_o, OC - nb_oc * blksize_o);
1096 const int ic_block = nstl::min(blksize_i, IC - nb_ic * blksize_i);
1097 ker(i, o, oc_block, ic_block);
1104 template <SIMPLE_REORDER_TEMPL_DECL>
1105 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1106 typename utils::enable_if<fmt_i == any && (fmt_o == OhIw8o32i || fmt_o == OhIw16o32i) && type_i == mkldnn_bin && type_o == mkldnn_bin>::type>
1108 PLAIN_TO_BLOCKED_IS_APPLICABLE();
1110 static status_t execute(const cpu_reorder_pd_t *pd,
1111 const data_t<type_i> *input, data_t<type_o> *output) {
1112 DECLARE_COMMON_PARAMS();
1114 static constexpr bool w_groups
1115 = format_traits<fmt_o>::data_kind == dk::gwei;
1116 constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
1117 constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
1118 constexpr int blksize_o = fmt_o == OhIw8o32i ? 8 : 16;
1119 constexpr int blksize_i = 32;
1121 const auto &dims = input_d.dims();
1122 const auto &pdims = order_keep
1123 ? output_d.blocking_desc().padding_dims
1124 : input_d.blocking_desc().padding_dims;
1126 const int G = w_groups ? dims[0] : 1;
1127 const int OC = dims[w_groups + 0];
1128 const int NB_OC = pdims[w_groups + 0] / blksize_o;
1129 const int IC = dims[w_groups + 1];
1130 const int NB_IC = pdims[w_groups + 1] / blksize_i;
1131 const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
1132 const int W = dims[w_groups + 3 + is_3d - is_1d];
1134 constexpr int i_mult_o = blksize_o;
1135 constexpr int i_mult_i = blksize_i;
1136 constexpr int nbits = 8;
1138 auto extract_bit = [](uint8_t val, uint8_t bit) -> uint8_t {
1139 return (uint8_t) ((val >> bit) & 0x0001);
1142 parallel_nd(G, NB_OC, NB_IC, H, W,
1143 [&](int g, int nb_oc, int nb_ic, int h, int w) {
1144 const int oc_block = nstl::min(blksize_o, OC - nb_oc * blksize_o);
1145 const int ic_block = nstl::min(blksize_i, IC - nb_ic * blksize_i);
1147 for (int oc = 0; oc < oc_block; ++oc) {
1148 for (int icb = 0; icb < div_up(ic_block, nbits); ++icb) {
1150 uint8_t bin_val = 0x00;
1151 for (int ic = icb*nbits, shift = 0; ic < std::min(IC, (icb + 1)*nbits); ic++, shift++) {
1152 size_t iidx = (i_mult_o * nb_oc + oc) * input_d.blocking_desc().strides[0][0] +
1153 (i_mult_i * nb_ic + ic) *input_d.blocking_desc().strides[0][1] +
1154 h * input_d.blocking_desc().strides[0][2] +
1157 uint8_t bit = extract_bit(input[iidx / nbits], (uint8_t)(iidx % nbits));
1158 bin_val |= (bit << shift);
1161 size_t oidx = wei_blk_off_like_gwei3D<fmt_o>(output_d, g, nb_oc, nb_ic, 0, h, w) + oc * blksize_i + icb * blksize_o;
1162 output[oidx / nbits] = bin_val;
1172 template <SIMPLE_REORDER_TEMPL_DECL>
1173 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1174 typename utils::enable_if<fmt_i == any
1175 && block_format_traits<format_traits<fmt_o>::blk_fmt>::blk_ndims == 2
1176 && fmt_o != OhIw8o4i && fmt_o != gOhIw8o4i && fmt_o != OhIw8o32i && fmt_o != OhIw16o32i>::type>
1178 PLAIN_TO_BLOCKED_IS_APPLICABLE();
1180 static status_t execute(const cpu_reorder_pd_t *pd,
1181 const data_t<type_i> *input, data_t<type_o> *output) {
1182 DECLARE_COMMON_PARAMS();
1184 static constexpr bool w_groups
1185 = format_traits<fmt_o>::data_kind == dk::gwei;
1186 constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
1187 constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
1188 constexpr int blksize = format_traits<fmt_o>::blk_size;
1190 const auto &flat_d = order_keep ? input_d : output_d;
1191 const auto &dims = input_d.dims();
1192 const auto &pdims = order_keep
1193 ? output_d.blocking_desc().padding_dims
1194 : input_d.blocking_desc().padding_dims;
1196 const int G = w_groups ? dims[0] : 1;
1197 const int OC = dims[w_groups + 0];
1198 const int NB_OC = pdims[w_groups + 0] / blksize;
1199 const int IC = dims[w_groups + 1];
1200 const int NB_IC = pdims[w_groups + 1] / blksize;
1201 const int D = is_3d ? dims[w_groups + 2] : 1;
1202 const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
1203 const int W = dims[w_groups + 3 + is_3d - is_1d];
1205 auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
1206 const int oc_block, const int ic_block) {
1207 # define blk_off OI_blk_off<format_traits<fmt_o>::blk_fmt>
1209 if (alpha == 1.0 && beta == 0.0) {
1210 for (int oc = 0; oc < oc_block; ++oc)
1211 for (int ic = 0; ic < ic_block; ++ic) {
1212 const ptrdiff_t flat_off = 0
1213 + oc * flat_d.blocking_desc().strides[0][w_groups + 0]
1214 + ic * flat_d.blocking_desc().strides[0][w_groups + 1];
1216 o[blk_off(oc, ic)] = _qz_a1b0<type_i, type_o>()(
1217 i[flat_off], rmode);
1219 o[flat_off] = _qz_a1b0<type_i, type_o>()(
1220 i[blk_off(oc, ic)], rmode);
1224 for (int oc = 0; oc < oc_block; ++oc)
1225 for (int ic = 0; ic < ic_block; ++ic) {
1226 const ptrdiff_t flat_off = 0
1227 + oc * flat_d.blocking_desc().strides[0][w_groups + 0]
1228 + ic * flat_d.blocking_desc().strides[0][w_groups + 1];
1230 o[blk_off(oc, ic)] = _qz<type_i, type_o>()(i[flat_off],
1231 o[blk_off(oc, ic)], alpha, beta, rmode);
1233 o[flat_off] = _qz<type_i, type_o>()(i[blk_off(oc, ic)],
1234 o[flat_off], alpha, beta, rmode);
1243 constexpr int i_mult = order_keep ? blksize : 1;
1244 constexpr int o_mult = order_keep ? 1 : blksize;
1246 parallel_nd(G, NB_OC, NB_IC, D, H, W,
1247 [&](int g, int nb_oc, int nb_ic, int d, int h, int w) {
1248 auto i = &input[wei_blk_off_like_gwei3D<fmt_o>(input_d,
1249 g, i_mult * nb_oc, i_mult * nb_ic, d, h, w)];
1250 auto o = &output[wei_blk_off_like_gwei3D<fmt_o>(output_d,
1251 g, o_mult * nb_oc, o_mult * nb_ic, d, h, w)];
1252 const int oc_block = nstl::min(blksize, OC - nb_oc * blksize);
1253 const int ic_block = nstl::min(blksize, IC - nb_ic * blksize);
1254 ker(i, o, oc_block, ic_block);
1261 template <SIMPLE_REORDER_TEMPL_DECL>
1262 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1263 typename utils::enable_if<fmt_i == any && (false
1264 || format_traits<fmt_o>::blk_fmt == bf::_4o
1265 || format_traits<fmt_o>::blk_fmt == bf::_8o
1266 || format_traits<fmt_o>::blk_fmt == bf::_16o)>::type>
1268 PLAIN_TO_BLOCKED_IS_APPLICABLE();
1270 static status_t execute(const cpu_reorder_pd_t *pd,
1271 const data_t<type_i> *input, data_t<type_o> *output) {
1272 DECLARE_COMMON_PARAMS();
1274 static constexpr bool w_groups
1275 = format_traits<fmt_o>::data_kind == dk::gwei;
1276 constexpr int is_1d = format_traits<fmt_o>::ndims_sp == 1;
1277 constexpr int is_3d = format_traits<fmt_o>::ndims_sp == 3;
1278 constexpr int blksize = format_traits<fmt_o>::blk_size;
1280 const auto &flat_d = order_keep ? input_d : output_d;
1281 const auto &dims = input_d.dims();
1282 const auto &pdims = order_keep
1283 ? output_d.blocking_desc().padding_dims
1284 : input_d.blocking_desc().padding_dims;
1286 const int G = w_groups ? dims[0] : 1;
1287 const int OC = dims[w_groups + 0];
1288 const int IC = dims[w_groups + 1];
1289 const int D = is_3d ? dims[w_groups + 2] : 1;
1290 const int H = is_1d ? 1 : dims[w_groups + 2 + is_3d];
1291 const int W = dims[w_groups + 3 + is_3d - is_1d];
1293 constexpr int i_mult = order_keep ? blksize : 1;
1294 constexpr int o_mult = order_keep ? 1 : blksize;
1295 const auto strd_oc = flat_d.blocking_desc().strides[0][w_groups];
1297 parallel_nd(G, pdims[w_groups + 0] / blksize, IC, D, H, W,
1298 [&](int g, int nb_oc, int ic, int d, int h, int w) {
1299 auto i = &input[wei_blk_off_like_gwei3D<fmt_o>(input_d,
1300 g, i_mult * nb_oc, ic, d, h, w)];
1301 auto o = &output[wei_blk_off_like_gwei3D<fmt_o>(output_d,
1302 g, o_mult * nb_oc, ic, d, h, w)];
1303 const int oc_block = nstl::min(blksize, OC - nb_oc * blksize);
1305 if (alpha == 1.0 && beta == 0.0) {
1306 for (int oc = 0; oc < oc_block; ++oc) {
1307 const auto off = oc * strd_oc;
1309 o[oc] = _qz_a1b0<type_i, type_o>()(i[off], rmode);
1311 o[off] = _qz_a1b0<type_i, type_o>()(i[oc], rmode);
1315 for (int oc = 0; oc < oc_block; ++oc) {
1316 const auto off = oc * strd_oc;
1318 o[oc] = _qz<type_i, type_o>()(i[off], o[oc], alpha,
1321 o[off] = _qz<type_i, type_o>()(i[oc], o[off], alpha,
1332 /* generic and direct-copy reorders */
1334 template <SIMPLE_REORDER_TEMPL_DECL>
1335 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1336 typename utils::enable_if<
1337 fmt_i == any && fmt_o == any && order_keep == fmt_order::any,
1338 spec::direct_copy>::type>
1340 static bool is_applicable(const memory_desc_wrapper &input_d,
1341 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
1342 /* FIXME: is the formula correct? */
1343 return input_d.similar_to(output_d, true, false, 0)
1344 && input_d.is_dense() && output_d.is_dense()
1345 && simple_attr_check(attr, false);
1348 static status_t execute(const cpu_reorder_pd_t *pd,
1349 const data_t<type_i> *input, data_t<type_o> *output) {
1350 DECLARE_COMMON_PARAMS();
1352 assert(input_d.is_dense());
1354 input += input_d.blk_off(0);
1355 output += output_d.blk_off(0);
1357 const size_t nelems = input_d.nelems();
1359 constexpr int block_size = 16;
1360 const auto num_blocks = nelems / block_size;
1361 const auto rem_elems = nelems % block_size;
1363 parallel(0, [&](const int ithr, const int nthr) {
1364 size_t start{0}, end{0};
1365 balance211(num_blocks, nthr, ithr, start, end);
1366 start = start * block_size;
1367 end = end * block_size;
1369 if (alpha == 1.0 && beta == 0.0) {
1371 for (size_t e = start; e < end; ++e) {
1372 output[e] = qz_a1b0<data_t<type_i>, data_t<type_o>>()
1375 } else if (alpha == 1.0) {
1377 for (size_t e = start; e < end; ++e) {
1378 output[e] = qz_a1<data_t<type_i>, data_t<type_o>>()
1379 (input[e], output[e], beta, rmode);
1381 } else if (beta == 0.0) {
1383 for (size_t e = start; e < end; ++e) {
1384 output[e] = qz_b0<data_t<type_i>, data_t<type_o>>()
1385 (input[e], alpha, rmode);
1389 for (size_t e = start; e < end; ++e) {
1390 output[e] = qz<data_t<type_i>, data_t<type_o>>()
1391 (input[e], output[e], alpha, beta, rmode);
1395 if (rem_elems != 0 && ithr == nthr - 1){
1396 if (alpha == 1.0 && beta == 0.0) {
1398 for (size_t e = nelems - rem_elems; e < nelems; ++e) {
1399 output[e] = qz_a1b0<data_t<type_i>,
1400 data_t<type_o>>()(input[e], rmode);
1402 } else if (alpha == 1.0) {
1404 for (size_t e = nelems - rem_elems; e < nelems; ++e) {
1405 output[e] = qz_a1<data_t<type_i>,
1406 data_t<type_o>>()(input[e], output[e], beta, rmode);
1408 } else if (beta == 0.0) {
1410 for (size_t e = nelems - rem_elems; e < nelems; ++e) {
1411 output[e] = qz_b0<data_t<type_i>,
1412 data_t<type_o>>()(input[e], alpha, rmode);
1416 for (size_t e = nelems - rem_elems; e < nelems; ++e) {
1417 output[e] = qz<data_t<type_i>, data_t<type_o>>()
1418 (input[e], output[e], alpha, beta, rmode);
1427 template <SIMPLE_REORDER_TEMPL_DECL>
1428 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1429 typename utils::enable_if<
1430 fmt_i == any && fmt_o == any && order_keep == fmt_order::any,
1431 spec::direct_copy_except_dim_0>::type>
1433 static bool is_applicable(const memory_desc_wrapper &input_d,
1434 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
1435 auto is_dense_no_0 = [](const memory_desc_wrapper &data_d) {
1436 return nelems_no_dim_0(data_d) == _size_no_dim_0(data_d);
1438 /* FIXME: is the formula correct? */
1439 return input_d.similar_to(output_d, true, false, 1)
1440 && is_dense_no_0(input_d) && is_dense_no_0(output_d)
1441 && simple_attr_check(attr, false);
1444 static status_t execute(const cpu_reorder_pd_t *pd,
1445 const data_t<type_i> *input, data_t<type_o> *output) {
1446 DECLARE_COMMON_PARAMS();
1448 input += input_d.blk_off(0);
1449 output += output_d.blk_off(0);
1451 const int N = input_d.dims()[0];
1452 const size_t is = input_d.blocking_desc().strides[0][0];
1453 const size_t os = output_d.blocking_desc().strides[0][0];
1454 const size_t nelems_no_d0 = nelems_no_dim_0(input_d);
1455 const size_t work_amount = N * nelems_no_d0;
1457 if (alpha == 1.0 && beta == 0.0) {
1458 parallel(0, [&](const int ithr, const int nthr) {
1459 size_t n{0}, dim1_s{0};
1460 size_t start{0}, end{0};
1461 balance211(work_amount, nthr, ithr, start, end);
1462 nd_iterator_init(start, n, N, dim1_s, nelems_no_d0);
1463 while(start < end) {
1464 size_t work_rem = end - start;
1465 size_t dim1_e = dim1_s + work_rem > nelems_no_d0
1466 ? nelems_no_d0 : dim1_s + work_rem;
1468 for (size_t e = dim1_s; e < dim1_e; ++e) {
1469 output[os * n + e] = _qz_a1b0<type_i, type_o>()(
1470 input[is * n + e], rmode);
1472 nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0);
1476 parallel(0, [&](const int ithr, const int nthr) {
1477 size_t n{0}, dim1_s{0};
1478 size_t start{0}, end{0};
1479 balance211(work_amount, nthr, ithr, start, end);
1480 nd_iterator_init(start, n, N, dim1_s, nelems_no_d0);
1481 while(start < end) {
1482 size_t work_rem = end - start;
1484 dim1_s + work_rem > nelems_no_d0 ? nelems_no_d0
1485 : dim1_s + work_rem;
1487 for (size_t e = dim1_s; e < dim1_e; ++e){
1488 output[os * n + e] = _qz<type_i, type_o>()(
1489 input[is * n + e], output[os * n + e], alpha,
1492 nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0);
1501 static size_t nelems_no_dim_0(const memory_desc_wrapper &data_d) {
1502 const int ndims = data_d.ndims();
1503 if (ndims <= 1) return 1;
1504 return utils::array_product(data_d.dims() + 1, data_d.ndims() - 1);
1507 static size_t _size_no_dim_0(const memory_desc_wrapper &data_d) {
1508 size_t max_size = 0;
1509 auto &blk = data_d.blocking_desc();
1510 for (int d = 1; d < data_d.ndims(); ++d) {
1511 auto block = blk.block_dims[d];
1512 max_size = nstl::max(max_size,
1513 size_t(size_t(blk.padding_dims[d] / block)
1514 * blk.strides[0][d]));
1516 max_size = nstl::max(max_size,
1517 size_t(block * blk.strides[1][d]));
1523 template <SIMPLE_REORDER_TEMPL_DECL>
1524 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1525 typename utils::enable_if<
1526 fmt_i == any && fmt_o == any && order_keep == fmt_order::any,
1527 spec::reference>::type>
1529 static bool is_applicable(const memory_desc_wrapper &input_d,
1530 const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
1531 /* supported smask: 0x0...011..10...0,
1532 * i.e. 1 should be contiguous */
1533 int smask = attr ? attr->output_scales_.mask_ : 0;
1534 for (; smask > 0 && !(smask & 0x1); smask >>= 1);
1535 for (; smask > 0 && smask & 0x1; smask >>= 1);
1537 && input_d.is_blocking_desc()
1538 && output_d.is_blocking_desc()
1539 && !output_d.is_additional_buffer()
1540 && !input_d.is_additional_buffer()
1544 static status_t execute(const cpu_reorder_pd_t *pd,
1545 const data_t<type_i> *input, data_t<type_o> *output) {
1546 DECLARE_COMMON_PARAMS();
1548 const size_t nelems = input_d.nelems();
1550 int ndims_start = 0, ndims_mask = 0;
1551 int smask = pd->attr()->output_scales_.mask_;
1552 for (; smask > 0 && !(smask & 0x1); smask >>= 1) ++ndims_start;
1553 for (; smask > 0 && smask & 0x1; smask >>= 1) ++ndims_mask;
1556 const ptrdiff_t D_start
1557 = utils::array_product(input_d.dims(), ndims_start);
1558 const ptrdiff_t D_mask
1559 = utils::array_product(input_d.dims() + ndims_start, ndims_mask);
1560 const ptrdiff_t D_rest = nelems / D_start / D_mask;
1562 const float *scales = pd->attr()->output_scales_.scales_;
1564 parallel_nd(D_start, D_mask, D_rest,
1565 [&](ptrdiff_t ds, ptrdiff_t dm, ptrdiff_t dr) {
1566 const float scale = scales[dm];
1568 const size_t e = (ds * D_mask + dm) * D_rest + dr;
1569 const auto &i = input[input_d.off_l(e)];
1570 auto &o = output[output_d.off_l(e)];
1572 o = _qz<type_i, type_o>()(i, o, scale, beta, rmode);
1580 /* high level class declaration */
1582 template <SIMPLE_REORDER_TEMPL_DECL, typename spec = void>
1583 struct simple_reorder_t: public cpu_primitive_t {
1584 struct pd_t: public cpu_reorder_pd_t {
1585 pd_t(const cpu_memory_pd_t *input_pd, const cpu_memory_pd_t *output_pd,
1586 const primitive_attr_t *attr)
1587 : cpu_reorder_pd_t(input_pd, output_pd, attr) {}
1589 DECLARE_COMMON_PD_T("simple:any", simple_reorder_t);
1591 static status_t create(reorder_pd_t **reorder_pd,
1592 const memory_pd_t *input_pd, const memory_pd_t *output_pd,
1593 const primitive_attr_t *attr) {
1594 assert(input_pd->engine()->kind() == engine_kind::cpu);
1595 assert(output_pd->engine()->kind() == engine_kind::cpu);
1597 && input_pd->desc()->data_type == type_i
1598 && output_pd->desc()->data_type == type_o
1599 && simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, spec>::
1600 is_applicable(input_pd->desc(), output_pd->desc(), attr);
1602 return invalid_arguments;
1604 auto _pd = new pd_t((const cpu_memory_pd_t *)input_pd,
1605 (const cpu_memory_pd_t *)output_pd, attr);
1606 if (_pd == nullptr) return out_of_memory;
1607 if (_pd->init() != success) { delete _pd; return unimplemented; }
1608 return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
1612 simple_reorder_t(const pd_t *apd, const input_vector &inputs,
1613 const output_vector &outputs)
1614 : cpu_primitive_t(apd, inputs, outputs) {}
1616 virtual void execute(event_t *e) const {
1617 auto input = reinterpret_cast<const data_t<type_i> *>(
1618 this->input_memory(0));
1619 auto output = reinterpret_cast<data_t<type_o> *>(this->memory());
1620 simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, spec>::execute(
1621 pd(), input, output);
1622 e->set_state(event_t::ready);
1626 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
1629 #undef SIMPLE_REORDER_TEMPL_DECL
1630 #undef SIMPLE_REORDER_TEMPL_CALL
1638 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s