Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / rnn / jit_uni_rnn_postgemm.hpp
1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 /*
18  * Cell execution LSTM
19  */
20
21 #include "rnn_utils.hpp"
22 #include "../jit_generator.hpp"
23 #include "../jit_uni_eltwise.hpp"
24 #include "c_types_map.hpp"
25 #include "utils.hpp"
26
27 #include "mkldnn_thread.hpp"
28
29
30 namespace mkldnn {
31 namespace impl {
32 namespace cpu {
33
34 struct jit_uni_rnn_postgemm_kernel : public jit_generator {
35
36     typedef void (*kernel_t)(void *gates_, const void *bias, void *states_t_l_,
37                      void *c_states_t_l_, void *c_states_tm1_l_);
38
39     jit_uni_rnn_postgemm_kernel(const rnn_utils::rnn_conf_t &rnn, const primitive_attr_t *attr): rnn_(rnn), attr_(attr){}
40
41     virtual void init() = 0;
42
43 template <typename src_data_t, typename acc_data_t>
44     rnn_elemwise_sig(execute) {
45         rnn_utils::ws_gates_aoc<acc_data_t> ws_gates(rnn, ws_gates_);
46         rnn_utils::bias_aoc_t bias(rnn, bias_);
47         rnn_utils::ws_states_aoc<src_data_t> states_t_l(rnn, states_t_l_);
48         rnn_utils::ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_);
49         rnn_utils::ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_);
50
51         // Todo: add parallelization on dic for the batch 1 case
52         // Assumption: the kernel runs a loop on dic elements
53         parallel_nd(rnn.mb, [&](int i) {
54                 auto b_ = &bias(0, 0);
55                 auto g_ = &ws_gates(i, 0, 0);
56                 auto s_tl_ = &states_t_l(i, 0);
57                 auto c_tl_ = &c_states_t_l(i, 0);
58                 auto c_tm1l_ = &c_states_tm1_l(i, 0);
59                 kernel_(g_, b_, s_tl_, c_tm1l_, c_tl_);
60             });
61     }
62
63 protected:
64     kernel_t kernel_;
65     const rnn_utils::rnn_conf_t &rnn_;
66     const primitive_attr_t *attr_;
67 };
68
69 template <cpu_isa_t isa, impl::data_type_t src_data_t>
70 struct jit_uni_lstm_postgemm_kernel_fwd: public jit_uni_rnn_postgemm_kernel
71 {
72     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lstm_postgemm_kernel_fwd)
73
74     typedef typename utils::conditional<src_data_t == data_type::u8, int32_t,
75             float>::type acc_data_t;
76     typedef typename utils::conditional<isa == avx512_core,
77             jit_uni_eltwise_injector_f32<avx512_common>,
78             jit_uni_eltwise_injector_f32<isa>>::type injector_t;
79
80     jit_uni_lstm_postgemm_kernel_fwd(const rnn_utils::rnn_conf_t &rnn, const primitive_attr_t *attr)
81     : jit_uni_rnn_postgemm_kernel(rnn, attr){}
82
83     void init() override {
84         // we use rax for both constant tables as they use the same table
85         sigmoid_injector_ = new injector_t(this,
86                 alg_kind::eltwise_logistic, 0.0f, 0.0f, true, rax);
87         tanh_injector_ = new injector_t(this,
88                 alg_kind::eltwise_tanh, 0.0f, 0.0f, true, rax);
89         generate();
90         kernel_ = (kernel_t) this->getCode();
91     }
92
93 protected:
94     injector_t *sigmoid_injector_;
95     injector_t *tanh_injector_;
96
97     // register size in bytes
98     using Vmm = typename jit_uni_eltwise_injector_f32<isa>::Vmm;
99     size_t vlen = cpu_isa_traits<isa>::vlen;
100     size_t vlen_dst = (src_data_t == data_type::u8) ? vlen/4 : vlen;
101     size_t cstate_dt_size = sizeof(float);
102     size_t hstate_dt_size = (src_data_t == data_type::u8) ? sizeof(uint8_t) : sizeof(float);
103     size_t gate_dt_size = (src_data_t == data_type::u8) ? sizeof(uint32_t) : sizeof(float);
104     size_t qscale_dt_size = sizeof(float);
105     size_t bias_dt_size = sizeof(float);
106
107     void generate() {
108         using namespace Xbyak;
109
110         int mask = attr_->rnn_weights_qparams_.mask_;
111         float *weights_scales = attr_->rnn_weights_qparams_.scales_;
112         float data_scale = attr_->rnn_data_qparams_.scale_;
113         float data_shift = attr_->rnn_data_qparams_.shift_;
114         round_mode_t rmode = attr_->round_mode_;
115
116         // Labels declaration
117         Label vector_loop_start_label, vector_loop_end_label;
118         Label rem_loop_start_label, rem_loop_end_label;
119         Label table_label;
120
121         // Register map
122         Reg64 loop_cnt(r11);  // loop counter
123         Reg64 table_reg(rbx); // table is used for data scale and shifts
124         Reg64 tmp_reg(r12);   // used as temporary to customize mxcsr
125         Reg64 weights_scales_reg(r13);
126         // We skip vmm0 as it can be used by the injector for masks on sse4.2
127         Vmm G0(1), G1(2), G2(3), G3(4), tmp1_vmm(5), tmp2_vmm(6), zero_vmm(7);
128
129         // stack map
130         Address saved_csr_addr = ptr[rsp];
131         Address modified_csr_addr = ptr[rsp + sizeof(int64_t)];
132         size_t stack_size = 2 * sizeof(int64_t);
133
134         // constant table map
135         Address dscale_off_addr = ptr[table_reg];
136         Address dshift_off_addr = ptr[table_reg + vlen];
137         Address ymm_perm_mask_addr = ptr[table_reg + 2*vlen];
138         Address zmm_perm_mask_addr = ptr[table_reg + 2*vlen + cpu_isa_traits<avx>::vlen];
139
140         // quantize from float to u8
141         auto q_d = [&](Vmm f, Vmm tmp_vmm, Reg64 tmp_reg) {
142             sub(rsp, stack_size);
143             stmxcsr(saved_csr_addr); // save the mxcsr
144
145             // set the rounding mode appropriatly
146             mov(tmp_reg, saved_csr_addr);
147             and_(tmp_reg, 0xffff9fff); // clear rc bits (rc = RNE)
148             if (rmode == round_mode::down)
149                 or_(tmp_reg, 0x00002000); // set rc=01 if RD
150             mov(modified_csr_addr, tmp_reg);
151             ldmxcsr(modified_csr_addr);
152
153             uni_vpxor(tmp_vmm, tmp_vmm, tmp_vmm);
154             uni_vmulps(f, f, dscale_off_addr); // apply scale
155             uni_vaddps(f, f, dshift_off_addr); // apply shift
156             uni_vcvtps2dq(f, f); // convert to int32 with mxcsr rounding
157             uni_vpackssdw(f, f, tmp_vmm); // convert from s32 to s16
158             uni_vpackuswb(f, f, tmp_vmm); // convert from s16 to u8 with saturation
159             // Note that the results are interleaved by 128 bit chunks, so we need to merge them together
160             switch (vlen) {
161             case 64:  { //avx512
162                 Zmm fz(f.getIdx()), tmpz(tmp_vmm.getIdx());
163                 uni_vmovups(tmpz, zmm_perm_mask_addr);
164                 vpermd(fz, tmpz, fz);
165                 break; }
166             case 32: { //avx
167                 Ymm fy(f.getIdx()), tmpy(tmp_vmm.getIdx());
168                 uni_vmovups(tmpy, ymm_perm_mask_addr);
169                 vpermd(fy, tmpy, fy);
170                 break; }
171             case 16: // sse: nothing to do
172                 break;
173             default: assert(!"Unsupported case");
174             };
175
176             ldmxcsr(saved_csr_addr); // restore the original mxcsr
177             add(rsp, stack_size);
178         };
179
180         auto fast_recip =[&](Vmm s, Vmm tmp, bool packed) {
181             if (packed)
182                 uni_vrcpps(tmp, s);
183             else
184                 uni_vrcpss(tmp, s); // prevent divide by zero
185             // we add one Newton iteration
186             uni_vmulps(s, s, tmp);
187             uni_vmulps(s, s, tmp); // s <- s * tmp^2
188             uni_vaddps(tmp, tmp, tmp);
189             uni_vsubps(tmp, tmp, s);
190             uni_vmovups(s, tmp); // s <- 2 * tmp - s * tmp^2
191         };
192
193         // dequantize from s32 to float
194         auto deq_w = [&](Vmm s, Vmm tmp1, Vmm tmp2, int gate, bool packed) {
195             // TODO: if mask is 0 precompute mul and inverse
196             if (mask == 0)
197                 uni_vbroadcastss(tmp1, ptr[weights_scales_reg]);
198             else
199                 uni_vmovups(tmp1, ptr[weights_scales_reg + gate * rnn_.dic * qscale_dt_size]);
200             uni_vcvtdq2ps(s, s);
201             uni_vmulps(tmp1, tmp1, dscale_off_addr);
202             fast_recip(tmp1, tmp2, packed);
203             uni_vmulps(s, s, tmp1);
204         };
205
206         // We start code generations here
207         preamble();
208
209         // extract addresses passed as parameter
210 #ifdef _WIN32
211         auto addr_ws_gates_reg = abi_param1;
212         auto addr_bias_reg = abi_param2;
213         auto addr_states_t_l_reg = abi_param3;
214         auto addr_c_states_tm1_l_reg = abi_param4;
215         auto addr_c_states_t_l_reg = r10;
216         // Here we cannot use rbp to have initial stack pointer so we
217         // use rsp and offset it with the size of pushed registers in
218         // preamble
219         mov(addr_c_states_t_l_reg, ptr[rsp + get_size_of_abi_save_regs() + 40]);
220 #else
221         auto addr_ws_gates_reg = abi_param1;
222         auto addr_bias_reg = abi_param2;
223         auto addr_states_t_l_reg = abi_param3;
224         auto addr_c_states_tm1_l_reg = abi_param4;
225         auto addr_c_states_t_l_reg = abi_param5;
226 #endif
227
228         // initialize registers with addresses and constants
229         mov(table_reg, table_label);
230         mov(weights_scales_reg, size_t(weights_scales));
231         // both sigmoid and tanh use the same table so load address just once in rax
232         sigmoid_injector_->load_table_addr();
233
234         mov(loop_cnt, rnn_.dic * gate_dt_size);
235         cmp(loop_cnt, vlen);
236         jl(vector_loop_end_label, Xbyak::CodeGenerator::T_NEAR);
237
238         L(vector_loop_start_label);
239         {
240             // load G0 G1 G2 G3
241             uni_vmovups(G0, ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size]);
242             uni_vmovups(G1, ptr[addr_ws_gates_reg + 1 * rnn_.dic * gate_dt_size]);
243             uni_vmovups(G2, ptr[addr_ws_gates_reg + 2 * rnn_.dic * gate_dt_size]);
244             uni_vmovups(G3, ptr[addr_ws_gates_reg + 3 * rnn_.dic * gate_dt_size]);
245
246             // dequantize the gates from s32 to f32 if needed
247             if (src_data_t == data_type::u8){
248                 deq_w(G0, tmp1_vmm, tmp2_vmm, 0, true);
249                 deq_w(G1, tmp1_vmm, tmp2_vmm, 1, true);
250                 deq_w(G2, tmp1_vmm, tmp2_vmm, 2, true);
251                 deq_w(G3, tmp1_vmm, tmp2_vmm, 3, true);
252             }
253
254             // add biases
255             uni_vaddps(G0, G0, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]);
256             uni_vaddps(G1, G1, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]);
257             uni_vaddps(G2, G2, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]);
258             uni_vaddps(G3, G3, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]);
259
260             // inject eltwise code
261             sigmoid_injector_->compute_vector(G0.getIdx());
262             sigmoid_injector_->compute_vector(G1.getIdx());
263             tanh_injector_->compute_vector(G2.getIdx());
264             sigmoid_injector_->compute_vector(G3.getIdx());
265
266             // compute c_states_t_l = G1 * c_tm1_l + G0 * G2
267             uni_vmovups(tmp1_vmm, ptr[addr_c_states_tm1_l_reg]);
268             uni_vmulps(tmp1_vmm, tmp1_vmm, G1);
269             uni_vfmadd231ps(tmp1_vmm, G0, G2);
270             uni_vmovups(ptr[addr_c_states_t_l_reg], tmp1_vmm);
271
272             // states_t_l = G3 * tanh(c_states_t_l)
273             tanh_injector_->compute_vector(tmp1_vmm.getIdx());
274             uni_vmulps(tmp1_vmm, tmp1_vmm, G3);
275
276             // if int8, we quantize the resulting state
277             if (src_data_t == data_type::u8) {
278                 q_d(tmp1_vmm, tmp2_vmm, tmp_reg);
279             }
280
281             // write back the result
282             if(vlen_dst == vlen)
283                 uni_vmovups(ptr[addr_states_t_l_reg], tmp1_vmm);
284             else
285                 // we write only 1/4 of the register
286                 switch(vlen_dst){
287                 case 16: uni_vmovups(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break;
288                 case 8: uni_vmovsd(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break;
289                 case 4: uni_vmovss(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break;
290                 default:
291                     assert(!"Unsuported vector length for quantization");
292                 }
293
294             // increment address pointers
295             add(addr_ws_gates_reg, vlen);
296             add(addr_bias_reg, vlen);
297             add(addr_states_t_l_reg, vlen_dst);
298             add(addr_c_states_tm1_l_reg, vlen);
299             add(addr_c_states_t_l_reg, vlen);
300             if (mask != 0)
301                 add(weights_scales_reg, vlen);
302
303             // increment loop counter
304             sub(loop_cnt, vlen);
305             cmp(loop_cnt, vlen);
306             jge(vector_loop_start_label);
307         }
308         L(vector_loop_end_label);
309
310         cmp(loop_cnt, 0);
311         je(rem_loop_end_label, Xbyak::CodeGenerator::T_NEAR);
312         // Same code as above, we just use movuss for accessing inputs
313         // TODO: smarter handling of tails with Zmm -> Ymm -> Xmm -> scalar
314         L(rem_loop_start_label);
315         {
316             // remaping registers to Xmms
317             Xmm G0s(G0.getIdx()), G1s(G1.getIdx()), G2s(G2.getIdx()), G3s(G3.getIdx());
318             Xmm tmp1s_vmm(tmp1_vmm.getIdx());
319
320             // load G0 G1 G2 G3
321             uni_vmovss(G0s, ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size]);
322             uni_vmovss(G1s, ptr[addr_ws_gates_reg + 1 * rnn_.dic * gate_dt_size]);
323             uni_vmovss(G2s, ptr[addr_ws_gates_reg + 2 * rnn_.dic * gate_dt_size]);
324             uni_vmovss(G3s, ptr[addr_ws_gates_reg + 3 * rnn_.dic * gate_dt_size]);
325
326             // dequantize the gates from s32 to f32 if needed
327             if (src_data_t == data_type::u8){
328                 deq_w(G0, tmp1_vmm, tmp2_vmm, 0, false);
329                 deq_w(G1, tmp1_vmm, tmp2_vmm, 1, false);
330                 deq_w(G2, tmp1_vmm, tmp2_vmm, 2, false);
331                 deq_w(G3, tmp1_vmm, tmp2_vmm, 3, false);
332             }
333
334             // add biases
335             uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]);
336             uni_vaddps(G0s, G0s, tmp1s_vmm);
337             uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]);
338             uni_vaddps(G1s, G1s, tmp1s_vmm);
339             uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]);
340             uni_vaddps(G2s, G2s, tmp1s_vmm);
341             uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]);
342             uni_vaddps(G3s, G3s, tmp1s_vmm);
343
344             // inject eltwise code
345             sigmoid_injector_->compute_vector(G0s.getIdx());
346             sigmoid_injector_->compute_vector(G1s.getIdx());
347             tanh_injector_->compute_vector(G2s.getIdx());
348             sigmoid_injector_->compute_vector(G3s.getIdx());
349
350             // compute c_states_t_l = G1 * c_tm1_l + G0s * G2
351             uni_vmovups(tmp1s_vmm, ptr[addr_c_states_tm1_l_reg]);
352             uni_vmulps(tmp1s_vmm, tmp1s_vmm, G1s);
353             uni_vfmadd231ps(tmp1s_vmm, G0s, G2s);
354             uni_vmovss(ptr[addr_c_states_t_l_reg], tmp1s_vmm);
355
356             // states_t_l = G3 * tanh(c_states_t_l)
357             tanh_injector_->compute_vector(tmp1s_vmm.getIdx());
358             uni_vmulps(tmp1s_vmm, tmp1s_vmm, G3s);
359
360             // if int8, we quantize the resulting state
361             if (src_data_t == data_type::u8) {
362                 q_d(tmp1_vmm, tmp2_vmm, tmp_reg);
363             }
364
365             // write back the result
366             if(vlen_dst == vlen)
367                 uni_vmovups(ptr[addr_states_t_l_reg], tmp1s_vmm);
368             else
369                 // we write only 1/4 of the register
370                 switch(vlen_dst){
371                 case 16: uni_vmovups(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break;
372                 case 8: uni_vmovsd(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break;
373                 case 4: uni_vmovss(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break;
374                 default:
375                     assert(!"Unsuported vector length for quantization");
376                 }
377
378             // increment address pointers
379             add(addr_ws_gates_reg, gate_dt_size);
380             add(addr_bias_reg, bias_dt_size);
381             add(addr_states_t_l_reg, hstate_dt_size);
382             add(addr_c_states_tm1_l_reg, cstate_dt_size);
383             add(addr_c_states_t_l_reg, cstate_dt_size);
384             if (mask != 0)
385                 add(weights_scales_reg, qscale_dt_size);
386
387             // increment loop counter
388             sub(loop_cnt, gate_dt_size);
389             cmp(loop_cnt, 0);
390             jg(rem_loop_start_label);
391
392         }
393         L(rem_loop_end_label);
394
395         postamble();
396
397         // Again, only one table is needed and shared between sigmoid and tanh
398         sigmoid_injector_->prepare_table(false);
399         tanh_injector_->prepare_table(true);
400
401         L(table_label);
402         {
403             for (size_t i = 0; i < vlen / sizeof(float); i++) dd(float2int(data_scale));
404             for (size_t i = 0; i < vlen / sizeof(float); i++) dd(float2int(data_shift));
405             // perm mask for ymm
406             dd(0); dd(4); dd(2); dd(3); dd(1); dd(5); dd(6); dd(7);
407             // perm mask for zmm
408             dd(0); dd(4); dd(8); dd(12); dd(1); dd(5); dd(6); dd(7);
409             dd(2); dd(9); dd(10); dd(11); dd(3); dd(12); dd(13); dd(14);
410         }
411     }
412
413 };
414
415 template struct jit_uni_lstm_postgemm_kernel_fwd<sse42, data_type::f32>;
416 template struct jit_uni_lstm_postgemm_kernel_fwd<avx2, data_type::f32>;
417 template struct jit_uni_lstm_postgemm_kernel_fwd<avx512_core, data_type::f32>;
418
419 template struct jit_uni_lstm_postgemm_kernel_fwd<sse42, data_type::u8>;
420 template struct jit_uni_lstm_postgemm_kernel_fwd<avx2, data_type::u8>;
421 template struct jit_uni_lstm_postgemm_kernel_fwd<avx512_core, data_type::u8>;
422 }
423 }
424 }