Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / rnn / cell_lstm.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 /*
18  * Cell execution LSTM
19  */
20
21 #include "math_utils.hpp"
22 #include "mkldnn_thread.hpp"
23
24 #include "../simple_q10n.hpp"
25 #include "ref_rnn.hpp"
26
27 namespace mkldnn {
28 namespace impl {
29 namespace cpu {
30
31 using namespace mkldnn::impl::utils;
32 using namespace mkldnn::impl::math;
33 using namespace rnn_utils;
34
35 template <>
36 rnn_elemwise_sig(ref_rnn_fwd_f32_t::lstm_elemwise) {
37     ws_gates_aoc_t ws_gates(rnn, ws_gates_);
38     bias_aoc_t bias(rnn, bias_);
39     ws_states_aoc_t states_t_l(rnn, states_t_l_);
40     ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_);
41     ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_);
42
43     parallel_nd(rnn.mb, [&](int i) {
44 // WA. Loss of correctnes in case of simd loop unrolling with icc 18
45 #if !defined(__INTEL_COMPILER)
46         PRAGMA_OMP_SIMD()
47 #endif
48         for (int j = 0; j < rnn.dic; j++) {
49             ws_gates(i, 0, j) = logistic_fwd(ws_gates(i, 0, j) + bias(0, j));
50             ws_gates(i, 1, j) = logistic_fwd(ws_gates(i, 1, j) + bias(1, j));
51             ws_gates(i, 2, j) = tanh_fwd(ws_gates(i, 2, j) + bias(2, j));
52             ws_gates(i, 3, j) = logistic_fwd(ws_gates(i, 3, j) + bias(3, j));
53
54             float tmp = ws_gates(i, 1, j) * c_states_tm1_l(i, j)
55                     + ws_gates(i, 0, j) * ws_gates(i, 2, j);
56             states_t_l(i, j) = ws_gates(i, 3, j) * tanh_fwd(tmp);
57             c_states_t_l(i, j) = tmp;
58         }
59     });
60 }
61
62 template <>
63 rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::lstm_elemwise) {
64     ws_gates_aoc_s32_t ws_gates_s32(rnn, ws_gates_);
65     bias_aoc_t bias(rnn, bias_);
66     ws_states_aoc_u8_t states_t_l(rnn, states_t_l_);
67     ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_);
68     ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_);
69
70     float *weights_scales = pd()->attr()->rnn_weights_qparams_.scales_;
71     float data_shift = pd()->attr()->rnn_data_qparams_.shift_;
72     float data_scale = pd()->attr()->rnn_data_qparams_.scale_;
73     round_mode_t rmode = pd()->attr()->round_mode_;
74
75     auto q_d = [&](float f) {
76         float qf = f * data_scale + data_shift;
77         return qz_a1b0<float, src_data_t>()(qf, rmode);
78     };
79
80     auto deq_w = [&](acc_data_t s, int gate, int j) {
81         return pd()->attr()->rnn_weights_qparams_.mask_ == 0 ?
82                 saturate<float>(s) * (1.f / (weights_scales[0] * data_scale)) :
83                 saturate<float>(s) * (1.f / (weights_scales[gate * rnn.dic + j]
84                                                    * data_scale));
85     };
86
87     parallel_nd(rnn.mb, [&](int i) {
88         PRAGMA_OMP_SIMD()
89         for (int j = 0; j < rnn.dic; j++) {
90             float G0 = logistic_fwd<float>(
91                     deq_w(ws_gates_s32(i, 0, j), 0, j) + bias(0, j));
92             float G1 = logistic_fwd<float>(
93                     deq_w(ws_gates_s32(i, 1, j), 1, j) + bias(1, j));
94             float G2 = tanh_fwd<float>(
95                     deq_w(ws_gates_s32(i, 2, j), 2, j) + bias(2, j));
96             float G3 = logistic_fwd<float>(
97                     deq_w(ws_gates_s32(i, 3, j), 3, j) + bias(3, j));
98             float tmp = G1 * c_states_tm1_l(i, j) + G0 * G2;
99             states_t_l(i, j) = q_d(G3 * tanh_fwd(tmp));
100             c_states_t_l(i, j) = tmp;
101         }
102     });
103 }
104
105 template <>
106 rnn_elemwise_sig(ref_rnn_bwd_f32_t::lstm_elemwise) {
107     ws_gates_aoc_t ws_gates(rnn, ws_gates_);
108     bias_aoc_t bias(rnn, bias_);
109     ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_);
110     ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_);
111     ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_);
112     ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_);
113     ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_);
114
115     parallel_nd(rnn.mb, [&](int i) {
116         PRAGMA_OMP_SIMD()
117         for (int j = 0; j < rnn.dic; j++) {
118             float Ct = c_states_t_l(i, j);
119             /// @todo save it in the workspace in fwd pass or recompute it to
120             /// save bw
121             float tanhCt = tanh_fwd(Ct);
122             // we have 2 incoming diffs on Ht
123             float dHt = diff_states_tp1_l(0, i, j)
124                     + diff_states_t_lp1(rnn.n_states, i, j);
125             float dCt = diff_states_tp1_l(1, i, j)
126                     + one_m_square(tanhCt) * ws_gates(i, 3, j) * dHt;
127
128             float dG1 = c_states_tm1_l(i, j) * dCt
129                     * x_m_square(ws_gates(i, 1, j));
130             float dG0 = ws_gates(i, 2, j) * dCt * x_m_square(ws_gates(i, 0, j));
131             float dG3 = tanhCt * dHt * x_m_square(ws_gates(i, 3, j));
132             float dG2
133                     = ws_gates(i, 0, j) * dCt * one_m_square(ws_gates(i, 2, j));
134
135             diff_states_t_l(1, i, j) = dCt * ws_gates(i, 1, j);
136
137             ws_gates(i, 0, j) = dG0;
138             ws_gates(i, 1, j) = dG1;
139             ws_gates(i, 2, j) = dG2;
140             ws_gates(i, 3, j) = dG3;
141         }
142     });
143 }
144
145 }
146 }
147 }