Add a section of how to link IE with CMake project (#99)
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_rnn.hpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_REF_RNN_HPP
18 #define CPU_REF_RNN_HPP
19
20 #include <assert.h>
21
22 #include "c_types_map.hpp"
23 #include "cpu_engine.hpp"
24 #include "cpu_rnn_pd.hpp"
25 #include "cpu_isa_traits.hpp"
26 #include "scratchpad.hpp"
27 #include "type_helpers.hpp"
28 #include "utils.hpp"
29
30 #include "gemm/os_blas.hpp"
31
32 namespace mkldnn {
33 namespace impl {
34 namespace cpu {
35
36 #define elemwise_sig(f)                                                 \
37     void f(int dic, int wic, int batch, int n_states, int iter_stride, int n_gates, \
38             float *ws_gates_, float *states_t_l_, float *states_t_lm1_, \
39             float *states_tm1_l_, float *diff_states_t_l_,              \
40             float *diff_states_t_lp1_, float *diff_states_tp1_l_,       \
41             const float *bias_, float *ws_grid_, float *ws_cell_)
42
43 #define cell_execution_sig(f)                                                 \
44     void f(int dic, int slc, int sic, int wic, int batch, int n_gates,        \
45             int n_states, int iter_stride, float *states_t_l_, float *diff_states_t_l_, \
46             float **w_input_, float **w_state_, const float *bias_,           \
47             float *states_t_lm1_, float *states_tm1_l_,                       \
48             float *diff_states_t_lp1_, float *diff_states_tp1_l_,             \
49             float *diff_w_input_, float *diff_w_state_, float *diff_bias_,    \
50             float *ws_gates_, float *ws_grid_, float *ws_cell_)
51
52 #define grid_execution_sig(f)                                              \
53     void f(int dic, int slc, int sic, int wic, int batch, int n_layer,     \
54             int n_direction, int n_iter, int n_gates, int n_states,        \
55             int n_bias, float **weights_input_, int n_parts_wei_i,         \
56             float **weights_states_, int n_parts_wei_st,                   \
57             const float *bias_, float *ws_states_, float *ws_diff_states_, \
58             float *ws_gates_, float *ws_cell_, float *ws_grid_,            \
59             int ws_per_cell, float *diff_weights_layer_,                   \
60             float *diff_weights_iter_, float *diff_bias_)
61
62 #define gemm_sig(f)                                                          \
63     void f(int m, int n, int k, int strideA_m, int strideA_k, int strideB_n, \
64             int strideB_k, int strideC_m, int strideC_n, const float *a_,    \
65             float *b_, float *c_, bool is_B_trans, float beta)
66
67 #define packing_sig(f)                                               \
68     void f(int n_layer, int n_direction, int n_weights, int n_gates, \
69             int batch, int OC_size, int IC_size, float **weights_,   \
70             int n_parts, int *gates_per_part, const float *w_,       \
71             float * scratch_mem, bool do_copy)
72
73 #define free_packed_sig(f) void f(int n_layer, int n_direction, int n_parts, \
74             float **weights_)
75
76 template <alg_kind_t alg_kind, prop_kind_t prop_kind>
77 float activation(float s, float alpha, float cliping, float dd);
78
79 template <prop_kind_t aprop>
80 struct _ref_rnn_common_t : public cpu_primitive_t {
81     using class_name = _ref_rnn_common_t<aprop>;
82     typedef enum execution_direction_ {
83         b2t_l2r,
84         b2t_r2l,
85         b2t_bi_concat,
86         b2t_bi_sum,
87         t2b_l2r,
88         t2b_r2l,
89         t2b_bi_concat,
90         t2b_bi_sum
91     } execution_direction;
92     typedef elemwise_sig((class_name::*elemwise_f));
93     typedef cell_execution_sig((class_name::*cell_execution_f));
94     typedef grid_execution_sig((class_name::*grid_execution_f));
95
96     typedef gemm_sig((class_name::*gemm_t));
97     typedef packing_sig((class_name::*packing_t));
98     typedef free_packed_sig((class_name::*free_packed_t));
99
100     using base_pd_t =
101             typename utils::conditional<false || aprop == prop_kind::forward,
102                     cpu_rnn_fwd_pd_t, cpu_rnn_bwd_pd_t>::type;
103
104     struct pd_t : public base_pd_t {
105         pd_t(engine_t *engine, const rnn_desc_t *adesc,
106                 const primitive_attr_t *attr,
107                 const typename pd_t::base_class *hint_pd)
108             : base_pd_t(engine, adesc, attr, hint_pd) {}
109
110         DECLARE_COMMON_PD_T("ref:any", class_name);
111
112         status_t init() {
113             using namespace prop_kind;
114             using namespace utils;
115             using namespace memory_format;
116             assert(this->engine()->kind() == engine_kind::cpu);
117             const alg_kind_t cell_kind = this->desc()->cell_desc.cell_kind;
118
119             bool ok = true
120                     && one_of(cell_kind, alg_kind::vanilla_rnn,
121                                alg_kind::vanilla_lstm, alg_kind::vanilla_gru,
122                                alg_kind::gru_linear_before_reset)
123                     && IMPLICATION(aprop == prop_kind::forward,
124                                one_of(this->desc()->prop_kind, forward_training,
125                                        forward_inference))
126                     && IMPLICATION(aprop == backward,
127                                one_of(this->desc()->prop_kind, backward))
128                     && this->set_default_params() == status::success;
129             if (!ok)
130                 return status::unimplemented;
131
132             ok = ok && utils::one_of(cell_kind, alg_kind::vanilla_rnn,
133                                alg_kind::vanilla_lstm, alg_kind::vanilla_gru,
134                                alg_kind::gru_linear_before_reset);
135
136             /// @todo check data layouts for all input tensors
137             ok = ok && this->desc()->src_layer_desc.format == tnc
138                     && this->desc()->dst_layer_desc.format == tnc;
139
140             ok = ok && this->with_bias();
141             switch (aprop) {
142             case (prop_kind::forward):
143                 ok = ok && utils::one_of(this->desc()->prop_kind,
144                                    forward_training, forward_inference);
145                 ok = ok && utils::one_of(
146                                    this->desc()->weights_layer_desc.format, any,
147                                    ldigo, ldigo_p)
148                         && utils::one_of(this->desc()->weights_iter_desc.format,
149                                    any, ldigo, ldigo_p);
150                 break;
151             case (prop_kind::backward):
152                 ok = ok && utils::one_of(this->desc()->prop_kind, backward);
153                 ok = ok && utils::one_of(
154                                    this->desc()->weights_layer_desc.format, any,
155                                    ldgoi, ldgoi_p)
156                         && utils::one_of(this->desc()->weights_iter_desc.format,
157                                    any, ldgoi, ldgoi_p);
158                 break;
159             default: ok = false;
160             }
161
162             // Check dimensions consistency
163             int ls_multiplier
164                     = (this->direction() == mkldnn_bidirectional_concat) ? 2 :
165                                                                            1;
166
167             ok = ok && (ls_multiplier * this->DIC() == this->DLC())
168                     && ((ls_multiplier * this->SLC()) == this->DLC()
169                                || (this->L() == 1))
170                     && (this->SIC() == this->DIC() || (this->T() == 1));
171
172             // initialize the workspace_pd if needed
173             if (this->desc()->prop_kind != forward_inference){
174                 dims_t ws_dims = { (dim_t)this->get_ws_size() };
175                 memory_desc_t ws_d;
176                 mkldnn_memory_desc_init(
177                         &ws_d, 1, ws_dims, impl::data_type::f32, memory_format::x);
178                 this->ws_pd_ = cpu_memory_t::pd_t(this->engine(), &ws_d);
179             }
180
181             return ok ? status::success : status::unimplemented;
182         }
183     };
184
185     _ref_rnn_common_t(const pd_t *pd, const input_vector &inputs,
186             const output_vector &outputs)
187         : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd) {
188         /// @todo set max_feature_size assuming that we limit the number of
189         /// iterations and layer to one if slc != dic and sic != dic
190         /// respectively
191
192         memory_format_t packed_format;
193         switch (aprop) {
194         case prop_kind::forward_inference:
195         case prop_kind::forward_training:
196             packed_format = memory_format::ldigo_p;
197             break;
198         case prop_kind::backward: packed_format = memory_format::ldgoi_p; break;
199         default: assert(false);
200         }
201
202         merge_gemm_layer = ((aprop == prop_kind::forward) && (conf_.MB() < 128))
203             || (aprop == prop_kind::backward);
204         merge_gemm_iter = (aprop == prop_kind::backward)
205                 && (!utils::one_of(conf_.cell_kind(), alg_kind::vanilla_gru,
206                             alg_kind::gru_linear_before_reset));
207         auto set_pack_funcs = [](bool packed_gemm, gemm_t &g, bool pack_w,
208                 packing_t &p, free_packed_t &f) {
209             g = packed_gemm ? &class_name::packed_gemm : &class_name::gemm;
210             p = pack_w ? &class_name::pack_weights :
211                              &class_name::no_pack_weights;
212             f = pack_w ? &class_name::free_packed_weights :
213                              &class_name::free_no_packed_weights;
214         };
215 #ifdef USE_MKL_PACKED_GEMM
216         const bool weights_pack_cond =
217             (conf_.T() > 1) && (conf_.MB() == 32) &&
218             (conf_.SIC() == 512) &&(conf_.SLC() == 512) && (conf_.DIC() == 512);
219 #else
220         const bool weights_pack_cond = false;
221 #endif
222
223         const bool is_weights_state_packed = conf_.desc()->weights_iter_desc.format == packed_format;
224         set_pack_funcs(weights_pack_cond || is_weights_state_packed,
225                 gemm_state_func, weights_pack_cond && !is_weights_state_packed,
226                 weights_state_pack_func, weights_state_free_packed_func);
227
228         const bool is_weights_input_packed = conf_.desc()->weights_layer_desc.format == packed_format;
229         set_pack_funcs(weights_pack_cond || is_weights_input_packed,
230                 gemm_input_func, weights_pack_cond && !is_weights_input_packed,
231                 weights_input_pack_func, weights_input_free_packed_func);
232
233         switch (conf_.cell_kind()) {
234         case alg_kind::vanilla_lstm:
235             cell_func = &class_name::cell_execution;
236             elemwise_func = &class_name::lstm_elemwise;
237             break;
238         case alg_kind::vanilla_rnn: // @todo switch on cell kind
239             cell_func = &class_name::cell_execution;
240             elemwise_func = &class_name::rnn_elemwise;
241             switch (conf_.activation_kind()) {
242             case alg_kind::eltwise_relu:
243                 activation_func = &activation<alg_kind::eltwise_relu, aprop>;
244                 break;
245             case alg_kind::eltwise_tanh:
246                 activation_func = &activation<alg_kind::eltwise_tanh, aprop>;
247                 break;
248             case alg_kind::eltwise_logistic:
249                 activation_func = &activation<alg_kind::eltwise_logistic, aprop>;
250                 break;
251             default: break;
252             }
253             break;
254         case alg_kind::vanilla_gru:
255             cell_func = &class_name::cell_execution_gru;
256             break;
257         case alg_kind::gru_linear_before_reset:
258             cell_func = &class_name::cell_execution_gru_lbr;
259             elemwise_func = &class_name::gru_lbr_elemwise;
260             break;
261         default: break;
262         }
263
264         n_output_features
265                 = (conf_.direction() == mkldnn_bidirectional_concat) ? 2 : 1;
266         switch (conf_.direction()) {
267         case mkldnn_unidirectional_left2right: exec_dir = b2t_l2r; break;
268         case mkldnn_unidirectional_right2left: exec_dir = b2t_r2l; break;
269         case mkldnn_bidirectional_concat: exec_dir = b2t_bi_concat; break;
270         case mkldnn_bidirectional_sum: exec_dir = b2t_bi_sum; break;
271         default: break;
272         }
273
274         /// @todo put a heuristic to choose between linear execution and
275         /// wavefront
276         grid_computation = &class_name::linear_execution;
277
278         // we need to allocate memory for:
279         // - the states to compute a pass.
280         // - the intermediate results from the gates.
281         // - the diff_states to compute the backward pass (training only)
282         // These should be allocated on scratchpad if fwd inference
283         // or on a workspace provided by the user for training.
284         /// @todo shall we require the workspace for training or make it
285         /// optional?
286
287         // if no worskpace is provided on forward, we use a scratchpad
288         // NOTE: here we use a large worskpace for simplicity:
289         // - for states:
290         //   - TODO: allocate only n_iter * dic + dic for linear execution
291         //   (inference)
292         //   - TODO: allocate only n_layer_wav * (2*dic) for wavefront
293         //   execution (inference)
294         // - for gates:
295         //   - TODO: allocate only batch * n_gates * dic for linear execution
296         //   (inference)
297         //   = TODO: allocate only n_layer_wav * batch * n_gates * dic for
298         //   wavefront execution (inference)
299
300         use_jit_sgemm_ = ((aprop == prop_kind::forward_inference)
301             || (conf_.is_training() && conf_.DIC() < 500))
302             && !mayiuse(avx512_mic);
303
304         copy_weights_layer_ = (conf_.WL_LD() != conf_.WL_GLD());
305         copy_weights_iter_ = (conf_.WI_LD() != conf_.WI_GLD());
306
307         copy_diff_weights_layer_ = (aprop == prop_kind::backward)
308                 && (conf_.DWL_LD() != conf_.DWL_GLD());
309         copy_diff_weights_iter_ = (aprop == prop_kind::backward)
310                 && (conf_.DWI_LD() != conf_.DWI_GLD());
311
312         use_workspace_ = (conf_.desc()->prop_kind != prop_kind::forward_inference);
313
314         size_t scratchpad_size = conf_.set_offsets(use_workspace_,
315             ws_gates_offset_, ws_states_offset_,  ws_diff_states_offset_,
316             ws_grid_comp_offset_,
317             conf_.is_lbr(), ws_cell_comp_offset_,
318             copy_weights_layer_, ws_weights_layer_offset_,
319             copy_weights_iter_, ws_weights_iter_offset_,
320             copy_diff_weights_layer_, ws_diff_weights_layer_offset_,
321             copy_diff_weights_iter_, ws_diff_weights_iter_offset_);
322
323         scratchpad_ =
324             create_scratchpad(scratchpad_size * sizeof(float));
325
326         int max_nparts = (conf_.cell_kind() == alg_kind::vanilla_gru) ? 2 : 1;
327         int ptr_wei_sz = conf_.L() * conf_.D() * max_nparts;
328         ptr_wei_input_ = (float **)malloc(sizeof(float *) * ptr_wei_sz, 64);
329         ptr_wei_state_ = (float **)malloc(sizeof(float *) * ptr_wei_sz, 64);
330     }
331     ~_ref_rnn_common_t() {
332         delete scratchpad_;
333         free(ptr_wei_input_);
334         free(ptr_wei_state_);
335     }
336
337     // typedef typename prec_traits::type data_t;
338
339     virtual void execute(event_t *e) {
340         execute_();
341         e->set_state(event_t::ready);
342     }
343
344 private:
345     void execute_();
346     grid_execution_sig(linear_execution);
347     // grid_execution_sig(wavefront_execution);
348     cell_execution_sig(cell_execution);
349     cell_execution_sig(cell_execution_gru);
350     cell_execution_sig(cell_execution_gru_lbr);
351     elemwise_sig(rnn_elemwise);
352     elemwise_sig(lstm_elemwise);
353     elemwise_sig(gru_lbr_elemwise);
354     gemm_sig(gemm);
355     gemm_sig(packed_gemm);
356     packing_sig(pack_weights);
357     packing_sig(no_pack_weights);
358     free_packed_sig(free_packed_weights);
359     free_packed_sig(free_no_packed_weights);
360
361     float (*activation_func)(float dd, float s, float alpha, float cliping);
362
363     void copy_init_layer(bool lr, bool rl, int n_direction, int n_layer,
364             int n_iter, int batch, int slc, int dic, int dlc, int wic,
365             int n_states, float *ws_states_, float *ws_diff_states_,
366             const float *xt_, const float *diff_dst_layer);
367     void copy_init_iter(int n_layer, int n_direction, int n_states, int batch,
368             int sic, int dic, int wic, int n_iter, float *ws_states_,
369             float *ws_diff_states_, const float *firstit_states_,
370             const float *diff_dst_iter);
371     void copy_res_layer(bool lr, bool rl, int n_layer, int n_direction,
372             int n_iter, int batch, int n_output_features, int slc, int dlc,
373             int wic, int n_states, mkldnn_rnn_direction_t direction,
374             float *dst_layer_, float *diff_src_layer, const float *ws_states_,
375             const float *ws_diff_states_);
376     void copy_res_iter(int n_layer, int n_direction, int n_states, int batch,
377             int sic, int dic, int wic, int n_iter, float *dst_iter_,
378             float *diff_src_iter, const float *ws_states_,
379             const float *ws_diff_states_);
380     void gates_reduction(int n_gates, int dic, int wic, int batch,
381             const float *ws_gates_, float *diff_bias_);
382     pd_t conf_;
383     bool use_workspace_;
384     scratchpad_t *scratchpad_;
385
386     size_t ws_gates_offset_;
387     size_t ws_states_offset_;
388     size_t ws_weights_layer_offset_;
389     size_t ws_weights_iter_offset_;
390     size_t ws_diff_states_offset_;
391     size_t ws_diff_weights_layer_offset_;
392     size_t ws_diff_weights_iter_offset_;
393     size_t ws_grid_comp_offset_;
394     size_t ws_cell_comp_offset_;
395
396     float *ws_gates_;
397     float *ws_states_;
398     float *ws_diff_states_;
399     float *ws_cell_;
400     float *ws_grid_;
401     float *ws_weights_layer_;
402     float *ws_weights_iter_;
403     float *ws_diff_weights_layer_;
404     float *ws_diff_weights_iter_;
405     int n_output_features;
406
407     float **ptr_wei_input_;
408     float **ptr_wei_state_;
409
410     execution_direction exec_dir;
411     grid_execution_f grid_computation;
412     cell_execution_f cell_func;
413
414     bool copy_weights_layer_;
415     bool copy_weights_iter_;
416     bool copy_diff_weights_layer_;
417     bool copy_diff_weights_iter_;
418     bool merge_gemm_layer;
419     bool merge_gemm_iter;
420     bool use_jit_sgemm_;
421
422     packing_t weights_input_pack_func;
423     packing_t weights_state_pack_func;
424
425     gemm_t gemm_input_func;
426     gemm_t gemm_state_func;
427     elemwise_f elemwise_func;
428
429     free_packed_t weights_input_free_packed_func;
430     free_packed_t weights_state_free_packed_func;
431 };
432
433 using ref_rnn_fwd_t = _ref_rnn_common_t<prop_kind::forward>;
434 using ref_rnn_bwd_t = _ref_rnn_common_t<prop_kind::backward>;
435 }
436 }
437 }
438 #endif
439
440 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s