1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "math_utils.hpp"
18 #include "mkldnn_thread.hpp"
19 #include "simple_q10n.hpp"
20 #include "gemm_x8s8s32x_inner_product.hpp"
27 using namespace memory_format;
28 using namespace memory_tracking::names;
30 template<data_type_t src_type, data_type_t dst_type>
31 gemm_x8s8s32x_inner_product_fwd_t<src_type, dst_type>::pp_kernel_t::pp_kernel_t(
32 const pd_t *pd, bool dst_is_acc)
33 : ker_(nullptr), OC_(pd->OC())
34 , bias_data_type_(data_type::undef), bias_data_type_size_(0)
35 , scale_idx_mult_(0), rmode_(round_mode::nearest)
36 , do_bias_(false), do_relu_(false)
38 using namespace types;
40 scale_idx_mult_ = (pd->attr()->output_scales_.mask_ == (1 << 1));
41 rmode_ = pd->attr()->round_mode_;
43 auto &post_ops = pd->attr()->post_ops_;
44 do_relu_ = post_ops.len_ == 1;
45 do_bias_ = pd->with_bias();
46 bias_data_type_ = pd->desc()->bias_desc.data_type;
48 assert(bias_data_type_ != data_type::undef);
49 bias_data_type_size_ = data_type_size(bias_data_type_);
52 if (!mayiuse(avx512_core))
53 // use fallback code for older CPUs since they do not have optimized
54 // x8s8s32 GEMM anyways. The configuration variables above are used by
61 template<data_type_t src_type, data_type_t dst_type>
62 void gemm_x8s8s32x_inner_product_fwd_t<src_type, dst_type>::pp_kernel_t::generate()
64 using namespace Xbyak;
65 using namespace utils;
66 using namespace round_mode;
69 Reg64 reg_param = abi_param1;
73 Reg64 reg_scales = rsi;
76 Reg64 reg_tmp = rcx; // intentional for shifting purposes
77 Reg64 reg_oc_offset = r9;
78 Reg64 reg_rem_mask = r10;
79 Opmask kreg_rem_mask = k1;
80 Opmask kreg_relu_cmp = k2;
82 const size_t vlen = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
84 Zmm vreg_zero = Zmm(0);
85 Zmm vreg_scale = Zmm(1);
86 Zmm vreg_nslope = Zmm(2);
88 auto vreg_dst = [&](int idx) { return Zmm(3 + idx * 2 + 0); };
89 auto vreg_bias = [&](int idx) { return Zmm(3 + idx * 2 + 1); };
93 #define PARAM_OFF(x) offsetof(ker_args, x)
94 mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]);
95 mov(reg_acc, ptr[reg_param + PARAM_OFF(acc)]);
96 mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]);
97 mov(reg_scales, ptr[reg_param + PARAM_OFF(scales)]);
98 mov(reg_len, ptr[reg_param + PARAM_OFF(len)]);
99 mov(reg_oc_offset, ptr[reg_param + PARAM_OFF(oc_offset)]);
100 vbroadcastss(vreg_nslope, ptr[reg_param + PARAM_OFF(nslope)]);
101 if (scale_idx_mult_ == 0)
102 vbroadcastss(vreg_scale, dword[reg_scales]);
105 if (do_relu_ || dst_type == data_type::u8)
106 vxorps(vreg_zero, vreg_zero, vreg_zero);
108 // Load accumulated value, convert to float, apply bias (if any), scaling,
109 // and relu (if any); then convert to destination type and store
110 auto compute = [&](size_t offset, int idx, bool apply_mask) {
111 auto acc_addr = ptr[reg_acc + offset * sizeof(acc_data_t)];
113 if (scale_idx_mult_ > 0) {
114 assert(scale_idx_mult_ == 1);
115 auto scale_addr = ptr[reg_scales + offset * sizeof(float)];
116 auto vreg_scale_ = vreg_scale;
118 vreg_scale_ = vreg_scale_ | kreg_rem_mask;
119 vmovups(vreg_scale, scale_addr);
122 auto vreg_dst_ = vreg_dst(idx);
124 vreg_dst_ = vreg_dst_ | kreg_rem_mask;
125 vcvtdq2ps(vreg_dst_, acc_addr);
128 auto bias_addr = ptr[reg_bias + offset * bias_data_type_size_];
129 auto vreg_bias_ = vreg_bias(idx);
131 vreg_bias_ = vreg_bias_ | kreg_rem_mask;
133 switch (bias_data_type_) {
135 vpmovsxbd(vreg_bias_, bias_addr);
138 vpmovzxbd(vreg_bias_, bias_addr);
142 vmovups(vreg_bias_, bias_addr);
144 default: assert(!"unimplemented");
146 if (bias_data_type_ != data_type::f32)
147 vcvtdq2ps(vreg_bias(idx), vreg_bias(idx));
148 vaddps(vreg_dst(idx), vreg_dst(idx), vreg_bias(idx));
151 vmulps(vreg_dst(idx), vreg_dst(idx), vreg_scale);
153 vcmpps(kreg_relu_cmp, vreg_dst(idx), vreg_zero, _cmp_lt_os);
154 vmulps(vreg_dst(idx) | kreg_relu_cmp, vreg_dst(idx), vreg_nslope);
157 if (dst_type == data_type::u8)
158 vmaxps(vreg_dst(idx), vreg_dst(idx), vreg_zero);
160 if (dst_type != data_type::f32) {
161 auto rmode_control = (rmode_ == nearest ? T_rn_sae : T_rd_sae);
162 vcvtps2dq(vreg_dst(idx) | rmode_control, vreg_dst(idx));
165 auto dst_addr = ptr[reg_dst + offset * sizeof(dst_data_t)];
168 vpmovsdb(dst_addr, vreg_dst_);
171 vpmovusdb(dst_addr, vreg_dst_);
175 vmovups(dst_addr, vreg_dst_);
177 default: assert(!"unimplemented");
181 // Advance all pointers by an immediate
182 auto advance_ptrs_imm = [&](size_t offset) {
183 add(reg_dst, offset * sizeof(dst_data_t));
184 add(reg_acc, offset * sizeof(acc_data_t));
185 if (scale_idx_mult_) {
186 assert(scale_idx_mult_ == 1);
187 add(reg_scales, offset * sizeof(float));
190 add(reg_bias, offset * bias_data_type_size_);
193 // Advance all pointers by a value stored in a register
194 auto advance_ptrs_reg = [&](Reg64 offset) {
195 lea(reg_dst, ptr[reg_dst + offset * sizeof(dst_data_t)]);
196 lea(reg_acc, ptr[reg_acc + offset * sizeof(acc_data_t)]);
197 if (scale_idx_mult_) {
198 assert(scale_idx_mult_ == 1);
199 lea(reg_scales, ptr[reg_scales + offset * sizeof(float)]);
202 lea(reg_bias, ptr[reg_bias + offset * bias_data_type_size_]);
205 // Rewind pointers that point to data that is indixed by output channel
206 // (bias or per-oc scaling factors)
207 auto rewind_ptrs = [&]() {
209 sub(reg_bias, OC_ * bias_data_type_size_);
210 if (scale_idx_mult_) {
211 assert(scale_idx_mult_ == 1);
212 sub(reg_scales, OC_ * sizeof(float));
216 // <-------------------- OC ------------------------------->
218 // ^ +....................+----------------------------------+
219 // | : not accessed | Prologue loop |
220 // | +--------------------+----------------------------------+
222 // M | Main loop (unrolled) |
224 // +--------------------------------+----------------------+
225 // | | Epilogue loop | not accessed :
226 // v +--------------------------------+......................+
229 cmp(reg_oc_offset, 0);
230 je(prologue_end, T_NEAR);
235 sub(reg_tmp, reg_oc_offset);
236 cmp(reg_tmp, reg_len);
237 cmovg(reg_tmp, reg_len);
238 sub(reg_len, reg_tmp);
240 Label prologue_loop, prologue_loop_tail, prologue_loop_end;
242 jle(prologue_loop_tail, T_NEAR); // Skips for reg_tmp == 16 too (?)
244 compute(0, 0, false);
245 advance_ptrs_imm(vlen);
248 jge(prologue_loop, T_NEAR);
251 L(prologue_loop_tail);
252 mov(reg_rem_mask, 1);
253 shl(reg_rem_mask, cl); // cl == reg_tmp because reg_tmp <= vlen here
254 sub(reg_rem_mask, 1);
255 jz(prologue_loop_end, T_NEAR);
257 kmovq(kreg_rem_mask, reg_rem_mask);
259 advance_ptrs_reg(reg_tmp);
261 L(prologue_loop_end);
270 jle(main_loop_end, T_NEAR);
274 size_t def_unroll = 4;
275 size_t max_unroll = 13;
277 size_t OC_loop, OC_tail;
278 if (OC_ < max_unroll * vlen) {
279 // Fully unroll small loops
283 OC_loop = vlen * def_unroll;
284 OC_tail = OC_ % OC_loop;
287 assert(!!OC_loop || !!OC_tail);
289 if (OC_tail % vlen) {
290 int vlen_tail = OC_tail % vlen;
291 unsigned tail_mask = (1 << vlen_tail) - 1;
292 mov(reg_tmp, tail_mask);
293 kmovq(kreg_rem_mask, reg_tmp);
297 mov(reg_tmp, rnd_dn(OC_, OC_loop));
300 for (size_t offset = 0; offset < OC_loop; offset += vlen)
301 compute(offset, offset / vlen, false);
302 advance_ptrs_imm(OC_loop);
303 sub(reg_tmp, OC_loop);
309 for (size_t offset = 0; offset < OC_tail; offset += vlen) {
310 bool use_mask = (offset + vlen) > OC_tail;
311 compute(offset, offset / vlen, use_mask);
313 advance_ptrs_imm(OC_tail);
319 jge(main_loop, T_NEAR);
328 je(epilogue_end, T_NEAR);
330 Label epilogue_loop, epilogue_loop_tail;
332 jle(epilogue_loop_tail, T_NEAR); // Skips for reg_len == 16 (?)
334 compute(0, 0, false);
336 advance_ptrs_imm(vlen);
338 jge(epilogue_loop, T_NEAR);
341 L(epilogue_loop_tail);
342 mov(reg_tmp, reg_len); // reg_tmp is rcx, and we need cl for the shift
343 mov(reg_rem_mask, 1);
344 shl(reg_rem_mask, cl); // reg_tmp == rcx and reg_tail < vlen == 16
345 sub(reg_rem_mask, 1);
346 jz(epilogue_end, T_NEAR);
347 kmovq(kreg_rem_mask, reg_rem_mask);
355 ker_ = getCode<decltype(ker_)>();
358 template<data_type_t src_type, data_type_t dst_type>
359 void gemm_x8s8s32x_inner_product_fwd_t<src_type, dst_type>::pp_kernel_t::operator ()(
360 dst_data_t *dst, const acc_data_t *acc,
361 const char *bias, const float *scales, float nslope,
362 size_t start, size_t end)
364 using math::get_bias;
372 size_t oc_offset = start % OC_;
373 args.dst = dst + start;
374 args.acc = acc + start;
375 args.bias = bias + oc_offset * bias_data_type_size_;
376 args.scales = scales + scale_idx_mult_ * oc_offset;
377 args.nslope = nslope;
378 args.len = end - start;
379 args.oc_offset = oc_offset;
383 size_t oc = start % OC_;
384 for (size_t i = start; i < end; i++) {
385 float d = (float)acc[i];
386 float b = get_bias(bias, oc, bias_data_type_);
388 d *= scales[oc * scale_idx_mult_];
389 if (do_relu_ && d < 0)
391 dst[i] = qz_a1b0<float, dst_data_t>()(d, rmode_);
392 oc = (oc == OC_ - 1) ? 0 : oc + 1;
397 template <data_type_t src_type, data_type_t dst_type>
398 void gemm_x8s8s32x_inner_product_fwd_t<src_type, dst_type
399 >::execute_forward() const {
400 auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
401 auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
402 auto bias = reinterpret_cast<const char *>(this->input_memory(2));
403 auto dst = reinterpret_cast<dst_data_t *>(this->memory());
405 const int MB = pd()->MB();
406 const int OC = pd()->OC();
408 bool wei_tr = utils::one_of(pd()->weights_pd()->desc()->format,
413 const int K = pd()->IC_total_padded();
414 const int8_t off_a = 0, off_b = 0;
415 const int32_t off_c = 0;
417 const float *scales = pd()->attr()->output_scales_.scales_;
419 const auto &post_ops = pd()->attr()->post_ops_;
420 const bool do_relu = post_ops.len_ == 1;
421 const float nslope = do_relu ? post_ops.entry_[0].eltwise.alpha : 0.f;
423 acc_data_t *acc = pd()->dst_is_acc_
425 : scratchpad().template get<acc_data_t>(key_iprod_int_dat_in_acc_dt);
427 const float onef = 1.0, zerof = 0.0;
429 if (src_type == data_type::u8) {
430 mkldnn_gemm_s8u8s32(wei_tr ? "T" : "N", "N", "F", &M, &N, &K, &onef,
431 weights, wei_tr ? &K : &M, &off_a, (uint8_t *)src, &K, &off_b, &zerof,
433 } else if (src_type == data_type::s8) {
434 mkldnn_gemm_s8s8s32(wei_tr ? "T" : "N", "N", "F", &M, &N, &K, &onef,
435 weights, wei_tr ? &K : &M, &off_a, (int8_t *)src, &K, &off_b, &zerof,
438 assert(!"incorrect src type");
441 const bool force_sequential = MB * OC < 2000;
442 parallel(force_sequential ? 1 : 0, [&](int ithr, int nthr) {
444 balance211((size_t)OC * MB, nthr, ithr, start, end);
445 (*pp_kernel_)(dst, acc, bias, scales, nslope, start, end);
449 using namespace data_type;
451 template struct gemm_x8s8s32x_inner_product_fwd_t<u8, f32>;
452 template struct gemm_x8s8s32x_inner_product_fwd_t<u8, s32>;
453 template struct gemm_x8s8s32x_inner_product_fwd_t<u8, s8>;
454 template struct gemm_x8s8s32x_inner_product_fwd_t<u8, u8>;
455 template struct gemm_x8s8s32x_inner_product_fwd_t<s8, f32>;
456 template struct gemm_x8s8s32x_inner_product_fwd_t<s8, s32>;
457 template struct gemm_x8s8s32x_inner_product_fwd_t<s8, s8>;
458 template struct gemm_x8s8s32x_inner_product_fwd_t<s8, u8>;