1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
21 #include "math_utils.hpp"
22 #include "mkldnn_thread.hpp"
24 #include "ref_rnn.hpp"
30 using namespace mkldnn::impl::utils;
31 using namespace mkldnn::impl::math;
32 using namespace rnn_utils;
34 #define AOC array_offset_calculator
36 rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru) {
37 ws_gates_aoc_t ws_gates(rnn, ws_gates_);
38 bias_aoc_t bias(rnn, bias_[0]);
39 ws_states_aoc_t states_t_l(rnn, states_t_l_);
40 ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_);
43 if (!rnn.merge_gemm_layer) {
44 (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb,
45 rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld,
46 states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_,
51 (this->*gemm_iter_func)('N', 'N', (rnn.n_gates - 1) * rnn.dic, rnn.mb,
52 rnn.sic, 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_,
53 rnn.states_ws_ld, 1.0, ws_gates_, rnn.gates_ws_ld);
55 // 3. activation zt and rt + elemwise multiplication rt,ht-1
56 parallel_nd(rnn.mb, [&](int i) {
58 for (int j = 0; j < rnn.dic; j++) {
59 ws_gates(i, 0, j) = logistic_fwd(ws_gates(i, 0, j) + bias(0, j));
60 ws_gates(i, 1, j) = logistic_fwd(ws_gates(i, 1, j) + bias(1, j));
61 states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 1, j);
66 (this->*gemm_iter_func)('N', 'N', rnn.dic, rnn.mb, rnn.sic, 1.0, w_iter_[1],
67 rnn.weights_iter_ld, states_t_l_, rnn.states_ws_ld, 1.0,
68 &(ws_gates(0, 2, 0)), rnn.gates_ws_ld);
70 // 5. activation h~t + calculate ht
71 parallel_nd(rnn.mb, [&](int i) {
73 for (int j = 0; j < rnn.dic; j++) {
74 ws_gates(i, 2, j) = tanh_fwd(ws_gates(i, 2, j) + bias(2, j));
75 states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 0, j)
76 + (1.0f - ws_gates(i, 0, j)) * ws_gates(i, 2, j);
82 rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru) {
83 assert(!"GRU int8 is not supported");
87 rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru) {
88 ws_gates_aoc_t ws_gates(rnn, ws_gates_);
89 ws_states_aoc_t states_t_l(rnn, states_t_l_);
90 ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_);
91 ws_diff_w_iter_aoc_t diff_w_iter(rnn, diff_w_iter_);
92 ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_);
93 ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_);
94 ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_);
96 // use state memory for intermediate computations
97 // TODO: use cell ws for that
98 float *dhG1_ = &(diff_states_t_l(rnn.n_states, 0, 0));
100 AOC<float, 2> dhG1(dhG1_, rnn.states_nld, rnn.states_ws_ld);
101 AOC<float, 2> hG1(hG1_, rnn.states_nld, rnn.states_ws_ld);
103 // 1. calculate dG2, dG1, and part of dht-1
104 // dG2^ = dh * (1 - G0) * (1 - G2^2)
105 // dG0^ = dh * (ht-1 - G2) * u * (1 - G0)
106 // dht-1 (part) = dh * G0
107 parallel_nd(rnn.mb, [&](int i) {
109 for (int j = 0; j < rnn.dic; j++) {
110 float h = states_tm1_l(i, j);
111 float dHt = diff_states_tp1_l(0, i, j)
112 + diff_states_t_lp1(rnn.n_states, i, j);
113 float dG2 = (1.0f - ws_gates(i, 0, j)) * dHt
114 * one_m_square(ws_gates(i, 2, j));
115 float dG0 = (h - ws_gates(i, 2, j)) * dHt
116 * x_m_square(ws_gates(i, 0, j));
118 diff_states_t_l(0, i, j) = dHt * ws_gates(i, 0, j);
119 ws_gates(i, 0, j) = dG0;
120 ws_gates(i, 2, j) = dG2;
124 // 2. calculate intermediate d(hG1)
125 // d(hG1) = dG2 * W2h^t
126 (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.dic, 1.0, w_iter_[1],
127 rnn.weights_iter_ld, &(ws_gates(0, 2, 0)), rnn.gates_ws_ld, 0.0,
128 dhG1_, rnn.states_ws_ld);
130 // 3. calculate dG1^ and part of dht-1
131 // dG1^ = d(hG1) * h * G1 * (1 - G1)
132 // dht-1 (part) += d(hG1) * G1
133 // h * G1 (required for dWh)
134 parallel_nd(rnn.mb, [&](int i) {
136 for (int j = 0; j < rnn.dic; j++) {
137 float h = states_tm1_l(i, j);
138 float G1 = ws_gates(i, 1, j);
139 diff_states_t_l(0, i, j) += dhG1(i, j) * G1;
140 ws_gates(i, 1, j) = dhG1(i, j) * h * x_m_square(G1);
145 // 4. calculate diff weights
146 // dWh1 += dG1 * h, dWh2 += dG2 * h, dWh3 += dG3 * (G1(*)h)
147 gemm('N', 'T', (rnn.n_gates - 1) * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_gates_,
148 rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0, diff_w_iter_,
149 rnn.diff_weights_iter_ld);
150 gemm('N', 'T', rnn.dic, rnn.sic, rnn.mb, 1.0, &(ws_gates(0, 2, 0)),
151 rnn.gates_ws_ld, hG1_, rnn.states_ws_ld, 1.0,
152 &(diff_w_iter(0, 2, 0)), rnn.diff_weights_iter_ld);
154 // 5. calculate diff states
155 // dht-1 += dG1 * W1h + dG0 * W0h
156 (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb,
157 (rnn.n_gates - 1) * rnn.dic, 1.0, w_iter_[0],
158 rnn.weights_iter_ld, ws_gates_, rnn.gates_ws_ld, 1.0,
159 diff_states_t_l_, rnn.states_ws_ld);
161 if (!rnn.merge_gemm_layer) {
162 // dWx += [dG0 dG1 dG2] * [x]
163 gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_,
164 rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0,
165 diff_w_layer_, rnn.diff_weights_layer_ld);
166 // dx = dG2 * W2x + dG1 * W1x + dG0 * W0x
167 (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb,
168 rnn.n_gates * rnn.dic, 1.0, w_layer_[0],
169 rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0,
170 &(diff_states_t_l(rnn.n_states, 0, 0)), rnn.states_ws_ld);
173 // 6. calculate diff bias
174 gates_reduction(rnn, ws_gates_, diff_bias_);