Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / rnn / cell_common.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 /*
18  * Common for RNN and LSTM cell execution
19  */
20 #include "ref_rnn.hpp"
21
22 namespace mkldnn {
23 namespace impl {
24 namespace cpu {
25 using namespace rnn_utils;
26
27 template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
28 rnn_cell_execution_sig(
29         (_ref_rnn_common_t<aprop, src_type, weights_type>::cell_execution)) {
30     if (!rnn.merge_gemm_layer) {
31         (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb,
32                 rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld,
33                 states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_,
34                 rnn.gates_ws_ld);
35     }
36     (this->*gemm_iter_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, rnn.sic,
37             1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_,
38             rnn.states_ws_ld, 1.0, ws_gates_, rnn.gates_ws_ld);
39
40     if (rnn_postgemm_ != nullptr)
41         rnn_postgemm_->execute<src_data_t, acc_data_t>(rnn, ws_gates_, states_t_l_, c_states_t_l_,
42             states_tm1_l_, c_states_tm1_l_, diff_states_t_l_,
43             diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_,
44             ws_cell_);
45     else
46         (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_,
47                 states_tm1_l_, c_states_tm1_l_, diff_states_t_l_,
48                 diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_,
49                 ws_cell_);
50 }
51 template rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution);
52 template rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution);
53
54 template <>
55 rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution) {
56     ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_);
57     (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_,
58             states_tm1_l_, c_states_tm1_l_, diff_states_t_l_,
59             diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_,
60             ws_cell_);
61
62     /// bwd by data on the cell
63     (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.n_gates * rnn.dic,
64             1.0, w_iter_[0], rnn.weights_iter_ld, ws_gates_, rnn.gates_ws_ld,
65             0.0, diff_states_t_l_, rnn.states_ws_ld);
66
67     if (!rnn.merge_gemm_layer) {
68         (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb,
69                 rnn.n_gates * rnn.dic, 1.0, w_layer_[0],
70                 rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0,
71                 &diff_states_t_l(rnn.n_states, 0, 0), rnn.states_ws_ld);
72
73         /// bwd by weights on the cell
74         gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_,
75                 rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0,
76                 diff_w_layer_, rnn.diff_weights_layer_ld);
77     }
78
79     if (!rnn.merge_gemm_iter)
80         gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_gates_,
81                 rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0,
82                 diff_w_iter_, rnn.diff_weights_iter_ld);
83
84     /// bwd by bias we just accumulate diffs from the gates
85     gates_reduction(rnn, ws_gates_, diff_bias_);
86 }
87
88 }
89 }
90 }