Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / rnn / cell_gru.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 /*
18  * Cell execution GRU
19  */
20
21 #include "math_utils.hpp"
22 #include "mkldnn_thread.hpp"
23
24 #include "ref_rnn.hpp"
25
26 namespace mkldnn {
27 namespace impl {
28 namespace cpu {
29
30 using namespace mkldnn::impl::utils;
31 using namespace mkldnn::impl::math;
32 using namespace rnn_utils;
33
34 #define AOC array_offset_calculator
35 template <>
36 rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru) {
37     ws_gates_aoc_t ws_gates(rnn, ws_gates_);
38     bias_aoc_t bias(rnn, bias_[0]);
39     ws_states_aoc_t states_t_l(rnn, states_t_l_);
40     ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_);
41
42     // 1. gemm Wx[0-2],x
43     if (!rnn.merge_gemm_layer) {
44         (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb,
45                 rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld,
46                 states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_,
47                 rnn.gates_ws_ld);
48     }
49
50     // 2. gemm Wh[0-1],h
51     (this->*gemm_iter_func)('N', 'N', (rnn.n_gates - 1) * rnn.dic, rnn.mb,
52             rnn.sic, 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_,
53             rnn.states_ws_ld, 1.0, ws_gates_, rnn.gates_ws_ld);
54
55     // 3. activation zt and rt + elemwise multiplication rt,ht-1
56     parallel_nd(rnn.mb, [&](int i) {
57         PRAGMA_OMP_SIMD()
58         for (int j = 0; j < rnn.dic; j++) {
59             ws_gates(i, 0, j) = logistic_fwd(ws_gates(i, 0, j) + bias(0, j));
60             ws_gates(i, 1, j) = logistic_fwd(ws_gates(i, 1, j) + bias(1, j));
61             states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 1, j);
62         }
63     });
64
65     // 4. gemm Wh[2],h~t
66     (this->*gemm_iter_func)('N', 'N', rnn.dic, rnn.mb, rnn.sic, 1.0, w_iter_[1],
67             rnn.weights_iter_ld, states_t_l_, rnn.states_ws_ld, 1.0,
68             &(ws_gates(0, 2, 0)), rnn.gates_ws_ld);
69
70     // 5. activation h~t + calculate ht
71     parallel_nd(rnn.mb, [&](int i) {
72         PRAGMA_OMP_SIMD()
73         for (int j = 0; j < rnn.dic; j++) {
74             ws_gates(i, 2, j) = tanh_fwd(ws_gates(i, 2, j) + bias(2, j));
75             states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 0, j)
76                     + (1.0f - ws_gates(i, 0, j)) * ws_gates(i, 2, j);
77         }
78     });
79 }
80
81 template <>
82 rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru) {
83     assert(!"GRU int8 is not supported");
84 }
85
86 template <>
87 rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru) {
88     ws_gates_aoc_t ws_gates(rnn, ws_gates_);
89     ws_states_aoc_t states_t_l(rnn, states_t_l_);
90     ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_);
91     ws_diff_w_iter_aoc_t diff_w_iter(rnn, diff_w_iter_);
92     ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_);
93     ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_);
94     ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_);
95
96     // use state memory for intermediate computations
97     // TODO: use cell ws for that
98     float *dhG1_ = &(diff_states_t_l(rnn.n_states, 0, 0));
99     float *hG1_ = dhG1_;
100     AOC<float, 2> dhG1(dhG1_, rnn.states_nld, rnn.states_ws_ld);
101     AOC<float, 2> hG1(hG1_, rnn.states_nld, rnn.states_ws_ld);
102
103     // 1. calculate dG2, dG1, and part of dht-1
104     // dG2^ = dh * (1 - G0) * (1 - G2^2)
105     // dG0^ = dh * (ht-1 - G2) * u * (1 - G0)
106     // dht-1 (part) = dh * G0
107     parallel_nd(rnn.mb, [&](int i) {
108         PRAGMA_OMP_SIMD()
109         for (int j = 0; j < rnn.dic; j++) {
110             float h = states_tm1_l(i, j);
111             float dHt = diff_states_tp1_l(0, i, j)
112                     + diff_states_t_lp1(rnn.n_states, i, j);
113             float dG2 = (1.0f - ws_gates(i, 0, j)) * dHt
114                     * one_m_square(ws_gates(i, 2, j));
115             float dG0 = (h - ws_gates(i, 2, j)) * dHt
116                     * x_m_square(ws_gates(i, 0, j));
117
118             diff_states_t_l(0, i, j) = dHt * ws_gates(i, 0, j);
119             ws_gates(i, 0, j) = dG0;
120             ws_gates(i, 2, j) = dG2;
121         }
122     });
123
124     // 2. calculate intermediate d(hG1)
125     // d(hG1) = dG2 * W2h^t
126     (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.dic, 1.0, w_iter_[1],
127             rnn.weights_iter_ld, &(ws_gates(0, 2, 0)), rnn.gates_ws_ld, 0.0,
128             dhG1_, rnn.states_ws_ld);
129
130     // 3. calculate dG1^ and part of dht-1
131     // dG1^ = d(hG1) * h * G1 * (1 - G1)
132     // dht-1 (part) += d(hG1) * G1
133     // h * G1 (required for dWh)
134     parallel_nd(rnn.mb, [&](int i) {
135         PRAGMA_OMP_SIMD()
136         for (int j = 0; j < rnn.dic; j++) {
137             float h = states_tm1_l(i, j);
138             float G1 = ws_gates(i, 1, j);
139             diff_states_t_l(0, i, j) += dhG1(i, j) * G1;
140             ws_gates(i, 1, j) = dhG1(i, j) * h * x_m_square(G1);
141             hG1(i, j) = G1 * h;
142         }
143     });
144
145     // 4. calculate diff weights
146     // dWh1 += dG1 * h, dWh2 += dG2 * h, dWh3 += dG3 * (G1(*)h)
147     gemm('N', 'T', (rnn.n_gates - 1) * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_gates_,
148             rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0, diff_w_iter_,
149             rnn.diff_weights_iter_ld);
150     gemm('N', 'T', rnn.dic, rnn.sic, rnn.mb, 1.0, &(ws_gates(0, 2, 0)),
151             rnn.gates_ws_ld, hG1_, rnn.states_ws_ld, 1.0,
152             &(diff_w_iter(0, 2, 0)), rnn.diff_weights_iter_ld);
153
154     // 5. calculate diff states
155     // dht-1 += dG1 * W1h + dG0 * W0h
156     (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb,
157             (rnn.n_gates - 1) * rnn.dic, 1.0, w_iter_[0],
158             rnn.weights_iter_ld, ws_gates_, rnn.gates_ws_ld, 1.0,
159             diff_states_t_l_, rnn.states_ws_ld);
160
161     if (!rnn.merge_gemm_layer) {
162         // dWx += [dG0 dG1 dG2] * [x]
163         gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_,
164                 rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0,
165                 diff_w_layer_, rnn.diff_weights_layer_ld);
166         // dx = dG2 * W2x + dG1 * W1x + dG0 * W0x
167         (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb,
168                 rnn.n_gates * rnn.dic, 1.0, w_layer_[0],
169                 rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0,
170                 &(diff_states_t_l(rnn.n_states, 0, 0)), rnn.states_ws_ld);
171     }
172
173     // 6. calculate diff bias
174     gates_reduction(rnn, ws_gates_, diff_bias_);
175 }
176 #undef AOC
177
178 }
179 }
180 }