updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm_inner_product_utils.cpp
1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "math_utils.hpp"
18 #include "mkldnn_thread.hpp"
19 #include "simple_q10n.hpp"
20 #include "gemm_inner_product_utils.hpp"
21 #include "jit_uni_eltwise.hpp"
22
23 namespace mkldnn {
24 namespace impl {
25 namespace cpu {
26
27 namespace inner_product_utils {
28
29 using namespace alg_kind;
30 using namespace math;
31
32 template <data_type_t acc_type, data_type_t dst_type>
33 pp_kernel_t<acc_type, dst_type>::pp_kernel_t(
34         const cpu_inner_product_fwd_pd_t *pd)
35     : ker_(nullptr)
36     , eltwise_injector_(nullptr)
37     , ref_eltwise_(nullptr)
38     , OC_(pd->OC())
39     , bias_data_type_(data_type::undef)
40     , bias_data_type_size_(0)
41     , scale_idx_mult_(0)
42     , rmode_(round_mode::nearest)
43     , do_bias_(pd->with_bias())
44     , do_eltwise_(false) {
45     using namespace types;
46
47     scale_idx_mult_ = (pd->attr()->output_scales_.mask_ == (1 << 1));
48     rmode_ = pd->attr()->round_mode_;
49
50     auto &p = pd->attr()->post_ops_;
51     const int eltwise_ind = p.find(primitive_kind::eltwise);
52     do_eltwise_ = eltwise_ind != -1;
53     if (do_eltwise_)
54         eltwise_ = p.entry_[eltwise_ind].eltwise;
55
56     bias_data_type_ = pd->desc()->bias_desc.data_type;
57     if (do_bias_) {
58         assert(bias_data_type_ != data_type::undef);
59         bias_data_type_size_ = data_type_size(bias_data_type_);
60     }
61
62     if (!mayiuse(avx512_core)) {
63         // use fallback code for older CPUs since they do not have optimized
64         // x8s8s32 GEMM anyways. The configuration variables above are used by
65         // the fallback code.
66         if (do_eltwise_)
67             ref_eltwise_ = new ref_eltwise_scalar_fwd_t(
68                     eltwise_.alg, eltwise_.alpha, eltwise_.beta);
69         return;
70     } else {
71         if (do_eltwise_)
72             eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx512_common>(
73                     this, eltwise_, true, Xbyak::util::rax, Xbyak::Opmask(2));
74         generate();
75     }
76 }
77
78 template<data_type_t acc_type, data_type_t dst_type>
79 void pp_kernel_t<acc_type, dst_type>::generate()
80 {
81     using namespace Xbyak;
82     using namespace utils;
83     using namespace round_mode;
84
85     // TODO: clean-up
86     Reg64 reg_param = abi_param1;
87     Reg64 reg_dst = rdx;
88     Reg64 reg_acc = rax;
89     Reg64 reg_bias = rbx;
90     Reg64 reg_scales = rsi;
91
92     Reg64 reg_len = r8;
93     Reg64 reg_tmp = rcx; // intentional for shifting purposes
94     Reg64 reg_oc_offset = r9;
95     Reg64 reg_rem_mask = r10;
96     Opmask kreg_rem_mask = k1;
97
98     const size_t vlen = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
99
100     Zmm vreg_zero = Zmm(0);
101     Zmm vreg_scale = Zmm(1);
102
103     auto vreg_dst = [&](int idx) { return Zmm(3 + idx * 2 + 0); };
104     auto vreg_bias = [&](int idx) { return Zmm(3 + idx * 2 + 1); };
105
106     preamble();
107
108 #define PARAM_OFF(x) offsetof(ker_args, x)
109     mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]);
110     mov(reg_acc, ptr[reg_param + PARAM_OFF(acc)]);
111     mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]);
112     mov(reg_scales, ptr[reg_param + PARAM_OFF(scales)]);
113     mov(reg_len, ptr[reg_param + PARAM_OFF(len)]);
114     mov(reg_oc_offset, ptr[reg_param + PARAM_OFF(oc_offset)]);
115     if (scale_idx_mult_ == 0)
116         vbroadcastss(vreg_scale, dword[reg_scales]);
117 #undef PARAM_OFF
118
119     if (dst_type == data_type::u8)
120         vxorps(vreg_zero, vreg_zero, vreg_zero);
121
122     // Load accumulated value, convert to float, apply bias (if any), scaling,
123     // and eltwise (if any); then convert to destination type and store
124     auto compute = [&](size_t offset, int idx, bool apply_mask) {
125         auto acc_addr = ptr[reg_acc + offset * sizeof(acc_data_t)];
126
127         if (scale_idx_mult_ > 0) {
128             assert(scale_idx_mult_ == 1);
129             auto scale_addr = ptr[reg_scales + offset * sizeof(float)];
130             auto vreg_scale_ = vreg_scale;
131             if (apply_mask)
132                 vreg_scale_ = vreg_scale_ | kreg_rem_mask;
133             vmovups(vreg_scale, scale_addr);
134         }
135
136         auto vreg_dst_ = vreg_dst(idx);
137         if (apply_mask)
138             vreg_dst_ = vreg_dst_ | kreg_rem_mask;
139
140         switch (acc_type) {
141         case data_type::s32: vcvtdq2ps(vreg_dst_, acc_addr); break;
142         case data_type::f32: vmovups(vreg_dst_, acc_addr); break;
143         }
144
145         if (do_bias_) {
146             auto bias_addr = ptr[reg_bias + offset * bias_data_type_size_];
147             auto vreg_bias_ = vreg_bias(idx);
148             if (apply_mask)
149                 vreg_bias_ = vreg_bias_ | kreg_rem_mask;
150
151             switch (bias_data_type_) {
152             case data_type::s8:
153                 vpmovsxbd(vreg_bias_, bias_addr);
154                 break;
155             case data_type::u8:
156                 vpmovzxbd(vreg_bias_, bias_addr);
157                 break;
158             case data_type::s32:
159             case data_type::f32:
160                 vmovups(vreg_bias_, bias_addr);
161                 break;
162             default: assert(!"unimplemented");
163             }
164             if (bias_data_type_ != data_type::f32)
165                 vcvtdq2ps(vreg_bias(idx), vreg_bias(idx));
166             vaddps(vreg_dst(idx), vreg_dst(idx), vreg_bias(idx));
167         }
168
169         vmulps(vreg_dst(idx), vreg_dst(idx), vreg_scale);
170         if (do_eltwise_)
171             eltwise_injector_->compute_vector(vreg_dst(idx).getIdx());
172
173         if (dst_type == data_type::u8)
174             vmaxps(vreg_dst(idx), vreg_dst(idx), vreg_zero);
175
176         if (dst_type != data_type::f32) {
177             auto rmode_control = (rmode_ == nearest ? T_rn_sae : T_rd_sae);
178             vcvtps2dq(vreg_dst(idx) | rmode_control, vreg_dst(idx));
179         }
180
181         auto dst_addr = ptr[reg_dst + offset * sizeof(dst_data_t)];
182         switch (dst_type) {
183         case data_type::s8:
184             vpmovsdb(dst_addr, vreg_dst_);
185             break;
186         case data_type::u8:
187             vpmovusdb(dst_addr, vreg_dst_);
188             break;
189         case data_type::f32:
190         case data_type::s32:
191             vmovups(dst_addr, vreg_dst_);
192             break;
193         default: assert(!"unimplemented");
194         }
195     };
196
197     // Advance all pointers by an immediate
198     auto advance_ptrs_imm = [&](size_t offset) {
199         add(reg_dst, offset * sizeof(dst_data_t));
200         add(reg_acc, offset * sizeof(acc_data_t));
201         if (scale_idx_mult_) {
202             assert(scale_idx_mult_ == 1);
203             add(reg_scales, offset * sizeof(float));
204         }
205         if (do_bias_)
206             add(reg_bias, offset * bias_data_type_size_);
207     };
208
209     // Advance all pointers by a value stored in a register
210     auto advance_ptrs_reg = [&](Reg64 offset) {
211         lea(reg_dst, ptr[reg_dst + offset * sizeof(dst_data_t)]);
212         lea(reg_acc, ptr[reg_acc + offset * sizeof(acc_data_t)]);
213         if (scale_idx_mult_) {
214             assert(scale_idx_mult_ == 1);
215             lea(reg_scales, ptr[reg_scales + offset * sizeof(float)]);
216         }
217         if (do_bias_)
218             lea(reg_bias, ptr[reg_bias + offset * bias_data_type_size_]);
219     };
220
221     // Rewind pointers that point to data that is indixed by output channel
222     // (bias or per-oc scaling factors)
223     auto rewind_ptrs = [&]() {
224         if (do_bias_)
225             sub(reg_bias, OC_ * bias_data_type_size_);
226         if (scale_idx_mult_) {
227             assert(scale_idx_mult_ == 1);
228             sub(reg_scales, OC_ * sizeof(float));
229         }
230     };
231
232     //      <-------------------- OC ------------------------------->
233     //
234     // ^    +....................+----------------------------------+
235     // |    :   not accessed     |          Prologue loop           |
236     // |    +--------------------+----------------------------------+
237     //      |                                                       |
238     // M    |                 Main loop (unrolled)                  |
239     // B    |                                                       |
240     //      +--------------------------------+----------------------+
241     // |    |       Epilogue loop            |      not accessed    :
242     // v    +--------------------------------+......................+
243
244     Label prologue_end;
245     cmp(reg_oc_offset, 0);
246     je(prologue_end, T_NEAR);
247
248     // Prologue loop
249     {
250         mov(reg_tmp, OC_);
251         sub(reg_tmp, reg_oc_offset);
252         cmp(reg_tmp, reg_len);
253         cmovg(reg_tmp, reg_len);
254         sub(reg_len, reg_tmp);
255
256         Label prologue_loop, prologue_loop_tail, prologue_loop_end;
257         cmp(reg_tmp, vlen);
258         jle(prologue_loop_tail, T_NEAR); // Skips for reg_tmp == 16 too (?)
259         L(prologue_loop); {
260             compute(0, 0, false);
261             advance_ptrs_imm(vlen);
262             sub(reg_tmp, vlen);
263             cmp(reg_tmp, vlen);
264             jge(prologue_loop, T_NEAR);
265         }
266
267         L(prologue_loop_tail);
268         mov(reg_rem_mask, 1);
269         shl(reg_rem_mask, cl); // cl == reg_tmp because reg_tmp <= vlen here
270         sub(reg_rem_mask, 1);
271         jz(prologue_loop_end, T_NEAR);
272
273         kmovq(kreg_rem_mask, reg_rem_mask);
274         compute(0, 0, true);
275         advance_ptrs_reg(reg_tmp);
276
277         L(prologue_loop_end);
278         rewind_ptrs();
279     }
280     L(prologue_end);
281
282     // Main loop
283     Label main_loop_end;
284     {
285         cmp(reg_len, OC_);
286         jle(main_loop_end, T_NEAR);
287
288         Label main_loop;
289         L(main_loop); {
290             size_t def_unroll = 4;
291             size_t max_unroll = 13;
292
293             size_t OC_loop, OC_tail;
294             if (OC_ < max_unroll * vlen) {
295                 // Fully unroll small loops
296                 OC_loop = 0;
297                 OC_tail = OC_;
298             } else {
299                 OC_loop = vlen * def_unroll;
300                 OC_tail = OC_ % OC_loop;
301             }
302
303             assert(!!OC_loop || !!OC_tail);
304
305             if (OC_tail % vlen) {
306                 int vlen_tail = OC_tail % vlen;
307                 unsigned tail_mask = (1 << vlen_tail) - 1;
308                 mov(reg_tmp, tail_mask);
309                 kmovq(kreg_rem_mask, reg_tmp);
310             }
311
312             if (OC_loop) {
313                 mov(reg_tmp, rnd_dn(OC_, OC_loop));
314                 Label oc_loop;
315                 L(oc_loop); {
316                     for (size_t offset = 0; offset < OC_loop; offset += vlen)
317                         compute(offset, offset / vlen, false);
318                     advance_ptrs_imm(OC_loop);
319                     sub(reg_tmp, OC_loop);
320                     jnz(oc_loop);
321                 }
322             }
323
324             if (OC_tail) {
325                 for (size_t offset = 0; offset < OC_tail; offset += vlen) {
326                     bool use_mask = (offset + vlen) > OC_tail;
327                     compute(offset, offset / vlen, use_mask);
328                 }
329                 advance_ptrs_imm(OC_tail);
330             }
331
332             rewind_ptrs();
333             sub(reg_len, OC_);
334             cmp(reg_len, OC_);
335             jge(main_loop, T_NEAR);
336         }
337     }
338     L(main_loop_end);
339
340     // Epilogue loop
341     Label epilogue_end;
342     {
343         cmp(reg_len, 0);
344         je(epilogue_end, T_NEAR);
345
346         Label epilogue_loop, epilogue_loop_tail;
347         cmp(reg_len, vlen);
348         jle(epilogue_loop_tail, T_NEAR); // Skips for reg_len == 16 (?)
349         L(epilogue_loop); {
350             compute(0, 0, false);
351             sub(reg_len, vlen);
352             advance_ptrs_imm(vlen);
353             cmp(reg_len, vlen);
354             jge(epilogue_loop, T_NEAR);
355         }
356
357         L(epilogue_loop_tail);
358         mov(reg_tmp, reg_len); // reg_tmp is rcx, and we need cl for the shift
359         mov(reg_rem_mask, 1);
360         shl(reg_rem_mask, cl); // reg_tmp == rcx and reg_tail < vlen == 16
361         sub(reg_rem_mask, 1);
362         jz(epilogue_end, T_NEAR);
363         kmovq(kreg_rem_mask, reg_rem_mask);
364         compute(0, 0, true);
365     }
366
367     L(epilogue_end);
368
369     postamble();
370
371     if (do_eltwise_)
372         eltwise_injector_->prepare_table();
373
374     ker_ = getCode<decltype(ker_)>();
375 }
376
377 template <data_type_t acc_type, data_type_t dst_type>
378 void pp_kernel_t<acc_type, dst_type>::operator()(dst_data_t *dst,
379         const acc_data_t *acc, const char *bias, const float *scales,
380         size_t start, size_t end) {
381     using math::get_bias;
382
383     if (end <= start)
384         return;
385
386     if (ker_) {
387         // JIT
388         ker_args args;
389         size_t oc_offset = start % OC_;
390         args.dst = dst + start;
391         args.acc = acc + start;
392         args.bias = bias + oc_offset * bias_data_type_size_;
393         args.scales = scales + scale_idx_mult_ * oc_offset;
394         args.len = end - start;
395         args.oc_offset = oc_offset;
396         ker_(&args);
397     } else {
398         // Fallback
399         size_t oc = start % OC_;
400         for (size_t i = start; i < end; i++) {
401             float d = (float)acc[i];
402             float b = get_bias(bias, oc, bias_data_type_);
403             d = d + b;
404             d *= scales[oc * scale_idx_mult_];
405             if (do_eltwise_)
406                 d = ref_eltwise_->compute_scalar(d);
407             dst[i] = qz_a1b0<float, dst_data_t>()(d, rmode_);
408             oc = (oc == OC_ - 1) ? 0 : oc + 1;
409         }
410     }
411 };
412
413
414 using namespace data_type;
415 template class pp_kernel_t<f32, f32>;
416 template class pp_kernel_t<s32, f32>;
417 template class pp_kernel_t<s32, s32>;
418 template class pp_kernel_t<s32, s8>;
419 template class pp_kernel_t<s32, u8>;
420 }
421
422 }
423 }
424 }