Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm_x8s8s32x_convolution.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn_types.h"
18
19 #include "c_types_map.hpp"
20 #include "utils.hpp"
21 #include "type_helpers.hpp"
22 #include "mkldnn_thread.hpp"
23 #include "math_utils.hpp"
24
25 #include "simple_q10n.hpp"
26
27 #include "gemm_x8s8s32x_convolution.hpp"
28
29 namespace mkldnn {
30 namespace impl {
31 namespace cpu {
32
33 using namespace mkldnn::impl::utils;
34 using namespace mkldnn::impl::math;
35 using namespace mkldnn::impl::memory_tracking::names;
36
37 template <data_type_t src_type, data_type_t dst_type>
38 void _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::
39 execute_forward() const {
40     auto src_base = reinterpret_cast<const src_data_t *>(this->input_memory(0));
41     auto wei_base = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
42     auto bia_base = reinterpret_cast<const char *>(this->input_memory(2));
43     auto dst_base = reinterpret_cast<dst_data_t *>(this->memory());
44
45     auto scratchpad = this->scratchpad();
46
47     const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
48
49     auto col = scratchpad.template get<uint8_t>(key_conv_gemm_col);
50     parallel_nd(jcp.im2col_sz * jcp.nthr, [&](ptrdiff_t i) {
51         col[i] = jcp.signed_input ? (uint8_t)128 : (uint8_t)0;
52     });
53
54     parallel(jcp.nthr, [&](const int ithr, const int nthr) {
55         execute_forward_thr(ithr, nthr, src_base, wei_base, bia_base, dst_base,
56                 scratchpad);
57     });
58 }
59
60 template <data_type_t src_type, data_type_t dst_type>
61 _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::pp_ker_t::pp_ker_t(
62     const pd_t *pd)
63     : ker_(nullptr)
64     , jcp_(pd->jcp_)
65     , OC_(pd->jcp_.oc)
66     , OS_(pd->jcp_.os)
67     , bias_data_type_(data_type::undef)
68     , bias_data_type_size_(0)
69     , scale_idx_mult_(0)
70     , rmode_(round_mode::nearest)
71     , do_bias_(false)
72     , do_relu_(false)
73     , do_sum_(false)
74 {
75     using namespace types;
76
77     const auto dst_md = memory_desc_wrapper(pd->dst_pd());
78     dst_os_stride_ = dst_md.blk_off(0, 0, 0, 1);
79
80     scale_idx_mult_ = (pd->attr()->output_scales_.mask_ == (1 << 1));
81     rmode_ = pd->attr()->round_mode_;
82
83     auto &post_ops = pd->attr()->post_ops_;
84
85     int entry_idx = -1;
86     for (int idx = 0; idx < post_ops.len_; ++idx) {
87         const auto &e = post_ops.entry_[idx];
88         if (e.is_relu(true, false)) {
89             entry_idx = idx;
90             break;
91         }
92     }
93     do_relu_ = entry_idx >= 0;
94
95     do_signed_scaling_ = jcp_.signed_input;
96
97     do_sum_ = post_ops.contain(primitive_kind::sum, 0);
98     do_bias_ = pd->with_bias();
99     bias_data_type_ = pd->desc()->bias_desc.data_type;
100     if (do_bias_) {
101         assert(bias_data_type_ != data_type::undef);
102         bias_data_type_size_ = data_type_size(bias_data_type_);
103     }
104     const size_t vlen_start
105             = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
106
107     for (size_t i = vlen_start; i > 0; i--) {
108         if (OC_ % i == 0) {
109             vlen_ = i;
110             break;
111         }
112     }
113
114     if (!mayiuse(avx512_core))
115         // use fallback code for older CPUs
116         return;
117     else
118         generate();
119 }
120
121 template <data_type_t src_type, data_type_t dst_type>
122 void _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::pp_ker_t::generate()
123 {
124     using namespace Xbyak;
125     using namespace utils;
126     using namespace round_mode;
127
128     // TODO: clean-up
129     Reg64 reg_param = abi_param1;
130     Reg64 reg_dst = rdx;
131     Reg64 reg_acc = rax;
132     Reg64 reg_bias = rbx;
133     Reg64 reg_scales = rsi;
134
135     Reg64 reg_len = r8;
136     Reg64 reg_tmp = rcx; // intentional for shifting purposes
137     Reg64 reg_oc_offset = r9;
138     Reg64 reg_rem_mask_short = r10;
139     Reg64 reg_rem_mask_vlen = r11;
140     Opmask kreg_rem_mask_short = k1;
141     Opmask kreg_rem_mask_vlen = k3;
142     Opmask kreg_relu_cmp = k2;
143
144     const size_t vlen = 4;
145
146     Zmm vreg_zero = Zmm(0);
147     Zmm vreg_scale = Zmm(1);
148     Zmm vreg_nslope = Zmm(2);
149     Zmm vreg_sum_scale = Zmm(3);
150     Zmm vreg_signed_scale = Zmm(4);
151
152     size_t def_unroll = 4;
153     size_t max_unroll = 12;
154     size_t zmm_step = 2;
155     if (do_sum_) {
156         max_unroll = 8;
157         zmm_step = 3;
158     }
159
160     auto vreg_dst = [&](int idx) {
161         return Zmm(5 + idx * zmm_step + 0);
162     };
163     auto vreg_bias = [&](int idx) {
164         return Zmm(5 + idx * zmm_step + 1);
165     };
166     auto vreg_prev_dst = [&](int idx) {
167         return Zmm(5 + idx * zmm_step + 2);
168     };
169
170     preamble();
171
172 #define PARAM_OFF(x) offsetof(ker_args, x)
173     mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]);
174     mov(reg_acc, ptr[reg_param + PARAM_OFF(acc)]);
175     mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]);
176     mov(reg_scales, ptr[reg_param + PARAM_OFF(scales)]);
177     mov(reg_len, ptr[reg_param + PARAM_OFF(len)]);
178     mov(reg_oc_offset, ptr[reg_param + PARAM_OFF(oc_offset)]);
179     vbroadcastss(vreg_nslope, ptr[reg_param + PARAM_OFF(nslope)]);
180     vbroadcastss(vreg_sum_scale, ptr[reg_param + PARAM_OFF(sum_scale)]);
181     vbroadcastss(vreg_signed_scale, ptr[reg_param + PARAM_OFF(signed_scale)]);
182     if (scale_idx_mult_ == 0)
183         vbroadcastss(vreg_scale, dword[reg_scales]);
184
185 #undef PARAM_OFF
186
187     mov(reg_rem_mask_vlen, 1);
188     shl(reg_rem_mask_vlen, vlen);
189     sub(reg_rem_mask_vlen, 1);
190     kmovq(kreg_rem_mask_vlen, reg_rem_mask_vlen);
191
192     if (do_relu_ || dst_type == data_type::u8)
193         vxorps(vreg_zero, vreg_zero, vreg_zero);
194
195     // Load accumulated value, convert to float, apply sum (if any),
196     // bias (if any), scaling, and relu (if any);
197     // then convert to destination type and store
198     auto compute = [&](size_t offset, int idx, bool apply_mask) {
199         auto acc_addr = ptr[reg_acc + offset * sizeof(acc_data_t)];
200
201         if (scale_idx_mult_ > 0) {
202             assert(scale_idx_mult_ == 1);
203             auto scale_addr = ptr[reg_scales + offset * sizeof(float)];
204             auto vreg_scale_ = vreg_scale;
205             if (apply_mask)
206                 vreg_scale_ = vreg_scale_ | kreg_rem_mask_short;
207             else
208                 vreg_scale_ = vreg_scale_ | kreg_rem_mask_vlen;
209             vmovups(vreg_scale_, scale_addr);
210         }
211
212         auto vreg_dst_ = vreg_dst(idx);
213         if (apply_mask)
214             vreg_dst_ = vreg_dst_ | kreg_rem_mask_short;
215         else
216             vreg_dst_ = vreg_dst_ | kreg_rem_mask_vlen;
217         vcvtdq2ps(vreg_dst_, acc_addr);
218
219         if (do_signed_scaling_)
220             vmulps(vreg_dst(idx), vreg_dst(idx), vreg_signed_scale);
221
222         if (do_bias_) {
223             auto bias_addr = ptr[reg_bias + offset * bias_data_type_size_];
224             auto vreg_bias_ = vreg_bias(idx);
225             if (apply_mask)
226                 vreg_bias_ = vreg_bias_ | kreg_rem_mask_short;
227             else
228                 vreg_bias_ = vreg_bias_ | kreg_rem_mask_vlen;
229
230             switch (bias_data_type_) {
231             case data_type::s8:
232                 vpmovsxbd(vreg_bias_, bias_addr);
233                 break;
234             case data_type::u8:
235                 vpmovzxbd(vreg_bias_, bias_addr);
236                 break;
237             case data_type::s32:
238                 vcvtdq2ps(vreg_bias_, bias_addr);
239                 break;
240             case data_type::f32:
241                 vmovups(vreg_bias_, bias_addr);
242                 break;
243             default: assert(!"unimplemented");
244             }
245             vaddps(vreg_dst(idx), vreg_dst(idx), vreg_bias(idx));
246         }
247
248         vmulps(vreg_dst(idx), vreg_dst(idx), vreg_scale);
249
250         auto dst_addr = ptr[reg_dst + offset * sizeof(dst_data_t)];
251
252         if (do_sum_)
253         {
254             auto vreg_prev_dst_ = vreg_prev_dst(idx);
255             if (apply_mask)
256                 vreg_prev_dst_ = vreg_prev_dst_ | kreg_rem_mask_short;
257             else
258                 vreg_prev_dst_ = vreg_prev_dst_ | kreg_rem_mask_vlen;
259
260             switch (dst_type) {
261             case data_type::f32:
262             case data_type::s32: vmovups(vreg_prev_dst_, dst_addr); break;
263             case data_type::s8: vpmovsxbd(vreg_prev_dst_, dst_addr); break;
264             case data_type::u8: vpmovzxbd(vreg_prev_dst_, dst_addr); break;
265             default: assert(!"unsupported data type");
266             }
267             if (dst_type != data_type::f32)
268                 vcvtdq2ps(vreg_prev_dst(idx), vreg_prev_dst(idx));
269
270             vfmadd231ps(vreg_dst(idx), vreg_prev_dst(idx), vreg_sum_scale);
271         }
272
273         if (do_relu_) {
274             vcmpps(kreg_relu_cmp, vreg_dst(idx), vreg_zero, _cmp_lt_os);
275             vmulps(vreg_dst(idx) | kreg_relu_cmp, vreg_dst(idx), vreg_nslope);
276         }
277
278         if (dst_type != data_type::f32) {
279             auto rmode_control = (rmode_ == nearest ? T_rn_sae : T_rd_sae);
280             vcvtps2dq(vreg_dst(idx) | rmode_control, vreg_dst(idx));
281         }
282
283         if (dst_type == data_type::u8)
284             vpmaxsd(vreg_dst(idx), vreg_dst(idx), vreg_zero);
285
286         switch (dst_type) {
287         case data_type::s8:
288             vpmovsdb(dst_addr, vreg_dst_);
289             break;
290         case data_type::u8:
291             vpmovusdb(dst_addr, vreg_dst_);
292             break;
293         case data_type::f32:
294         case data_type::s32:
295             vmovups(dst_addr, vreg_dst_);
296             break;
297         default: assert(!"unimplemented");
298         }
299     };
300
301     // Advance all pointers by an immediate
302     auto advance_ptrs_imm = [&](size_t offset) {
303         add(reg_dst, offset * sizeof(dst_data_t));
304         add(reg_acc, offset * sizeof(acc_data_t));
305         if (scale_idx_mult_) {
306             assert(scale_idx_mult_ == 1);
307             add(reg_scales, offset * sizeof(float));
308         }
309         if (do_bias_)
310             add(reg_bias, offset * bias_data_type_size_);
311     };
312
313     // Advance all pointers by a value stored in a register
314     auto advance_ptrs_reg = [&](Reg64 offset) {
315         lea(reg_dst, ptr[reg_dst + offset * sizeof(dst_data_t)]);
316         lea(reg_acc, ptr[reg_acc + offset * sizeof(acc_data_t)]);
317         if (scale_idx_mult_) {
318             assert(scale_idx_mult_ == 1);
319             lea(reg_scales, ptr[reg_scales + offset * sizeof(float)]);
320         }
321         if (do_bias_)
322             lea(reg_bias, ptr[reg_bias + offset * bias_data_type_size_]);
323     };
324
325     // Rewind pointers that point to data that is indexed by output channel
326     // (bias or per-oc scaling factors)
327     auto rewind_ptrs = [&]() {
328         if (do_bias_)
329             sub(reg_bias, OC_ * bias_data_type_size_);
330         if (scale_idx_mult_) {
331             assert(scale_idx_mult_ == 1);
332             sub(reg_scales, OC_ * sizeof(float));
333         }
334         add(reg_dst, (dst_os_stride_ - OC_) * sizeof(dst_data_t));
335     };
336
337     //                    <--------- OC --------------->
338     //
339     // ^  ................+..............+-------------+.......................
340     // |  .               : not accessed |Prologue loop|                      .
341     // |  .               +--------------+-------------+                      .
342     //    .               |                            |                      .
343     // O  .               |  Main loop (unrolled)      |                      .
344     // S  .               |                            |                      .
345     //    .               +--------------+-------------+                      .
346     // |  .               | Epilogue loop|not accessed :                      .
347     // v  ................+--------------+.............+.......................
348
349     Label prologue_end;
350     cmp(reg_oc_offset, 0);
351     je(prologue_end, T_NEAR);
352
353     // Prologue loop
354     {
355         mov(reg_tmp, OC_);
356         sub(reg_tmp, reg_oc_offset);
357         cmp(reg_tmp, reg_len);
358         cmovg(reg_tmp, reg_len);
359         sub(reg_len, reg_tmp);
360
361         Label prologue_loop, prologue_loop_tail, prologue_loop_end;
362         cmp(reg_tmp, vlen);
363         jle(prologue_loop_tail, T_NEAR);
364         L(prologue_loop); {
365             compute(0, 0, false);
366             advance_ptrs_imm(vlen);
367             sub(reg_tmp, vlen);
368             cmp(reg_tmp, vlen);
369             jge(prologue_loop, T_NEAR);
370         }
371
372         L(prologue_loop_tail);
373         mov(reg_rem_mask_short, 1);
374         // cl == reg_tmp because reg_tmp <= vlen here
375         shl(reg_rem_mask_short, cl);
376         sub(reg_rem_mask_short, 1);
377         jz(prologue_loop_end, T_NEAR);
378
379         kmovq(kreg_rem_mask_short, reg_rem_mask_short);
380         compute(0, 0, true);
381         advance_ptrs_reg(reg_tmp);
382
383         L(prologue_loop_end);
384         rewind_ptrs();
385     }
386     L(prologue_end);
387
388     // Main loop
389     Label main_loop_end;
390     {
391         cmp(reg_len, OC_);
392         jle(main_loop_end, T_NEAR);
393
394         Label main_loop;
395         L(main_loop); {
396             size_t OC_loop, OC_tail;
397             if (OC_ < max_unroll * vlen) {
398                 // Fully unroll small loops
399                 OC_loop = 0;
400                 OC_tail = OC_;
401             }
402             else {
403                 OC_loop = vlen * def_unroll;
404                 OC_tail = OC_ % OC_loop;
405             }
406
407             assert(!!OC_loop || !!OC_tail);
408
409             if (OC_tail % vlen) {
410                 int vlen_tail = OC_tail % vlen;
411                 unsigned tail_mask = (1 << vlen_tail) - 1;
412                 mov(reg_tmp, tail_mask);
413                 kmovq(kreg_rem_mask_short, reg_tmp);
414             }
415
416             if (OC_loop) {
417                 mov(reg_tmp, rnd_dn(OC_, OC_loop));
418                 Label oc_loop;
419                 L(oc_loop); {
420                     for (size_t offset = 0; offset < OC_loop; offset += vlen)
421                         compute(offset, offset / vlen, false);
422                     advance_ptrs_imm(OC_loop);
423                     sub(reg_tmp, OC_loop);
424                     jnz(oc_loop);
425                 }
426             }
427
428             if (OC_tail) {
429                 for (size_t offset = 0; offset < OC_tail; offset += vlen) {
430                     bool use_mask = (offset + vlen) > OC_tail;
431                     compute(offset, offset / vlen, use_mask);
432                 }
433                 advance_ptrs_imm(OC_tail);
434             }
435
436             rewind_ptrs();
437             sub(reg_len, OC_);
438             cmp(reg_len, OC_);
439             jge(main_loop, T_NEAR);
440         }
441     }
442     L(main_loop_end);
443
444     // Epilogue loop
445     Label epilogue_end;
446     {
447         cmp(reg_len, 0);
448         je(epilogue_end, T_NEAR);
449
450         Label epilogue_loop, epilogue_loop_tail;
451         cmp(reg_len, vlen);
452         jle(epilogue_loop_tail, T_NEAR);
453         L(epilogue_loop); {
454             compute(0, 0, false);
455             sub(reg_len, vlen);
456             advance_ptrs_imm(vlen);
457             cmp(reg_len, vlen);
458             jge(epilogue_loop, T_NEAR);
459         }
460
461         L(epilogue_loop_tail);
462         mov(reg_tmp, reg_len); // reg_tmp is rcx, and we need cl for the shift
463         mov(reg_rem_mask_short, 1);
464         shl(reg_rem_mask_short, cl); // reg_tmp == rcx and reg_tail < vlen
465         sub(reg_rem_mask_short, 1);
466         jz(epilogue_end, T_NEAR);
467         kmovq(kreg_rem_mask_short, reg_rem_mask_short);
468         compute(0, 0, true);
469     }
470
471     L(epilogue_end);
472
473     postamble();
474
475     ker_ = getCode<decltype(ker_)>();
476 }
477
478 template <data_type_t src_type, data_type_t dst_type>
479 void _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::pp_ker_t::operator ()
480     (dst_data_t *dst, const acc_data_t *acc, const char *bias,
481         const float *scales, float nslope, float sum_scale, float signed_scale,
482         int g, size_t start, size_t end)
483 {
484     using math::get_bias;
485
486     if (end <= start)
487         return;
488
489     if (ker_) {
490         // JIT
491         ker_args args;
492         size_t oc_offset = start % OC_;
493         size_t os_offset = start / OC_;
494         args.acc = acc + start;
495         args.dst = dst + os_offset * dst_os_stride_ + oc_offset;
496         args.bias = bias + (g * jcp_.oc + oc_offset) * bias_data_type_size_;
497         args.scales = scales + scale_idx_mult_ * (g * jcp_.oc + oc_offset);
498         args.nslope = nslope;
499         args.sum_scale = sum_scale;
500         args.signed_scale = signed_scale;
501         args.len = end - start;
502         args.oc_offset = oc_offset;
503         ker_(&args);
504     }
505     else {
506         // Fallback
507         const size_t first_oc = start % OC_;
508         const size_t last_oc = (end - 1) % OC_;
509         const size_t first_os = start / OC_;
510         const size_t last_os = (end - 1) / OC_;
511         for (size_t os = first_os; os <= last_os; os++) {
512             const size_t start_oc = (os == first_os) ? first_oc : 0;
513             const size_t end_oc = (os == last_os) ? last_oc : OC_ - 1;
514             for (size_t oc = start_oc; oc <= end_oc; oc++) {
515                 const size_t acc_off = os * jcp_.oc + oc;
516                 const size_t dst_off = os * dst_os_stride_ + oc;
517
518                 float d = (float)(acc[acc_off]);
519                 if (jcp_.signed_input)
520                     d *= signed_scale;
521
522                 if (do_bias_)
523                     d += get_bias(bias, g * jcp_.oc + oc,
524                         bias_data_type_);
525
526                 d *= scales[(g * jcp_.oc + oc) * scale_idx_mult_];
527                 if (do_sum_)
528                     d += sum_scale * dst[dst_off];
529                 if (do_relu_ && d < 0)
530                     d *= nslope;
531                 dst[dst_off] = qz_a1b0<float, dst_data_t>()(d, rmode_);
532             }
533         }
534     }
535 };
536
537 template <data_type_t src_type, data_type_t dst_type>
538 void _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::
539 execute_forward_thr(const int ithr, const int nthr, const src_data_t *src_base,
540         const wei_data_t *wei_base, const char *bia_base, dst_data_t *dst_base,
541         const memory_tracking::grantor_t &scratchpad) const {
542     const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
543
544     const auto src_md = memory_desc_wrapper(pd()->src_pd());
545     const size_t src_mb_stride = src_md.blk_off(1);
546     const size_t src_g_stride = src_md.blk_off(0, 1) * jcp.ic;
547
548     const auto wei_md = memory_desc_wrapper(pd()->weights_pd(0));
549     const size_t wei_g_stride = pd()->with_groups() ? wei_md.blk_off(1) : 0;
550
551     const auto dst_md = memory_desc_wrapper(pd()->dst_pd());
552     const size_t dst_mb_stride = dst_md.blk_off(1);
553     const size_t dst_g_stride = dst_md.blk_off(0, 1) * jcp.oc;
554
555     const float *scales = pd()->attr()->output_scales_.scales_;
556
557     const auto &post_ops = pd()->attr()->post_ops_;
558     const bool do_sum = post_ops.contain(primitive_kind::sum, 0);
559     const float sum_scale = do_sum ? post_ops.entry_[0].sum.scale : 0;
560
561     float nslope = 0;
562     for (int idx = 0; idx < post_ops.len_; ++idx) {
563         const auto &e = post_ops.entry_[idx];
564         if (e.is_relu(true, false)) {
565             nslope = e.eltwise.alpha;
566             break;
567         }
568     }
569
570     auto col = scratchpad.get<uint8_t>(key_conv_gemm_col)
571         + (ptrdiff_t)ithr * jcp.im2col_sz;
572     auto acc = scratchpad.get<acc_data_t>(key_conv_int_dat_in_acc_dt)
573         + (ptrdiff_t)ithr * jcp.os * jcp.oc;
574
575     const ptrdiff_t offset = (ptrdiff_t)jcp.ngroups * jcp.ks * jcp.ic * jcp.oc;
576     const int32_t *_wei_comp = (const int32_t *)(wei_base + offset);
577
578     int n{0}, g{0};
579     size_t start = 0, end = 0;
580
581     const size_t work_amount = jcp.ngroups * jcp.mb;
582     balance211(work_amount, nthr, ithr, start, end);
583     nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups);
584
585     for (size_t iwork = start; iwork < end; ++iwork) {
586         const src_data_t *src = src_base + n * src_mb_stride
587             + g * src_g_stride;
588         const wei_data_t *wei = wei_base + g * wei_g_stride;
589         dst_data_t *dst = dst_base + n * dst_mb_stride + g * dst_g_stride;
590         const int32_t *wei_comp = _wei_comp + g * jcp.oc;
591
592         if (jcp.im2col_sz)
593             jit_gemm_convolution_utils::im2col_u8<src_data_t>(jcp, src, col);
594
595         const int M = jcp.oc;
596         const int K = jcp.ks * jcp.ic;
597         const int N = jcp.os;
598         const int LD = M * jcp.ngroups;
599         const int8_t off_a = 0, off_b = 0;
600         const int32_t off_c = 0;
601         const float onef = 1.0, zerof = 0.0;
602
603         mkldnn_gemm_s8u8s32("N", "N", jcp.signed_input ? "C" : "F",
604                 &M, &N, &K, &onef, wei, &LD, &off_a,
605                 jcp.im2col_sz ? col : (uint8_t *)src, &K, &off_b,
606                 &zerof, acc, &M, jcp.signed_input ? wei_comp : &off_c);
607
608         parallel(0, [&](int ithr, int nthr) {
609             size_t start, end;
610             balance211((size_t)jcp.os * jcp.oc, nthr, ithr, start, end);
611             (*pp_ker_)(dst, acc, bia_base, scales, nslope, sum_scale,
612                     jcp.signed_input ? 1.f / jcp.wei_adj_scale : 1.f,
613                     g, start, end);
614         });
615
616         nd_iterator_step(n, jcp.mb, g, jcp.ngroups);
617     }
618 }
619
620 template <data_type_t dst_type>
621 void _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>::
622 execute_backward_data() const {
623     auto diff_dst_base = reinterpret_cast<const diff_dst_data_t *>
624             (this->input_memory(0));
625     auto wei_base = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
626     auto bia_base = reinterpret_cast<const char *>(this->input_memory(2));
627     auto diff_src_base = reinterpret_cast<diff_src_data_t *>(this->memory());
628
629     auto scratchpad = this->scratchpad();
630
631     const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
632
633     parallel(jcp.nthr, [&](const int ithr, const int nthr) {
634         execute_backward_data_thr(ithr, nthr, diff_dst_base, wei_base,
635                 bia_base, diff_src_base, scratchpad);
636     });
637 }
638
639 template <data_type_t dst_type>
640 void _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>::
641 execute_backward_data_thr(const int ithr, const int nthr,
642         const diff_dst_data_t *diff_dst_base, const wei_data_t *wei_base,
643         const char *bia_base, diff_src_data_t *diff_src_base,
644         const memory_tracking::grantor_t &scratchpad) const
645 {
646     const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
647
648     const auto diff_dst_md = memory_desc_wrapper(pd()->diff_dst_pd());
649     const size_t diff_dst_mb_stride = diff_dst_md.blk_off(1);
650     const size_t diff_dst_g_stride = diff_dst_md.blk_off(0, 1) * jcp.oc;
651
652     const auto wei_md = memory_desc_wrapper(pd()->weights_pd(0));
653     const size_t wei_g_stride = pd()->with_groups() ? wei_md.blk_off(1) : 0;
654
655     const auto diff_src_md = memory_desc_wrapper(pd()->diff_src_pd());
656     const size_t diff_src_mb_stride = diff_src_md.blk_off(1);
657     const size_t diff_src_g_stride = diff_src_md.blk_off(0, 1) * jcp.ic;
658     const size_t diff_src_os_stride = diff_src_md.blk_off(0, 0, 0, 1);
659
660     /* scale_idx_mult = 1 for per_oc scales and 0, otherwise */
661     const int scale_idx_mult = pd()->attr()->output_scales_.mask_ == (1 << 1);
662     const float *scales = pd()->attr()->output_scales_.scales_;
663     const auto rmode = pd()->attr()->round_mode_;
664     const size_t work_amount = jcp.ngroups * jcp.mb;
665
666     auto col = scratchpad.get<acc_data_t>(key_conv_gemm_col)
667         + (ptrdiff_t)ithr * jcp.im2col_sz;
668     auto acc = scratchpad.get<acc_data_t>(key_conv_int_dat_in_acc_dt)
669         + (ptrdiff_t)ithr * jcp.is * jcp.ic;
670
671     int n{0}, g{0};
672     size_t start = 0, end = 0;
673
674     balance211(work_amount, nthr, ithr, start, end);
675     nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups);
676
677     for (size_t iwork = start; iwork < end; ++iwork) {
678         const diff_dst_data_t *diff_dst = diff_dst_base
679             + n * diff_dst_mb_stride + g * diff_dst_g_stride;
680         const wei_data_t *wei = wei_base + g * wei_g_stride;
681         diff_src_data_t *diff_src = diff_src_base + n * diff_src_mb_stride
682             + g * diff_src_g_stride;
683
684         const int M = jcp.ks * jcp.ic;
685         const int N = jcp.os;
686         const int K = jcp.oc;
687         const int8_t off_a = 0, off_b = 0;
688         const int32_t off_c = 0;
689         const float onef = 1.0, zerof = 0.0;
690         const int LD = K * jcp.ngroups;
691
692         mkldnn_gemm_s8u8s32("T", "N", "F", &M, &N, &K, &onef,
693                 wei, &LD, &off_a, diff_dst, &LD, &off_b,
694                 &zerof, jcp.im2col_sz ? col : acc, &M, &off_c);
695
696         if (jcp.im2col_sz)
697             jit_gemm_convolution_utils::col2im_s32(jcp, col, acc);
698
699         parallel_nd(jcp.is, jcp.ic, [&](int is, int ic) {
700             float d = (float)acc[is * jcp.ic + ic];
701             if (jcp.with_bias)
702                 d += get_bias(bia_base, g * jcp.ic + ic,
703                         pd()->desc()->bias_desc.data_type);
704             d *= scales[(g * jcp.ic + ic) * scale_idx_mult];
705             const size_t diff_src_off = is * diff_src_os_stride + ic;
706             diff_src[diff_src_off] =
707                 qz_a1b0<float, diff_src_data_t>()(d, rmode);
708         });
709         nd_iterator_step(n, jcp.mb, g, jcp.ngroups);
710     }
711 }
712
713 using namespace data_type;
714
715 template struct _gemm_x8s8s32x_convolution_fwd_t<u8, f32>;
716 template struct _gemm_x8s8s32x_convolution_fwd_t<u8, s32>;
717 template struct _gemm_x8s8s32x_convolution_fwd_t<u8, s8>;
718 template struct _gemm_x8s8s32x_convolution_fwd_t<u8, u8>;
719
720 template struct _gemm_x8s8s32x_convolution_fwd_t<s8, f32>;
721 template struct _gemm_x8s8s32x_convolution_fwd_t<s8, s32>;
722 template struct _gemm_x8s8s32x_convolution_fwd_t<s8, s8>;
723 template struct _gemm_x8s8s32x_convolution_fwd_t<s8, u8>;
724
725 template struct _gemm_u8s8s32x_convolution_bwd_data_t<f32>;
726 template struct _gemm_u8s8s32x_convolution_bwd_data_t<s32>;
727 template struct _gemm_u8s8s32x_convolution_bwd_data_t<s8>;
728 template struct _gemm_u8s8s32x_convolution_bwd_data_t<u8>;
729 }
730 }
731 }