1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
20 for diff states, we have n_states + 1 as we have n_states diff
21 to propagate to the previous iteration and 1 states to propagate
23 index 0 is dh for cell(t-1, l) to consume
24 index 1 is dc for cell(t-1, l) to consume
25 index 2 is dh for cell(t, l-1) to consume
26 this indexing enables to have the same indexing for states in elemwise
28 only the cell execution function should be impacted
32 #include "math_utils.hpp"
33 #include "mkldnn_thread.hpp"
35 #include "ref_rnn.hpp"
36 #include "../gemm/gemm.hpp"
37 #include "../simple_q10n.hpp"
43 using namespace mkldnn::impl::utils;
44 using namespace mkldnn::impl::memory_tracking::names;
45 using namespace rnn_utils;
46 #define AOC array_offset_calculator
48 template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
49 void _ref_rnn_common_t<aprop, src_type, weights_type>::gates_reduction(
50 const rnn_conf_t &rnn, const acc_data_t *ws_gates_,
51 float *diff_bias_) const {
52 auto body = [&](int i, int k) {
53 for (int j = 0; j < rnn.mb; j++)
54 diff_bias_[i * rnn.dic + k]
55 += ws_gates_[j * rnn.gates_ws_ld + i * rnn.dic + k];
58 // @todo block k on simd-width
59 #if MKLDNN_THR == MKLDNN_THR_OMP && _OPENMP >= 201307 \
60 /* icc 17.0 has a problem with simd collapse */ \
61 && !((defined __INTEL_COMPILER) && (__INTEL_COMPILER == 1700))
62 #pragma omp parallel for simd collapse(2)
63 for (int i = 0; i < rnn.n_gates; i++)
64 for (int k = 0; k < rnn.dic; k++)
67 parallel_nd(rnn.n_gates, rnn.dic, body);
71 template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
72 rnn_gemm_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::gemm)) {
73 assert(ldA * ldB * ldC != 0);
74 extended_sgemm(&transA, &transB, &m, &n, &k, &alpha, a_, &ldA, b_, &ldB,
75 &beta, c_, &ldC, nullptr, pd()->rnn_.use_jit_gemm);
79 rnn_gemm_sig((ref_rnn_fwd_u8s8_t::gemm)) {
80 assert(!"non packed gemm is disabled for int8");
83 template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
84 rnn_gemm_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::packed_gemm)) {
85 #if (USE_MKL_PACKED_GEMM)
86 assert(transA == 'N');
87 cblas_sgemm_compute(CblasColMajor, CblasPacked,
88 (transB == 'T') ? CblasTrans : CblasNoTrans, m, n, k, a_, ldA, b_,
103 assert(!"packed gemm is disabled");
108 rnn_gemm_sig((ref_rnn_fwd_u8s8_t::packed_gemm)) {
109 #if (USE_MKL_PACKED_GEMM)
110 int8_t offseta = 0, offsetb = 0;
112 cblas_gemm_s8u8s32_compute(CblasColMajor, (CBLAS_TRANSPOSE)CblasPacked,
113 CblasNoTrans, CblasFixOffset, m, n, k, alpha, a_, ldA, offseta, b_,
114 ldB, offsetb, beta, c_, ldC, &offsetc);
128 assert(!"packed gemm is disabled");
132 //*************** Grid computations strategy: linear ***************//
133 template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
134 rnn_grid_execution_sig(
135 (_ref_rnn_common_t<aprop, src_type, weights_type>::linear_execution)) {
136 AOC<src_data_t, 4> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir,
137 rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld);
138 AOC<float, 4> ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir,
139 rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld);
140 AOC<float, 5> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir,
141 (rnn.n_states + 1), rnn.n_iter + 1,
142 rnn.states_nld * rnn.states_ws_ld);
143 AOC<acc_data_t, 4> ws_gates(ws_gates_, rnn.n_layer, rnn.n_dir, rnn.n_iter,
144 rnn.gates_nld * rnn.gates_ws_ld);
145 AOC<weights_data_t *, 3> weights_input(
146 weights_layer_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_layer);
147 AOC<weights_data_t *, 3> weights_states(
148 weights_states_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_iter);
150 bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias);
151 AOC<float, 3> diff_weights_layer(diff_weights_layer_, rnn.n_layer,
153 rnn.diff_weights_layer_nld * rnn.diff_weights_layer_ld);
154 AOC<float, 3> diff_weights_iter(diff_weights_iter_, rnn.n_layer, rnn.n_dir,
155 rnn.diff_weights_iter_nld * rnn.diff_weights_iter_ld);
156 AOC<float, 3> diff_bias(
157 diff_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic);
158 AOC<float, 4> ws_grid(
159 ws_grid_, rnn.n_layer, rnn.n_dir, rnn.n_iter, (int)rnn.ws_per_cell);
161 // We run the grid of computation
162 for (int dir = 0; dir < rnn.n_dir; dir++) {
163 for (int j = 0; j < rnn.n_layer; j++) {
164 int lay = (aprop == prop_kind::forward) ? j : rnn.n_layer - j - 1;
166 if ((aprop == prop_kind::forward) && rnn.merge_gemm_layer) {
167 (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic,
168 rnn.mb * rnn.n_iter, rnn.slc, 1.0,
169 weights_input(lay, dir, 0), rnn.weights_iter_ld,
170 &(ws_states(lay, dir, 1, 0)), rnn.states_ws_ld, 0.0,
171 &(ws_gates(lay, dir, 0, 0)), rnn.gates_ws_ld);
174 for (int i = 0; i < rnn.n_iter; i++) {
175 int iter = (aprop == prop_kind::forward) ? i : rnn.n_iter - i - 1;
176 (this->*cell_func)(rnn,
177 &(ws_states(lay + 1, dir, iter + 1, 0)),
178 &(ws_c_states(lay + 1, dir, iter + 1, 0)),
179 &(ws_diff_states(lay, dir, 0, iter, 0)),
180 &(weights_input(lay, dir, 0)),
181 &(weights_states(lay, dir, 0)),
182 &(bias(lay, dir, 0)),
183 &(ws_states(lay, dir, iter + 1, 0)),
184 &(ws_states(lay + 1, dir, iter, 0)),
185 &(ws_c_states(lay + 1, dir, iter, 0)),
186 &(ws_diff_states(lay + 1, dir, 0, iter, 0)),
187 &(ws_diff_states(lay, dir, 0, iter + 1, 0)),
188 &(diff_weights_layer(lay, dir, 0)),
189 &(diff_weights_iter(lay, dir, 0)),
190 &(diff_bias(lay, dir, 0)),
191 &(ws_gates(lay, dir, iter, 0)),
192 &(ws_grid(lay, dir, iter, 0)),
196 if ((aprop == prop_kind::backward) && rnn.merge_gemm_layer) {
197 (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb * rnn.n_iter,
198 rnn.n_gates * rnn.dic, 1.0, weights_input(lay, dir, 0),
199 rnn.weights_layer_ld,
200 (src_data_t *)(&(ws_gates(lay, dir, 0, 0))),
201 rnn.gates_ws_ld, 0.0,
202 (acc_data_t *)(&(ws_diff_states(
203 lay, dir, rnn.n_states, 0, 0))),
205 gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc,
206 rnn.mb * rnn.n_iter, 1.0,
207 (weights_data_t *)(&(ws_gates(lay, dir, 0, 0))),
209 (src_data_t *)(&(ws_states(lay, dir, 1, 0))),
210 rnn.states_ws_ld, 1.0,
211 (acc_data_t *)(&(diff_weights_layer(lay, dir, 0))),
212 rnn.diff_weights_layer_ld);
214 if ((aprop == prop_kind::backward) && rnn.merge_gemm_iter) {
215 gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic,
216 rnn.mb * rnn.n_iter, 1.0,
217 (weights_data_t *)(&(ws_gates(lay, dir, 0, 0))),
219 (src_data_t *)(&(ws_states(lay + 1, dir, 0, 0))),
220 rnn.states_ws_ld, 1.0,
221 (acc_data_t *)(&(diff_weights_iter(lay, dir, 0))),
222 rnn.diff_weights_iter_ld);
228 //********* GRID computations strategy: utility functions **********//
230 template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
231 void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_init_layer(
232 const rnn_conf_t &rnn, src_data_t *__restrict ws_states_,
233 float *__restrict ws_diff_states_, const src_data_t *__restrict xt_,
234 const float *__restrict diff_dst_layer_) const {
236 AOC<src_data_t, 4> ws_states(
237 ws_states_, rnn.n_dir, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
238 auto xt_d = memory_desc_wrapper(pd()->src_pd(0));
240 parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
241 auto xxt = xt_ + xt_d.blk_off(it, b);
242 src_data_t *ws_l2r_ptr = &(ws_states(0, it + 1, b, 0));
243 src_data_t *ws_r2l_ptr = &(ws_states(rnn.n_dir - 1, rnn.n_iter - it, b, 0));
244 if (rnn.exec_dir != r2l)
245 for (int c = 0; c < rnn.slc; c++)
246 ws_l2r_ptr[c] = xxt[c];
247 if (rnn.exec_dir != l2r)
248 for (int c = 0; c < rnn.slc; c++)
249 ws_r2l_ptr[c] = xxt[c];
254 void ref_rnn_bwd_f32_t::copy_init_layer(const rnn_conf_t &rnn,
255 src_data_t *ws_states_, float *ws_diff_states_, const src_data_t *xt_,
256 const float *diff_dst_layer_) const {
257 AOC<float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir,
258 (rnn.n_states + 1), rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
259 auto diff_dst_layer_d = memory_desc_wrapper(pd()->diff_dst_pd(0));
261 switch (rnn.exec_dir) {
263 parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
264 auto diff_dst_layer_x
265 = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b);
266 for (int s = 0; s < rnn.dic; s++) {
267 ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s)
268 = diff_dst_layer_x[s];
270 rnn.n_layer, 1, rnn.n_states, rnn.n_iter - it - 1, b, s)
271 = diff_dst_layer_x[rnn.dic + s];
276 parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
277 auto diff_dst_layer_x
278 = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b);
279 for (int s = 0; s < rnn.dic; s++) {
280 ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s)
281 = diff_dst_layer_x[s];
283 rnn.n_layer, 1, rnn.n_states, rnn.n_iter - it - 1, b, s)
284 = diff_dst_layer_x[s];
289 parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
290 auto diff_dst_layer_x
291 = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b);
292 for (int s = 0; s < rnn.dic; s++) {
293 ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s)
294 = diff_dst_layer_x[s];
299 parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
300 auto diff_dst_layer_x = diff_dst_layer_
301 + diff_dst_layer_d.blk_off(rnn.n_iter - it - 1, b);
302 for (int s = 0; s < rnn.dic; s++) {
303 ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s)
304 = diff_dst_layer_x[s];
308 default: assert(!"Unsupported direction"); break;
312 /* For int8 configuration, input iteration states may be of types f32 or u8
313 * Internally h_state is always stored in u8 and c_state is always stored in f32
314 * If input states are of type u8 then h state is copied and c state is dequantized
315 * If input states are of type f32 then h state is quantized and c_state is copied
317 template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
318 template <typename input_data_t>
319 void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_init_iter(
320 const rnn_conf_t &rnn, src_data_t *__restrict ws_states_,
321 float *__restrict ws_c_states_, float *__restrict ws_diff_states_,
322 const input_data_t *__restrict firstit_states_,
323 const float *__restrict diff_dst_iter_) const {
324 AOC<src_data_t, 5> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir,
325 rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
326 AOC<float, 5> ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir,
327 rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
328 float data_shift = pd()->attr()->rnn_data_qparams_.shift_;
329 float data_scale = pd()->attr()->rnn_data_qparams_.scale_;
330 round_mode_t rmode = pd()->attr()->round_mode_;
333 = pd()->desc()->src_iter_desc.data_type == data_type::f32
334 && rnn.dt_conf != all_f32;
335 auto maybe_q = [&](input_data_t f) {
337 float qf = f * data_scale + data_shift;
338 return qz_a1b0<float, src_data_t>()(qf, rmode);
340 return (src_data_t)f;
343 const bool dequantize
344 = pd()->desc()->src_iter_desc.data_type == data_type::u8;
345 auto maybe_deq = [&](input_data_t s) {
347 return (((float)s - data_shift) / data_scale);
351 auto firstit_states_d = memory_desc_wrapper(pd()->src_pd(1));
352 if (firstit_states_) {
354 rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) {
355 for (int s = 0; s < rnn.sic; s++)
356 ws_states(lay + 1, dir, 0, b, s) = maybe_q(
357 firstit_states_[firstit_states_d.blk_off(
358 lay, dir, 0, b, s)]);
359 if (pd()->cell_kind() == alg_kind::vanilla_lstm)
360 for (int s = 0; s < rnn.sic; s++)
361 ws_c_states(lay + 1, dir, 0, b, s) = maybe_deq(
362 firstit_states_[firstit_states_d.blk_off(
363 lay, dir, 1, b, s)]);
367 rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) {
368 for (int j = 0; j < rnn.sic; j++) {
369 ws_states(lay + 1, dir, 0, b, j) = (src_data_t)0;
370 ws_c_states(lay + 1, dir, 0, b, j) = 0.0f;
377 template <typename input_data_t>
378 void ref_rnn_bwd_f32_t::copy_init_iter(const rnn_conf_t &rnn,
379 src_data_t *ws_states_, float *ws_c_states_, float *ws_diff_states_,
380 const input_data_t *firstit_states_,
381 const float *diff_dst_iter_) const {
382 AOC<float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir,
383 rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
384 auto diff_dst_iter_d = memory_desc_wrapper(pd()->diff_dst_pd(1));
385 if (diff_dst_iter_) {
386 parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb,
387 [&](int lay, int dir, int state, int b) {
388 array_copy(&(ws_diff_states(
389 lay, dir, state, rnn.n_iter, b, 0)),
391 + diff_dst_iter_d.blk_off(
396 parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb,
397 [&](int lay, int dir, int state, int i) {
398 for (int j = 0; j < rnn.dic; j++)
399 ws_diff_states(lay, dir, state, rnn.n_iter, i, j)
405 template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
406 template <typename dst_data_t>
407 void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_res_layer(
408 const rnn_conf_t &rnn, dst_data_t *dst_layer_, float *diff_src_layer,
409 const src_data_t *ws_states_, const float *ws_diff_states_) const {
411 auto dst_layer_d = memory_desc_wrapper(pd()->dst_pd(0));
412 AOC<const src_data_t, 5> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir,
413 rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
414 float shift = (pd()->attr()->rnn_data_qparams_.shift_);
415 float scale = (pd()->attr()->rnn_data_qparams_.scale_);
417 const bool dequantize
418 = pd()->desc()->dst_layer_desc.data_type == data_type::f32
419 && rnn.dt_conf != all_f32;
420 auto maybe_deq = [&](src_data_t s) {
422 return (dst_data_t)(((float)s - shift) / scale);
424 return (dst_data_t)s;
426 parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
428 if (rnn.exec_dir != r2l) {
429 for (int s = 0; s < rnn.dic; s++) {
430 dst_layer_[dst_layer_d.blk_off(it, b, dir * rnn.dic + s)]
431 = maybe_deq(ws_states(rnn.n_layer, dir, it + 1, b, s));
435 if (rnn.exec_dir != l2r) {
436 for (int s = 0; s < rnn.dic; s++)
437 switch (rnn.exec_dir) {
439 dst_layer_[dst_layer_d.blk_off(it, b, s)]
440 += maybe_deq(ws_states(
441 rnn.n_layer, dir, rnn.n_iter - it, b, s));
444 dst_layer_[dst_layer_d.blk_off(it, b, dir * rnn.dic + s)]
445 = maybe_deq(ws_states(
446 rnn.n_layer, dir, rnn.n_iter - it, b, s));
453 template <typename dst_data_t>
454 void ref_rnn_bwd_f32_t::copy_res_layer(
455 const rnn_conf_t &rnn, dst_data_t *dst_layer_, float *diff_src_layer_,
456 const src_data_t *ws_states_, const float *ws_diff_states_) const {
457 auto diff_src_layer_d = memory_desc_wrapper(pd()->diff_src_pd(0));
458 AOC<const float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1,
459 rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb,
462 parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
464 for (int s = 0; s < rnn.slc; s++) {
465 float *dst_addr = diff_src_layer_
466 + diff_src_layer_d.blk_off(
467 (rnn.exec_dir == r2l) ? rnn.n_iter - 1 - it : it,
468 b, dir * rnn.slc + s);
469 float res = ws_diff_states(0, 0, rnn.n_states, it, b, s);
471 res += ws_diff_states(
472 0, 1, rnn.n_states, rnn.n_iter - 1 - it, b, s);
478 template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
479 template <typename output_data_t>
480 void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_res_iter(
481 const rnn_conf_t &rnn, output_data_t *dst_iter_, float *diff_src_iter_,
482 const src_data_t *ws_states_, float *ws_c_states_,
483 const float *ws_diff_states_) const {
484 auto dst_iter_d = memory_desc_wrapper(pd()->dst_pd(1));
485 AOC<const src_data_t, 5> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir,
486 rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
487 AOC<const float, 5> ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir,
488 rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
489 float data_shift = pd()->attr()->rnn_data_qparams_.shift_;
490 float data_scale = pd()->attr()->rnn_data_qparams_.scale_;
491 round_mode_t rmode = pd()->attr()->round_mode_;
493 const bool quantize = pd()->desc()->dst_iter_desc.data_type == data_type::u8
494 && rnn.dt_conf != all_f32;
495 auto maybe_q = [&](float f) {
497 float qf = f * data_scale + data_shift;
498 return qz_a1b0<float, output_data_t>()(qf, rmode);
500 return (output_data_t)f;
503 const bool dequantize
504 = pd()->desc()->dst_iter_desc.data_type == data_type::f32
505 && rnn.dt_conf != all_f32;
506 auto maybe_deq = [&](src_data_t s) {
508 return (output_data_t)(((float)s - data_shift) / data_scale);
510 return (output_data_t)s;
513 parallel_nd(rnn.n_layer, rnn.n_dir, rnn.mb,
514 [&](int lay, int dir, int b) {
515 for (int s = 0; s < rnn.dic; s++) {
516 dst_iter_[dst_iter_d.blk_off(lay, dir, 0, b, s)]
517 = maybe_deq(ws_states(lay + 1, dir, rnn.n_iter, b, s));
519 if (pd()->cell_kind() == alg_kind::vanilla_lstm)
520 for (int s = 0; s < rnn.dic; s++) {
521 dst_iter_[dst_iter_d.blk_off(lay, dir, 1, b, s)]
522 = maybe_q(ws_c_states(
523 lay + 1, dir, rnn.n_iter, b, s));
530 template <typename output_data_t>
531 void ref_rnn_bwd_f32_t::copy_res_iter(
532 const rnn_conf_t &rnn, output_data_t *dst_iter_, float *diff_src_iter_,
533 const src_data_t *ws_states_, float *ws_c_states_,
534 const float *ws_diff_states_) const {
535 auto diff_src_iter_d = memory_desc_wrapper(pd()->diff_src_pd(1));
536 AOC<const float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1,
537 rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb,
539 if (diff_src_iter_) {
540 parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb,
541 [&](int lay, int dir, int state, int b) {
542 for (int s = 0; s < rnn.sic; s++) {
543 diff_src_iter_[diff_src_iter_d.blk_off(
544 lay, dir, state, b, s)]
545 = ws_diff_states(lay, dir, state, 0, b, s);
551 template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
552 rnn_bias_prepare_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::bias_prepare)) {
553 /* Original set of bias provided by the user */
554 AOC<const float, 5> b(
555 b_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic);
556 /* Array of pointers initialized in packing */
557 AOC<float *, 3> bias(bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias);
558 AOC<float, 3> scratch_bias(
559 scratch_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic);
562 parallel_nd(rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dic,
563 [&](size_t i) { scratch_bias_[i] = b_[i]; });
566 for (int i = 0; i < rnn.n_layer; i++) {
567 for (int d = 0; d < rnn.n_dir; d++) {
569 for (int p = 0; p < rnn.n_parts_bias; p++) {
570 bias(i, d, p) = rnn.copy_bias
571 ? (float *) &scratch_bias(i, d, offset_bias)
572 : (float *) &b(i, d, offset_bias);
573 offset_bias += rnn.parts_bias[p] * rnn.dic;
580 template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
581 rnn_bias_finalize_sig(
582 (_ref_rnn_common_t<aprop, src_type, weights_type>::bias_finalize)) {
583 if (rnn.dt_conf != all_f32) {
584 float data_shift = pd()->attr()->rnn_data_qparams_.shift_;
585 float data_scale = pd()->attr()->rnn_data_qparams_.scale_;
586 float *weights_scales = pd()->attr()->rnn_weights_qparams_.scales_;
587 bool scale_per_oc = pd()->attr()->rnn_weights_qparams_.mask_ != 0;
588 for (int i = 0; i < rnn.n_layer * rnn.n_dir; i++)
589 for (int j = 0; j < rnn.n_bias * rnn.dic; j++) {
590 size_t off = i * rnn.n_bias * rnn.dic + j;
592 = scale_per_oc ? weights_scales[j] : weights_scales[0];
593 scratch_bias_[off] -= (w_iter_comp[off] + w_layer_comp[off])
594 * data_shift / (weights_scale * data_scale);
599 template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
600 rnn_weights_assign_sig((_ref_rnn_common_t<aprop, src_type,
601 weights_type>::assign_packed_weights)) {
602 AOC<weights_data_t *, 3> weights(weights_, rnn.n_layer, rnn.n_dir, n_parts);
604 size_t offset_packed = 0;
605 for (int l = 0; l < rnn.n_layer; l++)
606 for (int d = 0; d < rnn.n_dir; d++) {
607 for (int p = 0; p < n_parts; p++) {
608 weights(l, d, p) = (weights_data_t *)&w_[offset_packed];
610 += part_weights_pack_size[p] / sizeof(weights_data_t);
615 template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
616 rnn_weights_assign_sig(
617 (_ref_rnn_common_t<aprop, src_type, weights_type>::assign_weights)) {
618 assert(nld * ld != 0);
619 /* Original set of weights provided by the user */
620 AOC<const weights_data_t, 3> w(w_, rnn.n_layer, rnn.n_dir, nld * ld);
621 /* Array of pointers for each part of weights */
622 AOC<weights_data_t *, 3> weights(weights_, rnn.n_layer, rnn.n_dir, n_parts);
624 for (int i = 0; i < rnn.n_layer; i++)
625 for (int d = 0; d < rnn.n_dir; d++) {
626 size_t offset_weights = 0;
627 for (int p = 0; p < n_parts; p++) {
628 weights(i, d, p) = (weights_data_t *)&w(i, d, offset_weights);
629 offset_weights += fmt == memory_format::ldigo ?
630 gates_per_part[p] * OC_size :
631 gates_per_part[p] * OC_size * ld;
636 //********************* Execution function *********************//
637 template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
638 void _ref_rnn_common_t<aprop, src_type, weights_type>::execute_() const {
639 const rnn_conf_t &rnn = this->pd()->rnn_;
642 auto input = reinterpret_cast<const src_data_t *>(
643 this->input_memory(input_idx++));
644 auto states = pd()->with_src_iter() ? (this->input_memory(input_idx++)) :
647 const char *layer_weights_n_comp = this->input_memory(input_idx++);
649 = reinterpret_cast<const weights_data_t *>(layer_weights_n_comp);
650 auto w_layer_comp = reinterpret_cast<const float *>(layer_weights_n_comp
651 + rnn.weights_layer_comp_offset);
652 const char *iter_weights_n_comp = this->input_memory(input_idx++);
654 = reinterpret_cast<const weights_data_t *>(iter_weights_n_comp);
655 auto w_iter_comp = reinterpret_cast<const float *>(iter_weights_n_comp
656 + rnn.weights_iter_comp_offset);
657 auto bias = pd()->with_bias() ?
658 reinterpret_cast<const float *>(this->input_memory(input_idx++)) :
661 auto dst_last_layer = rnn.is_fwd ? this->memory(output_idx++) :
662 this->input_memory(input_idx++);
663 auto dst_last_iter = pd()->with_dst_iter()
665 ? this->memory(output_idx++)
666 : this->input_memory(input_idx++))
669 auto diff_dst_layer = rnn.is_fwd ?
671 reinterpret_cast<const float *>(this->input_memory(input_idx++));
672 auto diff_dst_iter = rnn.is_fwd || !pd()->with_dst_iter() ?
674 reinterpret_cast<const float *>(this->input_memory(input_idx++));
676 auto scratchpad = this->scratchpad();
679 = scratchpad.template get<weights_data_t *>(key_rnn_ptrs_wei_layer);
681 = scratchpad.template get<weights_data_t *>(key_rnn_ptrs_wei_iter);
683 scratchpad.template get<float *>(key_rnn_ptrs_bia);
685 // fetchihg buffers from the workspace
686 // if no workspace was provided we use the scratchpad
687 char *scratch_ptr = scratchpad.template get<char>(key_rnn_space);
688 char *ws_ptr = nullptr;
689 if (rnn.use_workspace)
691 ? this->memory(output_idx++)
692 : const_cast<char *>(this->input_memory(input_idx++));
693 char *base_ptr = rnn.use_workspace ? ws_ptr : scratch_ptr;
694 acc_data_t *ws_gates = (acc_data_t *)(base_ptr + ws_gates_offset_);
695 src_data_t *ws_states = (src_data_t *)(base_ptr + ws_states_offset_);
696 float *ws_c_states = (float *)(base_ptr + ws_c_states_offset_);
697 float *ws_diff_states = (float *)(base_ptr + ws_diff_states_offset_);
698 float *ws_grid = (float *)(base_ptr + ws_grid_comp_offset_);
699 float *ws_cell = (float *)(base_ptr + ws_cell_comp_offset_);
701 auto diff_src_layer = rnn.is_fwd ?
703 reinterpret_cast<float *>(this->memory(output_idx++));
704 auto diff_src_iter = rnn.is_fwd || !pd()->with_src_iter() ?
706 reinterpret_cast<float *>(this->memory(output_idx++));
707 auto diff_weights_layer = rnn.is_fwd ?
709 reinterpret_cast<float *>(this->memory(output_idx++));
710 auto diff_weights_iter = rnn.is_fwd ?
712 reinterpret_cast<float *>(this->memory(output_idx++));
713 auto diff_bias = rnn.is_fwd || !pd()->with_bias() ?
715 reinterpret_cast<float *>(this->memory(output_idx++));
717 // Fetching extra buffers from scratchpad
718 float *ws_bias = (float *)(scratch_ptr + ws_bias_offset_);
720 // initialize diff_states to 0
721 if (aprop == prop_kind::backward)
722 array_set(ws_diff_states, 0.0f, rnn.ws_diff_states_size / sizeof(float));
724 /* Pack(if using packed gemm API) or copy(if input arrays have bad leading
726 (this->*bias_preparation_func)(rnn, ptr_bias, bias, ws_bias);
728 (this->*weights_iter_assign_func)(rnn, rnn.weights_iter_fmt,
729 rnn.weights_iter_nld, rnn.weights_iter_ld, rnn.dic,
730 rnn.sic, rnn.n_parts_weights_iter, rnn.parts_weights_iter,
731 rnn.part_weights_iter_pack_size, ptr_wei_iter, w_iter,
732 ptr_bias, bias, ws_bias);
733 (this->*weights_layer_assign_func)(rnn, rnn.weights_layer_fmt,
734 rnn.weights_layer_nld, rnn.weights_layer_ld, rnn.dic, rnn.slc,
735 rnn.n_parts_weights_layer, rnn.parts_weights_layer,
736 rnn.part_weights_layer_pack_size, ptr_wei_layer, w_layer, ptr_bias,
739 (this->*bias_finalization_func)(rnn, ws_bias, w_iter_comp, w_layer_comp);
741 // we first need to copy the initial states and input into ws
742 copy_init_layer(rnn, ws_states, ws_diff_states, input, diff_dst_layer);
743 if (rnn.dt_conf == f32u8f32u8 || rnn.dt_conf == f32u8f32f32
744 || rnn.dt_conf == all_f32)
745 copy_init_iter(rnn, ws_states, ws_c_states, ws_diff_states,
746 (const float *)states, diff_dst_iter);
747 else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == u8u8u8f32)
748 copy_init_iter(rnn, ws_states, ws_c_states, ws_diff_states,
749 (const uint8_t *)states, diff_dst_iter);
751 assert(!"unimplemented");
753 // run the execution on the grid
754 (this->*grid_computation)(rnn, ptr_wei_layer, ptr_wei_iter, ptr_bias,
755 ws_states, ws_c_states, ws_diff_states, ws_gates, ws_cell, ws_grid,
756 diff_weights_layer, diff_weights_iter, diff_bias);
758 // Finally we copy the results to the result buffers
759 if (rnn.dt_conf == u8u8u8f32 || rnn.dt_conf == f32u8f32f32
760 || rnn.dt_conf == all_f32)
761 copy_res_layer(rnn, (float *)dst_last_layer, diff_src_layer, ws_states,
763 else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == f32u8f32u8)
764 copy_res_layer(rnn, (uint8_t *)dst_last_layer, diff_src_layer,
765 ws_states, ws_diff_states);
767 assert(!"unimplemented");
769 if (rnn.dt_conf == f32u8f32u8 || rnn.dt_conf == f32u8f32f32
770 || rnn.dt_conf == all_f32)
771 copy_res_iter(rnn, (float *)dst_last_iter, diff_src_iter, ws_states,
772 ws_c_states, ws_diff_states);
773 else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == u8u8u8f32)
774 copy_res_iter(rnn, (uint8_t *)dst_last_iter, diff_src_iter, ws_states,
775 ws_c_states, ws_diff_states);
777 assert(!"unimplemented");
780 /* Fix for MSVS warning C4661 */
781 template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution);
782 template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution);
783 template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution);
784 template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru);
785 template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru);
786 template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru);
787 template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru_lbr);
788 template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru_lbr);
789 template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru_lbr);
790 template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::rnn_elemwise);
791 template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::rnn_elemwise);
792 template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::rnn_elemwise);
793 template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::lstm_elemwise);
794 template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::lstm_elemwise);
795 template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::lstm_elemwise);
796 template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::gru_lbr_elemwise);
797 template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::gru_lbr_elemwise);
798 template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::gru_lbr_elemwise);
800 template struct _ref_rnn_common_t<prop_kind::forward, data_type::f32, data_type::f32>;
801 template struct _ref_rnn_common_t<prop_kind::forward, data_type::u8, data_type::s8>;
802 template struct _ref_rnn_common_t<prop_kind::backward, data_type::f32, data_type::f32>;