1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef CPU_REF_RNN_HPP
18 #define CPU_REF_RNN_HPP
22 #include "c_types_map.hpp"
23 #include "cpu_engine.hpp"
24 #include "cpu_rnn_pd.hpp"
25 #include "cpu_isa_traits.hpp"
26 #include "scratchpad.hpp"
27 #include "type_helpers.hpp"
30 #include "gemm/os_blas.hpp"
36 #define elemwise_sig(f) \
37 void f(int dic, int wic, int batch, int n_states, int iter_stride, int n_gates, \
38 float *ws_gates_, float *states_t_l_, float *states_t_lm1_, \
39 float *states_tm1_l_, float *diff_states_t_l_, \
40 float *diff_states_t_lp1_, float *diff_states_tp1_l_, \
41 const float *bias_, float *ws_grid_, float *ws_cell_)
43 #define cell_execution_sig(f) \
44 void f(int dic, int slc, int sic, int wic, int batch, int n_gates, \
45 int n_states, int iter_stride, float *states_t_l_, float *diff_states_t_l_, \
46 float **w_input_, float **w_state_, const float *bias_, \
47 float *states_t_lm1_, float *states_tm1_l_, \
48 float *diff_states_t_lp1_, float *diff_states_tp1_l_, \
49 float *diff_w_input_, float *diff_w_state_, float *diff_bias_, \
50 float *ws_gates_, float *ws_grid_, float *ws_cell_)
52 #define grid_execution_sig(f) \
53 void f(int dic, int slc, int sic, int wic, int batch, int n_layer, \
54 int n_direction, int n_iter, int n_gates, int n_states, \
55 int n_bias, float **weights_input_, int n_parts_wei_i, \
56 float **weights_states_, int n_parts_wei_st, \
57 const float *bias_, float *ws_states_, float *ws_diff_states_, \
58 float *ws_gates_, float *ws_cell_, float *ws_grid_, \
59 int ws_per_cell, float *diff_weights_layer_, \
60 float *diff_weights_iter_, float *diff_bias_)
63 void f(int m, int n, int k, int strideA_m, int strideA_k, int strideB_n, \
64 int strideB_k, int strideC_m, int strideC_n, const float *a_, \
65 float *b_, float *c_, bool is_B_trans, float beta)
67 #define packing_sig(f) \
68 void f(int n_layer, int n_direction, int n_weights, int n_gates, \
69 int batch, int OC_size, int IC_size, float **weights_, \
70 int n_parts, int *gates_per_part, const float *w_, \
71 float * scratch_mem, bool do_copy)
73 #define free_packed_sig(f) void f(int n_layer, int n_direction, int n_parts, \
76 template <alg_kind_t alg_kind, prop_kind_t prop_kind>
77 float activation(float s, float alpha, float cliping, float dd);
79 template <prop_kind_t aprop>
80 struct _ref_rnn_common_t : public cpu_primitive_t {
81 using class_name = _ref_rnn_common_t<aprop>;
82 typedef enum execution_direction_ {
91 } execution_direction;
92 typedef elemwise_sig((class_name::*elemwise_f));
93 typedef cell_execution_sig((class_name::*cell_execution_f));
94 typedef grid_execution_sig((class_name::*grid_execution_f));
96 typedef gemm_sig((class_name::*gemm_t));
97 typedef packing_sig((class_name::*packing_t));
98 typedef free_packed_sig((class_name::*free_packed_t));
101 typename utils::conditional<false || aprop == prop_kind::forward,
102 cpu_rnn_fwd_pd_t, cpu_rnn_bwd_pd_t>::type;
104 struct pd_t : public base_pd_t {
105 pd_t(engine_t *engine, const rnn_desc_t *adesc,
106 const primitive_attr_t *attr,
107 const typename pd_t::base_class *hint_pd)
108 : base_pd_t(engine, adesc, attr, hint_pd) {}
110 DECLARE_COMMON_PD_T("ref:any", class_name);
113 using namespace prop_kind;
114 using namespace utils;
115 using namespace memory_format;
116 assert(this->engine()->kind() == engine_kind::cpu);
117 const alg_kind_t cell_kind = this->desc()->cell_desc.cell_kind;
120 && one_of(cell_kind, alg_kind::vanilla_rnn,
121 alg_kind::vanilla_lstm, alg_kind::vanilla_gru,
122 alg_kind::gru_linear_before_reset)
123 && IMPLICATION(aprop == prop_kind::forward,
124 one_of(this->desc()->prop_kind, forward_training,
126 && IMPLICATION(aprop == backward,
127 one_of(this->desc()->prop_kind, backward))
128 && this->set_default_params() == status::success;
130 return status::unimplemented;
132 ok = ok && utils::one_of(cell_kind, alg_kind::vanilla_rnn,
133 alg_kind::vanilla_lstm, alg_kind::vanilla_gru,
134 alg_kind::gru_linear_before_reset);
136 /// @todo check data layouts for all input tensors
137 ok = ok && this->desc()->src_layer_desc.format == tnc
138 && this->desc()->dst_layer_desc.format == tnc;
140 ok = ok && this->with_bias();
142 case (prop_kind::forward):
143 ok = ok && utils::one_of(this->desc()->prop_kind,
144 forward_training, forward_inference);
145 ok = ok && utils::one_of(
146 this->desc()->weights_layer_desc.format, any,
148 && utils::one_of(this->desc()->weights_iter_desc.format,
149 any, ldigo, ldigo_p);
151 case (prop_kind::backward):
152 ok = ok && utils::one_of(this->desc()->prop_kind, backward);
153 ok = ok && utils::one_of(
154 this->desc()->weights_layer_desc.format, any,
156 && utils::one_of(this->desc()->weights_iter_desc.format,
157 any, ldgoi, ldgoi_p);
162 // Check dimensions consistency
164 = (this->direction() == mkldnn_bidirectional_concat) ? 2 :
167 ok = ok && (ls_multiplier * this->DIC() == this->DLC())
168 && ((ls_multiplier * this->SLC()) == this->DLC()
170 && (this->SIC() == this->DIC() || (this->T() == 1));
172 // initialize the workspace_pd if needed
173 if (this->desc()->prop_kind != forward_inference){
174 dims_t ws_dims = { (dim_t)this->get_ws_size() };
176 mkldnn_memory_desc_init(
177 &ws_d, 1, ws_dims, impl::data_type::f32, memory_format::x);
178 this->ws_pd_ = cpu_memory_t::pd_t(this->engine(), &ws_d);
181 return ok ? status::success : status::unimplemented;
185 _ref_rnn_common_t(const pd_t *pd, const input_vector &inputs,
186 const output_vector &outputs)
187 : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd) {
188 /// @todo set max_feature_size assuming that we limit the number of
189 /// iterations and layer to one if slc != dic and sic != dic
192 memory_format_t packed_format;
194 case prop_kind::forward_inference:
195 case prop_kind::forward_training:
196 packed_format = memory_format::ldigo_p;
198 case prop_kind::backward: packed_format = memory_format::ldgoi_p; break;
199 default: assert(false);
202 merge_gemm_layer = ((aprop == prop_kind::forward) && (conf_.MB() < 128))
203 || (aprop == prop_kind::backward);
204 merge_gemm_iter = (aprop == prop_kind::backward)
205 && (!utils::one_of(conf_.cell_kind(), alg_kind::vanilla_gru,
206 alg_kind::gru_linear_before_reset));
207 auto set_pack_funcs = [](bool packed_gemm, gemm_t &g, bool pack_w,
208 packing_t &p, free_packed_t &f) {
209 g = packed_gemm ? &class_name::packed_gemm : &class_name::gemm;
210 p = pack_w ? &class_name::pack_weights :
211 &class_name::no_pack_weights;
212 f = pack_w ? &class_name::free_packed_weights :
213 &class_name::free_no_packed_weights;
215 #ifdef USE_MKL_PACKED_GEMM
216 const bool weights_pack_cond =
217 (conf_.T() > 1) && (conf_.MB() == 32) &&
218 (conf_.SIC() == 512) &&(conf_.SLC() == 512) && (conf_.DIC() == 512);
220 const bool weights_pack_cond = false;
223 const bool is_weights_state_packed = conf_.desc()->weights_iter_desc.format == packed_format;
224 set_pack_funcs(weights_pack_cond || is_weights_state_packed,
225 gemm_state_func, weights_pack_cond && !is_weights_state_packed,
226 weights_state_pack_func, weights_state_free_packed_func);
228 const bool is_weights_input_packed = conf_.desc()->weights_layer_desc.format == packed_format;
229 set_pack_funcs(weights_pack_cond || is_weights_input_packed,
230 gemm_input_func, weights_pack_cond && !is_weights_input_packed,
231 weights_input_pack_func, weights_input_free_packed_func);
233 switch (conf_.cell_kind()) {
234 case alg_kind::vanilla_lstm:
235 cell_func = &class_name::cell_execution;
236 elemwise_func = &class_name::lstm_elemwise;
238 case alg_kind::vanilla_rnn: // @todo switch on cell kind
239 cell_func = &class_name::cell_execution;
240 elemwise_func = &class_name::rnn_elemwise;
241 switch (conf_.activation_kind()) {
242 case alg_kind::eltwise_relu:
243 activation_func = &activation<alg_kind::eltwise_relu, aprop>;
245 case alg_kind::eltwise_tanh:
246 activation_func = &activation<alg_kind::eltwise_tanh, aprop>;
248 case alg_kind::eltwise_logistic:
249 activation_func = &activation<alg_kind::eltwise_logistic, aprop>;
254 case alg_kind::vanilla_gru:
255 cell_func = &class_name::cell_execution_gru;
257 case alg_kind::gru_linear_before_reset:
258 cell_func = &class_name::cell_execution_gru_lbr;
259 elemwise_func = &class_name::gru_lbr_elemwise;
265 = (conf_.direction() == mkldnn_bidirectional_concat) ? 2 : 1;
266 switch (conf_.direction()) {
267 case mkldnn_unidirectional_left2right: exec_dir = b2t_l2r; break;
268 case mkldnn_unidirectional_right2left: exec_dir = b2t_r2l; break;
269 case mkldnn_bidirectional_concat: exec_dir = b2t_bi_concat; break;
270 case mkldnn_bidirectional_sum: exec_dir = b2t_bi_sum; break;
274 /// @todo put a heuristic to choose between linear execution and
276 grid_computation = &class_name::linear_execution;
278 // we need to allocate memory for:
279 // - the states to compute a pass.
280 // - the intermediate results from the gates.
281 // - the diff_states to compute the backward pass (training only)
282 // These should be allocated on scratchpad if fwd inference
283 // or on a workspace provided by the user for training.
284 /// @todo shall we require the workspace for training or make it
287 // if no worskpace is provided on forward, we use a scratchpad
288 // NOTE: here we use a large worskpace for simplicity:
290 // - TODO: allocate only n_iter * dic + dic for linear execution
292 // - TODO: allocate only n_layer_wav * (2*dic) for wavefront
293 // execution (inference)
295 // - TODO: allocate only batch * n_gates * dic for linear execution
297 // = TODO: allocate only n_layer_wav * batch * n_gates * dic for
298 // wavefront execution (inference)
300 use_jit_sgemm_ = ((aprop == prop_kind::forward_inference)
301 || (conf_.is_training() && conf_.DIC() < 500))
302 && !mayiuse(avx512_mic);
304 copy_weights_layer_ = (conf_.WL_LD() != conf_.WL_GLD());
305 copy_weights_iter_ = (conf_.WI_LD() != conf_.WI_GLD());
307 copy_diff_weights_layer_ = (aprop == prop_kind::backward)
308 && (conf_.DWL_LD() != conf_.DWL_GLD());
309 copy_diff_weights_iter_ = (aprop == prop_kind::backward)
310 && (conf_.DWI_LD() != conf_.DWI_GLD());
312 use_workspace_ = (conf_.desc()->prop_kind != prop_kind::forward_inference);
314 size_t scratchpad_size = conf_.set_offsets(use_workspace_,
315 ws_gates_offset_, ws_states_offset_, ws_diff_states_offset_,
316 ws_grid_comp_offset_,
317 conf_.is_lbr(), ws_cell_comp_offset_,
318 copy_weights_layer_, ws_weights_layer_offset_,
319 copy_weights_iter_, ws_weights_iter_offset_,
320 copy_diff_weights_layer_, ws_diff_weights_layer_offset_,
321 copy_diff_weights_iter_, ws_diff_weights_iter_offset_);
324 create_scratchpad(scratchpad_size * sizeof(float));
326 int max_nparts = (conf_.cell_kind() == alg_kind::vanilla_gru) ? 2 : 1;
327 int ptr_wei_sz = conf_.L() * conf_.D() * max_nparts;
328 ptr_wei_input_ = (float **)malloc(sizeof(float *) * ptr_wei_sz, 64);
329 ptr_wei_state_ = (float **)malloc(sizeof(float *) * ptr_wei_sz, 64);
331 ~_ref_rnn_common_t() {
333 free(ptr_wei_input_);
334 free(ptr_wei_state_);
337 // typedef typename prec_traits::type data_t;
339 virtual void execute(event_t *e) {
341 e->set_state(event_t::ready);
346 grid_execution_sig(linear_execution);
347 // grid_execution_sig(wavefront_execution);
348 cell_execution_sig(cell_execution);
349 cell_execution_sig(cell_execution_gru);
350 cell_execution_sig(cell_execution_gru_lbr);
351 elemwise_sig(rnn_elemwise);
352 elemwise_sig(lstm_elemwise);
353 elemwise_sig(gru_lbr_elemwise);
355 gemm_sig(packed_gemm);
356 packing_sig(pack_weights);
357 packing_sig(no_pack_weights);
358 free_packed_sig(free_packed_weights);
359 free_packed_sig(free_no_packed_weights);
361 float (*activation_func)(float dd, float s, float alpha, float cliping);
363 void copy_init_layer(bool lr, bool rl, int n_direction, int n_layer,
364 int n_iter, int batch, int slc, int dic, int dlc, int wic,
365 int n_states, float *ws_states_, float *ws_diff_states_,
366 const float *xt_, const float *diff_dst_layer);
367 void copy_init_iter(int n_layer, int n_direction, int n_states, int batch,
368 int sic, int dic, int wic, int n_iter, float *ws_states_,
369 float *ws_diff_states_, const float *firstit_states_,
370 const float *diff_dst_iter);
371 void copy_res_layer(bool lr, bool rl, int n_layer, int n_direction,
372 int n_iter, int batch, int n_output_features, int slc, int dlc,
373 int wic, int n_states, mkldnn_rnn_direction_t direction,
374 float *dst_layer_, float *diff_src_layer, const float *ws_states_,
375 const float *ws_diff_states_);
376 void copy_res_iter(int n_layer, int n_direction, int n_states, int batch,
377 int sic, int dic, int wic, int n_iter, float *dst_iter_,
378 float *diff_src_iter, const float *ws_states_,
379 const float *ws_diff_states_);
380 void gates_reduction(int n_gates, int dic, int wic, int batch,
381 const float *ws_gates_, float *diff_bias_);
384 scratchpad_t *scratchpad_;
386 size_t ws_gates_offset_;
387 size_t ws_states_offset_;
388 size_t ws_weights_layer_offset_;
389 size_t ws_weights_iter_offset_;
390 size_t ws_diff_states_offset_;
391 size_t ws_diff_weights_layer_offset_;
392 size_t ws_diff_weights_iter_offset_;
393 size_t ws_grid_comp_offset_;
394 size_t ws_cell_comp_offset_;
398 float *ws_diff_states_;
401 float *ws_weights_layer_;
402 float *ws_weights_iter_;
403 float *ws_diff_weights_layer_;
404 float *ws_diff_weights_iter_;
405 int n_output_features;
407 float **ptr_wei_input_;
408 float **ptr_wei_state_;
410 execution_direction exec_dir;
411 grid_execution_f grid_computation;
412 cell_execution_f cell_func;
414 bool copy_weights_layer_;
415 bool copy_weights_iter_;
416 bool copy_diff_weights_layer_;
417 bool copy_diff_weights_iter_;
418 bool merge_gemm_layer;
419 bool merge_gemm_iter;
422 packing_t weights_input_pack_func;
423 packing_t weights_state_pack_func;
425 gemm_t gemm_input_func;
426 gemm_t gemm_state_func;
427 elemwise_f elemwise_func;
429 free_packed_t weights_input_free_packed_func;
430 free_packed_t weights_state_free_packed_func;
433 using ref_rnn_fwd_t = _ref_rnn_common_t<prop_kind::forward>;
434 using ref_rnn_bwd_t = _ref_rnn_common_t<prop_kind::backward>;
440 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s