1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
21 #include "rnn_utils.hpp"
22 #include "../jit_generator.hpp"
23 #include "../jit_uni_eltwise.hpp"
24 #include "c_types_map.hpp"
27 #include "mkldnn_thread.hpp"
34 struct jit_uni_rnn_postgemm_kernel : public jit_generator {
36 typedef void (*kernel_t)(void *gates_, const void *bias, void *states_t_l_,
37 void *c_states_t_l_, void *c_states_tm1_l_);
39 jit_uni_rnn_postgemm_kernel(const rnn_utils::rnn_conf_t &rnn, const primitive_attr_t *attr): rnn_(rnn), attr_(attr){}
41 virtual void init() = 0;
43 template <typename src_data_t, typename acc_data_t>
44 rnn_elemwise_sig(execute) {
45 rnn_utils::ws_gates_aoc<acc_data_t> ws_gates(rnn, ws_gates_);
46 rnn_utils::bias_aoc_t bias(rnn, bias_);
47 rnn_utils::ws_states_aoc<src_data_t> states_t_l(rnn, states_t_l_);
48 rnn_utils::ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_);
49 rnn_utils::ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_);
51 // Todo: add parallelization on dic for the batch 1 case
52 // Assumption: the kernel runs a loop on dic elements
53 parallel_nd(rnn.mb, [&](int i) {
54 auto b_ = &bias(0, 0);
55 auto g_ = &ws_gates(i, 0, 0);
56 auto s_tl_ = &states_t_l(i, 0);
57 auto c_tl_ = &c_states_t_l(i, 0);
58 auto c_tm1l_ = &c_states_tm1_l(i, 0);
59 kernel_(g_, b_, s_tl_, c_tm1l_, c_tl_);
65 const rnn_utils::rnn_conf_t &rnn_;
66 const primitive_attr_t *attr_;
69 template <cpu_isa_t isa, impl::data_type_t src_data_t>
70 struct jit_uni_lstm_postgemm_kernel_fwd: public jit_uni_rnn_postgemm_kernel
72 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lstm_postgemm_kernel_fwd)
74 typedef typename utils::conditional<src_data_t == data_type::u8, int32_t,
75 float>::type acc_data_t;
76 typedef typename utils::conditional<isa == avx512_core,
77 jit_uni_eltwise_injector_f32<avx512_common>,
78 jit_uni_eltwise_injector_f32<isa>>::type injector_t;
80 jit_uni_lstm_postgemm_kernel_fwd(const rnn_utils::rnn_conf_t &rnn, const primitive_attr_t *attr)
81 : jit_uni_rnn_postgemm_kernel(rnn, attr){}
83 void init() override {
84 // we use rax for both constant tables as they use the same table
85 sigmoid_injector_ = new injector_t(this,
86 alg_kind::eltwise_logistic, 0.0f, 0.0f, true, rax);
87 tanh_injector_ = new injector_t(this,
88 alg_kind::eltwise_tanh, 0.0f, 0.0f, true, rax);
90 kernel_ = (kernel_t) this->getCode();
94 injector_t *sigmoid_injector_;
95 injector_t *tanh_injector_;
97 // register size in bytes
98 using Vmm = typename jit_uni_eltwise_injector_f32<isa>::Vmm;
99 size_t vlen = cpu_isa_traits<isa>::vlen;
100 size_t vlen_dst = (src_data_t == data_type::u8) ? vlen/4 : vlen;
101 size_t cstate_dt_size = sizeof(float);
102 size_t hstate_dt_size = (src_data_t == data_type::u8) ? sizeof(uint8_t) : sizeof(float);
103 size_t gate_dt_size = (src_data_t == data_type::u8) ? sizeof(uint32_t) : sizeof(float);
104 size_t qscale_dt_size = sizeof(float);
105 size_t bias_dt_size = sizeof(float);
108 using namespace Xbyak;
110 int mask = attr_->rnn_weights_qparams_.mask_;
111 float *weights_scales = attr_->rnn_weights_qparams_.scales_;
112 float data_scale = attr_->rnn_data_qparams_.scale_;
113 float data_shift = attr_->rnn_data_qparams_.shift_;
114 round_mode_t rmode = attr_->round_mode_;
116 // Labels declaration
117 Label vector_loop_start_label, vector_loop_end_label;
118 Label rem_loop_start_label, rem_loop_end_label;
122 Reg64 loop_cnt(r11); // loop counter
123 Reg64 table_reg(rbx); // table is used for data scale and shifts
124 Reg64 tmp_reg(r12); // used as temporary to customize mxcsr
125 Reg64 weights_scales_reg(r13);
126 // We skip vmm0 as it can be used by the injector for masks on sse4.2
127 Vmm G0(1), G1(2), G2(3), G3(4), tmp1_vmm(5), tmp2_vmm(6), zero_vmm(7);
130 Address saved_csr_addr = ptr[rsp];
131 Address modified_csr_addr = ptr[rsp + sizeof(int64_t)];
132 size_t stack_size = 2 * sizeof(int64_t);
134 // constant table map
135 Address dscale_off_addr = ptr[table_reg];
136 Address dshift_off_addr = ptr[table_reg + vlen];
137 Address ymm_perm_mask_addr = ptr[table_reg + 2*vlen];
138 Address zmm_perm_mask_addr = ptr[table_reg + 2*vlen + cpu_isa_traits<avx>::vlen];
140 // quantize from float to u8
141 auto q_d = [&](Vmm f, Vmm tmp_vmm, Reg64 tmp_reg) {
142 sub(rsp, stack_size);
143 stmxcsr(saved_csr_addr); // save the mxcsr
145 // set the rounding mode appropriatly
146 mov(tmp_reg, saved_csr_addr);
147 and_(tmp_reg, 0xffff9fff); // clear rc bits (rc = RNE)
148 if (rmode == round_mode::down)
149 or_(tmp_reg, 0x00002000); // set rc=01 if RD
150 mov(modified_csr_addr, tmp_reg);
151 ldmxcsr(modified_csr_addr);
153 uni_vpxor(tmp_vmm, tmp_vmm, tmp_vmm);
154 uni_vmulps(f, f, dscale_off_addr); // apply scale
155 uni_vaddps(f, f, dshift_off_addr); // apply shift
156 uni_vcvtps2dq(f, f); // convert to int32 with mxcsr rounding
157 uni_vpackssdw(f, f, tmp_vmm); // convert from s32 to s16
158 uni_vpackuswb(f, f, tmp_vmm); // convert from s16 to u8 with saturation
159 // Note that the results are interleaved by 128 bit chunks, so we need to merge them together
162 Zmm fz(f.getIdx()), tmpz(tmp_vmm.getIdx());
163 uni_vmovups(tmpz, zmm_perm_mask_addr);
164 vpermd(fz, tmpz, fz);
167 Ymm fy(f.getIdx()), tmpy(tmp_vmm.getIdx());
168 uni_vmovups(tmpy, ymm_perm_mask_addr);
169 vpermd(fy, tmpy, fy);
171 case 16: // sse: nothing to do
173 default: assert(!"Unsupported case");
176 ldmxcsr(saved_csr_addr); // restore the original mxcsr
177 add(rsp, stack_size);
180 auto fast_recip =[&](Vmm s, Vmm tmp, bool packed) {
184 uni_vrcpss(tmp, s); // prevent divide by zero
185 // we add one Newton iteration
186 uni_vmulps(s, s, tmp);
187 uni_vmulps(s, s, tmp); // s <- s * tmp^2
188 uni_vaddps(tmp, tmp, tmp);
189 uni_vsubps(tmp, tmp, s);
190 uni_vmovups(s, tmp); // s <- 2 * tmp - s * tmp^2
193 // dequantize from s32 to float
194 auto deq_w = [&](Vmm s, Vmm tmp1, Vmm tmp2, int gate, bool packed) {
195 // TODO: if mask is 0 precompute mul and inverse
197 uni_vbroadcastss(tmp1, ptr[weights_scales_reg]);
199 uni_vmovups(tmp1, ptr[weights_scales_reg + gate * rnn_.dic * qscale_dt_size]);
201 uni_vmulps(tmp1, tmp1, dscale_off_addr);
202 fast_recip(tmp1, tmp2, packed);
203 uni_vmulps(s, s, tmp1);
206 // We start code generations here
209 // extract addresses passed as parameter
211 auto addr_ws_gates_reg = abi_param1;
212 auto addr_bias_reg = abi_param2;
213 auto addr_states_t_l_reg = abi_param3;
214 auto addr_c_states_tm1_l_reg = abi_param4;
215 auto addr_c_states_t_l_reg = r10;
216 // Here we cannot use rbp to have initial stack pointer so we
217 // use rsp and offset it with the size of pushed registers in
219 mov(addr_c_states_t_l_reg, ptr[rsp + get_size_of_abi_save_regs() + 40]);
221 auto addr_ws_gates_reg = abi_param1;
222 auto addr_bias_reg = abi_param2;
223 auto addr_states_t_l_reg = abi_param3;
224 auto addr_c_states_tm1_l_reg = abi_param4;
225 auto addr_c_states_t_l_reg = abi_param5;
228 // initialize registers with addresses and constants
229 mov(table_reg, table_label);
230 mov(weights_scales_reg, size_t(weights_scales));
231 // both sigmoid and tanh use the same table so load address just once in rax
232 sigmoid_injector_->load_table_addr();
234 mov(loop_cnt, rnn_.dic * gate_dt_size);
236 jl(vector_loop_end_label, Xbyak::CodeGenerator::T_NEAR);
238 L(vector_loop_start_label);
241 uni_vmovups(G0, ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size]);
242 uni_vmovups(G1, ptr[addr_ws_gates_reg + 1 * rnn_.dic * gate_dt_size]);
243 uni_vmovups(G2, ptr[addr_ws_gates_reg + 2 * rnn_.dic * gate_dt_size]);
244 uni_vmovups(G3, ptr[addr_ws_gates_reg + 3 * rnn_.dic * gate_dt_size]);
246 // dequantize the gates from s32 to f32 if needed
247 if (src_data_t == data_type::u8){
248 deq_w(G0, tmp1_vmm, tmp2_vmm, 0, true);
249 deq_w(G1, tmp1_vmm, tmp2_vmm, 1, true);
250 deq_w(G2, tmp1_vmm, tmp2_vmm, 2, true);
251 deq_w(G3, tmp1_vmm, tmp2_vmm, 3, true);
255 uni_vaddps(G0, G0, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]);
256 uni_vaddps(G1, G1, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]);
257 uni_vaddps(G2, G2, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]);
258 uni_vaddps(G3, G3, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]);
260 // inject eltwise code
261 sigmoid_injector_->compute_vector(G0.getIdx());
262 sigmoid_injector_->compute_vector(G1.getIdx());
263 tanh_injector_->compute_vector(G2.getIdx());
264 sigmoid_injector_->compute_vector(G3.getIdx());
266 // compute c_states_t_l = G1 * c_tm1_l + G0 * G2
267 uni_vmovups(tmp1_vmm, ptr[addr_c_states_tm1_l_reg]);
268 uni_vmulps(tmp1_vmm, tmp1_vmm, G1);
269 uni_vfmadd231ps(tmp1_vmm, G0, G2);
270 uni_vmovups(ptr[addr_c_states_t_l_reg], tmp1_vmm);
272 // states_t_l = G3 * tanh(c_states_t_l)
273 tanh_injector_->compute_vector(tmp1_vmm.getIdx());
274 uni_vmulps(tmp1_vmm, tmp1_vmm, G3);
276 // if int8, we quantize the resulting state
277 if (src_data_t == data_type::u8) {
278 q_d(tmp1_vmm, tmp2_vmm, tmp_reg);
281 // write back the result
283 uni_vmovups(ptr[addr_states_t_l_reg], tmp1_vmm);
285 // we write only 1/4 of the register
287 case 16: uni_vmovups(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break;
288 case 8: uni_vmovsd(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break;
289 case 4: uni_vmovss(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break;
291 assert(!"Unsuported vector length for quantization");
294 // increment address pointers
295 add(addr_ws_gates_reg, vlen);
296 add(addr_bias_reg, vlen);
297 add(addr_states_t_l_reg, vlen_dst);
298 add(addr_c_states_tm1_l_reg, vlen);
299 add(addr_c_states_t_l_reg, vlen);
301 add(weights_scales_reg, vlen);
303 // increment loop counter
306 jge(vector_loop_start_label);
308 L(vector_loop_end_label);
311 je(rem_loop_end_label, Xbyak::CodeGenerator::T_NEAR);
312 // Same code as above, we just use movuss for accessing inputs
313 // TODO: smarter handling of tails with Zmm -> Ymm -> Xmm -> scalar
314 L(rem_loop_start_label);
316 // remaping registers to Xmms
317 Xmm G0s(G0.getIdx()), G1s(G1.getIdx()), G2s(G2.getIdx()), G3s(G3.getIdx());
318 Xmm tmp1s_vmm(tmp1_vmm.getIdx());
321 uni_vmovss(G0s, ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size]);
322 uni_vmovss(G1s, ptr[addr_ws_gates_reg + 1 * rnn_.dic * gate_dt_size]);
323 uni_vmovss(G2s, ptr[addr_ws_gates_reg + 2 * rnn_.dic * gate_dt_size]);
324 uni_vmovss(G3s, ptr[addr_ws_gates_reg + 3 * rnn_.dic * gate_dt_size]);
326 // dequantize the gates from s32 to f32 if needed
327 if (src_data_t == data_type::u8){
328 deq_w(G0, tmp1_vmm, tmp2_vmm, 0, false);
329 deq_w(G1, tmp1_vmm, tmp2_vmm, 1, false);
330 deq_w(G2, tmp1_vmm, tmp2_vmm, 2, false);
331 deq_w(G3, tmp1_vmm, tmp2_vmm, 3, false);
335 uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]);
336 uni_vaddps(G0s, G0s, tmp1s_vmm);
337 uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]);
338 uni_vaddps(G1s, G1s, tmp1s_vmm);
339 uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]);
340 uni_vaddps(G2s, G2s, tmp1s_vmm);
341 uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]);
342 uni_vaddps(G3s, G3s, tmp1s_vmm);
344 // inject eltwise code
345 sigmoid_injector_->compute_vector(G0s.getIdx());
346 sigmoid_injector_->compute_vector(G1s.getIdx());
347 tanh_injector_->compute_vector(G2s.getIdx());
348 sigmoid_injector_->compute_vector(G3s.getIdx());
350 // compute c_states_t_l = G1 * c_tm1_l + G0s * G2
351 uni_vmovups(tmp1s_vmm, ptr[addr_c_states_tm1_l_reg]);
352 uni_vmulps(tmp1s_vmm, tmp1s_vmm, G1s);
353 uni_vfmadd231ps(tmp1s_vmm, G0s, G2s);
354 uni_vmovss(ptr[addr_c_states_t_l_reg], tmp1s_vmm);
356 // states_t_l = G3 * tanh(c_states_t_l)
357 tanh_injector_->compute_vector(tmp1s_vmm.getIdx());
358 uni_vmulps(tmp1s_vmm, tmp1s_vmm, G3s);
360 // if int8, we quantize the resulting state
361 if (src_data_t == data_type::u8) {
362 q_d(tmp1_vmm, tmp2_vmm, tmp_reg);
365 // write back the result
367 uni_vmovups(ptr[addr_states_t_l_reg], tmp1s_vmm);
369 // we write only 1/4 of the register
371 case 16: uni_vmovups(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break;
372 case 8: uni_vmovsd(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break;
373 case 4: uni_vmovss(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break;
375 assert(!"Unsuported vector length for quantization");
378 // increment address pointers
379 add(addr_ws_gates_reg, gate_dt_size);
380 add(addr_bias_reg, bias_dt_size);
381 add(addr_states_t_l_reg, hstate_dt_size);
382 add(addr_c_states_tm1_l_reg, cstate_dt_size);
383 add(addr_c_states_t_l_reg, cstate_dt_size);
385 add(weights_scales_reg, qscale_dt_size);
387 // increment loop counter
388 sub(loop_cnt, gate_dt_size);
390 jg(rem_loop_start_label);
393 L(rem_loop_end_label);
397 // Again, only one table is needed and shared between sigmoid and tanh
398 sigmoid_injector_->prepare_table(false);
399 tanh_injector_->prepare_table(true);
403 for (size_t i = 0; i < vlen / sizeof(float); i++) dd(float2int(data_scale));
404 for (size_t i = 0; i < vlen / sizeof(float); i++) dd(float2int(data_shift));
406 dd(0); dd(4); dd(2); dd(3); dd(1); dd(5); dd(6); dd(7);
408 dd(0); dd(4); dd(8); dd(12); dd(1); dd(5); dd(6); dd(7);
409 dd(2); dd(9); dd(10); dd(11); dd(3); dd(12); dd(13); dd(14);
415 template struct jit_uni_lstm_postgemm_kernel_fwd<sse42, data_type::f32>;
416 template struct jit_uni_lstm_postgemm_kernel_fwd<avx2, data_type::f32>;
417 template struct jit_uni_lstm_postgemm_kernel_fwd<avx512_core, data_type::f32>;
419 template struct jit_uni_lstm_postgemm_kernel_fwd<sse42, data_type::u8>;
420 template struct jit_uni_lstm_postgemm_kernel_fwd<avx2, data_type::u8>;
421 template struct jit_uni_lstm_postgemm_kernel_fwd<avx512_core, data_type::u8>;