1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "mkldnn_types.h"
19 #include "c_types_map.hpp"
21 #include "type_helpers.hpp"
22 #include "mkldnn_thread.hpp"
23 #include "math_utils.hpp"
25 #include "simple_q10n.hpp"
27 #include "gemm_x8s8s32x_convolution.hpp"
33 using namespace mkldnn::impl::utils;
34 using namespace mkldnn::impl::math;
35 using namespace mkldnn::impl::memory_tracking::names;
37 template <data_type_t src_type, data_type_t dst_type>
38 void _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::
39 execute_forward() const {
40 auto src_base = reinterpret_cast<const src_data_t *>(this->input_memory(0));
41 auto wei_base = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
42 auto bia_base = reinterpret_cast<const char *>(this->input_memory(2));
43 auto dst_base = reinterpret_cast<dst_data_t *>(this->memory());
45 auto scratchpad = this->scratchpad();
47 const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
49 auto col = scratchpad.template get<uint8_t>(key_conv_gemm_col);
50 parallel_nd(jcp.im2col_sz * jcp.nthr, [&](ptrdiff_t i) {
51 col[i] = jcp.signed_input ? (uint8_t)128 : (uint8_t)0;
54 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
55 execute_forward_thr(ithr, nthr, src_base, wei_base, bia_base, dst_base,
60 template <data_type_t src_type, data_type_t dst_type>
61 _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::pp_ker_t::pp_ker_t(
67 , bias_data_type_(data_type::undef)
68 , bias_data_type_size_(0)
70 , rmode_(round_mode::nearest)
75 using namespace types;
77 const auto dst_md = memory_desc_wrapper(pd->dst_pd());
78 dst_os_stride_ = dst_md.blk_off(0, 0, 0, 1);
80 scale_idx_mult_ = (pd->attr()->output_scales_.mask_ == (1 << 1));
81 rmode_ = pd->attr()->round_mode_;
83 auto &post_ops = pd->attr()->post_ops_;
86 for (int idx = 0; idx < post_ops.len_; ++idx) {
87 const auto &e = post_ops.entry_[idx];
88 if (e.is_relu(true, false)) {
93 do_relu_ = entry_idx >= 0;
95 do_signed_scaling_ = jcp_.signed_input;
97 do_sum_ = post_ops.contain(primitive_kind::sum, 0);
98 do_bias_ = pd->with_bias();
99 bias_data_type_ = pd->desc()->bias_desc.data_type;
101 assert(bias_data_type_ != data_type::undef);
102 bias_data_type_size_ = data_type_size(bias_data_type_);
104 const size_t vlen_start
105 = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
107 for (size_t i = vlen_start; i > 0; i--) {
114 if (!mayiuse(avx512_core))
115 // use fallback code for older CPUs
121 template <data_type_t src_type, data_type_t dst_type>
122 void _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::pp_ker_t::generate()
124 using namespace Xbyak;
125 using namespace utils;
126 using namespace round_mode;
129 Reg64 reg_param = abi_param1;
132 Reg64 reg_bias = rbx;
133 Reg64 reg_scales = rsi;
136 Reg64 reg_tmp = rcx; // intentional for shifting purposes
137 Reg64 reg_oc_offset = r9;
138 Reg64 reg_rem_mask_short = r10;
139 Reg64 reg_rem_mask_vlen = r11;
140 Opmask kreg_rem_mask_short = k1;
141 Opmask kreg_rem_mask_vlen = k3;
142 Opmask kreg_relu_cmp = k2;
144 const size_t vlen = 4;
146 Zmm vreg_zero = Zmm(0);
147 Zmm vreg_scale = Zmm(1);
148 Zmm vreg_nslope = Zmm(2);
149 Zmm vreg_sum_scale = Zmm(3);
150 Zmm vreg_signed_scale = Zmm(4);
152 size_t def_unroll = 4;
153 size_t max_unroll = 12;
160 auto vreg_dst = [&](int idx) {
161 return Zmm(5 + idx * zmm_step + 0);
163 auto vreg_bias = [&](int idx) {
164 return Zmm(5 + idx * zmm_step + 1);
166 auto vreg_prev_dst = [&](int idx) {
167 return Zmm(5 + idx * zmm_step + 2);
172 #define PARAM_OFF(x) offsetof(ker_args, x)
173 mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]);
174 mov(reg_acc, ptr[reg_param + PARAM_OFF(acc)]);
175 mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]);
176 mov(reg_scales, ptr[reg_param + PARAM_OFF(scales)]);
177 mov(reg_len, ptr[reg_param + PARAM_OFF(len)]);
178 mov(reg_oc_offset, ptr[reg_param + PARAM_OFF(oc_offset)]);
179 vbroadcastss(vreg_nslope, ptr[reg_param + PARAM_OFF(nslope)]);
180 vbroadcastss(vreg_sum_scale, ptr[reg_param + PARAM_OFF(sum_scale)]);
181 vbroadcastss(vreg_signed_scale, ptr[reg_param + PARAM_OFF(signed_scale)]);
182 if (scale_idx_mult_ == 0)
183 vbroadcastss(vreg_scale, dword[reg_scales]);
187 mov(reg_rem_mask_vlen, 1);
188 shl(reg_rem_mask_vlen, vlen);
189 sub(reg_rem_mask_vlen, 1);
190 kmovq(kreg_rem_mask_vlen, reg_rem_mask_vlen);
192 if (do_relu_ || dst_type == data_type::u8)
193 vxorps(vreg_zero, vreg_zero, vreg_zero);
195 // Load accumulated value, convert to float, apply sum (if any),
196 // bias (if any), scaling, and relu (if any);
197 // then convert to destination type and store
198 auto compute = [&](size_t offset, int idx, bool apply_mask) {
199 auto acc_addr = ptr[reg_acc + offset * sizeof(acc_data_t)];
201 if (scale_idx_mult_ > 0) {
202 assert(scale_idx_mult_ == 1);
203 auto scale_addr = ptr[reg_scales + offset * sizeof(float)];
204 auto vreg_scale_ = vreg_scale;
206 vreg_scale_ = vreg_scale_ | kreg_rem_mask_short;
208 vreg_scale_ = vreg_scale_ | kreg_rem_mask_vlen;
209 vmovups(vreg_scale_, scale_addr);
212 auto vreg_dst_ = vreg_dst(idx);
214 vreg_dst_ = vreg_dst_ | kreg_rem_mask_short;
216 vreg_dst_ = vreg_dst_ | kreg_rem_mask_vlen;
217 vcvtdq2ps(vreg_dst_, acc_addr);
219 if (do_signed_scaling_)
220 vmulps(vreg_dst(idx), vreg_dst(idx), vreg_signed_scale);
223 auto bias_addr = ptr[reg_bias + offset * bias_data_type_size_];
224 auto vreg_bias_ = vreg_bias(idx);
226 vreg_bias_ = vreg_bias_ | kreg_rem_mask_short;
228 vreg_bias_ = vreg_bias_ | kreg_rem_mask_vlen;
230 switch (bias_data_type_) {
232 vpmovsxbd(vreg_bias_, bias_addr);
235 vpmovzxbd(vreg_bias_, bias_addr);
238 vcvtdq2ps(vreg_bias_, bias_addr);
241 vmovups(vreg_bias_, bias_addr);
243 default: assert(!"unimplemented");
245 vaddps(vreg_dst(idx), vreg_dst(idx), vreg_bias(idx));
248 vmulps(vreg_dst(idx), vreg_dst(idx), vreg_scale);
250 auto dst_addr = ptr[reg_dst + offset * sizeof(dst_data_t)];
254 auto vreg_prev_dst_ = vreg_prev_dst(idx);
256 vreg_prev_dst_ = vreg_prev_dst_ | kreg_rem_mask_short;
258 vreg_prev_dst_ = vreg_prev_dst_ | kreg_rem_mask_vlen;
262 case data_type::s32: vmovups(vreg_prev_dst_, dst_addr); break;
263 case data_type::s8: vpmovsxbd(vreg_prev_dst_, dst_addr); break;
264 case data_type::u8: vpmovzxbd(vreg_prev_dst_, dst_addr); break;
265 default: assert(!"unsupported data type");
267 if (dst_type != data_type::f32)
268 vcvtdq2ps(vreg_prev_dst(idx), vreg_prev_dst(idx));
270 vfmadd231ps(vreg_dst(idx), vreg_prev_dst(idx), vreg_sum_scale);
274 vcmpps(kreg_relu_cmp, vreg_dst(idx), vreg_zero, _cmp_lt_os);
275 vmulps(vreg_dst(idx) | kreg_relu_cmp, vreg_dst(idx), vreg_nslope);
278 if (dst_type != data_type::f32) {
279 auto rmode_control = (rmode_ == nearest ? T_rn_sae : T_rd_sae);
280 vcvtps2dq(vreg_dst(idx) | rmode_control, vreg_dst(idx));
283 if (dst_type == data_type::u8)
284 vpmaxsd(vreg_dst(idx), vreg_dst(idx), vreg_zero);
288 vpmovsdb(dst_addr, vreg_dst_);
291 vpmovusdb(dst_addr, vreg_dst_);
295 vmovups(dst_addr, vreg_dst_);
297 default: assert(!"unimplemented");
301 // Advance all pointers by an immediate
302 auto advance_ptrs_imm = [&](size_t offset) {
303 add(reg_dst, offset * sizeof(dst_data_t));
304 add(reg_acc, offset * sizeof(acc_data_t));
305 if (scale_idx_mult_) {
306 assert(scale_idx_mult_ == 1);
307 add(reg_scales, offset * sizeof(float));
310 add(reg_bias, offset * bias_data_type_size_);
313 // Advance all pointers by a value stored in a register
314 auto advance_ptrs_reg = [&](Reg64 offset) {
315 lea(reg_dst, ptr[reg_dst + offset * sizeof(dst_data_t)]);
316 lea(reg_acc, ptr[reg_acc + offset * sizeof(acc_data_t)]);
317 if (scale_idx_mult_) {
318 assert(scale_idx_mult_ == 1);
319 lea(reg_scales, ptr[reg_scales + offset * sizeof(float)]);
322 lea(reg_bias, ptr[reg_bias + offset * bias_data_type_size_]);
325 // Rewind pointers that point to data that is indexed by output channel
326 // (bias or per-oc scaling factors)
327 auto rewind_ptrs = [&]() {
329 sub(reg_bias, OC_ * bias_data_type_size_);
330 if (scale_idx_mult_) {
331 assert(scale_idx_mult_ == 1);
332 sub(reg_scales, OC_ * sizeof(float));
334 add(reg_dst, (dst_os_stride_ - OC_) * sizeof(dst_data_t));
337 // <--------- OC --------------->
339 // ^ ................+..............+-------------+.......................
340 // | . : not accessed |Prologue loop| .
341 // | . +--------------+-------------+ .
343 // O . | Main loop (unrolled) | .
345 // . +--------------+-------------+ .
346 // | . | Epilogue loop|not accessed : .
347 // v ................+--------------+.............+.......................
350 cmp(reg_oc_offset, 0);
351 je(prologue_end, T_NEAR);
356 sub(reg_tmp, reg_oc_offset);
357 cmp(reg_tmp, reg_len);
358 cmovg(reg_tmp, reg_len);
359 sub(reg_len, reg_tmp);
361 Label prologue_loop, prologue_loop_tail, prologue_loop_end;
363 jle(prologue_loop_tail, T_NEAR);
365 compute(0, 0, false);
366 advance_ptrs_imm(vlen);
369 jge(prologue_loop, T_NEAR);
372 L(prologue_loop_tail);
373 mov(reg_rem_mask_short, 1);
374 // cl == reg_tmp because reg_tmp <= vlen here
375 shl(reg_rem_mask_short, cl);
376 sub(reg_rem_mask_short, 1);
377 jz(prologue_loop_end, T_NEAR);
379 kmovq(kreg_rem_mask_short, reg_rem_mask_short);
381 advance_ptrs_reg(reg_tmp);
383 L(prologue_loop_end);
392 jle(main_loop_end, T_NEAR);
396 size_t OC_loop, OC_tail;
397 if (OC_ < max_unroll * vlen) {
398 // Fully unroll small loops
403 OC_loop = vlen * def_unroll;
404 OC_tail = OC_ % OC_loop;
407 assert(!!OC_loop || !!OC_tail);
409 if (OC_tail % vlen) {
410 int vlen_tail = OC_tail % vlen;
411 unsigned tail_mask = (1 << vlen_tail) - 1;
412 mov(reg_tmp, tail_mask);
413 kmovq(kreg_rem_mask_short, reg_tmp);
417 mov(reg_tmp, rnd_dn(OC_, OC_loop));
420 for (size_t offset = 0; offset < OC_loop; offset += vlen)
421 compute(offset, offset / vlen, false);
422 advance_ptrs_imm(OC_loop);
423 sub(reg_tmp, OC_loop);
429 for (size_t offset = 0; offset < OC_tail; offset += vlen) {
430 bool use_mask = (offset + vlen) > OC_tail;
431 compute(offset, offset / vlen, use_mask);
433 advance_ptrs_imm(OC_tail);
439 jge(main_loop, T_NEAR);
448 je(epilogue_end, T_NEAR);
450 Label epilogue_loop, epilogue_loop_tail;
452 jle(epilogue_loop_tail, T_NEAR);
454 compute(0, 0, false);
456 advance_ptrs_imm(vlen);
458 jge(epilogue_loop, T_NEAR);
461 L(epilogue_loop_tail);
462 mov(reg_tmp, reg_len); // reg_tmp is rcx, and we need cl for the shift
463 mov(reg_rem_mask_short, 1);
464 shl(reg_rem_mask_short, cl); // reg_tmp == rcx and reg_tail < vlen
465 sub(reg_rem_mask_short, 1);
466 jz(epilogue_end, T_NEAR);
467 kmovq(kreg_rem_mask_short, reg_rem_mask_short);
475 ker_ = getCode<decltype(ker_)>();
478 template <data_type_t src_type, data_type_t dst_type>
479 void _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::pp_ker_t::operator ()
480 (dst_data_t *dst, const acc_data_t *acc, const char *bias,
481 const float *scales, float nslope, float sum_scale, float signed_scale,
482 int g, size_t start, size_t end)
484 using math::get_bias;
492 size_t oc_offset = start % OC_;
493 size_t os_offset = start / OC_;
494 args.acc = acc + start;
495 args.dst = dst + os_offset * dst_os_stride_ + oc_offset;
496 args.bias = bias + (g * jcp_.oc + oc_offset) * bias_data_type_size_;
497 args.scales = scales + scale_idx_mult_ * (g * jcp_.oc + oc_offset);
498 args.nslope = nslope;
499 args.sum_scale = sum_scale;
500 args.signed_scale = signed_scale;
501 args.len = end - start;
502 args.oc_offset = oc_offset;
507 const size_t first_oc = start % OC_;
508 const size_t last_oc = (end - 1) % OC_;
509 const size_t first_os = start / OC_;
510 const size_t last_os = (end - 1) / OC_;
511 for (size_t os = first_os; os <= last_os; os++) {
512 const size_t start_oc = (os == first_os) ? first_oc : 0;
513 const size_t end_oc = (os == last_os) ? last_oc : OC_ - 1;
514 for (size_t oc = start_oc; oc <= end_oc; oc++) {
515 const size_t acc_off = os * jcp_.oc + oc;
516 const size_t dst_off = os * dst_os_stride_ + oc;
518 float d = (float)(acc[acc_off]);
519 if (jcp_.signed_input)
523 d += get_bias(bias, g * jcp_.oc + oc,
526 d *= scales[(g * jcp_.oc + oc) * scale_idx_mult_];
528 d += sum_scale * dst[dst_off];
529 if (do_relu_ && d < 0)
531 dst[dst_off] = qz_a1b0<float, dst_data_t>()(d, rmode_);
537 template <data_type_t src_type, data_type_t dst_type>
538 void _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::
539 execute_forward_thr(const int ithr, const int nthr, const src_data_t *src_base,
540 const wei_data_t *wei_base, const char *bia_base, dst_data_t *dst_base,
541 const memory_tracking::grantor_t &scratchpad) const {
542 const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
544 const auto src_md = memory_desc_wrapper(pd()->src_pd());
545 const size_t src_mb_stride = src_md.blk_off(1);
546 const size_t src_g_stride = src_md.blk_off(0, 1) * jcp.ic;
548 const auto wei_md = memory_desc_wrapper(pd()->weights_pd(0));
549 const size_t wei_g_stride = pd()->with_groups() ? wei_md.blk_off(1) : 0;
551 const auto dst_md = memory_desc_wrapper(pd()->dst_pd());
552 const size_t dst_mb_stride = dst_md.blk_off(1);
553 const size_t dst_g_stride = dst_md.blk_off(0, 1) * jcp.oc;
555 const float *scales = pd()->attr()->output_scales_.scales_;
557 const auto &post_ops = pd()->attr()->post_ops_;
558 const bool do_sum = post_ops.contain(primitive_kind::sum, 0);
559 const float sum_scale = do_sum ? post_ops.entry_[0].sum.scale : 0;
562 for (int idx = 0; idx < post_ops.len_; ++idx) {
563 const auto &e = post_ops.entry_[idx];
564 if (e.is_relu(true, false)) {
565 nslope = e.eltwise.alpha;
570 auto col = scratchpad.get<uint8_t>(key_conv_gemm_col)
571 + (ptrdiff_t)ithr * jcp.im2col_sz;
572 auto acc = scratchpad.get<acc_data_t>(key_conv_int_dat_in_acc_dt)
573 + (ptrdiff_t)ithr * jcp.os * jcp.oc;
575 const ptrdiff_t offset = (ptrdiff_t)jcp.ngroups * jcp.ks * jcp.ic * jcp.oc;
576 const int32_t *_wei_comp = (const int32_t *)(wei_base + offset);
579 size_t start = 0, end = 0;
581 const size_t work_amount = jcp.ngroups * jcp.mb;
582 balance211(work_amount, nthr, ithr, start, end);
583 nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups);
585 for (size_t iwork = start; iwork < end; ++iwork) {
586 const src_data_t *src = src_base + n * src_mb_stride
588 const wei_data_t *wei = wei_base + g * wei_g_stride;
589 dst_data_t *dst = dst_base + n * dst_mb_stride + g * dst_g_stride;
590 const int32_t *wei_comp = _wei_comp + g * jcp.oc;
593 jit_gemm_convolution_utils::im2col_u8<src_data_t>(jcp, src, col);
595 const int M = jcp.oc;
596 const int K = jcp.ks * jcp.ic;
597 const int N = jcp.os;
598 const int LD = M * jcp.ngroups;
599 const int8_t off_a = 0, off_b = 0;
600 const int32_t off_c = 0;
601 const float onef = 1.0, zerof = 0.0;
603 mkldnn_gemm_s8u8s32("N", "N", jcp.signed_input ? "C" : "F",
604 &M, &N, &K, &onef, wei, &LD, &off_a,
605 jcp.im2col_sz ? col : (uint8_t *)src, &K, &off_b,
606 &zerof, acc, &M, jcp.signed_input ? wei_comp : &off_c);
608 parallel(0, [&](int ithr, int nthr) {
610 balance211((size_t)jcp.os * jcp.oc, nthr, ithr, start, end);
611 (*pp_ker_)(dst, acc, bia_base, scales, nslope, sum_scale,
612 jcp.signed_input ? 1.f / jcp.wei_adj_scale : 1.f,
616 nd_iterator_step(n, jcp.mb, g, jcp.ngroups);
620 template <data_type_t dst_type>
621 void _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>::
622 execute_backward_data() const {
623 auto diff_dst_base = reinterpret_cast<const diff_dst_data_t *>
624 (this->input_memory(0));
625 auto wei_base = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
626 auto bia_base = reinterpret_cast<const char *>(this->input_memory(2));
627 auto diff_src_base = reinterpret_cast<diff_src_data_t *>(this->memory());
629 auto scratchpad = this->scratchpad();
631 const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
633 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
634 execute_backward_data_thr(ithr, nthr, diff_dst_base, wei_base,
635 bia_base, diff_src_base, scratchpad);
639 template <data_type_t dst_type>
640 void _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>::
641 execute_backward_data_thr(const int ithr, const int nthr,
642 const diff_dst_data_t *diff_dst_base, const wei_data_t *wei_base,
643 const char *bia_base, diff_src_data_t *diff_src_base,
644 const memory_tracking::grantor_t &scratchpad) const
646 const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
648 const auto diff_dst_md = memory_desc_wrapper(pd()->diff_dst_pd());
649 const size_t diff_dst_mb_stride = diff_dst_md.blk_off(1);
650 const size_t diff_dst_g_stride = diff_dst_md.blk_off(0, 1) * jcp.oc;
652 const auto wei_md = memory_desc_wrapper(pd()->weights_pd(0));
653 const size_t wei_g_stride = pd()->with_groups() ? wei_md.blk_off(1) : 0;
655 const auto diff_src_md = memory_desc_wrapper(pd()->diff_src_pd());
656 const size_t diff_src_mb_stride = diff_src_md.blk_off(1);
657 const size_t diff_src_g_stride = diff_src_md.blk_off(0, 1) * jcp.ic;
658 const size_t diff_src_os_stride = diff_src_md.blk_off(0, 0, 0, 1);
660 /* scale_idx_mult = 1 for per_oc scales and 0, otherwise */
661 const int scale_idx_mult = pd()->attr()->output_scales_.mask_ == (1 << 1);
662 const float *scales = pd()->attr()->output_scales_.scales_;
663 const auto rmode = pd()->attr()->round_mode_;
664 const size_t work_amount = jcp.ngroups * jcp.mb;
666 auto col = scratchpad.get<acc_data_t>(key_conv_gemm_col)
667 + (ptrdiff_t)ithr * jcp.im2col_sz;
668 auto acc = scratchpad.get<acc_data_t>(key_conv_int_dat_in_acc_dt)
669 + (ptrdiff_t)ithr * jcp.is * jcp.ic;
672 size_t start = 0, end = 0;
674 balance211(work_amount, nthr, ithr, start, end);
675 nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups);
677 for (size_t iwork = start; iwork < end; ++iwork) {
678 const diff_dst_data_t *diff_dst = diff_dst_base
679 + n * diff_dst_mb_stride + g * diff_dst_g_stride;
680 const wei_data_t *wei = wei_base + g * wei_g_stride;
681 diff_src_data_t *diff_src = diff_src_base + n * diff_src_mb_stride
682 + g * diff_src_g_stride;
684 const int M = jcp.ks * jcp.ic;
685 const int N = jcp.os;
686 const int K = jcp.oc;
687 const int8_t off_a = 0, off_b = 0;
688 const int32_t off_c = 0;
689 const float onef = 1.0, zerof = 0.0;
690 const int LD = K * jcp.ngroups;
692 mkldnn_gemm_s8u8s32("T", "N", "F", &M, &N, &K, &onef,
693 wei, &LD, &off_a, diff_dst, &LD, &off_b,
694 &zerof, jcp.im2col_sz ? col : acc, &M, &off_c);
697 jit_gemm_convolution_utils::col2im_s32(jcp, col, acc);
699 parallel_nd(jcp.is, jcp.ic, [&](int is, int ic) {
700 float d = (float)acc[is * jcp.ic + ic];
702 d += get_bias(bia_base, g * jcp.ic + ic,
703 pd()->desc()->bias_desc.data_type);
704 d *= scales[(g * jcp.ic + ic) * scale_idx_mult];
705 const size_t diff_src_off = is * diff_src_os_stride + ic;
706 diff_src[diff_src_off] =
707 qz_a1b0<float, diff_src_data_t>()(d, rmode);
709 nd_iterator_step(n, jcp.mb, g, jcp.ngroups);
713 using namespace data_type;
715 template struct _gemm_x8s8s32x_convolution_fwd_t<u8, f32>;
716 template struct _gemm_x8s8s32x_convolution_fwd_t<u8, s32>;
717 template struct _gemm_x8s8s32x_convolution_fwd_t<u8, s8>;
718 template struct _gemm_x8s8s32x_convolution_fwd_t<u8, u8>;
720 template struct _gemm_x8s8s32x_convolution_fwd_t<s8, f32>;
721 template struct _gemm_x8s8s32x_convolution_fwd_t<s8, s32>;
722 template struct _gemm_x8s8s32x_convolution_fwd_t<s8, s8>;
723 template struct _gemm_x8s8s32x_convolution_fwd_t<s8, u8>;
725 template struct _gemm_u8s8s32x_convolution_bwd_data_t<f32>;
726 template struct _gemm_u8s8s32x_convolution_bwd_data_t<s32>;
727 template struct _gemm_u8s8s32x_convolution_bwd_data_t<s8>;
728 template struct _gemm_u8s8s32x_convolution_bwd_data_t<u8>;