Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / rnn / cell_gru_lbr.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 /*
18  * Cell execution GRU with linear before reset
19  */
20
21 #include "math_utils.hpp"
22 #include "mkldnn_thread.hpp"
23
24 #include "ref_rnn.hpp"
25
26 namespace mkldnn {
27 namespace impl {
28 namespace cpu {
29
30 using namespace mkldnn::impl::utils;
31 using namespace mkldnn::impl::math;
32 using namespace rnn_utils;
33 #define AOC array_offset_calculator
34
35 template <>
36 rnn_elemwise_sig(ref_rnn_fwd_f32_t::gru_lbr_elemwise) {
37     ws_gates_aoc_t ws_gates(rnn, ws_gates_);
38     bias_aoc_t bias(rnn, bias_);
39     ws_states_aoc_t states_t_l(rnn, states_t_l_);
40     ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_);
41     ws_gates_aoc_t ws_gemm_state(rnn, ws_cell_);
42     AOC<float, 2> ws_Wh_b(ws_grid_, rnn.mb, rnn.dic);
43
44     parallel_nd(rnn.mb, [&](int i) {
45         PRAGMA_OMP_SIMD()
46         for (int j = 0; j < rnn.dic; j++) {
47             float Wh_b = ws_gemm_state(i, 2, j) + bias(3, j);
48             ws_gates(i, 0, j) = logistic_fwd(
49                     ws_gates(i, 0, j) + ws_gemm_state(i, 0, j) + bias(0, j));
50             ws_gates(i, 1, j) = logistic_fwd(
51                     ws_gates(i, 1, j) + ws_gemm_state(i, 1, j) + bias(1, j));
52             ws_gates(i, 2, j) = tanh_fwd(
53                     ws_gates(i, 2, j) + ws_gates(i, 1, j) * Wh_b + bias(2, j));
54             states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 0, j)
55                     + (1.0f - ws_gates(i, 0, j)) * ws_gates(i, 2, j);
56             if (rnn.is_training)
57                 ws_Wh_b(i, j) = Wh_b;
58         }
59     });
60 }
61
62 template <>
63 rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::gru_lbr_elemwise) {
64     assert(!"GRU LBR int8 is not supported");
65 }
66
67 template <>
68 rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru_lbr) {
69     if (!rnn.merge_gemm_layer) {
70         (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb,
71                 rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld,
72                 states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_,
73                 rnn.gates_ws_ld);
74     }
75     (this->*gemm_iter_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, rnn.sic,
76             1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_,
77             rnn.states_ws_ld, 0.0, ws_cell_, rnn.gates_ws_ld);
78     (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_,
79             states_tm1_l_, c_states_tm1_l_, diff_states_t_l_,
80             diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_,
81             ws_cell_);
82 }
83
84 template <>
85 rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru_lbr) {
86     assert(!"GRU LBR int8 is not supported");
87 }
88
89 template <>
90 rnn_elemwise_sig(ref_rnn_bwd_f32_t::gru_lbr_elemwise) {
91     ws_gates_aoc_t ws_gates(rnn, ws_gates_);
92     ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_);
93     ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_);
94     ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_);
95     ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_);
96     ws_gates_aoc_t ws_gates_r(rnn, ws_cell_);
97     AOC<float, 2> ws_Wh_b(ws_grid_, rnn.mb, rnn.dic);
98
99     // 1. calculate dG1 dG2 dG3
100     // dG0 = (dht - G2) * dht * (1 - G0) * G0
101     // dG1 = (W*h + b) * dG2 * (1 - G1) * G1
102     // dG2 = (1 - G0) * dht * (1 - G2*G2)
103     parallel_nd(rnn.mb, [&](int i) {
104         PRAGMA_OMP_SIMD()
105         for (int j = 0; j < rnn.dic; j++) {
106             float h = states_tm1_l(i, j);
107             float dHt = diff_states_tp1_l(0, i, j)
108                     + diff_states_t_lp1(rnn.n_states, i, j);
109             float dG0 = (h - ws_gates(i, 2, j)) * dHt
110                     * x_m_square(ws_gates(i, 0, j));
111             float dG2 = (1.0f - ws_gates(i, 0, j))
112                     * one_m_square(ws_gates(i, 2, j)) * dHt;
113             float dG1 = ws_Wh_b(i, j) * dG2 * x_m_square(ws_gates(i, 1, j));
114
115             diff_states_t_l(0, i, j) = dHt * ws_gates(i, 0, j);
116             ws_gates(i, 2, j) = dG2;
117             ws_gates_r(i, 2, j) = dG2 * ws_gates(i, 1, j);
118             ws_gates(i, 0, j) = ws_gates_r(i, 0, j) = dG0;
119             ws_gates(i, 1, j) = ws_gates_r(i, 1, j) = dG1;
120         }
121     });
122 }
123
124 template <>
125 rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru_lbr) {
126     ws_gates_aoc_t ws_gates_r(rnn, ws_cell_);
127     ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_);
128
129     (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_,
130             states_tm1_l_, c_states_tm1_l_, diff_states_t_l_,
131             diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_,
132             ws_cell_);
133
134     if (!rnn.merge_gemm_layer) {
135         //  dx = dG * Wx^t
136         (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb,
137                 rnn.n_gates * rnn.dic, 1.0, w_layer_[0],
138                 rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0,
139                 &diff_states_t_l(rnn.n_states, 0, 0), rnn.states_ws_ld);
140         // dWx +=  dG^t * x
141         gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_,
142                 rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0,
143                 diff_w_layer_, rnn.diff_weights_layer_ld);
144     }
145     // dh +=  dGr * Wh^t
146     (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.n_gates * rnn.dic,
147             1.0, w_iter_[0], rnn.weights_iter_ld, ws_cell_, rnn.gates_ws_ld,
148             1.0, diff_states_t_l_, rnn.states_ws_ld);
149
150     // dWh += dGr^t * h
151     gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_cell_,
152             rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0, diff_w_iter_,
153             rnn.diff_weights_layer_ld);
154
155     // db1-3 += e * dG
156     // db4 += e * (r * dG2)
157     gates_reduction(rnn, ws_gates_, diff_bias_);
158
159     parallel_nd(rnn.dic, [&](int j) {
160         for (int i = 0; i < rnn.mb; i++) {
161             diff_bias_[3 * rnn.dic + j] += ws_gates_r(i, 2, j);
162         }
163     });
164 }
165
166 #undef AOC
167
168 }
169 }
170 }