Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm_x8s8s32x_inner_product.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "math_utils.hpp"
18 #include "mkldnn_thread.hpp"
19 #include "simple_q10n.hpp"
20 #include "gemm_x8s8s32x_inner_product.hpp"
21
22 namespace mkldnn {
23 namespace impl {
24 namespace cpu {
25
26 using namespace math;
27 using namespace memory_format;
28 using namespace memory_tracking::names;
29
30 template<data_type_t src_type, data_type_t dst_type>
31 gemm_x8s8s32x_inner_product_fwd_t<src_type, dst_type>::pp_kernel_t::pp_kernel_t(
32         const pd_t *pd, bool dst_is_acc)
33     : ker_(nullptr), OC_(pd->OC())
34     , bias_data_type_(data_type::undef), bias_data_type_size_(0)
35     , scale_idx_mult_(0), rmode_(round_mode::nearest)
36     , do_bias_(false), do_relu_(false)
37 {
38     using namespace types;
39
40     scale_idx_mult_ = (pd->attr()->output_scales_.mask_ == (1 << 1));
41     rmode_ = pd->attr()->round_mode_;
42
43     auto &post_ops = pd->attr()->post_ops_;
44     do_relu_ = post_ops.len_ == 1;
45     do_bias_ = pd->with_bias();
46     bias_data_type_ = pd->desc()->bias_desc.data_type;
47     if (do_bias_) {
48         assert(bias_data_type_ != data_type::undef);
49         bias_data_type_size_ = data_type_size(bias_data_type_);
50     }
51
52     if (!mayiuse(avx512_core))
53         // use fallback code for older CPUs since they do not have optimized
54         // x8s8s32 GEMM anyways. The configuration variables above are used by
55         // the fallback code.
56         return;
57     else
58         generate();
59 }
60
61 template<data_type_t src_type, data_type_t dst_type>
62 void gemm_x8s8s32x_inner_product_fwd_t<src_type, dst_type>::pp_kernel_t::generate()
63 {
64     using namespace Xbyak;
65     using namespace utils;
66     using namespace round_mode;
67
68     // TODO: clean-up
69     Reg64 reg_param = abi_param1;
70     Reg64 reg_dst = rdx;
71     Reg64 reg_acc = rax;
72     Reg64 reg_bias = rbx;
73     Reg64 reg_scales = rsi;
74
75     Reg64 reg_len = r8;
76     Reg64 reg_tmp = rcx; // intentional for shifting purposes
77     Reg64 reg_oc_offset = r9;
78     Reg64 reg_rem_mask = r10;
79     Opmask kreg_rem_mask = k1;
80     Opmask kreg_relu_cmp = k2;
81
82     const size_t vlen = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
83
84     Zmm vreg_zero = Zmm(0);
85     Zmm vreg_scale = Zmm(1);
86     Zmm vreg_nslope = Zmm(2);
87
88     auto vreg_dst = [&](int idx) { return Zmm(3 + idx * 2 + 0); };
89     auto vreg_bias = [&](int idx) { return Zmm(3 + idx * 2 + 1); };
90
91     preamble();
92
93 #define PARAM_OFF(x) offsetof(ker_args, x)
94     mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]);
95     mov(reg_acc, ptr[reg_param + PARAM_OFF(acc)]);
96     mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]);
97     mov(reg_scales, ptr[reg_param + PARAM_OFF(scales)]);
98     mov(reg_len, ptr[reg_param + PARAM_OFF(len)]);
99     mov(reg_oc_offset, ptr[reg_param + PARAM_OFF(oc_offset)]);
100     vbroadcastss(vreg_nslope, ptr[reg_param + PARAM_OFF(nslope)]);
101     if (scale_idx_mult_ == 0)
102         vbroadcastss(vreg_scale, dword[reg_scales]);
103 #undef PARAM_OFF
104
105     if (do_relu_ || dst_type == data_type::u8)
106         vxorps(vreg_zero, vreg_zero, vreg_zero);
107
108     // Load accumulated value, convert to float, apply bias (if any), scaling,
109     // and relu (if any); then convert to destination type and store
110     auto compute = [&](size_t offset, int idx, bool apply_mask) {
111         auto acc_addr = ptr[reg_acc + offset * sizeof(acc_data_t)];
112
113         if (scale_idx_mult_ > 0) {
114             assert(scale_idx_mult_ == 1);
115             auto scale_addr = ptr[reg_scales + offset * sizeof(float)];
116             auto vreg_scale_ = vreg_scale;
117             if (apply_mask)
118                 vreg_scale_ = vreg_scale_ | kreg_rem_mask;
119             vmovups(vreg_scale, scale_addr);
120         }
121
122         auto vreg_dst_ = vreg_dst(idx);
123         if (apply_mask)
124             vreg_dst_ = vreg_dst_ | kreg_rem_mask;
125         vcvtdq2ps(vreg_dst_, acc_addr);
126
127         if (do_bias_) {
128             auto bias_addr = ptr[reg_bias + offset * bias_data_type_size_];
129             auto vreg_bias_ = vreg_bias(idx);
130             if (apply_mask)
131                 vreg_bias_ = vreg_bias_ | kreg_rem_mask;
132
133             switch (bias_data_type_) {
134             case data_type::s8:
135                 vpmovsxbd(vreg_bias_, bias_addr);
136                 break;
137             case data_type::u8:
138                 vpmovzxbd(vreg_bias_, bias_addr);
139                 break;
140             case data_type::s32:
141             case data_type::f32:
142                 vmovups(vreg_bias_, bias_addr);
143                 break;
144             default: assert(!"unimplemented");
145             }
146             if (bias_data_type_ != data_type::f32)
147                 vcvtdq2ps(vreg_bias(idx), vreg_bias(idx));
148             vaddps(vreg_dst(idx), vreg_dst(idx), vreg_bias(idx));
149         }
150
151         vmulps(vreg_dst(idx), vreg_dst(idx), vreg_scale);
152         if (do_relu_) {
153             vcmpps(kreg_relu_cmp, vreg_dst(idx), vreg_zero, _cmp_lt_os);
154             vmulps(vreg_dst(idx) | kreg_relu_cmp, vreg_dst(idx), vreg_nslope);
155         }
156
157         if (dst_type == data_type::u8)
158             vmaxps(vreg_dst(idx), vreg_dst(idx), vreg_zero);
159
160         if (dst_type != data_type::f32) {
161             auto rmode_control = (rmode_ == nearest ? T_rn_sae : T_rd_sae);
162             vcvtps2dq(vreg_dst(idx) | rmode_control, vreg_dst(idx));
163         }
164
165         auto dst_addr = ptr[reg_dst + offset * sizeof(dst_data_t)];
166         switch (dst_type) {
167         case data_type::s8:
168             vpmovsdb(dst_addr, vreg_dst_);
169             break;
170         case data_type::u8:
171             vpmovusdb(dst_addr, vreg_dst_);
172             break;
173         case data_type::f32:
174         case data_type::s32:
175             vmovups(dst_addr, vreg_dst_);
176             break;
177         default: assert(!"unimplemented");
178         }
179     };
180
181     // Advance all pointers by an immediate
182     auto advance_ptrs_imm = [&](size_t offset) {
183         add(reg_dst, offset * sizeof(dst_data_t));
184         add(reg_acc, offset * sizeof(acc_data_t));
185         if (scale_idx_mult_) {
186             assert(scale_idx_mult_ == 1);
187             add(reg_scales, offset * sizeof(float));
188         }
189         if (do_bias_)
190             add(reg_bias, offset * bias_data_type_size_);
191     };
192
193     // Advance all pointers by a value stored in a register
194     auto advance_ptrs_reg = [&](Reg64 offset) {
195         lea(reg_dst, ptr[reg_dst + offset * sizeof(dst_data_t)]);
196         lea(reg_acc, ptr[reg_acc + offset * sizeof(acc_data_t)]);
197         if (scale_idx_mult_) {
198             assert(scale_idx_mult_ == 1);
199             lea(reg_scales, ptr[reg_scales + offset * sizeof(float)]);
200         }
201         if (do_bias_)
202             lea(reg_bias, ptr[reg_bias + offset * bias_data_type_size_]);
203     };
204
205     // Rewind pointers that point to data that is indixed by output channel
206     // (bias or per-oc scaling factors)
207     auto rewind_ptrs = [&]() {
208         if (do_bias_)
209             sub(reg_bias, OC_ * bias_data_type_size_);
210         if (scale_idx_mult_) {
211             assert(scale_idx_mult_ == 1);
212             sub(reg_scales, OC_ * sizeof(float));
213         }
214     };
215
216     //      <-------------------- OC ------------------------------->
217     //
218     // ^    +....................+----------------------------------+
219     // |    :   not accessed     |          Prologue loop           |
220     // |    +--------------------+----------------------------------+
221     //      |                                                       |
222     // M    |                 Main loop (unrolled)                  |
223     // B    |                                                       |
224     //      +--------------------------------+----------------------+
225     // |    |       Epilogue loop            |      not accessed    :
226     // v    +--------------------------------+......................+
227
228     Label prologue_end;
229     cmp(reg_oc_offset, 0);
230     je(prologue_end, T_NEAR);
231
232     // Prologue loop
233     {
234         mov(reg_tmp, OC_);
235         sub(reg_tmp, reg_oc_offset);
236         cmp(reg_tmp, reg_len);
237         cmovg(reg_tmp, reg_len);
238         sub(reg_len, reg_tmp);
239
240         Label prologue_loop, prologue_loop_tail, prologue_loop_end;
241         cmp(reg_tmp, vlen);
242         jle(prologue_loop_tail, T_NEAR); // Skips for reg_tmp == 16 too (?)
243         L(prologue_loop); {
244             compute(0, 0, false);
245             advance_ptrs_imm(vlen);
246             sub(reg_tmp, vlen);
247             cmp(reg_tmp, vlen);
248             jge(prologue_loop, T_NEAR);
249         }
250
251         L(prologue_loop_tail);
252         mov(reg_rem_mask, 1);
253         shl(reg_rem_mask, cl); // cl == reg_tmp because reg_tmp <= vlen here
254         sub(reg_rem_mask, 1);
255         jz(prologue_loop_end, T_NEAR);
256
257         kmovq(kreg_rem_mask, reg_rem_mask);
258         compute(0, 0, true);
259         advance_ptrs_reg(reg_tmp);
260
261         L(prologue_loop_end);
262         rewind_ptrs();
263     }
264     L(prologue_end);
265
266     // Main loop
267     Label main_loop_end;
268     {
269         cmp(reg_len, OC_);
270         jle(main_loop_end, T_NEAR);
271
272         Label main_loop;
273         L(main_loop); {
274             size_t def_unroll = 4;
275             size_t max_unroll = 13;
276
277             size_t OC_loop, OC_tail;
278             if (OC_ < max_unroll * vlen) {
279                 // Fully unroll small loops
280                 OC_loop = 0;
281                 OC_tail = OC_;
282             } else {
283                 OC_loop = vlen * def_unroll;
284                 OC_tail = OC_ % OC_loop;
285             }
286
287             assert(!!OC_loop || !!OC_tail);
288
289             if (OC_tail % vlen) {
290                 int vlen_tail = OC_tail % vlen;
291                 unsigned tail_mask = (1 << vlen_tail) - 1;
292                 mov(reg_tmp, tail_mask);
293                 kmovq(kreg_rem_mask, reg_tmp);
294             }
295
296             if (OC_loop) {
297                 mov(reg_tmp, rnd_dn(OC_, OC_loop));
298                 Label oc_loop;
299                 L(oc_loop); {
300                     for (size_t offset = 0; offset < OC_loop; offset += vlen)
301                         compute(offset, offset / vlen, false);
302                     advance_ptrs_imm(OC_loop);
303                     sub(reg_tmp, OC_loop);
304                     jnz(oc_loop);
305                 }
306             }
307
308             if (OC_tail) {
309                 for (size_t offset = 0; offset < OC_tail; offset += vlen) {
310                     bool use_mask = (offset + vlen) > OC_tail;
311                     compute(offset, offset / vlen, use_mask);
312                 }
313                 advance_ptrs_imm(OC_tail);
314             }
315
316             rewind_ptrs();
317             sub(reg_len, OC_);
318             cmp(reg_len, OC_);
319             jge(main_loop, T_NEAR);
320         }
321     }
322     L(main_loop_end);
323
324     // Epilogue loop
325     Label epilogue_end;
326     {
327         cmp(reg_len, 0);
328         je(epilogue_end, T_NEAR);
329
330         Label epilogue_loop, epilogue_loop_tail;
331         cmp(reg_len, vlen);
332         jle(epilogue_loop_tail, T_NEAR); // Skips for reg_len == 16 (?)
333         L(epilogue_loop); {
334             compute(0, 0, false);
335             sub(reg_len, vlen);
336             advance_ptrs_imm(vlen);
337             cmp(reg_len, vlen);
338             jge(epilogue_loop, T_NEAR);
339         }
340
341         L(epilogue_loop_tail);
342         mov(reg_tmp, reg_len); // reg_tmp is rcx, and we need cl for the shift
343         mov(reg_rem_mask, 1);
344         shl(reg_rem_mask, cl); // reg_tmp == rcx and reg_tail < vlen == 16
345         sub(reg_rem_mask, 1);
346         jz(epilogue_end, T_NEAR);
347         kmovq(kreg_rem_mask, reg_rem_mask);
348         compute(0, 0, true);
349     }
350
351     L(epilogue_end);
352
353     postamble();
354
355     ker_ = getCode<decltype(ker_)>();
356 }
357
358 template<data_type_t src_type, data_type_t dst_type>
359 void gemm_x8s8s32x_inner_product_fwd_t<src_type, dst_type>::pp_kernel_t::operator ()(
360         dst_data_t *dst, const acc_data_t *acc,
361         const char *bias, const float *scales, float nslope,
362         size_t start, size_t end)
363 {
364     using math::get_bias;
365
366     if (end <= start)
367         return;
368
369     if (ker_) {
370         // JIT
371         ker_args args;
372         size_t oc_offset = start % OC_;
373         args.dst = dst + start;
374         args.acc = acc + start;
375         args.bias = bias + oc_offset * bias_data_type_size_;
376         args.scales = scales + scale_idx_mult_ * oc_offset;
377         args.nslope = nslope;
378         args.len = end - start;
379         args.oc_offset = oc_offset;
380         ker_(&args);
381     } else {
382         // Fallback
383         size_t oc = start % OC_;
384         for (size_t i = start; i < end; i++) {
385             float d = (float)acc[i];
386             float b = get_bias(bias, oc, bias_data_type_);
387             d = d + b;
388             d *= scales[oc * scale_idx_mult_];
389             if (do_relu_ && d < 0)
390                 d *= nslope;
391             dst[i] = qz_a1b0<float, dst_data_t>()(d, rmode_);
392             oc = (oc == OC_ - 1) ? 0 : oc + 1;
393         }
394     }
395 };
396
397 template <data_type_t src_type, data_type_t dst_type>
398 void gemm_x8s8s32x_inner_product_fwd_t<src_type, dst_type
399         >::execute_forward() const {
400     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
401     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
402     auto bias = reinterpret_cast<const char *>(this->input_memory(2));
403     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
404
405     const int MB = pd()->MB();
406     const int OC = pd()->OC();
407
408     bool wei_tr = utils::one_of(pd()->weights_pd()->desc()->format,
409              oihw, oidhw, oi);
410
411     const int M = OC;
412     const int N = MB;
413     const int K = pd()->IC_total_padded();
414     const int8_t off_a = 0, off_b = 0;
415     const int32_t off_c = 0;
416
417     const float *scales = pd()->attr()->output_scales_.scales_;
418
419     const auto &post_ops = pd()->attr()->post_ops_;
420     const bool do_relu = post_ops.len_ == 1;
421     const float nslope = do_relu ? post_ops.entry_[0].eltwise.alpha : 0.f;
422
423     acc_data_t *acc = pd()->dst_is_acc_
424         ? (acc_data_t *)dst
425         : scratchpad().template get<acc_data_t>(key_iprod_int_dat_in_acc_dt);
426
427     const float onef = 1.0, zerof = 0.0;
428
429     if (src_type == data_type::u8) {
430         mkldnn_gemm_s8u8s32(wei_tr ? "T" : "N", "N", "F", &M, &N, &K, &onef,
431                 weights, wei_tr ? &K : &M, &off_a, (uint8_t *)src, &K, &off_b, &zerof,
432                 acc, &M, &off_c);
433     } else if (src_type == data_type::s8) {
434         mkldnn_gemm_s8s8s32(wei_tr ? "T" : "N", "N", "F", &M, &N, &K, &onef,
435                 weights, wei_tr ? &K : &M, &off_a, (int8_t *)src, &K, &off_b, &zerof,
436                 acc, &M, &off_c);
437     } else {
438         assert(!"incorrect src type");
439     }
440
441     const bool force_sequential = MB * OC < 2000;
442     parallel(force_sequential ? 1 : 0, [&](int ithr, int nthr) {
443             size_t start, end;
444             balance211((size_t)OC * MB, nthr, ithr, start, end);
445             (*pp_kernel_)(dst, acc, bias, scales, nslope, start, end);
446             });
447 }
448
449 using namespace data_type;
450
451 template struct gemm_x8s8s32x_inner_product_fwd_t<u8, f32>;
452 template struct gemm_x8s8s32x_inner_product_fwd_t<u8, s32>;
453 template struct gemm_x8s8s32x_inner_product_fwd_t<u8, s8>;
454 template struct gemm_x8s8s32x_inner_product_fwd_t<u8, u8>;
455 template struct gemm_x8s8s32x_inner_product_fwd_t<s8, f32>;
456 template struct gemm_x8s8s32x_inner_product_fwd_t<s8, s32>;
457 template struct gemm_x8s8s32x_inner_product_fwd_t<s8, s8>;
458 template struct gemm_x8s8s32x_inner_product_fwd_t<s8, u8>;
459 }
460 }
461 }