1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "math_utils.hpp"
18 #include "mkldnn_thread.hpp"
19 #include "simple_q10n.hpp"
20 #include "gemm_inner_product_utils.hpp"
21 #include "jit_uni_eltwise.hpp"
27 namespace inner_product_utils {
29 using namespace alg_kind;
32 template <data_type_t acc_type, data_type_t dst_type>
33 pp_kernel_t<acc_type, dst_type>::pp_kernel_t(
34 const cpu_inner_product_fwd_pd_t *pd)
36 , eltwise_injector_(nullptr)
37 , ref_eltwise_(nullptr)
39 , bias_data_type_(data_type::undef)
40 , bias_data_type_size_(0)
42 , rmode_(round_mode::nearest)
43 , do_bias_(pd->with_bias())
44 , do_eltwise_(false) {
45 using namespace types;
47 scale_idx_mult_ = (pd->attr()->output_scales_.mask_ == (1 << 1));
48 rmode_ = pd->attr()->round_mode_;
50 auto &p = pd->attr()->post_ops_;
51 const int eltwise_ind = p.find(primitive_kind::eltwise);
52 do_eltwise_ = eltwise_ind != -1;
54 eltwise_ = p.entry_[eltwise_ind].eltwise;
56 bias_data_type_ = pd->desc()->bias_desc.data_type;
58 assert(bias_data_type_ != data_type::undef);
59 bias_data_type_size_ = data_type_size(bias_data_type_);
62 if (!mayiuse(avx512_core)) {
63 // use fallback code for older CPUs since they do not have optimized
64 // x8s8s32 GEMM anyways. The configuration variables above are used by
67 ref_eltwise_ = new ref_eltwise_scalar_fwd_t(
68 eltwise_.alg, eltwise_.alpha, eltwise_.beta);
72 eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx512_common>(
73 this, eltwise_, true, Xbyak::util::rax, Xbyak::Opmask(2));
78 template<data_type_t acc_type, data_type_t dst_type>
79 void pp_kernel_t<acc_type, dst_type>::generate()
81 using namespace Xbyak;
82 using namespace utils;
83 using namespace round_mode;
86 Reg64 reg_param = abi_param1;
90 Reg64 reg_scales = rsi;
93 Reg64 reg_tmp = rcx; // intentional for shifting purposes
94 Reg64 reg_oc_offset = r9;
95 Reg64 reg_rem_mask = r10;
96 Opmask kreg_rem_mask = k1;
98 const size_t vlen = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
100 Zmm vreg_zero = Zmm(0);
101 Zmm vreg_scale = Zmm(1);
103 auto vreg_dst = [&](int idx) { return Zmm(3 + idx * 2 + 0); };
104 auto vreg_bias = [&](int idx) { return Zmm(3 + idx * 2 + 1); };
108 #define PARAM_OFF(x) offsetof(ker_args, x)
109 mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]);
110 mov(reg_acc, ptr[reg_param + PARAM_OFF(acc)]);
111 mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]);
112 mov(reg_scales, ptr[reg_param + PARAM_OFF(scales)]);
113 mov(reg_len, ptr[reg_param + PARAM_OFF(len)]);
114 mov(reg_oc_offset, ptr[reg_param + PARAM_OFF(oc_offset)]);
115 if (scale_idx_mult_ == 0)
116 vbroadcastss(vreg_scale, dword[reg_scales]);
119 if (dst_type == data_type::u8)
120 vxorps(vreg_zero, vreg_zero, vreg_zero);
122 // Load accumulated value, convert to float, apply bias (if any), scaling,
123 // and eltwise (if any); then convert to destination type and store
124 auto compute = [&](size_t offset, int idx, bool apply_mask) {
125 auto acc_addr = ptr[reg_acc + offset * sizeof(acc_data_t)];
127 if (scale_idx_mult_ > 0) {
128 assert(scale_idx_mult_ == 1);
129 auto scale_addr = ptr[reg_scales + offset * sizeof(float)];
130 auto vreg_scale_ = vreg_scale;
132 vreg_scale_ = vreg_scale_ | kreg_rem_mask;
133 vmovups(vreg_scale, scale_addr);
136 auto vreg_dst_ = vreg_dst(idx);
138 vreg_dst_ = vreg_dst_ | kreg_rem_mask;
141 case data_type::s32: vcvtdq2ps(vreg_dst_, acc_addr); break;
142 case data_type::f32: vmovups(vreg_dst_, acc_addr); break;
146 auto bias_addr = ptr[reg_bias + offset * bias_data_type_size_];
147 auto vreg_bias_ = vreg_bias(idx);
149 vreg_bias_ = vreg_bias_ | kreg_rem_mask;
151 switch (bias_data_type_) {
153 vpmovsxbd(vreg_bias_, bias_addr);
156 vpmovzxbd(vreg_bias_, bias_addr);
160 vmovups(vreg_bias_, bias_addr);
162 default: assert(!"unimplemented");
164 if (bias_data_type_ != data_type::f32)
165 vcvtdq2ps(vreg_bias(idx), vreg_bias(idx));
166 vaddps(vreg_dst(idx), vreg_dst(idx), vreg_bias(idx));
169 vmulps(vreg_dst(idx), vreg_dst(idx), vreg_scale);
171 eltwise_injector_->compute_vector(vreg_dst(idx).getIdx());
173 if (dst_type == data_type::u8)
174 vmaxps(vreg_dst(idx), vreg_dst(idx), vreg_zero);
176 if (dst_type != data_type::f32) {
177 auto rmode_control = (rmode_ == nearest ? T_rn_sae : T_rd_sae);
178 vcvtps2dq(vreg_dst(idx) | rmode_control, vreg_dst(idx));
181 auto dst_addr = ptr[reg_dst + offset * sizeof(dst_data_t)];
184 vpmovsdb(dst_addr, vreg_dst_);
187 vpmovusdb(dst_addr, vreg_dst_);
191 vmovups(dst_addr, vreg_dst_);
193 default: assert(!"unimplemented");
197 // Advance all pointers by an immediate
198 auto advance_ptrs_imm = [&](size_t offset) {
199 add(reg_dst, offset * sizeof(dst_data_t));
200 add(reg_acc, offset * sizeof(acc_data_t));
201 if (scale_idx_mult_) {
202 assert(scale_idx_mult_ == 1);
203 add(reg_scales, offset * sizeof(float));
206 add(reg_bias, offset * bias_data_type_size_);
209 // Advance all pointers by a value stored in a register
210 auto advance_ptrs_reg = [&](Reg64 offset) {
211 lea(reg_dst, ptr[reg_dst + offset * sizeof(dst_data_t)]);
212 lea(reg_acc, ptr[reg_acc + offset * sizeof(acc_data_t)]);
213 if (scale_idx_mult_) {
214 assert(scale_idx_mult_ == 1);
215 lea(reg_scales, ptr[reg_scales + offset * sizeof(float)]);
218 lea(reg_bias, ptr[reg_bias + offset * bias_data_type_size_]);
221 // Rewind pointers that point to data that is indixed by output channel
222 // (bias or per-oc scaling factors)
223 auto rewind_ptrs = [&]() {
225 sub(reg_bias, OC_ * bias_data_type_size_);
226 if (scale_idx_mult_) {
227 assert(scale_idx_mult_ == 1);
228 sub(reg_scales, OC_ * sizeof(float));
232 // <-------------------- OC ------------------------------->
234 // ^ +....................+----------------------------------+
235 // | : not accessed | Prologue loop |
236 // | +--------------------+----------------------------------+
238 // M | Main loop (unrolled) |
240 // +--------------------------------+----------------------+
241 // | | Epilogue loop | not accessed :
242 // v +--------------------------------+......................+
245 cmp(reg_oc_offset, 0);
246 je(prologue_end, T_NEAR);
251 sub(reg_tmp, reg_oc_offset);
252 cmp(reg_tmp, reg_len);
253 cmovg(reg_tmp, reg_len);
254 sub(reg_len, reg_tmp);
256 Label prologue_loop, prologue_loop_tail, prologue_loop_end;
258 jle(prologue_loop_tail, T_NEAR); // Skips for reg_tmp == 16 too (?)
260 compute(0, 0, false);
261 advance_ptrs_imm(vlen);
264 jge(prologue_loop, T_NEAR);
267 L(prologue_loop_tail);
268 mov(reg_rem_mask, 1);
269 shl(reg_rem_mask, cl); // cl == reg_tmp because reg_tmp <= vlen here
270 sub(reg_rem_mask, 1);
271 jz(prologue_loop_end, T_NEAR);
273 kmovq(kreg_rem_mask, reg_rem_mask);
275 advance_ptrs_reg(reg_tmp);
277 L(prologue_loop_end);
286 jle(main_loop_end, T_NEAR);
290 size_t def_unroll = 4;
291 size_t max_unroll = 13;
293 size_t OC_loop, OC_tail;
294 if (OC_ < max_unroll * vlen) {
295 // Fully unroll small loops
299 OC_loop = vlen * def_unroll;
300 OC_tail = OC_ % OC_loop;
303 assert(!!OC_loop || !!OC_tail);
305 if (OC_tail % vlen) {
306 int vlen_tail = OC_tail % vlen;
307 unsigned tail_mask = (1 << vlen_tail) - 1;
308 mov(reg_tmp, tail_mask);
309 kmovq(kreg_rem_mask, reg_tmp);
313 mov(reg_tmp, rnd_dn(OC_, OC_loop));
316 for (size_t offset = 0; offset < OC_loop; offset += vlen)
317 compute(offset, offset / vlen, false);
318 advance_ptrs_imm(OC_loop);
319 sub(reg_tmp, OC_loop);
325 for (size_t offset = 0; offset < OC_tail; offset += vlen) {
326 bool use_mask = (offset + vlen) > OC_tail;
327 compute(offset, offset / vlen, use_mask);
329 advance_ptrs_imm(OC_tail);
335 jge(main_loop, T_NEAR);
344 je(epilogue_end, T_NEAR);
346 Label epilogue_loop, epilogue_loop_tail;
348 jle(epilogue_loop_tail, T_NEAR); // Skips for reg_len == 16 (?)
350 compute(0, 0, false);
352 advance_ptrs_imm(vlen);
354 jge(epilogue_loop, T_NEAR);
357 L(epilogue_loop_tail);
358 mov(reg_tmp, reg_len); // reg_tmp is rcx, and we need cl for the shift
359 mov(reg_rem_mask, 1);
360 shl(reg_rem_mask, cl); // reg_tmp == rcx and reg_tail < vlen == 16
361 sub(reg_rem_mask, 1);
362 jz(epilogue_end, T_NEAR);
363 kmovq(kreg_rem_mask, reg_rem_mask);
372 eltwise_injector_->prepare_table();
374 ker_ = getCode<decltype(ker_)>();
377 template <data_type_t acc_type, data_type_t dst_type>
378 void pp_kernel_t<acc_type, dst_type>::operator()(dst_data_t *dst,
379 const acc_data_t *acc, const char *bias, const float *scales,
380 size_t start, size_t end) {
381 using math::get_bias;
389 size_t oc_offset = start % OC_;
390 args.dst = dst + start;
391 args.acc = acc + start;
392 args.bias = bias + oc_offset * bias_data_type_size_;
393 args.scales = scales + scale_idx_mult_ * oc_offset;
394 args.len = end - start;
395 args.oc_offset = oc_offset;
399 size_t oc = start % OC_;
400 for (size_t i = start; i < end; i++) {
401 float d = (float)acc[i];
402 float b = get_bias(bias, oc, bias_data_type_);
404 d *= scales[oc * scale_idx_mult_];
406 d = ref_eltwise_->compute_scalar(d);
407 dst[i] = qz_a1b0<float, dst_data_t>()(d, rmode_);
408 oc = (oc == OC_ - 1) ? 0 : oc + 1;
414 using namespace data_type;
415 template class pp_kernel_t<f32, f32>;
416 template class pp_kernel_t<s32, f32>;
417 template class pp_kernel_t<s32, s32>;
418 template class pp_kernel_t<s32, s8>;
419 template class pp_kernel_t<s32, u8>;