Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / rnn / rnn_utils.hpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef RNN_UTILS_HPP
18 #define RNN_UTILS_HPP
19
20 #include "mkldnn.h"
21
22 #include "cpu_rnn_pd.hpp"
23
24
25 #define rnn_elemwise_sig(f)                                               \
26     void f(const rnn_utils::rnn_conf_t &rnn, acc_data_t *ws_gates_,   \
27             src_data_t *states_t_l_, float *c_states_t_l_,            \
28             src_data_t *states_tm1_l_, float *c_states_tm1_l_,        \
29             float *diff_states_t_l_, float *diff_states_t_lp1_,       \
30             float *diff_states_tp1_l_, float *bias_, float *ws_grid_, \
31             float *ws_cell_) const
32
33 #define rnn_cell_execution_sig(f)                                             \
34     void f(const rnn_utils::rnn_conf_t &rnn, src_data_t *states_t_l_,     \
35             float *c_states_t_l_, float *diff_states_t_l_,                \
36             weights_data_t **w_layer_, weights_data_t **w_iter_,          \
37             float **bias_, src_data_t *states_t_lm1_,                     \
38             src_data_t *states_tm1_l_, float *c_states_tm1_l_,            \
39             float *diff_states_t_lp1_, float *diff_states_tp1_l_,         \
40             float *diff_w_layer_, float *diff_w_iter_, float *diff_bias_, \
41             acc_data_t *ws_gates_, float *ws_grid_, float *ws_cell_) const
42
43 #define rnn_grid_execution_sig(f)                                                 \
44     void f(const rnn_utils::rnn_conf_t &rnn, weights_data_t **weights_layer_, \
45             weights_data_t **weights_states_, float **bias_,                  \
46             src_data_t *ws_states_, float *ws_c_states_,                      \
47             float *ws_diff_states_, acc_data_t *ws_gates_, float *ws_cell_,   \
48             float *ws_grid_, float *diff_weights_layer_,                      \
49             float *diff_weights_iter_, float *diff_bias_) const
50
51 #define rnn_gemm_sig(f)                                                     \
52     void f(const char transA, const char transB, int m, int n, int k,   \
53             const float alpha, const weights_data_t *a_, const int ldA, \
54             const src_data_t *b_, const int ldB, const float beta,      \
55             acc_data_t *c_, const int ldC) const
56
57 #define rnn_bias_prepare_sig(f)                                                  \
58     void f(const rnn_utils::rnn_conf_t &rnn, float **bias_, const float *b_, \
59             float *scratch_bias_) const
60
61 #define rnn_bias_finalize_sig(f)                                       \
62     void f(const rnn_utils::rnn_conf_t &rnn, float *scratch_bias_, \
63             const float *w_iter_comp, const float *w_layer_comp) const
64
65 #define rnn_weights_assign_sig(f)                                                \
66     void f(const rnn_utils::rnn_conf_t &rnn, memory_format_t fmt, int nld,   \
67             int ld, int OC_size, int IC_size, const int n_parts,             \
68             const int *gates_per_part, const size_t *part_weights_pack_size, \
69             weights_data_t **weights_, const weights_data_t *w_,             \
70             float **bias_, const float *b_, float *scratch_bias_) const
71
72
73 namespace mkldnn {
74 namespace impl {
75 namespace cpu {
76
77 namespace rnn_utils {
78
79 using namespace mkldnn::impl::utils;
80
81 enum execution_direction_t {
82     l2r,
83     r2l,
84     bi_concat,
85     bi_sum,
86 };
87
88 enum data_type_conf_t {
89     all_f32,
90     u8u8u8f32,
91     f32u8f32f32,
92     u8u8u8u8,
93     f32u8f32u8
94 };
95
96 struct rnn_conf_t {
97     execution_direction_t exec_dir;
98     data_type_conf_t dt_conf;
99     int n_layer, n_iter, n_dir, n_gates, n_states;
100     int mb;
101     int slc, sic, dic, dlc;
102     int gates_ld, gates_nld, gates_ws_ld;
103     int n_parts_weights_layer, parts_weights_layer[MKLDNN_RNN_MAX_N_PARTS];
104     int n_parts_weights_iter, parts_weights_iter[MKLDNN_RNN_MAX_N_PARTS];
105     int n_bias, n_parts_bias, parts_bias[MKLDNN_RNN_MAX_N_PARTS];
106     size_t part_weights_iter_pack_size[MKLDNN_RNN_MAX_N_PARTS],
107             part_weights_layer_pack_size[MKLDNN_RNN_MAX_N_PARTS];
108     bool weights_layer_is_packed, weights_iter_is_packed;
109     /* Size of packed data in bytes */
110     size_t weights_layer_comp_offset, weights_layer_pack_size,
111         weights_iter_comp_offset, weights_iter_pack_size;
112
113     bool copy_bias;
114     int weights_layer_ld, weights_layer_nld;
115     int diff_weights_layer_ld, diff_weights_layer_nld;
116     int weights_iter_ld, weights_iter_nld;
117     int diff_weights_iter_ld, diff_weights_iter_nld;
118     int states_nld, states_ws_ld;
119     int weights_iter_compensation_size, weights_layer_compensation_size;
120     bool is_fwd, is_training, is_lbr;
121     bool use_workspace;
122
123     /* Size of workspace for each tensor in bytes */
124     size_t ws_gates_size, ws_states_size, ws_c_states_size, ws_diff_states_size,
125             ws_cell_comp_size, ws_grid_comp_size, ws_per_cell, ws_bias_size;
126     bool merge_gemm_iter, merge_gemm_layer, use_jit_gemm, use_layer_packed_gemm,
127         use_iter_packed_gemm;
128     memory_format_t weights_layer_fmt, weights_iter_fmt, diff_weights_layer_fmt,
129             diff_weights_iter_fmt;
130 };
131
132 int get_good_ld(int dim, int sizeof_dt);
133
134 void init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd,
135         const memory_desc_wrapper &src_layer_d,
136         const memory_desc_wrapper &src_iter_d,
137         const memory_desc_wrapper &weights_layer_d,
138         const memory_desc_wrapper &weights_iter_d,
139         const memory_desc_wrapper &dst_layer_d);
140
141 void set_conf(rnn_conf_t &rnn, const rnn_desc_t &rd,
142         const memory_desc_wrapper &weights_layer_d,
143         const memory_desc_wrapper &weights_iter_d,
144         const memory_desc_wrapper &diff_weights_layer_d,
145         const memory_desc_wrapper &diff_weights_iter_d);
146
147 void set_offsets(const rnn_conf_t &rnn, size_t &ws_gates_offset,
148         size_t &ws_h_state_offset, size_t &ws_c_state_offset,
149         size_t &ws_diff_states_offset, size_t &ws_grid_comp_offset,
150         size_t &ws_cell_comp_offset, size_t &ws_bias_offset,
151         size_t &scratchpad_size, size_t &workspace_size);
152
153 void get_scratchpad_and_workspace_sizes(const rnn_conf_t &rnn,
154         size_t &scratchpad_size, size_t &workspace_size);
155 status_t set_expected_desc(
156         rnn_conf_t &rnn, memory_desc_t &weights_md, bool is_iter);
157 status_t set_good_strides(memory_desc_t &weights_md);
158
159 template <typename T>
160 struct ws_gates_aoc {
161     ws_gates_aoc(const rnn_conf_t &rnn, T *data)
162         : gates_(data, rnn.gates_nld, rnn.gates_ws_ld), DIC_(rnn.dic) {}
163     T &operator()(int batch, int gate, int dic) {
164         return gates_(batch, gate * DIC_ + dic);
165     }
166
167 private:
168     mkldnn::impl::utils::array_offset_calculator<T, 2> gates_;
169     int DIC_;
170 };
171 using ws_gates_aoc_t = ws_gates_aoc<float>;
172 using ws_gates_aoc_s32_t = ws_gates_aoc<int32_t>;
173
174 struct bias_aoc_t {
175     bias_aoc_t(const rnn_conf_t &rnn, const float *data)
176         : bias_(data, rnn.n_bias, rnn.dic) {}
177     const float &operator()(int bias_n, int dic) { return bias_(bias_n, dic); }
178
179 private:
180     mkldnn::impl::utils::array_offset_calculator<const float, 2> bias_;
181 };
182
183 template <typename T>
184 struct ws_states_aoc {
185     ws_states_aoc(const rnn_conf_t &rnn, T *data)
186         : state_(data, rnn.states_nld, rnn.states_ws_ld) {}
187     T &operator()(int batch, int dic) { return state_(batch, dic); }
188
189 private:
190     mkldnn::impl::utils::array_offset_calculator<T, 2> state_;
191 };
192 using ws_states_aoc_t = ws_states_aoc<float>;
193 using ws_states_aoc_u8_t = ws_states_aoc<uint8_t>;
194
195 struct ws_diff_states_aoc_t {
196     ws_diff_states_aoc_t(const rnn_conf_t &rnn, float *data)
197         : diff_states_(data, rnn.n_states + 1, rnn.n_iter + 1, rnn.states_nld,
198                   rnn.states_ws_ld) {}
199     float &operator()(int state_n, int batch, int dic) {
200         return diff_states_(state_n, 0, batch, dic);
201     }
202
203 private:
204     mkldnn::impl::utils::array_offset_calculator<float, 4> diff_states_;
205 };
206
207 struct ws_diff_w_iter_aoc_t {
208     ws_diff_w_iter_aoc_t(const rnn_conf_t &rnn, float *data)
209         : diff_weights_iter_(
210                   data, rnn.diff_weights_iter_nld, rnn.diff_weights_iter_ld)
211         , DIC_(rnn.dic) {}
212     float &operator()(int sic, int gate, int dic) {
213         return diff_weights_iter_(sic, gate * DIC_ + dic);
214     }
215
216 private:
217     mkldnn::impl::utils::array_offset_calculator<float, 2> diff_weights_iter_;
218     int DIC_;
219 };
220 }
221 }
222 }
223 }
224 #endif