1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef CPU_REF_RNN_HPP
18 #define CPU_REF_RNN_HPP
22 #include "c_types_map.hpp"
23 #include "memory_tracking.hpp"
24 #include "type_helpers.hpp"
27 #include "../cpu_isa_traits.hpp"
28 #include "../gemm/os_blas.hpp"
30 #include "cpu_rnn_pd.hpp"
31 #include "rnn_utils.hpp"
32 #include "jit_uni_rnn_postgemm.hpp"
38 template <alg_kind_t alg_kind, prop_kind_t prop_kind>
39 float activation(float s, float alpha, float cliping, float dd);
41 template <prop_kind_t aprop, impl::data_type_t src_type,
42 impl::data_type_t weights_type>
43 struct _ref_rnn_common_t : public cpu_primitive_t {
44 typedef typename prec_traits<src_type>::type src_data_t;
45 typedef typename prec_traits<weights_type>::type weights_data_t;
46 typedef typename utils::conditional<src_type == data_type::u8, int32_t,
47 float>::type acc_data_t;
49 using class_name = _ref_rnn_common_t<aprop, src_type, weights_type>;
51 typedef rnn_elemwise_sig((class_name::*elemwise_f));
52 typedef rnn_cell_execution_sig((class_name::*cell_execution_f));
53 typedef rnn_grid_execution_sig((class_name::*grid_execution_f));
55 typedef rnn_gemm_sig((class_name::*gemm_t));
56 typedef rnn_bias_prepare_sig((class_name::*bias_prepare_t));
57 typedef rnn_bias_finalize_sig((class_name::*bias_finalize_t));
58 typedef rnn_weights_assign_sig((class_name::*weights_assign_t));
61 typename utils::conditional<false || aprop == prop_kind::forward,
62 cpu_rnn_fwd_pd_t, cpu_rnn_bwd_pd_t>::type;
64 struct pd_t : public base_pd_t {
65 pd_t(engine_t *engine, const rnn_desc_t *adesc,
66 const primitive_attr_t *attr,
67 const typename pd_t::hint_class *hint_pd)
68 : base_pd_t(engine, adesc, attr, hint_pd) {}
70 DECLARE_COMMON_PD_T("ref:any", class_name);
73 using namespace prop_kind;
74 using namespace utils;
75 using namespace memory_format;
76 using namespace rnn_utils;
77 assert(this->engine()->kind() == engine_kind::cpu);
78 const alg_kind_t cell_kind = this->desc()->cell_desc.cell_kind;
80 data_type_t src_layer_dt = this->desc()->src_layer_desc.data_type;
81 data_type_t weights_iter_dt
82 = this->desc()->weights_iter_desc.data_type;
83 data_type_t weights_layer_dt
84 = this->desc()->weights_layer_desc.data_type;
87 && one_of(cell_kind, alg_kind::vanilla_rnn,
88 alg_kind::vanilla_lstm, alg_kind::vanilla_gru,
89 alg_kind::gru_linear_before_reset)
90 && IMPLICATION(aprop == prop_kind::forward,
91 one_of(this->desc()->prop_kind, forward_training,
93 && IMPLICATION(aprop == backward,
94 one_of(this->desc()->prop_kind, backward))
95 && src_layer_dt == src_type
97 weights_type, weights_iter_dt, weights_layer_dt)
98 && this->set_default_params() == status::success
101 return status::unimplemented;
103 init_conf(rnn_, *this->desc(), this->src_pd(0), this->src_pd(1),
104 this->weights_pd(0), this->weights_pd(1), this->dst_pd(0));
106 if (rnn_.dt_conf == all_f32)
107 ok = ok && this->attr()->has_default_values();
109 // Set weights descriptors to desired format
110 memory_desc_t weights_layer_md = *(this->weights_layer_pd_.desc());
111 CHECK(set_expected_desc(rnn_, weights_layer_md, false));
112 cpu_memory_t::pd_t new_weights_layer_pd(
113 this->engine_, &weights_layer_md);
114 if (this->weights_layer_pd_.desc()->format == any) {
115 this->weights_layer_pd_ = new_weights_layer_pd;
116 } else if (this->weights_layer_pd_.desc()->format == rnn_packed) {
117 if (!this->weights_layer_pd_.is_equal(&new_weights_layer_pd))
118 return status::unimplemented;
121 memory_desc_t weights_iter_md = *(this->weights_iter_pd_.desc());
122 CHECK(set_expected_desc(rnn_, weights_iter_md, true));
123 cpu_memory_t::pd_t new_weights_iter_pd(
124 this->engine_, &weights_iter_md);
125 if (this->weights_iter_pd_.desc()->format == any) {
126 this->weights_iter_pd_ = new_weights_iter_pd;
127 } else if (this->weights_iter_pd_.desc()->format == rnn_packed) {
128 if (!this->weights_iter_pd_.is_equal(&new_weights_iter_pd))
129 return status::unimplemented;
132 CHECK(this->check_layout_consistency());
134 set_conf(rnn_, *this->desc(), this->weights_pd(0),
135 this->weights_pd(1), this->diff_weights_pd(0),
136 this->diff_weights_pd(1));
138 size_t scratchpad_sz{0}, ws_sz{0};
139 get_scratchpad_and_workspace_sizes(rnn_, scratchpad_sz, ws_sz);
141 // initialize the workspace_pd if needed
142 if (rnn_.is_training) {
143 dims_t ws_dims = {(int)ws_sz};
145 mkldnn_memory_desc_init(&ws_d, 1, ws_dims, data_type::u8, x);
146 this->ws_pd_ = cpu_memory_t::pd_t(this->engine(), &ws_d);
149 init_scratchpad(scratchpad_sz);
151 return status::success;
154 rnn_utils::rnn_conf_t rnn_;
157 void init_scratchpad(size_t scratchpad_sz) {
158 using namespace memory_tracking::names;
159 auto scratchpad = this->scratchpad_registry().registrar();
160 scratchpad.book(key_rnn_space, sizeof(float) * scratchpad_sz, 4096);
162 int max_nparts = this->cell_kind() == alg_kind::vanilla_gru ? 2 : 1;
163 int ptr_wei_sz = rnn_.n_layer * rnn_.n_dir * max_nparts;
164 scratchpad.book(key_rnn_ptrs_wei_layer,
165 sizeof(float *) * ptr_wei_sz);
166 scratchpad.book(key_rnn_ptrs_wei_iter,
167 sizeof(float *) * ptr_wei_sz);
168 scratchpad.book(key_rnn_ptrs_bia,
169 sizeof(float *) * ptr_wei_sz);
173 _ref_rnn_common_t(const pd_t *apd, const input_vector &inputs,
174 const output_vector &outputs)
175 : cpu_primitive_t(apd, inputs, outputs, true), rnn_postgemm_(nullptr) {
176 /// @todo set max_feature_size assuming that we limit the number of
177 /// iterations and layer to one if slc != dic and sic != dic
180 bias_preparation_func = &class_name::bias_prepare;
181 bias_finalization_func = &class_name::bias_finalize;
184 = [](bool packed_gemm, gemm_t &g, weights_assign_t &a) {
186 g = &class_name::packed_gemm;
187 a = &class_name::assign_packed_weights;
189 g = &class_name::gemm;
190 a = &class_name::assign_weights;
193 set_gemm_funcs(pd()->rnn_.use_iter_packed_gemm, gemm_iter_func,
194 weights_iter_assign_func);
196 set_gemm_funcs(pd()->rnn_.use_layer_packed_gemm, gemm_layer_func,
197 weights_layer_assign_func);
199 switch (pd()->cell_kind()) {
200 case alg_kind::vanilla_lstm:
201 cell_func = &class_name::cell_execution;
202 if (aprop == prop_kind::forward) {
203 if (mayiuse(avx512_core))
204 rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd<avx512_core, src_type>(
205 pd()->rnn_, pd()->attr());
206 else if (mayiuse(avx2))
207 rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd<avx2, src_type>(
208 pd()->rnn_, pd()->attr());
209 else if (mayiuse(sse42))
210 rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd<sse42, src_type>(
211 pd()->rnn_, pd()->attr());
212 assert(rnn_postgemm_ != nullptr);
213 rnn_postgemm_->init();
215 elemwise_func = &class_name::lstm_elemwise;
217 case alg_kind::vanilla_rnn: // @todo switch on cell kind
218 cell_func = &class_name::cell_execution;
219 elemwise_func = &class_name::rnn_elemwise;
220 switch (pd()->activation_kind()) {
221 case alg_kind::eltwise_relu:
222 activation_func = &activation<alg_kind::eltwise_relu, aprop>;
224 case alg_kind::eltwise_tanh:
225 activation_func = &activation<alg_kind::eltwise_tanh, aprop>;
227 case alg_kind::eltwise_logistic:
228 activation_func = &activation<alg_kind::eltwise_logistic, aprop>;
233 case alg_kind::vanilla_gru:
234 cell_func = &class_name::cell_execution_gru;
236 case alg_kind::gru_linear_before_reset:
237 cell_func = &class_name::cell_execution_gru_lbr;
238 elemwise_func = &class_name::gru_lbr_elemwise;
243 grid_computation = &class_name::linear_execution;
245 size_t scratchpad_size, workspace_size;
246 rnn_utils::set_offsets(pd()->rnn_, ws_gates_offset_, ws_states_offset_,
247 ws_c_states_offset_, ws_diff_states_offset_,
248 ws_grid_comp_offset_, ws_cell_comp_offset_,
249 ws_bias_offset_, scratchpad_size, workspace_size);
252 ~_ref_rnn_common_t() {}
254 // typedef typename prec_traits::type data_t;
256 virtual void execute(event_t *e) const {
258 e->set_state(event_t::ready);
262 void execute_() const;
263 rnn_grid_execution_sig(linear_execution);
264 rnn_cell_execution_sig(cell_execution);
265 rnn_cell_execution_sig(cell_execution_gru);
266 rnn_cell_execution_sig(cell_execution_gru_lbr);
267 rnn_elemwise_sig(rnn_elemwise);
268 rnn_elemwise_sig(lstm_elemwise);
269 rnn_elemwise_sig(gru_lbr_elemwise);
271 rnn_gemm_sig(packed_gemm);
272 rnn_bias_prepare_sig(bias_prepare);
273 rnn_bias_finalize_sig(bias_finalize);
274 rnn_weights_assign_sig(assign_weights);
275 rnn_weights_assign_sig(assign_packed_weights);
277 float (*activation_func)(float dd, float s, float alpha, float cliping);
279 void copy_init_layer(const rnn_utils::rnn_conf_t &rnn,
280 src_data_t *ws_states_, float *ws_diff_states_,
281 const src_data_t *xt_, const float *diff_dst_layer) const;
283 template <typename input_data_t>
284 void copy_init_iter(const rnn_utils::rnn_conf_t &rnn,
285 src_data_t *ws_states_, float *ws_c_states, float *ws_diff_states_,
286 const input_data_t *firstit_states_,
287 const float *diff_dst_iter) const;
289 template <typename dst_data_t>
290 void copy_res_layer(const rnn_utils::rnn_conf_t &rnn,
291 dst_data_t *dst_layer_, float *diff_src_layer,
292 const src_data_t *ws_states_, const float *ws_diff_states_) const;
294 template <typename output_data_t>
295 void copy_res_iter(const rnn_utils::rnn_conf_t &rnn,
296 output_data_t *dst_iter_, float *diff_src_iter,
297 const src_data_t *ws_states_, float *ws_c_states,
298 const float *ws_diff_states_) const;
300 void gates_reduction(const rnn_utils::rnn_conf_t &rnn,
301 const acc_data_t *ws_gates_, float *diff_bias_) const;
303 const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
305 size_t ws_gates_offset_;
306 size_t ws_states_offset_;
307 size_t ws_c_states_offset_;
308 size_t ws_bias_offset_;
309 size_t ws_diff_states_offset_;
310 size_t ws_grid_comp_offset_;
311 size_t ws_cell_comp_offset_;
312 jit_uni_rnn_postgemm_kernel *rnn_postgemm_;
314 grid_execution_f grid_computation;
315 cell_execution_f cell_func;
317 bias_prepare_t bias_preparation_func;
318 bias_finalize_t bias_finalization_func;
319 weights_assign_t weights_layer_assign_func;
320 weights_assign_t weights_iter_assign_func;
322 gemm_t gemm_layer_func;
323 gemm_t gemm_iter_func;
324 elemwise_f elemwise_func;
327 using ref_rnn_fwd_f32_t = _ref_rnn_common_t<prop_kind::forward, data_type::f32, data_type::f32>;
328 using ref_rnn_bwd_f32_t = _ref_rnn_common_t<prop_kind::backward, data_type::f32, data_type::f32>;
329 using ref_rnn_fwd_u8s8_t = _ref_rnn_common_t<prop_kind::forward, data_type::u8, data_type::s8>;
335 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s