Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / rnn / ref_rnn.hpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_REF_RNN_HPP
18 #define CPU_REF_RNN_HPP
19
20 #include <assert.h>
21
22 #include "c_types_map.hpp"
23 #include "memory_tracking.hpp"
24 #include "type_helpers.hpp"
25 #include "utils.hpp"
26
27 #include "../cpu_isa_traits.hpp"
28 #include "../gemm/os_blas.hpp"
29
30 #include "cpu_rnn_pd.hpp"
31 #include "rnn_utils.hpp"
32 #include "jit_uni_rnn_postgemm.hpp"
33
34 namespace mkldnn {
35 namespace impl {
36 namespace cpu {
37
38 template <alg_kind_t alg_kind, prop_kind_t prop_kind>
39 float activation(float s, float alpha, float cliping, float dd);
40
41 template <prop_kind_t aprop, impl::data_type_t src_type,
42         impl::data_type_t weights_type>
43 struct _ref_rnn_common_t : public cpu_primitive_t {
44     typedef typename prec_traits<src_type>::type src_data_t;
45     typedef typename prec_traits<weights_type>::type weights_data_t;
46     typedef typename utils::conditional<src_type == data_type::u8, int32_t,
47             float>::type acc_data_t;
48
49     using class_name = _ref_rnn_common_t<aprop, src_type, weights_type>;
50
51     typedef rnn_elemwise_sig((class_name::*elemwise_f));
52     typedef rnn_cell_execution_sig((class_name::*cell_execution_f));
53     typedef rnn_grid_execution_sig((class_name::*grid_execution_f));
54
55     typedef rnn_gemm_sig((class_name::*gemm_t));
56     typedef rnn_bias_prepare_sig((class_name::*bias_prepare_t));
57     typedef rnn_bias_finalize_sig((class_name::*bias_finalize_t));
58     typedef rnn_weights_assign_sig((class_name::*weights_assign_t));
59
60     using base_pd_t =
61             typename utils::conditional<false || aprop == prop_kind::forward,
62                     cpu_rnn_fwd_pd_t, cpu_rnn_bwd_pd_t>::type;
63
64     struct pd_t : public base_pd_t {
65         pd_t(engine_t *engine, const rnn_desc_t *adesc,
66                 const primitive_attr_t *attr,
67                 const typename pd_t::hint_class *hint_pd)
68             : base_pd_t(engine, adesc, attr, hint_pd) {}
69
70         DECLARE_COMMON_PD_T("ref:any", class_name);
71
72         status_t init() {
73             using namespace prop_kind;
74             using namespace utils;
75             using namespace memory_format;
76             using namespace rnn_utils;
77             assert(this->engine()->kind() == engine_kind::cpu);
78             const alg_kind_t cell_kind = this->desc()->cell_desc.cell_kind;
79
80             data_type_t src_layer_dt = this->desc()->src_layer_desc.data_type;
81             data_type_t weights_iter_dt
82                     = this->desc()->weights_iter_desc.data_type;
83             data_type_t weights_layer_dt
84                     = this->desc()->weights_layer_desc.data_type;
85
86             bool ok = true
87                     && one_of(cell_kind, alg_kind::vanilla_rnn,
88                                alg_kind::vanilla_lstm, alg_kind::vanilla_gru,
89                                alg_kind::gru_linear_before_reset)
90                     && IMPLICATION(aprop == prop_kind::forward,
91                                one_of(this->desc()->prop_kind, forward_training,
92                                            forward_inference))
93                     && IMPLICATION(aprop == backward,
94                                one_of(this->desc()->prop_kind, backward))
95                     && src_layer_dt == src_type
96                     && everyone_is(
97                                weights_type, weights_iter_dt, weights_layer_dt)
98                     && this->set_default_params() == status::success
99                     && this->with_bias();
100             if (!ok)
101                 return status::unimplemented;
102
103             init_conf(rnn_, *this->desc(), this->src_pd(0), this->src_pd(1),
104                     this->weights_pd(0), this->weights_pd(1), this->dst_pd(0));
105
106             if (rnn_.dt_conf == all_f32)
107                 ok = ok && this->attr()->has_default_values();
108
109             // Set weights descriptors to desired format
110             memory_desc_t weights_layer_md = *(this->weights_layer_pd_.desc());
111             CHECK(set_expected_desc(rnn_, weights_layer_md, false));
112             cpu_memory_t::pd_t new_weights_layer_pd(
113                     this->engine_, &weights_layer_md);
114             if (this->weights_layer_pd_.desc()->format == any) {
115                 this->weights_layer_pd_ = new_weights_layer_pd;
116             } else if (this->weights_layer_pd_.desc()->format == rnn_packed) {
117                 if (!this->weights_layer_pd_.is_equal(&new_weights_layer_pd))
118                     return status::unimplemented;
119             }
120
121             memory_desc_t weights_iter_md = *(this->weights_iter_pd_.desc());
122             CHECK(set_expected_desc(rnn_, weights_iter_md, true));
123             cpu_memory_t::pd_t new_weights_iter_pd(
124                     this->engine_, &weights_iter_md);
125             if (this->weights_iter_pd_.desc()->format == any) {
126                 this->weights_iter_pd_ = new_weights_iter_pd;
127             } else if (this->weights_iter_pd_.desc()->format == rnn_packed) {
128                 if (!this->weights_iter_pd_.is_equal(&new_weights_iter_pd))
129                     return status::unimplemented;
130             }
131
132             CHECK(this->check_layout_consistency());
133
134             set_conf(rnn_, *this->desc(), this->weights_pd(0),
135                     this->weights_pd(1), this->diff_weights_pd(0),
136                     this->diff_weights_pd(1));
137
138             size_t scratchpad_sz{0}, ws_sz{0};
139             get_scratchpad_and_workspace_sizes(rnn_, scratchpad_sz, ws_sz);
140
141             // initialize the workspace_pd if needed
142             if (rnn_.is_training) {
143                 dims_t ws_dims = {(int)ws_sz};
144                 memory_desc_t ws_d;
145                 mkldnn_memory_desc_init(&ws_d, 1, ws_dims, data_type::u8, x);
146                 this->ws_pd_ = cpu_memory_t::pd_t(this->engine(), &ws_d);
147             }
148
149             init_scratchpad(scratchpad_sz);
150
151             return status::success;
152         }
153
154         rnn_utils::rnn_conf_t rnn_;
155
156     private:
157         void init_scratchpad(size_t scratchpad_sz) {
158             using namespace memory_tracking::names;
159             auto scratchpad = this->scratchpad_registry().registrar();
160             scratchpad.book(key_rnn_space, sizeof(float) * scratchpad_sz, 4096);
161
162             int max_nparts = this->cell_kind() == alg_kind::vanilla_gru ? 2 : 1;
163             int ptr_wei_sz = rnn_.n_layer * rnn_.n_dir * max_nparts;
164             scratchpad.book(key_rnn_ptrs_wei_layer,
165                     sizeof(float *) * ptr_wei_sz);
166             scratchpad.book(key_rnn_ptrs_wei_iter,
167                     sizeof(float *) * ptr_wei_sz);
168             scratchpad.book(key_rnn_ptrs_bia,
169                     sizeof(float *) * ptr_wei_sz);
170         }
171     };
172
173     _ref_rnn_common_t(const pd_t *apd, const input_vector &inputs,
174             const output_vector &outputs)
175         : cpu_primitive_t(apd, inputs, outputs, true), rnn_postgemm_(nullptr) {
176         /// @todo set max_feature_size assuming that we limit the number of
177         /// iterations and layer to one if slc != dic and sic != dic
178         /// respectively
179
180         bias_preparation_func = &class_name::bias_prepare;
181         bias_finalization_func = &class_name::bias_finalize;
182
183         auto set_gemm_funcs
184                 = [](bool packed_gemm, gemm_t &g, weights_assign_t &a) {
185                       if (packed_gemm) {
186                           g = &class_name::packed_gemm;
187                           a = &class_name::assign_packed_weights;
188                       } else {
189                           g = &class_name::gemm;
190                           a = &class_name::assign_weights;
191                       }
192                   };
193         set_gemm_funcs(pd()->rnn_.use_iter_packed_gemm, gemm_iter_func,
194                 weights_iter_assign_func);
195
196         set_gemm_funcs(pd()->rnn_.use_layer_packed_gemm, gemm_layer_func,
197                 weights_layer_assign_func);
198
199         switch (pd()->cell_kind()) {
200         case alg_kind::vanilla_lstm:
201             cell_func = &class_name::cell_execution;
202             if (aprop == prop_kind::forward) {
203                 if (mayiuse(avx512_core))
204                     rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd<avx512_core, src_type>(
205                         pd()->rnn_, pd()->attr());
206                 else if (mayiuse(avx2))
207                     rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd<avx2, src_type>(
208                         pd()->rnn_, pd()->attr());
209                 else if (mayiuse(sse42))
210                     rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd<sse42, src_type>(
211                         pd()->rnn_, pd()->attr());
212                 assert(rnn_postgemm_ != nullptr);
213                 rnn_postgemm_->init();
214             }
215             elemwise_func = &class_name::lstm_elemwise;
216             break;
217         case alg_kind::vanilla_rnn: // @todo switch on cell kind
218             cell_func = &class_name::cell_execution;
219             elemwise_func = &class_name::rnn_elemwise;
220             switch (pd()->activation_kind()) {
221             case alg_kind::eltwise_relu:
222                 activation_func = &activation<alg_kind::eltwise_relu, aprop>;
223                 break;
224             case alg_kind::eltwise_tanh:
225                 activation_func = &activation<alg_kind::eltwise_tanh, aprop>;
226                 break;
227             case alg_kind::eltwise_logistic:
228                 activation_func = &activation<alg_kind::eltwise_logistic, aprop>;
229                 break;
230             default: break;
231             }
232             break;
233         case alg_kind::vanilla_gru:
234             cell_func = &class_name::cell_execution_gru;
235             break;
236         case alg_kind::gru_linear_before_reset:
237             cell_func = &class_name::cell_execution_gru_lbr;
238             elemwise_func = &class_name::gru_lbr_elemwise;
239             break;
240         default: break;
241         }
242
243         grid_computation = &class_name::linear_execution;
244
245         size_t scratchpad_size, workspace_size;
246         rnn_utils::set_offsets(pd()->rnn_, ws_gates_offset_, ws_states_offset_,
247                 ws_c_states_offset_, ws_diff_states_offset_,
248                 ws_grid_comp_offset_, ws_cell_comp_offset_,
249                 ws_bias_offset_, scratchpad_size, workspace_size);
250     }
251
252     ~_ref_rnn_common_t() {}
253
254     // typedef typename prec_traits::type data_t;
255
256     virtual void execute(event_t *e) const {
257         execute_();
258         e->set_state(event_t::ready);
259     }
260
261 private:
262     void execute_() const;
263     rnn_grid_execution_sig(linear_execution);
264     rnn_cell_execution_sig(cell_execution);
265     rnn_cell_execution_sig(cell_execution_gru);
266     rnn_cell_execution_sig(cell_execution_gru_lbr);
267     rnn_elemwise_sig(rnn_elemwise);
268     rnn_elemwise_sig(lstm_elemwise);
269     rnn_elemwise_sig(gru_lbr_elemwise);
270     rnn_gemm_sig(gemm);
271     rnn_gemm_sig(packed_gemm);
272     rnn_bias_prepare_sig(bias_prepare);
273     rnn_bias_finalize_sig(bias_finalize);
274     rnn_weights_assign_sig(assign_weights);
275     rnn_weights_assign_sig(assign_packed_weights);
276
277     float (*activation_func)(float dd, float s, float alpha, float cliping);
278
279     void copy_init_layer(const rnn_utils::rnn_conf_t &rnn,
280             src_data_t *ws_states_, float *ws_diff_states_,
281             const src_data_t *xt_, const float *diff_dst_layer) const;
282
283     template <typename input_data_t>
284     void copy_init_iter(const rnn_utils::rnn_conf_t &rnn,
285             src_data_t *ws_states_, float *ws_c_states, float *ws_diff_states_,
286             const input_data_t *firstit_states_,
287             const float *diff_dst_iter) const;
288
289     template <typename dst_data_t>
290     void copy_res_layer(const rnn_utils::rnn_conf_t &rnn,
291             dst_data_t *dst_layer_, float *diff_src_layer,
292             const src_data_t *ws_states_, const float *ws_diff_states_) const;
293
294     template <typename output_data_t>
295     void copy_res_iter(const rnn_utils::rnn_conf_t &rnn,
296             output_data_t *dst_iter_, float *diff_src_iter,
297             const src_data_t *ws_states_, float *ws_c_states,
298             const float *ws_diff_states_) const;
299
300     void gates_reduction(const rnn_utils::rnn_conf_t &rnn,
301             const acc_data_t *ws_gates_, float *diff_bias_) const;
302
303     const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
304
305     size_t ws_gates_offset_;
306     size_t ws_states_offset_;
307     size_t ws_c_states_offset_;
308     size_t ws_bias_offset_;
309     size_t ws_diff_states_offset_;
310     size_t ws_grid_comp_offset_;
311     size_t ws_cell_comp_offset_;
312     jit_uni_rnn_postgemm_kernel *rnn_postgemm_;
313
314     grid_execution_f grid_computation;
315     cell_execution_f cell_func;
316
317     bias_prepare_t bias_preparation_func;
318     bias_finalize_t bias_finalization_func;
319     weights_assign_t weights_layer_assign_func;
320     weights_assign_t weights_iter_assign_func;
321
322     gemm_t gemm_layer_func;
323     gemm_t gemm_iter_func;
324     elemwise_f elemwise_func;
325 };
326
327 using ref_rnn_fwd_f32_t = _ref_rnn_common_t<prop_kind::forward, data_type::f32, data_type::f32>;
328 using ref_rnn_bwd_f32_t = _ref_rnn_common_t<prop_kind::backward, data_type::f32, data_type::f32>;
329 using ref_rnn_fwd_u8s8_t = _ref_rnn_common_t<prop_kind::forward, data_type::u8, data_type::s8>;
330 }
331 }
332 }
333 #endif
334
335 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s